import unittest import asyncio import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from pocketflow import Node, AsyncNode, AsyncFlow class AsyncNumberNode(AsyncNode): """ Simple async node that sets 'current' to a given number. Demonstrates overriding .process() (sync) and using post_async() for the async portion. """ def __init__(self, number): super().__init__() self.number = number async def prep_async(self, shared_storage): # Synchronous work is allowed inside an AsyncNode, # but final 'condition' is determined by post_async(). shared_storage['current'] = self.number return "set_number" async def post_async(self, shared_storage, prep_result, proc_result): # Possibly do asynchronous tasks here await asyncio.sleep(0.01) # Return a condition for the flow return "number_set" class AsyncIncrementNode(AsyncNode): """ Demonstrates incrementing the 'current' value asynchronously. """ async def prep_async(self, shared_storage): shared_storage['current'] = shared_storage.get('current', 0) + 1 return "incremented" async def post_async(self, shared_storage, prep_result, proc_result): await asyncio.sleep(0.01) # simulate async I/O return "done" class AsyncSignalNode(AsyncNode): """ An async node that returns a specific signal string from post_async. """ def __init__(self, signal="default_async_signal"): super().__init__() self.signal = signal # No prep needed usually if just signaling async def prep_async(self, shared_storage): await asyncio.sleep(0.01) # Simulate async work async def post_async(self, shared_storage, prep_result, exec_result): # Store the signal in shared storage for verification shared_storage['last_async_signal_emitted'] = self.signal await asyncio.sleep(0.01) # Simulate async work print(self.signal) return self.signal # Return the specific action string class AsyncPathNode(AsyncNode): """ An async node to indicate which path was taken in the outer flow. """ def __init__(self, path_id): super().__init__() self.path_id = path_id async def prep_async(self, shared_storage): await asyncio.sleep(0.01) # Simulate async work shared_storage['async_path_taken'] = self.path_id # post_async implicitly returns None (for default transition out if needed) async def post_async(self, shared_storage, prep_result, exec_result): await asyncio.sleep(0.01) # Return None by default class TestAsyncNode(unittest.TestCase): """ Test the AsyncNode (and descendants) in isolation (not in a flow). """ def test_async_number_node_direct_call(self): """ Even though AsyncNumberNode is designed for an async flow, we can still test it directly by calling run_async(). """ async def run_node(): node = AsyncNumberNode(42) shared_storage = {} condition = await node.run_async(shared_storage) return shared_storage, condition shared_storage, condition = asyncio.run(run_node()) self.assertEqual(shared_storage['current'], 42) self.assertEqual(condition, "number_set") def test_async_increment_node_direct_call(self): async def run_node(): node = AsyncIncrementNode() shared_storage = {'current': 10} condition = await node.run_async(shared_storage) return shared_storage, condition shared_storage, condition = asyncio.run(run_node()) self.assertEqual(shared_storage['current'], 11) self.assertEqual(condition, "done") class TestAsyncFlow(unittest.TestCase): """ Test how AsyncFlow orchestrates multiple async nodes. """ def test_simple_async_flow(self): """ Flow: 1) AsyncNumberNode(5) -> sets 'current' to 5 2) AsyncIncrementNode() -> increments 'current' to 6 """ # Create our nodes start = AsyncNumberNode(5) inc_node = AsyncIncrementNode() # Chain them: start >> inc_node start - "number_set" >> inc_node # Create an AsyncFlow with start flow = AsyncFlow(start) # We'll run the flow synchronously (which under the hood is asyncio.run()) shared_storage = {} asyncio.run(flow.run_async(shared_storage)) self.assertEqual(shared_storage['current'], 6) def test_async_flow_branching(self): """ Demonstrate a branching scenario where we return different conditions. For example, you could have an async node that returns "go_left" or "go_right" in post_async, but here we'll keep it simpler for demonstration. """ class BranchingAsyncNode(AsyncNode): def exec(self, data): value = shared_storage.get("value", 0) shared_storage["value"] = value # We'll decide branch based on whether 'value' is positive return None async def post_async(self, shared_storage, prep_result, proc_result): await asyncio.sleep(0.01) if shared_storage["value"] >= 0: return "positive_branch" else: return "negative_branch" class PositiveNode(Node): def exec(self, data): shared_storage["path"] = "positive" return None class NegativeNode(Node): def exec(self, data): shared_storage["path"] = "negative" return None shared_storage = {"value": 10} start = BranchingAsyncNode() positive_node = PositiveNode() negative_node = NegativeNode() # Condition-based chaining start - "positive_branch" >> positive_node start - "negative_branch" >> negative_node flow = AsyncFlow(start) asyncio.run(flow.run_async(shared_storage)) self.assertEqual(shared_storage["path"], "positive", "Should have taken the positive branch") def test_async_composition_with_action_propagation(self): """ Test AsyncFlow branches based on action from nested AsyncFlow's last node. """ async def run_test(): shared_storage = {} # 1. Define an inner async flow ending with AsyncSignalNode # Use existing AsyncNumberNode which should return None from post_async implicitly inner_start_node = AsyncNumberNode(200) inner_end_node = AsyncSignalNode("async_inner_done") # post_async -> "async_inner_done" inner_start_node - "number_set" >> inner_end_node # Inner flow will execute start->end, Flow exec returns "async_inner_done" inner_flow = AsyncFlow(start=inner_start_node) # 2. Define target async nodes for the outer flow branches path_a_node = AsyncPathNode("AsyncA") # post_async -> None path_b_node = AsyncPathNode("AsyncB") # post_async -> None # 3. Define the outer async flow starting with the inner async flow outer_flow = AsyncFlow(start=inner_flow) # 4. Define branches FROM the inner_flow object based on its returned action inner_flow - "async_inner_done" >> path_b_node # This path should be taken inner_flow - "other_action" >> path_a_node # This path should NOT be taken # 5. Run the outer async flow and capture the last action # Execution: inner_start -> inner_end -> path_b last_action_outer = await outer_flow.run_async(shared_storage) # 6. Return results for assertion return shared_storage, last_action_outer # Run the async test function shared_storage, last_action_outer = asyncio.run(run_test()) # 7. Assert the results # Check state after inner flow execution self.assertEqual(shared_storage.get('current'), 200) # From AsyncNumberNode self.assertEqual(shared_storage.get('last_async_signal_emitted'), "async_inner_done") # Check that the correct outer path was taken self.assertEqual(shared_storage.get('async_path_taken'), "AsyncB") # Check the action returned by the outer flow. The last node executed was # path_b_node, which returns None from its post_async method. self.assertIsNone(last_action_outer) if __name__ == '__main__': unittest.main()