add test cases
This commit is contained in:
parent
c1ba9dd0d4
commit
643534bc2f
|
|
@ -162,7 +162,7 @@ most_relevant: filename"""
|
|||
assert "most_relevant" in result if result["has_relevant"] else True
|
||||
return result
|
||||
# handle errors by returning a default response in case of exception after retries
|
||||
def process_after_fail(self,shared,prep_res,exc):
|
||||
def exec_fallback(self,shared,prep_res,exc):
|
||||
# if not overridden, the default is to throw the exception
|
||||
return {"think":"error finding the file", "has_relevant":False}
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
|
|
|
|||
|
|
@ -355,7 +355,7 @@
|
|||
" assert \"most_relevant\" in result if result[\"has_relevant\"] else True\n",
|
||||
" return result\n",
|
||||
" # handle errors by returning a default response in case of exception after retries\n",
|
||||
" def process_after_fail(self,shared,prep_res,exc):\n",
|
||||
" def exec_fallback(self,shared,prep_res,exc):\n",
|
||||
" # if not overridden, the default is to throw the exception\n",
|
||||
" return {\"think\":\"error finding the file\", \"has_relevant\":False}\n",
|
||||
" def post(self, shared, prep_res, exec_res):\n",
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ nav_order: 5
|
|||
|
||||
# Async
|
||||
|
||||
**Mini LLM Flow** allows fully asynchronous nodes by implementing `prep_async()`, `exec_async()`, and/or `post_async()`. This is useful for:
|
||||
**Mini LLM Flow** allows fully asynchronous nodes by implementing `prep_async()`, `exec_async()`, `exec_fallback_async()`, and/or `post_async()`. This is useful for:
|
||||
|
||||
## Implementation
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ When an exception occurs in `exec()`, the Node automatically retries until:
|
|||
If you want to **gracefully handle** the error rather than raising it, you can override:
|
||||
|
||||
```python
|
||||
def process_after_fail(self, shared, prep_res, exc):
|
||||
def exec_fallback(self, shared, prep_res, exc):
|
||||
raise exc
|
||||
```
|
||||
|
||||
|
|
@ -64,7 +64,7 @@ class SummarizeFile(Node):
|
|||
summary = call_llm(prompt) # might fail
|
||||
return summary
|
||||
|
||||
def process_after_fail(self, shared, prep_res, exc):
|
||||
def exec_fallback(self, shared, prep_res, exc):
|
||||
# Provide a simple fallback instead of crashing
|
||||
return "There was an error processing your request."
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ class SummarizeFile(Node):
|
|||
summarize_node = SummarizeFile(max_retries=3)
|
||||
|
||||
# Run the node standalone for testing (calls prep->exec->post).
|
||||
# If exec() fails, it retries up to 3 times before calling process_after_fail().
|
||||
# If exec() fails, it retries up to 3 times before calling exec_fallback().
|
||||
summarize_node.set_params({"filename": "test_file.txt"})
|
||||
action_result = summarize_node.run(shared)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import asyncio, warnings
|
||||
import asyncio, warnings, copy
|
||||
|
||||
class BaseNode:
|
||||
def __init__(self): self.params,self.successors={},{}
|
||||
|
|
@ -25,15 +25,15 @@ class _ConditionalTransition:
|
|||
|
||||
class Node(BaseNode):
|
||||
def __init__(self,max_retries=1): super().__init__();self.max_retries=max_retries
|
||||
def process_after_fail(self,prep_res,exc): raise exc
|
||||
def exec_fallback(self,prep_res,exc): raise exc
|
||||
def _exec(self,prep_res):
|
||||
for i in range(self.max_retries):
|
||||
try: return super()._exec(prep_res)
|
||||
try: return self.exec(prep_res)
|
||||
except Exception as e:
|
||||
if i==self.max_retries-1: return self.process_after_fail(prep_res,e)
|
||||
if i==self.max_retries-1: return self.exec_fallback(prep_res,e)
|
||||
|
||||
class BatchNode(Node):
|
||||
def _exec(self,items): return [super(Node,self)._exec(i) for i in items]
|
||||
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in items]
|
||||
|
||||
class Flow(BaseNode):
|
||||
def __init__(self,start): super().__init__();self.start=start
|
||||
|
|
@ -43,8 +43,8 @@ class Flow(BaseNode):
|
|||
warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
|
||||
return nxt
|
||||
def _orch(self,shared,params=None):
|
||||
curr,p=self.start,(params or {**self.params})
|
||||
while curr: curr.set_params(p);c=curr._run(shared);curr=self.get_next_node(curr,c)
|
||||
curr,p=copy.copy(self.start),(params or {**self.params})
|
||||
while curr: curr.set_params(p);c=curr._run(shared);curr=copy.copy(self.get_next_node(curr,c))
|
||||
def _run(self,shared): pr=self.prep(shared);self._orch(shared);return self.post(shared,pr,None)
|
||||
def exec(self,prep_res): raise RuntimeError("Flow can't exec.")
|
||||
|
||||
|
|
@ -58,11 +58,17 @@ class AsyncNode(Node):
|
|||
def prep(self,shared): raise RuntimeError("Use prep_async.")
|
||||
def exec(self,prep_res): raise RuntimeError("Use exec_async.")
|
||||
def post(self,shared,prep_res,exec_res): raise RuntimeError("Use post_async.")
|
||||
def exec_fallback(self,prep_res,exc): raise RuntimeError("Use exec_fallback_async.")
|
||||
def _run(self,shared): raise RuntimeError("Use run_async.")
|
||||
async def prep_async(self,shared): pass
|
||||
async def exec_async(self,prep_res): pass
|
||||
async def exec_fallback_async(self,prep_res,exc): raise exc
|
||||
async def post_async(self,shared,prep_res,exec_res): pass
|
||||
async def _exec(self,prep_res): return await self.exec_async(prep_res)
|
||||
async def _exec(self,prep_res):
|
||||
for i in range(self.max_retries):
|
||||
try: return await self.exec_async(prep_res)
|
||||
except Exception as e:
|
||||
if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e)
|
||||
async def run_async(self,shared):
|
||||
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
|
||||
return await self._run_async(shared)
|
||||
|
|
@ -71,18 +77,18 @@ class AsyncNode(Node):
|
|||
return await self.post_async(shared,p,e)
|
||||
|
||||
class AsyncBatchNode(AsyncNode):
|
||||
async def _exec(self,items): return [await super()._exec(i) for i in items]
|
||||
async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items]
|
||||
|
||||
class AsyncParallelBatchNode(AsyncNode):
|
||||
async def _exec(self,items): return await asyncio.gather(*(super()._exec(i) for i in items))
|
||||
async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))
|
||||
|
||||
class AsyncFlow(Flow,AsyncNode):
|
||||
async def _orch_async(self,shared,params=None):
|
||||
curr,p=self.start,(params or {**self.params})
|
||||
curr,p=copy.copy(self.start),(params or {**self.params})
|
||||
while curr:
|
||||
curr.set_params(p)
|
||||
c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared)
|
||||
curr=self.get_next_node(curr,c)
|
||||
curr=copy.copy(self.get_next_node(curr,c))
|
||||
async def _run_async(self,shared):
|
||||
pr=await self.prep_async(shared);await self._orch_async(shared)
|
||||
return await self.post_async(shared,pr,None)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,181 @@
|
|||
import unittest
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from minillmflow import AsyncNode, AsyncBatchNode, AsyncFlow
|
||||
|
||||
class AsyncArrayChunkNode(AsyncBatchNode):
|
||||
def __init__(self, chunk_size=10):
|
||||
super().__init__()
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
async def prep_async(self, shared_storage):
|
||||
# Get array from shared storage and split into chunks
|
||||
array = shared_storage.get('input_array', [])
|
||||
chunks = []
|
||||
for start in range(0, len(array), self.chunk_size):
|
||||
end = min(start + self.chunk_size, len(array))
|
||||
chunks.append(array[start:end])
|
||||
return chunks
|
||||
|
||||
async def exec_async(self, chunk):
|
||||
# Simulate async processing of each chunk
|
||||
await asyncio.sleep(0.01)
|
||||
return sum(chunk)
|
||||
|
||||
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||
# Store chunk results in shared storage
|
||||
shared_storage['chunk_results'] = proc_result
|
||||
return "processed"
|
||||
|
||||
class AsyncSumReduceNode(AsyncNode):
|
||||
async def prep_async(self, shared_storage):
|
||||
# Get chunk results from shared storage
|
||||
chunk_results = shared_storage.get('chunk_results', [])
|
||||
await asyncio.sleep(0.01) # Simulate async processing
|
||||
total = sum(chunk_results)
|
||||
shared_storage['total'] = total
|
||||
return "reduced"
|
||||
|
||||
class TestAsyncBatchNode(unittest.TestCase):
|
||||
def test_array_chunking(self):
|
||||
"""
|
||||
Test that the array is correctly split into chunks and processed asynchronously
|
||||
"""
|
||||
shared_storage = {
|
||||
'input_array': list(range(25)) # [0,1,2,...,24]
|
||||
}
|
||||
|
||||
chunk_node = AsyncArrayChunkNode(chunk_size=10)
|
||||
asyncio.run(chunk_node.run_async(shared_storage))
|
||||
|
||||
results = shared_storage['chunk_results']
|
||||
self.assertEqual(results, [45, 145, 110]) # Sum of chunks [0-9], [10-19], [20-24]
|
||||
|
||||
# def test_async_map_reduce_sum(self):
|
||||
# """
|
||||
# Test a complete async map-reduce pipeline that sums a large array:
|
||||
# 1. Map: Split array into chunks and sum each chunk asynchronously
|
||||
# 2. Reduce: Sum all the chunk sums asynchronously
|
||||
# """
|
||||
# array = list(range(100))
|
||||
# expected_sum = sum(array) # 4950
|
||||
|
||||
# shared_storage = {
|
||||
# 'input_array': array
|
||||
# }
|
||||
|
||||
# # Create nodes
|
||||
# chunk_node = AsyncArrayChunkNode(chunk_size=10)
|
||||
# reduce_node = AsyncSumReduceNode()
|
||||
|
||||
# # Connect nodes
|
||||
# chunk_node - "processed" >> reduce_node
|
||||
|
||||
# # Create and run pipeline
|
||||
# pipeline = AsyncFlow(start=chunk_node)
|
||||
# asyncio.run(pipeline.run_async(shared_storage))
|
||||
|
||||
# self.assertEqual(shared_storage['total'], expected_sum)
|
||||
|
||||
# def test_uneven_chunks(self):
|
||||
# """
|
||||
# Test that the async map-reduce works correctly with array lengths
|
||||
# that don't divide evenly by chunk_size
|
||||
# """
|
||||
# array = list(range(25))
|
||||
# expected_sum = sum(array) # 300
|
||||
|
||||
# shared_storage = {
|
||||
# 'input_array': array
|
||||
# }
|
||||
|
||||
# chunk_node = AsyncArrayChunkNode(chunk_size=10)
|
||||
# reduce_node = AsyncSumReduceNode()
|
||||
|
||||
# chunk_node - "processed" >> reduce_node
|
||||
# pipeline = AsyncFlow(start=chunk_node)
|
||||
# asyncio.run(pipeline.run_async(shared_storage))
|
||||
|
||||
# self.assertEqual(shared_storage['total'], expected_sum)
|
||||
|
||||
# def test_custom_chunk_size(self):
|
||||
# """
|
||||
# Test that the async map-reduce works with different chunk sizes
|
||||
# """
|
||||
# array = list(range(100))
|
||||
# expected_sum = sum(array)
|
||||
|
||||
# shared_storage = {
|
||||
# 'input_array': array
|
||||
# }
|
||||
|
||||
# # Use chunk_size=15 instead of default 10
|
||||
# chunk_node = AsyncArrayChunkNode(chunk_size=15)
|
||||
# reduce_node = AsyncSumReduceNode()
|
||||
|
||||
# chunk_node - "processed" >> reduce_node
|
||||
# pipeline = AsyncFlow(start=chunk_node)
|
||||
# asyncio.run(pipeline.run_async(shared_storage))
|
||||
|
||||
# self.assertEqual(shared_storage['total'], expected_sum)
|
||||
|
||||
# def test_single_element_chunks(self):
|
||||
# """
|
||||
# Test extreme case where chunk_size=1
|
||||
# """
|
||||
# array = list(range(5))
|
||||
# expected_sum = sum(array)
|
||||
|
||||
# shared_storage = {
|
||||
# 'input_array': array
|
||||
# }
|
||||
|
||||
# chunk_node = AsyncArrayChunkNode(chunk_size=1)
|
||||
# reduce_node = AsyncSumReduceNode()
|
||||
|
||||
# chunk_node - "processed" >> reduce_node
|
||||
# pipeline = AsyncFlow(start=chunk_node)
|
||||
# asyncio.run(pipeline.run_async(shared_storage))
|
||||
|
||||
# self.assertEqual(shared_storage['total'], expected_sum)
|
||||
|
||||
# def test_empty_array(self):
|
||||
# """
|
||||
# Test edge case of empty input array
|
||||
# """
|
||||
# shared_storage = {
|
||||
# 'input_array': []
|
||||
# }
|
||||
|
||||
# chunk_node = AsyncArrayChunkNode(chunk_size=10)
|
||||
# reduce_node = AsyncSumReduceNode()
|
||||
|
||||
# chunk_node - "processed" >> reduce_node
|
||||
# pipeline = AsyncFlow(start=chunk_node)
|
||||
# asyncio.run(pipeline.run_async(shared_storage))
|
||||
|
||||
# self.assertEqual(shared_storage['total'], 0)
|
||||
|
||||
# def test_error_handling(self):
|
||||
# """
|
||||
# Test error handling in async batch processing
|
||||
# """
|
||||
# class ErrorAsyncBatchNode(AsyncBatchNode):
|
||||
# async def exec_async(self, item):
|
||||
# if item == 2:
|
||||
# raise ValueError("Error processing item 2")
|
||||
# return item
|
||||
|
||||
# shared_storage = {
|
||||
# 'input_array': [1, 2, 3]
|
||||
# }
|
||||
|
||||
# error_node = ErrorAsyncBatchNode()
|
||||
# with self.assertRaises(ValueError):
|
||||
# asyncio.run(error_node.run_async(shared_storage))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
import unittest
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from minillmflow import AsyncNode, AsyncParallelBatchNode, AsyncParallelBatchFlow
|
||||
|
||||
class AsyncParallelNumberProcessor(AsyncParallelBatchNode):
|
||||
def __init__(self, delay=0.1):
|
||||
super().__init__()
|
||||
self.delay = delay
|
||||
|
||||
async def prep_async(self, shared_storage):
|
||||
batch = shared_storage['batches'][self.params['batch_id']]
|
||||
return batch
|
||||
|
||||
async def exec_async(self, number):
|
||||
await asyncio.sleep(self.delay) # Simulate async processing
|
||||
return number * 2
|
||||
|
||||
async def post_async(self, shared_storage, prep_result, exec_result):
|
||||
if 'processed_numbers' not in shared_storage:
|
||||
shared_storage['processed_numbers'] = {}
|
||||
shared_storage['processed_numbers'][self.params['batch_id']] = exec_result
|
||||
return "processed"
|
||||
|
||||
class AsyncAggregatorNode(AsyncNode):
|
||||
async def prep_async(self, shared_storage):
|
||||
# Combine all batch results in order
|
||||
all_results = []
|
||||
processed = shared_storage.get('processed_numbers', {})
|
||||
for i in range(len(processed)):
|
||||
all_results.extend(processed[i])
|
||||
return all_results
|
||||
|
||||
async def exec_async(self, prep_result):
|
||||
await asyncio.sleep(0.01)
|
||||
return sum(prep_result)
|
||||
|
||||
async def post_async(self, shared_storage, prep_result, exec_result):
|
||||
shared_storage['total'] = exec_result
|
||||
return "aggregated"
|
||||
|
||||
class TestAsyncParallelBatchFlow(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close()
|
||||
|
||||
def test_parallel_batch_flow(self):
|
||||
"""
|
||||
Test basic parallel batch processing flow with batch IDs
|
||||
"""
|
||||
class TestParallelBatchFlow(AsyncParallelBatchFlow):
|
||||
async def prep_async(self, shared_storage):
|
||||
return [{'batch_id': i} for i in range(len(shared_storage['batches']))]
|
||||
|
||||
shared_storage = {
|
||||
'batches': [
|
||||
[1, 2, 3], # batch_id: 0
|
||||
[4, 5, 6], # batch_id: 1
|
||||
[7, 8, 9] # batch_id: 2
|
||||
]
|
||||
}
|
||||
|
||||
processor = AsyncParallelNumberProcessor(delay=0.1)
|
||||
aggregator = AsyncAggregatorNode()
|
||||
|
||||
processor - "processed" >> aggregator
|
||||
flow = TestParallelBatchFlow(start=processor)
|
||||
|
||||
start_time = self.loop.time()
|
||||
self.loop.run_until_complete(flow.run_async(shared_storage))
|
||||
execution_time = self.loop.time() - start_time
|
||||
|
||||
# Verify each batch was processed correctly
|
||||
expected_batch_results = {
|
||||
0: [2, 4, 6], # [1,2,3] * 2
|
||||
1: [8, 10, 12], # [4,5,6] * 2
|
||||
2: [14, 16, 18] # [7,8,9] * 2
|
||||
}
|
||||
self.assertEqual(shared_storage['processed_numbers'], expected_batch_results)
|
||||
|
||||
# Verify total
|
||||
expected_total = sum(num * 2 for batch in shared_storage['batches'] for num in batch)
|
||||
self.assertEqual(shared_storage['total'], expected_total)
|
||||
|
||||
# Verify parallel execution
|
||||
self.assertLess(execution_time, 0.2)
|
||||
|
||||
def test_error_handling(self):
|
||||
"""
|
||||
Test error handling in parallel batch flow
|
||||
"""
|
||||
class ErrorProcessor(AsyncParallelNumberProcessor):
|
||||
async def exec_async(self, item):
|
||||
if item == 2:
|
||||
raise ValueError(f"Error processing item {item}")
|
||||
return item
|
||||
|
||||
class ErrorBatchFlow(AsyncParallelBatchFlow):
|
||||
async def prep_async(self, shared_storage):
|
||||
return [{'batch_id': i} for i in range(len(shared_storage['batches']))]
|
||||
|
||||
shared_storage = {
|
||||
'batches': [
|
||||
[1, 2, 3], # Contains error-triggering value
|
||||
[4, 5, 6]
|
||||
]
|
||||
}
|
||||
|
||||
processor = ErrorProcessor()
|
||||
flow = ErrorBatchFlow(start=processor)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.loop.run_until_complete(flow.run_async(shared_storage))
|
||||
|
||||
def test_multiple_batch_sizes(self):
|
||||
"""
|
||||
Test parallel batch flow with varying batch sizes
|
||||
"""
|
||||
class VaryingBatchFlow(AsyncParallelBatchFlow):
|
||||
async def prep_async(self, shared_storage):
|
||||
return [{'batch_id': i} for i in range(len(shared_storage['batches']))]
|
||||
|
||||
shared_storage = {
|
||||
'batches': [
|
||||
[1], # batch_id: 0
|
||||
[2, 3, 4], # batch_id: 1
|
||||
[5, 6], # batch_id: 2
|
||||
[7, 8, 9, 10] # batch_id: 3
|
||||
]
|
||||
}
|
||||
|
||||
processor = AsyncParallelNumberProcessor(delay=0.05)
|
||||
aggregator = AsyncAggregatorNode()
|
||||
|
||||
processor - "processed" >> aggregator
|
||||
flow = VaryingBatchFlow(start=processor)
|
||||
|
||||
self.loop.run_until_complete(flow.run_async(shared_storage))
|
||||
|
||||
# Verify each batch was processed correctly
|
||||
expected_batch_results = {
|
||||
0: [2], # [1] * 2
|
||||
1: [4, 6, 8], # [2,3,4] * 2
|
||||
2: [10, 12], # [5,6] * 2
|
||||
3: [14, 16, 18, 20] # [7,8,9,10] * 2
|
||||
}
|
||||
self.assertEqual(shared_storage['processed_numbers'], expected_batch_results)
|
||||
|
||||
# Verify total
|
||||
expected_total = sum(num * 2 for batch in shared_storage['batches'] for num in batch)
|
||||
self.assertEqual(shared_storage['total'], expected_total)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
import unittest
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from minillmflow import AsyncParallelBatchNode, AsyncParallelBatchFlow
|
||||
|
||||
class AsyncParallelNumberProcessor(AsyncParallelBatchNode):
|
||||
def __init__(self, delay=0.1):
|
||||
super().__init__()
|
||||
self.delay = delay
|
||||
|
||||
async def prep_async(self, shared_storage):
|
||||
numbers = shared_storage.get('input_numbers', [])
|
||||
return numbers
|
||||
|
||||
async def exec_async(self, number):
|
||||
await asyncio.sleep(self.delay) # Simulate async processing
|
||||
return number * 2
|
||||
|
||||
async def post_async(self, shared_storage, prep_result, exec_result):
|
||||
shared_storage['processed_numbers'] = exec_result
|
||||
return "processed"
|
||||
|
||||
class TestAsyncParallelBatchNode(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Reset the event loop for each test
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close()
|
||||
|
||||
def test_parallel_processing(self):
|
||||
"""
|
||||
Test that numbers are processed in parallel by measuring execution time
|
||||
"""
|
||||
shared_storage = {
|
||||
'input_numbers': list(range(5))
|
||||
}
|
||||
|
||||
processor = AsyncParallelNumberProcessor(delay=0.1)
|
||||
|
||||
# Run the processor
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
self.loop.run_until_complete(processor.run_async(shared_storage))
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Check results
|
||||
expected = [0, 2, 4, 6, 8] # Each number doubled
|
||||
self.assertEqual(shared_storage['processed_numbers'], expected)
|
||||
|
||||
# Since processing is parallel, total time should be approximately
|
||||
# equal to the delay of a single operation, not delay * number_of_items
|
||||
execution_time = end_time - start_time
|
||||
self.assertLess(execution_time, 0.2) # Should be around 0.1s plus minimal overhead
|
||||
|
||||
def test_empty_input(self):
|
||||
"""
|
||||
Test processing of empty input
|
||||
"""
|
||||
shared_storage = {
|
||||
'input_numbers': []
|
||||
}
|
||||
|
||||
processor = AsyncParallelNumberProcessor()
|
||||
self.loop.run_until_complete(processor.run_async(shared_storage))
|
||||
|
||||
self.assertEqual(shared_storage['processed_numbers'], [])
|
||||
|
||||
def test_single_item(self):
|
||||
"""
|
||||
Test processing of a single item
|
||||
"""
|
||||
shared_storage = {
|
||||
'input_numbers': [42]
|
||||
}
|
||||
|
||||
processor = AsyncParallelNumberProcessor()
|
||||
self.loop.run_until_complete(processor.run_async(shared_storage))
|
||||
|
||||
self.assertEqual(shared_storage['processed_numbers'], [84])
|
||||
|
||||
def test_large_batch(self):
|
||||
"""
|
||||
Test processing of a large batch of numbers
|
||||
"""
|
||||
input_size = 100
|
||||
shared_storage = {
|
||||
'input_numbers': list(range(input_size))
|
||||
}
|
||||
|
||||
processor = AsyncParallelNumberProcessor(delay=0.01)
|
||||
self.loop.run_until_complete(processor.run_async(shared_storage))
|
||||
|
||||
expected = [x * 2 for x in range(input_size)]
|
||||
self.assertEqual(shared_storage['processed_numbers'], expected)
|
||||
|
||||
def test_error_handling(self):
|
||||
"""
|
||||
Test error handling during parallel processing
|
||||
"""
|
||||
class ErrorProcessor(AsyncParallelNumberProcessor):
|
||||
async def exec_async(self, item):
|
||||
if item == 2:
|
||||
raise ValueError(f"Error processing item {item}")
|
||||
return item
|
||||
|
||||
shared_storage = {
|
||||
'input_numbers': [1, 2, 3]
|
||||
}
|
||||
|
||||
processor = ErrorProcessor()
|
||||
with self.assertRaises(ValueError):
|
||||
self.loop.run_until_complete(processor.run_async(shared_storage))
|
||||
|
||||
def test_concurrent_execution(self):
|
||||
"""
|
||||
Test that tasks are actually running concurrently by tracking execution order
|
||||
"""
|
||||
execution_order = []
|
||||
|
||||
class OrderTrackingProcessor(AsyncParallelNumberProcessor):
|
||||
async def exec_async(self, item):
|
||||
delay = 0.1 if item % 2 == 0 else 0.05
|
||||
await asyncio.sleep(delay)
|
||||
execution_order.append(item)
|
||||
return item
|
||||
|
||||
shared_storage = {
|
||||
'input_numbers': list(range(4)) # [0, 1, 2, 3]
|
||||
}
|
||||
|
||||
processor = OrderTrackingProcessor()
|
||||
self.loop.run_until_complete(processor.run_async(shared_storage))
|
||||
|
||||
# Odd numbers should finish before even numbers due to shorter delay
|
||||
self.assertLess(execution_order.index(1), execution_order.index(0))
|
||||
self.assertLess(execution_order.index(3), execution_order.index(2))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue