tests
This commit is contained in:
parent
881a903e2f
commit
5f95f23bc9
|
|
@ -6,10 +6,10 @@ class BaseNode:
|
|||
# process(): this is for the LLM call, and should be idempotent for retries
|
||||
# postprocess(): this is to summarize the result and retrun the condition for the successor node
|
||||
def __init__(self):
|
||||
self.parameters, self.successors = {}, {}
|
||||
self.params, self.successors = {}, {}
|
||||
|
||||
def set_parameters(self, params): # make sure params is immutable
|
||||
self.parameters = params # must be immutable during pre/post/process
|
||||
def set_params(self, params): # make sure params is immutable
|
||||
self.params = params # must be immutable during pre/post/process
|
||||
|
||||
def add_successor(self, node, condition="default"):
|
||||
if condition in self.successors:
|
||||
|
|
@ -30,12 +30,12 @@ class BaseNode:
|
|||
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||
return "default" # condition for next node
|
||||
|
||||
def _run(self, shared_storage=None):
|
||||
prep = self.preprocess(shared_storage)
|
||||
proc = self._process(shared_storage, prep)
|
||||
return self.postprocess(shared_storage, prep, proc)
|
||||
def _run(self, shared_storage):
|
||||
prep_result = self.preprocess(shared_storage)
|
||||
proc_result = self._process(shared_storage, prep_result)
|
||||
return self.postprocess(shared_storage, prep_result, proc_result)
|
||||
|
||||
def run(self, shared_storage=None):
|
||||
def run(self, shared_storage):
|
||||
if self.successors:
|
||||
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent Flow and use that Flow.run() instead.")
|
||||
return self._run(shared_storage)
|
||||
|
|
@ -91,37 +91,28 @@ class BatchNode(Node):
|
|||
return results
|
||||
|
||||
class AsyncNode(Node):
|
||||
"""
|
||||
A Node whose postprocess step is async.
|
||||
You can also override process() to be async if needed.
|
||||
"""
|
||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||
# Not used in async workflow; define postprocess_async() instead.
|
||||
raise NotImplementedError("AsyncNode requires postprocess_async, and should be run in an AsyncFlow")
|
||||
|
||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||
"""
|
||||
Async version of postprocess. By default, returns "default".
|
||||
Override as needed.
|
||||
"""
|
||||
await asyncio.sleep(0) # trivial async pause (no-op)
|
||||
return "default"
|
||||
|
||||
async def run_async(self, shared_storage=None):
|
||||
async def run_async(self, shared_storage):
|
||||
if self.successors:
|
||||
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent AsyncFlow and use that AsyncFlow.run_async() instead.")
|
||||
return await self._run_async(shared_storage)
|
||||
|
||||
async def _run_async(self, shared_storage=None):
|
||||
prep = self.preprocess(shared_storage)
|
||||
proc = self._process(shared_storage, prep)
|
||||
return await self.postprocess_async(shared_storage, prep, proc)
|
||||
async def _run_async(self, shared_storage):
|
||||
prep_result = self.preprocess(shared_storage)
|
||||
proc_result = self._process(shared_storage, prep_result)
|
||||
return await self.postprocess_async(shared_storage, prep_result, proc_result)
|
||||
|
||||
def _run(self, shared_storage=None):
|
||||
raise RuntimeError("AsyncNode requires run_async, and should be run in an AsyncFlow")
|
||||
def _run(self, shared_storage):
|
||||
raise RuntimeError("AsyncNode requires asynchronous execution. Use 'await node.run_async()' if inside an async function, or 'asyncio.run(node.run_async())' if in synchronous code.")
|
||||
|
||||
class BaseFlow(BaseNode):
|
||||
def __init__(self, start_node=None):
|
||||
def __init__(self, start_node):
|
||||
super().__init__()
|
||||
self.start_node = start_node
|
||||
|
||||
|
|
@ -134,28 +125,26 @@ class BaseFlow(BaseNode):
|
|||
return next_node
|
||||
|
||||
class Flow(BaseFlow):
|
||||
def _process_flow(self, shared_storage):
|
||||
def _process(self, shared_storage, params=None):
|
||||
current_node = self.start_node
|
||||
params = params if params is not None else self.params.copy()
|
||||
|
||||
while current_node:
|
||||
# Pass down the Flow's parameters to the current node
|
||||
current_node.set_parameters(self.parameters)
|
||||
# Synchronous run
|
||||
current_node.set_params(params)
|
||||
condition = current_node._run(shared_storage)
|
||||
# Decide next node
|
||||
current_node = self.get_next_node(current_node, condition)
|
||||
|
||||
def _run(self, shared_storage=None):
|
||||
prep_result = self.preprocess(shared_storage)
|
||||
self._process_flow(shared_storage)
|
||||
return self.postprocess(shared_storage, prep_result, None)
|
||||
|
||||
class AsyncFlow(BaseFlow):
|
||||
async def _process_flow_async(self, shared_storage):
|
||||
|
||||
def process(self, shared_storage, prep_result):
|
||||
raise NotImplementedError("Flow should not process directly")
|
||||
|
||||
class AsyncFlow(BaseFlow, AsyncNode):
|
||||
async def _process_async(self, shared_storage, params=None):
|
||||
current_node = self.start_node
|
||||
params = params if params is not None else self.params.copy()
|
||||
|
||||
while current_node:
|
||||
current_node.set_parameters(self.parameters)
|
||||
current_node.set_params(params)
|
||||
|
||||
# If node is async-capable, call run_async; otherwise run sync
|
||||
if hasattr(current_node, "run_async") and callable(current_node.run_async):
|
||||
condition = await current_node._run_async(shared_storage)
|
||||
else:
|
||||
|
|
@ -163,53 +152,33 @@ class AsyncFlow(BaseFlow):
|
|||
|
||||
current_node = self.get_next_node(current_node, condition)
|
||||
|
||||
async def _run_async(self, shared_storage=None):
|
||||
async def _run_async(self, shared_storage):
|
||||
prep_result = self.preprocess(shared_storage)
|
||||
await self._process_flow_async(shared_storage)
|
||||
return self.postprocess(shared_storage, prep_result, None)
|
||||
await self._process_async(shared_storage)
|
||||
return await self.postprocess_async(shared_storage, prep_result, None)
|
||||
|
||||
def _run(self, shared_storage=None):
|
||||
try:
|
||||
return asyncio.run(self._run_async(shared_storage))
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError("If you are running in Jupyter, please use `await run_async()` instead of `run()`.") from e
|
||||
|
||||
class BaseBatchFlow(BaseFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
return []
|
||||
return [] # return an iterable of parameter dictionaries
|
||||
|
||||
class BatchFlow(BaseBatchFlow, Flow):
|
||||
def _run(self, shared_storage=None):
|
||||
def _run(self, shared_storage):
|
||||
prep_result = self.preprocess(shared_storage)
|
||||
all_results = []
|
||||
|
||||
# For each set of parameters (or items) we got from preprocess
|
||||
|
||||
for param_dict in prep_result:
|
||||
# Merge param_dict into the Flow's parameters
|
||||
original_params = self.parameters.copy()
|
||||
self.parameters.update(param_dict)
|
||||
|
||||
# Run from the start node to end
|
||||
self._process_flow(shared_storage)
|
||||
|
||||
# Optionally collect results from shared_storage or a custom method
|
||||
all_results.append(f"Finished run with parameters: {param_dict}")
|
||||
|
||||
# Reset the parameters if needed
|
||||
self.parameters = original_params
|
||||
merged_params = self.params.copy()
|
||||
merged_params.update(param_dict)
|
||||
self._process(shared_storage, params=merged_params)
|
||||
|
||||
return self.postprocess(shared_storage, prep_result, None)
|
||||
|
||||
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
||||
async def _run_async(self, shared_storage=None):
|
||||
async def _run_async(self, shared_storage):
|
||||
prep_result = self.preprocess(shared_storage)
|
||||
all_results = []
|
||||
|
||||
|
||||
for param_dict in prep_result:
|
||||
original_params = self.parameters.copy()
|
||||
self.parameters.update(param_dict)
|
||||
|
||||
await self._process_flow_async(shared_storage)
|
||||
|
||||
all_results.append(f"Finished async run with parameters: {param_dict}")
|
||||
|
||||
# Reset back to original parameters if needed
|
||||
self.parameters = original_params
|
||||
merged_params = self.params.copy()
|
||||
merged_params.update(param_dict)
|
||||
await self._process_async(shared_storage, params=merged_params)
|
||||
|
||||
return await self.postprocess_async(shared_storage, prep_result, None)
|
||||
|
|
@ -97,7 +97,7 @@ class TestAsyncFlow(unittest.TestCase):
|
|||
|
||||
# We'll run the flow synchronously (which under the hood is asyncio.run())
|
||||
shared_storage = {}
|
||||
flow.run(shared_storage)
|
||||
asyncio.run(flow.run_async(shared_storage))
|
||||
|
||||
self.assertEqual(shared_storage['current'], 6)
|
||||
|
||||
|
|
@ -144,7 +144,7 @@ class TestAsyncFlow(unittest.TestCase):
|
|||
start_node - "negative_branch" >> negative_node
|
||||
|
||||
flow = AsyncFlow(start_node)
|
||||
flow.run(shared_storage)
|
||||
asyncio.run(flow.run_async(shared_storage))
|
||||
|
||||
self.assertEqual(shared_storage["path"], "positive",
|
||||
"Should have taken the positive branch")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,184 @@
|
|||
import unittest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
from minillmflow import Node, BatchFlow, Flow
|
||||
|
||||
class DataProcessNode(Node):
|
||||
def process(self, shared_storage, prep_result):
|
||||
key = self.params.get('key')
|
||||
data = shared_storage['input_data'][key]
|
||||
if 'results' not in shared_storage:
|
||||
shared_storage['results'] = {}
|
||||
shared_storage['results'][key] = data * 2
|
||||
|
||||
class ErrorProcessNode(Node):
|
||||
def process(self, shared_storage, prep_result):
|
||||
key = self.params.get('key')
|
||||
if key == 'error_key':
|
||||
raise ValueError(f"Error processing key: {key}")
|
||||
if 'results' not in shared_storage:
|
||||
shared_storage['results'] = {}
|
||||
shared_storage['results'][key] = True
|
||||
|
||||
class TestBatchFlow(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.process_node = DataProcessNode()
|
||||
|
||||
def test_basic_batch_processing(self):
|
||||
"""Test basic batch processing with multiple keys"""
|
||||
class SimpleTestBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
shared_storage = {
|
||||
'input_data': {
|
||||
'a': 1,
|
||||
'b': 2,
|
||||
'c': 3
|
||||
}
|
||||
}
|
||||
|
||||
flow = SimpleTestBatchFlow(start_node=self.process_node)
|
||||
flow.run(shared_storage)
|
||||
|
||||
expected_results = {
|
||||
'a': 2,
|
||||
'b': 4,
|
||||
'c': 6
|
||||
}
|
||||
self.assertEqual(shared_storage['results'], expected_results)
|
||||
|
||||
def test_empty_input(self):
|
||||
"""Test batch processing with empty input dictionary"""
|
||||
class EmptyTestBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
shared_storage = {
|
||||
'input_data': {}
|
||||
}
|
||||
|
||||
flow = EmptyTestBatchFlow(start_node=self.process_node)
|
||||
flow.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage.get('results', {}), {})
|
||||
|
||||
def test_single_item(self):
|
||||
"""Test batch processing with single item"""
|
||||
class SingleItemBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
shared_storage = {
|
||||
'input_data': {
|
||||
'single': 5
|
||||
}
|
||||
}
|
||||
|
||||
flow = SingleItemBatchFlow(start_node=self.process_node)
|
||||
flow.run(shared_storage)
|
||||
|
||||
expected_results = {
|
||||
'single': 10
|
||||
}
|
||||
self.assertEqual(shared_storage['results'], expected_results)
|
||||
|
||||
def test_error_handling(self):
|
||||
"""Test error handling during batch processing"""
|
||||
class ErrorTestBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
shared_storage = {
|
||||
'input_data': {
|
||||
'normal_key': 1,
|
||||
'error_key': 2,
|
||||
'another_key': 3
|
||||
}
|
||||
}
|
||||
|
||||
flow = ErrorTestBatchFlow(start_node=ErrorProcessNode())
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
flow.run(shared_storage)
|
||||
|
||||
def test_nested_flow(self):
|
||||
"""Test batch processing with nested flows"""
|
||||
class InnerNode(Node):
|
||||
def process(self, shared_storage, prep_result):
|
||||
key = self.params.get('key')
|
||||
if 'intermediate_results' not in shared_storage:
|
||||
shared_storage['intermediate_results'] = {}
|
||||
shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1
|
||||
|
||||
class OuterNode(Node):
|
||||
def process(self, shared_storage, prep_result):
|
||||
key = self.params.get('key')
|
||||
if 'results' not in shared_storage:
|
||||
shared_storage['results'] = {}
|
||||
shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2
|
||||
|
||||
class NestedBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
# Create inner flow
|
||||
inner_node = InnerNode()
|
||||
outer_node = OuterNode()
|
||||
inner_node >> outer_node
|
||||
|
||||
shared_storage = {
|
||||
'input_data': {
|
||||
'x': 1,
|
||||
'y': 2
|
||||
}
|
||||
}
|
||||
|
||||
flow = NestedBatchFlow(start_node=inner_node)
|
||||
flow.run(shared_storage)
|
||||
|
||||
expected_results = {
|
||||
'x': 4, # (1 + 1) * 2
|
||||
'y': 6 # (2 + 1) * 2
|
||||
}
|
||||
self.assertEqual(shared_storage['results'], expected_results)
|
||||
|
||||
def test_custom_parameters(self):
|
||||
"""Test batch processing with additional custom parameters"""
|
||||
class CustomParamNode(Node):
|
||||
def process(self, shared_storage, prep_result):
|
||||
key = self.params.get('key')
|
||||
multiplier = self.params.get('multiplier', 1)
|
||||
if 'results' not in shared_storage:
|
||||
shared_storage['results'] = {}
|
||||
shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier
|
||||
|
||||
class CustomParamBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
return [{
|
||||
'key': k,
|
||||
'multiplier': i + 1
|
||||
} for i, k in enumerate(shared_storage['input_data'].keys())]
|
||||
|
||||
shared_storage = {
|
||||
'input_data': {
|
||||
'a': 1,
|
||||
'b': 2,
|
||||
'c': 3
|
||||
}
|
||||
}
|
||||
|
||||
flow = CustomParamBatchFlow(start_node=CustomParamNode())
|
||||
flow.run(shared_storage)
|
||||
|
||||
expected_results = {
|
||||
'a': 1 * 1, # first item, multiplier = 1
|
||||
'b': 2 * 2, # second item, multiplier = 2
|
||||
'c': 3 * 3 # third item, multiplier = 3
|
||||
}
|
||||
self.assertEqual(shared_storage['results'], expected_results)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
import unittest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
from minillmflow import Node, BatchNode, Flow
|
||||
|
||||
class ArrayChunkNode(BatchNode):
|
||||
def __init__(self, chunk_size=10):
|
||||
super().__init__()
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def preprocess(self, shared_storage):
|
||||
# Get array from shared storage and split into chunks
|
||||
array = shared_storage.get('input_array', [])
|
||||
chunks = []
|
||||
for i in range(0, len(array), self.chunk_size):
|
||||
end = min(i + self.chunk_size, len(array))
|
||||
chunks.append((i, end))
|
||||
return chunks
|
||||
|
||||
def process(self, shared_storage, chunk_indices):
|
||||
start, end = chunk_indices
|
||||
array = shared_storage['input_array']
|
||||
# Process the chunk and return its sum
|
||||
chunk_sum = sum(array[start:end])
|
||||
return chunk_sum
|
||||
|
||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||
# Store chunk results in shared storage
|
||||
shared_storage['chunk_results'] = proc_result
|
||||
return "default"
|
||||
|
||||
class SumReduceNode(Node):
|
||||
def process(self, shared_storage, data):
|
||||
# Get chunk results from shared storage and sum them
|
||||
chunk_results = shared_storage.get('chunk_results', [])
|
||||
total = sum(chunk_results)
|
||||
shared_storage['total'] = total
|
||||
|
||||
class TestBatchNode(unittest.TestCase):
|
||||
def test_array_chunking(self):
|
||||
"""
|
||||
Test that the array is correctly split into chunks
|
||||
"""
|
||||
shared_storage = {
|
||||
'input_array': list(range(25)) # [0,1,2,...,24]
|
||||
}
|
||||
|
||||
chunk_node = ArrayChunkNode(chunk_size=10)
|
||||
chunks = chunk_node.preprocess(shared_storage)
|
||||
|
||||
self.assertEqual(chunks, [(0, 10), (10, 20), (20, 25)])
|
||||
|
||||
def test_map_reduce_sum(self):
|
||||
"""
|
||||
Test a complete map-reduce pipeline that sums a large array:
|
||||
1. Map: Split array into chunks and sum each chunk
|
||||
2. Reduce: Sum all the chunk sums
|
||||
"""
|
||||
# Create test array: [0,1,2,...,99]
|
||||
array = list(range(100))
|
||||
expected_sum = sum(array) # 4950
|
||||
|
||||
shared_storage = {
|
||||
'input_array': array
|
||||
}
|
||||
|
||||
# Create nodes
|
||||
chunk_node = ArrayChunkNode(chunk_size=10)
|
||||
reduce_node = SumReduceNode()
|
||||
|
||||
# Connect nodes
|
||||
chunk_node >> reduce_node
|
||||
|
||||
# Create and run pipeline
|
||||
pipeline = Flow(start_node=chunk_node)
|
||||
pipeline.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['total'], expected_sum)
|
||||
|
||||
def test_uneven_chunks(self):
|
||||
"""
|
||||
Test that the map-reduce works correctly with array lengths
|
||||
that don't divide evenly by chunk_size
|
||||
"""
|
||||
array = list(range(25))
|
||||
expected_sum = sum(array) # 300
|
||||
|
||||
shared_storage = {
|
||||
'input_array': array
|
||||
}
|
||||
|
||||
chunk_node = ArrayChunkNode(chunk_size=10)
|
||||
reduce_node = SumReduceNode()
|
||||
|
||||
chunk_node >> reduce_node
|
||||
pipeline = Flow(start_node=chunk_node)
|
||||
pipeline.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['total'], expected_sum)
|
||||
|
||||
def test_custom_chunk_size(self):
|
||||
"""
|
||||
Test that the map-reduce works with different chunk sizes
|
||||
"""
|
||||
array = list(range(100))
|
||||
expected_sum = sum(array)
|
||||
|
||||
shared_storage = {
|
||||
'input_array': array
|
||||
}
|
||||
|
||||
# Use chunk_size=15 instead of default 10
|
||||
chunk_node = ArrayChunkNode(chunk_size=15)
|
||||
reduce_node = SumReduceNode()
|
||||
|
||||
chunk_node >> reduce_node
|
||||
pipeline = Flow(start_node=chunk_node)
|
||||
pipeline.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['total'], expected_sum)
|
||||
|
||||
def test_single_element_chunks(self):
|
||||
"""
|
||||
Test extreme case where chunk_size=1
|
||||
"""
|
||||
array = list(range(5))
|
||||
expected_sum = sum(array)
|
||||
|
||||
shared_storage = {
|
||||
'input_array': array
|
||||
}
|
||||
|
||||
chunk_node = ArrayChunkNode(chunk_size=1)
|
||||
reduce_node = SumReduceNode()
|
||||
|
||||
chunk_node >> reduce_node
|
||||
pipeline = Flow(start_node=chunk_node)
|
||||
pipeline.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['total'], expected_sum)
|
||||
|
||||
def test_empty_array(self):
|
||||
"""
|
||||
Test edge case of empty input array
|
||||
"""
|
||||
shared_storage = {
|
||||
'input_array': []
|
||||
}
|
||||
|
||||
chunk_node = ArrayChunkNode(chunk_size=10)
|
||||
reduce_node = SumReduceNode()
|
||||
|
||||
chunk_node >> reduce_node
|
||||
pipeline = Flow(start_node=chunk_node)
|
||||
pipeline.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['total'], 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -2,10 +2,11 @@ import unittest
|
|||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from minillmflow import Node, Flow
|
||||
|
||||
# Simple example Nodes
|
||||
class NumberNode(Node):
|
||||
def __init__(self, number):
|
||||
super().__init__()
|
||||
|
|
@ -30,106 +31,79 @@ class MultiplyNode(Node):
|
|||
def process(self, shared_storage, prep_result):
|
||||
shared_storage['current'] *= self.number
|
||||
|
||||
|
||||
class TestFlowComposition(unittest.TestCase):
|
||||
|
||||
def test_flow_as_node(self):
|
||||
"""
|
||||
Demonstrates that a Flow can itself be chained like a Node.
|
||||
We create a flow (f1) that starts with NumberNode(5) -> AddNode(10).
|
||||
Then we chain f1 >> MultiplyNode(2).
|
||||
|
||||
Expected result after running from f1:
|
||||
start = 5
|
||||
5 + 10 = 15
|
||||
15 * 2 = 30
|
||||
1) Create a Flow (f1) starting with NumberNode(5), then AddNode(10), then MultiplyNode(2).
|
||||
2) Create a second Flow (f2) whose start_node is f1.
|
||||
3) Create a wrapper Flow (f3) that contains f2 to ensure proper execution.
|
||||
Expected final result in shared_storage['current']: (5 + 10) * 2 = 30.
|
||||
"""
|
||||
shared_storage = {}
|
||||
|
||||
|
||||
# Inner flow f1
|
||||
f1 = Flow(start_node=NumberNode(5))
|
||||
f1 >> AddNode(10)
|
||||
|
||||
# Then chain a node after the flow
|
||||
f1 >> MultiplyNode(2)
|
||||
|
||||
# Run from f1
|
||||
f1.run(shared_storage)
|
||||
|
||||
f1 >> AddNode(10) >> MultiplyNode(2)
|
||||
|
||||
# f2 starts with f1
|
||||
f2 = Flow(start_node=f1)
|
||||
|
||||
# Wrapper flow f3 to ensure proper execution
|
||||
f3 = Flow(start_node=f2)
|
||||
f3.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['current'], 30)
|
||||
|
||||
def test_nested_flow(self):
|
||||
"""
|
||||
Demonstrates embedding one Flow inside another Flow.
|
||||
Demonstrates nested flows with proper wrapping:
|
||||
inner_flow: NumberNode(5) -> AddNode(3)
|
||||
outer_flow: starts with inner_flow -> MultiplyNode(4)
|
||||
|
||||
Expected result:
|
||||
(5 + 3) * 4 = 32
|
||||
middle_flow: starts with inner_flow -> MultiplyNode(4)
|
||||
wrapper_flow: contains middle_flow to ensure proper execution
|
||||
Expected final result: (5 + 3) * 4 = 32.
|
||||
"""
|
||||
shared_storage = {}
|
||||
|
||||
# Define an inner flow
|
||||
# Build the inner flow
|
||||
inner_flow = Flow(start_node=NumberNode(5))
|
||||
inner_flow >> AddNode(3)
|
||||
|
||||
# Define an outer flow, whose start node is inner_flow
|
||||
outer_flow = Flow(start_node=inner_flow)
|
||||
outer_flow >> MultiplyNode(4)
|
||||
|
||||
# Run outer_flow
|
||||
outer_flow.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['current'], 32) # (5+3)*4=32
|
||||
|
||||
# Build the middle flow, whose start_node is the inner flow
|
||||
middle_flow = Flow(start_node=inner_flow)
|
||||
middle_flow >> MultiplyNode(4)
|
||||
|
||||
# Wrapper flow to ensure proper execution
|
||||
wrapper_flow = Flow(start_node=middle_flow)
|
||||
wrapper_flow.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['current'], 32)
|
||||
|
||||
def test_flow_chaining_flows(self):
|
||||
"""
|
||||
Demonstrates chaining one flow to another flow.
|
||||
flow1: NumberNode(10) -> AddNode(10) # final shared_storage['current'] = 20
|
||||
flow2: MultiplyNode(2) # final shared_storage['current'] = 40
|
||||
|
||||
flow1 >> flow2 means once flow1 finishes, flow2 starts.
|
||||
|
||||
Expected result: (10 + 10) * 2 = 40
|
||||
Demonstrates chaining two flows with proper wrapping:
|
||||
flow1: NumberNode(10) -> AddNode(10) # final = 20
|
||||
flow2: MultiplyNode(2) # final = 40
|
||||
wrapper_flow: contains both flow1 and flow2 to ensure proper execution
|
||||
Expected final result: (10 + 10) * 2 = 40.
|
||||
"""
|
||||
shared_storage = {}
|
||||
|
||||
# flow1
|
||||
flow1 = Flow(start_node=NumberNode(10))
|
||||
flow1 >> AddNode(10)
|
||||
numbernode = NumberNode(10)
|
||||
numbernode >> AddNode(10)
|
||||
flow1 = Flow(start_node=numbernode)
|
||||
|
||||
# flow2
|
||||
flow2 = Flow(start_node=MultiplyNode(2))
|
||||
|
||||
# Chain them: flow1 >> flow2
|
||||
# Chain flow1 to flow2
|
||||
flow1 >> flow2
|
||||
|
||||
# Start running from flow1
|
||||
flow1.run(shared_storage)
|
||||
|
||||
# Wrapper flow to ensure proper execution
|
||||
wrapper_flow = Flow(start_node=flow1)
|
||||
wrapper_flow.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['current'], 40)
|
||||
|
||||
def test_flow_with_parameters(self):
|
||||
"""
|
||||
Demonstrates passing parameters into a Flow (and retrieved by a Node).
|
||||
"""
|
||||
|
||||
class ParamNode(Node):
|
||||
def process(self, shared_storage, prep_result):
|
||||
# Reads 'level' from the node's (or flow's) parameters
|
||||
shared_storage['param'] = self.parameters.get('level', 'no param')
|
||||
|
||||
shared_storage = {}
|
||||
|
||||
# Create a flow with a ParamNode
|
||||
f = Flow(start_node=ParamNode())
|
||||
# Set parameters on the flow
|
||||
f.parameters = {'level': 'Level 1'}
|
||||
|
||||
f.run(shared_storage)
|
||||
|
||||
self.assertEqual(shared_storage['param'], 'Level 1')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue