This commit is contained in:
zachary62 2024-12-26 01:38:54 +00:00
parent 881a903e2f
commit 5f95f23bc9
5 changed files with 440 additions and 151 deletions

View File

@ -6,10 +6,10 @@ class BaseNode:
# process(): this is for the LLM call, and should be idempotent for retries
# postprocess(): this is to summarize the result and retrun the condition for the successor node
def __init__(self):
self.parameters, self.successors = {}, {}
self.params, self.successors = {}, {}
def set_parameters(self, params): # make sure params is immutable
self.parameters = params # must be immutable during pre/post/process
def set_params(self, params): # make sure params is immutable
self.params = params # must be immutable during pre/post/process
def add_successor(self, node, condition="default"):
if condition in self.successors:
@ -30,12 +30,12 @@ class BaseNode:
def postprocess(self, shared_storage, prep_result, proc_result):
return "default" # condition for next node
def _run(self, shared_storage=None):
prep = self.preprocess(shared_storage)
proc = self._process(shared_storage, prep)
return self.postprocess(shared_storage, prep, proc)
def _run(self, shared_storage):
prep_result = self.preprocess(shared_storage)
proc_result = self._process(shared_storage, prep_result)
return self.postprocess(shared_storage, prep_result, proc_result)
def run(self, shared_storage=None):
def run(self, shared_storage):
if self.successors:
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent Flow and use that Flow.run() instead.")
return self._run(shared_storage)
@ -91,37 +91,28 @@ class BatchNode(Node):
return results
class AsyncNode(Node):
"""
A Node whose postprocess step is async.
You can also override process() to be async if needed.
"""
def postprocess(self, shared_storage, prep_result, proc_result):
# Not used in async workflow; define postprocess_async() instead.
raise NotImplementedError("AsyncNode requires postprocess_async, and should be run in an AsyncFlow")
async def postprocess_async(self, shared_storage, prep_result, proc_result):
"""
Async version of postprocess. By default, returns "default".
Override as needed.
"""
await asyncio.sleep(0) # trivial async pause (no-op)
return "default"
async def run_async(self, shared_storage=None):
async def run_async(self, shared_storage):
if self.successors:
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent AsyncFlow and use that AsyncFlow.run_async() instead.")
return await self._run_async(shared_storage)
async def _run_async(self, shared_storage=None):
prep = self.preprocess(shared_storage)
proc = self._process(shared_storage, prep)
return await self.postprocess_async(shared_storage, prep, proc)
async def _run_async(self, shared_storage):
prep_result = self.preprocess(shared_storage)
proc_result = self._process(shared_storage, prep_result)
return await self.postprocess_async(shared_storage, prep_result, proc_result)
def _run(self, shared_storage=None):
raise RuntimeError("AsyncNode requires run_async, and should be run in an AsyncFlow")
def _run(self, shared_storage):
raise RuntimeError("AsyncNode requires asynchronous execution. Use 'await node.run_async()' if inside an async function, or 'asyncio.run(node.run_async())' if in synchronous code.")
class BaseFlow(BaseNode):
def __init__(self, start_node=None):
def __init__(self, start_node):
super().__init__()
self.start_node = start_node
@ -134,28 +125,26 @@ class BaseFlow(BaseNode):
return next_node
class Flow(BaseFlow):
def _process_flow(self, shared_storage):
def _process(self, shared_storage, params=None):
current_node = self.start_node
params = params if params is not None else self.params.copy()
while current_node:
# Pass down the Flow's parameters to the current node
current_node.set_parameters(self.parameters)
# Synchronous run
current_node.set_params(params)
condition = current_node._run(shared_storage)
# Decide next node
current_node = self.get_next_node(current_node, condition)
def _run(self, shared_storage=None):
prep_result = self.preprocess(shared_storage)
self._process_flow(shared_storage)
return self.postprocess(shared_storage, prep_result, None)
class AsyncFlow(BaseFlow):
async def _process_flow_async(self, shared_storage):
def process(self, shared_storage, prep_result):
raise NotImplementedError("Flow should not process directly")
class AsyncFlow(BaseFlow, AsyncNode):
async def _process_async(self, shared_storage, params=None):
current_node = self.start_node
params = params if params is not None else self.params.copy()
while current_node:
current_node.set_parameters(self.parameters)
current_node.set_params(params)
# If node is async-capable, call run_async; otherwise run sync
if hasattr(current_node, "run_async") and callable(current_node.run_async):
condition = await current_node._run_async(shared_storage)
else:
@ -163,53 +152,33 @@ class AsyncFlow(BaseFlow):
current_node = self.get_next_node(current_node, condition)
async def _run_async(self, shared_storage=None):
async def _run_async(self, shared_storage):
prep_result = self.preprocess(shared_storage)
await self._process_flow_async(shared_storage)
return self.postprocess(shared_storage, prep_result, None)
await self._process_async(shared_storage)
return await self.postprocess_async(shared_storage, prep_result, None)
def _run(self, shared_storage=None):
try:
return asyncio.run(self._run_async(shared_storage))
except RuntimeError as e:
raise RuntimeError("If you are running in Jupyter, please use `await run_async()` instead of `run()`.") from e
class BaseBatchFlow(BaseFlow):
def preprocess(self, shared_storage):
return []
return [] # return an iterable of parameter dictionaries
class BatchFlow(BaseBatchFlow, Flow):
def _run(self, shared_storage=None):
def _run(self, shared_storage):
prep_result = self.preprocess(shared_storage)
all_results = []
# For each set of parameters (or items) we got from preprocess
for param_dict in prep_result:
# Merge param_dict into the Flow's parameters
original_params = self.parameters.copy()
self.parameters.update(param_dict)
# Run from the start node to end
self._process_flow(shared_storage)
# Optionally collect results from shared_storage or a custom method
all_results.append(f"Finished run with parameters: {param_dict}")
# Reset the parameters if needed
self.parameters = original_params
merged_params = self.params.copy()
merged_params.update(param_dict)
self._process(shared_storage, params=merged_params)
return self.postprocess(shared_storage, prep_result, None)
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
async def _run_async(self, shared_storage=None):
async def _run_async(self, shared_storage):
prep_result = self.preprocess(shared_storage)
all_results = []
for param_dict in prep_result:
original_params = self.parameters.copy()
self.parameters.update(param_dict)
await self._process_flow_async(shared_storage)
all_results.append(f"Finished async run with parameters: {param_dict}")
# Reset back to original parameters if needed
self.parameters = original_params
merged_params = self.params.copy()
merged_params.update(param_dict)
await self._process_async(shared_storage, params=merged_params)
return await self.postprocess_async(shared_storage, prep_result, None)

View File

@ -97,7 +97,7 @@ class TestAsyncFlow(unittest.TestCase):
# We'll run the flow synchronously (which under the hood is asyncio.run())
shared_storage = {}
flow.run(shared_storage)
asyncio.run(flow.run_async(shared_storage))
self.assertEqual(shared_storage['current'], 6)
@ -144,7 +144,7 @@ class TestAsyncFlow(unittest.TestCase):
start_node - "negative_branch" >> negative_node
flow = AsyncFlow(start_node)
flow.run(shared_storage)
asyncio.run(flow.run_async(shared_storage))
self.assertEqual(shared_storage["path"], "positive",
"Should have taken the positive branch")

184
tests/test_batch_flow.py Normal file
View File

@ -0,0 +1,184 @@
import unittest
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from minillmflow import Node, BatchFlow, Flow
class DataProcessNode(Node):
def process(self, shared_storage, prep_result):
key = self.params.get('key')
data = shared_storage['input_data'][key]
if 'results' not in shared_storage:
shared_storage['results'] = {}
shared_storage['results'][key] = data * 2
class ErrorProcessNode(Node):
def process(self, shared_storage, prep_result):
key = self.params.get('key')
if key == 'error_key':
raise ValueError(f"Error processing key: {key}")
if 'results' not in shared_storage:
shared_storage['results'] = {}
shared_storage['results'][key] = True
class TestBatchFlow(unittest.TestCase):
def setUp(self):
self.process_node = DataProcessNode()
def test_basic_batch_processing(self):
"""Test basic batch processing with multiple keys"""
class SimpleTestBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = {
'input_data': {
'a': 1,
'b': 2,
'c': 3
}
}
flow = SimpleTestBatchFlow(start_node=self.process_node)
flow.run(shared_storage)
expected_results = {
'a': 2,
'b': 4,
'c': 6
}
self.assertEqual(shared_storage['results'], expected_results)
def test_empty_input(self):
"""Test batch processing with empty input dictionary"""
class EmptyTestBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = {
'input_data': {}
}
flow = EmptyTestBatchFlow(start_node=self.process_node)
flow.run(shared_storage)
self.assertEqual(shared_storage.get('results', {}), {})
def test_single_item(self):
"""Test batch processing with single item"""
class SingleItemBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = {
'input_data': {
'single': 5
}
}
flow = SingleItemBatchFlow(start_node=self.process_node)
flow.run(shared_storage)
expected_results = {
'single': 10
}
self.assertEqual(shared_storage['results'], expected_results)
def test_error_handling(self):
"""Test error handling during batch processing"""
class ErrorTestBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = {
'input_data': {
'normal_key': 1,
'error_key': 2,
'another_key': 3
}
}
flow = ErrorTestBatchFlow(start_node=ErrorProcessNode())
with self.assertRaises(ValueError):
flow.run(shared_storage)
def test_nested_flow(self):
"""Test batch processing with nested flows"""
class InnerNode(Node):
def process(self, shared_storage, prep_result):
key = self.params.get('key')
if 'intermediate_results' not in shared_storage:
shared_storage['intermediate_results'] = {}
shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1
class OuterNode(Node):
def process(self, shared_storage, prep_result):
key = self.params.get('key')
if 'results' not in shared_storage:
shared_storage['results'] = {}
shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2
class NestedBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
# Create inner flow
inner_node = InnerNode()
outer_node = OuterNode()
inner_node >> outer_node
shared_storage = {
'input_data': {
'x': 1,
'y': 2
}
}
flow = NestedBatchFlow(start_node=inner_node)
flow.run(shared_storage)
expected_results = {
'x': 4, # (1 + 1) * 2
'y': 6 # (2 + 1) * 2
}
self.assertEqual(shared_storage['results'], expected_results)
def test_custom_parameters(self):
"""Test batch processing with additional custom parameters"""
class CustomParamNode(Node):
def process(self, shared_storage, prep_result):
key = self.params.get('key')
multiplier = self.params.get('multiplier', 1)
if 'results' not in shared_storage:
shared_storage['results'] = {}
shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier
class CustomParamBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
return [{
'key': k,
'multiplier': i + 1
} for i, k in enumerate(shared_storage['input_data'].keys())]
shared_storage = {
'input_data': {
'a': 1,
'b': 2,
'c': 3
}
}
flow = CustomParamBatchFlow(start_node=CustomParamNode())
flow.run(shared_storage)
expected_results = {
'a': 1 * 1, # first item, multiplier = 1
'b': 2 * 2, # second item, multiplier = 2
'c': 3 * 3 # third item, multiplier = 3
}
self.assertEqual(shared_storage['results'], expected_results)
if __name__ == '__main__':
unittest.main()

162
tests/test_batch_node.py Normal file
View File

@ -0,0 +1,162 @@
import unittest
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from minillmflow import Node, BatchNode, Flow
class ArrayChunkNode(BatchNode):
def __init__(self, chunk_size=10):
super().__init__()
self.chunk_size = chunk_size
def preprocess(self, shared_storage):
# Get array from shared storage and split into chunks
array = shared_storage.get('input_array', [])
chunks = []
for i in range(0, len(array), self.chunk_size):
end = min(i + self.chunk_size, len(array))
chunks.append((i, end))
return chunks
def process(self, shared_storage, chunk_indices):
start, end = chunk_indices
array = shared_storage['input_array']
# Process the chunk and return its sum
chunk_sum = sum(array[start:end])
return chunk_sum
def postprocess(self, shared_storage, prep_result, proc_result):
# Store chunk results in shared storage
shared_storage['chunk_results'] = proc_result
return "default"
class SumReduceNode(Node):
def process(self, shared_storage, data):
# Get chunk results from shared storage and sum them
chunk_results = shared_storage.get('chunk_results', [])
total = sum(chunk_results)
shared_storage['total'] = total
class TestBatchNode(unittest.TestCase):
def test_array_chunking(self):
"""
Test that the array is correctly split into chunks
"""
shared_storage = {
'input_array': list(range(25)) # [0,1,2,...,24]
}
chunk_node = ArrayChunkNode(chunk_size=10)
chunks = chunk_node.preprocess(shared_storage)
self.assertEqual(chunks, [(0, 10), (10, 20), (20, 25)])
def test_map_reduce_sum(self):
"""
Test a complete map-reduce pipeline that sums a large array:
1. Map: Split array into chunks and sum each chunk
2. Reduce: Sum all the chunk sums
"""
# Create test array: [0,1,2,...,99]
array = list(range(100))
expected_sum = sum(array) # 4950
shared_storage = {
'input_array': array
}
# Create nodes
chunk_node = ArrayChunkNode(chunk_size=10)
reduce_node = SumReduceNode()
# Connect nodes
chunk_node >> reduce_node
# Create and run pipeline
pipeline = Flow(start_node=chunk_node)
pipeline.run(shared_storage)
self.assertEqual(shared_storage['total'], expected_sum)
def test_uneven_chunks(self):
"""
Test that the map-reduce works correctly with array lengths
that don't divide evenly by chunk_size
"""
array = list(range(25))
expected_sum = sum(array) # 300
shared_storage = {
'input_array': array
}
chunk_node = ArrayChunkNode(chunk_size=10)
reduce_node = SumReduceNode()
chunk_node >> reduce_node
pipeline = Flow(start_node=chunk_node)
pipeline.run(shared_storage)
self.assertEqual(shared_storage['total'], expected_sum)
def test_custom_chunk_size(self):
"""
Test that the map-reduce works with different chunk sizes
"""
array = list(range(100))
expected_sum = sum(array)
shared_storage = {
'input_array': array
}
# Use chunk_size=15 instead of default 10
chunk_node = ArrayChunkNode(chunk_size=15)
reduce_node = SumReduceNode()
chunk_node >> reduce_node
pipeline = Flow(start_node=chunk_node)
pipeline.run(shared_storage)
self.assertEqual(shared_storage['total'], expected_sum)
def test_single_element_chunks(self):
"""
Test extreme case where chunk_size=1
"""
array = list(range(5))
expected_sum = sum(array)
shared_storage = {
'input_array': array
}
chunk_node = ArrayChunkNode(chunk_size=1)
reduce_node = SumReduceNode()
chunk_node >> reduce_node
pipeline = Flow(start_node=chunk_node)
pipeline.run(shared_storage)
self.assertEqual(shared_storage['total'], expected_sum)
def test_empty_array(self):
"""
Test edge case of empty input array
"""
shared_storage = {
'input_array': []
}
chunk_node = ArrayChunkNode(chunk_size=10)
reduce_node = SumReduceNode()
chunk_node >> reduce_node
pipeline = Flow(start_node=chunk_node)
pipeline.run(shared_storage)
self.assertEqual(shared_storage['total'], 0)
if __name__ == '__main__':
unittest.main()

View File

@ -2,10 +2,11 @@ import unittest
import asyncio
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from minillmflow import Node, Flow
# Simple example Nodes
class NumberNode(Node):
def __init__(self, number):
super().__init__()
@ -30,106 +31,79 @@ class MultiplyNode(Node):
def process(self, shared_storage, prep_result):
shared_storage['current'] *= self.number
class TestFlowComposition(unittest.TestCase):
def test_flow_as_node(self):
"""
Demonstrates that a Flow can itself be chained like a Node.
We create a flow (f1) that starts with NumberNode(5) -> AddNode(10).
Then we chain f1 >> MultiplyNode(2).
Expected result after running from f1:
start = 5
5 + 10 = 15
15 * 2 = 30
1) Create a Flow (f1) starting with NumberNode(5), then AddNode(10), then MultiplyNode(2).
2) Create a second Flow (f2) whose start_node is f1.
3) Create a wrapper Flow (f3) that contains f2 to ensure proper execution.
Expected final result in shared_storage['current']: (5 + 10) * 2 = 30.
"""
shared_storage = {}
# Inner flow f1
f1 = Flow(start_node=NumberNode(5))
f1 >> AddNode(10)
# Then chain a node after the flow
f1 >> MultiplyNode(2)
# Run from f1
f1.run(shared_storage)
f1 >> AddNode(10) >> MultiplyNode(2)
# f2 starts with f1
f2 = Flow(start_node=f1)
# Wrapper flow f3 to ensure proper execution
f3 = Flow(start_node=f2)
f3.run(shared_storage)
self.assertEqual(shared_storage['current'], 30)
def test_nested_flow(self):
"""
Demonstrates embedding one Flow inside another Flow.
Demonstrates nested flows with proper wrapping:
inner_flow: NumberNode(5) -> AddNode(3)
outer_flow: starts with inner_flow -> MultiplyNode(4)
Expected result:
(5 + 3) * 4 = 32
middle_flow: starts with inner_flow -> MultiplyNode(4)
wrapper_flow: contains middle_flow to ensure proper execution
Expected final result: (5 + 3) * 4 = 32.
"""
shared_storage = {}
# Define an inner flow
# Build the inner flow
inner_flow = Flow(start_node=NumberNode(5))
inner_flow >> AddNode(3)
# Define an outer flow, whose start node is inner_flow
outer_flow = Flow(start_node=inner_flow)
outer_flow >> MultiplyNode(4)
# Run outer_flow
outer_flow.run(shared_storage)
self.assertEqual(shared_storage['current'], 32) # (5+3)*4=32
# Build the middle flow, whose start_node is the inner flow
middle_flow = Flow(start_node=inner_flow)
middle_flow >> MultiplyNode(4)
# Wrapper flow to ensure proper execution
wrapper_flow = Flow(start_node=middle_flow)
wrapper_flow.run(shared_storage)
self.assertEqual(shared_storage['current'], 32)
def test_flow_chaining_flows(self):
"""
Demonstrates chaining one flow to another flow.
flow1: NumberNode(10) -> AddNode(10) # final shared_storage['current'] = 20
flow2: MultiplyNode(2) # final shared_storage['current'] = 40
flow1 >> flow2 means once flow1 finishes, flow2 starts.
Expected result: (10 + 10) * 2 = 40
Demonstrates chaining two flows with proper wrapping:
flow1: NumberNode(10) -> AddNode(10) # final = 20
flow2: MultiplyNode(2) # final = 40
wrapper_flow: contains both flow1 and flow2 to ensure proper execution
Expected final result: (10 + 10) * 2 = 40.
"""
shared_storage = {}
# flow1
flow1 = Flow(start_node=NumberNode(10))
flow1 >> AddNode(10)
numbernode = NumberNode(10)
numbernode >> AddNode(10)
flow1 = Flow(start_node=numbernode)
# flow2
flow2 = Flow(start_node=MultiplyNode(2))
# Chain them: flow1 >> flow2
# Chain flow1 to flow2
flow1 >> flow2
# Start running from flow1
flow1.run(shared_storage)
# Wrapper flow to ensure proper execution
wrapper_flow = Flow(start_node=flow1)
wrapper_flow.run(shared_storage)
self.assertEqual(shared_storage['current'], 40)
def test_flow_with_parameters(self):
"""
Demonstrates passing parameters into a Flow (and retrieved by a Node).
"""
class ParamNode(Node):
def process(self, shared_storage, prep_result):
# Reads 'level' from the node's (or flow's) parameters
shared_storage['param'] = self.parameters.get('level', 'no param')
shared_storage = {}
# Create a flow with a ParamNode
f = Flow(start_node=ParamNode())
# Set parameters on the flow
f.parameters = {'level': 'Level 1'}
f.run(shared_storage)
self.assertEqual(shared_storage['param'], 'Level 1')
if __name__ == '__main__':
unittest.main()
unittest.main()