pocketflow/tests/test_flow_basic.py

224 lines
8.8 KiB
Python

# tests/test_flow_basic.py
import unittest
import sys
from pathlib import Path
import warnings
sys.path.insert(0, str(Path(__file__).parent.parent))
from pocketflow import Node, Flow
# --- Node Definitions ---
# Nodes intended for default transitions (>>) should NOT return a specific
# action string from post. Let it return None by default.
# Nodes intended for conditional transitions (-) MUST return the action string.
class NumberNode(Node):
def __init__(self, number):
super().__init__()
self.number = number
def prep(self, shared_storage):
shared_storage['current'] = self.number
# post implicitly returns None - used for default transition
class AddNode(Node):
def __init__(self, number):
super().__init__()
self.number = number
def prep(self, shared_storage):
shared_storage['current'] += self.number
# post implicitly returns None - used for default transition
class MultiplyNode(Node):
def __init__(self, number):
super().__init__()
self.number = number
def prep(self, shared_storage):
shared_storage['current'] *= self.number
# post implicitly returns None - used for default transition
class CheckPositiveNode(Node):
# This node IS designed for conditional branching
def prep(self, shared_storage):
pass
def post(self, shared_storage, prep_result, proc_result):
# MUST return the specific action string for branching
if shared_storage['current'] >= 0:
return 'positive'
else:
return 'negative'
class NoOpNode(Node):
# Just a placeholder node
pass # post implicitly returns None
class EndSignalNode(Node):
# A node specifically to return a value when it's the end
def __init__(self, signal="finished"):
super().__init__()
self.signal = signal
def post(self, shared_storage, prep_result, exec_result):
return self.signal # Return a specific signal
# --- Test Class ---
class TestFlowBasic(unittest.TestCase):
def test_start_method_initialization(self):
"""Test initializing flow with start() after creation."""
shared_storage = {}
n1 = NumberNode(5)
pipeline = Flow()
pipeline.start(n1)
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], 5)
# NumberNode.post returns None (default)
self.assertIsNone(last_action)
def test_start_method_chaining(self):
"""Test fluent chaining using start().next()..."""
shared_storage = {}
pipeline = Flow()
# Chain: NumberNode -> AddNode -> MultiplyNode
# All use default transitions (post returns None)
pipeline.start(NumberNode(5)).next(AddNode(3)).next(MultiplyNode(2))
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], 16)
# Last node (MultiplyNode) post returns None
self.assertIsNone(last_action)
def test_sequence_with_rshift(self):
"""Test a simple linear pipeline using >>"""
shared_storage = {}
n1 = NumberNode(5)
n2 = AddNode(3)
n3 = MultiplyNode(2)
pipeline = Flow()
# All default transitions (post returns None)
pipeline.start(n1) >> n2 >> n3
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], 16)
# Last node (n3: MultiplyNode) post returns None
self.assertIsNone(last_action)
def test_branching_positive(self):
"""Test positive branch: CheckPositiveNode returns 'positive'"""
shared_storage = {}
start_node = NumberNode(5) # post -> None
check_node = CheckPositiveNode() # post -> 'positive' or 'negative'
add_if_positive = AddNode(10) # post -> None
add_if_negative = AddNode(-20) # post -> None (won't run)
pipeline = Flow()
# start -> check (default); check branches on 'positive'/'negative'
pipeline.start(start_node) >> check_node
check_node - "positive" >> add_if_positive
check_node - "negative" >> add_if_negative
# Execution: start_node -> check_node -> add_if_positive
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], 15) # 5 + 10
# Last node executed was add_if_positive, its post returns None
self.assertIsNone(last_action)
def test_branching_negative(self):
"""Test negative branch: CheckPositiveNode returns 'negative'"""
shared_storage = {}
start_node = NumberNode(-5) # post -> None
check_node = CheckPositiveNode() # post -> 'positive' or 'negative'
add_if_positive = AddNode(10) # post -> None (won't run)
add_if_negative = AddNode(-20) # post -> None
pipeline = Flow()
pipeline.start(start_node) >> check_node
check_node - "positive" >> add_if_positive
check_node - "negative" >> add_if_negative
# Execution: start_node -> check_node -> add_if_negative
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], -25) # -5 + -20
# Last node executed was add_if_negative, its post returns None
self.assertIsNone(last_action)
def test_cycle_until_negative_ends_with_signal(self):
"""Test cycle, ending on a node that returns a signal"""
shared_storage = {}
n1 = NumberNode(10) # post -> None
check = CheckPositiveNode() # post -> 'positive' or 'negative'
subtract3 = AddNode(-3) # post -> None
end_node = EndSignalNode("cycle_done") # post -> "cycle_done"
pipeline = Flow()
pipeline.start(n1) >> check
# Branching from CheckPositiveNode
check - 'positive' >> subtract3
check - 'negative' >> end_node # End on negative branch
# After subtracting, go back to check (default transition)
subtract3 >> check
# Execution: n1->check->sub3->check->sub3->check->sub3->check->sub3->check->end_node
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], -2) # 10 -> 7 -> 4 -> 1 -> -2
# Last node executed was end_node, its post returns "cycle_done"
self.assertEqual(last_action, "cycle_done")
def test_flow_ends_warning_default_missing(self):
"""Test warning when default transition is needed but not found"""
shared_storage = {}
# Node that returns a specific action from post
class ActionNode(Node):
def post(self, *args): return "specific_action"
start_node = ActionNode()
next_node = NoOpNode()
pipeline = Flow()
pipeline.start(start_node)
# Define successor only for the specific action
start_node - "specific_action" >> next_node
# Make start_node return None instead, triggering default search
start_node.post = lambda *args: None
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Run flow. start_node runs, post returns None.
# Flow looks for "default", but only "specific_action" exists.
last_action = pipeline.run(shared_storage)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
# Warning message should indicate "default" wasn't found
self.assertIn("Flow ends: 'None' not found in ['specific_action']", str(w[-1].message))
# Last action is from start_node's post
self.assertIsNone(last_action)
def test_flow_ends_warning_specific_missing(self):
"""Test warning when specific action is returned but not found"""
shared_storage = {}
# Node that returns a specific action from post
class ActionNode(Node):
def post(self, *args): return "specific_action"
start_node = ActionNode()
next_node = NoOpNode()
pipeline = Flow()
pipeline.start(start_node)
# Define successor only for "default"
start_node >> next_node # same as start_node.next(next_node, "default")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Run flow. start_node runs, post returns "specific_action".
# Flow looks for "specific_action", but only "default" exists.
last_action = pipeline.run(shared_storage)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
# Warning message should indicate "specific_action" wasn't found
self.assertIn("Flow ends: 'specific_action' not found in ['default']", str(w[-1].message))
# Last action is from start_node's post
self.assertEqual(last_action, "specific_action")
if __name__ == '__main__':
unittest.main()