update the pocketflow design
This commit is contained in:
parent
3c32908212
commit
cf4875710b
|
|
@ -37,17 +37,17 @@ class BatchNode(Node):
|
||||||
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])]
|
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])]
|
||||||
|
|
||||||
class Flow(BaseNode):
|
class Flow(BaseNode):
|
||||||
def __init__(self,start): super().__init__(); self.start=start
|
def __init__(self,start=None): super().__init__(); self.start_node=start
|
||||||
def start(self,start): self.start=start; return start
|
def start(self,start): self.start_node=start; return start
|
||||||
def get_next_node(self,curr,action):
|
def get_next_node(self,curr,action):
|
||||||
nxt=curr.successors.get(action or "default")
|
nxt=curr.successors.get(action or "default")
|
||||||
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
|
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
|
||||||
return nxt
|
return nxt
|
||||||
def _orch(self,shared,params=None):
|
def _orch(self,shared,params=None):
|
||||||
curr,p,last_action =copy.copy(self.start),(params or {**self.params}),None
|
curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
|
||||||
while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
|
while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
|
||||||
return last_action
|
return last_action
|
||||||
def _run(self,shared): pr=self.prep(shared); self._orch(shared); return self.post(shared,pr,None)
|
def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o)
|
||||||
def post(self,shared,prep_res,exec_res): return exec_res
|
def post(self,shared,prep_res,exec_res): return exec_res
|
||||||
|
|
||||||
class BatchFlow(Flow):
|
class BatchFlow(Flow):
|
||||||
|
|
@ -81,10 +81,10 @@ class AsyncParallelBatchNode(AsyncNode,BatchNode):
|
||||||
|
|
||||||
class AsyncFlow(Flow,AsyncNode):
|
class AsyncFlow(Flow,AsyncNode):
|
||||||
async def _orch_async(self,shared,params=None):
|
async def _orch_async(self,shared,params=None):
|
||||||
curr,p,last_action =copy.copy(self.start),(params or {**self.params}),None
|
curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
|
||||||
while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
|
while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
|
||||||
return last_action
|
return last_action
|
||||||
async def _run_async(self,shared): p=await self.prep_async(shared); await self._orch_async(shared); return await self.post_async(shared,p,None)
|
async def _run_async(self,shared): p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o)
|
||||||
async def post_async(self,shared,prep_res,exec_res): return exec_res
|
async def post_async(self,shared,prep_res,exec_res): return exec_res
|
||||||
|
|
||||||
class AsyncBatchFlow(AsyncFlow,BatchFlow):
|
class AsyncBatchFlow(AsyncFlow,BatchFlow):
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ from pathlib import Path
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
from pocketflow import Node, AsyncNode, AsyncFlow
|
from pocketflow import Node, AsyncNode, AsyncFlow
|
||||||
|
|
||||||
|
|
||||||
class AsyncNumberNode(AsyncNode):
|
class AsyncNumberNode(AsyncNode):
|
||||||
"""
|
"""
|
||||||
Simple async node that sets 'current' to a given number.
|
Simple async node that sets 'current' to a given number.
|
||||||
|
|
@ -29,7 +28,6 @@ class AsyncNumberNode(AsyncNode):
|
||||||
# Return a condition for the flow
|
# Return a condition for the flow
|
||||||
return "number_set"
|
return "number_set"
|
||||||
|
|
||||||
|
|
||||||
class AsyncIncrementNode(AsyncNode):
|
class AsyncIncrementNode(AsyncNode):
|
||||||
"""
|
"""
|
||||||
Demonstrates incrementing the 'current' value asynchronously.
|
Demonstrates incrementing the 'current' value asynchronously.
|
||||||
|
|
@ -42,6 +40,37 @@ class AsyncIncrementNode(AsyncNode):
|
||||||
await asyncio.sleep(0.01) # simulate async I/O
|
await asyncio.sleep(0.01) # simulate async I/O
|
||||||
return "done"
|
return "done"
|
||||||
|
|
||||||
|
class AsyncSignalNode(AsyncNode):
|
||||||
|
""" An async node that returns a specific signal string from post_async. """
|
||||||
|
def __init__(self, signal="default_async_signal"):
|
||||||
|
super().__init__()
|
||||||
|
self.signal = signal
|
||||||
|
|
||||||
|
# No prep needed usually if just signaling
|
||||||
|
async def prep_async(self, shared_storage):
|
||||||
|
await asyncio.sleep(0.01) # Simulate async work
|
||||||
|
|
||||||
|
async def post_async(self, shared_storage, prep_result, exec_result):
|
||||||
|
# Store the signal in shared storage for verification
|
||||||
|
shared_storage['last_async_signal_emitted'] = self.signal
|
||||||
|
await asyncio.sleep(0.01) # Simulate async work
|
||||||
|
print(self.signal)
|
||||||
|
return self.signal # Return the specific action string
|
||||||
|
|
||||||
|
class AsyncPathNode(AsyncNode):
|
||||||
|
""" An async node to indicate which path was taken in the outer flow. """
|
||||||
|
def __init__(self, path_id):
|
||||||
|
super().__init__()
|
||||||
|
self.path_id = path_id
|
||||||
|
|
||||||
|
async def prep_async(self, shared_storage):
|
||||||
|
await asyncio.sleep(0.01) # Simulate async work
|
||||||
|
shared_storage['async_path_taken'] = self.path_id
|
||||||
|
|
||||||
|
# post_async implicitly returns None (for default transition out if needed)
|
||||||
|
async def post_async(self, shared_storage, prep_result, exec_result):
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
# Return None by default
|
||||||
|
|
||||||
class TestAsyncNode(unittest.TestCase):
|
class TestAsyncNode(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
@ -149,6 +178,51 @@ class TestAsyncFlow(unittest.TestCase):
|
||||||
self.assertEqual(shared_storage["path"], "positive",
|
self.assertEqual(shared_storage["path"], "positive",
|
||||||
"Should have taken the positive branch")
|
"Should have taken the positive branch")
|
||||||
|
|
||||||
|
def test_async_composition_with_action_propagation(self):
|
||||||
|
"""
|
||||||
|
Test AsyncFlow branches based on action from nested AsyncFlow's last node.
|
||||||
|
"""
|
||||||
|
async def run_test():
|
||||||
|
shared_storage = {}
|
||||||
|
|
||||||
|
# 1. Define an inner async flow ending with AsyncSignalNode
|
||||||
|
# Use existing AsyncNumberNode which should return None from post_async implicitly
|
||||||
|
inner_start_node = AsyncNumberNode(200)
|
||||||
|
inner_end_node = AsyncSignalNode("async_inner_done") # post_async -> "async_inner_done"
|
||||||
|
inner_start_node - "number_set" >> inner_end_node
|
||||||
|
# Inner flow will execute start->end, Flow exec returns "async_inner_done"
|
||||||
|
inner_flow = AsyncFlow(start=inner_start_node)
|
||||||
|
|
||||||
|
# 2. Define target async nodes for the outer flow branches
|
||||||
|
path_a_node = AsyncPathNode("AsyncA") # post_async -> None
|
||||||
|
path_b_node = AsyncPathNode("AsyncB") # post_async -> None
|
||||||
|
|
||||||
|
# 3. Define the outer async flow starting with the inner async flow
|
||||||
|
outer_flow = AsyncFlow(start=inner_flow)
|
||||||
|
|
||||||
|
# 4. Define branches FROM the inner_flow object based on its returned action
|
||||||
|
inner_flow - "async_inner_done" >> path_b_node # This path should be taken
|
||||||
|
inner_flow - "other_action" >> path_a_node # This path should NOT be taken
|
||||||
|
|
||||||
|
# 5. Run the outer async flow and capture the last action
|
||||||
|
# Execution: inner_start -> inner_end -> path_b
|
||||||
|
last_action_outer = await outer_flow.run_async(shared_storage)
|
||||||
|
|
||||||
|
# 6. Return results for assertion
|
||||||
|
return shared_storage, last_action_outer
|
||||||
|
|
||||||
|
# Run the async test function
|
||||||
|
shared_storage, last_action_outer = asyncio.run(run_test())
|
||||||
|
|
||||||
|
# 7. Assert the results
|
||||||
|
# Check state after inner flow execution
|
||||||
|
self.assertEqual(shared_storage.get('current'), 200) # From AsyncNumberNode
|
||||||
|
self.assertEqual(shared_storage.get('last_async_signal_emitted'), "async_inner_done")
|
||||||
|
# Check that the correct outer path was taken
|
||||||
|
self.assertEqual(shared_storage.get('async_path_taken'), "AsyncB")
|
||||||
|
# Check the action returned by the outer flow. The last node executed was
|
||||||
|
# path_b_node, which returns None from its post_async method.
|
||||||
|
self.assertIsNone(last_action_outer)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -1,156 +1,224 @@
|
||||||
|
# tests/test_flow_basic.py
|
||||||
import unittest
|
import unittest
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
from pocketflow import Node, Flow
|
from pocketflow import Node, Flow
|
||||||
|
|
||||||
|
# --- Node Definitions ---
|
||||||
|
# Nodes intended for default transitions (>>) should NOT return a specific
|
||||||
|
# action string from post. Let it return None by default.
|
||||||
|
# Nodes intended for conditional transitions (-) MUST return the action string.
|
||||||
|
|
||||||
class NumberNode(Node):
|
class NumberNode(Node):
|
||||||
def __init__(self, number):
|
def __init__(self, number):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def prep(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
shared_storage['current'] = self.number
|
shared_storage['current'] = self.number
|
||||||
|
# post implicitly returns None - used for default transition
|
||||||
|
|
||||||
class AddNode(Node):
|
class AddNode(Node):
|
||||||
def __init__(self, number):
|
def __init__(self, number):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def prep(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
shared_storage['current'] += self.number
|
shared_storage['current'] += self.number
|
||||||
|
# post implicitly returns None - used for default transition
|
||||||
|
|
||||||
class MultiplyNode(Node):
|
class MultiplyNode(Node):
|
||||||
def __init__(self, number):
|
def __init__(self, number):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def prep(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
shared_storage['current'] *= self.number
|
shared_storage['current'] *= self.number
|
||||||
|
# post implicitly returns None - used for default transition
|
||||||
|
|
||||||
class CheckPositiveNode(Node):
|
class CheckPositiveNode(Node):
|
||||||
|
# This node IS designed for conditional branching
|
||||||
|
def prep(self, shared_storage):
|
||||||
|
pass
|
||||||
def post(self, shared_storage, prep_result, proc_result):
|
def post(self, shared_storage, prep_result, proc_result):
|
||||||
|
# MUST return the specific action string for branching
|
||||||
if shared_storage['current'] >= 0:
|
if shared_storage['current'] >= 0:
|
||||||
return 'positive'
|
return 'positive'
|
||||||
else:
|
else:
|
||||||
return 'negative'
|
return 'negative'
|
||||||
|
|
||||||
class NoOpNode(Node):
|
class NoOpNode(Node):
|
||||||
def prep(self, shared_storage):
|
# Just a placeholder node
|
||||||
# Do nothing, just pass
|
pass # post implicitly returns None
|
||||||
pass
|
|
||||||
|
class EndSignalNode(Node):
|
||||||
class TestNode(unittest.TestCase):
|
# A node specifically to return a value when it's the end
|
||||||
def test_single_number(self):
|
def __init__(self, signal="finished"):
|
||||||
|
super().__init__()
|
||||||
|
self.signal = signal
|
||||||
|
def post(self, shared_storage, prep_result, exec_result):
|
||||||
|
return self.signal # Return a specific signal
|
||||||
|
|
||||||
|
# --- Test Class ---
|
||||||
|
class TestFlowBasic(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_start_method_initialization(self):
|
||||||
|
"""Test initializing flow with start() after creation."""
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
start = NumberNode(5)
|
n1 = NumberNode(5)
|
||||||
pipeline = Flow(start=start)
|
pipeline = Flow()
|
||||||
pipeline.run(shared_storage)
|
pipeline.start(n1)
|
||||||
|
last_action = pipeline.run(shared_storage)
|
||||||
self.assertEqual(shared_storage['current'], 5)
|
self.assertEqual(shared_storage['current'], 5)
|
||||||
|
# NumberNode.post returns None (default)
|
||||||
|
self.assertIsNone(last_action)
|
||||||
|
|
||||||
def test_sequence(self):
|
def test_start_method_chaining(self):
|
||||||
"""
|
"""Test fluent chaining using start().next()..."""
|
||||||
Test a simple linear pipeline:
|
shared_storage = {}
|
||||||
NumberNode(5) -> AddNode(3) -> MultiplyNode(2)
|
pipeline = Flow()
|
||||||
|
# Chain: NumberNode -> AddNode -> MultiplyNode
|
||||||
|
# All use default transitions (post returns None)
|
||||||
|
pipeline.start(NumberNode(5)).next(AddNode(3)).next(MultiplyNode(2))
|
||||||
|
last_action = pipeline.run(shared_storage)
|
||||||
|
self.assertEqual(shared_storage['current'], 16)
|
||||||
|
# Last node (MultiplyNode) post returns None
|
||||||
|
self.assertIsNone(last_action)
|
||||||
|
|
||||||
Expected result:
|
def test_sequence_with_rshift(self):
|
||||||
(5 + 3) * 2 = 16
|
"""Test a simple linear pipeline using >>"""
|
||||||
"""
|
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
n1 = NumberNode(5)
|
n1 = NumberNode(5)
|
||||||
n2 = AddNode(3)
|
n2 = AddNode(3)
|
||||||
n3 = MultiplyNode(2)
|
n3 = MultiplyNode(2)
|
||||||
|
|
||||||
# Chain them in sequence using the >> operator
|
pipeline = Flow()
|
||||||
n1 >> n2 >> n3
|
# All default transitions (post returns None)
|
||||||
|
pipeline.start(n1) >> n2 >> n3
|
||||||
pipeline = Flow(start=n1)
|
|
||||||
pipeline.run(shared_storage)
|
|
||||||
|
|
||||||
|
last_action = pipeline.run(shared_storage)
|
||||||
self.assertEqual(shared_storage['current'], 16)
|
self.assertEqual(shared_storage['current'], 16)
|
||||||
|
# Last node (n3: MultiplyNode) post returns None
|
||||||
|
self.assertIsNone(last_action)
|
||||||
|
|
||||||
def test_branching_positive(self):
|
def test_branching_positive(self):
|
||||||
"""
|
"""Test positive branch: CheckPositiveNode returns 'positive'"""
|
||||||
Test a branching pipeline with positive route:
|
|
||||||
start = NumberNode(5)
|
|
||||||
check = CheckPositiveNode()
|
|
||||||
if 'positive' -> AddNode(10)
|
|
||||||
if 'negative' -> AddNode(-20)
|
|
||||||
|
|
||||||
Since we start with 5,
|
|
||||||
check returns 'positive',
|
|
||||||
so we add 10. Final result = 15.
|
|
||||||
"""
|
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
start = NumberNode(5)
|
start_node = NumberNode(5) # post -> None
|
||||||
check = CheckPositiveNode()
|
check_node = CheckPositiveNode() # post -> 'positive' or 'negative'
|
||||||
add_if_positive = AddNode(10)
|
add_if_positive = AddNode(10) # post -> None
|
||||||
add_if_negative = AddNode(-20)
|
add_if_negative = AddNode(-20) # post -> None (won't run)
|
||||||
|
|
||||||
start >> check
|
pipeline = Flow()
|
||||||
|
# start -> check (default); check branches on 'positive'/'negative'
|
||||||
|
pipeline.start(start_node) >> check_node
|
||||||
|
check_node - "positive" >> add_if_positive
|
||||||
|
check_node - "negative" >> add_if_negative
|
||||||
|
|
||||||
# Use the new dash operator for condition
|
# Execution: start_node -> check_node -> add_if_positive
|
||||||
check - "positive" >> add_if_positive
|
last_action = pipeline.run(shared_storage)
|
||||||
check - "negative" >> add_if_negative
|
self.assertEqual(shared_storage['current'], 15) # 5 + 10
|
||||||
|
# Last node executed was add_if_positive, its post returns None
|
||||||
|
self.assertIsNone(last_action)
|
||||||
|
|
||||||
pipeline = Flow(start=start)
|
def test_branching_negative(self):
|
||||||
pipeline.run(shared_storage)
|
"""Test negative branch: CheckPositiveNode returns 'negative'"""
|
||||||
|
|
||||||
self.assertEqual(shared_storage['current'], 15)
|
|
||||||
|
|
||||||
def test_negative_branch(self):
|
|
||||||
"""
|
|
||||||
Same branching pipeline, but starting with -5.
|
|
||||||
That should return 'negative' from CheckPositiveNode
|
|
||||||
and proceed to add_if_negative, i.e. add -20.
|
|
||||||
|
|
||||||
Final result: (-5) + (-20) = -25.
|
|
||||||
"""
|
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
start = NumberNode(-5)
|
start_node = NumberNode(-5) # post -> None
|
||||||
check = CheckPositiveNode()
|
check_node = CheckPositiveNode() # post -> 'positive' or 'negative'
|
||||||
add_if_positive = AddNode(10)
|
add_if_positive = AddNode(10) # post -> None (won't run)
|
||||||
add_if_negative = AddNode(-20)
|
add_if_negative = AddNode(-20) # post -> None
|
||||||
|
|
||||||
# Build the flow
|
pipeline = Flow()
|
||||||
start >> check
|
pipeline.start(start_node) >> check_node
|
||||||
check - "positive" >> add_if_positive
|
check_node - "positive" >> add_if_positive
|
||||||
check - "negative" >> add_if_negative
|
check_node - "negative" >> add_if_negative
|
||||||
|
|
||||||
pipeline = Flow(start=start)
|
# Execution: start_node -> check_node -> add_if_negative
|
||||||
pipeline.run(shared_storage)
|
last_action = pipeline.run(shared_storage)
|
||||||
|
self.assertEqual(shared_storage['current'], -25) # -5 + -20
|
||||||
|
# Last node executed was add_if_negative, its post returns None
|
||||||
|
self.assertIsNone(last_action)
|
||||||
|
|
||||||
# Should have gone down the 'negative' branch
|
def test_cycle_until_negative_ends_with_signal(self):
|
||||||
self.assertEqual(shared_storage['current'], -25)
|
"""Test cycle, ending on a node that returns a signal"""
|
||||||
|
|
||||||
def test_cycle_until_negative(self):
|
|
||||||
"""
|
|
||||||
Demonstrate a cyclical pipeline:
|
|
||||||
Start with 10, check if positive -> subtract 3, then go back to check.
|
|
||||||
Repeat until the number becomes negative, at which point pipeline ends.
|
|
||||||
"""
|
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
n1 = NumberNode(10)
|
n1 = NumberNode(10) # post -> None
|
||||||
check = CheckPositiveNode()
|
check = CheckPositiveNode() # post -> 'positive' or 'negative'
|
||||||
subtract3 = AddNode(-3)
|
subtract3 = AddNode(-3) # post -> None
|
||||||
no_op = NoOpNode() # Dummy node for the 'negative' branch
|
end_node = EndSignalNode("cycle_done") # post -> "cycle_done"
|
||||||
|
|
||||||
# Build the cycle:
|
pipeline = Flow()
|
||||||
# n1 -> check -> if 'positive': subtract3 -> back to check
|
pipeline.start(n1) >> check
|
||||||
n1 >> check
|
# Branching from CheckPositiveNode
|
||||||
check - 'positive' >> subtract3
|
check - 'positive' >> subtract3
|
||||||
subtract3 >> check
|
check - 'negative' >> end_node # End on negative branch
|
||||||
|
# After subtracting, go back to check (default transition)
|
||||||
# Attach a no-op node on the negative branch to avoid warning
|
subtract3 >> check
|
||||||
check - 'negative' >> no_op
|
|
||||||
|
|
||||||
pipeline = Flow(start=n1)
|
# Execution: n1->check->sub3->check->sub3->check->sub3->check->sub3->check->end_node
|
||||||
pipeline.run(shared_storage)
|
last_action = pipeline.run(shared_storage)
|
||||||
|
self.assertEqual(shared_storage['current'], -2) # 10 -> 7 -> 4 -> 1 -> -2
|
||||||
|
# Last node executed was end_node, its post returns "cycle_done"
|
||||||
|
self.assertEqual(last_action, "cycle_done")
|
||||||
|
|
||||||
# final result should be -2: (10 -> 7 -> 4 -> 1 -> -2)
|
def test_flow_ends_warning_default_missing(self):
|
||||||
self.assertEqual(shared_storage['current'], -2)
|
"""Test warning when default transition is needed but not found"""
|
||||||
|
shared_storage = {}
|
||||||
|
# Node that returns a specific action from post
|
||||||
|
class ActionNode(Node):
|
||||||
|
def post(self, *args): return "specific_action"
|
||||||
|
start_node = ActionNode()
|
||||||
|
next_node = NoOpNode()
|
||||||
|
|
||||||
|
pipeline = Flow()
|
||||||
|
pipeline.start(start_node)
|
||||||
|
# Define successor only for the specific action
|
||||||
|
start_node - "specific_action" >> next_node
|
||||||
|
|
||||||
|
# Make start_node return None instead, triggering default search
|
||||||
|
start_node.post = lambda *args: None
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
# Run flow. start_node runs, post returns None.
|
||||||
|
# Flow looks for "default", but only "specific_action" exists.
|
||||||
|
last_action = pipeline.run(shared_storage)
|
||||||
|
|
||||||
|
self.assertEqual(len(w), 1)
|
||||||
|
self.assertTrue(issubclass(w[-1].category, UserWarning))
|
||||||
|
# Warning message should indicate "default" wasn't found
|
||||||
|
self.assertIn("Flow ends: 'None' not found in ['specific_action']", str(w[-1].message))
|
||||||
|
# Last action is from start_node's post
|
||||||
|
self.assertIsNone(last_action)
|
||||||
|
|
||||||
|
def test_flow_ends_warning_specific_missing(self):
|
||||||
|
"""Test warning when specific action is returned but not found"""
|
||||||
|
shared_storage = {}
|
||||||
|
# Node that returns a specific action from post
|
||||||
|
class ActionNode(Node):
|
||||||
|
def post(self, *args): return "specific_action"
|
||||||
|
start_node = ActionNode()
|
||||||
|
next_node = NoOpNode()
|
||||||
|
|
||||||
|
pipeline = Flow()
|
||||||
|
pipeline.start(start_node)
|
||||||
|
# Define successor only for "default"
|
||||||
|
start_node >> next_node # same as start_node.next(next_node, "default")
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
# Run flow. start_node runs, post returns "specific_action".
|
||||||
|
# Flow looks for "specific_action", but only "default" exists.
|
||||||
|
last_action = pipeline.run(shared_storage)
|
||||||
|
|
||||||
|
self.assertEqual(len(w), 1)
|
||||||
|
self.assertTrue(issubclass(w[-1].category, UserWarning))
|
||||||
|
# Warning message should indicate "specific_action" wasn't found
|
||||||
|
self.assertIn("Flow ends: 'specific_action' not found in ['default']", str(w[-1].message))
|
||||||
|
# Last action is from start_node's post
|
||||||
|
self.assertEqual(last_action, "specific_action")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
@ -1,38 +1,62 @@
|
||||||
|
# tests/test_flow_composition.py
|
||||||
import unittest
|
import unittest
|
||||||
import asyncio
|
import asyncio # Keep import, might be needed if other tests use it indirectly
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
from pocketflow import Node, Flow
|
from pocketflow import Node, Flow
|
||||||
|
|
||||||
# Simple example Nodes
|
# --- Existing Nodes ---
|
||||||
class NumberNode(Node):
|
class NumberNode(Node):
|
||||||
def __init__(self, number):
|
def __init__(self, number):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def prep(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
shared_storage['current'] = self.number
|
shared_storage['current'] = self.number
|
||||||
|
# post implicitly returns None
|
||||||
|
|
||||||
class AddNode(Node):
|
class AddNode(Node):
|
||||||
def __init__(self, number):
|
def __init__(self, number):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def prep(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
shared_storage['current'] += self.number
|
shared_storage['current'] += self.number
|
||||||
|
# post implicitly returns None
|
||||||
|
|
||||||
class MultiplyNode(Node):
|
class MultiplyNode(Node):
|
||||||
def __init__(self, number):
|
def __init__(self, number):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def prep(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
shared_storage['current'] *= self.number
|
shared_storage['current'] *= self.number
|
||||||
|
# post implicitly returns None
|
||||||
|
|
||||||
|
# --- New Nodes for Action Propagation Test ---
|
||||||
|
class SignalNode(Node):
|
||||||
|
"""A node that returns a specific signal string from its post method."""
|
||||||
|
def __init__(self, signal="default_signal"):
|
||||||
|
super().__init__()
|
||||||
|
self.signal = signal
|
||||||
|
# No prep needed usually if just signaling
|
||||||
|
def post(self, shared_storage, prep_result, exec_result):
|
||||||
|
# Store the signal in shared storage for verification
|
||||||
|
shared_storage['last_signal_emitted'] = self.signal
|
||||||
|
return self.signal # Return the specific action string
|
||||||
|
|
||||||
|
class PathNode(Node):
|
||||||
|
"""A node to indicate which path was taken in the outer flow."""
|
||||||
|
def __init__(self, path_id):
|
||||||
|
super().__init__()
|
||||||
|
self.path_id = path_id
|
||||||
|
def prep(self, shared_storage):
|
||||||
|
shared_storage['path_taken'] = self.path_id
|
||||||
|
# post implicitly returns None
|
||||||
|
|
||||||
|
# --- Test Class ---
|
||||||
class TestFlowComposition(unittest.TestCase):
|
class TestFlowComposition(unittest.TestCase):
|
||||||
|
|
||||||
|
# --- Existing Tests (Unchanged) ---
|
||||||
def test_flow_as_node(self):
|
def test_flow_as_node(self):
|
||||||
"""
|
"""
|
||||||
1) Create a Flow (f1) starting with NumberNode(5), then AddNode(10), then MultiplyNode(2).
|
1) Create a Flow (f1) starting with NumberNode(5), then AddNode(10), then MultiplyNode(2).
|
||||||
|
|
@ -41,18 +65,11 @@ class TestFlowComposition(unittest.TestCase):
|
||||||
Expected final result in shared_storage['current']: (5 + 10) * 2 = 30.
|
Expected final result in shared_storage['current']: (5 + 10) * 2 = 30.
|
||||||
"""
|
"""
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
|
|
||||||
# Inner flow f1
|
|
||||||
f1 = Flow(start=NumberNode(5))
|
f1 = Flow(start=NumberNode(5))
|
||||||
f1 >> AddNode(10) >> MultiplyNode(2)
|
f1 >> AddNode(10) >> MultiplyNode(2)
|
||||||
|
|
||||||
# f2 starts with f1
|
|
||||||
f2 = Flow(start=f1)
|
f2 = Flow(start=f1)
|
||||||
|
|
||||||
# Wrapper flow f3 to ensure proper execution
|
|
||||||
f3 = Flow(start=f2)
|
f3 = Flow(start=f2)
|
||||||
f3.run(shared_storage)
|
f3.run(shared_storage)
|
||||||
|
|
||||||
self.assertEqual(shared_storage['current'], 30)
|
self.assertEqual(shared_storage['current'], 30)
|
||||||
|
|
||||||
def test_nested_flow(self):
|
def test_nested_flow(self):
|
||||||
|
|
@ -64,19 +81,12 @@ class TestFlowComposition(unittest.TestCase):
|
||||||
Expected final result: (5 + 3) * 4 = 32.
|
Expected final result: (5 + 3) * 4 = 32.
|
||||||
"""
|
"""
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
|
|
||||||
# Build the inner flow
|
|
||||||
inner_flow = Flow(start=NumberNode(5))
|
inner_flow = Flow(start=NumberNode(5))
|
||||||
inner_flow >> AddNode(3)
|
inner_flow >> AddNode(3)
|
||||||
|
|
||||||
# Build the middle flow, whose start is the inner flow
|
|
||||||
middle_flow = Flow(start=inner_flow)
|
middle_flow = Flow(start=inner_flow)
|
||||||
middle_flow >> MultiplyNode(4)
|
middle_flow >> MultiplyNode(4)
|
||||||
|
|
||||||
# Wrapper flow to ensure proper execution
|
|
||||||
wrapper_flow = Flow(start=middle_flow)
|
wrapper_flow = Flow(start=middle_flow)
|
||||||
wrapper_flow.run(shared_storage)
|
wrapper_flow.run(shared_storage)
|
||||||
|
|
||||||
self.assertEqual(shared_storage['current'], 32)
|
self.assertEqual(shared_storage['current'], 32)
|
||||||
|
|
||||||
def test_flow_chaining_flows(self):
|
def test_flow_chaining_flows(self):
|
||||||
|
|
@ -88,23 +98,54 @@ class TestFlowComposition(unittest.TestCase):
|
||||||
Expected final result: (10 + 10) * 2 = 40.
|
Expected final result: (10 + 10) * 2 = 40.
|
||||||
"""
|
"""
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
|
|
||||||
# flow1
|
|
||||||
numbernode = NumberNode(10)
|
numbernode = NumberNode(10)
|
||||||
numbernode >> AddNode(10)
|
numbernode >> AddNode(10)
|
||||||
flow1 = Flow(start=numbernode)
|
flow1 = Flow(start=numbernode)
|
||||||
|
|
||||||
# flow2
|
|
||||||
flow2 = Flow(start=MultiplyNode(2))
|
flow2 = Flow(start=MultiplyNode(2))
|
||||||
|
flow1 >> flow2 # Default transition based on flow1 returning None
|
||||||
# Chain flow1 to flow2
|
|
||||||
flow1 >> flow2
|
|
||||||
|
|
||||||
# Wrapper flow to ensure proper execution
|
|
||||||
wrapper_flow = Flow(start=flow1)
|
wrapper_flow = Flow(start=flow1)
|
||||||
wrapper_flow.run(shared_storage)
|
wrapper_flow.run(shared_storage)
|
||||||
|
|
||||||
self.assertEqual(shared_storage['current'], 40)
|
self.assertEqual(shared_storage['current'], 40)
|
||||||
|
|
||||||
|
def test_composition_with_action_propagation(self):
|
||||||
|
"""
|
||||||
|
Test that an outer flow can branch based on the action returned
|
||||||
|
by the last node's post() within an inner flow.
|
||||||
|
"""
|
||||||
|
shared_storage = {}
|
||||||
|
|
||||||
|
# 1. Define an inner flow that ends with a node returning a specific action
|
||||||
|
inner_start_node = NumberNode(100) # current = 100, post -> None
|
||||||
|
inner_end_node = SignalNode("inner_done") # post -> "inner_done"
|
||||||
|
inner_start_node >> inner_end_node
|
||||||
|
# Inner flow will execute start->end, and the Flow's execution will return "inner_done"
|
||||||
|
inner_flow = Flow(start=inner_start_node)
|
||||||
|
|
||||||
|
# 2. Define target nodes for the outer flow branches
|
||||||
|
path_a_node = PathNode("A") # post -> None
|
||||||
|
path_b_node = PathNode("B") # post -> None
|
||||||
|
|
||||||
|
# 3. Define the outer flow starting with the inner flow
|
||||||
|
outer_flow = Flow()
|
||||||
|
outer_flow.start(inner_flow) # Use the start() method
|
||||||
|
|
||||||
|
# 4. Define branches FROM the inner_flow object based on its returned action
|
||||||
|
inner_flow - "inner_done" >> path_b_node # This path should be taken
|
||||||
|
inner_flow - "other_action" >> path_a_node # This path should NOT be taken
|
||||||
|
|
||||||
|
# 5. Run the outer flow and capture the last action
|
||||||
|
# Execution: inner_start -> inner_end -> path_b
|
||||||
|
last_action_outer = outer_flow.run(shared_storage)
|
||||||
|
|
||||||
|
# 6. Assert the results
|
||||||
|
# Check state after inner flow execution
|
||||||
|
self.assertEqual(shared_storage.get('current'), 100)
|
||||||
|
self.assertEqual(shared_storage.get('last_signal_emitted'), "inner_done")
|
||||||
|
# Check that the correct outer path was taken
|
||||||
|
self.assertEqual(shared_storage.get('path_taken'), "B")
|
||||||
|
# Check the action returned by the outer flow. The last node executed was
|
||||||
|
# path_b_node, which returns None from its post method.
|
||||||
|
self.assertIsNone(last_action_outer)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
Loading…
Reference in New Issue