fix trsts
This commit is contained in:
parent
812c041bb4
commit
4b9b357608
|
|
@ -6,95 +6,99 @@ class BaseNode:
|
|||
def add_successor(self, node, cond="default"):
|
||||
if cond in self.successors: warnings.warn(f"Overwriting existing successor for '{cond}'")
|
||||
self.successors[cond] = node; return node
|
||||
def preprocess(self, s): return None
|
||||
def process(self, s, p): return None
|
||||
def _process(self, s, p): return self.process(s, p)
|
||||
def postprocess(self, s, pr, r): return "default"
|
||||
def _run(self, s):
|
||||
pr = self.preprocess(s)
|
||||
r = self._process(s, pr)
|
||||
return self.postprocess(s, pr, r)
|
||||
def run(self, s):
|
||||
if self.successors: warnings.warn("Has successors; use Flow.run() instead.")
|
||||
return self._run(s)
|
||||
def prep(self, shared): return None
|
||||
def exec(self, shared, prep_res): return None
|
||||
def _exec(self, shared, prep_res): return self.exec(shared, prep_res)
|
||||
def post(self, shared, prep_res, exec_res): return "default"
|
||||
def _run(self, shared):
|
||||
prep_res = self.prep(shared)
|
||||
exec_res = self._exec(shared, prep_res)
|
||||
return self.post(shared, prep_res, exec_res)
|
||||
def run(self, shared):
|
||||
if self.successors:
|
||||
warnings.warn("This node has successors. Create a parent Flow to run them or run that Flow instead.")
|
||||
return self._run(shared)
|
||||
def __rshift__(self, other): return self.add_successor(other)
|
||||
def __sub__(self, cond):
|
||||
if isinstance(cond, str): return _ConditionalTransition(self, cond)
|
||||
raise TypeError("Condition must be a string")
|
||||
|
||||
class _ConditionalTransition:
|
||||
def __init__(self, src, c): self.src, self.c = src, c
|
||||
def __rshift__(self, tgt): return self.src.add_successor(tgt, self.c)
|
||||
def __init__(self, src, cond): self.src, self.cond = src, cond
|
||||
def __rshift__(self, tgt): return self.src.add_successor(tgt, self.cond)
|
||||
|
||||
class Node(BaseNode):
|
||||
def __init__(self, max_retries=1): super().__init__(); self.max_retries = max_retries
|
||||
def process_after_fail(self, s, d, e): raise e
|
||||
def _process(self, s, d):
|
||||
def process_after_fail(self, shared, prep_res, exc): raise exc
|
||||
def _exec(self, shared, prep_res):
|
||||
for i in range(self.max_retries):
|
||||
try: return super()._process(s, d)
|
||||
try: return super()._exec(shared, prep_res)
|
||||
except Exception as e:
|
||||
if i == self.max_retries - 1: return self.process_after_fail(s, d, e)
|
||||
if i == self.max_retries - 1: return self.process_after_fail(shared, prep_res, e)
|
||||
|
||||
class BatchNode(Node):
|
||||
def preprocess(self, s): return []
|
||||
def process(self, s, item): return None
|
||||
def _process(self, s, items): return [super(Node, self)._process(s, i) for i in items]
|
||||
def prep(self, shared): return []
|
||||
def exec(self, shared, item): return None
|
||||
def _exec(self, shared, items): return [super(Node, self)._exec(shared, i) for i in items]
|
||||
|
||||
class BaseFlow(BaseNode):
|
||||
def __init__(self, start_node): super().__init__(); self.start_node = start_node
|
||||
def get_next_node(self, curr, c):
|
||||
nxt = curr.successors.get(c)
|
||||
if nxt is None and curr.successors: warnings.warn(f"Flow ends. '{c}' not found in {list(curr.successors.keys())}")
|
||||
def get_next_node(self, curr, cond):
|
||||
nxt = curr.successors.get(cond)
|
||||
if nxt is None and curr.successors:
|
||||
warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors.keys())}")
|
||||
return nxt
|
||||
|
||||
class Flow(BaseFlow):
|
||||
def _process(self, s, p=None):
|
||||
curr, p = self.start_node, (p if p is not None else self.params.copy())
|
||||
def _exec(self, shared, params=None):
|
||||
curr = self.start_node
|
||||
p = params if params else self.params.copy()
|
||||
while curr:
|
||||
curr.set_params(p)
|
||||
c = curr._run(s)
|
||||
c = curr._run(shared)
|
||||
curr = self.get_next_node(curr, c)
|
||||
def process(self, s, pr): raise NotImplementedError("Use Flow._process(...) instead")
|
||||
def exec(self, shared, prep_res): raise NotImplementedError("Flow exec not used directly")
|
||||
|
||||
class BaseBatchFlow(BaseFlow):
|
||||
def preprocess(self, s): return []
|
||||
def prep(self, shared): return []
|
||||
|
||||
class BatchFlow(BaseBatchFlow, Flow):
|
||||
def _run(self, s):
|
||||
pr = self.preprocess(s)
|
||||
for d in pr:
|
||||
def _run(self, shared):
|
||||
prep_res = self.prep(shared)
|
||||
for d in prep_res:
|
||||
mp = self.params.copy(); mp.update(d)
|
||||
self._process(s, mp)
|
||||
return self.postprocess(s, pr, None)
|
||||
self._exec(shared, mp)
|
||||
return self.post(shared, prep_res, None)
|
||||
|
||||
class AsyncNode(Node):
|
||||
def postprocess(self, s, pr, r): raise NotImplementedError("Use postprocess_async")
|
||||
async def postprocess_async(self, s, pr, r): await asyncio.sleep(0); return "default"
|
||||
async def run_async(self, s):
|
||||
if self.successors: warnings.warn("Has successors; use AsyncFlow.run_async() instead.")
|
||||
return await self._run_async(s)
|
||||
async def _run_async(self, s):
|
||||
pr = self.preprocess(s)
|
||||
r = self._process(s, pr)
|
||||
return await self.postprocess_async(s, pr, r)
|
||||
def _run(self, s): raise RuntimeError("AsyncNode requires async execution")
|
||||
def post(self, shared, prep_res, exec_res): raise NotImplementedError("Use post_async")
|
||||
async def post_async(self, shared, prep_res, exec_res): await asyncio.sleep(0); return "default"
|
||||
async def run_async(self, shared):
|
||||
if self.successors:
|
||||
warnings.warn("This node has successors. Create a parent AsyncFlow to run them or run that Flow instead.")
|
||||
return await self._run_async(shared)
|
||||
async def _run_async(self, shared):
|
||||
prep_res = self.prep(shared)
|
||||
exec_res = self._exec(shared, prep_res)
|
||||
return await self.post_async(shared, prep_res, exec_res)
|
||||
def _run(self, shared): raise RuntimeError("AsyncNode requires async execution")
|
||||
|
||||
class AsyncFlow(BaseFlow, AsyncNode):
|
||||
async def _process_async(self, s, p=None):
|
||||
curr, p = self.start_node, (p if p else self.params.copy())
|
||||
async def _exec_async(self, shared, params=None):
|
||||
curr, p = self.start_node, (params if params else self.params.copy())
|
||||
while curr:
|
||||
curr.set_params(p)
|
||||
c = await curr._run_async(s) if hasattr(curr, "run_async") else curr._run(s)
|
||||
c = await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared)
|
||||
curr = self.get_next_node(curr, c)
|
||||
async def _run_async(self, s):
|
||||
pr = self.preprocess(s)
|
||||
await self._process_async(s)
|
||||
return await self.postprocess_async(s, pr, None)
|
||||
async def _run_async(self, shared):
|
||||
prep_res = self.prep(shared)
|
||||
await self._exec_async(shared)
|
||||
return await self.post_async(shared, prep_res, None)
|
||||
|
||||
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
||||
async def _run_async(self, s):
|
||||
pr = self.preprocess(s)
|
||||
for d in pr:
|
||||
async def _run_async(self, shared):
|
||||
prep_res = self.prep(shared)
|
||||
for d in prep_res:
|
||||
mp = self.params.copy(); mp.update(d)
|
||||
await self._process_async(s, mp)
|
||||
return await self.postprocess_async(s, pr, None)
|
||||
await self._exec_async(shared, mp)
|
||||
return await self.post_async(shared, prep_res, None)
|
||||
|
|
@ -7,7 +7,7 @@ sys.path.append(str(Path(__file__).parent.parent))
|
|||
from minillmflow import AsyncNode, BatchAsyncFlow
|
||||
|
||||
class AsyncDataProcessNode(AsyncNode):
|
||||
def process(self, shared_storage, prep_result):
|
||||
def exec(self, shared_storage, prep_result):
|
||||
key = self.params.get('key')
|
||||
data = shared_storage['input_data'][key]
|
||||
if 'results' not in shared_storage:
|
||||
|
|
@ -15,14 +15,14 @@ class AsyncDataProcessNode(AsyncNode):
|
|||
shared_storage['results'][key] = data
|
||||
return data
|
||||
|
||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
key = self.params.get('key')
|
||||
shared_storage['results'][key] = proc_result * 2 # Double the value
|
||||
return "processed"
|
||||
|
||||
class AsyncErrorNode(AsyncNode):
|
||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||
key = self.params.get('key')
|
||||
if key == 'error_key':
|
||||
raise ValueError(f"Async error processing key: {key}")
|
||||
|
|
@ -35,7 +35,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
|||
def test_basic_async_batch_processing(self):
|
||||
"""Test basic async batch processing with multiple keys"""
|
||||
class SimpleTestAsyncBatchFlow(BatchAsyncFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
shared_storage = {
|
||||
|
|
@ -59,7 +59,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
|||
def test_empty_async_batch(self):
|
||||
"""Test async batch processing with empty input"""
|
||||
class EmptyTestAsyncBatchFlow(BatchAsyncFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
shared_storage = {
|
||||
|
|
@ -74,7 +74,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
|||
def test_async_error_handling(self):
|
||||
"""Test error handling during async batch processing"""
|
||||
class ErrorTestAsyncBatchFlow(BatchAsyncFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
shared_storage = {
|
||||
|
|
@ -93,7 +93,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
|||
def test_nested_async_flow(self):
|
||||
"""Test async batch processing with nested flows"""
|
||||
class AsyncInnerNode(AsyncNode):
|
||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||
key = self.params.get('key')
|
||||
if 'intermediate_results' not in shared_storage:
|
||||
shared_storage['intermediate_results'] = {}
|
||||
|
|
@ -102,7 +102,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
|||
return "next"
|
||||
|
||||
class AsyncOuterNode(AsyncNode):
|
||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||
key = self.params.get('key')
|
||||
if 'results' not in shared_storage:
|
||||
shared_storage['results'] = {}
|
||||
|
|
@ -111,7 +111,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
|||
return "done"
|
||||
|
||||
class NestedAsyncBatchFlow(BatchAsyncFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
# Create inner flow
|
||||
|
|
@ -138,7 +138,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
|||
def test_custom_async_parameters(self):
|
||||
"""Test async batch processing with additional custom parameters"""
|
||||
class CustomParamAsyncNode(AsyncNode):
|
||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||
key = self.params.get('key')
|
||||
multiplier = self.params.get('multiplier', 1)
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -148,7 +148,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
|
|||
return "done"
|
||||
|
||||
class CustomParamAsyncBatchFlow(BatchAsyncFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
return [{
|
||||
'key': k,
|
||||
'multiplier': i + 1
|
||||
|
|
|
|||
|
|
@ -11,19 +11,19 @@ class AsyncNumberNode(AsyncNode):
|
|||
"""
|
||||
Simple async node that sets 'current' to a given number.
|
||||
Demonstrates overriding .process() (sync) and using
|
||||
postprocess_async() for the async portion.
|
||||
post_async() for the async portion.
|
||||
"""
|
||||
def __init__(self, number):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def process(self, shared_storage, data):
|
||||
def exec(self, shared_storage, data):
|
||||
# Synchronous work is allowed inside an AsyncNode,
|
||||
# but final 'condition' is determined by postprocess_async().
|
||||
# but final 'condition' is determined by post_async().
|
||||
shared_storage['current'] = self.number
|
||||
return "set_number"
|
||||
|
||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||
# Possibly do asynchronous tasks here
|
||||
await asyncio.sleep(0.01)
|
||||
# Return a condition for the flow
|
||||
|
|
@ -34,11 +34,11 @@ class AsyncIncrementNode(AsyncNode):
|
|||
"""
|
||||
Demonstrates incrementing the 'current' value asynchronously.
|
||||
"""
|
||||
def process(self, shared_storage, data):
|
||||
def exec(self, shared_storage, data):
|
||||
shared_storage['current'] = shared_storage.get('current', 0) + 1
|
||||
return "incremented"
|
||||
|
||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||
await asyncio.sleep(0.01) # simulate async I/O
|
||||
return "done"
|
||||
|
||||
|
|
@ -105,18 +105,18 @@ class TestAsyncFlow(unittest.TestCase):
|
|||
"""
|
||||
Demonstrate a branching scenario where we return different
|
||||
conditions. For example, you could have an async node that
|
||||
returns "go_left" or "go_right" in postprocess_async, but here
|
||||
returns "go_left" or "go_right" in post_async, but here
|
||||
we'll keep it simpler for demonstration.
|
||||
"""
|
||||
|
||||
class BranchingAsyncNode(AsyncNode):
|
||||
def process(self, shared_storage, data):
|
||||
def exec(self, shared_storage, data):
|
||||
value = shared_storage.get("value", 0)
|
||||
shared_storage["value"] = value
|
||||
# We'll decide branch based on whether 'value' is positive
|
||||
return None
|
||||
|
||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||
async def post_async(self, shared_storage, prep_result, proc_result):
|
||||
await asyncio.sleep(0.01)
|
||||
if shared_storage["value"] >= 0:
|
||||
return "positive_branch"
|
||||
|
|
@ -124,12 +124,12 @@ class TestAsyncFlow(unittest.TestCase):
|
|||
return "negative_branch"
|
||||
|
||||
class PositiveNode(Node):
|
||||
def process(self, shared_storage, data):
|
||||
def exec(self, shared_storage, data):
|
||||
shared_storage["path"] = "positive"
|
||||
return None
|
||||
|
||||
class NegativeNode(Node):
|
||||
def process(self, shared_storage, data):
|
||||
def exec(self, shared_storage, data):
|
||||
shared_storage["path"] = "negative"
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ sys.path.append(str(Path(__file__).parent.parent))
|
|||
from minillmflow import Node, BatchFlow, Flow
|
||||
|
||||
class DataProcessNode(Node):
|
||||
def process(self, shared_storage, prep_result):
|
||||
def exec(self, shared_storage, prep_result):
|
||||
key = self.params.get('key')
|
||||
data = shared_storage['input_data'][key]
|
||||
if 'results' not in shared_storage:
|
||||
|
|
@ -14,7 +14,7 @@ class DataProcessNode(Node):
|
|||
shared_storage['results'][key] = data * 2
|
||||
|
||||
class ErrorProcessNode(Node):
|
||||
def process(self, shared_storage, prep_result):
|
||||
def exec(self, shared_storage, prep_result):
|
||||
key = self.params.get('key')
|
||||
if key == 'error_key':
|
||||
raise ValueError(f"Error processing key: {key}")
|
||||
|
|
@ -29,7 +29,7 @@ class TestBatchFlow(unittest.TestCase):
|
|||
def test_basic_batch_processing(self):
|
||||
"""Test basic batch processing with multiple keys"""
|
||||
class SimpleTestBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
shared_storage = {
|
||||
|
|
@ -53,7 +53,7 @@ class TestBatchFlow(unittest.TestCase):
|
|||
def test_empty_input(self):
|
||||
"""Test batch processing with empty input dictionary"""
|
||||
class EmptyTestBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
shared_storage = {
|
||||
|
|
@ -68,7 +68,7 @@ class TestBatchFlow(unittest.TestCase):
|
|||
def test_single_item(self):
|
||||
"""Test batch processing with single item"""
|
||||
class SingleItemBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
shared_storage = {
|
||||
|
|
@ -88,7 +88,7 @@ class TestBatchFlow(unittest.TestCase):
|
|||
def test_error_handling(self):
|
||||
"""Test error handling during batch processing"""
|
||||
class ErrorTestBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
shared_storage = {
|
||||
|
|
@ -107,21 +107,21 @@ class TestBatchFlow(unittest.TestCase):
|
|||
def test_nested_flow(self):
|
||||
"""Test batch processing with nested flows"""
|
||||
class InnerNode(Node):
|
||||
def process(self, shared_storage, prep_result):
|
||||
def exec(self, shared_storage, prep_result):
|
||||
key = self.params.get('key')
|
||||
if 'intermediate_results' not in shared_storage:
|
||||
shared_storage['intermediate_results'] = {}
|
||||
shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1
|
||||
|
||||
class OuterNode(Node):
|
||||
def process(self, shared_storage, prep_result):
|
||||
def exec(self, shared_storage, prep_result):
|
||||
key = self.params.get('key')
|
||||
if 'results' not in shared_storage:
|
||||
shared_storage['results'] = {}
|
||||
shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2
|
||||
|
||||
class NestedBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
return [{'key': k} for k in shared_storage['input_data'].keys()]
|
||||
|
||||
# Create inner flow
|
||||
|
|
@ -148,7 +148,7 @@ class TestBatchFlow(unittest.TestCase):
|
|||
def test_custom_parameters(self):
|
||||
"""Test batch processing with additional custom parameters"""
|
||||
class CustomParamNode(Node):
|
||||
def process(self, shared_storage, prep_result):
|
||||
def exec(self, shared_storage, prep_result):
|
||||
key = self.params.get('key')
|
||||
multiplier = self.params.get('multiplier', 1)
|
||||
if 'results' not in shared_storage:
|
||||
|
|
@ -156,7 +156,7 @@ class TestBatchFlow(unittest.TestCase):
|
|||
shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier
|
||||
|
||||
class CustomParamBatchFlow(BatchFlow):
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
return [{
|
||||
'key': k,
|
||||
'multiplier': i + 1
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class ArrayChunkNode(BatchNode):
|
|||
super().__init__()
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def preprocess(self, shared_storage):
|
||||
def prep(self, shared_storage):
|
||||
# Get array from shared storage and split into chunks
|
||||
array = shared_storage.get('input_array', [])
|
||||
chunks = []
|
||||
|
|
@ -19,20 +19,20 @@ class ArrayChunkNode(BatchNode):
|
|||
chunks.append((i, end))
|
||||
return chunks
|
||||
|
||||
def process(self, shared_storage, chunk_indices):
|
||||
def exec(self, shared_storage, chunk_indices):
|
||||
start, end = chunk_indices
|
||||
array = shared_storage['input_array']
|
||||
# Process the chunk and return its sum
|
||||
chunk_sum = sum(array[start:end])
|
||||
return chunk_sum
|
||||
|
||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||
def post(self, shared_storage, prep_result, proc_result):
|
||||
# Store chunk results in shared storage
|
||||
shared_storage['chunk_results'] = proc_result
|
||||
return "default"
|
||||
|
||||
class SumReduceNode(Node):
|
||||
def process(self, shared_storage, data):
|
||||
def exec(self, shared_storage, data):
|
||||
# Get chunk results from shared storage and sum them
|
||||
chunk_results = shared_storage.get('chunk_results', [])
|
||||
total = sum(chunk_results)
|
||||
|
|
@ -48,7 +48,7 @@ class TestBatchNode(unittest.TestCase):
|
|||
}
|
||||
|
||||
chunk_node = ArrayChunkNode(chunk_size=10)
|
||||
chunks = chunk_node.preprocess(shared_storage)
|
||||
chunks = chunk_node.prep(shared_storage)
|
||||
|
||||
self.assertEqual(chunks, [(0, 10), (10, 20), (20, 25)])
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class NumberNode(Node):
|
|||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def process(self, shared_storage, data):
|
||||
def exec(self, shared_storage, data):
|
||||
shared_storage['current'] = self.number
|
||||
|
||||
class AddNode(Node):
|
||||
|
|
@ -18,7 +18,7 @@ class AddNode(Node):
|
|||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def process(self, shared_storage, data):
|
||||
def exec(self, shared_storage, data):
|
||||
shared_storage['current'] += self.number
|
||||
|
||||
class MultiplyNode(Node):
|
||||
|
|
@ -26,18 +26,18 @@ class MultiplyNode(Node):
|
|||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def process(self, shared_storage, data):
|
||||
def exec(self, shared_storage, data):
|
||||
shared_storage['current'] *= self.number
|
||||
|
||||
class CheckPositiveNode(Node):
|
||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||
def post(self, shared_storage, prep_result, proc_result):
|
||||
if shared_storage['current'] >= 0:
|
||||
return 'positive'
|
||||
else:
|
||||
return 'negative'
|
||||
|
||||
class NoOpNode(Node):
|
||||
def process(self, shared_storage, data):
|
||||
def exec(self, shared_storage, data):
|
||||
# Do nothing, just pass
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class NumberNode(Node):
|
|||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def process(self, shared_storage, prep_result):
|
||||
def exec(self, shared_storage, prep_result):
|
||||
shared_storage['current'] = self.number
|
||||
|
||||
class AddNode(Node):
|
||||
|
|
@ -20,7 +20,7 @@ class AddNode(Node):
|
|||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def process(self, shared_storage, prep_result):
|
||||
def exec(self, shared_storage, prep_result):
|
||||
shared_storage['current'] += self.number
|
||||
|
||||
class MultiplyNode(Node):
|
||||
|
|
@ -28,7 +28,7 @@ class MultiplyNode(Node):
|
|||
super().__init__()
|
||||
self.number = number
|
||||
|
||||
def process(self, shared_storage, prep_result):
|
||||
def exec(self, shared_storage, prep_result):
|
||||
shared_storage['current'] *= self.number
|
||||
|
||||
class TestFlowComposition(unittest.TestCase):
|
||||
|
|
|
|||
Loading…
Reference in New Issue