pocketflow/tests/test_async_batch_node.py

181 lines
6.0 KiB
Python

import unittest
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from minillmflow import AsyncNode, AsyncBatchNode, AsyncFlow
class AsyncArrayChunkNode(AsyncBatchNode):
def __init__(self, chunk_size=10):
super().__init__()
self.chunk_size = chunk_size
async def prep_async(self, shared_storage):
# Get array from shared storage and split into chunks
array = shared_storage.get('input_array', [])
chunks = []
for start in range(0, len(array), self.chunk_size):
end = min(start + self.chunk_size, len(array))
chunks.append(array[start:end])
return chunks
async def exec_async(self, chunk):
# Simulate async processing of each chunk
await asyncio.sleep(0.01)
return sum(chunk)
async def post_async(self, shared_storage, prep_result, proc_result):
# Store chunk results in shared storage
shared_storage['chunk_results'] = proc_result
return "processed"
class AsyncSumReduceNode(AsyncNode):
async def prep_async(self, shared_storage):
# Get chunk results from shared storage
chunk_results = shared_storage.get('chunk_results', [])
await asyncio.sleep(0.01) # Simulate async processing
total = sum(chunk_results)
shared_storage['total'] = total
return "reduced"
class TestAsyncBatchNode(unittest.TestCase):
def test_array_chunking(self):
"""
Test that the array is correctly split into chunks and processed asynchronously
"""
shared_storage = {
'input_array': list(range(25)) # [0,1,2,...,24]
}
chunk_node = AsyncArrayChunkNode(chunk_size=10)
asyncio.run(chunk_node.run_async(shared_storage))
results = shared_storage['chunk_results']
self.assertEqual(results, [45, 145, 110]) # Sum of chunks [0-9], [10-19], [20-24]
# def test_async_map_reduce_sum(self):
# """
# Test a complete async map-reduce pipeline that sums a large array:
# 1. Map: Split array into chunks and sum each chunk asynchronously
# 2. Reduce: Sum all the chunk sums asynchronously
# """
# array = list(range(100))
# expected_sum = sum(array) # 4950
# shared_storage = {
# 'input_array': array
# }
# # Create nodes
# chunk_node = AsyncArrayChunkNode(chunk_size=10)
# reduce_node = AsyncSumReduceNode()
# # Connect nodes
# chunk_node - "processed" >> reduce_node
# # Create and run pipeline
# pipeline = AsyncFlow(start=chunk_node)
# asyncio.run(pipeline.run_async(shared_storage))
# self.assertEqual(shared_storage['total'], expected_sum)
# def test_uneven_chunks(self):
# """
# Test that the async map-reduce works correctly with array lengths
# that don't divide evenly by chunk_size
# """
# array = list(range(25))
# expected_sum = sum(array) # 300
# shared_storage = {
# 'input_array': array
# }
# chunk_node = AsyncArrayChunkNode(chunk_size=10)
# reduce_node = AsyncSumReduceNode()
# chunk_node - "processed" >> reduce_node
# pipeline = AsyncFlow(start=chunk_node)
# asyncio.run(pipeline.run_async(shared_storage))
# self.assertEqual(shared_storage['total'], expected_sum)
# def test_custom_chunk_size(self):
# """
# Test that the async map-reduce works with different chunk sizes
# """
# array = list(range(100))
# expected_sum = sum(array)
# shared_storage = {
# 'input_array': array
# }
# # Use chunk_size=15 instead of default 10
# chunk_node = AsyncArrayChunkNode(chunk_size=15)
# reduce_node = AsyncSumReduceNode()
# chunk_node - "processed" >> reduce_node
# pipeline = AsyncFlow(start=chunk_node)
# asyncio.run(pipeline.run_async(shared_storage))
# self.assertEqual(shared_storage['total'], expected_sum)
# def test_single_element_chunks(self):
# """
# Test extreme case where chunk_size=1
# """
# array = list(range(5))
# expected_sum = sum(array)
# shared_storage = {
# 'input_array': array
# }
# chunk_node = AsyncArrayChunkNode(chunk_size=1)
# reduce_node = AsyncSumReduceNode()
# chunk_node - "processed" >> reduce_node
# pipeline = AsyncFlow(start=chunk_node)
# asyncio.run(pipeline.run_async(shared_storage))
# self.assertEqual(shared_storage['total'], expected_sum)
# def test_empty_array(self):
# """
# Test edge case of empty input array
# """
# shared_storage = {
# 'input_array': []
# }
# chunk_node = AsyncArrayChunkNode(chunk_size=10)
# reduce_node = AsyncSumReduceNode()
# chunk_node - "processed" >> reduce_node
# pipeline = AsyncFlow(start=chunk_node)
# asyncio.run(pipeline.run_async(shared_storage))
# self.assertEqual(shared_storage['total'], 0)
# def test_error_handling(self):
# """
# Test error handling in async batch processing
# """
# class ErrorAsyncBatchNode(AsyncBatchNode):
# async def exec_async(self, item):
# if item == 2:
# raise ValueError("Error processing item 2")
# return item
# shared_storage = {
# 'input_array': [1, 2, 3]
# }
# error_node = ErrorAsyncBatchNode()
# with self.assertRaises(ValueError):
# asyncio.run(error_node.run_async(shared_storage))
if __name__ == '__main__':
unittest.main()