fix trsts

This commit is contained in:
zachary62 2024-12-26 03:18:00 +00:00
parent 812c041bb4
commit 4b9b357608
7 changed files with 105 additions and 101 deletions

View File

@ -6,95 +6,99 @@ class BaseNode:
def add_successor(self, node, cond="default"): def add_successor(self, node, cond="default"):
if cond in self.successors: warnings.warn(f"Overwriting existing successor for '{cond}'") if cond in self.successors: warnings.warn(f"Overwriting existing successor for '{cond}'")
self.successors[cond] = node; return node self.successors[cond] = node; return node
def preprocess(self, s): return None def prep(self, shared): return None
def process(self, s, p): return None def exec(self, shared, prep_res): return None
def _process(self, s, p): return self.process(s, p) def _exec(self, shared, prep_res): return self.exec(shared, prep_res)
def postprocess(self, s, pr, r): return "default" def post(self, shared, prep_res, exec_res): return "default"
def _run(self, s): def _run(self, shared):
pr = self.preprocess(s) prep_res = self.prep(shared)
r = self._process(s, pr) exec_res = self._exec(shared, prep_res)
return self.postprocess(s, pr, r) return self.post(shared, prep_res, exec_res)
def run(self, s): def run(self, shared):
if self.successors: warnings.warn("Has successors; use Flow.run() instead.") if self.successors:
return self._run(s) warnings.warn("This node has successors. Create a parent Flow to run them or run that Flow instead.")
return self._run(shared)
def __rshift__(self, other): return self.add_successor(other) def __rshift__(self, other): return self.add_successor(other)
def __sub__(self, cond): def __sub__(self, cond):
if isinstance(cond, str): return _ConditionalTransition(self, cond) if isinstance(cond, str): return _ConditionalTransition(self, cond)
raise TypeError("Condition must be a string") raise TypeError("Condition must be a string")
class _ConditionalTransition: class _ConditionalTransition:
def __init__(self, src, c): self.src, self.c = src, c def __init__(self, src, cond): self.src, self.cond = src, cond
def __rshift__(self, tgt): return self.src.add_successor(tgt, self.c) def __rshift__(self, tgt): return self.src.add_successor(tgt, self.cond)
class Node(BaseNode): class Node(BaseNode):
def __init__(self, max_retries=1): super().__init__(); self.max_retries = max_retries def __init__(self, max_retries=1): super().__init__(); self.max_retries = max_retries
def process_after_fail(self, s, d, e): raise e def process_after_fail(self, shared, prep_res, exc): raise exc
def _process(self, s, d): def _exec(self, shared, prep_res):
for i in range(self.max_retries): for i in range(self.max_retries):
try: return super()._process(s, d) try: return super()._exec(shared, prep_res)
except Exception as e: except Exception as e:
if i == self.max_retries - 1: return self.process_after_fail(s, d, e) if i == self.max_retries - 1: return self.process_after_fail(shared, prep_res, e)
class BatchNode(Node): class BatchNode(Node):
def preprocess(self, s): return [] def prep(self, shared): return []
def process(self, s, item): return None def exec(self, shared, item): return None
def _process(self, s, items): return [super(Node, self)._process(s, i) for i in items] def _exec(self, shared, items): return [super(Node, self)._exec(shared, i) for i in items]
class BaseFlow(BaseNode): class BaseFlow(BaseNode):
def __init__(self, start_node): super().__init__(); self.start_node = start_node def __init__(self, start_node): super().__init__(); self.start_node = start_node
def get_next_node(self, curr, c): def get_next_node(self, curr, cond):
nxt = curr.successors.get(c) nxt = curr.successors.get(cond)
if nxt is None and curr.successors: warnings.warn(f"Flow ends. '{c}' not found in {list(curr.successors.keys())}") if nxt is None and curr.successors:
warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors.keys())}")
return nxt return nxt
class Flow(BaseFlow): class Flow(BaseFlow):
def _process(self, s, p=None): def _exec(self, shared, params=None):
curr, p = self.start_node, (p if p is not None else self.params.copy()) curr = self.start_node
p = params if params else self.params.copy()
while curr: while curr:
curr.set_params(p) curr.set_params(p)
c = curr._run(s) c = curr._run(shared)
curr = self.get_next_node(curr, c) curr = self.get_next_node(curr, c)
def process(self, s, pr): raise NotImplementedError("Use Flow._process(...) instead") def exec(self, shared, prep_res): raise NotImplementedError("Flow exec not used directly")
class BaseBatchFlow(BaseFlow): class BaseBatchFlow(BaseFlow):
def preprocess(self, s): return [] def prep(self, shared): return []
class BatchFlow(BaseBatchFlow, Flow): class BatchFlow(BaseBatchFlow, Flow):
def _run(self, s): def _run(self, shared):
pr = self.preprocess(s) prep_res = self.prep(shared)
for d in pr: for d in prep_res:
mp = self.params.copy(); mp.update(d) mp = self.params.copy(); mp.update(d)
self._process(s, mp) self._exec(shared, mp)
return self.postprocess(s, pr, None) return self.post(shared, prep_res, None)
class AsyncNode(Node): class AsyncNode(Node):
def postprocess(self, s, pr, r): raise NotImplementedError("Use postprocess_async") def post(self, shared, prep_res, exec_res): raise NotImplementedError("Use post_async")
async def postprocess_async(self, s, pr, r): await asyncio.sleep(0); return "default" async def post_async(self, shared, prep_res, exec_res): await asyncio.sleep(0); return "default"
async def run_async(self, s): async def run_async(self, shared):
if self.successors: warnings.warn("Has successors; use AsyncFlow.run_async() instead.") if self.successors:
return await self._run_async(s) warnings.warn("This node has successors. Create a parent AsyncFlow to run them or run that Flow instead.")
async def _run_async(self, s): return await self._run_async(shared)
pr = self.preprocess(s) async def _run_async(self, shared):
r = self._process(s, pr) prep_res = self.prep(shared)
return await self.postprocess_async(s, pr, r) exec_res = self._exec(shared, prep_res)
def _run(self, s): raise RuntimeError("AsyncNode requires async execution") return await self.post_async(shared, prep_res, exec_res)
def _run(self, shared): raise RuntimeError("AsyncNode requires async execution")
class AsyncFlow(BaseFlow, AsyncNode): class AsyncFlow(BaseFlow, AsyncNode):
async def _process_async(self, s, p=None): async def _exec_async(self, shared, params=None):
curr, p = self.start_node, (p if p else self.params.copy()) curr, p = self.start_node, (params if params else self.params.copy())
while curr: while curr:
curr.set_params(p) curr.set_params(p)
c = await curr._run_async(s) if hasattr(curr, "run_async") else curr._run(s) c = await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared)
curr = self.get_next_node(curr, c) curr = self.get_next_node(curr, c)
async def _run_async(self, s): async def _run_async(self, shared):
pr = self.preprocess(s) prep_res = self.prep(shared)
await self._process_async(s) await self._exec_async(shared)
return await self.postprocess_async(s, pr, None) return await self.post_async(shared, prep_res, None)
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow): class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
async def _run_async(self, s): async def _run_async(self, shared):
pr = self.preprocess(s) prep_res = self.prep(shared)
for d in pr: for d in prep_res:
mp = self.params.copy(); mp.update(d) mp = self.params.copy(); mp.update(d)
await self._process_async(s, mp) await self._exec_async(shared, mp)
return await self.postprocess_async(s, pr, None) return await self.post_async(shared, prep_res, None)

View File

@ -7,7 +7,7 @@ sys.path.append(str(Path(__file__).parent.parent))
from minillmflow import AsyncNode, BatchAsyncFlow from minillmflow import AsyncNode, BatchAsyncFlow
class AsyncDataProcessNode(AsyncNode): class AsyncDataProcessNode(AsyncNode):
def process(self, shared_storage, prep_result): def exec(self, shared_storage, prep_result):
key = self.params.get('key') key = self.params.get('key')
data = shared_storage['input_data'][key] data = shared_storage['input_data'][key]
if 'results' not in shared_storage: if 'results' not in shared_storage:
@ -15,14 +15,14 @@ class AsyncDataProcessNode(AsyncNode):
shared_storage['results'][key] = data shared_storage['results'][key] = data
return data return data
async def postprocess_async(self, shared_storage, prep_result, proc_result): async def post_async(self, shared_storage, prep_result, proc_result):
await asyncio.sleep(0.01) # Simulate async work await asyncio.sleep(0.01) # Simulate async work
key = self.params.get('key') key = self.params.get('key')
shared_storage['results'][key] = proc_result * 2 # Double the value shared_storage['results'][key] = proc_result * 2 # Double the value
return "processed" return "processed"
class AsyncErrorNode(AsyncNode): class AsyncErrorNode(AsyncNode):
async def postprocess_async(self, shared_storage, prep_result, proc_result): async def post_async(self, shared_storage, prep_result, proc_result):
key = self.params.get('key') key = self.params.get('key')
if key == 'error_key': if key == 'error_key':
raise ValueError(f"Async error processing key: {key}") raise ValueError(f"Async error processing key: {key}")
@ -35,7 +35,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
def test_basic_async_batch_processing(self): def test_basic_async_batch_processing(self):
"""Test basic async batch processing with multiple keys""" """Test basic async batch processing with multiple keys"""
class SimpleTestAsyncBatchFlow(BatchAsyncFlow): class SimpleTestAsyncBatchFlow(BatchAsyncFlow):
def preprocess(self, shared_storage): def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()] return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = { shared_storage = {
@ -59,7 +59,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
def test_empty_async_batch(self): def test_empty_async_batch(self):
"""Test async batch processing with empty input""" """Test async batch processing with empty input"""
class EmptyTestAsyncBatchFlow(BatchAsyncFlow): class EmptyTestAsyncBatchFlow(BatchAsyncFlow):
def preprocess(self, shared_storage): def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()] return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = { shared_storage = {
@ -74,7 +74,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
def test_async_error_handling(self): def test_async_error_handling(self):
"""Test error handling during async batch processing""" """Test error handling during async batch processing"""
class ErrorTestAsyncBatchFlow(BatchAsyncFlow): class ErrorTestAsyncBatchFlow(BatchAsyncFlow):
def preprocess(self, shared_storage): def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()] return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = { shared_storage = {
@ -93,7 +93,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
def test_nested_async_flow(self): def test_nested_async_flow(self):
"""Test async batch processing with nested flows""" """Test async batch processing with nested flows"""
class AsyncInnerNode(AsyncNode): class AsyncInnerNode(AsyncNode):
async def postprocess_async(self, shared_storage, prep_result, proc_result): async def post_async(self, shared_storage, prep_result, proc_result):
key = self.params.get('key') key = self.params.get('key')
if 'intermediate_results' not in shared_storage: if 'intermediate_results' not in shared_storage:
shared_storage['intermediate_results'] = {} shared_storage['intermediate_results'] = {}
@ -102,7 +102,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
return "next" return "next"
class AsyncOuterNode(AsyncNode): class AsyncOuterNode(AsyncNode):
async def postprocess_async(self, shared_storage, prep_result, proc_result): async def post_async(self, shared_storage, prep_result, proc_result):
key = self.params.get('key') key = self.params.get('key')
if 'results' not in shared_storage: if 'results' not in shared_storage:
shared_storage['results'] = {} shared_storage['results'] = {}
@ -111,7 +111,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
return "done" return "done"
class NestedAsyncBatchFlow(BatchAsyncFlow): class NestedAsyncBatchFlow(BatchAsyncFlow):
def preprocess(self, shared_storage): def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()] return [{'key': k} for k in shared_storage['input_data'].keys()]
# Create inner flow # Create inner flow
@ -138,7 +138,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
def test_custom_async_parameters(self): def test_custom_async_parameters(self):
"""Test async batch processing with additional custom parameters""" """Test async batch processing with additional custom parameters"""
class CustomParamAsyncNode(AsyncNode): class CustomParamAsyncNode(AsyncNode):
async def postprocess_async(self, shared_storage, prep_result, proc_result): async def post_async(self, shared_storage, prep_result, proc_result):
key = self.params.get('key') key = self.params.get('key')
multiplier = self.params.get('multiplier', 1) multiplier = self.params.get('multiplier', 1)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
@ -148,7 +148,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
return "done" return "done"
class CustomParamAsyncBatchFlow(BatchAsyncFlow): class CustomParamAsyncBatchFlow(BatchAsyncFlow):
def preprocess(self, shared_storage): def prep(self, shared_storage):
return [{ return [{
'key': k, 'key': k,
'multiplier': i + 1 'multiplier': i + 1

View File

@ -11,19 +11,19 @@ class AsyncNumberNode(AsyncNode):
""" """
Simple async node that sets 'current' to a given number. Simple async node that sets 'current' to a given number.
Demonstrates overriding .process() (sync) and using Demonstrates overriding .process() (sync) and using
postprocess_async() for the async portion. post_async() for the async portion.
""" """
def __init__(self, number): def __init__(self, number):
super().__init__() super().__init__()
self.number = number self.number = number
def process(self, shared_storage, data): def exec(self, shared_storage, data):
# Synchronous work is allowed inside an AsyncNode, # Synchronous work is allowed inside an AsyncNode,
# but final 'condition' is determined by postprocess_async(). # but final 'condition' is determined by post_async().
shared_storage['current'] = self.number shared_storage['current'] = self.number
return "set_number" return "set_number"
async def postprocess_async(self, shared_storage, prep_result, proc_result): async def post_async(self, shared_storage, prep_result, proc_result):
# Possibly do asynchronous tasks here # Possibly do asynchronous tasks here
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
# Return a condition for the flow # Return a condition for the flow
@ -34,11 +34,11 @@ class AsyncIncrementNode(AsyncNode):
""" """
Demonstrates incrementing the 'current' value asynchronously. Demonstrates incrementing the 'current' value asynchronously.
""" """
def process(self, shared_storage, data): def exec(self, shared_storage, data):
shared_storage['current'] = shared_storage.get('current', 0) + 1 shared_storage['current'] = shared_storage.get('current', 0) + 1
return "incremented" return "incremented"
async def postprocess_async(self, shared_storage, prep_result, proc_result): async def post_async(self, shared_storage, prep_result, proc_result):
await asyncio.sleep(0.01) # simulate async I/O await asyncio.sleep(0.01) # simulate async I/O
return "done" return "done"
@ -105,18 +105,18 @@ class TestAsyncFlow(unittest.TestCase):
""" """
Demonstrate a branching scenario where we return different Demonstrate a branching scenario where we return different
conditions. For example, you could have an async node that conditions. For example, you could have an async node that
returns "go_left" or "go_right" in postprocess_async, but here returns "go_left" or "go_right" in post_async, but here
we'll keep it simpler for demonstration. we'll keep it simpler for demonstration.
""" """
class BranchingAsyncNode(AsyncNode): class BranchingAsyncNode(AsyncNode):
def process(self, shared_storage, data): def exec(self, shared_storage, data):
value = shared_storage.get("value", 0) value = shared_storage.get("value", 0)
shared_storage["value"] = value shared_storage["value"] = value
# We'll decide branch based on whether 'value' is positive # We'll decide branch based on whether 'value' is positive
return None return None
async def postprocess_async(self, shared_storage, prep_result, proc_result): async def post_async(self, shared_storage, prep_result, proc_result):
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
if shared_storage["value"] >= 0: if shared_storage["value"] >= 0:
return "positive_branch" return "positive_branch"
@ -124,12 +124,12 @@ class TestAsyncFlow(unittest.TestCase):
return "negative_branch" return "negative_branch"
class PositiveNode(Node): class PositiveNode(Node):
def process(self, shared_storage, data): def exec(self, shared_storage, data):
shared_storage["path"] = "positive" shared_storage["path"] = "positive"
return None return None
class NegativeNode(Node): class NegativeNode(Node):
def process(self, shared_storage, data): def exec(self, shared_storage, data):
shared_storage["path"] = "negative" shared_storage["path"] = "negative"
return None return None

View File

@ -6,7 +6,7 @@ sys.path.append(str(Path(__file__).parent.parent))
from minillmflow import Node, BatchFlow, Flow from minillmflow import Node, BatchFlow, Flow
class DataProcessNode(Node): class DataProcessNode(Node):
def process(self, shared_storage, prep_result): def exec(self, shared_storage, prep_result):
key = self.params.get('key') key = self.params.get('key')
data = shared_storage['input_data'][key] data = shared_storage['input_data'][key]
if 'results' not in shared_storage: if 'results' not in shared_storage:
@ -14,7 +14,7 @@ class DataProcessNode(Node):
shared_storage['results'][key] = data * 2 shared_storage['results'][key] = data * 2
class ErrorProcessNode(Node): class ErrorProcessNode(Node):
def process(self, shared_storage, prep_result): def exec(self, shared_storage, prep_result):
key = self.params.get('key') key = self.params.get('key')
if key == 'error_key': if key == 'error_key':
raise ValueError(f"Error processing key: {key}") raise ValueError(f"Error processing key: {key}")
@ -29,7 +29,7 @@ class TestBatchFlow(unittest.TestCase):
def test_basic_batch_processing(self): def test_basic_batch_processing(self):
"""Test basic batch processing with multiple keys""" """Test basic batch processing with multiple keys"""
class SimpleTestBatchFlow(BatchFlow): class SimpleTestBatchFlow(BatchFlow):
def preprocess(self, shared_storage): def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()] return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = { shared_storage = {
@ -53,7 +53,7 @@ class TestBatchFlow(unittest.TestCase):
def test_empty_input(self): def test_empty_input(self):
"""Test batch processing with empty input dictionary""" """Test batch processing with empty input dictionary"""
class EmptyTestBatchFlow(BatchFlow): class EmptyTestBatchFlow(BatchFlow):
def preprocess(self, shared_storage): def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()] return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = { shared_storage = {
@ -68,7 +68,7 @@ class TestBatchFlow(unittest.TestCase):
def test_single_item(self): def test_single_item(self):
"""Test batch processing with single item""" """Test batch processing with single item"""
class SingleItemBatchFlow(BatchFlow): class SingleItemBatchFlow(BatchFlow):
def preprocess(self, shared_storage): def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()] return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = { shared_storage = {
@ -88,7 +88,7 @@ class TestBatchFlow(unittest.TestCase):
def test_error_handling(self): def test_error_handling(self):
"""Test error handling during batch processing""" """Test error handling during batch processing"""
class ErrorTestBatchFlow(BatchFlow): class ErrorTestBatchFlow(BatchFlow):
def preprocess(self, shared_storage): def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()] return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = { shared_storage = {
@ -107,21 +107,21 @@ class TestBatchFlow(unittest.TestCase):
def test_nested_flow(self): def test_nested_flow(self):
"""Test batch processing with nested flows""" """Test batch processing with nested flows"""
class InnerNode(Node): class InnerNode(Node):
def process(self, shared_storage, prep_result): def exec(self, shared_storage, prep_result):
key = self.params.get('key') key = self.params.get('key')
if 'intermediate_results' not in shared_storage: if 'intermediate_results' not in shared_storage:
shared_storage['intermediate_results'] = {} shared_storage['intermediate_results'] = {}
shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1 shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1
class OuterNode(Node): class OuterNode(Node):
def process(self, shared_storage, prep_result): def exec(self, shared_storage, prep_result):
key = self.params.get('key') key = self.params.get('key')
if 'results' not in shared_storage: if 'results' not in shared_storage:
shared_storage['results'] = {} shared_storage['results'] = {}
shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2 shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2
class NestedBatchFlow(BatchFlow): class NestedBatchFlow(BatchFlow):
def preprocess(self, shared_storage): def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()] return [{'key': k} for k in shared_storage['input_data'].keys()]
# Create inner flow # Create inner flow
@ -148,7 +148,7 @@ class TestBatchFlow(unittest.TestCase):
def test_custom_parameters(self): def test_custom_parameters(self):
"""Test batch processing with additional custom parameters""" """Test batch processing with additional custom parameters"""
class CustomParamNode(Node): class CustomParamNode(Node):
def process(self, shared_storage, prep_result): def exec(self, shared_storage, prep_result):
key = self.params.get('key') key = self.params.get('key')
multiplier = self.params.get('multiplier', 1) multiplier = self.params.get('multiplier', 1)
if 'results' not in shared_storage: if 'results' not in shared_storage:
@ -156,7 +156,7 @@ class TestBatchFlow(unittest.TestCase):
shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier
class CustomParamBatchFlow(BatchFlow): class CustomParamBatchFlow(BatchFlow):
def preprocess(self, shared_storage): def prep(self, shared_storage):
return [{ return [{
'key': k, 'key': k,
'multiplier': i + 1 'multiplier': i + 1

View File

@ -10,7 +10,7 @@ class ArrayChunkNode(BatchNode):
super().__init__() super().__init__()
self.chunk_size = chunk_size self.chunk_size = chunk_size
def preprocess(self, shared_storage): def prep(self, shared_storage):
# Get array from shared storage and split into chunks # Get array from shared storage and split into chunks
array = shared_storage.get('input_array', []) array = shared_storage.get('input_array', [])
chunks = [] chunks = []
@ -19,20 +19,20 @@ class ArrayChunkNode(BatchNode):
chunks.append((i, end)) chunks.append((i, end))
return chunks return chunks
def process(self, shared_storage, chunk_indices): def exec(self, shared_storage, chunk_indices):
start, end = chunk_indices start, end = chunk_indices
array = shared_storage['input_array'] array = shared_storage['input_array']
# Process the chunk and return its sum # Process the chunk and return its sum
chunk_sum = sum(array[start:end]) chunk_sum = sum(array[start:end])
return chunk_sum return chunk_sum
def postprocess(self, shared_storage, prep_result, proc_result): def post(self, shared_storage, prep_result, proc_result):
# Store chunk results in shared storage # Store chunk results in shared storage
shared_storage['chunk_results'] = proc_result shared_storage['chunk_results'] = proc_result
return "default" return "default"
class SumReduceNode(Node): class SumReduceNode(Node):
def process(self, shared_storage, data): def exec(self, shared_storage, data):
# Get chunk results from shared storage and sum them # Get chunk results from shared storage and sum them
chunk_results = shared_storage.get('chunk_results', []) chunk_results = shared_storage.get('chunk_results', [])
total = sum(chunk_results) total = sum(chunk_results)
@ -48,7 +48,7 @@ class TestBatchNode(unittest.TestCase):
} }
chunk_node = ArrayChunkNode(chunk_size=10) chunk_node = ArrayChunkNode(chunk_size=10)
chunks = chunk_node.preprocess(shared_storage) chunks = chunk_node.prep(shared_storage)
self.assertEqual(chunks, [(0, 10), (10, 20), (20, 25)]) self.assertEqual(chunks, [(0, 10), (10, 20), (20, 25)])

View File

@ -10,7 +10,7 @@ class NumberNode(Node):
super().__init__() super().__init__()
self.number = number self.number = number
def process(self, shared_storage, data): def exec(self, shared_storage, data):
shared_storage['current'] = self.number shared_storage['current'] = self.number
class AddNode(Node): class AddNode(Node):
@ -18,7 +18,7 @@ class AddNode(Node):
super().__init__() super().__init__()
self.number = number self.number = number
def process(self, shared_storage, data): def exec(self, shared_storage, data):
shared_storage['current'] += self.number shared_storage['current'] += self.number
class MultiplyNode(Node): class MultiplyNode(Node):
@ -26,18 +26,18 @@ class MultiplyNode(Node):
super().__init__() super().__init__()
self.number = number self.number = number
def process(self, shared_storage, data): def exec(self, shared_storage, data):
shared_storage['current'] *= self.number shared_storage['current'] *= self.number
class CheckPositiveNode(Node): class CheckPositiveNode(Node):
def postprocess(self, shared_storage, prep_result, proc_result): def post(self, shared_storage, prep_result, proc_result):
if shared_storage['current'] >= 0: if shared_storage['current'] >= 0:
return 'positive' return 'positive'
else: else:
return 'negative' return 'negative'
class NoOpNode(Node): class NoOpNode(Node):
def process(self, shared_storage, data): def exec(self, shared_storage, data):
# Do nothing, just pass # Do nothing, just pass
pass pass

View File

@ -12,7 +12,7 @@ class NumberNode(Node):
super().__init__() super().__init__()
self.number = number self.number = number
def process(self, shared_storage, prep_result): def exec(self, shared_storage, prep_result):
shared_storage['current'] = self.number shared_storage['current'] = self.number
class AddNode(Node): class AddNode(Node):
@ -20,7 +20,7 @@ class AddNode(Node):
super().__init__() super().__init__()
self.number = number self.number = number
def process(self, shared_storage, prep_result): def exec(self, shared_storage, prep_result):
shared_storage['current'] += self.number shared_storage['current'] += self.number
class MultiplyNode(Node): class MultiplyNode(Node):
@ -28,7 +28,7 @@ class MultiplyNode(Node):
super().__init__() super().__init__()
self.number = number self.number = number
def process(self, shared_storage, prep_result): def exec(self, shared_storage, prep_result):
shared_storage['current'] *= self.number shared_storage['current'] *= self.number
class TestFlowComposition(unittest.TestCase): class TestFlowComposition(unittest.TestCase):