# tests/test_flow_basic.py import unittest import sys from pathlib import Path import warnings sys.path.insert(0, str(Path(__file__).parent.parent)) from pocketflow import Node, Flow # --- Node Definitions --- # Nodes intended for default transitions (>>) should NOT return a specific # action string from post. Let it return None by default. # Nodes intended for conditional transitions (-) MUST return the action string. class NumberNode(Node): def __init__(self, number): super().__init__() self.number = number def prep(self, shared_storage): shared_storage['current'] = self.number # post implicitly returns None - used for default transition class AddNode(Node): def __init__(self, number): super().__init__() self.number = number def prep(self, shared_storage): shared_storage['current'] += self.number # post implicitly returns None - used for default transition class MultiplyNode(Node): def __init__(self, number): super().__init__() self.number = number def prep(self, shared_storage): shared_storage['current'] *= self.number # post implicitly returns None - used for default transition class CheckPositiveNode(Node): # This node IS designed for conditional branching def prep(self, shared_storage): pass def post(self, shared_storage, prep_result, proc_result): # MUST return the specific action string for branching if shared_storage['current'] >= 0: return 'positive' else: return 'negative' class NoOpNode(Node): # Just a placeholder node pass # post implicitly returns None class EndSignalNode(Node): # A node specifically to return a value when it's the end def __init__(self, signal="finished"): super().__init__() self.signal = signal def post(self, shared_storage, prep_result, exec_result): return self.signal # Return a specific signal # --- Test Class --- class TestFlowBasic(unittest.TestCase): def test_start_method_initialization(self): """Test initializing flow with start() after creation.""" shared_storage = {} n1 = NumberNode(5) pipeline = Flow() pipeline.start(n1) last_action = pipeline.run(shared_storage) self.assertEqual(shared_storage['current'], 5) # NumberNode.post returns None (default) self.assertIsNone(last_action) def test_start_method_chaining(self): """Test fluent chaining using start().next()...""" shared_storage = {} pipeline = Flow() # Chain: NumberNode -> AddNode -> MultiplyNode # All use default transitions (post returns None) pipeline.start(NumberNode(5)).next(AddNode(3)).next(MultiplyNode(2)) last_action = pipeline.run(shared_storage) self.assertEqual(shared_storage['current'], 16) # Last node (MultiplyNode) post returns None self.assertIsNone(last_action) def test_sequence_with_rshift(self): """Test a simple linear pipeline using >>""" shared_storage = {} n1 = NumberNode(5) n2 = AddNode(3) n3 = MultiplyNode(2) pipeline = Flow() # All default transitions (post returns None) pipeline.start(n1) >> n2 >> n3 last_action = pipeline.run(shared_storage) self.assertEqual(shared_storage['current'], 16) # Last node (n3: MultiplyNode) post returns None self.assertIsNone(last_action) def test_branching_positive(self): """Test positive branch: CheckPositiveNode returns 'positive'""" shared_storage = {} start_node = NumberNode(5) # post -> None check_node = CheckPositiveNode() # post -> 'positive' or 'negative' add_if_positive = AddNode(10) # post -> None add_if_negative = AddNode(-20) # post -> None (won't run) pipeline = Flow() # start -> check (default); check branches on 'positive'/'negative' pipeline.start(start_node) >> check_node check_node - "positive" >> add_if_positive check_node - "negative" >> add_if_negative # Execution: start_node -> check_node -> add_if_positive last_action = pipeline.run(shared_storage) self.assertEqual(shared_storage['current'], 15) # 5 + 10 # Last node executed was add_if_positive, its post returns None self.assertIsNone(last_action) def test_branching_negative(self): """Test negative branch: CheckPositiveNode returns 'negative'""" shared_storage = {} start_node = NumberNode(-5) # post -> None check_node = CheckPositiveNode() # post -> 'positive' or 'negative' add_if_positive = AddNode(10) # post -> None (won't run) add_if_negative = AddNode(-20) # post -> None pipeline = Flow() pipeline.start(start_node) >> check_node check_node - "positive" >> add_if_positive check_node - "negative" >> add_if_negative # Execution: start_node -> check_node -> add_if_negative last_action = pipeline.run(shared_storage) self.assertEqual(shared_storage['current'], -25) # -5 + -20 # Last node executed was add_if_negative, its post returns None self.assertIsNone(last_action) def test_cycle_until_negative_ends_with_signal(self): """Test cycle, ending on a node that returns a signal""" shared_storage = {} n1 = NumberNode(10) # post -> None check = CheckPositiveNode() # post -> 'positive' or 'negative' subtract3 = AddNode(-3) # post -> None end_node = EndSignalNode("cycle_done") # post -> "cycle_done" pipeline = Flow() pipeline.start(n1) >> check # Branching from CheckPositiveNode check - 'positive' >> subtract3 check - 'negative' >> end_node # End on negative branch # After subtracting, go back to check (default transition) subtract3 >> check # Execution: n1->check->sub3->check->sub3->check->sub3->check->sub3->check->end_node last_action = pipeline.run(shared_storage) self.assertEqual(shared_storage['current'], -2) # 10 -> 7 -> 4 -> 1 -> -2 # Last node executed was end_node, its post returns "cycle_done" self.assertEqual(last_action, "cycle_done") def test_flow_ends_warning_default_missing(self): """Test warning when default transition is needed but not found""" shared_storage = {} # Node that returns a specific action from post class ActionNode(Node): def post(self, *args): return "specific_action" start_node = ActionNode() next_node = NoOpNode() pipeline = Flow() pipeline.start(start_node) # Define successor only for the specific action start_node - "specific_action" >> next_node # Make start_node return None instead, triggering default search start_node.post = lambda *args: None with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") # Run flow. start_node runs, post returns None. # Flow looks for "default", but only "specific_action" exists. last_action = pipeline.run(shared_storage) self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, UserWarning)) # Warning message should indicate "default" wasn't found self.assertIn("Flow ends: 'None' not found in ['specific_action']", str(w[-1].message)) # Last action is from start_node's post self.assertIsNone(last_action) def test_flow_ends_warning_specific_missing(self): """Test warning when specific action is returned but not found""" shared_storage = {} # Node that returns a specific action from post class ActionNode(Node): def post(self, *args): return "specific_action" start_node = ActionNode() next_node = NoOpNode() pipeline = Flow() pipeline.start(start_node) # Define successor only for "default" start_node >> next_node # same as start_node.next(next_node, "default") with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") # Run flow. start_node runs, post returns "specific_action". # Flow looks for "specific_action", but only "default" exists. last_action = pipeline.run(shared_storage) self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, UserWarning)) # Warning message should indicate "specific_action" wasn't found self.assertIn("Flow ends: 'specific_action' not found in ['default']", str(w[-1].message)) # Last action is from start_node's post self.assertEqual(last_action, "specific_action") if __name__ == '__main__': unittest.main()