From 643534bc2fa58b91587f4da3da64be1732e7c950 Mon Sep 17 00:00:00 2001 From: zachary62 Date: Tue, 31 Dec 2024 02:37:49 +0000 Subject: [PATCH] add test cases --- assets/prompt | 2 +- cookbook/demo.ipynb | 2 +- docs/async.md | 2 +- docs/node.md | 6 +- minillmflow/__init__.py | 30 ++-- tests/test_async_batch_node.py | 181 ++++++++++++++++++++++++ tests/test_async_parallel_batch_flow.py | 160 +++++++++++++++++++++ tests/test_async_parallel_batch_node.py | 143 +++++++++++++++++++ 8 files changed, 508 insertions(+), 18 deletions(-) create mode 100644 tests/test_async_batch_node.py create mode 100644 tests/test_async_parallel_batch_flow.py create mode 100644 tests/test_async_parallel_batch_node.py diff --git a/assets/prompt b/assets/prompt index cf431f1..e991612 100644 --- a/assets/prompt +++ b/assets/prompt @@ -162,7 +162,7 @@ most_relevant: filename""" assert "most_relevant" in result if result["has_relevant"] else True return result # handle errors by returning a default response in case of exception after retries - def process_after_fail(self,shared,prep_res,exc): + def exec_fallback(self,shared,prep_res,exc): # if not overridden, the default is to throw the exception return {"think":"error finding the file", "has_relevant":False} def post(self, shared, prep_res, exec_res): diff --git a/cookbook/demo.ipynb b/cookbook/demo.ipynb index 3b6ae9f..c4d54e7 100644 --- a/cookbook/demo.ipynb +++ b/cookbook/demo.ipynb @@ -355,7 +355,7 @@ " assert \"most_relevant\" in result if result[\"has_relevant\"] else True\n", " return result\n", " # handle errors by returning a default response in case of exception after retries\n", - " def process_after_fail(self,shared,prep_res,exc):\n", + " def exec_fallback(self,shared,prep_res,exc):\n", " # if not overridden, the default is to throw the exception\n", " return {\"think\":\"error finding the file\", \"has_relevant\":False}\n", " def post(self, shared, prep_res, exec_res):\n", diff --git a/docs/async.md b/docs/async.md index bbb5c95..e92133b 100644 --- a/docs/async.md +++ b/docs/async.md @@ -7,7 +7,7 @@ nav_order: 5 # Async -**Mini LLM Flow** allows fully asynchronous nodes by implementing `prep_async()`, `exec_async()`, and/or `post_async()`. This is useful for: +**Mini LLM Flow** allows fully asynchronous nodes by implementing `prep_async()`, `exec_async()`, `exec_fallback_async()`, and/or `post_async()`. This is useful for: ## Implementation diff --git a/docs/node.md b/docs/node.md index aa55177..32b20b6 100644 --- a/docs/node.md +++ b/docs/node.md @@ -42,7 +42,7 @@ When an exception occurs in `exec()`, the Node automatically retries until: If you want to **gracefully handle** the error rather than raising it, you can override: ```python -def process_after_fail(self, shared, prep_res, exc): +def exec_fallback(self, shared, prep_res, exc): raise exc ``` @@ -64,7 +64,7 @@ class SummarizeFile(Node): summary = call_llm(prompt) # might fail return summary - def process_after_fail(self, shared, prep_res, exc): + def exec_fallback(self, shared, prep_res, exc): # Provide a simple fallback instead of crashing return "There was an error processing your request." @@ -76,7 +76,7 @@ class SummarizeFile(Node): summarize_node = SummarizeFile(max_retries=3) # Run the node standalone for testing (calls prep->exec->post). -# If exec() fails, it retries up to 3 times before calling process_after_fail(). +# If exec() fails, it retries up to 3 times before calling exec_fallback(). summarize_node.set_params({"filename": "test_file.txt"}) action_result = summarize_node.run(shared) diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index ac1a4a0..58e0358 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -1,4 +1,4 @@ -import asyncio, warnings +import asyncio, warnings, copy class BaseNode: def __init__(self): self.params,self.successors={},{} @@ -25,15 +25,15 @@ class _ConditionalTransition: class Node(BaseNode): def __init__(self,max_retries=1): super().__init__();self.max_retries=max_retries - def process_after_fail(self,prep_res,exc): raise exc + def exec_fallback(self,prep_res,exc): raise exc def _exec(self,prep_res): for i in range(self.max_retries): - try: return super()._exec(prep_res) + try: return self.exec(prep_res) except Exception as e: - if i==self.max_retries-1: return self.process_after_fail(prep_res,e) + if i==self.max_retries-1: return self.exec_fallback(prep_res,e) class BatchNode(Node): - def _exec(self,items): return [super(Node,self)._exec(i) for i in items] + def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in items] class Flow(BaseNode): def __init__(self,start): super().__init__();self.start=start @@ -43,8 +43,8 @@ class Flow(BaseNode): warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") return nxt def _orch(self,shared,params=None): - curr,p=self.start,(params or {**self.params}) - while curr: curr.set_params(p);c=curr._run(shared);curr=self.get_next_node(curr,c) + curr,p=copy.copy(self.start),(params or {**self.params}) + while curr: curr.set_params(p);c=curr._run(shared);curr=copy.copy(self.get_next_node(curr,c)) def _run(self,shared): pr=self.prep(shared);self._orch(shared);return self.post(shared,pr,None) def exec(self,prep_res): raise RuntimeError("Flow can't exec.") @@ -58,11 +58,17 @@ class AsyncNode(Node): def prep(self,shared): raise RuntimeError("Use prep_async.") def exec(self,prep_res): raise RuntimeError("Use exec_async.") def post(self,shared,prep_res,exec_res): raise RuntimeError("Use post_async.") + def exec_fallback(self,prep_res,exc): raise RuntimeError("Use exec_fallback_async.") def _run(self,shared): raise RuntimeError("Use run_async.") async def prep_async(self,shared): pass async def exec_async(self,prep_res): pass + async def exec_fallback_async(self,prep_res,exc): raise exc async def post_async(self,shared,prep_res,exec_res): pass - async def _exec(self,prep_res): return await self.exec_async(prep_res) + async def _exec(self,prep_res): + for i in range(self.max_retries): + try: return await self.exec_async(prep_res) + except Exception as e: + if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e) async def run_async(self,shared): if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.") return await self._run_async(shared) @@ -71,18 +77,18 @@ class AsyncNode(Node): return await self.post_async(shared,p,e) class AsyncBatchNode(AsyncNode): - async def _exec(self,items): return [await super()._exec(i) for i in items] + async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items] class AsyncParallelBatchNode(AsyncNode): - async def _exec(self,items): return await asyncio.gather(*(super()._exec(i) for i in items)) + async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items)) class AsyncFlow(Flow,AsyncNode): async def _orch_async(self,shared,params=None): - curr,p=self.start,(params or {**self.params}) + curr,p=copy.copy(self.start),(params or {**self.params}) while curr: curr.set_params(p) c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared) - curr=self.get_next_node(curr,c) + curr=copy.copy(self.get_next_node(curr,c)) async def _run_async(self,shared): pr=await self.prep_async(shared);await self._orch_async(shared) return await self.post_async(shared,pr,None) diff --git a/tests/test_async_batch_node.py b/tests/test_async_batch_node.py new file mode 100644 index 0000000..23b8186 --- /dev/null +++ b/tests/test_async_batch_node.py @@ -0,0 +1,181 @@ +import unittest +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) +from minillmflow import AsyncNode, AsyncBatchNode, AsyncFlow + +class AsyncArrayChunkNode(AsyncBatchNode): + def __init__(self, chunk_size=10): + super().__init__() + self.chunk_size = chunk_size + + async def prep_async(self, shared_storage): + # Get array from shared storage and split into chunks + array = shared_storage.get('input_array', []) + chunks = [] + for start in range(0, len(array), self.chunk_size): + end = min(start + self.chunk_size, len(array)) + chunks.append(array[start:end]) + return chunks + + async def exec_async(self, chunk): + # Simulate async processing of each chunk + await asyncio.sleep(0.01) + return sum(chunk) + + async def post_async(self, shared_storage, prep_result, proc_result): + # Store chunk results in shared storage + shared_storage['chunk_results'] = proc_result + return "processed" + +class AsyncSumReduceNode(AsyncNode): + async def prep_async(self, shared_storage): + # Get chunk results from shared storage + chunk_results = shared_storage.get('chunk_results', []) + await asyncio.sleep(0.01) # Simulate async processing + total = sum(chunk_results) + shared_storage['total'] = total + return "reduced" + +class TestAsyncBatchNode(unittest.TestCase): + def test_array_chunking(self): + """ + Test that the array is correctly split into chunks and processed asynchronously + """ + shared_storage = { + 'input_array': list(range(25)) # [0,1,2,...,24] + } + + chunk_node = AsyncArrayChunkNode(chunk_size=10) + asyncio.run(chunk_node.run_async(shared_storage)) + + results = shared_storage['chunk_results'] + self.assertEqual(results, [45, 145, 110]) # Sum of chunks [0-9], [10-19], [20-24] + + # def test_async_map_reduce_sum(self): + # """ + # Test a complete async map-reduce pipeline that sums a large array: + # 1. Map: Split array into chunks and sum each chunk asynchronously + # 2. Reduce: Sum all the chunk sums asynchronously + # """ + # array = list(range(100)) + # expected_sum = sum(array) # 4950 + + # shared_storage = { + # 'input_array': array + # } + + # # Create nodes + # chunk_node = AsyncArrayChunkNode(chunk_size=10) + # reduce_node = AsyncSumReduceNode() + + # # Connect nodes + # chunk_node - "processed" >> reduce_node + + # # Create and run pipeline + # pipeline = AsyncFlow(start=chunk_node) + # asyncio.run(pipeline.run_async(shared_storage)) + + # self.assertEqual(shared_storage['total'], expected_sum) + + # def test_uneven_chunks(self): + # """ + # Test that the async map-reduce works correctly with array lengths + # that don't divide evenly by chunk_size + # """ + # array = list(range(25)) + # expected_sum = sum(array) # 300 + + # shared_storage = { + # 'input_array': array + # } + + # chunk_node = AsyncArrayChunkNode(chunk_size=10) + # reduce_node = AsyncSumReduceNode() + + # chunk_node - "processed" >> reduce_node + # pipeline = AsyncFlow(start=chunk_node) + # asyncio.run(pipeline.run_async(shared_storage)) + + # self.assertEqual(shared_storage['total'], expected_sum) + + # def test_custom_chunk_size(self): + # """ + # Test that the async map-reduce works with different chunk sizes + # """ + # array = list(range(100)) + # expected_sum = sum(array) + + # shared_storage = { + # 'input_array': array + # } + + # # Use chunk_size=15 instead of default 10 + # chunk_node = AsyncArrayChunkNode(chunk_size=15) + # reduce_node = AsyncSumReduceNode() + + # chunk_node - "processed" >> reduce_node + # pipeline = AsyncFlow(start=chunk_node) + # asyncio.run(pipeline.run_async(shared_storage)) + + # self.assertEqual(shared_storage['total'], expected_sum) + + # def test_single_element_chunks(self): + # """ + # Test extreme case where chunk_size=1 + # """ + # array = list(range(5)) + # expected_sum = sum(array) + + # shared_storage = { + # 'input_array': array + # } + + # chunk_node = AsyncArrayChunkNode(chunk_size=1) + # reduce_node = AsyncSumReduceNode() + + # chunk_node - "processed" >> reduce_node + # pipeline = AsyncFlow(start=chunk_node) + # asyncio.run(pipeline.run_async(shared_storage)) + + # self.assertEqual(shared_storage['total'], expected_sum) + + # def test_empty_array(self): + # """ + # Test edge case of empty input array + # """ + # shared_storage = { + # 'input_array': [] + # } + + # chunk_node = AsyncArrayChunkNode(chunk_size=10) + # reduce_node = AsyncSumReduceNode() + + # chunk_node - "processed" >> reduce_node + # pipeline = AsyncFlow(start=chunk_node) + # asyncio.run(pipeline.run_async(shared_storage)) + + # self.assertEqual(shared_storage['total'], 0) + + # def test_error_handling(self): + # """ + # Test error handling in async batch processing + # """ + # class ErrorAsyncBatchNode(AsyncBatchNode): + # async def exec_async(self, item): + # if item == 2: + # raise ValueError("Error processing item 2") + # return item + + # shared_storage = { + # 'input_array': [1, 2, 3] + # } + + # error_node = ErrorAsyncBatchNode() + # with self.assertRaises(ValueError): + # asyncio.run(error_node.run_async(shared_storage)) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_async_parallel_batch_flow.py b/tests/test_async_parallel_batch_flow.py new file mode 100644 index 0000000..9b761fe --- /dev/null +++ b/tests/test_async_parallel_batch_flow.py @@ -0,0 +1,160 @@ +import unittest +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) +from minillmflow import AsyncNode, AsyncParallelBatchNode, AsyncParallelBatchFlow + +class AsyncParallelNumberProcessor(AsyncParallelBatchNode): + def __init__(self, delay=0.1): + super().__init__() + self.delay = delay + + async def prep_async(self, shared_storage): + batch = shared_storage['batches'][self.params['batch_id']] + return batch + + async def exec_async(self, number): + await asyncio.sleep(self.delay) # Simulate async processing + return number * 2 + + async def post_async(self, shared_storage, prep_result, exec_result): + if 'processed_numbers' not in shared_storage: + shared_storage['processed_numbers'] = {} + shared_storage['processed_numbers'][self.params['batch_id']] = exec_result + return "processed" + +class AsyncAggregatorNode(AsyncNode): + async def prep_async(self, shared_storage): + # Combine all batch results in order + all_results = [] + processed = shared_storage.get('processed_numbers', {}) + for i in range(len(processed)): + all_results.extend(processed[i]) + return all_results + + async def exec_async(self, prep_result): + await asyncio.sleep(0.01) + return sum(prep_result) + + async def post_async(self, shared_storage, prep_result, exec_result): + shared_storage['total'] = exec_result + return "aggregated" + +class TestAsyncParallelBatchFlow(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_parallel_batch_flow(self): + """ + Test basic parallel batch processing flow with batch IDs + """ + class TestParallelBatchFlow(AsyncParallelBatchFlow): + async def prep_async(self, shared_storage): + return [{'batch_id': i} for i in range(len(shared_storage['batches']))] + + shared_storage = { + 'batches': [ + [1, 2, 3], # batch_id: 0 + [4, 5, 6], # batch_id: 1 + [7, 8, 9] # batch_id: 2 + ] + } + + processor = AsyncParallelNumberProcessor(delay=0.1) + aggregator = AsyncAggregatorNode() + + processor - "processed" >> aggregator + flow = TestParallelBatchFlow(start=processor) + + start_time = self.loop.time() + self.loop.run_until_complete(flow.run_async(shared_storage)) + execution_time = self.loop.time() - start_time + + # Verify each batch was processed correctly + expected_batch_results = { + 0: [2, 4, 6], # [1,2,3] * 2 + 1: [8, 10, 12], # [4,5,6] * 2 + 2: [14, 16, 18] # [7,8,9] * 2 + } + self.assertEqual(shared_storage['processed_numbers'], expected_batch_results) + + # Verify total + expected_total = sum(num * 2 for batch in shared_storage['batches'] for num in batch) + self.assertEqual(shared_storage['total'], expected_total) + + # Verify parallel execution + self.assertLess(execution_time, 0.2) + + def test_error_handling(self): + """ + Test error handling in parallel batch flow + """ + class ErrorProcessor(AsyncParallelNumberProcessor): + async def exec_async(self, item): + if item == 2: + raise ValueError(f"Error processing item {item}") + return item + + class ErrorBatchFlow(AsyncParallelBatchFlow): + async def prep_async(self, shared_storage): + return [{'batch_id': i} for i in range(len(shared_storage['batches']))] + + shared_storage = { + 'batches': [ + [1, 2, 3], # Contains error-triggering value + [4, 5, 6] + ] + } + + processor = ErrorProcessor() + flow = ErrorBatchFlow(start=processor) + + with self.assertRaises(ValueError): + self.loop.run_until_complete(flow.run_async(shared_storage)) + + def test_multiple_batch_sizes(self): + """ + Test parallel batch flow with varying batch sizes + """ + class VaryingBatchFlow(AsyncParallelBatchFlow): + async def prep_async(self, shared_storage): + return [{'batch_id': i} for i in range(len(shared_storage['batches']))] + + shared_storage = { + 'batches': [ + [1], # batch_id: 0 + [2, 3, 4], # batch_id: 1 + [5, 6], # batch_id: 2 + [7, 8, 9, 10] # batch_id: 3 + ] + } + + processor = AsyncParallelNumberProcessor(delay=0.05) + aggregator = AsyncAggregatorNode() + + processor - "processed" >> aggregator + flow = VaryingBatchFlow(start=processor) + + self.loop.run_until_complete(flow.run_async(shared_storage)) + + # Verify each batch was processed correctly + expected_batch_results = { + 0: [2], # [1] * 2 + 1: [4, 6, 8], # [2,3,4] * 2 + 2: [10, 12], # [5,6] * 2 + 3: [14, 16, 18, 20] # [7,8,9,10] * 2 + } + self.assertEqual(shared_storage['processed_numbers'], expected_batch_results) + + # Verify total + expected_total = sum(num * 2 for batch in shared_storage['batches'] for num in batch) + self.assertEqual(shared_storage['total'], expected_total) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_async_parallel_batch_node.py b/tests/test_async_parallel_batch_node.py new file mode 100644 index 0000000..5008748 --- /dev/null +++ b/tests/test_async_parallel_batch_node.py @@ -0,0 +1,143 @@ +import unittest +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) +from minillmflow import AsyncParallelBatchNode, AsyncParallelBatchFlow + +class AsyncParallelNumberProcessor(AsyncParallelBatchNode): + def __init__(self, delay=0.1): + super().__init__() + self.delay = delay + + async def prep_async(self, shared_storage): + numbers = shared_storage.get('input_numbers', []) + return numbers + + async def exec_async(self, number): + await asyncio.sleep(self.delay) # Simulate async processing + return number * 2 + + async def post_async(self, shared_storage, prep_result, exec_result): + shared_storage['processed_numbers'] = exec_result + return "processed" + +class TestAsyncParallelBatchNode(unittest.TestCase): + def setUp(self): + # Reset the event loop for each test + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_parallel_processing(self): + """ + Test that numbers are processed in parallel by measuring execution time + """ + shared_storage = { + 'input_numbers': list(range(5)) + } + + processor = AsyncParallelNumberProcessor(delay=0.1) + + # Run the processor + start_time = asyncio.get_event_loop().time() + self.loop.run_until_complete(processor.run_async(shared_storage)) + end_time = asyncio.get_event_loop().time() + + # Check results + expected = [0, 2, 4, 6, 8] # Each number doubled + self.assertEqual(shared_storage['processed_numbers'], expected) + + # Since processing is parallel, total time should be approximately + # equal to the delay of a single operation, not delay * number_of_items + execution_time = end_time - start_time + self.assertLess(execution_time, 0.2) # Should be around 0.1s plus minimal overhead + + def test_empty_input(self): + """ + Test processing of empty input + """ + shared_storage = { + 'input_numbers': [] + } + + processor = AsyncParallelNumberProcessor() + self.loop.run_until_complete(processor.run_async(shared_storage)) + + self.assertEqual(shared_storage['processed_numbers'], []) + + def test_single_item(self): + """ + Test processing of a single item + """ + shared_storage = { + 'input_numbers': [42] + } + + processor = AsyncParallelNumberProcessor() + self.loop.run_until_complete(processor.run_async(shared_storage)) + + self.assertEqual(shared_storage['processed_numbers'], [84]) + + def test_large_batch(self): + """ + Test processing of a large batch of numbers + """ + input_size = 100 + shared_storage = { + 'input_numbers': list(range(input_size)) + } + + processor = AsyncParallelNumberProcessor(delay=0.01) + self.loop.run_until_complete(processor.run_async(shared_storage)) + + expected = [x * 2 for x in range(input_size)] + self.assertEqual(shared_storage['processed_numbers'], expected) + + def test_error_handling(self): + """ + Test error handling during parallel processing + """ + class ErrorProcessor(AsyncParallelNumberProcessor): + async def exec_async(self, item): + if item == 2: + raise ValueError(f"Error processing item {item}") + return item + + shared_storage = { + 'input_numbers': [1, 2, 3] + } + + processor = ErrorProcessor() + with self.assertRaises(ValueError): + self.loop.run_until_complete(processor.run_async(shared_storage)) + + def test_concurrent_execution(self): + """ + Test that tasks are actually running concurrently by tracking execution order + """ + execution_order = [] + + class OrderTrackingProcessor(AsyncParallelNumberProcessor): + async def exec_async(self, item): + delay = 0.1 if item % 2 == 0 else 0.05 + await asyncio.sleep(delay) + execution_order.append(item) + return item + + shared_storage = { + 'input_numbers': list(range(4)) # [0, 1, 2, 3] + } + + processor = OrderTrackingProcessor() + self.loop.run_until_complete(processor.run_async(shared_storage)) + + # Odd numbers should finish before even numbers due to shorter delay + self.assertLess(execution_order.index(1), execution_order.index(0)) + self.assertLess(execution_order.index(3), execution_order.index(2)) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file