160 lines
5.5 KiB
Python
160 lines
5.5 KiB
Python
import unittest
|
|
import asyncio
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
from pocketflow import AsyncNode, AsyncParallelBatchNode, AsyncParallelBatchFlow
|
|
|
|
class AsyncParallelNumberProcessor(AsyncParallelBatchNode):
|
|
def __init__(self, delay=0.1):
|
|
super().__init__()
|
|
self.delay = delay
|
|
|
|
async def prep_async(self, shared_storage):
|
|
batch = shared_storage['batches'][self.params['batch_id']]
|
|
return batch
|
|
|
|
async def exec_async(self, number):
|
|
await asyncio.sleep(self.delay) # Simulate async processing
|
|
return number * 2
|
|
|
|
async def post_async(self, shared_storage, prep_result, exec_result):
|
|
if 'processed_numbers' not in shared_storage:
|
|
shared_storage['processed_numbers'] = {}
|
|
shared_storage['processed_numbers'][self.params['batch_id']] = exec_result
|
|
return "processed"
|
|
|
|
class AsyncAggregatorNode(AsyncNode):
|
|
async def prep_async(self, shared_storage):
|
|
# Combine all batch results in order
|
|
all_results = []
|
|
processed = shared_storage.get('processed_numbers', {})
|
|
for i in range(len(processed)):
|
|
all_results.extend(processed[i])
|
|
return all_results
|
|
|
|
async def exec_async(self, prep_result):
|
|
await asyncio.sleep(0.01)
|
|
return sum(prep_result)
|
|
|
|
async def post_async(self, shared_storage, prep_result, exec_result):
|
|
shared_storage['total'] = exec_result
|
|
return "aggregated"
|
|
|
|
class TestAsyncParallelBatchFlow(unittest.TestCase):
|
|
def setUp(self):
|
|
self.loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self.loop)
|
|
|
|
def tearDown(self):
|
|
self.loop.close()
|
|
|
|
def test_parallel_batch_flow(self):
|
|
"""
|
|
Test basic parallel batch processing flow with batch IDs
|
|
"""
|
|
class TestParallelBatchFlow(AsyncParallelBatchFlow):
|
|
async def prep_async(self, shared_storage):
|
|
return [{'batch_id': i} for i in range(len(shared_storage['batches']))]
|
|
|
|
shared_storage = {
|
|
'batches': [
|
|
[1, 2, 3], # batch_id: 0
|
|
[4, 5, 6], # batch_id: 1
|
|
[7, 8, 9] # batch_id: 2
|
|
]
|
|
}
|
|
|
|
processor = AsyncParallelNumberProcessor(delay=0.1)
|
|
aggregator = AsyncAggregatorNode()
|
|
|
|
processor - "processed" >> aggregator
|
|
flow = TestParallelBatchFlow(start=processor)
|
|
|
|
start_time = self.loop.time()
|
|
self.loop.run_until_complete(flow.run_async(shared_storage))
|
|
execution_time = self.loop.time() - start_time
|
|
|
|
# Verify each batch was processed correctly
|
|
expected_batch_results = {
|
|
0: [2, 4, 6], # [1,2,3] * 2
|
|
1: [8, 10, 12], # [4,5,6] * 2
|
|
2: [14, 16, 18] # [7,8,9] * 2
|
|
}
|
|
self.assertEqual(shared_storage['processed_numbers'], expected_batch_results)
|
|
|
|
# Verify total
|
|
expected_total = sum(num * 2 for batch in shared_storage['batches'] for num in batch)
|
|
self.assertEqual(shared_storage['total'], expected_total)
|
|
|
|
# Verify parallel execution
|
|
self.assertLess(execution_time, 0.2)
|
|
|
|
def test_error_handling(self):
|
|
"""
|
|
Test error handling in parallel batch flow
|
|
"""
|
|
class ErrorProcessor(AsyncParallelNumberProcessor):
|
|
async def exec_async(self, item):
|
|
if item == 2:
|
|
raise ValueError(f"Error processing item {item}")
|
|
return item
|
|
|
|
class ErrorBatchFlow(AsyncParallelBatchFlow):
|
|
async def prep_async(self, shared_storage):
|
|
return [{'batch_id': i} for i in range(len(shared_storage['batches']))]
|
|
|
|
shared_storage = {
|
|
'batches': [
|
|
[1, 2, 3], # Contains error-triggering value
|
|
[4, 5, 6]
|
|
]
|
|
}
|
|
|
|
processor = ErrorProcessor()
|
|
flow = ErrorBatchFlow(start=processor)
|
|
|
|
with self.assertRaises(ValueError):
|
|
self.loop.run_until_complete(flow.run_async(shared_storage))
|
|
|
|
def test_multiple_batch_sizes(self):
|
|
"""
|
|
Test parallel batch flow with varying batch sizes
|
|
"""
|
|
class VaryingBatchFlow(AsyncParallelBatchFlow):
|
|
async def prep_async(self, shared_storage):
|
|
return [{'batch_id': i} for i in range(len(shared_storage['batches']))]
|
|
|
|
shared_storage = {
|
|
'batches': [
|
|
[1], # batch_id: 0
|
|
[2, 3, 4], # batch_id: 1
|
|
[5, 6], # batch_id: 2
|
|
[7, 8, 9, 10] # batch_id: 3
|
|
]
|
|
}
|
|
|
|
processor = AsyncParallelNumberProcessor(delay=0.05)
|
|
aggregator = AsyncAggregatorNode()
|
|
|
|
processor - "processed" >> aggregator
|
|
flow = VaryingBatchFlow(start=processor)
|
|
|
|
self.loop.run_until_complete(flow.run_async(shared_storage))
|
|
|
|
# Verify each batch was processed correctly
|
|
expected_batch_results = {
|
|
0: [2], # [1] * 2
|
|
1: [4, 6, 8], # [2,3,4] * 2
|
|
2: [10, 12], # [5,6] * 2
|
|
3: [14, 16, 18, 20] # [7,8,9,10] * 2
|
|
}
|
|
self.assertEqual(shared_storage['processed_numbers'], expected_batch_results)
|
|
|
|
# Verify total
|
|
expected_total = sum(num * 2 for batch in shared_storage['batches'] for num in batch)
|
|
self.assertEqual(shared_storage['total'], expected_total)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |