diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index dc1db63..79250b5 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -6,95 +6,99 @@ class BaseNode: def add_successor(self, node, cond="default"): if cond in self.successors: warnings.warn(f"Overwriting existing successor for '{cond}'") self.successors[cond] = node; return node - def preprocess(self, s): return None - def process(self, s, p): return None - def _process(self, s, p): return self.process(s, p) - def postprocess(self, s, pr, r): return "default" - def _run(self, s): - pr = self.preprocess(s) - r = self._process(s, pr) - return self.postprocess(s, pr, r) - def run(self, s): - if self.successors: warnings.warn("Has successors; use Flow.run() instead.") - return self._run(s) + def prep(self, shared): return None + def exec(self, shared, prep_res): return None + def _exec(self, shared, prep_res): return self.exec(shared, prep_res) + def post(self, shared, prep_res, exec_res): return "default" + def _run(self, shared): + prep_res = self.prep(shared) + exec_res = self._exec(shared, prep_res) + return self.post(shared, prep_res, exec_res) + def run(self, shared): + if self.successors: + warnings.warn("This node has successors. Create a parent Flow to run them or run that Flow instead.") + return self._run(shared) def __rshift__(self, other): return self.add_successor(other) def __sub__(self, cond): if isinstance(cond, str): return _ConditionalTransition(self, cond) raise TypeError("Condition must be a string") class _ConditionalTransition: - def __init__(self, src, c): self.src, self.c = src, c - def __rshift__(self, tgt): return self.src.add_successor(tgt, self.c) + def __init__(self, src, cond): self.src, self.cond = src, cond + def __rshift__(self, tgt): return self.src.add_successor(tgt, self.cond) class Node(BaseNode): def __init__(self, max_retries=1): super().__init__(); self.max_retries = max_retries - def process_after_fail(self, s, d, e): raise e - def _process(self, s, d): + def process_after_fail(self, shared, prep_res, exc): raise exc + def _exec(self, shared, prep_res): for i in range(self.max_retries): - try: return super()._process(s, d) + try: return super()._exec(shared, prep_res) except Exception as e: - if i == self.max_retries - 1: return self.process_after_fail(s, d, e) + if i == self.max_retries - 1: return self.process_after_fail(shared, prep_res, e) class BatchNode(Node): - def preprocess(self, s): return [] - def process(self, s, item): return None - def _process(self, s, items): return [super(Node, self)._process(s, i) for i in items] + def prep(self, shared): return [] + def exec(self, shared, item): return None + def _exec(self, shared, items): return [super(Node, self)._exec(shared, i) for i in items] class BaseFlow(BaseNode): def __init__(self, start_node): super().__init__(); self.start_node = start_node - def get_next_node(self, curr, c): - nxt = curr.successors.get(c) - if nxt is None and curr.successors: warnings.warn(f"Flow ends. '{c}' not found in {list(curr.successors.keys())}") + def get_next_node(self, curr, cond): + nxt = curr.successors.get(cond) + if nxt is None and curr.successors: + warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors.keys())}") return nxt class Flow(BaseFlow): - def _process(self, s, p=None): - curr, p = self.start_node, (p if p is not None else self.params.copy()) + def _exec(self, shared, params=None): + curr = self.start_node + p = params if params else self.params.copy() while curr: curr.set_params(p) - c = curr._run(s) + c = curr._run(shared) curr = self.get_next_node(curr, c) - def process(self, s, pr): raise NotImplementedError("Use Flow._process(...) instead") + def exec(self, shared, prep_res): raise NotImplementedError("Flow exec not used directly") class BaseBatchFlow(BaseFlow): - def preprocess(self, s): return [] + def prep(self, shared): return [] class BatchFlow(BaseBatchFlow, Flow): - def _run(self, s): - pr = self.preprocess(s) - for d in pr: + def _run(self, shared): + prep_res = self.prep(shared) + for d in prep_res: mp = self.params.copy(); mp.update(d) - self._process(s, mp) - return self.postprocess(s, pr, None) + self._exec(shared, mp) + return self.post(shared, prep_res, None) class AsyncNode(Node): - def postprocess(self, s, pr, r): raise NotImplementedError("Use postprocess_async") - async def postprocess_async(self, s, pr, r): await asyncio.sleep(0); return "default" - async def run_async(self, s): - if self.successors: warnings.warn("Has successors; use AsyncFlow.run_async() instead.") - return await self._run_async(s) - async def _run_async(self, s): - pr = self.preprocess(s) - r = self._process(s, pr) - return await self.postprocess_async(s, pr, r) - def _run(self, s): raise RuntimeError("AsyncNode requires async execution") + def post(self, shared, prep_res, exec_res): raise NotImplementedError("Use post_async") + async def post_async(self, shared, prep_res, exec_res): await asyncio.sleep(0); return "default" + async def run_async(self, shared): + if self.successors: + warnings.warn("This node has successors. Create a parent AsyncFlow to run them or run that Flow instead.") + return await self._run_async(shared) + async def _run_async(self, shared): + prep_res = self.prep(shared) + exec_res = self._exec(shared, prep_res) + return await self.post_async(shared, prep_res, exec_res) + def _run(self, shared): raise RuntimeError("AsyncNode requires async execution") class AsyncFlow(BaseFlow, AsyncNode): - async def _process_async(self, s, p=None): - curr, p = self.start_node, (p if p else self.params.copy()) + async def _exec_async(self, shared, params=None): + curr, p = self.start_node, (params if params else self.params.copy()) while curr: curr.set_params(p) - c = await curr._run_async(s) if hasattr(curr, "run_async") else curr._run(s) + c = await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared) curr = self.get_next_node(curr, c) - async def _run_async(self, s): - pr = self.preprocess(s) - await self._process_async(s) - return await self.postprocess_async(s, pr, None) + async def _run_async(self, shared): + prep_res = self.prep(shared) + await self._exec_async(shared) + return await self.post_async(shared, prep_res, None) class BatchAsyncFlow(BaseBatchFlow, AsyncFlow): - async def _run_async(self, s): - pr = self.preprocess(s) - for d in pr: + async def _run_async(self, shared): + prep_res = self.prep(shared) + for d in prep_res: mp = self.params.copy(); mp.update(d) - await self._process_async(s, mp) - return await self.postprocess_async(s, pr, None) \ No newline at end of file + await self._exec_async(shared, mp) + return await self.post_async(shared, prep_res, None) \ No newline at end of file diff --git a/tests/test_async_batch_flow.py b/tests/test_async_batch_flow.py index 76ddfcf..e0e911e 100644 --- a/tests/test_async_batch_flow.py +++ b/tests/test_async_batch_flow.py @@ -7,7 +7,7 @@ sys.path.append(str(Path(__file__).parent.parent)) from minillmflow import AsyncNode, BatchAsyncFlow class AsyncDataProcessNode(AsyncNode): - def process(self, shared_storage, prep_result): + def exec(self, shared_storage, prep_result): key = self.params.get('key') data = shared_storage['input_data'][key] if 'results' not in shared_storage: @@ -15,14 +15,14 @@ class AsyncDataProcessNode(AsyncNode): shared_storage['results'][key] = data return data - async def postprocess_async(self, shared_storage, prep_result, proc_result): + async def post_async(self, shared_storage, prep_result, proc_result): await asyncio.sleep(0.01) # Simulate async work key = self.params.get('key') shared_storage['results'][key] = proc_result * 2 # Double the value return "processed" class AsyncErrorNode(AsyncNode): - async def postprocess_async(self, shared_storage, prep_result, proc_result): + async def post_async(self, shared_storage, prep_result, proc_result): key = self.params.get('key') if key == 'error_key': raise ValueError(f"Async error processing key: {key}") @@ -35,7 +35,7 @@ class TestAsyncBatchFlow(unittest.TestCase): def test_basic_async_batch_processing(self): """Test basic async batch processing with multiple keys""" class SimpleTestAsyncBatchFlow(BatchAsyncFlow): - def preprocess(self, shared_storage): + def prep(self, shared_storage): return [{'key': k} for k in shared_storage['input_data'].keys()] shared_storage = { @@ -59,7 +59,7 @@ class TestAsyncBatchFlow(unittest.TestCase): def test_empty_async_batch(self): """Test async batch processing with empty input""" class EmptyTestAsyncBatchFlow(BatchAsyncFlow): - def preprocess(self, shared_storage): + def prep(self, shared_storage): return [{'key': k} for k in shared_storage['input_data'].keys()] shared_storage = { @@ -74,7 +74,7 @@ class TestAsyncBatchFlow(unittest.TestCase): def test_async_error_handling(self): """Test error handling during async batch processing""" class ErrorTestAsyncBatchFlow(BatchAsyncFlow): - def preprocess(self, shared_storage): + def prep(self, shared_storage): return [{'key': k} for k in shared_storage['input_data'].keys()] shared_storage = { @@ -93,7 +93,7 @@ class TestAsyncBatchFlow(unittest.TestCase): def test_nested_async_flow(self): """Test async batch processing with nested flows""" class AsyncInnerNode(AsyncNode): - async def postprocess_async(self, shared_storage, prep_result, proc_result): + async def post_async(self, shared_storage, prep_result, proc_result): key = self.params.get('key') if 'intermediate_results' not in shared_storage: shared_storage['intermediate_results'] = {} @@ -102,7 +102,7 @@ class TestAsyncBatchFlow(unittest.TestCase): return "next" class AsyncOuterNode(AsyncNode): - async def postprocess_async(self, shared_storage, prep_result, proc_result): + async def post_async(self, shared_storage, prep_result, proc_result): key = self.params.get('key') if 'results' not in shared_storage: shared_storage['results'] = {} @@ -111,7 +111,7 @@ class TestAsyncBatchFlow(unittest.TestCase): return "done" class NestedAsyncBatchFlow(BatchAsyncFlow): - def preprocess(self, shared_storage): + def prep(self, shared_storage): return [{'key': k} for k in shared_storage['input_data'].keys()] # Create inner flow @@ -138,7 +138,7 @@ class TestAsyncBatchFlow(unittest.TestCase): def test_custom_async_parameters(self): """Test async batch processing with additional custom parameters""" class CustomParamAsyncNode(AsyncNode): - async def postprocess_async(self, shared_storage, prep_result, proc_result): + async def post_async(self, shared_storage, prep_result, proc_result): key = self.params.get('key') multiplier = self.params.get('multiplier', 1) await asyncio.sleep(0.01) @@ -148,7 +148,7 @@ class TestAsyncBatchFlow(unittest.TestCase): return "done" class CustomParamAsyncBatchFlow(BatchAsyncFlow): - def preprocess(self, shared_storage): + def prep(self, shared_storage): return [{ 'key': k, 'multiplier': i + 1 diff --git a/tests/test_async_flow.py b/tests/test_async_flow.py index 7811b2e..bda8ec5 100644 --- a/tests/test_async_flow.py +++ b/tests/test_async_flow.py @@ -11,19 +11,19 @@ class AsyncNumberNode(AsyncNode): """ Simple async node that sets 'current' to a given number. Demonstrates overriding .process() (sync) and using - postprocess_async() for the async portion. + post_async() for the async portion. """ def __init__(self, number): super().__init__() self.number = number - def process(self, shared_storage, data): + def exec(self, shared_storage, data): # Synchronous work is allowed inside an AsyncNode, - # but final 'condition' is determined by postprocess_async(). + # but final 'condition' is determined by post_async(). shared_storage['current'] = self.number return "set_number" - async def postprocess_async(self, shared_storage, prep_result, proc_result): + async def post_async(self, shared_storage, prep_result, proc_result): # Possibly do asynchronous tasks here await asyncio.sleep(0.01) # Return a condition for the flow @@ -34,11 +34,11 @@ class AsyncIncrementNode(AsyncNode): """ Demonstrates incrementing the 'current' value asynchronously. """ - def process(self, shared_storage, data): + def exec(self, shared_storage, data): shared_storage['current'] = shared_storage.get('current', 0) + 1 return "incremented" - async def postprocess_async(self, shared_storage, prep_result, proc_result): + async def post_async(self, shared_storage, prep_result, proc_result): await asyncio.sleep(0.01) # simulate async I/O return "done" @@ -105,18 +105,18 @@ class TestAsyncFlow(unittest.TestCase): """ Demonstrate a branching scenario where we return different conditions. For example, you could have an async node that - returns "go_left" or "go_right" in postprocess_async, but here + returns "go_left" or "go_right" in post_async, but here we'll keep it simpler for demonstration. """ class BranchingAsyncNode(AsyncNode): - def process(self, shared_storage, data): + def exec(self, shared_storage, data): value = shared_storage.get("value", 0) shared_storage["value"] = value # We'll decide branch based on whether 'value' is positive return None - async def postprocess_async(self, shared_storage, prep_result, proc_result): + async def post_async(self, shared_storage, prep_result, proc_result): await asyncio.sleep(0.01) if shared_storage["value"] >= 0: return "positive_branch" @@ -124,12 +124,12 @@ class TestAsyncFlow(unittest.TestCase): return "negative_branch" class PositiveNode(Node): - def process(self, shared_storage, data): + def exec(self, shared_storage, data): shared_storage["path"] = "positive" return None class NegativeNode(Node): - def process(self, shared_storage, data): + def exec(self, shared_storage, data): shared_storage["path"] = "negative" return None diff --git a/tests/test_batch_flow.py b/tests/test_batch_flow.py index cacade0..cd2463b 100644 --- a/tests/test_batch_flow.py +++ b/tests/test_batch_flow.py @@ -6,7 +6,7 @@ sys.path.append(str(Path(__file__).parent.parent)) from minillmflow import Node, BatchFlow, Flow class DataProcessNode(Node): - def process(self, shared_storage, prep_result): + def exec(self, shared_storage, prep_result): key = self.params.get('key') data = shared_storage['input_data'][key] if 'results' not in shared_storage: @@ -14,7 +14,7 @@ class DataProcessNode(Node): shared_storage['results'][key] = data * 2 class ErrorProcessNode(Node): - def process(self, shared_storage, prep_result): + def exec(self, shared_storage, prep_result): key = self.params.get('key') if key == 'error_key': raise ValueError(f"Error processing key: {key}") @@ -29,7 +29,7 @@ class TestBatchFlow(unittest.TestCase): def test_basic_batch_processing(self): """Test basic batch processing with multiple keys""" class SimpleTestBatchFlow(BatchFlow): - def preprocess(self, shared_storage): + def prep(self, shared_storage): return [{'key': k} for k in shared_storage['input_data'].keys()] shared_storage = { @@ -53,7 +53,7 @@ class TestBatchFlow(unittest.TestCase): def test_empty_input(self): """Test batch processing with empty input dictionary""" class EmptyTestBatchFlow(BatchFlow): - def preprocess(self, shared_storage): + def prep(self, shared_storage): return [{'key': k} for k in shared_storage['input_data'].keys()] shared_storage = { @@ -68,7 +68,7 @@ class TestBatchFlow(unittest.TestCase): def test_single_item(self): """Test batch processing with single item""" class SingleItemBatchFlow(BatchFlow): - def preprocess(self, shared_storage): + def prep(self, shared_storage): return [{'key': k} for k in shared_storage['input_data'].keys()] shared_storage = { @@ -88,7 +88,7 @@ class TestBatchFlow(unittest.TestCase): def test_error_handling(self): """Test error handling during batch processing""" class ErrorTestBatchFlow(BatchFlow): - def preprocess(self, shared_storage): + def prep(self, shared_storage): return [{'key': k} for k in shared_storage['input_data'].keys()] shared_storage = { @@ -107,21 +107,21 @@ class TestBatchFlow(unittest.TestCase): def test_nested_flow(self): """Test batch processing with nested flows""" class InnerNode(Node): - def process(self, shared_storage, prep_result): + def exec(self, shared_storage, prep_result): key = self.params.get('key') if 'intermediate_results' not in shared_storage: shared_storage['intermediate_results'] = {} shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1 class OuterNode(Node): - def process(self, shared_storage, prep_result): + def exec(self, shared_storage, prep_result): key = self.params.get('key') if 'results' not in shared_storage: shared_storage['results'] = {} shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2 class NestedBatchFlow(BatchFlow): - def preprocess(self, shared_storage): + def prep(self, shared_storage): return [{'key': k} for k in shared_storage['input_data'].keys()] # Create inner flow @@ -148,7 +148,7 @@ class TestBatchFlow(unittest.TestCase): def test_custom_parameters(self): """Test batch processing with additional custom parameters""" class CustomParamNode(Node): - def process(self, shared_storage, prep_result): + def exec(self, shared_storage, prep_result): key = self.params.get('key') multiplier = self.params.get('multiplier', 1) if 'results' not in shared_storage: @@ -156,7 +156,7 @@ class TestBatchFlow(unittest.TestCase): shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier class CustomParamBatchFlow(BatchFlow): - def preprocess(self, shared_storage): + def prep(self, shared_storage): return [{ 'key': k, 'multiplier': i + 1 diff --git a/tests/test_batch_node.py b/tests/test_batch_node.py index 1c861cd..8f3a145 100644 --- a/tests/test_batch_node.py +++ b/tests/test_batch_node.py @@ -10,7 +10,7 @@ class ArrayChunkNode(BatchNode): super().__init__() self.chunk_size = chunk_size - def preprocess(self, shared_storage): + def prep(self, shared_storage): # Get array from shared storage and split into chunks array = shared_storage.get('input_array', []) chunks = [] @@ -19,20 +19,20 @@ class ArrayChunkNode(BatchNode): chunks.append((i, end)) return chunks - def process(self, shared_storage, chunk_indices): + def exec(self, shared_storage, chunk_indices): start, end = chunk_indices array = shared_storage['input_array'] # Process the chunk and return its sum chunk_sum = sum(array[start:end]) return chunk_sum - def postprocess(self, shared_storage, prep_result, proc_result): + def post(self, shared_storage, prep_result, proc_result): # Store chunk results in shared storage shared_storage['chunk_results'] = proc_result return "default" class SumReduceNode(Node): - def process(self, shared_storage, data): + def exec(self, shared_storage, data): # Get chunk results from shared storage and sum them chunk_results = shared_storage.get('chunk_results', []) total = sum(chunk_results) @@ -48,7 +48,7 @@ class TestBatchNode(unittest.TestCase): } chunk_node = ArrayChunkNode(chunk_size=10) - chunks = chunk_node.preprocess(shared_storage) + chunks = chunk_node.prep(shared_storage) self.assertEqual(chunks, [(0, 10), (10, 20), (20, 25)]) diff --git a/tests/test_flow_basic.py b/tests/test_flow_basic.py index f852613..3f59744 100644 --- a/tests/test_flow_basic.py +++ b/tests/test_flow_basic.py @@ -10,7 +10,7 @@ class NumberNode(Node): super().__init__() self.number = number - def process(self, shared_storage, data): + def exec(self, shared_storage, data): shared_storage['current'] = self.number class AddNode(Node): @@ -18,7 +18,7 @@ class AddNode(Node): super().__init__() self.number = number - def process(self, shared_storage, data): + def exec(self, shared_storage, data): shared_storage['current'] += self.number class MultiplyNode(Node): @@ -26,18 +26,18 @@ class MultiplyNode(Node): super().__init__() self.number = number - def process(self, shared_storage, data): + def exec(self, shared_storage, data): shared_storage['current'] *= self.number class CheckPositiveNode(Node): - def postprocess(self, shared_storage, prep_result, proc_result): + def post(self, shared_storage, prep_result, proc_result): if shared_storage['current'] >= 0: return 'positive' else: return 'negative' class NoOpNode(Node): - def process(self, shared_storage, data): + def exec(self, shared_storage, data): # Do nothing, just pass pass diff --git a/tests/test_flow_composition.py b/tests/test_flow_composition.py index 0555723..bac8fa0 100644 --- a/tests/test_flow_composition.py +++ b/tests/test_flow_composition.py @@ -12,7 +12,7 @@ class NumberNode(Node): super().__init__() self.number = number - def process(self, shared_storage, prep_result): + def exec(self, shared_storage, prep_result): shared_storage['current'] = self.number class AddNode(Node): @@ -20,7 +20,7 @@ class AddNode(Node): super().__init__() self.number = number - def process(self, shared_storage, prep_result): + def exec(self, shared_storage, prep_result): shared_storage['current'] += self.number class MultiplyNode(Node): @@ -28,7 +28,7 @@ class MultiplyNode(Node): super().__init__() self.number = number - def process(self, shared_storage, prep_result): + def exec(self, shared_storage, prep_result): shared_storage['current'] *= self.number class TestFlowComposition(unittest.TestCase):