pocketflow/tests/test_async_parallel_batch_n...

143 lines
4.8 KiB
Python

import unittest
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from minillmflow import AsyncParallelBatchNode, AsyncParallelBatchFlow
class AsyncParallelNumberProcessor(AsyncParallelBatchNode):
def __init__(self, delay=0.1):
super().__init__()
self.delay = delay
async def prep_async(self, shared_storage):
numbers = shared_storage.get('input_numbers', [])
return numbers
async def exec_async(self, number):
await asyncio.sleep(self.delay) # Simulate async processing
return number * 2
async def post_async(self, shared_storage, prep_result, exec_result):
shared_storage['processed_numbers'] = exec_result
return "processed"
class TestAsyncParallelBatchNode(unittest.TestCase):
def setUp(self):
# Reset the event loop for each test
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
def test_parallel_processing(self):
"""
Test that numbers are processed in parallel by measuring execution time
"""
shared_storage = {
'input_numbers': list(range(5))
}
processor = AsyncParallelNumberProcessor(delay=0.1)
# Run the processor
start_time = asyncio.get_event_loop().time()
self.loop.run_until_complete(processor.run_async(shared_storage))
end_time = asyncio.get_event_loop().time()
# Check results
expected = [0, 2, 4, 6, 8] # Each number doubled
self.assertEqual(shared_storage['processed_numbers'], expected)
# Since processing is parallel, total time should be approximately
# equal to the delay of a single operation, not delay * number_of_items
execution_time = end_time - start_time
self.assertLess(execution_time, 0.2) # Should be around 0.1s plus minimal overhead
def test_empty_input(self):
"""
Test processing of empty input
"""
shared_storage = {
'input_numbers': []
}
processor = AsyncParallelNumberProcessor()
self.loop.run_until_complete(processor.run_async(shared_storage))
self.assertEqual(shared_storage['processed_numbers'], [])
def test_single_item(self):
"""
Test processing of a single item
"""
shared_storage = {
'input_numbers': [42]
}
processor = AsyncParallelNumberProcessor()
self.loop.run_until_complete(processor.run_async(shared_storage))
self.assertEqual(shared_storage['processed_numbers'], [84])
def test_large_batch(self):
"""
Test processing of a large batch of numbers
"""
input_size = 100
shared_storage = {
'input_numbers': list(range(input_size))
}
processor = AsyncParallelNumberProcessor(delay=0.01)
self.loop.run_until_complete(processor.run_async(shared_storage))
expected = [x * 2 for x in range(input_size)]
self.assertEqual(shared_storage['processed_numbers'], expected)
def test_error_handling(self):
"""
Test error handling during parallel processing
"""
class ErrorProcessor(AsyncParallelNumberProcessor):
async def exec_async(self, item):
if item == 2:
raise ValueError(f"Error processing item {item}")
return item
shared_storage = {
'input_numbers': [1, 2, 3]
}
processor = ErrorProcessor()
with self.assertRaises(ValueError):
self.loop.run_until_complete(processor.run_async(shared_storage))
def test_concurrent_execution(self):
"""
Test that tasks are actually running concurrently by tracking execution order
"""
execution_order = []
class OrderTrackingProcessor(AsyncParallelNumberProcessor):
async def exec_async(self, item):
delay = 0.1 if item % 2 == 0 else 0.05
await asyncio.sleep(delay)
execution_order.append(item)
return item
shared_storage = {
'input_numbers': list(range(4)) # [0, 1, 2, 3]
}
processor = OrderTrackingProcessor()
self.loop.run_until_complete(processor.run_async(shared_storage))
# Odd numbers should finish before even numbers due to shorter delay
self.assertLess(execution_order.index(1), execution_order.index(0))
self.assertLess(execution_order.index(3), execution_order.index(2))
if __name__ == '__main__':
unittest.main()