refactor
This commit is contained in:
parent
5f95f23bc9
commit
812c041bb4
|
|
@ -1,184 +1,100 @@
|
||||||
import asyncio
|
import asyncio, warnings
|
||||||
import warnings
|
|
||||||
|
|
||||||
class BaseNode:
|
class BaseNode:
|
||||||
# preprocess(): this is for compute intensive preparation tasks, before the LLM call
|
def __init__(self): self.params, self.successors = {}, {}
|
||||||
# process(): this is for the LLM call, and should be idempotent for retries
|
def set_params(self, params): self.params = params
|
||||||
# postprocess(): this is to summarize the result and retrun the condition for the successor node
|
def add_successor(self, node, cond="default"):
|
||||||
def __init__(self):
|
if cond in self.successors: warnings.warn(f"Overwriting existing successor for '{cond}'")
|
||||||
self.params, self.successors = {}, {}
|
self.successors[cond] = node; return node
|
||||||
|
def preprocess(self, s): return None
|
||||||
def set_params(self, params): # make sure params is immutable
|
def process(self, s, p): return None
|
||||||
self.params = params # must be immutable during pre/post/process
|
def _process(self, s, p): return self.process(s, p)
|
||||||
|
def postprocess(self, s, pr, r): return "default"
|
||||||
def add_successor(self, node, condition="default"):
|
def _run(self, s):
|
||||||
if condition in self.successors:
|
pr = self.preprocess(s)
|
||||||
warnings.warn(f"Overwriting existing successor for condition '{condition}'")
|
r = self._process(s, pr)
|
||||||
self.successors[condition] = node # maps condition -> successor node
|
return self.postprocess(s, pr, r)
|
||||||
return node
|
def run(self, s):
|
||||||
|
if self.successors: warnings.warn("Has successors; use Flow.run() instead.")
|
||||||
def preprocess(self, shared_storage):
|
return self._run(s)
|
||||||
return None # will be passed to process() and postprocess()
|
def __rshift__(self, other): return self.add_successor(other)
|
||||||
|
def __sub__(self, cond):
|
||||||
def process(self, shared_storage, prep_result):
|
if isinstance(cond, str): return _ConditionalTransition(self, cond)
|
||||||
return None # will be passed to postprocess()
|
|
||||||
|
|
||||||
def _process(self, shared_storage, prep_result):
|
|
||||||
# Could have retry logic or other wrap logic
|
|
||||||
return self.process(shared_storage, prep_result)
|
|
||||||
|
|
||||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
|
||||||
return "default" # condition for next node
|
|
||||||
|
|
||||||
def _run(self, shared_storage):
|
|
||||||
prep_result = self.preprocess(shared_storage)
|
|
||||||
proc_result = self._process(shared_storage, prep_result)
|
|
||||||
return self.postprocess(shared_storage, prep_result, proc_result)
|
|
||||||
|
|
||||||
def run(self, shared_storage):
|
|
||||||
if self.successors:
|
|
||||||
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent Flow and use that Flow.run() instead.")
|
|
||||||
return self._run(shared_storage)
|
|
||||||
|
|
||||||
def __rshift__(self, other):
|
|
||||||
# chaining: node1 >> node2
|
|
||||||
return self.add_successor(other)
|
|
||||||
|
|
||||||
def __sub__(self, condition):
|
|
||||||
# condition-based chaining: node - "some_condition" >> next_node
|
|
||||||
if isinstance(condition, str):
|
|
||||||
return _ConditionalTransition(self, condition)
|
|
||||||
raise TypeError("Condition must be a string")
|
raise TypeError("Condition must be a string")
|
||||||
|
|
||||||
class _ConditionalTransition:
|
class _ConditionalTransition:
|
||||||
def __init__(self, source_node, condition):
|
def __init__(self, src, c): self.src, self.c = src, c
|
||||||
self.source_node = source_node
|
def __rshift__(self, tgt): return self.src.add_successor(tgt, self.c)
|
||||||
self.condition = condition
|
|
||||||
|
|
||||||
def __rshift__(self, target_node):
|
|
||||||
return self.source_node.add_successor(target_node, self.condition)
|
|
||||||
|
|
||||||
class Node(BaseNode):
|
class Node(BaseNode):
|
||||||
def __init__(self, max_retries=1):
|
def __init__(self, max_retries=1): super().__init__(); self.max_retries = max_retries
|
||||||
super().__init__()
|
def process_after_fail(self, s, d, e): raise e
|
||||||
self.max_retries = max_retries
|
def _process(self, s, d):
|
||||||
|
for i in range(self.max_retries):
|
||||||
def process_after_fail(self, shared_storage, data, exc):
|
try: return super()._process(s, d)
|
||||||
raise exc
|
|
||||||
|
|
||||||
def _process(self, shared_storage, data):
|
|
||||||
for attempt in range(self.max_retries):
|
|
||||||
try:
|
|
||||||
return super()._process(shared_storage, data)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt == self.max_retries - 1:
|
if i == self.max_retries - 1: return self.process_after_fail(s, d, e)
|
||||||
return self.process_after_fail(shared_storage, data, e)
|
|
||||||
|
|
||||||
class BatchNode(Node):
|
class BatchNode(Node):
|
||||||
def preprocess(self, shared_storage):
|
def preprocess(self, s): return []
|
||||||
# return an iterable of items, one for each run
|
def process(self, s, item): return None
|
||||||
return []
|
def _process(self, s, items): return [super(Node, self)._process(s, i) for i in items]
|
||||||
|
|
||||||
def process(self, shared_storage, item): # process() is called for each item
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _process(self, shared_storage, items):
|
|
||||||
results = []
|
|
||||||
for item in items:
|
|
||||||
# Here, 'item' is passed in place of 'prep_result' from the BaseNode's perspective.
|
|
||||||
r = super()._process(shared_storage, item)
|
|
||||||
results.append(r)
|
|
||||||
return results
|
|
||||||
|
|
||||||
class AsyncNode(Node):
|
|
||||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
|
||||||
raise NotImplementedError("AsyncNode requires postprocess_async, and should be run in an AsyncFlow")
|
|
||||||
|
|
||||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
|
||||||
await asyncio.sleep(0) # trivial async pause (no-op)
|
|
||||||
return "default"
|
|
||||||
|
|
||||||
async def run_async(self, shared_storage):
|
|
||||||
if self.successors:
|
|
||||||
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent AsyncFlow and use that AsyncFlow.run_async() instead.")
|
|
||||||
return await self._run_async(shared_storage)
|
|
||||||
|
|
||||||
async def _run_async(self, shared_storage):
|
|
||||||
prep_result = self.preprocess(shared_storage)
|
|
||||||
proc_result = self._process(shared_storage, prep_result)
|
|
||||||
return await self.postprocess_async(shared_storage, prep_result, proc_result)
|
|
||||||
|
|
||||||
def _run(self, shared_storage):
|
|
||||||
raise RuntimeError("AsyncNode requires asynchronous execution. Use 'await node.run_async()' if inside an async function, or 'asyncio.run(node.run_async())' if in synchronous code.")
|
|
||||||
|
|
||||||
class BaseFlow(BaseNode):
|
class BaseFlow(BaseNode):
|
||||||
def __init__(self, start_node):
|
def __init__(self, start_node): super().__init__(); self.start_node = start_node
|
||||||
super().__init__()
|
def get_next_node(self, curr, c):
|
||||||
self.start_node = start_node
|
nxt = curr.successors.get(c)
|
||||||
|
if nxt is None and curr.successors: warnings.warn(f"Flow ends. '{c}' not found in {list(curr.successors.keys())}")
|
||||||
|
return nxt
|
||||||
|
|
||||||
def get_next_node(self, current_node, condition):
|
|
||||||
next_node = current_node.successors.get(condition, None)
|
|
||||||
|
|
||||||
if next_node is None and current_node.successors:
|
|
||||||
warnings.warn(f"Flow will end. Condition '{condition}' not found among possible conditions: {list(current_node.successors.keys())}")
|
|
||||||
|
|
||||||
return next_node
|
|
||||||
|
|
||||||
class Flow(BaseFlow):
|
class Flow(BaseFlow):
|
||||||
def _process(self, shared_storage, params=None):
|
def _process(self, s, p=None):
|
||||||
current_node = self.start_node
|
curr, p = self.start_node, (p if p is not None else self.params.copy())
|
||||||
params = params if params is not None else self.params.copy()
|
while curr:
|
||||||
|
curr.set_params(p)
|
||||||
while current_node:
|
c = curr._run(s)
|
||||||
current_node.set_params(params)
|
curr = self.get_next_node(curr, c)
|
||||||
condition = current_node._run(shared_storage)
|
def process(self, s, pr): raise NotImplementedError("Use Flow._process(...) instead")
|
||||||
current_node = self.get_next_node(current_node, condition)
|
|
||||||
|
|
||||||
def process(self, shared_storage, prep_result):
|
|
||||||
raise NotImplementedError("Flow should not process directly")
|
|
||||||
|
|
||||||
class AsyncFlow(BaseFlow, AsyncNode):
|
|
||||||
async def _process_async(self, shared_storage, params=None):
|
|
||||||
current_node = self.start_node
|
|
||||||
params = params if params is not None else self.params.copy()
|
|
||||||
|
|
||||||
while current_node:
|
|
||||||
current_node.set_params(params)
|
|
||||||
|
|
||||||
if hasattr(current_node, "run_async") and callable(current_node.run_async):
|
|
||||||
condition = await current_node._run_async(shared_storage)
|
|
||||||
else:
|
|
||||||
condition = current_node._run(shared_storage)
|
|
||||||
|
|
||||||
current_node = self.get_next_node(current_node, condition)
|
|
||||||
|
|
||||||
async def _run_async(self, shared_storage):
|
|
||||||
prep_result = self.preprocess(shared_storage)
|
|
||||||
await self._process_async(shared_storage)
|
|
||||||
return await self.postprocess_async(shared_storage, prep_result, None)
|
|
||||||
|
|
||||||
class BaseBatchFlow(BaseFlow):
|
class BaseBatchFlow(BaseFlow):
|
||||||
def preprocess(self, shared_storage):
|
def preprocess(self, s): return []
|
||||||
return [] # return an iterable of parameter dictionaries
|
|
||||||
|
|
||||||
class BatchFlow(BaseBatchFlow, Flow):
|
class BatchFlow(BaseBatchFlow, Flow):
|
||||||
def _run(self, shared_storage):
|
def _run(self, s):
|
||||||
prep_result = self.preprocess(shared_storage)
|
pr = self.preprocess(s)
|
||||||
|
for d in pr:
|
||||||
for param_dict in prep_result:
|
mp = self.params.copy(); mp.update(d)
|
||||||
merged_params = self.params.copy()
|
self._process(s, mp)
|
||||||
merged_params.update(param_dict)
|
return self.postprocess(s, pr, None)
|
||||||
self._process(shared_storage, params=merged_params)
|
|
||||||
|
class AsyncNode(Node):
|
||||||
return self.postprocess(shared_storage, prep_result, None)
|
def postprocess(self, s, pr, r): raise NotImplementedError("Use postprocess_async")
|
||||||
|
async def postprocess_async(self, s, pr, r): await asyncio.sleep(0); return "default"
|
||||||
|
async def run_async(self, s):
|
||||||
|
if self.successors: warnings.warn("Has successors; use AsyncFlow.run_async() instead.")
|
||||||
|
return await self._run_async(s)
|
||||||
|
async def _run_async(self, s):
|
||||||
|
pr = self.preprocess(s)
|
||||||
|
r = self._process(s, pr)
|
||||||
|
return await self.postprocess_async(s, pr, r)
|
||||||
|
def _run(self, s): raise RuntimeError("AsyncNode requires async execution")
|
||||||
|
|
||||||
|
class AsyncFlow(BaseFlow, AsyncNode):
|
||||||
|
async def _process_async(self, s, p=None):
|
||||||
|
curr, p = self.start_node, (p if p else self.params.copy())
|
||||||
|
while curr:
|
||||||
|
curr.set_params(p)
|
||||||
|
c = await curr._run_async(s) if hasattr(curr, "run_async") else curr._run(s)
|
||||||
|
curr = self.get_next_node(curr, c)
|
||||||
|
async def _run_async(self, s):
|
||||||
|
pr = self.preprocess(s)
|
||||||
|
await self._process_async(s)
|
||||||
|
return await self.postprocess_async(s, pr, None)
|
||||||
|
|
||||||
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
||||||
async def _run_async(self, shared_storage):
|
async def _run_async(self, s):
|
||||||
prep_result = self.preprocess(shared_storage)
|
pr = self.preprocess(s)
|
||||||
|
for d in pr:
|
||||||
for param_dict in prep_result:
|
mp = self.params.copy(); mp.update(d)
|
||||||
merged_params = self.params.copy()
|
await self._process_async(s, mp)
|
||||||
merged_params.update(param_dict)
|
return await self.postprocess_async(s, pr, None)
|
||||||
await self._process_async(shared_storage, params=merged_params)
|
|
||||||
|
|
||||||
return await self.postprocess_async(shared_storage, prep_result, None)
|
|
||||||
|
|
@ -0,0 +1,176 @@
|
||||||
|
import unittest
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
from minillmflow import AsyncNode, BatchAsyncFlow
|
||||||
|
|
||||||
|
class AsyncDataProcessNode(AsyncNode):
|
||||||
|
def process(self, shared_storage, prep_result):
|
||||||
|
key = self.params.get('key')
|
||||||
|
data = shared_storage['input_data'][key]
|
||||||
|
if 'results' not in shared_storage:
|
||||||
|
shared_storage['results'] = {}
|
||||||
|
shared_storage['results'][key] = data
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||||
|
await asyncio.sleep(0.01) # Simulate async work
|
||||||
|
key = self.params.get('key')
|
||||||
|
shared_storage['results'][key] = proc_result * 2 # Double the value
|
||||||
|
return "processed"
|
||||||
|
|
||||||
|
class AsyncErrorNode(AsyncNode):
|
||||||
|
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||||
|
key = self.params.get('key')
|
||||||
|
if key == 'error_key':
|
||||||
|
raise ValueError(f"Async error processing key: {key}")
|
||||||
|
return "processed"
|
||||||
|
|
||||||
|
class TestAsyncBatchFlow(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.process_node = AsyncDataProcessNode()
|
||||||
|
|
||||||
|
def test_basic_async_batch_processing(self):
|
||||||
|
"""Test basic async batch processing with multiple keys"""
|
||||||
|
class SimpleTestAsyncBatchFlow(BatchAsyncFlow):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_data': {
|
||||||
|
'a': 1,
|
||||||
|
'b': 2,
|
||||||
|
'c': 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flow = SimpleTestAsyncBatchFlow(start_node=self.process_node)
|
||||||
|
asyncio.run(flow.run_async(shared_storage))
|
||||||
|
|
||||||
|
expected_results = {
|
||||||
|
'a': 2, # 1 * 2
|
||||||
|
'b': 4, # 2 * 2
|
||||||
|
'c': 6 # 3 * 2
|
||||||
|
}
|
||||||
|
self.assertEqual(shared_storage['results'], expected_results)
|
||||||
|
|
||||||
|
def test_empty_async_batch(self):
|
||||||
|
"""Test async batch processing with empty input"""
|
||||||
|
class EmptyTestAsyncBatchFlow(BatchAsyncFlow):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_data': {}
|
||||||
|
}
|
||||||
|
|
||||||
|
flow = EmptyTestAsyncBatchFlow(start_node=self.process_node)
|
||||||
|
asyncio.run(flow.run_async(shared_storage))
|
||||||
|
|
||||||
|
self.assertEqual(shared_storage.get('results', {}), {})
|
||||||
|
|
||||||
|
def test_async_error_handling(self):
|
||||||
|
"""Test error handling during async batch processing"""
|
||||||
|
class ErrorTestAsyncBatchFlow(BatchAsyncFlow):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_data': {
|
||||||
|
'normal_key': 1,
|
||||||
|
'error_key': 2,
|
||||||
|
'another_key': 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flow = ErrorTestAsyncBatchFlow(start_node=AsyncErrorNode())
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
asyncio.run(flow.run_async(shared_storage))
|
||||||
|
|
||||||
|
def test_nested_async_flow(self):
|
||||||
|
"""Test async batch processing with nested flows"""
|
||||||
|
class AsyncInnerNode(AsyncNode):
|
||||||
|
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||||
|
key = self.params.get('key')
|
||||||
|
if 'intermediate_results' not in shared_storage:
|
||||||
|
shared_storage['intermediate_results'] = {}
|
||||||
|
shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return "next"
|
||||||
|
|
||||||
|
class AsyncOuterNode(AsyncNode):
|
||||||
|
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||||
|
key = self.params.get('key')
|
||||||
|
if 'results' not in shared_storage:
|
||||||
|
shared_storage['results'] = {}
|
||||||
|
shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
class NestedAsyncBatchFlow(BatchAsyncFlow):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
|
# Create inner flow
|
||||||
|
inner_node = AsyncInnerNode()
|
||||||
|
outer_node = AsyncOuterNode()
|
||||||
|
inner_node - "next" >> outer_node
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_data': {
|
||||||
|
'x': 1,
|
||||||
|
'y': 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flow = NestedAsyncBatchFlow(start_node=inner_node)
|
||||||
|
asyncio.run(flow.run_async(shared_storage))
|
||||||
|
|
||||||
|
expected_results = {
|
||||||
|
'x': 4, # (1 + 1) * 2
|
||||||
|
'y': 6 # (2 + 1) * 2
|
||||||
|
}
|
||||||
|
self.assertEqual(shared_storage['results'], expected_results)
|
||||||
|
|
||||||
|
def test_custom_async_parameters(self):
|
||||||
|
"""Test async batch processing with additional custom parameters"""
|
||||||
|
class CustomParamAsyncNode(AsyncNode):
|
||||||
|
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||||
|
key = self.params.get('key')
|
||||||
|
multiplier = self.params.get('multiplier', 1)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
if 'results' not in shared_storage:
|
||||||
|
shared_storage['results'] = {}
|
||||||
|
shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
class CustomParamAsyncBatchFlow(BatchAsyncFlow):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
return [{
|
||||||
|
'key': k,
|
||||||
|
'multiplier': i + 1
|
||||||
|
} for i, k in enumerate(shared_storage['input_data'].keys())]
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_data': {
|
||||||
|
'a': 1,
|
||||||
|
'b': 2,
|
||||||
|
'c': 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flow = CustomParamAsyncBatchFlow(start_node=CustomParamAsyncNode())
|
||||||
|
asyncio.run(flow.run_async(shared_storage))
|
||||||
|
|
||||||
|
expected_results = {
|
||||||
|
'a': 1 * 1, # first item, multiplier = 1
|
||||||
|
'b': 2 * 2, # second item, multiplier = 2
|
||||||
|
'c': 3 * 3 # third item, multiplier = 3
|
||||||
|
}
|
||||||
|
self.assertEqual(shared_storage['results'], expected_results)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue