tests
This commit is contained in:
parent
881a903e2f
commit
5f95f23bc9
|
|
@ -6,10 +6,10 @@ class BaseNode:
|
||||||
# process(): this is for the LLM call, and should be idempotent for retries
|
# process(): this is for the LLM call, and should be idempotent for retries
|
||||||
# postprocess(): this is to summarize the result and retrun the condition for the successor node
|
# postprocess(): this is to summarize the result and retrun the condition for the successor node
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.parameters, self.successors = {}, {}
|
self.params, self.successors = {}, {}
|
||||||
|
|
||||||
def set_parameters(self, params): # make sure params is immutable
|
def set_params(self, params): # make sure params is immutable
|
||||||
self.parameters = params # must be immutable during pre/post/process
|
self.params = params # must be immutable during pre/post/process
|
||||||
|
|
||||||
def add_successor(self, node, condition="default"):
|
def add_successor(self, node, condition="default"):
|
||||||
if condition in self.successors:
|
if condition in self.successors:
|
||||||
|
|
@ -30,12 +30,12 @@ class BaseNode:
|
||||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||||
return "default" # condition for next node
|
return "default" # condition for next node
|
||||||
|
|
||||||
def _run(self, shared_storage=None):
|
def _run(self, shared_storage):
|
||||||
prep = self.preprocess(shared_storage)
|
prep_result = self.preprocess(shared_storage)
|
||||||
proc = self._process(shared_storage, prep)
|
proc_result = self._process(shared_storage, prep_result)
|
||||||
return self.postprocess(shared_storage, prep, proc)
|
return self.postprocess(shared_storage, prep_result, proc_result)
|
||||||
|
|
||||||
def run(self, shared_storage=None):
|
def run(self, shared_storage):
|
||||||
if self.successors:
|
if self.successors:
|
||||||
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent Flow and use that Flow.run() instead.")
|
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent Flow and use that Flow.run() instead.")
|
||||||
return self._run(shared_storage)
|
return self._run(shared_storage)
|
||||||
|
|
@ -91,37 +91,28 @@ class BatchNode(Node):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
class AsyncNode(Node):
|
class AsyncNode(Node):
|
||||||
"""
|
|
||||||
A Node whose postprocess step is async.
|
|
||||||
You can also override process() to be async if needed.
|
|
||||||
"""
|
|
||||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||||
# Not used in async workflow; define postprocess_async() instead.
|
|
||||||
raise NotImplementedError("AsyncNode requires postprocess_async, and should be run in an AsyncFlow")
|
raise NotImplementedError("AsyncNode requires postprocess_async, and should be run in an AsyncFlow")
|
||||||
|
|
||||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||||
"""
|
|
||||||
Async version of postprocess. By default, returns "default".
|
|
||||||
Override as needed.
|
|
||||||
"""
|
|
||||||
await asyncio.sleep(0) # trivial async pause (no-op)
|
await asyncio.sleep(0) # trivial async pause (no-op)
|
||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
async def run_async(self, shared_storage=None):
|
async def run_async(self, shared_storage):
|
||||||
if self.successors:
|
if self.successors:
|
||||||
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent AsyncFlow and use that AsyncFlow.run_async() instead.")
|
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent AsyncFlow and use that AsyncFlow.run_async() instead.")
|
||||||
return await self._run_async(shared_storage)
|
return await self._run_async(shared_storage)
|
||||||
|
|
||||||
async def _run_async(self, shared_storage=None):
|
async def _run_async(self, shared_storage):
|
||||||
prep = self.preprocess(shared_storage)
|
prep_result = self.preprocess(shared_storage)
|
||||||
proc = self._process(shared_storage, prep)
|
proc_result = self._process(shared_storage, prep_result)
|
||||||
return await self.postprocess_async(shared_storage, prep, proc)
|
return await self.postprocess_async(shared_storage, prep_result, proc_result)
|
||||||
|
|
||||||
def _run(self, shared_storage=None):
|
def _run(self, shared_storage):
|
||||||
raise RuntimeError("AsyncNode requires run_async, and should be run in an AsyncFlow")
|
raise RuntimeError("AsyncNode requires asynchronous execution. Use 'await node.run_async()' if inside an async function, or 'asyncio.run(node.run_async())' if in synchronous code.")
|
||||||
|
|
||||||
class BaseFlow(BaseNode):
|
class BaseFlow(BaseNode):
|
||||||
def __init__(self, start_node=None):
|
def __init__(self, start_node):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.start_node = start_node
|
self.start_node = start_node
|
||||||
|
|
||||||
|
|
@ -134,28 +125,26 @@ class BaseFlow(BaseNode):
|
||||||
return next_node
|
return next_node
|
||||||
|
|
||||||
class Flow(BaseFlow):
|
class Flow(BaseFlow):
|
||||||
def _process_flow(self, shared_storage):
|
def _process(self, shared_storage, params=None):
|
||||||
current_node = self.start_node
|
current_node = self.start_node
|
||||||
|
params = params if params is not None else self.params.copy()
|
||||||
|
|
||||||
while current_node:
|
while current_node:
|
||||||
# Pass down the Flow's parameters to the current node
|
current_node.set_params(params)
|
||||||
current_node.set_parameters(self.parameters)
|
|
||||||
# Synchronous run
|
|
||||||
condition = current_node._run(shared_storage)
|
condition = current_node._run(shared_storage)
|
||||||
# Decide next node
|
|
||||||
current_node = self.get_next_node(current_node, condition)
|
current_node = self.get_next_node(current_node, condition)
|
||||||
|
|
||||||
def _run(self, shared_storage=None):
|
def process(self, shared_storage, prep_result):
|
||||||
prep_result = self.preprocess(shared_storage)
|
raise NotImplementedError("Flow should not process directly")
|
||||||
self._process_flow(shared_storage)
|
|
||||||
return self.postprocess(shared_storage, prep_result, None)
|
|
||||||
|
|
||||||
class AsyncFlow(BaseFlow):
|
class AsyncFlow(BaseFlow, AsyncNode):
|
||||||
async def _process_flow_async(self, shared_storage):
|
async def _process_async(self, shared_storage, params=None):
|
||||||
current_node = self.start_node
|
current_node = self.start_node
|
||||||
while current_node:
|
params = params if params is not None else self.params.copy()
|
||||||
current_node.set_parameters(self.parameters)
|
|
||||||
|
while current_node:
|
||||||
|
current_node.set_params(params)
|
||||||
|
|
||||||
# If node is async-capable, call run_async; otherwise run sync
|
|
||||||
if hasattr(current_node, "run_async") and callable(current_node.run_async):
|
if hasattr(current_node, "run_async") and callable(current_node.run_async):
|
||||||
condition = await current_node._run_async(shared_storage)
|
condition = await current_node._run_async(shared_storage)
|
||||||
else:
|
else:
|
||||||
|
|
@ -163,53 +152,33 @@ class AsyncFlow(BaseFlow):
|
||||||
|
|
||||||
current_node = self.get_next_node(current_node, condition)
|
current_node = self.get_next_node(current_node, condition)
|
||||||
|
|
||||||
async def _run_async(self, shared_storage=None):
|
async def _run_async(self, shared_storage):
|
||||||
prep_result = self.preprocess(shared_storage)
|
prep_result = self.preprocess(shared_storage)
|
||||||
await self._process_flow_async(shared_storage)
|
await self._process_async(shared_storage)
|
||||||
return self.postprocess(shared_storage, prep_result, None)
|
return await self.postprocess_async(shared_storage, prep_result, None)
|
||||||
|
|
||||||
def _run(self, shared_storage=None):
|
|
||||||
try:
|
|
||||||
return asyncio.run(self._run_async(shared_storage))
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise RuntimeError("If you are running in Jupyter, please use `await run_async()` instead of `run()`.") from e
|
|
||||||
|
|
||||||
class BaseBatchFlow(BaseFlow):
|
class BaseBatchFlow(BaseFlow):
|
||||||
def preprocess(self, shared_storage):
|
def preprocess(self, shared_storage):
|
||||||
return []
|
return [] # return an iterable of parameter dictionaries
|
||||||
|
|
||||||
class BatchFlow(BaseBatchFlow, Flow):
|
class BatchFlow(BaseBatchFlow, Flow):
|
||||||
def _run(self, shared_storage=None):
|
def _run(self, shared_storage):
|
||||||
prep_result = self.preprocess(shared_storage)
|
prep_result = self.preprocess(shared_storage)
|
||||||
all_results = []
|
|
||||||
|
|
||||||
# For each set of parameters (or items) we got from preprocess
|
|
||||||
for param_dict in prep_result:
|
for param_dict in prep_result:
|
||||||
# Merge param_dict into the Flow's parameters
|
merged_params = self.params.copy()
|
||||||
original_params = self.parameters.copy()
|
merged_params.update(param_dict)
|
||||||
self.parameters.update(param_dict)
|
self._process(shared_storage, params=merged_params)
|
||||||
|
|
||||||
# Run from the start node to end
|
return self.postprocess(shared_storage, prep_result, None)
|
||||||
self._process_flow(shared_storage)
|
|
||||||
|
|
||||||
# Optionally collect results from shared_storage or a custom method
|
|
||||||
all_results.append(f"Finished run with parameters: {param_dict}")
|
|
||||||
|
|
||||||
# Reset the parameters if needed
|
|
||||||
self.parameters = original_params
|
|
||||||
|
|
||||||
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
||||||
async def _run_async(self, shared_storage=None):
|
async def _run_async(self, shared_storage):
|
||||||
prep_result = self.preprocess(shared_storage)
|
prep_result = self.preprocess(shared_storage)
|
||||||
all_results = []
|
|
||||||
|
|
||||||
for param_dict in prep_result:
|
for param_dict in prep_result:
|
||||||
original_params = self.parameters.copy()
|
merged_params = self.params.copy()
|
||||||
self.parameters.update(param_dict)
|
merged_params.update(param_dict)
|
||||||
|
await self._process_async(shared_storage, params=merged_params)
|
||||||
|
|
||||||
await self._process_flow_async(shared_storage)
|
return await self.postprocess_async(shared_storage, prep_result, None)
|
||||||
|
|
||||||
all_results.append(f"Finished async run with parameters: {param_dict}")
|
|
||||||
|
|
||||||
# Reset back to original parameters if needed
|
|
||||||
self.parameters = original_params
|
|
||||||
|
|
@ -97,7 +97,7 @@ class TestAsyncFlow(unittest.TestCase):
|
||||||
|
|
||||||
# We'll run the flow synchronously (which under the hood is asyncio.run())
|
# We'll run the flow synchronously (which under the hood is asyncio.run())
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
flow.run(shared_storage)
|
asyncio.run(flow.run_async(shared_storage))
|
||||||
|
|
||||||
self.assertEqual(shared_storage['current'], 6)
|
self.assertEqual(shared_storage['current'], 6)
|
||||||
|
|
||||||
|
|
@ -144,7 +144,7 @@ class TestAsyncFlow(unittest.TestCase):
|
||||||
start_node - "negative_branch" >> negative_node
|
start_node - "negative_branch" >> negative_node
|
||||||
|
|
||||||
flow = AsyncFlow(start_node)
|
flow = AsyncFlow(start_node)
|
||||||
flow.run(shared_storage)
|
asyncio.run(flow.run_async(shared_storage))
|
||||||
|
|
||||||
self.assertEqual(shared_storage["path"], "positive",
|
self.assertEqual(shared_storage["path"], "positive",
|
||||||
"Should have taken the positive branch")
|
"Should have taken the positive branch")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,184 @@
|
||||||
|
import unittest
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
from minillmflow import Node, BatchFlow, Flow
|
||||||
|
|
||||||
|
class DataProcessNode(Node):
|
||||||
|
def process(self, shared_storage, prep_result):
|
||||||
|
key = self.params.get('key')
|
||||||
|
data = shared_storage['input_data'][key]
|
||||||
|
if 'results' not in shared_storage:
|
||||||
|
shared_storage['results'] = {}
|
||||||
|
shared_storage['results'][key] = data * 2
|
||||||
|
|
||||||
|
class ErrorProcessNode(Node):
|
||||||
|
def process(self, shared_storage, prep_result):
|
||||||
|
key = self.params.get('key')
|
||||||
|
if key == 'error_key':
|
||||||
|
raise ValueError(f"Error processing key: {key}")
|
||||||
|
if 'results' not in shared_storage:
|
||||||
|
shared_storage['results'] = {}
|
||||||
|
shared_storage['results'][key] = True
|
||||||
|
|
||||||
|
class TestBatchFlow(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.process_node = DataProcessNode()
|
||||||
|
|
||||||
|
def test_basic_batch_processing(self):
|
||||||
|
"""Test basic batch processing with multiple keys"""
|
||||||
|
class SimpleTestBatchFlow(BatchFlow):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_data': {
|
||||||
|
'a': 1,
|
||||||
|
'b': 2,
|
||||||
|
'c': 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flow = SimpleTestBatchFlow(start_node=self.process_node)
|
||||||
|
flow.run(shared_storage)
|
||||||
|
|
||||||
|
expected_results = {
|
||||||
|
'a': 2,
|
||||||
|
'b': 4,
|
||||||
|
'c': 6
|
||||||
|
}
|
||||||
|
self.assertEqual(shared_storage['results'], expected_results)
|
||||||
|
|
||||||
|
def test_empty_input(self):
|
||||||
|
"""Test batch processing with empty input dictionary"""
|
||||||
|
class EmptyTestBatchFlow(BatchFlow):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_data': {}
|
||||||
|
}
|
||||||
|
|
||||||
|
flow = EmptyTestBatchFlow(start_node=self.process_node)
|
||||||
|
flow.run(shared_storage)
|
||||||
|
|
||||||
|
self.assertEqual(shared_storage.get('results', {}), {})
|
||||||
|
|
||||||
|
def test_single_item(self):
|
||||||
|
"""Test batch processing with single item"""
|
||||||
|
class SingleItemBatchFlow(BatchFlow):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_data': {
|
||||||
|
'single': 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flow = SingleItemBatchFlow(start_node=self.process_node)
|
||||||
|
flow.run(shared_storage)
|
||||||
|
|
||||||
|
expected_results = {
|
||||||
|
'single': 10
|
||||||
|
}
|
||||||
|
self.assertEqual(shared_storage['results'], expected_results)
|
||||||
|
|
||||||
|
def test_error_handling(self):
|
||||||
|
"""Test error handling during batch processing"""
|
||||||
|
class ErrorTestBatchFlow(BatchFlow):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_data': {
|
||||||
|
'normal_key': 1,
|
||||||
|
'error_key': 2,
|
||||||
|
'another_key': 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flow = ErrorTestBatchFlow(start_node=ErrorProcessNode())
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
flow.run(shared_storage)
|
||||||
|
|
||||||
|
def test_nested_flow(self):
|
||||||
|
"""Test batch processing with nested flows"""
|
||||||
|
class InnerNode(Node):
|
||||||
|
def process(self, shared_storage, prep_result):
|
||||||
|
key = self.params.get('key')
|
||||||
|
if 'intermediate_results' not in shared_storage:
|
||||||
|
shared_storage['intermediate_results'] = {}
|
||||||
|
shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1
|
||||||
|
|
||||||
|
class OuterNode(Node):
|
||||||
|
def process(self, shared_storage, prep_result):
|
||||||
|
key = self.params.get('key')
|
||||||
|
if 'results' not in shared_storage:
|
||||||
|
shared_storage['results'] = {}
|
||||||
|
shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2
|
||||||
|
|
||||||
|
class NestedBatchFlow(BatchFlow):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
|
# Create inner flow
|
||||||
|
inner_node = InnerNode()
|
||||||
|
outer_node = OuterNode()
|
||||||
|
inner_node >> outer_node
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_data': {
|
||||||
|
'x': 1,
|
||||||
|
'y': 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flow = NestedBatchFlow(start_node=inner_node)
|
||||||
|
flow.run(shared_storage)
|
||||||
|
|
||||||
|
expected_results = {
|
||||||
|
'x': 4, # (1 + 1) * 2
|
||||||
|
'y': 6 # (2 + 1) * 2
|
||||||
|
}
|
||||||
|
self.assertEqual(shared_storage['results'], expected_results)
|
||||||
|
|
||||||
|
def test_custom_parameters(self):
|
||||||
|
"""Test batch processing with additional custom parameters"""
|
||||||
|
class CustomParamNode(Node):
|
||||||
|
def process(self, shared_storage, prep_result):
|
||||||
|
key = self.params.get('key')
|
||||||
|
multiplier = self.params.get('multiplier', 1)
|
||||||
|
if 'results' not in shared_storage:
|
||||||
|
shared_storage['results'] = {}
|
||||||
|
shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier
|
||||||
|
|
||||||
|
class CustomParamBatchFlow(BatchFlow):
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
return [{
|
||||||
|
'key': k,
|
||||||
|
'multiplier': i + 1
|
||||||
|
} for i, k in enumerate(shared_storage['input_data'].keys())]
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_data': {
|
||||||
|
'a': 1,
|
||||||
|
'b': 2,
|
||||||
|
'c': 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flow = CustomParamBatchFlow(start_node=CustomParamNode())
|
||||||
|
flow.run(shared_storage)
|
||||||
|
|
||||||
|
expected_results = {
|
||||||
|
'a': 1 * 1, # first item, multiplier = 1
|
||||||
|
'b': 2 * 2, # second item, multiplier = 2
|
||||||
|
'c': 3 * 3 # third item, multiplier = 3
|
||||||
|
}
|
||||||
|
self.assertEqual(shared_storage['results'], expected_results)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
|
@ -0,0 +1,162 @@
|
||||||
|
import unittest
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
from minillmflow import Node, BatchNode, Flow
|
||||||
|
|
||||||
|
class ArrayChunkNode(BatchNode):
|
||||||
|
def __init__(self, chunk_size=10):
|
||||||
|
super().__init__()
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
|
def preprocess(self, shared_storage):
|
||||||
|
# Get array from shared storage and split into chunks
|
||||||
|
array = shared_storage.get('input_array', [])
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(array), self.chunk_size):
|
||||||
|
end = min(i + self.chunk_size, len(array))
|
||||||
|
chunks.append((i, end))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def process(self, shared_storage, chunk_indices):
|
||||||
|
start, end = chunk_indices
|
||||||
|
array = shared_storage['input_array']
|
||||||
|
# Process the chunk and return its sum
|
||||||
|
chunk_sum = sum(array[start:end])
|
||||||
|
return chunk_sum
|
||||||
|
|
||||||
|
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||||
|
# Store chunk results in shared storage
|
||||||
|
shared_storage['chunk_results'] = proc_result
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
class SumReduceNode(Node):
|
||||||
|
def process(self, shared_storage, data):
|
||||||
|
# Get chunk results from shared storage and sum them
|
||||||
|
chunk_results = shared_storage.get('chunk_results', [])
|
||||||
|
total = sum(chunk_results)
|
||||||
|
shared_storage['total'] = total
|
||||||
|
|
||||||
|
class TestBatchNode(unittest.TestCase):
|
||||||
|
def test_array_chunking(self):
|
||||||
|
"""
|
||||||
|
Test that the array is correctly split into chunks
|
||||||
|
"""
|
||||||
|
shared_storage = {
|
||||||
|
'input_array': list(range(25)) # [0,1,2,...,24]
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_node = ArrayChunkNode(chunk_size=10)
|
||||||
|
chunks = chunk_node.preprocess(shared_storage)
|
||||||
|
|
||||||
|
self.assertEqual(chunks, [(0, 10), (10, 20), (20, 25)])
|
||||||
|
|
||||||
|
def test_map_reduce_sum(self):
|
||||||
|
"""
|
||||||
|
Test a complete map-reduce pipeline that sums a large array:
|
||||||
|
1. Map: Split array into chunks and sum each chunk
|
||||||
|
2. Reduce: Sum all the chunk sums
|
||||||
|
"""
|
||||||
|
# Create test array: [0,1,2,...,99]
|
||||||
|
array = list(range(100))
|
||||||
|
expected_sum = sum(array) # 4950
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_array': array
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create nodes
|
||||||
|
chunk_node = ArrayChunkNode(chunk_size=10)
|
||||||
|
reduce_node = SumReduceNode()
|
||||||
|
|
||||||
|
# Connect nodes
|
||||||
|
chunk_node >> reduce_node
|
||||||
|
|
||||||
|
# Create and run pipeline
|
||||||
|
pipeline = Flow(start_node=chunk_node)
|
||||||
|
pipeline.run(shared_storage)
|
||||||
|
|
||||||
|
self.assertEqual(shared_storage['total'], expected_sum)
|
||||||
|
|
||||||
|
def test_uneven_chunks(self):
|
||||||
|
"""
|
||||||
|
Test that the map-reduce works correctly with array lengths
|
||||||
|
that don't divide evenly by chunk_size
|
||||||
|
"""
|
||||||
|
array = list(range(25))
|
||||||
|
expected_sum = sum(array) # 300
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_array': array
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_node = ArrayChunkNode(chunk_size=10)
|
||||||
|
reduce_node = SumReduceNode()
|
||||||
|
|
||||||
|
chunk_node >> reduce_node
|
||||||
|
pipeline = Flow(start_node=chunk_node)
|
||||||
|
pipeline.run(shared_storage)
|
||||||
|
|
||||||
|
self.assertEqual(shared_storage['total'], expected_sum)
|
||||||
|
|
||||||
|
def test_custom_chunk_size(self):
|
||||||
|
"""
|
||||||
|
Test that the map-reduce works with different chunk sizes
|
||||||
|
"""
|
||||||
|
array = list(range(100))
|
||||||
|
expected_sum = sum(array)
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_array': array
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use chunk_size=15 instead of default 10
|
||||||
|
chunk_node = ArrayChunkNode(chunk_size=15)
|
||||||
|
reduce_node = SumReduceNode()
|
||||||
|
|
||||||
|
chunk_node >> reduce_node
|
||||||
|
pipeline = Flow(start_node=chunk_node)
|
||||||
|
pipeline.run(shared_storage)
|
||||||
|
|
||||||
|
self.assertEqual(shared_storage['total'], expected_sum)
|
||||||
|
|
||||||
|
def test_single_element_chunks(self):
|
||||||
|
"""
|
||||||
|
Test extreme case where chunk_size=1
|
||||||
|
"""
|
||||||
|
array = list(range(5))
|
||||||
|
expected_sum = sum(array)
|
||||||
|
|
||||||
|
shared_storage = {
|
||||||
|
'input_array': array
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_node = ArrayChunkNode(chunk_size=1)
|
||||||
|
reduce_node = SumReduceNode()
|
||||||
|
|
||||||
|
chunk_node >> reduce_node
|
||||||
|
pipeline = Flow(start_node=chunk_node)
|
||||||
|
pipeline.run(shared_storage)
|
||||||
|
|
||||||
|
self.assertEqual(shared_storage['total'], expected_sum)
|
||||||
|
|
||||||
|
def test_empty_array(self):
|
||||||
|
"""
|
||||||
|
Test edge case of empty input array
|
||||||
|
"""
|
||||||
|
shared_storage = {
|
||||||
|
'input_array': []
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_node = ArrayChunkNode(chunk_size=10)
|
||||||
|
reduce_node = SumReduceNode()
|
||||||
|
|
||||||
|
chunk_node >> reduce_node
|
||||||
|
pipeline = Flow(start_node=chunk_node)
|
||||||
|
pipeline.run(shared_storage)
|
||||||
|
|
||||||
|
self.assertEqual(shared_storage['total'], 0)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
|
@ -2,10 +2,11 @@ import unittest
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
from minillmflow import Node, Flow
|
from minillmflow import Node, Flow
|
||||||
|
|
||||||
|
# Simple example Nodes
|
||||||
class NumberNode(Node):
|
class NumberNode(Node):
|
||||||
def __init__(self, number):
|
def __init__(self, number):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -30,106 +31,79 @@ class MultiplyNode(Node):
|
||||||
def process(self, shared_storage, prep_result):
|
def process(self, shared_storage, prep_result):
|
||||||
shared_storage['current'] *= self.number
|
shared_storage['current'] *= self.number
|
||||||
|
|
||||||
|
|
||||||
class TestFlowComposition(unittest.TestCase):
|
class TestFlowComposition(unittest.TestCase):
|
||||||
|
|
||||||
def test_flow_as_node(self):
|
def test_flow_as_node(self):
|
||||||
"""
|
"""
|
||||||
Demonstrates that a Flow can itself be chained like a Node.
|
1) Create a Flow (f1) starting with NumberNode(5), then AddNode(10), then MultiplyNode(2).
|
||||||
We create a flow (f1) that starts with NumberNode(5) -> AddNode(10).
|
2) Create a second Flow (f2) whose start_node is f1.
|
||||||
Then we chain f1 >> MultiplyNode(2).
|
3) Create a wrapper Flow (f3) that contains f2 to ensure proper execution.
|
||||||
|
Expected final result in shared_storage['current']: (5 + 10) * 2 = 30.
|
||||||
Expected result after running from f1:
|
|
||||||
start = 5
|
|
||||||
5 + 10 = 15
|
|
||||||
15 * 2 = 30
|
|
||||||
"""
|
"""
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
|
|
||||||
# Inner flow f1
|
# Inner flow f1
|
||||||
f1 = Flow(start_node=NumberNode(5))
|
f1 = Flow(start_node=NumberNode(5))
|
||||||
f1 >> AddNode(10)
|
f1 >> AddNode(10) >> MultiplyNode(2)
|
||||||
|
|
||||||
# Then chain a node after the flow
|
# f2 starts with f1
|
||||||
f1 >> MultiplyNode(2)
|
f2 = Flow(start_node=f1)
|
||||||
|
|
||||||
# Run from f1
|
# Wrapper flow f3 to ensure proper execution
|
||||||
f1.run(shared_storage)
|
f3 = Flow(start_node=f2)
|
||||||
|
f3.run(shared_storage)
|
||||||
|
|
||||||
self.assertEqual(shared_storage['current'], 30)
|
self.assertEqual(shared_storage['current'], 30)
|
||||||
|
|
||||||
def test_nested_flow(self):
|
def test_nested_flow(self):
|
||||||
"""
|
"""
|
||||||
Demonstrates embedding one Flow inside another Flow.
|
Demonstrates nested flows with proper wrapping:
|
||||||
inner_flow: NumberNode(5) -> AddNode(3)
|
inner_flow: NumberNode(5) -> AddNode(3)
|
||||||
outer_flow: starts with inner_flow -> MultiplyNode(4)
|
middle_flow: starts with inner_flow -> MultiplyNode(4)
|
||||||
|
wrapper_flow: contains middle_flow to ensure proper execution
|
||||||
Expected result:
|
Expected final result: (5 + 3) * 4 = 32.
|
||||||
(5 + 3) * 4 = 32
|
|
||||||
"""
|
"""
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
|
|
||||||
# Define an inner flow
|
# Build the inner flow
|
||||||
inner_flow = Flow(start_node=NumberNode(5))
|
inner_flow = Flow(start_node=NumberNode(5))
|
||||||
inner_flow >> AddNode(3)
|
inner_flow >> AddNode(3)
|
||||||
|
|
||||||
# Define an outer flow, whose start node is inner_flow
|
# Build the middle flow, whose start_node is the inner flow
|
||||||
outer_flow = Flow(start_node=inner_flow)
|
middle_flow = Flow(start_node=inner_flow)
|
||||||
outer_flow >> MultiplyNode(4)
|
middle_flow >> MultiplyNode(4)
|
||||||
|
|
||||||
# Run outer_flow
|
# Wrapper flow to ensure proper execution
|
||||||
outer_flow.run(shared_storage)
|
wrapper_flow = Flow(start_node=middle_flow)
|
||||||
|
wrapper_flow.run(shared_storage)
|
||||||
|
|
||||||
self.assertEqual(shared_storage['current'], 32) # (5+3)*4=32
|
self.assertEqual(shared_storage['current'], 32)
|
||||||
|
|
||||||
def test_flow_chaining_flows(self):
|
def test_flow_chaining_flows(self):
|
||||||
"""
|
"""
|
||||||
Demonstrates chaining one flow to another flow.
|
Demonstrates chaining two flows with proper wrapping:
|
||||||
flow1: NumberNode(10) -> AddNode(10) # final shared_storage['current'] = 20
|
flow1: NumberNode(10) -> AddNode(10) # final = 20
|
||||||
flow2: MultiplyNode(2) # final shared_storage['current'] = 40
|
flow2: MultiplyNode(2) # final = 40
|
||||||
|
wrapper_flow: contains both flow1 and flow2 to ensure proper execution
|
||||||
flow1 >> flow2 means once flow1 finishes, flow2 starts.
|
Expected final result: (10 + 10) * 2 = 40.
|
||||||
|
|
||||||
Expected result: (10 + 10) * 2 = 40
|
|
||||||
"""
|
"""
|
||||||
shared_storage = {}
|
shared_storage = {}
|
||||||
|
|
||||||
# flow1
|
# flow1
|
||||||
flow1 = Flow(start_node=NumberNode(10))
|
numbernode = NumberNode(10)
|
||||||
flow1 >> AddNode(10)
|
numbernode >> AddNode(10)
|
||||||
|
flow1 = Flow(start_node=numbernode)
|
||||||
|
|
||||||
# flow2
|
# flow2
|
||||||
flow2 = Flow(start_node=MultiplyNode(2))
|
flow2 = Flow(start_node=MultiplyNode(2))
|
||||||
|
|
||||||
# Chain them: flow1 >> flow2
|
# Chain flow1 to flow2
|
||||||
flow1 >> flow2
|
flow1 >> flow2
|
||||||
|
|
||||||
# Start running from flow1
|
# Wrapper flow to ensure proper execution
|
||||||
flow1.run(shared_storage)
|
wrapper_flow = Flow(start_node=flow1)
|
||||||
|
wrapper_flow.run(shared_storage)
|
||||||
|
|
||||||
self.assertEqual(shared_storage['current'], 40)
|
self.assertEqual(shared_storage['current'], 40)
|
||||||
|
|
||||||
def test_flow_with_parameters(self):
|
|
||||||
"""
|
|
||||||
Demonstrates passing parameters into a Flow (and retrieved by a Node).
|
|
||||||
"""
|
|
||||||
|
|
||||||
class ParamNode(Node):
|
|
||||||
def process(self, shared_storage, prep_result):
|
|
||||||
# Reads 'level' from the node's (or flow's) parameters
|
|
||||||
shared_storage['param'] = self.parameters.get('level', 'no param')
|
|
||||||
|
|
||||||
shared_storage = {}
|
|
||||||
|
|
||||||
# Create a flow with a ParamNode
|
|
||||||
f = Flow(start_node=ParamNode())
|
|
||||||
# Set parameters on the flow
|
|
||||||
f.parameters = {'level': 'Level 1'}
|
|
||||||
|
|
||||||
f.run(shared_storage)
|
|
||||||
|
|
||||||
self.assertEqual(shared_storage['param'], 'Level 1')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
Loading…
Reference in New Issue