From cf4875710ba48ac1a9d298429a29c79bf23c1eda Mon Sep 17 00:00:00 2001 From: zachary62 Date: Fri, 11 Apr 2025 15:23:21 -0400 Subject: [PATCH] update the pocketflow design --- pocketflow/__init__.py | 12 +- tests/test_async_flow.py | 78 +++++++++- tests/test_flow_basic.py | 250 +++++++++++++++++++++------------ tests/test_flow_composition.py | 101 +++++++++---- 4 files changed, 312 insertions(+), 129 deletions(-) diff --git a/pocketflow/__init__.py b/pocketflow/__init__.py index 6f4f533..a7203df 100644 --- a/pocketflow/__init__.py +++ b/pocketflow/__init__.py @@ -37,17 +37,17 @@ class BatchNode(Node): def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])] class Flow(BaseNode): - def __init__(self,start): super().__init__(); self.start=start - def start(self,start): self.start=start; return start + def __init__(self,start=None): super().__init__(); self.start_node=start + def start(self,start): self.start_node=start; return start def get_next_node(self,curr,action): nxt=curr.successors.get(action or "default") if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") return nxt def _orch(self,shared,params=None): - curr,p,last_action =copy.copy(self.start),(params or {**self.params}),None + curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) return last_action - def _run(self,shared): pr=self.prep(shared); self._orch(shared); return self.post(shared,pr,None) + def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o) def post(self,shared,prep_res,exec_res): return exec_res class BatchFlow(Flow): @@ -81,10 +81,10 @@ class AsyncParallelBatchNode(AsyncNode,BatchNode): class AsyncFlow(Flow,AsyncNode): async def _orch_async(self,shared,params=None): - curr,p,last_action =copy.copy(self.start),(params or {**self.params}),None + curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) return last_action - async def _run_async(self,shared): p=await self.prep_async(shared); await self._orch_async(shared); return await self.post_async(shared,p,None) + async def _run_async(self,shared): p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o) async def post_async(self,shared,prep_res,exec_res): return exec_res class AsyncBatchFlow(AsyncFlow,BatchFlow): diff --git a/tests/test_async_flow.py b/tests/test_async_flow.py index db6bdf3..38a9fbe 100644 --- a/tests/test_async_flow.py +++ b/tests/test_async_flow.py @@ -6,7 +6,6 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from pocketflow import Node, AsyncNode, AsyncFlow - class AsyncNumberNode(AsyncNode): """ Simple async node that sets 'current' to a given number. @@ -29,7 +28,6 @@ class AsyncNumberNode(AsyncNode): # Return a condition for the flow return "number_set" - class AsyncIncrementNode(AsyncNode): """ Demonstrates incrementing the 'current' value asynchronously. @@ -42,6 +40,37 @@ class AsyncIncrementNode(AsyncNode): await asyncio.sleep(0.01) # simulate async I/O return "done" +class AsyncSignalNode(AsyncNode): + """ An async node that returns a specific signal string from post_async. """ + def __init__(self, signal="default_async_signal"): + super().__init__() + self.signal = signal + + # No prep needed usually if just signaling + async def prep_async(self, shared_storage): + await asyncio.sleep(0.01) # Simulate async work + + async def post_async(self, shared_storage, prep_result, exec_result): + # Store the signal in shared storage for verification + shared_storage['last_async_signal_emitted'] = self.signal + await asyncio.sleep(0.01) # Simulate async work + print(self.signal) + return self.signal # Return the specific action string + +class AsyncPathNode(AsyncNode): + """ An async node to indicate which path was taken in the outer flow. """ + def __init__(self, path_id): + super().__init__() + self.path_id = path_id + + async def prep_async(self, shared_storage): + await asyncio.sleep(0.01) # Simulate async work + shared_storage['async_path_taken'] = self.path_id + + # post_async implicitly returns None (for default transition out if needed) + async def post_async(self, shared_storage, prep_result, exec_result): + await asyncio.sleep(0.01) + # Return None by default class TestAsyncNode(unittest.TestCase): """ @@ -149,6 +178,51 @@ class TestAsyncFlow(unittest.TestCase): self.assertEqual(shared_storage["path"], "positive", "Should have taken the positive branch") + def test_async_composition_with_action_propagation(self): + """ + Test AsyncFlow branches based on action from nested AsyncFlow's last node. + """ + async def run_test(): + shared_storage = {} + + # 1. Define an inner async flow ending with AsyncSignalNode + # Use existing AsyncNumberNode which should return None from post_async implicitly + inner_start_node = AsyncNumberNode(200) + inner_end_node = AsyncSignalNode("async_inner_done") # post_async -> "async_inner_done" + inner_start_node - "number_set" >> inner_end_node + # Inner flow will execute start->end, Flow exec returns "async_inner_done" + inner_flow = AsyncFlow(start=inner_start_node) + + # 2. Define target async nodes for the outer flow branches + path_a_node = AsyncPathNode("AsyncA") # post_async -> None + path_b_node = AsyncPathNode("AsyncB") # post_async -> None + + # 3. Define the outer async flow starting with the inner async flow + outer_flow = AsyncFlow(start=inner_flow) + + # 4. Define branches FROM the inner_flow object based on its returned action + inner_flow - "async_inner_done" >> path_b_node # This path should be taken + inner_flow - "other_action" >> path_a_node # This path should NOT be taken + + # 5. Run the outer async flow and capture the last action + # Execution: inner_start -> inner_end -> path_b + last_action_outer = await outer_flow.run_async(shared_storage) + + # 6. Return results for assertion + return shared_storage, last_action_outer + + # Run the async test function + shared_storage, last_action_outer = asyncio.run(run_test()) + + # 7. Assert the results + # Check state after inner flow execution + self.assertEqual(shared_storage.get('current'), 200) # From AsyncNumberNode + self.assertEqual(shared_storage.get('last_async_signal_emitted'), "async_inner_done") + # Check that the correct outer path was taken + self.assertEqual(shared_storage.get('async_path_taken'), "AsyncB") + # Check the action returned by the outer flow. The last node executed was + # path_b_node, which returns None from its post_async method. + self.assertIsNone(last_action_outer) if __name__ == '__main__': unittest.main() diff --git a/tests/test_flow_basic.py b/tests/test_flow_basic.py index e62c543..e7b6ddd 100644 --- a/tests/test_flow_basic.py +++ b/tests/test_flow_basic.py @@ -1,156 +1,224 @@ +# tests/test_flow_basic.py import unittest import sys from pathlib import Path +import warnings sys.path.insert(0, str(Path(__file__).parent.parent)) from pocketflow import Node, Flow +# --- Node Definitions --- +# Nodes intended for default transitions (>>) should NOT return a specific +# action string from post. Let it return None by default. +# Nodes intended for conditional transitions (-) MUST return the action string. + class NumberNode(Node): def __init__(self, number): super().__init__() self.number = number - def prep(self, shared_storage): shared_storage['current'] = self.number + # post implicitly returns None - used for default transition class AddNode(Node): def __init__(self, number): super().__init__() self.number = number - def prep(self, shared_storage): shared_storage['current'] += self.number + # post implicitly returns None - used for default transition class MultiplyNode(Node): def __init__(self, number): super().__init__() self.number = number - def prep(self, shared_storage): shared_storage['current'] *= self.number + # post implicitly returns None - used for default transition class CheckPositiveNode(Node): + # This node IS designed for conditional branching + def prep(self, shared_storage): + pass def post(self, shared_storage, prep_result, proc_result): + # MUST return the specific action string for branching if shared_storage['current'] >= 0: return 'positive' else: return 'negative' class NoOpNode(Node): - def prep(self, shared_storage): - # Do nothing, just pass - pass - -class TestNode(unittest.TestCase): - def test_single_number(self): + # Just a placeholder node + pass # post implicitly returns None + +class EndSignalNode(Node): + # A node specifically to return a value when it's the end + def __init__(self, signal="finished"): + super().__init__() + self.signal = signal + def post(self, shared_storage, prep_result, exec_result): + return self.signal # Return a specific signal + +# --- Test Class --- +class TestFlowBasic(unittest.TestCase): + + def test_start_method_initialization(self): + """Test initializing flow with start() after creation.""" shared_storage = {} - start = NumberNode(5) - pipeline = Flow(start=start) - pipeline.run(shared_storage) + n1 = NumberNode(5) + pipeline = Flow() + pipeline.start(n1) + last_action = pipeline.run(shared_storage) self.assertEqual(shared_storage['current'], 5) + # NumberNode.post returns None (default) + self.assertIsNone(last_action) - def test_sequence(self): - """ - Test a simple linear pipeline: - NumberNode(5) -> AddNode(3) -> MultiplyNode(2) + def test_start_method_chaining(self): + """Test fluent chaining using start().next()...""" + shared_storage = {} + pipeline = Flow() + # Chain: NumberNode -> AddNode -> MultiplyNode + # All use default transitions (post returns None) + pipeline.start(NumberNode(5)).next(AddNode(3)).next(MultiplyNode(2)) + last_action = pipeline.run(shared_storage) + self.assertEqual(shared_storage['current'], 16) + # Last node (MultiplyNode) post returns None + self.assertIsNone(last_action) - Expected result: - (5 + 3) * 2 = 16 - """ + def test_sequence_with_rshift(self): + """Test a simple linear pipeline using >>""" shared_storage = {} n1 = NumberNode(5) n2 = AddNode(3) n3 = MultiplyNode(2) - # Chain them in sequence using the >> operator - n1 >> n2 >> n3 - - pipeline = Flow(start=n1) - pipeline.run(shared_storage) + pipeline = Flow() + # All default transitions (post returns None) + pipeline.start(n1) >> n2 >> n3 + last_action = pipeline.run(shared_storage) self.assertEqual(shared_storage['current'], 16) + # Last node (n3: MultiplyNode) post returns None + self.assertIsNone(last_action) def test_branching_positive(self): - """ - Test a branching pipeline with positive route: - start = NumberNode(5) - check = CheckPositiveNode() - if 'positive' -> AddNode(10) - if 'negative' -> AddNode(-20) - - Since we start with 5, - check returns 'positive', - so we add 10. Final result = 15. - """ + """Test positive branch: CheckPositiveNode returns 'positive'""" shared_storage = {} - start = NumberNode(5) - check = CheckPositiveNode() - add_if_positive = AddNode(10) - add_if_negative = AddNode(-20) + start_node = NumberNode(5) # post -> None + check_node = CheckPositiveNode() # post -> 'positive' or 'negative' + add_if_positive = AddNode(10) # post -> None + add_if_negative = AddNode(-20) # post -> None (won't run) - start >> check + pipeline = Flow() + # start -> check (default); check branches on 'positive'/'negative' + pipeline.start(start_node) >> check_node + check_node - "positive" >> add_if_positive + check_node - "negative" >> add_if_negative - # Use the new dash operator for condition - check - "positive" >> add_if_positive - check - "negative" >> add_if_negative + # Execution: start_node -> check_node -> add_if_positive + last_action = pipeline.run(shared_storage) + self.assertEqual(shared_storage['current'], 15) # 5 + 10 + # Last node executed was add_if_positive, its post returns None + self.assertIsNone(last_action) - pipeline = Flow(start=start) - pipeline.run(shared_storage) - - self.assertEqual(shared_storage['current'], 15) - - def test_negative_branch(self): - """ - Same branching pipeline, but starting with -5. - That should return 'negative' from CheckPositiveNode - and proceed to add_if_negative, i.e. add -20. - - Final result: (-5) + (-20) = -25. - """ + def test_branching_negative(self): + """Test negative branch: CheckPositiveNode returns 'negative'""" shared_storage = {} - start = NumberNode(-5) - check = CheckPositiveNode() - add_if_positive = AddNode(10) - add_if_negative = AddNode(-20) + start_node = NumberNode(-5) # post -> None + check_node = CheckPositiveNode() # post -> 'positive' or 'negative' + add_if_positive = AddNode(10) # post -> None (won't run) + add_if_negative = AddNode(-20) # post -> None - # Build the flow - start >> check - check - "positive" >> add_if_positive - check - "negative" >> add_if_negative + pipeline = Flow() + pipeline.start(start_node) >> check_node + check_node - "positive" >> add_if_positive + check_node - "negative" >> add_if_negative - pipeline = Flow(start=start) - pipeline.run(shared_storage) + # Execution: start_node -> check_node -> add_if_negative + last_action = pipeline.run(shared_storage) + self.assertEqual(shared_storage['current'], -25) # -5 + -20 + # Last node executed was add_if_negative, its post returns None + self.assertIsNone(last_action) - # Should have gone down the 'negative' branch - self.assertEqual(shared_storage['current'], -25) - - def test_cycle_until_negative(self): - """ - Demonstrate a cyclical pipeline: - Start with 10, check if positive -> subtract 3, then go back to check. - Repeat until the number becomes negative, at which point pipeline ends. - """ + def test_cycle_until_negative_ends_with_signal(self): + """Test cycle, ending on a node that returns a signal""" shared_storage = {} - n1 = NumberNode(10) - check = CheckPositiveNode() - subtract3 = AddNode(-3) - no_op = NoOpNode() # Dummy node for the 'negative' branch + n1 = NumberNode(10) # post -> None + check = CheckPositiveNode() # post -> 'positive' or 'negative' + subtract3 = AddNode(-3) # post -> None + end_node = EndSignalNode("cycle_done") # post -> "cycle_done" - # Build the cycle: - # n1 -> check -> if 'positive': subtract3 -> back to check - n1 >> check + pipeline = Flow() + pipeline.start(n1) >> check + # Branching from CheckPositiveNode check - 'positive' >> subtract3 - subtract3 >> check - - # Attach a no-op node on the negative branch to avoid warning - check - 'negative' >> no_op + check - 'negative' >> end_node # End on negative branch + # After subtracting, go back to check (default transition) + subtract3 >> check - pipeline = Flow(start=n1) - pipeline.run(shared_storage) + # Execution: n1->check->sub3->check->sub3->check->sub3->check->sub3->check->end_node + last_action = pipeline.run(shared_storage) + self.assertEqual(shared_storage['current'], -2) # 10 -> 7 -> 4 -> 1 -> -2 + # Last node executed was end_node, its post returns "cycle_done" + self.assertEqual(last_action, "cycle_done") - # final result should be -2: (10 -> 7 -> 4 -> 1 -> -2) - self.assertEqual(shared_storage['current'], -2) + def test_flow_ends_warning_default_missing(self): + """Test warning when default transition is needed but not found""" + shared_storage = {} + # Node that returns a specific action from post + class ActionNode(Node): + def post(self, *args): return "specific_action" + start_node = ActionNode() + next_node = NoOpNode() + + pipeline = Flow() + pipeline.start(start_node) + # Define successor only for the specific action + start_node - "specific_action" >> next_node + + # Make start_node return None instead, triggering default search + start_node.post = lambda *args: None + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + # Run flow. start_node runs, post returns None. + # Flow looks for "default", but only "specific_action" exists. + last_action = pipeline.run(shared_storage) + + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, UserWarning)) + # Warning message should indicate "default" wasn't found + self.assertIn("Flow ends: 'None' not found in ['specific_action']", str(w[-1].message)) + # Last action is from start_node's post + self.assertIsNone(last_action) + + def test_flow_ends_warning_specific_missing(self): + """Test warning when specific action is returned but not found""" + shared_storage = {} + # Node that returns a specific action from post + class ActionNode(Node): + def post(self, *args): return "specific_action" + start_node = ActionNode() + next_node = NoOpNode() + + pipeline = Flow() + pipeline.start(start_node) + # Define successor only for "default" + start_node >> next_node # same as start_node.next(next_node, "default") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + # Run flow. start_node runs, post returns "specific_action". + # Flow looks for "specific_action", but only "default" exists. + last_action = pipeline.run(shared_storage) + + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, UserWarning)) + # Warning message should indicate "specific_action" wasn't found + self.assertIn("Flow ends: 'specific_action' not found in ['default']", str(w[-1].message)) + # Last action is from start_node's post + self.assertEqual(last_action, "specific_action") if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/test_flow_composition.py b/tests/test_flow_composition.py index 277e719..544bd08 100644 --- a/tests/test_flow_composition.py +++ b/tests/test_flow_composition.py @@ -1,38 +1,62 @@ +# tests/test_flow_composition.py import unittest -import asyncio +import asyncio # Keep import, might be needed if other tests use it indirectly import sys from pathlib import Path - sys.path.insert(0, str(Path(__file__).parent.parent)) from pocketflow import Node, Flow -# Simple example Nodes +# --- Existing Nodes --- class NumberNode(Node): def __init__(self, number): super().__init__() self.number = number - def prep(self, shared_storage): shared_storage['current'] = self.number + # post implicitly returns None class AddNode(Node): def __init__(self, number): super().__init__() self.number = number - def prep(self, shared_storage): shared_storage['current'] += self.number + # post implicitly returns None class MultiplyNode(Node): def __init__(self, number): super().__init__() self.number = number - def prep(self, shared_storage): shared_storage['current'] *= self.number + # post implicitly returns None +# --- New Nodes for Action Propagation Test --- +class SignalNode(Node): + """A node that returns a specific signal string from its post method.""" + def __init__(self, signal="default_signal"): + super().__init__() + self.signal = signal + # No prep needed usually if just signaling + def post(self, shared_storage, prep_result, exec_result): + # Store the signal in shared storage for verification + shared_storage['last_signal_emitted'] = self.signal + return self.signal # Return the specific action string + +class PathNode(Node): + """A node to indicate which path was taken in the outer flow.""" + def __init__(self, path_id): + super().__init__() + self.path_id = path_id + def prep(self, shared_storage): + shared_storage['path_taken'] = self.path_id + # post implicitly returns None + +# --- Test Class --- class TestFlowComposition(unittest.TestCase): + + # --- Existing Tests (Unchanged) --- def test_flow_as_node(self): """ 1) Create a Flow (f1) starting with NumberNode(5), then AddNode(10), then MultiplyNode(2). @@ -41,18 +65,11 @@ class TestFlowComposition(unittest.TestCase): Expected final result in shared_storage['current']: (5 + 10) * 2 = 30. """ shared_storage = {} - - # Inner flow f1 f1 = Flow(start=NumberNode(5)) f1 >> AddNode(10) >> MultiplyNode(2) - - # f2 starts with f1 f2 = Flow(start=f1) - - # Wrapper flow f3 to ensure proper execution f3 = Flow(start=f2) f3.run(shared_storage) - self.assertEqual(shared_storage['current'], 30) def test_nested_flow(self): @@ -64,19 +81,12 @@ class TestFlowComposition(unittest.TestCase): Expected final result: (5 + 3) * 4 = 32. """ shared_storage = {} - - # Build the inner flow inner_flow = Flow(start=NumberNode(5)) inner_flow >> AddNode(3) - - # Build the middle flow, whose start is the inner flow middle_flow = Flow(start=inner_flow) middle_flow >> MultiplyNode(4) - - # Wrapper flow to ensure proper execution wrapper_flow = Flow(start=middle_flow) wrapper_flow.run(shared_storage) - self.assertEqual(shared_storage['current'], 32) def test_flow_chaining_flows(self): @@ -88,23 +98,54 @@ class TestFlowComposition(unittest.TestCase): Expected final result: (10 + 10) * 2 = 40. """ shared_storage = {} - - # flow1 numbernode = NumberNode(10) numbernode >> AddNode(10) flow1 = Flow(start=numbernode) - - # flow2 flow2 = Flow(start=MultiplyNode(2)) - - # Chain flow1 to flow2 - flow1 >> flow2 - - # Wrapper flow to ensure proper execution + flow1 >> flow2 # Default transition based on flow1 returning None wrapper_flow = Flow(start=flow1) wrapper_flow.run(shared_storage) - self.assertEqual(shared_storage['current'], 40) + def test_composition_with_action_propagation(self): + """ + Test that an outer flow can branch based on the action returned + by the last node's post() within an inner flow. + """ + shared_storage = {} + + # 1. Define an inner flow that ends with a node returning a specific action + inner_start_node = NumberNode(100) # current = 100, post -> None + inner_end_node = SignalNode("inner_done") # post -> "inner_done" + inner_start_node >> inner_end_node + # Inner flow will execute start->end, and the Flow's execution will return "inner_done" + inner_flow = Flow(start=inner_start_node) + + # 2. Define target nodes for the outer flow branches + path_a_node = PathNode("A") # post -> None + path_b_node = PathNode("B") # post -> None + + # 3. Define the outer flow starting with the inner flow + outer_flow = Flow() + outer_flow.start(inner_flow) # Use the start() method + + # 4. Define branches FROM the inner_flow object based on its returned action + inner_flow - "inner_done" >> path_b_node # This path should be taken + inner_flow - "other_action" >> path_a_node # This path should NOT be taken + + # 5. Run the outer flow and capture the last action + # Execution: inner_start -> inner_end -> path_b + last_action_outer = outer_flow.run(shared_storage) + + # 6. Assert the results + # Check state after inner flow execution + self.assertEqual(shared_storage.get('current'), 100) + self.assertEqual(shared_storage.get('last_signal_emitted'), "inner_done") + # Check that the correct outer path was taken + self.assertEqual(shared_storage.get('path_taken'), "B") + # Check the action returned by the outer flow. The last node executed was + # path_b_node, which returns None from its post method. + self.assertIsNone(last_action_outer) + if __name__ == '__main__': unittest.main() \ No newline at end of file