diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index d9c6725..dc1db63 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -1,184 +1,100 @@ -import asyncio -import warnings - +import asyncio, warnings + class BaseNode: - # preprocess(): this is for compute intensive preparation tasks, before the LLM call - # process(): this is for the LLM call, and should be idempotent for retries - # postprocess(): this is to summarize the result and retrun the condition for the successor node - def __init__(self): - self.params, self.successors = {}, {} - - def set_params(self, params): # make sure params is immutable - self.params = params # must be immutable during pre/post/process - - def add_successor(self, node, condition="default"): - if condition in self.successors: - warnings.warn(f"Overwriting existing successor for condition '{condition}'") - self.successors[condition] = node # maps condition -> successor node - return node - - def preprocess(self, shared_storage): - return None # will be passed to process() and postprocess() - - def process(self, shared_storage, prep_result): - return None # will be passed to postprocess() - - def _process(self, shared_storage, prep_result): - # Could have retry logic or other wrap logic - return self.process(shared_storage, prep_result) - - def postprocess(self, shared_storage, prep_result, proc_result): - return "default" # condition for next node - - def _run(self, shared_storage): - prep_result = self.preprocess(shared_storage) - proc_result = self._process(shared_storage, prep_result) - return self.postprocess(shared_storage, prep_result, proc_result) - - def run(self, shared_storage): - if self.successors: - warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent Flow and use that Flow.run() instead.") - return self._run(shared_storage) - - def __rshift__(self, other): - # chaining: node1 >> node2 - return self.add_successor(other) - - def __sub__(self, condition): - # condition-based chaining: node - "some_condition" >> next_node - if isinstance(condition, str): - return _ConditionalTransition(self, condition) + def __init__(self): self.params, self.successors = {}, {} + def set_params(self, params): self.params = params + def add_successor(self, node, cond="default"): + if cond in self.successors: warnings.warn(f"Overwriting existing successor for '{cond}'") + self.successors[cond] = node; return node + def preprocess(self, s): return None + def process(self, s, p): return None + def _process(self, s, p): return self.process(s, p) + def postprocess(self, s, pr, r): return "default" + def _run(self, s): + pr = self.preprocess(s) + r = self._process(s, pr) + return self.postprocess(s, pr, r) + def run(self, s): + if self.successors: warnings.warn("Has successors; use Flow.run() instead.") + return self._run(s) + def __rshift__(self, other): return self.add_successor(other) + def __sub__(self, cond): + if isinstance(cond, str): return _ConditionalTransition(self, cond) raise TypeError("Condition must be a string") class _ConditionalTransition: - def __init__(self, source_node, condition): - self.source_node = source_node - self.condition = condition - - def __rshift__(self, target_node): - return self.source_node.add_successor(target_node, self.condition) + def __init__(self, src, c): self.src, self.c = src, c + def __rshift__(self, tgt): return self.src.add_successor(tgt, self.c) class Node(BaseNode): - def __init__(self, max_retries=1): - super().__init__() - self.max_retries = max_retries - - def process_after_fail(self, shared_storage, data, exc): - raise exc - - def _process(self, shared_storage, data): - for attempt in range(self.max_retries): - try: - return super()._process(shared_storage, data) + def __init__(self, max_retries=1): super().__init__(); self.max_retries = max_retries + def process_after_fail(self, s, d, e): raise e + def _process(self, s, d): + for i in range(self.max_retries): + try: return super()._process(s, d) except Exception as e: - if attempt == self.max_retries - 1: - return self.process_after_fail(shared_storage, data, e) + if i == self.max_retries - 1: return self.process_after_fail(s, d, e) class BatchNode(Node): - def preprocess(self, shared_storage): - # return an iterable of items, one for each run - return [] - - def process(self, shared_storage, item): # process() is called for each item - return None - - def _process(self, shared_storage, items): - results = [] - for item in items: - # Here, 'item' is passed in place of 'prep_result' from the BaseNode's perspective. - r = super()._process(shared_storage, item) - results.append(r) - return results - -class AsyncNode(Node): - def postprocess(self, shared_storage, prep_result, proc_result): - raise NotImplementedError("AsyncNode requires postprocess_async, and should be run in an AsyncFlow") - - async def postprocess_async(self, shared_storage, prep_result, proc_result): - await asyncio.sleep(0) # trivial async pause (no-op) - return "default" - - async def run_async(self, shared_storage): - if self.successors: - warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent AsyncFlow and use that AsyncFlow.run_async() instead.") - return await self._run_async(shared_storage) - - async def _run_async(self, shared_storage): - prep_result = self.preprocess(shared_storage) - proc_result = self._process(shared_storage, prep_result) - return await self.postprocess_async(shared_storage, prep_result, proc_result) - - def _run(self, shared_storage): - raise RuntimeError("AsyncNode requires asynchronous execution. Use 'await node.run_async()' if inside an async function, or 'asyncio.run(node.run_async())' if in synchronous code.") + def preprocess(self, s): return [] + def process(self, s, item): return None + def _process(self, s, items): return [super(Node, self)._process(s, i) for i in items] class BaseFlow(BaseNode): - def __init__(self, start_node): - super().__init__() - self.start_node = start_node + def __init__(self, start_node): super().__init__(); self.start_node = start_node + def get_next_node(self, curr, c): + nxt = curr.successors.get(c) + if nxt is None and curr.successors: warnings.warn(f"Flow ends. '{c}' not found in {list(curr.successors.keys())}") + return nxt - def get_next_node(self, current_node, condition): - next_node = current_node.successors.get(condition, None) - - if next_node is None and current_node.successors: - warnings.warn(f"Flow will end. Condition '{condition}' not found among possible conditions: {list(current_node.successors.keys())}") - - return next_node - class Flow(BaseFlow): - def _process(self, shared_storage, params=None): - current_node = self.start_node - params = params if params is not None else self.params.copy() - - while current_node: - current_node.set_params(params) - condition = current_node._run(shared_storage) - current_node = self.get_next_node(current_node, condition) - - def process(self, shared_storage, prep_result): - raise NotImplementedError("Flow should not process directly") - -class AsyncFlow(BaseFlow, AsyncNode): - async def _process_async(self, shared_storage, params=None): - current_node = self.start_node - params = params if params is not None else self.params.copy() - - while current_node: - current_node.set_params(params) - - if hasattr(current_node, "run_async") and callable(current_node.run_async): - condition = await current_node._run_async(shared_storage) - else: - condition = current_node._run(shared_storage) - - current_node = self.get_next_node(current_node, condition) - - async def _run_async(self, shared_storage): - prep_result = self.preprocess(shared_storage) - await self._process_async(shared_storage) - return await self.postprocess_async(shared_storage, prep_result, None) + def _process(self, s, p=None): + curr, p = self.start_node, (p if p is not None else self.params.copy()) + while curr: + curr.set_params(p) + c = curr._run(s) + curr = self.get_next_node(curr, c) + def process(self, s, pr): raise NotImplementedError("Use Flow._process(...) instead") class BaseBatchFlow(BaseFlow): - def preprocess(self, shared_storage): - return [] # return an iterable of parameter dictionaries + def preprocess(self, s): return [] class BatchFlow(BaseBatchFlow, Flow): - def _run(self, shared_storage): - prep_result = self.preprocess(shared_storage) - - for param_dict in prep_result: - merged_params = self.params.copy() - merged_params.update(param_dict) - self._process(shared_storage, params=merged_params) - - return self.postprocess(shared_storage, prep_result, None) + def _run(self, s): + pr = self.preprocess(s) + for d in pr: + mp = self.params.copy(); mp.update(d) + self._process(s, mp) + return self.postprocess(s, pr, None) + +class AsyncNode(Node): + def postprocess(self, s, pr, r): raise NotImplementedError("Use postprocess_async") + async def postprocess_async(self, s, pr, r): await asyncio.sleep(0); return "default" + async def run_async(self, s): + if self.successors: warnings.warn("Has successors; use AsyncFlow.run_async() instead.") + return await self._run_async(s) + async def _run_async(self, s): + pr = self.preprocess(s) + r = self._process(s, pr) + return await self.postprocess_async(s, pr, r) + def _run(self, s): raise RuntimeError("AsyncNode requires async execution") + +class AsyncFlow(BaseFlow, AsyncNode): + async def _process_async(self, s, p=None): + curr, p = self.start_node, (p if p else self.params.copy()) + while curr: + curr.set_params(p) + c = await curr._run_async(s) if hasattr(curr, "run_async") else curr._run(s) + curr = self.get_next_node(curr, c) + async def _run_async(self, s): + pr = self.preprocess(s) + await self._process_async(s) + return await self.postprocess_async(s, pr, None) class BatchAsyncFlow(BaseBatchFlow, AsyncFlow): - async def _run_async(self, shared_storage): - prep_result = self.preprocess(shared_storage) - - for param_dict in prep_result: - merged_params = self.params.copy() - merged_params.update(param_dict) - await self._process_async(shared_storage, params=merged_params) - - return await self.postprocess_async(shared_storage, prep_result, None) \ No newline at end of file + async def _run_async(self, s): + pr = self.preprocess(s) + for d in pr: + mp = self.params.copy(); mp.update(d) + await self._process_async(s, mp) + return await self.postprocess_async(s, pr, None) \ No newline at end of file diff --git a/tests/test_async_batch_flow.py b/tests/test_async_batch_flow.py new file mode 100644 index 0000000..76ddfcf --- /dev/null +++ b/tests/test_async_batch_flow.py @@ -0,0 +1,176 @@ +import unittest +import asyncio +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) +from minillmflow import AsyncNode, BatchAsyncFlow + +class AsyncDataProcessNode(AsyncNode): + def process(self, shared_storage, prep_result): + key = self.params.get('key') + data = shared_storage['input_data'][key] + if 'results' not in shared_storage: + shared_storage['results'] = {} + shared_storage['results'][key] = data + return data + + async def postprocess_async(self, shared_storage, prep_result, proc_result): + await asyncio.sleep(0.01) # Simulate async work + key = self.params.get('key') + shared_storage['results'][key] = proc_result * 2 # Double the value + return "processed" + +class AsyncErrorNode(AsyncNode): + async def postprocess_async(self, shared_storage, prep_result, proc_result): + key = self.params.get('key') + if key == 'error_key': + raise ValueError(f"Async error processing key: {key}") + return "processed" + +class TestAsyncBatchFlow(unittest.TestCase): + def setUp(self): + self.process_node = AsyncDataProcessNode() + + def test_basic_async_batch_processing(self): + """Test basic async batch processing with multiple keys""" + class SimpleTestAsyncBatchFlow(BatchAsyncFlow): + def preprocess(self, shared_storage): + return [{'key': k} for k in shared_storage['input_data'].keys()] + + shared_storage = { + 'input_data': { + 'a': 1, + 'b': 2, + 'c': 3 + } + } + + flow = SimpleTestAsyncBatchFlow(start_node=self.process_node) + asyncio.run(flow.run_async(shared_storage)) + + expected_results = { + 'a': 2, # 1 * 2 + 'b': 4, # 2 * 2 + 'c': 6 # 3 * 2 + } + self.assertEqual(shared_storage['results'], expected_results) + + def test_empty_async_batch(self): + """Test async batch processing with empty input""" + class EmptyTestAsyncBatchFlow(BatchAsyncFlow): + def preprocess(self, shared_storage): + return [{'key': k} for k in shared_storage['input_data'].keys()] + + shared_storage = { + 'input_data': {} + } + + flow = EmptyTestAsyncBatchFlow(start_node=self.process_node) + asyncio.run(flow.run_async(shared_storage)) + + self.assertEqual(shared_storage.get('results', {}), {}) + + def test_async_error_handling(self): + """Test error handling during async batch processing""" + class ErrorTestAsyncBatchFlow(BatchAsyncFlow): + def preprocess(self, shared_storage): + return [{'key': k} for k in shared_storage['input_data'].keys()] + + shared_storage = { + 'input_data': { + 'normal_key': 1, + 'error_key': 2, + 'another_key': 3 + } + } + + flow = ErrorTestAsyncBatchFlow(start_node=AsyncErrorNode()) + + with self.assertRaises(ValueError): + asyncio.run(flow.run_async(shared_storage)) + + def test_nested_async_flow(self): + """Test async batch processing with nested flows""" + class AsyncInnerNode(AsyncNode): + async def postprocess_async(self, shared_storage, prep_result, proc_result): + key = self.params.get('key') + if 'intermediate_results' not in shared_storage: + shared_storage['intermediate_results'] = {} + shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1 + await asyncio.sleep(0.01) + return "next" + + class AsyncOuterNode(AsyncNode): + async def postprocess_async(self, shared_storage, prep_result, proc_result): + key = self.params.get('key') + if 'results' not in shared_storage: + shared_storage['results'] = {} + shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2 + await asyncio.sleep(0.01) + return "done" + + class NestedAsyncBatchFlow(BatchAsyncFlow): + def preprocess(self, shared_storage): + return [{'key': k} for k in shared_storage['input_data'].keys()] + + # Create inner flow + inner_node = AsyncInnerNode() + outer_node = AsyncOuterNode() + inner_node - "next" >> outer_node + + shared_storage = { + 'input_data': { + 'x': 1, + 'y': 2 + } + } + + flow = NestedAsyncBatchFlow(start_node=inner_node) + asyncio.run(flow.run_async(shared_storage)) + + expected_results = { + 'x': 4, # (1 + 1) * 2 + 'y': 6 # (2 + 1) * 2 + } + self.assertEqual(shared_storage['results'], expected_results) + + def test_custom_async_parameters(self): + """Test async batch processing with additional custom parameters""" + class CustomParamAsyncNode(AsyncNode): + async def postprocess_async(self, shared_storage, prep_result, proc_result): + key = self.params.get('key') + multiplier = self.params.get('multiplier', 1) + await asyncio.sleep(0.01) + if 'results' not in shared_storage: + shared_storage['results'] = {} + shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier + return "done" + + class CustomParamAsyncBatchFlow(BatchAsyncFlow): + def preprocess(self, shared_storage): + return [{ + 'key': k, + 'multiplier': i + 1 + } for i, k in enumerate(shared_storage['input_data'].keys())] + + shared_storage = { + 'input_data': { + 'a': 1, + 'b': 2, + 'c': 3 + } + } + + flow = CustomParamAsyncBatchFlow(start_node=CustomParamAsyncNode()) + asyncio.run(flow.run_async(shared_storage)) + + expected_results = { + 'a': 1 * 1, # first item, multiplier = 1 + 'b': 2 * 2, # second item, multiplier = 2 + 'c': 3 * 3 # third item, multiplier = 3 + } + self.assertEqual(shared_storage['results'], expected_results) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file