update the pocketflow design

This commit is contained in:
zachary62 2025-04-11 15:23:21 -04:00
parent 3c32908212
commit cf4875710b
4 changed files with 312 additions and 129 deletions

View File

@ -37,17 +37,17 @@ class BatchNode(Node):
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])] def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])]
class Flow(BaseNode): class Flow(BaseNode):
def __init__(self,start): super().__init__(); self.start=start def __init__(self,start=None): super().__init__(); self.start_node=start
def start(self,start): self.start=start; return start def start(self,start): self.start_node=start; return start
def get_next_node(self,curr,action): def get_next_node(self,curr,action):
nxt=curr.successors.get(action or "default") nxt=curr.successors.get(action or "default")
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
return nxt return nxt
def _orch(self,shared,params=None): def _orch(self,shared,params=None):
curr,p,last_action =copy.copy(self.start),(params or {**self.params}),None curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
return last_action return last_action
def _run(self,shared): pr=self.prep(shared); self._orch(shared); return self.post(shared,pr,None) def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o)
def post(self,shared,prep_res,exec_res): return exec_res def post(self,shared,prep_res,exec_res): return exec_res
class BatchFlow(Flow): class BatchFlow(Flow):
@ -81,10 +81,10 @@ class AsyncParallelBatchNode(AsyncNode,BatchNode):
class AsyncFlow(Flow,AsyncNode): class AsyncFlow(Flow,AsyncNode):
async def _orch_async(self,shared,params=None): async def _orch_async(self,shared,params=None):
curr,p,last_action =copy.copy(self.start),(params or {**self.params}),None curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
return last_action return last_action
async def _run_async(self,shared): p=await self.prep_async(shared); await self._orch_async(shared); return await self.post_async(shared,p,None) async def _run_async(self,shared): p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o)
async def post_async(self,shared,prep_res,exec_res): return exec_res async def post_async(self,shared,prep_res,exec_res): return exec_res
class AsyncBatchFlow(AsyncFlow,BatchFlow): class AsyncBatchFlow(AsyncFlow,BatchFlow):

View File

@ -6,7 +6,6 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent))
from pocketflow import Node, AsyncNode, AsyncFlow from pocketflow import Node, AsyncNode, AsyncFlow
class AsyncNumberNode(AsyncNode): class AsyncNumberNode(AsyncNode):
""" """
Simple async node that sets 'current' to a given number. Simple async node that sets 'current' to a given number.
@ -29,7 +28,6 @@ class AsyncNumberNode(AsyncNode):
# Return a condition for the flow # Return a condition for the flow
return "number_set" return "number_set"
class AsyncIncrementNode(AsyncNode): class AsyncIncrementNode(AsyncNode):
""" """
Demonstrates incrementing the 'current' value asynchronously. Demonstrates incrementing the 'current' value asynchronously.
@ -42,6 +40,37 @@ class AsyncIncrementNode(AsyncNode):
await asyncio.sleep(0.01) # simulate async I/O await asyncio.sleep(0.01) # simulate async I/O
return "done" return "done"
class AsyncSignalNode(AsyncNode):
""" An async node that returns a specific signal string from post_async. """
def __init__(self, signal="default_async_signal"):
super().__init__()
self.signal = signal
# No prep needed usually if just signaling
async def prep_async(self, shared_storage):
await asyncio.sleep(0.01) # Simulate async work
async def post_async(self, shared_storage, prep_result, exec_result):
# Store the signal in shared storage for verification
shared_storage['last_async_signal_emitted'] = self.signal
await asyncio.sleep(0.01) # Simulate async work
print(self.signal)
return self.signal # Return the specific action string
class AsyncPathNode(AsyncNode):
""" An async node to indicate which path was taken in the outer flow. """
def __init__(self, path_id):
super().__init__()
self.path_id = path_id
async def prep_async(self, shared_storage):
await asyncio.sleep(0.01) # Simulate async work
shared_storage['async_path_taken'] = self.path_id
# post_async implicitly returns None (for default transition out if needed)
async def post_async(self, shared_storage, prep_result, exec_result):
await asyncio.sleep(0.01)
# Return None by default
class TestAsyncNode(unittest.TestCase): class TestAsyncNode(unittest.TestCase):
""" """
@ -149,6 +178,51 @@ class TestAsyncFlow(unittest.TestCase):
self.assertEqual(shared_storage["path"], "positive", self.assertEqual(shared_storage["path"], "positive",
"Should have taken the positive branch") "Should have taken the positive branch")
def test_async_composition_with_action_propagation(self):
"""
Test AsyncFlow branches based on action from nested AsyncFlow's last node.
"""
async def run_test():
shared_storage = {}
# 1. Define an inner async flow ending with AsyncSignalNode
# Use existing AsyncNumberNode which should return None from post_async implicitly
inner_start_node = AsyncNumberNode(200)
inner_end_node = AsyncSignalNode("async_inner_done") # post_async -> "async_inner_done"
inner_start_node - "number_set" >> inner_end_node
# Inner flow will execute start->end, Flow exec returns "async_inner_done"
inner_flow = AsyncFlow(start=inner_start_node)
# 2. Define target async nodes for the outer flow branches
path_a_node = AsyncPathNode("AsyncA") # post_async -> None
path_b_node = AsyncPathNode("AsyncB") # post_async -> None
# 3. Define the outer async flow starting with the inner async flow
outer_flow = AsyncFlow(start=inner_flow)
# 4. Define branches FROM the inner_flow object based on its returned action
inner_flow - "async_inner_done" >> path_b_node # This path should be taken
inner_flow - "other_action" >> path_a_node # This path should NOT be taken
# 5. Run the outer async flow and capture the last action
# Execution: inner_start -> inner_end -> path_b
last_action_outer = await outer_flow.run_async(shared_storage)
# 6. Return results for assertion
return shared_storage, last_action_outer
# Run the async test function
shared_storage, last_action_outer = asyncio.run(run_test())
# 7. Assert the results
# Check state after inner flow execution
self.assertEqual(shared_storage.get('current'), 200) # From AsyncNumberNode
self.assertEqual(shared_storage.get('last_async_signal_emitted'), "async_inner_done")
# Check that the correct outer path was taken
self.assertEqual(shared_storage.get('async_path_taken'), "AsyncB")
# Check the action returned by the outer flow. The last node executed was
# path_b_node, which returns None from its post_async method.
self.assertIsNone(last_action_outer)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,156 +1,224 @@
# tests/test_flow_basic.py
import unittest import unittest
import sys import sys
from pathlib import Path from pathlib import Path
import warnings
sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent))
from pocketflow import Node, Flow from pocketflow import Node, Flow
# --- Node Definitions ---
# Nodes intended for default transitions (>>) should NOT return a specific
# action string from post. Let it return None by default.
# Nodes intended for conditional transitions (-) MUST return the action string.
class NumberNode(Node): class NumberNode(Node):
def __init__(self, number): def __init__(self, number):
super().__init__() super().__init__()
self.number = number self.number = number
def prep(self, shared_storage): def prep(self, shared_storage):
shared_storage['current'] = self.number shared_storage['current'] = self.number
# post implicitly returns None - used for default transition
class AddNode(Node): class AddNode(Node):
def __init__(self, number): def __init__(self, number):
super().__init__() super().__init__()
self.number = number self.number = number
def prep(self, shared_storage): def prep(self, shared_storage):
shared_storage['current'] += self.number shared_storage['current'] += self.number
# post implicitly returns None - used for default transition
class MultiplyNode(Node): class MultiplyNode(Node):
def __init__(self, number): def __init__(self, number):
super().__init__() super().__init__()
self.number = number self.number = number
def prep(self, shared_storage): def prep(self, shared_storage):
shared_storage['current'] *= self.number shared_storage['current'] *= self.number
# post implicitly returns None - used for default transition
class CheckPositiveNode(Node): class CheckPositiveNode(Node):
# This node IS designed for conditional branching
def prep(self, shared_storage):
pass
def post(self, shared_storage, prep_result, proc_result): def post(self, shared_storage, prep_result, proc_result):
# MUST return the specific action string for branching
if shared_storage['current'] >= 0: if shared_storage['current'] >= 0:
return 'positive' return 'positive'
else: else:
return 'negative' return 'negative'
class NoOpNode(Node): class NoOpNode(Node):
def prep(self, shared_storage): # Just a placeholder node
# Do nothing, just pass pass # post implicitly returns None
pass
class EndSignalNode(Node):
class TestNode(unittest.TestCase): # A node specifically to return a value when it's the end
def test_single_number(self): def __init__(self, signal="finished"):
super().__init__()
self.signal = signal
def post(self, shared_storage, prep_result, exec_result):
return self.signal # Return a specific signal
# --- Test Class ---
class TestFlowBasic(unittest.TestCase):
def test_start_method_initialization(self):
"""Test initializing flow with start() after creation."""
shared_storage = {} shared_storage = {}
start = NumberNode(5) n1 = NumberNode(5)
pipeline = Flow(start=start) pipeline = Flow()
pipeline.run(shared_storage) pipeline.start(n1)
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], 5) self.assertEqual(shared_storage['current'], 5)
# NumberNode.post returns None (default)
self.assertIsNone(last_action)
def test_sequence(self): def test_start_method_chaining(self):
""" """Test fluent chaining using start().next()..."""
Test a simple linear pipeline: shared_storage = {}
NumberNode(5) -> AddNode(3) -> MultiplyNode(2) pipeline = Flow()
# Chain: NumberNode -> AddNode -> MultiplyNode
# All use default transitions (post returns None)
pipeline.start(NumberNode(5)).next(AddNode(3)).next(MultiplyNode(2))
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], 16)
# Last node (MultiplyNode) post returns None
self.assertIsNone(last_action)
Expected result: def test_sequence_with_rshift(self):
(5 + 3) * 2 = 16 """Test a simple linear pipeline using >>"""
"""
shared_storage = {} shared_storage = {}
n1 = NumberNode(5) n1 = NumberNode(5)
n2 = AddNode(3) n2 = AddNode(3)
n3 = MultiplyNode(2) n3 = MultiplyNode(2)
# Chain them in sequence using the >> operator pipeline = Flow()
n1 >> n2 >> n3 # All default transitions (post returns None)
pipeline.start(n1) >> n2 >> n3
pipeline = Flow(start=n1)
pipeline.run(shared_storage)
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], 16) self.assertEqual(shared_storage['current'], 16)
# Last node (n3: MultiplyNode) post returns None
self.assertIsNone(last_action)
def test_branching_positive(self): def test_branching_positive(self):
""" """Test positive branch: CheckPositiveNode returns 'positive'"""
Test a branching pipeline with positive route:
start = NumberNode(5)
check = CheckPositiveNode()
if 'positive' -> AddNode(10)
if 'negative' -> AddNode(-20)
Since we start with 5,
check returns 'positive',
so we add 10. Final result = 15.
"""
shared_storage = {} shared_storage = {}
start = NumberNode(5) start_node = NumberNode(5) # post -> None
check = CheckPositiveNode() check_node = CheckPositiveNode() # post -> 'positive' or 'negative'
add_if_positive = AddNode(10) add_if_positive = AddNode(10) # post -> None
add_if_negative = AddNode(-20) add_if_negative = AddNode(-20) # post -> None (won't run)
start >> check pipeline = Flow()
# start -> check (default); check branches on 'positive'/'negative'
pipeline.start(start_node) >> check_node
check_node - "positive" >> add_if_positive
check_node - "negative" >> add_if_negative
# Use the new dash operator for condition # Execution: start_node -> check_node -> add_if_positive
check - "positive" >> add_if_positive last_action = pipeline.run(shared_storage)
check - "negative" >> add_if_negative self.assertEqual(shared_storage['current'], 15) # 5 + 10
# Last node executed was add_if_positive, its post returns None
self.assertIsNone(last_action)
pipeline = Flow(start=start) def test_branching_negative(self):
pipeline.run(shared_storage) """Test negative branch: CheckPositiveNode returns 'negative'"""
self.assertEqual(shared_storage['current'], 15)
def test_negative_branch(self):
"""
Same branching pipeline, but starting with -5.
That should return 'negative' from CheckPositiveNode
and proceed to add_if_negative, i.e. add -20.
Final result: (-5) + (-20) = -25.
"""
shared_storage = {} shared_storage = {}
start = NumberNode(-5) start_node = NumberNode(-5) # post -> None
check = CheckPositiveNode() check_node = CheckPositiveNode() # post -> 'positive' or 'negative'
add_if_positive = AddNode(10) add_if_positive = AddNode(10) # post -> None (won't run)
add_if_negative = AddNode(-20) add_if_negative = AddNode(-20) # post -> None
# Build the flow pipeline = Flow()
start >> check pipeline.start(start_node) >> check_node
check - "positive" >> add_if_positive check_node - "positive" >> add_if_positive
check - "negative" >> add_if_negative check_node - "negative" >> add_if_negative
pipeline = Flow(start=start) # Execution: start_node -> check_node -> add_if_negative
pipeline.run(shared_storage) last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], -25) # -5 + -20
# Last node executed was add_if_negative, its post returns None
self.assertIsNone(last_action)
# Should have gone down the 'negative' branch def test_cycle_until_negative_ends_with_signal(self):
self.assertEqual(shared_storage['current'], -25) """Test cycle, ending on a node that returns a signal"""
def test_cycle_until_negative(self):
"""
Demonstrate a cyclical pipeline:
Start with 10, check if positive -> subtract 3, then go back to check.
Repeat until the number becomes negative, at which point pipeline ends.
"""
shared_storage = {} shared_storage = {}
n1 = NumberNode(10) n1 = NumberNode(10) # post -> None
check = CheckPositiveNode() check = CheckPositiveNode() # post -> 'positive' or 'negative'
subtract3 = AddNode(-3) subtract3 = AddNode(-3) # post -> None
no_op = NoOpNode() # Dummy node for the 'negative' branch end_node = EndSignalNode("cycle_done") # post -> "cycle_done"
# Build the cycle: pipeline = Flow()
# n1 -> check -> if 'positive': subtract3 -> back to check pipeline.start(n1) >> check
n1 >> check # Branching from CheckPositiveNode
check - 'positive' >> subtract3 check - 'positive' >> subtract3
subtract3 >> check check - 'negative' >> end_node # End on negative branch
# After subtracting, go back to check (default transition)
# Attach a no-op node on the negative branch to avoid warning subtract3 >> check
check - 'negative' >> no_op
pipeline = Flow(start=n1) # Execution: n1->check->sub3->check->sub3->check->sub3->check->sub3->check->end_node
pipeline.run(shared_storage) last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], -2) # 10 -> 7 -> 4 -> 1 -> -2
# Last node executed was end_node, its post returns "cycle_done"
self.assertEqual(last_action, "cycle_done")
# final result should be -2: (10 -> 7 -> 4 -> 1 -> -2) def test_flow_ends_warning_default_missing(self):
self.assertEqual(shared_storage['current'], -2) """Test warning when default transition is needed but not found"""
shared_storage = {}
# Node that returns a specific action from post
class ActionNode(Node):
def post(self, *args): return "specific_action"
start_node = ActionNode()
next_node = NoOpNode()
pipeline = Flow()
pipeline.start(start_node)
# Define successor only for the specific action
start_node - "specific_action" >> next_node
# Make start_node return None instead, triggering default search
start_node.post = lambda *args: None
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Run flow. start_node runs, post returns None.
# Flow looks for "default", but only "specific_action" exists.
last_action = pipeline.run(shared_storage)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
# Warning message should indicate "default" wasn't found
self.assertIn("Flow ends: 'None' not found in ['specific_action']", str(w[-1].message))
# Last action is from start_node's post
self.assertIsNone(last_action)
def test_flow_ends_warning_specific_missing(self):
"""Test warning when specific action is returned but not found"""
shared_storage = {}
# Node that returns a specific action from post
class ActionNode(Node):
def post(self, *args): return "specific_action"
start_node = ActionNode()
next_node = NoOpNode()
pipeline = Flow()
pipeline.start(start_node)
# Define successor only for "default"
start_node >> next_node # same as start_node.next(next_node, "default")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Run flow. start_node runs, post returns "specific_action".
# Flow looks for "specific_action", but only "default" exists.
last_action = pipeline.run(shared_storage)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
# Warning message should indicate "specific_action" wasn't found
self.assertIn("Flow ends: 'specific_action' not found in ['default']", str(w[-1].message))
# Last action is from start_node's post
self.assertEqual(last_action, "specific_action")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,38 +1,62 @@
# tests/test_flow_composition.py
import unittest import unittest
import asyncio import asyncio # Keep import, might be needed if other tests use it indirectly
import sys import sys
from pathlib import Path from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent))
from pocketflow import Node, Flow from pocketflow import Node, Flow
# Simple example Nodes # --- Existing Nodes ---
class NumberNode(Node): class NumberNode(Node):
def __init__(self, number): def __init__(self, number):
super().__init__() super().__init__()
self.number = number self.number = number
def prep(self, shared_storage): def prep(self, shared_storage):
shared_storage['current'] = self.number shared_storage['current'] = self.number
# post implicitly returns None
class AddNode(Node): class AddNode(Node):
def __init__(self, number): def __init__(self, number):
super().__init__() super().__init__()
self.number = number self.number = number
def prep(self, shared_storage): def prep(self, shared_storage):
shared_storage['current'] += self.number shared_storage['current'] += self.number
# post implicitly returns None
class MultiplyNode(Node): class MultiplyNode(Node):
def __init__(self, number): def __init__(self, number):
super().__init__() super().__init__()
self.number = number self.number = number
def prep(self, shared_storage): def prep(self, shared_storage):
shared_storage['current'] *= self.number shared_storage['current'] *= self.number
# post implicitly returns None
# --- New Nodes for Action Propagation Test ---
class SignalNode(Node):
"""A node that returns a specific signal string from its post method."""
def __init__(self, signal="default_signal"):
super().__init__()
self.signal = signal
# No prep needed usually if just signaling
def post(self, shared_storage, prep_result, exec_result):
# Store the signal in shared storage for verification
shared_storage['last_signal_emitted'] = self.signal
return self.signal # Return the specific action string
class PathNode(Node):
"""A node to indicate which path was taken in the outer flow."""
def __init__(self, path_id):
super().__init__()
self.path_id = path_id
def prep(self, shared_storage):
shared_storage['path_taken'] = self.path_id
# post implicitly returns None
# --- Test Class ---
class TestFlowComposition(unittest.TestCase): class TestFlowComposition(unittest.TestCase):
# --- Existing Tests (Unchanged) ---
def test_flow_as_node(self): def test_flow_as_node(self):
""" """
1) Create a Flow (f1) starting with NumberNode(5), then AddNode(10), then MultiplyNode(2). 1) Create a Flow (f1) starting with NumberNode(5), then AddNode(10), then MultiplyNode(2).
@ -41,18 +65,11 @@ class TestFlowComposition(unittest.TestCase):
Expected final result in shared_storage['current']: (5 + 10) * 2 = 30. Expected final result in shared_storage['current']: (5 + 10) * 2 = 30.
""" """
shared_storage = {} shared_storage = {}
# Inner flow f1
f1 = Flow(start=NumberNode(5)) f1 = Flow(start=NumberNode(5))
f1 >> AddNode(10) >> MultiplyNode(2) f1 >> AddNode(10) >> MultiplyNode(2)
# f2 starts with f1
f2 = Flow(start=f1) f2 = Flow(start=f1)
# Wrapper flow f3 to ensure proper execution
f3 = Flow(start=f2) f3 = Flow(start=f2)
f3.run(shared_storage) f3.run(shared_storage)
self.assertEqual(shared_storage['current'], 30) self.assertEqual(shared_storage['current'], 30)
def test_nested_flow(self): def test_nested_flow(self):
@ -64,19 +81,12 @@ class TestFlowComposition(unittest.TestCase):
Expected final result: (5 + 3) * 4 = 32. Expected final result: (5 + 3) * 4 = 32.
""" """
shared_storage = {} shared_storage = {}
# Build the inner flow
inner_flow = Flow(start=NumberNode(5)) inner_flow = Flow(start=NumberNode(5))
inner_flow >> AddNode(3) inner_flow >> AddNode(3)
# Build the middle flow, whose start is the inner flow
middle_flow = Flow(start=inner_flow) middle_flow = Flow(start=inner_flow)
middle_flow >> MultiplyNode(4) middle_flow >> MultiplyNode(4)
# Wrapper flow to ensure proper execution
wrapper_flow = Flow(start=middle_flow) wrapper_flow = Flow(start=middle_flow)
wrapper_flow.run(shared_storage) wrapper_flow.run(shared_storage)
self.assertEqual(shared_storage['current'], 32) self.assertEqual(shared_storage['current'], 32)
def test_flow_chaining_flows(self): def test_flow_chaining_flows(self):
@ -88,23 +98,54 @@ class TestFlowComposition(unittest.TestCase):
Expected final result: (10 + 10) * 2 = 40. Expected final result: (10 + 10) * 2 = 40.
""" """
shared_storage = {} shared_storage = {}
# flow1
numbernode = NumberNode(10) numbernode = NumberNode(10)
numbernode >> AddNode(10) numbernode >> AddNode(10)
flow1 = Flow(start=numbernode) flow1 = Flow(start=numbernode)
# flow2
flow2 = Flow(start=MultiplyNode(2)) flow2 = Flow(start=MultiplyNode(2))
flow1 >> flow2 # Default transition based on flow1 returning None
# Chain flow1 to flow2
flow1 >> flow2
# Wrapper flow to ensure proper execution
wrapper_flow = Flow(start=flow1) wrapper_flow = Flow(start=flow1)
wrapper_flow.run(shared_storage) wrapper_flow.run(shared_storage)
self.assertEqual(shared_storage['current'], 40) self.assertEqual(shared_storage['current'], 40)
def test_composition_with_action_propagation(self):
"""
Test that an outer flow can branch based on the action returned
by the last node's post() within an inner flow.
"""
shared_storage = {}
# 1. Define an inner flow that ends with a node returning a specific action
inner_start_node = NumberNode(100) # current = 100, post -> None
inner_end_node = SignalNode("inner_done") # post -> "inner_done"
inner_start_node >> inner_end_node
# Inner flow will execute start->end, and the Flow's execution will return "inner_done"
inner_flow = Flow(start=inner_start_node)
# 2. Define target nodes for the outer flow branches
path_a_node = PathNode("A") # post -> None
path_b_node = PathNode("B") # post -> None
# 3. Define the outer flow starting with the inner flow
outer_flow = Flow()
outer_flow.start(inner_flow) # Use the start() method
# 4. Define branches FROM the inner_flow object based on its returned action
inner_flow - "inner_done" >> path_b_node # This path should be taken
inner_flow - "other_action" >> path_a_node # This path should NOT be taken
# 5. Run the outer flow and capture the last action
# Execution: inner_start -> inner_end -> path_b
last_action_outer = outer_flow.run(shared_storage)
# 6. Assert the results
# Check state after inner flow execution
self.assertEqual(shared_storage.get('current'), 100)
self.assertEqual(shared_storage.get('last_signal_emitted'), "inner_done")
# Check that the correct outer path was taken
self.assertEqual(shared_storage.get('path_taken'), "B")
# Check the action returned by the outer flow. The last node executed was
# path_b_node, which returns None from its post method.
self.assertIsNone(last_action_outer)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()