add test cases

This commit is contained in:
zachary62 2024-12-31 02:37:49 +00:00
parent c1ba9dd0d4
commit 643534bc2f
8 changed files with 508 additions and 18 deletions

View File

@ -162,7 +162,7 @@ most_relevant: filename"""
assert "most_relevant" in result if result["has_relevant"] else True
return result
# handle errors by returning a default response in case of exception after retries
def process_after_fail(self,shared,prep_res,exc):
def exec_fallback(self,shared,prep_res,exc):
# if not overridden, the default is to throw the exception
return {"think":"error finding the file", "has_relevant":False}
def post(self, shared, prep_res, exec_res):

View File

@ -355,7 +355,7 @@
" assert \"most_relevant\" in result if result[\"has_relevant\"] else True\n",
" return result\n",
" # handle errors by returning a default response in case of exception after retries\n",
" def process_after_fail(self,shared,prep_res,exc):\n",
" def exec_fallback(self,shared,prep_res,exc):\n",
" # if not overridden, the default is to throw the exception\n",
" return {\"think\":\"error finding the file\", \"has_relevant\":False}\n",
" def post(self, shared, prep_res, exec_res):\n",

View File

@ -7,7 +7,7 @@ nav_order: 5
# Async
**Mini LLM Flow** allows fully asynchronous nodes by implementing `prep_async()`, `exec_async()`, and/or `post_async()`. This is useful for:
**Mini LLM Flow** allows fully asynchronous nodes by implementing `prep_async()`, `exec_async()`, `exec_fallback_async()`, and/or `post_async()`. This is useful for:
## Implementation

View File

@ -42,7 +42,7 @@ When an exception occurs in `exec()`, the Node automatically retries until:
If you want to **gracefully handle** the error rather than raising it, you can override:
```python
def process_after_fail(self, shared, prep_res, exc):
def exec_fallback(self, shared, prep_res, exc):
raise exc
```
@ -64,7 +64,7 @@ class SummarizeFile(Node):
summary = call_llm(prompt) # might fail
return summary
def process_after_fail(self, shared, prep_res, exc):
def exec_fallback(self, shared, prep_res, exc):
# Provide a simple fallback instead of crashing
return "There was an error processing your request."
@ -76,7 +76,7 @@ class SummarizeFile(Node):
summarize_node = SummarizeFile(max_retries=3)
# Run the node standalone for testing (calls prep->exec->post).
# If exec() fails, it retries up to 3 times before calling process_after_fail().
# If exec() fails, it retries up to 3 times before calling exec_fallback().
summarize_node.set_params({"filename": "test_file.txt"})
action_result = summarize_node.run(shared)

View File

@ -1,4 +1,4 @@
import asyncio, warnings
import asyncio, warnings, copy
class BaseNode:
def __init__(self): self.params,self.successors={},{}
@ -25,15 +25,15 @@ class _ConditionalTransition:
class Node(BaseNode):
def __init__(self,max_retries=1): super().__init__();self.max_retries=max_retries
def process_after_fail(self,prep_res,exc): raise exc
def exec_fallback(self,prep_res,exc): raise exc
def _exec(self,prep_res):
for i in range(self.max_retries):
try: return super()._exec(prep_res)
try: return self.exec(prep_res)
except Exception as e:
if i==self.max_retries-1: return self.process_after_fail(prep_res,e)
if i==self.max_retries-1: return self.exec_fallback(prep_res,e)
class BatchNode(Node):
def _exec(self,items): return [super(Node,self)._exec(i) for i in items]
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in items]
class Flow(BaseNode):
def __init__(self,start): super().__init__();self.start=start
@ -43,8 +43,8 @@ class Flow(BaseNode):
warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
return nxt
def _orch(self,shared,params=None):
curr,p=self.start,(params or {**self.params})
while curr: curr.set_params(p);c=curr._run(shared);curr=self.get_next_node(curr,c)
curr,p=copy.copy(self.start),(params or {**self.params})
while curr: curr.set_params(p);c=curr._run(shared);curr=copy.copy(self.get_next_node(curr,c))
def _run(self,shared): pr=self.prep(shared);self._orch(shared);return self.post(shared,pr,None)
def exec(self,prep_res): raise RuntimeError("Flow can't exec.")
@ -58,11 +58,17 @@ class AsyncNode(Node):
def prep(self,shared): raise RuntimeError("Use prep_async.")
def exec(self,prep_res): raise RuntimeError("Use exec_async.")
def post(self,shared,prep_res,exec_res): raise RuntimeError("Use post_async.")
def exec_fallback(self,prep_res,exc): raise RuntimeError("Use exec_fallback_async.")
def _run(self,shared): raise RuntimeError("Use run_async.")
async def prep_async(self,shared): pass
async def exec_async(self,prep_res): pass
async def exec_fallback_async(self,prep_res,exc): raise exc
async def post_async(self,shared,prep_res,exec_res): pass
async def _exec(self,prep_res): return await self.exec_async(prep_res)
async def _exec(self,prep_res):
for i in range(self.max_retries):
try: return await self.exec_async(prep_res)
except Exception as e:
if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e)
async def run_async(self,shared):
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
return await self._run_async(shared)
@ -71,18 +77,18 @@ class AsyncNode(Node):
return await self.post_async(shared,p,e)
class AsyncBatchNode(AsyncNode):
async def _exec(self,items): return [await super()._exec(i) for i in items]
async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items]
class AsyncParallelBatchNode(AsyncNode):
async def _exec(self,items): return await asyncio.gather(*(super()._exec(i) for i in items))
async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))
class AsyncFlow(Flow,AsyncNode):
async def _orch_async(self,shared,params=None):
curr,p=self.start,(params or {**self.params})
curr,p=copy.copy(self.start),(params or {**self.params})
while curr:
curr.set_params(p)
c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared)
curr=self.get_next_node(curr,c)
curr=copy.copy(self.get_next_node(curr,c))
async def _run_async(self,shared):
pr=await self.prep_async(shared);await self._orch_async(shared)
return await self.post_async(shared,pr,None)

View File

@ -0,0 +1,181 @@
import unittest
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from minillmflow import AsyncNode, AsyncBatchNode, AsyncFlow
class AsyncArrayChunkNode(AsyncBatchNode):
def __init__(self, chunk_size=10):
super().__init__()
self.chunk_size = chunk_size
async def prep_async(self, shared_storage):
# Get array from shared storage and split into chunks
array = shared_storage.get('input_array', [])
chunks = []
for start in range(0, len(array), self.chunk_size):
end = min(start + self.chunk_size, len(array))
chunks.append(array[start:end])
return chunks
async def exec_async(self, chunk):
# Simulate async processing of each chunk
await asyncio.sleep(0.01)
return sum(chunk)
async def post_async(self, shared_storage, prep_result, proc_result):
# Store chunk results in shared storage
shared_storage['chunk_results'] = proc_result
return "processed"
class AsyncSumReduceNode(AsyncNode):
async def prep_async(self, shared_storage):
# Get chunk results from shared storage
chunk_results = shared_storage.get('chunk_results', [])
await asyncio.sleep(0.01) # Simulate async processing
total = sum(chunk_results)
shared_storage['total'] = total
return "reduced"
class TestAsyncBatchNode(unittest.TestCase):
def test_array_chunking(self):
"""
Test that the array is correctly split into chunks and processed asynchronously
"""
shared_storage = {
'input_array': list(range(25)) # [0,1,2,...,24]
}
chunk_node = AsyncArrayChunkNode(chunk_size=10)
asyncio.run(chunk_node.run_async(shared_storage))
results = shared_storage['chunk_results']
self.assertEqual(results, [45, 145, 110]) # Sum of chunks [0-9], [10-19], [20-24]
# def test_async_map_reduce_sum(self):
# """
# Test a complete async map-reduce pipeline that sums a large array:
# 1. Map: Split array into chunks and sum each chunk asynchronously
# 2. Reduce: Sum all the chunk sums asynchronously
# """
# array = list(range(100))
# expected_sum = sum(array) # 4950
# shared_storage = {
# 'input_array': array
# }
# # Create nodes
# chunk_node = AsyncArrayChunkNode(chunk_size=10)
# reduce_node = AsyncSumReduceNode()
# # Connect nodes
# chunk_node - "processed" >> reduce_node
# # Create and run pipeline
# pipeline = AsyncFlow(start=chunk_node)
# asyncio.run(pipeline.run_async(shared_storage))
# self.assertEqual(shared_storage['total'], expected_sum)
# def test_uneven_chunks(self):
# """
# Test that the async map-reduce works correctly with array lengths
# that don't divide evenly by chunk_size
# """
# array = list(range(25))
# expected_sum = sum(array) # 300
# shared_storage = {
# 'input_array': array
# }
# chunk_node = AsyncArrayChunkNode(chunk_size=10)
# reduce_node = AsyncSumReduceNode()
# chunk_node - "processed" >> reduce_node
# pipeline = AsyncFlow(start=chunk_node)
# asyncio.run(pipeline.run_async(shared_storage))
# self.assertEqual(shared_storage['total'], expected_sum)
# def test_custom_chunk_size(self):
# """
# Test that the async map-reduce works with different chunk sizes
# """
# array = list(range(100))
# expected_sum = sum(array)
# shared_storage = {
# 'input_array': array
# }
# # Use chunk_size=15 instead of default 10
# chunk_node = AsyncArrayChunkNode(chunk_size=15)
# reduce_node = AsyncSumReduceNode()
# chunk_node - "processed" >> reduce_node
# pipeline = AsyncFlow(start=chunk_node)
# asyncio.run(pipeline.run_async(shared_storage))
# self.assertEqual(shared_storage['total'], expected_sum)
# def test_single_element_chunks(self):
# """
# Test extreme case where chunk_size=1
# """
# array = list(range(5))
# expected_sum = sum(array)
# shared_storage = {
# 'input_array': array
# }
# chunk_node = AsyncArrayChunkNode(chunk_size=1)
# reduce_node = AsyncSumReduceNode()
# chunk_node - "processed" >> reduce_node
# pipeline = AsyncFlow(start=chunk_node)
# asyncio.run(pipeline.run_async(shared_storage))
# self.assertEqual(shared_storage['total'], expected_sum)
# def test_empty_array(self):
# """
# Test edge case of empty input array
# """
# shared_storage = {
# 'input_array': []
# }
# chunk_node = AsyncArrayChunkNode(chunk_size=10)
# reduce_node = AsyncSumReduceNode()
# chunk_node - "processed" >> reduce_node
# pipeline = AsyncFlow(start=chunk_node)
# asyncio.run(pipeline.run_async(shared_storage))
# self.assertEqual(shared_storage['total'], 0)
# def test_error_handling(self):
# """
# Test error handling in async batch processing
# """
# class ErrorAsyncBatchNode(AsyncBatchNode):
# async def exec_async(self, item):
# if item == 2:
# raise ValueError("Error processing item 2")
# return item
# shared_storage = {
# 'input_array': [1, 2, 3]
# }
# error_node = ErrorAsyncBatchNode()
# with self.assertRaises(ValueError):
# asyncio.run(error_node.run_async(shared_storage))
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,160 @@
import unittest
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from minillmflow import AsyncNode, AsyncParallelBatchNode, AsyncParallelBatchFlow
class AsyncParallelNumberProcessor(AsyncParallelBatchNode):
def __init__(self, delay=0.1):
super().__init__()
self.delay = delay
async def prep_async(self, shared_storage):
batch = shared_storage['batches'][self.params['batch_id']]
return batch
async def exec_async(self, number):
await asyncio.sleep(self.delay) # Simulate async processing
return number * 2
async def post_async(self, shared_storage, prep_result, exec_result):
if 'processed_numbers' not in shared_storage:
shared_storage['processed_numbers'] = {}
shared_storage['processed_numbers'][self.params['batch_id']] = exec_result
return "processed"
class AsyncAggregatorNode(AsyncNode):
async def prep_async(self, shared_storage):
# Combine all batch results in order
all_results = []
processed = shared_storage.get('processed_numbers', {})
for i in range(len(processed)):
all_results.extend(processed[i])
return all_results
async def exec_async(self, prep_result):
await asyncio.sleep(0.01)
return sum(prep_result)
async def post_async(self, shared_storage, prep_result, exec_result):
shared_storage['total'] = exec_result
return "aggregated"
class TestAsyncParallelBatchFlow(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
def test_parallel_batch_flow(self):
"""
Test basic parallel batch processing flow with batch IDs
"""
class TestParallelBatchFlow(AsyncParallelBatchFlow):
async def prep_async(self, shared_storage):
return [{'batch_id': i} for i in range(len(shared_storage['batches']))]
shared_storage = {
'batches': [
[1, 2, 3], # batch_id: 0
[4, 5, 6], # batch_id: 1
[7, 8, 9] # batch_id: 2
]
}
processor = AsyncParallelNumberProcessor(delay=0.1)
aggregator = AsyncAggregatorNode()
processor - "processed" >> aggregator
flow = TestParallelBatchFlow(start=processor)
start_time = self.loop.time()
self.loop.run_until_complete(flow.run_async(shared_storage))
execution_time = self.loop.time() - start_time
# Verify each batch was processed correctly
expected_batch_results = {
0: [2, 4, 6], # [1,2,3] * 2
1: [8, 10, 12], # [4,5,6] * 2
2: [14, 16, 18] # [7,8,9] * 2
}
self.assertEqual(shared_storage['processed_numbers'], expected_batch_results)
# Verify total
expected_total = sum(num * 2 for batch in shared_storage['batches'] for num in batch)
self.assertEqual(shared_storage['total'], expected_total)
# Verify parallel execution
self.assertLess(execution_time, 0.2)
def test_error_handling(self):
"""
Test error handling in parallel batch flow
"""
class ErrorProcessor(AsyncParallelNumberProcessor):
async def exec_async(self, item):
if item == 2:
raise ValueError(f"Error processing item {item}")
return item
class ErrorBatchFlow(AsyncParallelBatchFlow):
async def prep_async(self, shared_storage):
return [{'batch_id': i} for i in range(len(shared_storage['batches']))]
shared_storage = {
'batches': [
[1, 2, 3], # Contains error-triggering value
[4, 5, 6]
]
}
processor = ErrorProcessor()
flow = ErrorBatchFlow(start=processor)
with self.assertRaises(ValueError):
self.loop.run_until_complete(flow.run_async(shared_storage))
def test_multiple_batch_sizes(self):
"""
Test parallel batch flow with varying batch sizes
"""
class VaryingBatchFlow(AsyncParallelBatchFlow):
async def prep_async(self, shared_storage):
return [{'batch_id': i} for i in range(len(shared_storage['batches']))]
shared_storage = {
'batches': [
[1], # batch_id: 0
[2, 3, 4], # batch_id: 1
[5, 6], # batch_id: 2
[7, 8, 9, 10] # batch_id: 3
]
}
processor = AsyncParallelNumberProcessor(delay=0.05)
aggregator = AsyncAggregatorNode()
processor - "processed" >> aggregator
flow = VaryingBatchFlow(start=processor)
self.loop.run_until_complete(flow.run_async(shared_storage))
# Verify each batch was processed correctly
expected_batch_results = {
0: [2], # [1] * 2
1: [4, 6, 8], # [2,3,4] * 2
2: [10, 12], # [5,6] * 2
3: [14, 16, 18, 20] # [7,8,9,10] * 2
}
self.assertEqual(shared_storage['processed_numbers'], expected_batch_results)
# Verify total
expected_total = sum(num * 2 for batch in shared_storage['batches'] for num in batch)
self.assertEqual(shared_storage['total'], expected_total)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,143 @@
import unittest
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from minillmflow import AsyncParallelBatchNode, AsyncParallelBatchFlow
class AsyncParallelNumberProcessor(AsyncParallelBatchNode):
def __init__(self, delay=0.1):
super().__init__()
self.delay = delay
async def prep_async(self, shared_storage):
numbers = shared_storage.get('input_numbers', [])
return numbers
async def exec_async(self, number):
await asyncio.sleep(self.delay) # Simulate async processing
return number * 2
async def post_async(self, shared_storage, prep_result, exec_result):
shared_storage['processed_numbers'] = exec_result
return "processed"
class TestAsyncParallelBatchNode(unittest.TestCase):
def setUp(self):
# Reset the event loop for each test
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
def test_parallel_processing(self):
"""
Test that numbers are processed in parallel by measuring execution time
"""
shared_storage = {
'input_numbers': list(range(5))
}
processor = AsyncParallelNumberProcessor(delay=0.1)
# Run the processor
start_time = asyncio.get_event_loop().time()
self.loop.run_until_complete(processor.run_async(shared_storage))
end_time = asyncio.get_event_loop().time()
# Check results
expected = [0, 2, 4, 6, 8] # Each number doubled
self.assertEqual(shared_storage['processed_numbers'], expected)
# Since processing is parallel, total time should be approximately
# equal to the delay of a single operation, not delay * number_of_items
execution_time = end_time - start_time
self.assertLess(execution_time, 0.2) # Should be around 0.1s plus minimal overhead
def test_empty_input(self):
"""
Test processing of empty input
"""
shared_storage = {
'input_numbers': []
}
processor = AsyncParallelNumberProcessor()
self.loop.run_until_complete(processor.run_async(shared_storage))
self.assertEqual(shared_storage['processed_numbers'], [])
def test_single_item(self):
"""
Test processing of a single item
"""
shared_storage = {
'input_numbers': [42]
}
processor = AsyncParallelNumberProcessor()
self.loop.run_until_complete(processor.run_async(shared_storage))
self.assertEqual(shared_storage['processed_numbers'], [84])
def test_large_batch(self):
"""
Test processing of a large batch of numbers
"""
input_size = 100
shared_storage = {
'input_numbers': list(range(input_size))
}
processor = AsyncParallelNumberProcessor(delay=0.01)
self.loop.run_until_complete(processor.run_async(shared_storage))
expected = [x * 2 for x in range(input_size)]
self.assertEqual(shared_storage['processed_numbers'], expected)
def test_error_handling(self):
"""
Test error handling during parallel processing
"""
class ErrorProcessor(AsyncParallelNumberProcessor):
async def exec_async(self, item):
if item == 2:
raise ValueError(f"Error processing item {item}")
return item
shared_storage = {
'input_numbers': [1, 2, 3]
}
processor = ErrorProcessor()
with self.assertRaises(ValueError):
self.loop.run_until_complete(processor.run_async(shared_storage))
def test_concurrent_execution(self):
"""
Test that tasks are actually running concurrently by tracking execution order
"""
execution_order = []
class OrderTrackingProcessor(AsyncParallelNumberProcessor):
async def exec_async(self, item):
delay = 0.1 if item % 2 == 0 else 0.05
await asyncio.sleep(delay)
execution_order.append(item)
return item
shared_storage = {
'input_numbers': list(range(4)) # [0, 1, 2, 3]
}
processor = OrderTrackingProcessor()
self.loop.run_until_complete(processor.run_async(shared_storage))
# Odd numbers should finish before even numbers due to shorter delay
self.assertLess(execution_order.index(1), execution_order.index(0))
self.assertLess(execution_order.index(3), execution_order.index(2))
if __name__ == '__main__':
unittest.main()