update the pocketflow design
This commit is contained in:
parent
3c32908212
commit
cf4875710b
|
|
@ -37,17 +37,17 @@ class BatchNode(Node):
|
|||
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])]
|
||||
|
||||
class Flow(BaseNode):
|
||||
def __init__(self,start): super().__init__(); self.start=start
|
||||
def start(self,start): self.start=start; return start
|
||||
def __init__(self,start=None): super().__init__(); self.start_node=start
|
||||
def start(self,start): self.start_node=start; return start
|
||||
def get_next_node(self,curr,action):
|
||||
nxt=curr.successors.get(action or "default")
|
||||
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
|
||||
return nxt
|
||||
def _orch(self,shared,params=None):
|
||||
curr,p,last_action =copy.copy(self.start),(params or {**self.params}),None
|
||||
curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
|
||||
while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
|
||||
return last_action
|
||||
def _run(self,shared): pr=self.prep(shared); self._orch(shared); return self.post(shared,pr,None)
|
||||
def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o)
|
||||
def post(self,shared,prep_res,exec_res): return exec_res
|
||||
|
||||
class BatchFlow(Flow):
|
||||
|
|
@ -81,10 +81,10 @@ class AsyncParallelBatchNode(AsyncNode,BatchNode):
|
|||
|
||||
class AsyncFlow(Flow,AsyncNode):
|
||||
async def _orch_async(self,shared,params=None):
|
||||
curr,p,last_action =copy.copy(self.start),(params or {**self.params}),None
|
||||
curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
|
||||
while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
|
||||
return last_action
|
||||
async def _run_async(self,shared): p=await self.prep_async(shared); await self._orch_async(shared); return await self.post_async(shared,p,None)
|
||||
async def _run_async(self,shared): p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o)
|
||||
async def post_async(self,shared,prep_res,exec_res): return exec_res
|
||||
|
||||
class AsyncBatchFlow(AsyncFlow,BatchFlow):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from pathlib import Path
|
|||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from pocketflow import Node, AsyncNode, AsyncFlow
|
||||
|
||||
|
||||
class AsyncNumberNode(AsyncNode):
|
||||
"""
|
||||
Simple async node that sets 'current' to a given number.
|
||||
|
|
@ -29,7 +28,6 @@ class AsyncNumberNode(AsyncNode):
|
|||
# Return a condition for the flow
|
||||
return "number_set"
|
||||
|
||||
|
||||
class AsyncIncrementNode(AsyncNode):
|
||||
"""
|
||||
Demonstrates incrementing the 'current' value asynchronously.
|
||||
|
|
@ -42,6 +40,37 @@ class AsyncIncrementNode(AsyncNode):
|
|||
await asyncio.sleep(0.01) # simulate async I/O
|
||||
return "done"
|
||||
|
||||
class AsyncSignalNode(AsyncNode):
|
||||
""" An async node that returns a specific signal string from post_async. """
|
||||
def __init__(self, signal="default_async_signal"):
|
||||
super().__init__()
|
||||
self.signal = signal
|
||||
|
||||
# No prep needed usually if just signaling
|
||||
async def prep_async(self, shared_storage):
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
|
||||
async def post_async(self, shared_storage, prep_result, exec_result):
|
||||
# Store the signal in shared storage for verification
|
||||
shared_storage['last_async_signal_emitted'] = self.signal
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
print(self.signal)
|
||||
return self.signal # Return the specific action string
|
||||
|
||||
class AsyncPathNode(AsyncNode):
|
||||
""" An async node to indicate which path was taken in the outer flow. """
|
||||
def __init__(self, path_id):
|
||||
super().__init__()
|
||||
self.path_id = path_id
|
||||
|
||||
async def prep_async(self, shared_storage):
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
shared_storage['async_path_taken'] = self.path_id
|
||||
|
||||
# post_async implicitly returns None (for default transition out if needed)
|
||||
async def post_async(self, shared_storage, prep_result, exec_result):
|
||||
await asyncio.sleep(0.01)
|
||||
# Return None by default
|
||||
|
||||
class TestAsyncNode(unittest.TestCase):
|
||||
"""
|
||||
|
|
@ -149,6 +178,51 @@ class TestAsyncFlow(unittest.TestCase):
|
|||
self.assertEqual(shared_storage["path"], "positive",
|
||||
"Should have taken the positive branch")
|
||||
|
||||
def test_async_composition_with_action_propagation(self):
|
||||
"""
|
||||
Test AsyncFlow branches based on action from nested AsyncFlow's last node.
|
||||
"""
|
||||
async def run_test():
|
||||
shared_storage = {}
|
||||
|
||||
# 1. Define an inner async flow ending with AsyncSignalNode
|
||||
# Use existing AsyncNumberNode which should return None from post_async implicitly
|
||||
inner_start_node = AsyncNumberNode(200)
|
||||
inner_end_node = AsyncSignalNode("async_inner_done") # post_async -> "async_inner_done"
|
||||
inner_start_node - "number_set" >> inner_end_node
|
||||
# Inner flow will execute start->end, Flow exec returns "async_inner_done"
|
||||
inner_flow = AsyncFlow(start=inner_start_node)
|
||||
|
||||
# 2. Define target async nodes for the outer flow branches
|
||||
path_a_node = AsyncPathNode("AsyncA") # post_async -> None
|
||||
path_b_node = AsyncPathNode("AsyncB") # post_async -> None
|
||||
|
||||
# 3. Define the outer async flow starting with the inner async flow
|
||||
outer_flow = AsyncFlow(start=inner_flow)
|
||||
|
||||
# 4. Define branches FROM the inner_flow object based on its returned action
|
||||
inner_flow - "async_inner_done" >> path_b_node # This path should be taken
|
||||
inner_flow - "other_action" >> path_a_node # This path should NOT be taken
|
||||
|
||||
# 5. Run the outer async flow and capture the last action
|
||||
# Execution: inner_start -> inner_end -> path_b
|
||||
last_action_outer = await outer_flow.run_async(shared_storage)
|
||||
|
||||
# 6. Return results for assertion
|
||||
return shared_storage, last_action_outer
|
||||
|
||||
# Run the async test function
|
||||
shared_storage, last_action_outer = asyncio.run(run_test())
|
||||
|
||||
# 7. Assert the results
|
||||
# Check state after inner flow execution
|
||||
self.assertEqual(shared_storage.get('current'), 200) # From AsyncNumberNode
|
||||
self.assertEqual(shared_storage.get('last_async_signal_emitted'), "async_inner_done")
|
||||
# Check that the correct outer path was taken
|
||||
self.assertEqual(shared_storage.get('async_path_taken'), "AsyncB")
|
||||
# Check the action returned by the outer flow. The last node executed was
|
||||
# path_b_node, which returns None from its post_async method.
|
||||
self.assertIsNone(last_action_outer)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,155 +1,223 @@
|
|||
# tests/test_flow_basic.py
|
||||
import unittest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from pocketflow import Node, Flow
|
||||
|
||||
# --- Node Definitions ---
|
||||
# Nodes intended for default transitions (>>) should NOT return a specific
|
||||
# action string from post. Let it return None by default.
|
||||
# Nodes intended for conditional transitions (-) MUST return the action string.
|
||||
|
||||
class NumberNode(Node):
|
||||
def __init__(self, number):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def prep(self, shared_storage):
|
||||
shared_storage['current'] = self.number
|
||||
# post implicitly returns None - used for default transition
|
||||
|
||||
class AddNode(Node):
|
||||
def __init__(self, number):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def prep(self, shared_storage):
|
||||
shared_storage['current'] += self.number
|
||||
# post implicitly returns None - used for default transition
|
||||
|
||||
class MultiplyNode(Node):
|
||||
def __init__(self, number):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def prep(self, shared_storage):
|
||||
shared_storage['current'] *= self.number
|
||||
# post implicitly returns None - used for default transition
|
||||
|
||||
class CheckPositiveNode(Node):
|
||||
# This node IS designed for conditional branching
|
||||
def prep(self, shared_storage):
|
||||
pass
|
||||
def post(self, shared_storage, prep_result, proc_result):
|
||||
# MUST return the specific action string for branching
|
||||
if shared_storage['current'] >= 0:
|
||||
return 'positive'
|
||||
else:
|
||||
return 'negative'
|
||||
|
||||
class NoOpNode(Node):
|
||||
def prep(self, shared_storage):
|
||||
# Do nothing, just pass
|
||||
pass
|
||||
# Just a placeholder node
|
||||
pass # post implicitly returns None
|
||||
|
||||
class TestNode(unittest.TestCase):
|
||||
def test_single_number(self):
|
||||
class EndSignalNode(Node):
|
||||
# A node specifically to return a value when it's the end
|
||||
def __init__(self, signal="finished"):
|
||||
super().__init__()
|
||||
self.signal = signal
|
||||
def post(self, shared_storage, prep_result, exec_result):
|
||||
return self.signal # Return a specific signal
|
||||
|
||||
# --- Test Class ---
|
||||
class TestFlowBasic(unittest.TestCase):
|
||||
|
||||
def test_start_method_initialization(self):
|
||||
"""Test initializing flow with start() after creation."""
|
||||
shared_storage = {}
|
||||
start = NumberNode(5)
|
||||
pipeline = Flow(start=start)
|
||||
pipeline.run(shared_storage)
|
||||
n1 = NumberNode(5)
|
||||
pipeline = Flow()
|
||||
pipeline.start(n1)
|
||||
last_action = pipeline.run(shared_storage)
|
||||
self.assertEqual(shared_storage['current'], 5)
|
||||
# NumberNode.post returns None (default)
|
||||
self.assertIsNone(last_action)
|
||||
|
||||
def test_sequence(self):
|
||||
"""
|
||||
Test a simple linear pipeline:
|
||||
NumberNode(5) -> AddNode(3) -> MultiplyNode(2)
|
||||
def test_start_method_chaining(self):
|
||||
"""Test fluent chaining using start().next()..."""
|
||||
shared_storage = {}
|
||||
pipeline = Flow()
|
||||
# Chain: NumberNode -> AddNode -> MultiplyNode
|
||||
# All use default transitions (post returns None)
|
||||
pipeline.start(NumberNode(5)).next(AddNode(3)).next(MultiplyNode(2))
|
||||
last_action = pipeline.run(shared_storage)
|
||||
self.assertEqual(shared_storage['current'], 16)
|
||||
# Last node (MultiplyNode) post returns None
|
||||
self.assertIsNone(last_action)
|
||||
|
||||
Expected result:
|
||||
(5 + 3) * 2 = 16
|
||||
"""
|
||||
def test_sequence_with_rshift(self):
|
||||
"""Test a simple linear pipeline using >>"""
|
||||
shared_storage = {}
|
||||
n1 = NumberNode(5)
|
||||
n2 = AddNode(3)
|
||||
n3 = MultiplyNode(2)
|
||||
|
||||
# Chain them in sequence using the >> operator
|
||||
n1 >> n2 >> n3
|
||||
|
||||
pipeline = Flow(start=n1)
|
||||
pipeline.run(shared_storage)
|
||||
pipeline = Flow()
|
||||
# All default transitions (post returns None)
|
||||
pipeline.start(n1) >> n2 >> n3
|
||||
|
||||
last_action = pipeline.run(shared_storage)
|
||||
self.assertEqual(shared_storage['current'], 16)
|
||||
# Last node (n3: MultiplyNode) post returns None
|
||||
self.assertIsNone(last_action)
|
||||
|
||||
def test_branching_positive(self):
|
||||
"""
|
||||
Test a branching pipeline with positive route:
|
||||
start = NumberNode(5)
|
||||
check = CheckPositiveNode()
|
||||
if 'positive' -> AddNode(10)
|
||||
if 'negative' -> AddNode(-20)
|
||||
|
||||
Since we start with 5,
|
||||
check returns 'positive',
|
||||
so we add 10. Final result = 15.
|
||||
"""
|
||||
"""Test positive branch: CheckPositiveNode returns 'positive'"""
|
||||
shared_storage = {}
|
||||
start = NumberNode(5)
|
||||
check = CheckPositiveNode()
|
||||
add_if_positive = AddNode(10)
|
||||
add_if_negative = AddNode(-20)
|
||||
start_node = NumberNode(5) # post -> None
|
||||
check_node = CheckPositiveNode() # post -> 'positive' or 'negative'
|
||||
add_if_positive = AddNode(10) # post -> None
|
||||
add_if_negative = AddNode(-20) # post -> None (won't run)
|
||||
|
||||
start >> check
|
||||
pipeline = Flow()
|
||||
# start -> check (default); check branches on 'positive'/'negative'
|
||||
pipeline.start(start_node) >> check_node
|
||||
check_node - "positive" >> add_if_positive
|
||||
check_node - "negative" >> add_if_negative
|
||||
|
||||
# Use the new dash operator for condition
|
||||
check - "positive" >> add_if_positive
|
||||
check - "negative" >> add_if_negative
|
||||
# Execution: start_node -> check_node -> add_if_positive
|
||||
last_action = pipeline.run(shared_storage)
|
||||
self.assertEqual(shared_storage['current'], 15) # 5 + 10
|
||||
# Last node executed was add_if_positive, its post returns None
|
||||
self.assertIsNone(last_action)
|
||||
|
||||
pipeline = Flow(start=start)
|
||||
pipeline.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['current'], 15)
|
||||
|
||||
def test_negative_branch(self):
|
||||
"""
|
||||
Same branching pipeline, but starting with -5.
|
||||
That should return 'negative' from CheckPositiveNode
|
||||
and proceed to add_if_negative, i.e. add -20.
|
||||
|
||||
Final result: (-5) + (-20) = -25.
|
||||
"""
|
||||
def test_branching_negative(self):
|
||||
"""Test negative branch: CheckPositiveNode returns 'negative'"""
|
||||
shared_storage = {}
|
||||
start = NumberNode(-5)
|
||||
check = CheckPositiveNode()
|
||||
add_if_positive = AddNode(10)
|
||||
add_if_negative = AddNode(-20)
|
||||
start_node = NumberNode(-5) # post -> None
|
||||
check_node = CheckPositiveNode() # post -> 'positive' or 'negative'
|
||||
add_if_positive = AddNode(10) # post -> None (won't run)
|
||||
add_if_negative = AddNode(-20) # post -> None
|
||||
|
||||
# Build the flow
|
||||
start >> check
|
||||
check - "positive" >> add_if_positive
|
||||
check - "negative" >> add_if_negative
|
||||
pipeline = Flow()
|
||||
pipeline.start(start_node) >> check_node
|
||||
check_node - "positive" >> add_if_positive
|
||||
check_node - "negative" >> add_if_negative
|
||||
|
||||
pipeline = Flow(start=start)
|
||||
pipeline.run(shared_storage)
|
||||
# Execution: start_node -> check_node -> add_if_negative
|
||||
last_action = pipeline.run(shared_storage)
|
||||
self.assertEqual(shared_storage['current'], -25) # -5 + -20
|
||||
# Last node executed was add_if_negative, its post returns None
|
||||
self.assertIsNone(last_action)
|
||||
|
||||
# Should have gone down the 'negative' branch
|
||||
self.assertEqual(shared_storage['current'], -25)
|
||||
|
||||
def test_cycle_until_negative(self):
|
||||
"""
|
||||
Demonstrate a cyclical pipeline:
|
||||
Start with 10, check if positive -> subtract 3, then go back to check.
|
||||
Repeat until the number becomes negative, at which point pipeline ends.
|
||||
"""
|
||||
def test_cycle_until_negative_ends_with_signal(self):
|
||||
"""Test cycle, ending on a node that returns a signal"""
|
||||
shared_storage = {}
|
||||
n1 = NumberNode(10)
|
||||
check = CheckPositiveNode()
|
||||
subtract3 = AddNode(-3)
|
||||
no_op = NoOpNode() # Dummy node for the 'negative' branch
|
||||
n1 = NumberNode(10) # post -> None
|
||||
check = CheckPositiveNode() # post -> 'positive' or 'negative'
|
||||
subtract3 = AddNode(-3) # post -> None
|
||||
end_node = EndSignalNode("cycle_done") # post -> "cycle_done"
|
||||
|
||||
# Build the cycle:
|
||||
# n1 -> check -> if 'positive': subtract3 -> back to check
|
||||
n1 >> check
|
||||
pipeline = Flow()
|
||||
pipeline.start(n1) >> check
|
||||
# Branching from CheckPositiveNode
|
||||
check - 'positive' >> subtract3
|
||||
check - 'negative' >> end_node # End on negative branch
|
||||
# After subtracting, go back to check (default transition)
|
||||
subtract3 >> check
|
||||
|
||||
# Attach a no-op node on the negative branch to avoid warning
|
||||
check - 'negative' >> no_op
|
||||
# Execution: n1->check->sub3->check->sub3->check->sub3->check->sub3->check->end_node
|
||||
last_action = pipeline.run(shared_storage)
|
||||
self.assertEqual(shared_storage['current'], -2) # 10 -> 7 -> 4 -> 1 -> -2
|
||||
# Last node executed was end_node, its post returns "cycle_done"
|
||||
self.assertEqual(last_action, "cycle_done")
|
||||
|
||||
pipeline = Flow(start=n1)
|
||||
pipeline.run(shared_storage)
|
||||
def test_flow_ends_warning_default_missing(self):
|
||||
"""Test warning when default transition is needed but not found"""
|
||||
shared_storage = {}
|
||||
# Node that returns a specific action from post
|
||||
class ActionNode(Node):
|
||||
def post(self, *args): return "specific_action"
|
||||
start_node = ActionNode()
|
||||
next_node = NoOpNode()
|
||||
|
||||
# final result should be -2: (10 -> 7 -> 4 -> 1 -> -2)
|
||||
self.assertEqual(shared_storage['current'], -2)
|
||||
pipeline = Flow()
|
||||
pipeline.start(start_node)
|
||||
# Define successor only for the specific action
|
||||
start_node - "specific_action" >> next_node
|
||||
|
||||
# Make start_node return None instead, triggering default search
|
||||
start_node.post = lambda *args: None
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
# Run flow. start_node runs, post returns None.
|
||||
# Flow looks for "default", but only "specific_action" exists.
|
||||
last_action = pipeline.run(shared_storage)
|
||||
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertTrue(issubclass(w[-1].category, UserWarning))
|
||||
# Warning message should indicate "default" wasn't found
|
||||
self.assertIn("Flow ends: 'None' not found in ['specific_action']", str(w[-1].message))
|
||||
# Last action is from start_node's post
|
||||
self.assertIsNone(last_action)
|
||||
|
||||
def test_flow_ends_warning_specific_missing(self):
|
||||
"""Test warning when specific action is returned but not found"""
|
||||
shared_storage = {}
|
||||
# Node that returns a specific action from post
|
||||
class ActionNode(Node):
|
||||
def post(self, *args): return "specific_action"
|
||||
start_node = ActionNode()
|
||||
next_node = NoOpNode()
|
||||
|
||||
pipeline = Flow()
|
||||
pipeline.start(start_node)
|
||||
# Define successor only for "default"
|
||||
start_node >> next_node # same as start_node.next(next_node, "default")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
# Run flow. start_node runs, post returns "specific_action".
|
||||
# Flow looks for "specific_action", but only "default" exists.
|
||||
last_action = pipeline.run(shared_storage)
|
||||
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertTrue(issubclass(w[-1].category, UserWarning))
|
||||
# Warning message should indicate "specific_action" wasn't found
|
||||
self.assertIn("Flow ends: 'specific_action' not found in ['default']", str(w[-1].message))
|
||||
# Last action is from start_node's post
|
||||
self.assertEqual(last_action, "specific_action")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -1,38 +1,62 @@
|
|||
# tests/test_flow_composition.py
|
||||
import unittest
|
||||
import asyncio
|
||||
import asyncio # Keep import, might be needed if other tests use it indirectly
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from pocketflow import Node, Flow
|
||||
|
||||
# Simple example Nodes
|
||||
# --- Existing Nodes ---
|
||||
class NumberNode(Node):
|
||||
def __init__(self, number):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def prep(self, shared_storage):
|
||||
shared_storage['current'] = self.number
|
||||
# post implicitly returns None
|
||||
|
||||
class AddNode(Node):
|
||||
def __init__(self, number):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def prep(self, shared_storage):
|
||||
shared_storage['current'] += self.number
|
||||
# post implicitly returns None
|
||||
|
||||
class MultiplyNode(Node):
|
||||
def __init__(self, number):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def prep(self, shared_storage):
|
||||
shared_storage['current'] *= self.number
|
||||
# post implicitly returns None
|
||||
|
||||
# --- New Nodes for Action Propagation Test ---
|
||||
class SignalNode(Node):
|
||||
"""A node that returns a specific signal string from its post method."""
|
||||
def __init__(self, signal="default_signal"):
|
||||
super().__init__()
|
||||
self.signal = signal
|
||||
# No prep needed usually if just signaling
|
||||
def post(self, shared_storage, prep_result, exec_result):
|
||||
# Store the signal in shared storage for verification
|
||||
shared_storage['last_signal_emitted'] = self.signal
|
||||
return self.signal # Return the specific action string
|
||||
|
||||
class PathNode(Node):
|
||||
"""A node to indicate which path was taken in the outer flow."""
|
||||
def __init__(self, path_id):
|
||||
super().__init__()
|
||||
self.path_id = path_id
|
||||
def prep(self, shared_storage):
|
||||
shared_storage['path_taken'] = self.path_id
|
||||
# post implicitly returns None
|
||||
|
||||
# --- Test Class ---
|
||||
class TestFlowComposition(unittest.TestCase):
|
||||
|
||||
# --- Existing Tests (Unchanged) ---
|
||||
def test_flow_as_node(self):
|
||||
"""
|
||||
1) Create a Flow (f1) starting with NumberNode(5), then AddNode(10), then MultiplyNode(2).
|
||||
|
|
@ -41,18 +65,11 @@ class TestFlowComposition(unittest.TestCase):
|
|||
Expected final result in shared_storage['current']: (5 + 10) * 2 = 30.
|
||||
"""
|
||||
shared_storage = {}
|
||||
|
||||
# Inner flow f1
|
||||
f1 = Flow(start=NumberNode(5))
|
||||
f1 >> AddNode(10) >> MultiplyNode(2)
|
||||
|
||||
# f2 starts with f1
|
||||
f2 = Flow(start=f1)
|
||||
|
||||
# Wrapper flow f3 to ensure proper execution
|
||||
f3 = Flow(start=f2)
|
||||
f3.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['current'], 30)
|
||||
|
||||
def test_nested_flow(self):
|
||||
|
|
@ -64,19 +81,12 @@ class TestFlowComposition(unittest.TestCase):
|
|||
Expected final result: (5 + 3) * 4 = 32.
|
||||
"""
|
||||
shared_storage = {}
|
||||
|
||||
# Build the inner flow
|
||||
inner_flow = Flow(start=NumberNode(5))
|
||||
inner_flow >> AddNode(3)
|
||||
|
||||
# Build the middle flow, whose start is the inner flow
|
||||
middle_flow = Flow(start=inner_flow)
|
||||
middle_flow >> MultiplyNode(4)
|
||||
|
||||
# Wrapper flow to ensure proper execution
|
||||
wrapper_flow = Flow(start=middle_flow)
|
||||
wrapper_flow.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['current'], 32)
|
||||
|
||||
def test_flow_chaining_flows(self):
|
||||
|
|
@ -88,23 +98,54 @@ class TestFlowComposition(unittest.TestCase):
|
|||
Expected final result: (10 + 10) * 2 = 40.
|
||||
"""
|
||||
shared_storage = {}
|
||||
|
||||
# flow1
|
||||
numbernode = NumberNode(10)
|
||||
numbernode >> AddNode(10)
|
||||
flow1 = Flow(start=numbernode)
|
||||
|
||||
# flow2
|
||||
flow2 = Flow(start=MultiplyNode(2))
|
||||
|
||||
# Chain flow1 to flow2
|
||||
flow1 >> flow2
|
||||
|
||||
# Wrapper flow to ensure proper execution
|
||||
flow1 >> flow2 # Default transition based on flow1 returning None
|
||||
wrapper_flow = Flow(start=flow1)
|
||||
wrapper_flow.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['current'], 40)
|
||||
|
||||
def test_composition_with_action_propagation(self):
|
||||
"""
|
||||
Test that an outer flow can branch based on the action returned
|
||||
by the last node's post() within an inner flow.
|
||||
"""
|
||||
shared_storage = {}
|
||||
|
||||
# 1. Define an inner flow that ends with a node returning a specific action
|
||||
inner_start_node = NumberNode(100) # current = 100, post -> None
|
||||
inner_end_node = SignalNode("inner_done") # post -> "inner_done"
|
||||
inner_start_node >> inner_end_node
|
||||
# Inner flow will execute start->end, and the Flow's execution will return "inner_done"
|
||||
inner_flow = Flow(start=inner_start_node)
|
||||
|
||||
# 2. Define target nodes for the outer flow branches
|
||||
path_a_node = PathNode("A") # post -> None
|
||||
path_b_node = PathNode("B") # post -> None
|
||||
|
||||
# 3. Define the outer flow starting with the inner flow
|
||||
outer_flow = Flow()
|
||||
outer_flow.start(inner_flow) # Use the start() method
|
||||
|
||||
# 4. Define branches FROM the inner_flow object based on its returned action
|
||||
inner_flow - "inner_done" >> path_b_node # This path should be taken
|
||||
inner_flow - "other_action" >> path_a_node # This path should NOT be taken
|
||||
|
||||
# 5. Run the outer flow and capture the last action
|
||||
# Execution: inner_start -> inner_end -> path_b
|
||||
last_action_outer = outer_flow.run(shared_storage)
|
||||
|
||||
# 6. Assert the results
|
||||
# Check state after inner flow execution
|
||||
self.assertEqual(shared_storage.get('current'), 100)
|
||||
self.assertEqual(shared_storage.get('last_signal_emitted'), "inner_done")
|
||||
# Check that the correct outer path was taken
|
||||
self.assertEqual(shared_storage.get('path_taken'), "B")
|
||||
# Check the action returned by the outer flow. The last node executed was
|
||||
# path_b_node, which returns None from its post method.
|
||||
self.assertIsNone(last_action_outer)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue