fix trsts
This commit is contained in:
parent
812c041bb4
commit
4b9b357608
|
|
@ -6,95 +6,99 @@ class BaseNode:
|
||||||
def add_successor(self, node, cond="default"):
|
def add_successor(self, node, cond="default"):
|
||||||
if cond in self.successors: warnings.warn(f"Overwriting existing successor for '{cond}'")
|
if cond in self.successors: warnings.warn(f"Overwriting existing successor for '{cond}'")
|
||||||
self.successors[cond] = node; return node
|
self.successors[cond] = node; return node
|
||||||
def preprocess(self, s): return None
|
def prep(self, shared): return None
|
||||||
def process(self, s, p): return None
|
def exec(self, shared, prep_res): return None
|
||||||
def _process(self, s, p): return self.process(s, p)
|
def _exec(self, shared, prep_res): return self.exec(shared, prep_res)
|
||||||
def postprocess(self, s, pr, r): return "default"
|
def post(self, shared, prep_res, exec_res): return "default"
|
||||||
def _run(self, s):
|
def _run(self, shared):
|
||||||
pr = self.preprocess(s)
|
prep_res = self.prep(shared)
|
||||||
r = self._process(s, pr)
|
exec_res = self._exec(shared, prep_res)
|
||||||
return self.postprocess(s, pr, r)
|
return self.post(shared, prep_res, exec_res)
|
||||||
def run(self, s):
|
def run(self, shared):
|
||||||
if self.successors: warnings.warn("Has successors; use Flow.run() instead.")
|
if self.successors:
|
||||||
return self._run(s)
|
warnings.warn("This node has successors. Create a parent Flow to run them or run that Flow instead.")
|
||||||
|
return self._run(shared)
|
||||||
def __rshift__(self, other): return self.add_successor(other)
|
def __rshift__(self, other): return self.add_successor(other)
|
||||||
def __sub__(self, cond):
|
def __sub__(self, cond):
|
||||||
if isinstance(cond, str): return _ConditionalTransition(self, cond)
|
if isinstance(cond, str): return _ConditionalTransition(self, cond)
|
||||||
raise TypeError("Condition must be a string")
|
raise TypeError("Condition must be a string")
|
||||||
|
|
||||||
class _ConditionalTransition:
|
class _ConditionalTransition:
|
||||||
def __init__(self, src, c): self.src, self.c = src, c
|
def __init__(self, src, cond): self.src, self.cond = src, cond
|
||||||
def __rshift__(self, tgt): return self.src.add_successor(tgt, self.c)
|
def __rshift__(self, tgt): return self.src.add_successor(tgt, self.cond)
|
||||||
|
|
||||||
class Node(BaseNode):
|
class Node(BaseNode):
|
||||||
def __init__(self, max_retries=1): super().__init__(); self.max_retries = max_retries
|
def __init__(self, max_retries=1): super().__init__(); self.max_retries = max_retries
|
||||||
def process_after_fail(self, s, d, e): raise e
|
def process_after_fail(self, shared, prep_res, exc): raise exc
|
||||||
def _process(self, s, d):
|
def _exec(self, shared, prep_res):
|
||||||
for i in range(self.max_retries):
|
for i in range(self.max_retries):
|
||||||
try: return super()._process(s, d)
|
try: return super()._exec(shared, prep_res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if i == self.max_retries - 1: return self.process_after_fail(s, d, e)
|
if i == self.max_retries - 1: return self.process_after_fail(shared, prep_res, e)
|
||||||
|
|
||||||
class BatchNode(Node):
|
class BatchNode(Node):
|
||||||
def preprocess(self, s): return []
|
def prep(self, shared): return []
|
||||||
def process(self, s, item): return None
|
def exec(self, shared, item): return None
|
||||||
def _process(self, s, items): return [super(Node, self)._process(s, i) for i in items]
|
def _exec(self, shared, items): return [super(Node, self)._exec(shared, i) for i in items]
|
||||||
|
|
||||||
class BaseFlow(BaseNode):
|
class BaseFlow(BaseNode):
|
||||||
def __init__(self, start_node): super().__init__(); self.start_node = start_node
|
def __init__(self, start_node): super().__init__(); self.start_node = start_node
|
||||||
def get_next_node(self, curr, c):
|
def get_next_node(self, curr, cond):
|
||||||
nxt = curr.successors.get(c)
|
nxt = curr.successors.get(cond)
|
||||||
if nxt is None and curr.successors: warnings.warn(f"Flow ends. '{c}' not found in {list(curr.successors.keys())}")
|
if nxt is None and curr.successors:
|
||||||
|
warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors.keys())}")
|
||||||
return nxt
|
return nxt
|
||||||
|
|
||||||
class Flow(BaseFlow):
|
class Flow(BaseFlow):
|
||||||
def _process(self, s, p=None):
|
def _exec(self, shared, params=None):
|
||||||
curr, p = self.start_node, (p if p is not None else self.params.copy())
|
curr = self.start_node
|
||||||
|
p = params if params else self.params.copy()
|
||||||
while curr:
|
while curr:
|
||||||
curr.set_params(p)
|
curr.set_params(p)
|
||||||
c = curr._run(s)
|
c = curr._run(shared)
|
||||||
curr = self.get_next_node(curr, c)
|
curr = self.get_next_node(curr, c)
|
||||||
def process(self, s, pr): raise NotImplementedError("Use Flow._process(...) instead")
|
def exec(self, shared, prep_res): raise NotImplementedError("Flow exec not used directly")
|
||||||
|
|
||||||
class BaseBatchFlow(BaseFlow):
|
class BaseBatchFlow(BaseFlow):
|
||||||
def preprocess(self, s): return []
|
def prep(self, shared): return []
|
||||||
|
|
||||||
class BatchFlow(BaseBatchFlow, Flow):
|
class BatchFlow(BaseBatchFlow, Flow):
|
||||||
def _run(self, s):
|
def _run(self, shared):
|
||||||
pr = self.preprocess(s)
|
prep_res = self.prep(shared)
|
||||||
for d in pr:
|
for d in prep_res:
|
||||||
mp = self.params.copy(); mp.update(d)
|
mp = self.params.copy(); mp.update(d)
|
||||||
self._process(s, mp)
|
self._exec(shared, mp)
|
||||||
return self.postprocess(s, pr, None)
|
return self.post(shared, prep_res, None)
|
||||||
|
|
||||||
class AsyncNode(Node):
|
class AsyncNode(Node):
|
||||||
def postprocess(self, s, pr, r): raise NotImplementedError("Use postprocess_async")
|
def post(self, shared, prep_res, exec_res): raise NotImplementedError("Use post_async")
|
||||||
async def postprocess_async(self, s, pr, r): await asyncio.sleep(0); return "default"
|
async def post_async(self, shared, prep_res, exec_res): await asyncio.sleep(0); return "default"
|
||||||
async def run_async(self, s):
|
async def run_async(self, shared):
|
||||||
if self.successors: warnings.warn("Has successors; use AsyncFlow.run_async() instead.")
|
if self.successors:
|
||||||
return await self._run_async(s)
|
warnings.warn("This node has successors. Create a parent AsyncFlow to run them or run that Flow instead.")
|
||||||
async def _run_async(self, s):
|
return await self._run_async(shared)
|
||||||
pr = self.preprocess(s)
|
async def _run_async(self, shared):
|
||||||
r = self._process(s, pr)
|
prep_res = self.prep(shared)
|
||||||
return await self.postprocess_async(s, pr, r)
|
exec_res = self._exec(shared, prep_res)
|
||||||
def _run(self, s): raise RuntimeError("AsyncNode requires async execution")
|
return await self.post_async(shared, prep_res, exec_res)
|
||||||
|
def _run(self, shared): raise RuntimeError("AsyncNode requires async execution")
|
||||||
|
|
||||||
class AsyncFlow(BaseFlow, AsyncNode):
|
class AsyncFlow(BaseFlow, AsyncNode):
|
||||||
async def _process_async(self, s, p=None):
|
async def _exec_async(self, shared, params=None):
|
||||||
curr, p = self.start_node, (p if p else self.params.copy())
|
curr, p = self.start_node, (params if params else self.params.copy())
|
||||||
while curr:
|
while curr:
|
||||||
curr.set_params(p)
|
curr.set_params(p)
|
||||||
c = await curr._run_async(s) if hasattr(curr, "run_async") else curr._run(s)
|
c = await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared)
|
||||||
curr = self.get_next_node(curr, c)
|
curr = self.get_next_node(curr, c)
|
||||||
async def _run_async(self, s):
|
async def _run_async(self, shared):
|
||||||
pr = self.preprocess(s)
|
prep_res = self.prep(shared)
|
||||||
await self._process_async(s)
|
await self._exec_async(shared)
|
||||||
return await self.postprocess_async(s, pr, None)
|
return await self.post_async(shared, prep_res, None)
|
||||||
|
|
||||||
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
||||||
async def _run_async(self, s):
|
async def _run_async(self, shared):
|
||||||
pr = self.preprocess(s)
|
prep_res = self.prep(shared)
|
||||||
for d in pr:
|
for d in prep_res:
|
||||||
mp = self.params.copy(); mp.update(d)
|
mp = self.params.copy(); mp.update(d)
|
||||||
await self._process_async(s, mp)
|
await self._exec_async(shared, mp)
|
||||||
return await self.postprocess_async(s, pr, None)
|
return await self.post_async(shared, prep_res, None)
|
||||||
|
|
@ -7,7 +7,7 @@ sys.path.append(str(Path(__file__).parent.parent))
|
||||||
from minillmflow import AsyncNode, BatchAsyncFlow
|
from minillmflow import AsyncNode, BatchAsyncFlow
|
||||||
|
|
||||||
class AsyncDataProcessNode(AsyncNode):
|
class AsyncDataProcessNode(AsyncNode):
|
||||||
def process(self, shared_storage, prep_result):
|
def exec(self, shared_storage, prep_result):
|
||||||
key = self.params.get('key')
|
key = self.params.get('key')
|
||||||
data = shared_storage['input_data'][key]
|
data = shared_storage['input_data'][key]
|
||||||
if 'results' not in shared_storage:
|
if 'results' not in shared_storage:
|
||||||
|
|
@ -15,14 +15,14 @@ class AsyncDataProcessNode(AsyncNode):
|
||||||
shared_storage['results'][key] = data
|
shared_storage['results'][key] = data
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||||
await asyncio.sleep(0.01) # Simulate async work
|
await asyncio.sleep(0.01) # Simulate async work
|
||||||
key = self.params.get('key')
|
key = self.params.get('key')
|
||||||
shared_storage['results'][key] = proc_result * 2 # Double the value
|
shared_storage['results'][key] = proc_result * 2 # Double the value
|
||||||
return "processed"
|
return "processed"
|
||||||
|
|
||||||
class AsyncErrorNode(AsyncNode):
|
class AsyncErrorNode(AsyncNode):
|
||||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||||
key = self.params.get('key')
|
key = self.params.get('key')
|
||||||
if key == 'error_key':
|
if key == 'error_key':
|
||||||
raise ValueError(f"Async error processing key: {key}")
|
raise ValueError(f"Async error processing key: {key}")
|
||||||
|
|
@ -35,7 +35,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
||||||
def test_basic_async_batch_processing(self):
|
def test_basic_async_batch_processing(self):
|
||||||
"""Test basic async batch processing with multiple keys"""
|
"""Test basic async batch processing with multiple keys"""
|
||||||
class SimpleTestAsyncBatchFlow(BatchAsyncFlow):
|
class SimpleTestAsyncBatchFlow(BatchAsyncFlow):
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
shared_storage = {
|
shared_storage = {
|
||||||
|
|
@ -59,7 +59,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
||||||
def test_empty_async_batch(self):
|
def test_empty_async_batch(self):
|
||||||
"""Test async batch processing with empty input"""
|
"""Test async batch processing with empty input"""
|
||||||
class EmptyTestAsyncBatchFlow(BatchAsyncFlow):
|
class EmptyTestAsyncBatchFlow(BatchAsyncFlow):
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
shared_storage = {
|
shared_storage = {
|
||||||
|
|
@ -74,7 +74,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
||||||
def test_async_error_handling(self):
|
def test_async_error_handling(self):
|
||||||
"""Test error handling during async batch processing"""
|
"""Test error handling during async batch processing"""
|
||||||
class ErrorTestAsyncBatchFlow(BatchAsyncFlow):
|
class ErrorTestAsyncBatchFlow(BatchAsyncFlow):
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
shared_storage = {
|
shared_storage = {
|
||||||
|
|
@ -93,7 +93,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
||||||
def test_nested_async_flow(self):
|
def test_nested_async_flow(self):
|
||||||
"""Test async batch processing with nested flows"""
|
"""Test async batch processing with nested flows"""
|
||||||
class AsyncInnerNode(AsyncNode):
|
class AsyncInnerNode(AsyncNode):
|
||||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||||
key = self.params.get('key')
|
key = self.params.get('key')
|
||||||
if 'intermediate_results' not in shared_storage:
|
if 'intermediate_results' not in shared_storage:
|
||||||
shared_storage['intermediate_results'] = {}
|
shared_storage['intermediate_results'] = {}
|
||||||
|
|
@ -102,7 +102,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
||||||
return "next"
|
return "next"
|
||||||
|
|
||||||
class AsyncOuterNode(AsyncNode):
|
class AsyncOuterNode(AsyncNode):
|
||||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||||
key = self.params.get('key')
|
key = self.params.get('key')
|
||||||
if 'results' not in shared_storage:
|
if 'results' not in shared_storage:
|
||||||
shared_storage['results'] = {}
|
shared_storage['results'] = {}
|
||||||
|
|
@ -111,7 +111,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
||||||
return "done"
|
return "done"
|
||||||
|
|
||||||
class NestedAsyncBatchFlow(BatchAsyncFlow):
|
class NestedAsyncBatchFlow(BatchAsyncFlow):
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
# Create inner flow
|
# Create inner flow
|
||||||
|
|
@ -138,7 +138,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
||||||
def test_custom_async_parameters(self):
|
def test_custom_async_parameters(self):
|
||||||
"""Test async batch processing with additional custom parameters"""
|
"""Test async batch processing with additional custom parameters"""
|
||||||
class CustomParamAsyncNode(AsyncNode):
|
class CustomParamAsyncNode(AsyncNode):
|
||||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||||
key = self.params.get('key')
|
key = self.params.get('key')
|
||||||
multiplier = self.params.get('multiplier', 1)
|
multiplier = self.params.get('multiplier', 1)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
|
|
@ -148,7 +148,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
||||||
return "done"
|
return "done"
|
||||||
|
|
||||||
class CustomParamAsyncBatchFlow(BatchAsyncFlow):
|
class CustomParamAsyncBatchFlow(BatchAsyncFlow):
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
return [{
|
return [{
|
||||||
'key': k,
|
'key': k,
|
||||||
'multiplier': i + 1
|
'multiplier': i + 1
|
||||||
|
|
|
||||||
|
|
@ -11,19 +11,19 @@ class AsyncNumberNode(AsyncNode):
|
||||||
"""
|
"""
|
||||||
Simple async node that sets 'current' to a given number.
|
Simple async node that sets 'current' to a given number.
|
||||||
Demonstrates overriding .process() (sync) and using
|
Demonstrates overriding .process() (sync) and using
|
||||||
postprocess_async() for the async portion.
|
post_async() for the async portion.
|
||||||
"""
|
"""
|
||||||
def __init__(self, number):
|
def __init__(self, number):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def process(self, shared_storage, data):
|
def exec(self, shared_storage, data):
|
||||||
# Synchronous work is allowed inside an AsyncNode,
|
# Synchronous work is allowed inside an AsyncNode,
|
||||||
# but final 'condition' is determined by postprocess_async().
|
# but final 'condition' is determined by post_async().
|
||||||
shared_storage['current'] = self.number
|
shared_storage['current'] = self.number
|
||||||
return "set_number"
|
return "set_number"
|
||||||
|
|
||||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||||
# Possibly do asynchronous tasks here
|
# Possibly do asynchronous tasks here
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
# Return a condition for the flow
|
# Return a condition for the flow
|
||||||
|
|
@ -34,11 +34,11 @@ class AsyncIncrementNode(AsyncNode):
|
||||||
"""
|
"""
|
||||||
Demonstrates incrementing the 'current' value asynchronously.
|
Demonstrates incrementing the 'current' value asynchronously.
|
||||||
"""
|
"""
|
||||||
def process(self, shared_storage, data):
|
def exec(self, shared_storage, data):
|
||||||
shared_storage['current'] = shared_storage.get('current', 0) + 1
|
shared_storage['current'] = shared_storage.get('current', 0) + 1
|
||||||
return "incremented"
|
return "incremented"
|
||||||
|
|
||||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||||
await asyncio.sleep(0.01) # simulate async I/O
|
await asyncio.sleep(0.01) # simulate async I/O
|
||||||
return "done"
|
return "done"
|
||||||
|
|
||||||
|
|
@ -105,18 +105,18 @@ class TestAsyncFlow(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Demonstrate a branching scenario where we return different
|
Demonstrate a branching scenario where we return different
|
||||||
conditions. For example, you could have an async node that
|
conditions. For example, you could have an async node that
|
||||||
returns "go_left" or "go_right" in postprocess_async, but here
|
returns "go_left" or "go_right" in post_async, but here
|
||||||
we'll keep it simpler for demonstration.
|
we'll keep it simpler for demonstration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class BranchingAsyncNode(AsyncNode):
|
class BranchingAsyncNode(AsyncNode):
|
||||||
def process(self, shared_storage, data):
|
def exec(self, shared_storage, data):
|
||||||
value = shared_storage.get("value", 0)
|
value = shared_storage.get("value", 0)
|
||||||
shared_storage["value"] = value
|
shared_storage["value"] = value
|
||||||
# We'll decide branch based on whether 'value' is positive
|
# We'll decide branch based on whether 'value' is positive
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
if shared_storage["value"] >= 0:
|
if shared_storage["value"] >= 0:
|
||||||
return "positive_branch"
|
return "positive_branch"
|
||||||
|
|
@ -124,12 +124,12 @@ class TestAsyncFlow(unittest.TestCase):
|
||||||
return "negative_branch"
|
return "negative_branch"
|
||||||
|
|
||||||
class PositiveNode(Node):
|
class PositiveNode(Node):
|
||||||
def process(self, shared_storage, data):
|
def exec(self, shared_storage, data):
|
||||||
shared_storage["path"] = "positive"
|
shared_storage["path"] = "positive"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class NegativeNode(Node):
|
class NegativeNode(Node):
|
||||||
def process(self, shared_storage, data):
|
def exec(self, shared_storage, data):
|
||||||
shared_storage["path"] = "negative"
|
shared_storage["path"] = "negative"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ sys.path.append(str(Path(__file__).parent.parent))
|
||||||
from minillmflow import Node, BatchFlow, Flow
|
from minillmflow import Node, BatchFlow, Flow
|
||||||
|
|
||||||
class DataProcessNode(Node):
|
class DataProcessNode(Node):
|
||||||
def process(self, shared_storage, prep_result):
|
def exec(self, shared_storage, prep_result):
|
||||||
key = self.params.get('key')
|
key = self.params.get('key')
|
||||||
data = shared_storage['input_data'][key]
|
data = shared_storage['input_data'][key]
|
||||||
if 'results' not in shared_storage:
|
if 'results' not in shared_storage:
|
||||||
|
|
@ -14,7 +14,7 @@ class DataProcessNode(Node):
|
||||||
shared_storage['results'][key] = data * 2
|
shared_storage['results'][key] = data * 2
|
||||||
|
|
||||||
class ErrorProcessNode(Node):
|
class ErrorProcessNode(Node):
|
||||||
def process(self, shared_storage, prep_result):
|
def exec(self, shared_storage, prep_result):
|
||||||
key = self.params.get('key')
|
key = self.params.get('key')
|
||||||
if key == 'error_key':
|
if key == 'error_key':
|
||||||
raise ValueError(f"Error processing key: {key}")
|
raise ValueError(f"Error processing key: {key}")
|
||||||
|
|
@ -29,7 +29,7 @@ class TestBatchFlow(unittest.TestCase):
|
||||||
def test_basic_batch_processing(self):
|
def test_basic_batch_processing(self):
|
||||||
"""Test basic batch processing with multiple keys"""
|
"""Test basic batch processing with multiple keys"""
|
||||||
class SimpleTestBatchFlow(BatchFlow):
|
class SimpleTestBatchFlow(BatchFlow):
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
shared_storage = {
|
shared_storage = {
|
||||||
|
|
@ -53,7 +53,7 @@ class TestBatchFlow(unittest.TestCase):
|
||||||
def test_empty_input(self):
|
def test_empty_input(self):
|
||||||
"""Test batch processing with empty input dictionary"""
|
"""Test batch processing with empty input dictionary"""
|
||||||
class EmptyTestBatchFlow(BatchFlow):
|
class EmptyTestBatchFlow(BatchFlow):
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
shared_storage = {
|
shared_storage = {
|
||||||
|
|
@ -68,7 +68,7 @@ class TestBatchFlow(unittest.TestCase):
|
||||||
def test_single_item(self):
|
def test_single_item(self):
|
||||||
"""Test batch processing with single item"""
|
"""Test batch processing with single item"""
|
||||||
class SingleItemBatchFlow(BatchFlow):
|
class SingleItemBatchFlow(BatchFlow):
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
shared_storage = {
|
shared_storage = {
|
||||||
|
|
@ -88,7 +88,7 @@ class TestBatchFlow(unittest.TestCase):
|
||||||
def test_error_handling(self):
|
def test_error_handling(self):
|
||||||
"""Test error handling during batch processing"""
|
"""Test error handling during batch processing"""
|
||||||
class ErrorTestBatchFlow(BatchFlow):
|
class ErrorTestBatchFlow(BatchFlow):
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
shared_storage = {
|
shared_storage = {
|
||||||
|
|
@ -107,21 +107,21 @@ class TestBatchFlow(unittest.TestCase):
|
||||||
def test_nested_flow(self):
|
def test_nested_flow(self):
|
||||||
"""Test batch processing with nested flows"""
|
"""Test batch processing with nested flows"""
|
||||||
class InnerNode(Node):
|
class InnerNode(Node):
|
||||||
def process(self, shared_storage, prep_result):
|
def exec(self, shared_storage, prep_result):
|
||||||
key = self.params.get('key')
|
key = self.params.get('key')
|
||||||
if 'intermediate_results' not in shared_storage:
|
if 'intermediate_results' not in shared_storage:
|
||||||
shared_storage['intermediate_results'] = {}
|
shared_storage['intermediate_results'] = {}
|
||||||
shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1
|
shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1
|
||||||
|
|
||||||
class OuterNode(Node):
|
class OuterNode(Node):
|
||||||
def process(self, shared_storage, prep_result):
|
def exec(self, shared_storage, prep_result):
|
||||||
key = self.params.get('key')
|
key = self.params.get('key')
|
||||||
if 'results' not in shared_storage:
|
if 'results' not in shared_storage:
|
||||||
shared_storage['results'] = {}
|
shared_storage['results'] = {}
|
||||||
shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2
|
shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2
|
||||||
|
|
||||||
class NestedBatchFlow(BatchFlow):
|
class NestedBatchFlow(BatchFlow):
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||||
|
|
||||||
# Create inner flow
|
# Create inner flow
|
||||||
|
|
@ -148,7 +148,7 @@ class TestBatchFlow(unittest.TestCase):
|
||||||
def test_custom_parameters(self):
|
def test_custom_parameters(self):
|
||||||
"""Test batch processing with additional custom parameters"""
|
"""Test batch processing with additional custom parameters"""
|
||||||
class CustomParamNode(Node):
|
class CustomParamNode(Node):
|
||||||
def process(self, shared_storage, prep_result):
|
def exec(self, shared_storage, prep_result):
|
||||||
key = self.params.get('key')
|
key = self.params.get('key')
|
||||||
multiplier = self.params.get('multiplier', 1)
|
multiplier = self.params.get('multiplier', 1)
|
||||||
if 'results' not in shared_storage:
|
if 'results' not in shared_storage:
|
||||||
|
|
@ -156,7 +156,7 @@ class TestBatchFlow(unittest.TestCase):
|
||||||
shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier
|
shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier
|
||||||
|
|
||||||
class CustomParamBatchFlow(BatchFlow):
|
class CustomParamBatchFlow(BatchFlow):
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
return [{
|
return [{
|
||||||
'key': k,
|
'key': k,
|
||||||
'multiplier': i + 1
|
'multiplier': i + 1
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ class ArrayChunkNode(BatchNode):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
def preprocess(self, shared_storage):
|
def prep(self, shared_storage):
|
||||||
# Get array from shared storage and split into chunks
|
# Get array from shared storage and split into chunks
|
||||||
array = shared_storage.get('input_array', [])
|
array = shared_storage.get('input_array', [])
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
@ -19,20 +19,20 @@ class ArrayChunkNode(BatchNode):
|
||||||
chunks.append((i, end))
|
chunks.append((i, end))
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def process(self, shared_storage, chunk_indices):
|
def exec(self, shared_storage, chunk_indices):
|
||||||
start, end = chunk_indices
|
start, end = chunk_indices
|
||||||
array = shared_storage['input_array']
|
array = shared_storage['input_array']
|
||||||
# Process the chunk and return its sum
|
# Process the chunk and return its sum
|
||||||
chunk_sum = sum(array[start:end])
|
chunk_sum = sum(array[start:end])
|
||||||
return chunk_sum
|
return chunk_sum
|
||||||
|
|
||||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
def post(self, shared_storage, prep_result, proc_result):
|
||||||
# Store chunk results in shared storage
|
# Store chunk results in shared storage
|
||||||
shared_storage['chunk_results'] = proc_result
|
shared_storage['chunk_results'] = proc_result
|
||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
class SumReduceNode(Node):
|
class SumReduceNode(Node):
|
||||||
def process(self, shared_storage, data):
|
def exec(self, shared_storage, data):
|
||||||
# Get chunk results from shared storage and sum them
|
# Get chunk results from shared storage and sum them
|
||||||
chunk_results = shared_storage.get('chunk_results', [])
|
chunk_results = shared_storage.get('chunk_results', [])
|
||||||
total = sum(chunk_results)
|
total = sum(chunk_results)
|
||||||
|
|
@ -48,7 +48,7 @@ class TestBatchNode(unittest.TestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
chunk_node = ArrayChunkNode(chunk_size=10)
|
chunk_node = ArrayChunkNode(chunk_size=10)
|
||||||
chunks = chunk_node.preprocess(shared_storage)
|
chunks = chunk_node.prep(shared_storage)
|
||||||
|
|
||||||
self.assertEqual(chunks, [(0, 10), (10, 20), (20, 25)])
|
self.assertEqual(chunks, [(0, 10), (10, 20), (20, 25)])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ class NumberNode(Node):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def process(self, shared_storage, data):
|
def exec(self, shared_storage, data):
|
||||||
shared_storage['current'] = self.number
|
shared_storage['current'] = self.number
|
||||||
|
|
||||||
class AddNode(Node):
|
class AddNode(Node):
|
||||||
|
|
@ -18,7 +18,7 @@ class AddNode(Node):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def process(self, shared_storage, data):
|
def exec(self, shared_storage, data):
|
||||||
shared_storage['current'] += self.number
|
shared_storage['current'] += self.number
|
||||||
|
|
||||||
class MultiplyNode(Node):
|
class MultiplyNode(Node):
|
||||||
|
|
@ -26,18 +26,18 @@ class MultiplyNode(Node):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def process(self, shared_storage, data):
|
def exec(self, shared_storage, data):
|
||||||
shared_storage['current'] *= self.number
|
shared_storage['current'] *= self.number
|
||||||
|
|
||||||
class CheckPositiveNode(Node):
|
class CheckPositiveNode(Node):
|
||||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
def post(self, shared_storage, prep_result, proc_result):
|
||||||
if shared_storage['current'] >= 0:
|
if shared_storage['current'] >= 0:
|
||||||
return 'positive'
|
return 'positive'
|
||||||
else:
|
else:
|
||||||
return 'negative'
|
return 'negative'
|
||||||
|
|
||||||
class NoOpNode(Node):
|
class NoOpNode(Node):
|
||||||
def process(self, shared_storage, data):
|
def exec(self, shared_storage, data):
|
||||||
# Do nothing, just pass
|
# Do nothing, just pass
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ class NumberNode(Node):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def process(self, shared_storage, prep_result):
|
def exec(self, shared_storage, prep_result):
|
||||||
shared_storage['current'] = self.number
|
shared_storage['current'] = self.number
|
||||||
|
|
||||||
class AddNode(Node):
|
class AddNode(Node):
|
||||||
|
|
@ -20,7 +20,7 @@ class AddNode(Node):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def process(self, shared_storage, prep_result):
|
def exec(self, shared_storage, prep_result):
|
||||||
shared_storage['current'] += self.number
|
shared_storage['current'] += self.number
|
||||||
|
|
||||||
class MultiplyNode(Node):
|
class MultiplyNode(Node):
|
||||||
|
|
@ -28,7 +28,7 @@ class MultiplyNode(Node):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
|
|
||||||
def process(self, shared_storage, prep_result):
|
def exec(self, shared_storage, prep_result):
|
||||||
shared_storage['current'] *= self.number
|
shared_storage['current'] *= self.number
|
||||||
|
|
||||||
class TestFlowComposition(unittest.TestCase):
|
class TestFlowComposition(unittest.TestCase):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue