batch node
This commit is contained in:
parent
b257169d73
commit
f802251246
|
|
@ -2,38 +2,33 @@ import asyncio
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
class BaseNode:
|
class BaseNode:
|
||||||
"""
|
# preprocess(): this is for compute intensive preparation tasks, before the LLM call
|
||||||
A base node that provides:
|
# process(): this is for the LLM call, and should be idempotent for retries
|
||||||
- preprocess()
|
# postprocess(): this is to summarize the result and retrun the condition for the successor node
|
||||||
- process()
|
|
||||||
- postprocess()
|
|
||||||
- run() -- just runs itself (no chaining)
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.parameters = {}
|
self.parameters, self.successors = {}, {}
|
||||||
self.successors = {}
|
|
||||||
|
|
||||||
def set_parameters(self, params):
|
def set_parameters(self, params): # make sure params is immutable
|
||||||
self.parameters.update(params)
|
self.parameters = params # must be immutable during pre/post/process
|
||||||
|
|
||||||
def add_successor(self, node, condition="default"):
|
def add_successor(self, node, condition="default"):
|
||||||
if condition in self.successors:
|
if condition in self.successors:
|
||||||
print(f"Warning: overwriting existing successor for condition '{condition}'")
|
warnings.warn(f"Overwriting existing successor for condition '{condition}'")
|
||||||
self.successors[condition] = node
|
self.successors[condition] = node # maps condition -> successor node
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def preprocess(self, shared_storage):
|
def preprocess(self, shared_storage):
|
||||||
return None
|
return None # will be passed to process() and postprocess()
|
||||||
|
|
||||||
def process(self, shared_storage, prep_result):
|
def process(self, shared_storage, prep_result):
|
||||||
return None
|
return None # will be passed to postprocess()
|
||||||
|
|
||||||
def _process(self, shared_storage, prep_result):
|
def _process(self, shared_storage, prep_result):
|
||||||
# Could have retry logic or other wrap logic
|
# Could have retry logic or other wrap logic
|
||||||
return self.process(shared_storage, prep_result)
|
return self.process(shared_storage, prep_result)
|
||||||
|
|
||||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||||
return "default"
|
return "default" # condition for next node
|
||||||
|
|
||||||
def run(self, shared_storage=None):
|
def run(self, shared_storage=None):
|
||||||
prep = self.preprocess(shared_storage)
|
prep = self.preprocess(shared_storage)
|
||||||
|
|
@ -41,42 +36,16 @@ class BaseNode:
|
||||||
return self.postprocess(shared_storage, prep, proc)
|
return self.postprocess(shared_storage, prep, proc)
|
||||||
|
|
||||||
def __rshift__(self, other):
|
def __rshift__(self, other):
|
||||||
"""
|
# chaining: node1 >> node2
|
||||||
For chaining with >> operator, e.g. node1 >> node2
|
|
||||||
"""
|
|
||||||
return self.add_successor(other)
|
return self.add_successor(other)
|
||||||
|
|
||||||
def __gt__(self, other):
|
|
||||||
"""
|
|
||||||
For chaining with > operator, e.g. node1 > "some_condition"
|
|
||||||
then >> node2
|
|
||||||
"""
|
|
||||||
if isinstance(other, str):
|
|
||||||
return _ConditionalTransition(self, other)
|
|
||||||
elif isinstance(other, BaseNode):
|
|
||||||
return self.add_successor(other)
|
|
||||||
raise TypeError("Unsupported operand type")
|
|
||||||
|
|
||||||
def __call__(self, condition):
|
|
||||||
"""
|
|
||||||
For node("condition") >> next_node syntax
|
|
||||||
"""
|
|
||||||
return _ConditionalTransition(self, condition)
|
|
||||||
|
|
||||||
def __sub__(self, condition):
|
def __sub__(self, condition):
|
||||||
"""
|
# condition-based chaining: node - "some_condition" >> next_node
|
||||||
For chaining with - operator, e.g. node - "some_condition" >> next_node
|
|
||||||
"""
|
|
||||||
if isinstance(condition, str):
|
if isinstance(condition, str):
|
||||||
return _ConditionalTransition(self, condition)
|
return _ConditionalTransition(self, condition)
|
||||||
raise TypeError("Condition must be a string")
|
raise TypeError("Condition must be a string")
|
||||||
|
|
||||||
|
|
||||||
class _ConditionalTransition:
|
class _ConditionalTransition:
|
||||||
"""
|
|
||||||
Helper for Node > 'condition' >> AnotherNode style
|
|
||||||
(and also Node - 'condition' >> AnotherNode now).
|
|
||||||
"""
|
|
||||||
def __init__(self, source_node, condition):
|
def __init__(self, source_node, condition):
|
||||||
self.source_node = source_node
|
self.source_node = source_node
|
||||||
self.condition = condition
|
self.condition = condition
|
||||||
|
|
@ -100,6 +69,22 @@ class Node(BaseNode):
|
||||||
if attempt == self.max_retries - 1:
|
if attempt == self.max_retries - 1:
|
||||||
return self.process_after_fail(shared_storage, data, e)
|
return self.process_after_fail(shared_storage, data, e)
|
||||||
|
|
||||||
|
class BatchNode(Node):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
# return an iterable of items, one for each run
|
||||||
|
return []
|
||||||
|
|
||||||
|
def process(self, shared_storage, item): # process() is called for each item
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _process(self, shared_storage, items):
|
||||||
|
results = []
|
||||||
|
for item in items:
|
||||||
|
# Here, 'item' is passed in place of 'prep_result' from the BaseNode's perspective.
|
||||||
|
r = super()._process(shared_storage, item)
|
||||||
|
results.append(r)
|
||||||
|
return results
|
||||||
|
|
||||||
class AsyncNode(Node):
|
class AsyncNode(Node):
|
||||||
"""
|
"""
|
||||||
A Node whose postprocess step is async.
|
A Node whose postprocess step is async.
|
||||||
|
|
@ -199,7 +184,10 @@ class AsyncFlow(BaseFlow):
|
||||||
return self.postprocess(shared_storage, prep_result, None)
|
return self.postprocess(shared_storage, prep_result, None)
|
||||||
|
|
||||||
def run(self, shared_storage=None):
|
def run(self, shared_storage=None):
|
||||||
return asyncio.run(self.run_async(shared_storage))
|
try:
|
||||||
|
return asyncio.run(self.run_async(shared_storage))
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError("If you are running in Jupyter, please use `await run_async()` instead of `run()`.") from e
|
||||||
|
|
||||||
class BaseBatchFlow(BaseFlow):
|
class BaseBatchFlow(BaseFlow):
|
||||||
"""
|
"""
|
||||||
|
|
@ -213,12 +201,6 @@ class BaseBatchFlow(BaseFlow):
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def post_batch_run(self, all_results):
|
|
||||||
"""
|
|
||||||
Hook for after the entire batch is done, to combine results, etc.
|
|
||||||
"""
|
|
||||||
return all_results
|
|
||||||
|
|
||||||
class BatchFlow(BaseBatchFlow, Flow):
|
class BatchFlow(BaseBatchFlow, Flow):
|
||||||
"""
|
"""
|
||||||
Synchronous batch flow: calls the flow repeatedly
|
Synchronous batch flow: calls the flow repeatedly
|
||||||
|
|
@ -243,10 +225,6 @@ class BatchFlow(BaseBatchFlow, Flow):
|
||||||
# Reset the parameters if needed
|
# Reset the parameters if needed
|
||||||
self.parameters = original_params
|
self.parameters = original_params
|
||||||
|
|
||||||
# Postprocess the entire batch
|
|
||||||
result = self.post_batch_run(all_results)
|
|
||||||
return self.postprocess(shared_storage, prep_result, result)
|
|
||||||
|
|
||||||
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
||||||
"""
|
"""
|
||||||
Asynchronous batch flow: calls the flow repeatedly in an async manner
|
Asynchronous batch flow: calls the flow repeatedly in an async manner
|
||||||
|
|
@ -265,8 +243,4 @@ class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
||||||
all_results.append(f"Finished async run with parameters: {param_dict}")
|
all_results.append(f"Finished async run with parameters: {param_dict}")
|
||||||
|
|
||||||
# Reset back to original parameters if needed
|
# Reset back to original parameters if needed
|
||||||
self.parameters = original_params
|
self.parameters = original_params
|
||||||
|
|
||||||
# Combine or process results at the end
|
|
||||||
result = self.post_batch_run(all_results)
|
|
||||||
return self.postprocess(shared_storage, prep_result, result)
|
|
||||||
Loading…
Reference in New Issue