176 lines
6.2 KiB
Python
176 lines
6.2 KiB
Python
import unittest
|
|
import asyncio
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
from minillmflow import AsyncNode, BatchAsyncFlow
|
|
|
|
class AsyncDataProcessNode(AsyncNode):
|
|
def exec(self, shared_storage, prep_result):
|
|
key = self.params.get('key')
|
|
data = shared_storage['input_data'][key]
|
|
if 'results' not in shared_storage:
|
|
shared_storage['results'] = {}
|
|
shared_storage['results'][key] = data
|
|
return data
|
|
|
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
|
await asyncio.sleep(0.01) # Simulate async work
|
|
key = self.params.get('key')
|
|
shared_storage['results'][key] = proc_result * 2 # Double the value
|
|
return "processed"
|
|
|
|
class AsyncErrorNode(AsyncNode):
|
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
|
key = self.params.get('key')
|
|
if key == 'error_key':
|
|
raise ValueError(f"Async error processing key: {key}")
|
|
return "processed"
|
|
|
|
class TestAsyncBatchFlow(unittest.TestCase):
|
|
def setUp(self):
|
|
self.process_node = AsyncDataProcessNode()
|
|
|
|
def test_basic_async_batch_processing(self):
|
|
"""Test basic async batch processing with multiple keys"""
|
|
class SimpleTestAsyncBatchFlow(BatchAsyncFlow):
|
|
def prep(self, shared_storage):
|
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
|
|
|
shared_storage = {
|
|
'input_data': {
|
|
'a': 1,
|
|
'b': 2,
|
|
'c': 3
|
|
}
|
|
}
|
|
|
|
flow = SimpleTestAsyncBatchFlow(start=self.process_node)
|
|
asyncio.run(flow.run_async(shared_storage))
|
|
|
|
expected_results = {
|
|
'a': 2, # 1 * 2
|
|
'b': 4, # 2 * 2
|
|
'c': 6 # 3 * 2
|
|
}
|
|
self.assertEqual(shared_storage['results'], expected_results)
|
|
|
|
def test_empty_async_batch(self):
|
|
"""Test async batch processing with empty input"""
|
|
class EmptyTestAsyncBatchFlow(BatchAsyncFlow):
|
|
def prep(self, shared_storage):
|
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
|
|
|
shared_storage = {
|
|
'input_data': {}
|
|
}
|
|
|
|
flow = EmptyTestAsyncBatchFlow(start=self.process_node)
|
|
asyncio.run(flow.run_async(shared_storage))
|
|
|
|
self.assertEqual(shared_storage.get('results', {}), {})
|
|
|
|
def test_async_error_handling(self):
|
|
"""Test error handling during async batch processing"""
|
|
class ErrorTestAsyncBatchFlow(BatchAsyncFlow):
|
|
def prep(self, shared_storage):
|
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
|
|
|
shared_storage = {
|
|
'input_data': {
|
|
'normal_key': 1,
|
|
'error_key': 2,
|
|
'another_key': 3
|
|
}
|
|
}
|
|
|
|
flow = ErrorTestAsyncBatchFlow(start=AsyncErrorNode())
|
|
|
|
with self.assertRaises(ValueError):
|
|
asyncio.run(flow.run_async(shared_storage))
|
|
|
|
def test_nested_async_flow(self):
|
|
"""Test async batch processing with nested flows"""
|
|
class AsyncInnerNode(AsyncNode):
|
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
|
key = self.params.get('key')
|
|
if 'intermediate_results' not in shared_storage:
|
|
shared_storage['intermediate_results'] = {}
|
|
shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1
|
|
await asyncio.sleep(0.01)
|
|
return "next"
|
|
|
|
class AsyncOuterNode(AsyncNode):
|
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
|
key = self.params.get('key')
|
|
if 'results' not in shared_storage:
|
|
shared_storage['results'] = {}
|
|
shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2
|
|
await asyncio.sleep(0.01)
|
|
return "done"
|
|
|
|
class NestedAsyncBatchFlow(BatchAsyncFlow):
|
|
def prep(self, shared_storage):
|
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
|
|
|
# Create inner flow
|
|
inner_node = AsyncInnerNode()
|
|
outer_node = AsyncOuterNode()
|
|
inner_node - "next" >> outer_node
|
|
|
|
shared_storage = {
|
|
'input_data': {
|
|
'x': 1,
|
|
'y': 2
|
|
}
|
|
}
|
|
|
|
flow = NestedAsyncBatchFlow(start=inner_node)
|
|
asyncio.run(flow.run_async(shared_storage))
|
|
|
|
expected_results = {
|
|
'x': 4, # (1 + 1) * 2
|
|
'y': 6 # (2 + 1) * 2
|
|
}
|
|
self.assertEqual(shared_storage['results'], expected_results)
|
|
|
|
def test_custom_async_parameters(self):
|
|
"""Test async batch processing with additional custom parameters"""
|
|
class CustomParamAsyncNode(AsyncNode):
|
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
|
key = self.params.get('key')
|
|
multiplier = self.params.get('multiplier', 1)
|
|
await asyncio.sleep(0.01)
|
|
if 'results' not in shared_storage:
|
|
shared_storage['results'] = {}
|
|
shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier
|
|
return "done"
|
|
|
|
class CustomParamAsyncBatchFlow(BatchAsyncFlow):
|
|
def prep(self, shared_storage):
|
|
return [{
|
|
'key': k,
|
|
'multiplier': i + 1
|
|
} for i, k in enumerate(shared_storage['input_data'].keys())]
|
|
|
|
shared_storage = {
|
|
'input_data': {
|
|
'a': 1,
|
|
'b': 2,
|
|
'c': 3
|
|
}
|
|
}
|
|
|
|
flow = CustomParamAsyncBatchFlow(start=CustomParamAsyncNode())
|
|
asyncio.run(flow.run_async(shared_storage))
|
|
|
|
expected_results = {
|
|
'a': 1 * 1, # first item, multiplier = 1
|
|
'b': 2 * 2, # second item, multiplier = 2
|
|
'c': 3 * 3 # third item, multiplier = 3
|
|
}
|
|
self.assertEqual(shared_storage['results'], expected_results)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |