143 lines
4.8 KiB
Python
143 lines
4.8 KiB
Python
import unittest
|
|
import asyncio
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
from pocketflow import AsyncParallelBatchNode, AsyncParallelBatchFlow
|
|
|
|
class AsyncParallelNumberProcessor(AsyncParallelBatchNode):
|
|
def __init__(self, delay=0.1):
|
|
super().__init__()
|
|
self.delay = delay
|
|
|
|
async def prep_async(self, shared_storage):
|
|
numbers = shared_storage.get('input_numbers', [])
|
|
return numbers
|
|
|
|
async def exec_async(self, number):
|
|
await asyncio.sleep(self.delay) # Simulate async processing
|
|
return number * 2
|
|
|
|
async def post_async(self, shared_storage, prep_result, exec_result):
|
|
shared_storage['processed_numbers'] = exec_result
|
|
return "processed"
|
|
|
|
class TestAsyncParallelBatchNode(unittest.TestCase):
|
|
def setUp(self):
|
|
# Reset the event loop for each test
|
|
self.loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self.loop)
|
|
|
|
def tearDown(self):
|
|
self.loop.close()
|
|
|
|
def test_parallel_processing(self):
|
|
"""
|
|
Test that numbers are processed in parallel by measuring execution time
|
|
"""
|
|
shared_storage = {
|
|
'input_numbers': list(range(5))
|
|
}
|
|
|
|
processor = AsyncParallelNumberProcessor(delay=0.1)
|
|
|
|
# Run the processor
|
|
start_time = asyncio.get_event_loop().time()
|
|
self.loop.run_until_complete(processor.run_async(shared_storage))
|
|
end_time = asyncio.get_event_loop().time()
|
|
|
|
# Check results
|
|
expected = [0, 2, 4, 6, 8] # Each number doubled
|
|
self.assertEqual(shared_storage['processed_numbers'], expected)
|
|
|
|
# Since processing is parallel, total time should be approximately
|
|
# equal to the delay of a single operation, not delay * number_of_items
|
|
execution_time = end_time - start_time
|
|
self.assertLess(execution_time, 0.2) # Should be around 0.1s plus minimal overhead
|
|
|
|
def test_empty_input(self):
|
|
"""
|
|
Test processing of empty input
|
|
"""
|
|
shared_storage = {
|
|
'input_numbers': []
|
|
}
|
|
|
|
processor = AsyncParallelNumberProcessor()
|
|
self.loop.run_until_complete(processor.run_async(shared_storage))
|
|
|
|
self.assertEqual(shared_storage['processed_numbers'], [])
|
|
|
|
def test_single_item(self):
|
|
"""
|
|
Test processing of a single item
|
|
"""
|
|
shared_storage = {
|
|
'input_numbers': [42]
|
|
}
|
|
|
|
processor = AsyncParallelNumberProcessor()
|
|
self.loop.run_until_complete(processor.run_async(shared_storage))
|
|
|
|
self.assertEqual(shared_storage['processed_numbers'], [84])
|
|
|
|
def test_large_batch(self):
|
|
"""
|
|
Test processing of a large batch of numbers
|
|
"""
|
|
input_size = 100
|
|
shared_storage = {
|
|
'input_numbers': list(range(input_size))
|
|
}
|
|
|
|
processor = AsyncParallelNumberProcessor(delay=0.01)
|
|
self.loop.run_until_complete(processor.run_async(shared_storage))
|
|
|
|
expected = [x * 2 for x in range(input_size)]
|
|
self.assertEqual(shared_storage['processed_numbers'], expected)
|
|
|
|
def test_error_handling(self):
|
|
"""
|
|
Test error handling during parallel processing
|
|
"""
|
|
class ErrorProcessor(AsyncParallelNumberProcessor):
|
|
async def exec_async(self, item):
|
|
if item == 2:
|
|
raise ValueError(f"Error processing item {item}")
|
|
return item
|
|
|
|
shared_storage = {
|
|
'input_numbers': [1, 2, 3]
|
|
}
|
|
|
|
processor = ErrorProcessor()
|
|
with self.assertRaises(ValueError):
|
|
self.loop.run_until_complete(processor.run_async(shared_storage))
|
|
|
|
def test_concurrent_execution(self):
|
|
"""
|
|
Test that tasks are actually running concurrently by tracking execution order
|
|
"""
|
|
execution_order = []
|
|
|
|
class OrderTrackingProcessor(AsyncParallelNumberProcessor):
|
|
async def exec_async(self, item):
|
|
delay = 0.1 if item % 2 == 0 else 0.05
|
|
await asyncio.sleep(delay)
|
|
execution_order.append(item)
|
|
return item
|
|
|
|
shared_storage = {
|
|
'input_numbers': list(range(4)) # [0, 1, 2, 3]
|
|
}
|
|
|
|
processor = OrderTrackingProcessor()
|
|
self.loop.run_until_complete(processor.run_async(shared_storage))
|
|
|
|
# Odd numbers should finish before even numbers due to shorter delay
|
|
self.assertLess(execution_order.index(1), execution_order.index(0))
|
|
self.assertLess(execution_order.index(3), execution_order.index(2))
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |