fix trsts

This commit is contained in:
zachary62 2024-12-26 03:18:00 +00:00
parent 812c041bb4
commit 4b9b357608
7 changed files with 105 additions and 101 deletions

View File

@ -6,95 +6,99 @@ class BaseNode:
def add_successor(self, node, cond="default"):
if cond in self.successors: warnings.warn(f"Overwriting existing successor for '{cond}'")
self.successors[cond] = node; return node
def preprocess(self, s): return None
def process(self, s, p): return None
def _process(self, s, p): return self.process(s, p)
def postprocess(self, s, pr, r): return "default"
def _run(self, s):
pr = self.preprocess(s)
r = self._process(s, pr)
return self.postprocess(s, pr, r)
def run(self, s):
if self.successors: warnings.warn("Has successors; use Flow.run() instead.")
return self._run(s)
def prep(self, shared): return None
def exec(self, shared, prep_res): return None
def _exec(self, shared, prep_res): return self.exec(shared, prep_res)
def post(self, shared, prep_res, exec_res): return "default"
def _run(self, shared):
prep_res = self.prep(shared)
exec_res = self._exec(shared, prep_res)
return self.post(shared, prep_res, exec_res)
def run(self, shared):
if self.successors:
warnings.warn("This node has successors. Create a parent Flow to run them or run that Flow instead.")
return self._run(shared)
def __rshift__(self, other): return self.add_successor(other)
def __sub__(self, cond):
if isinstance(cond, str): return _ConditionalTransition(self, cond)
raise TypeError("Condition must be a string")
class _ConditionalTransition:
def __init__(self, src, c): self.src, self.c = src, c
def __rshift__(self, tgt): return self.src.add_successor(tgt, self.c)
def __init__(self, src, cond): self.src, self.cond = src, cond
def __rshift__(self, tgt): return self.src.add_successor(tgt, self.cond)
class Node(BaseNode):
def __init__(self, max_retries=1): super().__init__(); self.max_retries = max_retries
def process_after_fail(self, s, d, e): raise e
def _process(self, s, d):
def process_after_fail(self, shared, prep_res, exc): raise exc
def _exec(self, shared, prep_res):
for i in range(self.max_retries):
try: return super()._process(s, d)
try: return super()._exec(shared, prep_res)
except Exception as e:
if i == self.max_retries - 1: return self.process_after_fail(s, d, e)
if i == self.max_retries - 1: return self.process_after_fail(shared, prep_res, e)
class BatchNode(Node):
def preprocess(self, s): return []
def process(self, s, item): return None
def _process(self, s, items): return [super(Node, self)._process(s, i) for i in items]
def prep(self, shared): return []
def exec(self, shared, item): return None
def _exec(self, shared, items): return [super(Node, self)._exec(shared, i) for i in items]
class BaseFlow(BaseNode):
def __init__(self, start_node): super().__init__(); self.start_node = start_node
def get_next_node(self, curr, c):
nxt = curr.successors.get(c)
if nxt is None and curr.successors: warnings.warn(f"Flow ends. '{c}' not found in {list(curr.successors.keys())}")
def get_next_node(self, curr, cond):
nxt = curr.successors.get(cond)
if nxt is None and curr.successors:
warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors.keys())}")
return nxt
class Flow(BaseFlow):
def _process(self, s, p=None):
curr, p = self.start_node, (p if p is not None else self.params.copy())
def _exec(self, shared, params=None):
curr = self.start_node
p = params if params else self.params.copy()
while curr:
curr.set_params(p)
c = curr._run(s)
c = curr._run(shared)
curr = self.get_next_node(curr, c)
def process(self, s, pr): raise NotImplementedError("Use Flow._process(...) instead")
def exec(self, shared, prep_res): raise NotImplementedError("Flow exec not used directly")
class BaseBatchFlow(BaseFlow):
def preprocess(self, s): return []
def prep(self, shared): return []
class BatchFlow(BaseBatchFlow, Flow):
def _run(self, s):
pr = self.preprocess(s)
for d in pr:
def _run(self, shared):
prep_res = self.prep(shared)
for d in prep_res:
mp = self.params.copy(); mp.update(d)
self._process(s, mp)
return self.postprocess(s, pr, None)
self._exec(shared, mp)
return self.post(shared, prep_res, None)
class AsyncNode(Node):
def postprocess(self, s, pr, r): raise NotImplementedError("Use postprocess_async")
async def postprocess_async(self, s, pr, r): await asyncio.sleep(0); return "default"
async def run_async(self, s):
if self.successors: warnings.warn("Has successors; use AsyncFlow.run_async() instead.")
return await self._run_async(s)
async def _run_async(self, s):
pr = self.preprocess(s)
r = self._process(s, pr)
return await self.postprocess_async(s, pr, r)
def _run(self, s): raise RuntimeError("AsyncNode requires async execution")
def post(self, shared, prep_res, exec_res): raise NotImplementedError("Use post_async")
async def post_async(self, shared, prep_res, exec_res): await asyncio.sleep(0); return "default"
async def run_async(self, shared):
if self.successors:
warnings.warn("This node has successors. Create a parent AsyncFlow to run them or run that Flow instead.")
return await self._run_async(shared)
async def _run_async(self, shared):
prep_res = self.prep(shared)
exec_res = self._exec(shared, prep_res)
return await self.post_async(shared, prep_res, exec_res)
def _run(self, shared): raise RuntimeError("AsyncNode requires async execution")
class AsyncFlow(BaseFlow, AsyncNode):
async def _process_async(self, s, p=None):
curr, p = self.start_node, (p if p else self.params.copy())
async def _exec_async(self, shared, params=None):
curr, p = self.start_node, (params if params else self.params.copy())
while curr:
curr.set_params(p)
c = await curr._run_async(s) if hasattr(curr, "run_async") else curr._run(s)
c = await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared)
curr = self.get_next_node(curr, c)
async def _run_async(self, s):
pr = self.preprocess(s)
await self._process_async(s)
return await self.postprocess_async(s, pr, None)
async def _run_async(self, shared):
prep_res = self.prep(shared)
await self._exec_async(shared)
return await self.post_async(shared, prep_res, None)
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
async def _run_async(self, s):
pr = self.preprocess(s)
for d in pr:
async def _run_async(self, shared):
prep_res = self.prep(shared)
for d in prep_res:
mp = self.params.copy(); mp.update(d)
await self._process_async(s, mp)
return await self.postprocess_async(s, pr, None)
await self._exec_async(shared, mp)
return await self.post_async(shared, prep_res, None)

View File

@ -7,7 +7,7 @@ sys.path.append(str(Path(__file__).parent.parent))
from minillmflow import AsyncNode, BatchAsyncFlow
class AsyncDataProcessNode(AsyncNode):
def process(self, shared_storage, prep_result):
def exec(self, shared_storage, prep_result):
key = self.params.get('key')
data = shared_storage['input_data'][key]
if 'results' not in shared_storage:
@ -15,14 +15,14 @@ class AsyncDataProcessNode(AsyncNode):
shared_storage['results'][key] = data
return data
async def postprocess_async(self, shared_storage, prep_result, proc_result):
async def post_async(self, shared_storage, prep_result, proc_result):
await asyncio.sleep(0.01) # Simulate async work
key = self.params.get('key')
shared_storage['results'][key] = proc_result * 2 # Double the value
return "processed"
class AsyncErrorNode(AsyncNode):
async def postprocess_async(self, shared_storage, prep_result, proc_result):
async def post_async(self, shared_storage, prep_result, proc_result):
key = self.params.get('key')
if key == 'error_key':
raise ValueError(f"Async error processing key: {key}")
@ -35,7 +35,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
def test_basic_async_batch_processing(self):
"""Test basic async batch processing with multiple keys"""
class SimpleTestAsyncBatchFlow(BatchAsyncFlow):
def preprocess(self, shared_storage):
def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = {
@ -59,7 +59,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
def test_empty_async_batch(self):
"""Test async batch processing with empty input"""
class EmptyTestAsyncBatchFlow(BatchAsyncFlow):
def preprocess(self, shared_storage):
def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = {
@ -74,7 +74,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
def test_async_error_handling(self):
"""Test error handling during async batch processing"""
class ErrorTestAsyncBatchFlow(BatchAsyncFlow):
def preprocess(self, shared_storage):
def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = {
@ -93,7 +93,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
def test_nested_async_flow(self):
"""Test async batch processing with nested flows"""
class AsyncInnerNode(AsyncNode):
async def postprocess_async(self, shared_storage, prep_result, proc_result):
async def post_async(self, shared_storage, prep_result, proc_result):
key = self.params.get('key')
if 'intermediate_results' not in shared_storage:
shared_storage['intermediate_results'] = {}
@ -102,7 +102,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
return "next"
class AsyncOuterNode(AsyncNode):
async def postprocess_async(self, shared_storage, prep_result, proc_result):
async def post_async(self, shared_storage, prep_result, proc_result):
key = self.params.get('key')
if 'results' not in shared_storage:
shared_storage['results'] = {}
@ -111,7 +111,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
return "done"
class NestedAsyncBatchFlow(BatchAsyncFlow):
def preprocess(self, shared_storage):
def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
# Create inner flow
@ -138,7 +138,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
def test_custom_async_parameters(self):
"""Test async batch processing with additional custom parameters"""
class CustomParamAsyncNode(AsyncNode):
async def postprocess_async(self, shared_storage, prep_result, proc_result):
async def post_async(self, shared_storage, prep_result, proc_result):
key = self.params.get('key')
multiplier = self.params.get('multiplier', 1)
await asyncio.sleep(0.01)
@ -148,7 +148,7 @@ class TestAsyncBatchFlow(unittest.TestCase):
return "done"
class CustomParamAsyncBatchFlow(BatchAsyncFlow):
def preprocess(self, shared_storage):
def prep(self, shared_storage):
return [{
'key': k,
'multiplier': i + 1

View File

@ -11,19 +11,19 @@ class AsyncNumberNode(AsyncNode):
"""
Simple async node that sets 'current' to a given number.
Demonstrates overriding .process() (sync) and using
postprocess_async() for the async portion.
post_async() for the async portion.
"""
def __init__(self, number):
super().__init__()
self.number = number
def process(self, shared_storage, data):
def exec(self, shared_storage, data):
# Synchronous work is allowed inside an AsyncNode,
# but final 'condition' is determined by postprocess_async().
# but final 'condition' is determined by post_async().
shared_storage['current'] = self.number
return "set_number"
async def postprocess_async(self, shared_storage, prep_result, proc_result):
async def post_async(self, shared_storage, prep_result, proc_result):
# Possibly do asynchronous tasks here
await asyncio.sleep(0.01)
# Return a condition for the flow
@ -34,11 +34,11 @@ class AsyncIncrementNode(AsyncNode):
"""
Demonstrates incrementing the 'current' value asynchronously.
"""
def process(self, shared_storage, data):
def exec(self, shared_storage, data):
shared_storage['current'] = shared_storage.get('current', 0) + 1
return "incremented"
async def postprocess_async(self, shared_storage, prep_result, proc_result):
async def post_async(self, shared_storage, prep_result, proc_result):
await asyncio.sleep(0.01) # simulate async I/O
return "done"
@ -105,18 +105,18 @@ class TestAsyncFlow(unittest.TestCase):
"""
Demonstrate a branching scenario where we return different
conditions. For example, you could have an async node that
returns "go_left" or "go_right" in postprocess_async, but here
returns "go_left" or "go_right" in post_async, but here
we'll keep it simpler for demonstration.
"""
class BranchingAsyncNode(AsyncNode):
def process(self, shared_storage, data):
def exec(self, shared_storage, data):
value = shared_storage.get("value", 0)
shared_storage["value"] = value
# We'll decide branch based on whether 'value' is positive
return None
async def postprocess_async(self, shared_storage, prep_result, proc_result):
async def post_async(self, shared_storage, prep_result, proc_result):
await asyncio.sleep(0.01)
if shared_storage["value"] >= 0:
return "positive_branch"
@ -124,12 +124,12 @@ class TestAsyncFlow(unittest.TestCase):
return "negative_branch"
class PositiveNode(Node):
def process(self, shared_storage, data):
def exec(self, shared_storage, data):
shared_storage["path"] = "positive"
return None
class NegativeNode(Node):
def process(self, shared_storage, data):
def exec(self, shared_storage, data):
shared_storage["path"] = "negative"
return None

View File

@ -6,7 +6,7 @@ sys.path.append(str(Path(__file__).parent.parent))
from minillmflow import Node, BatchFlow, Flow
class DataProcessNode(Node):
def process(self, shared_storage, prep_result):
def exec(self, shared_storage, prep_result):
key = self.params.get('key')
data = shared_storage['input_data'][key]
if 'results' not in shared_storage:
@ -14,7 +14,7 @@ class DataProcessNode(Node):
shared_storage['results'][key] = data * 2
class ErrorProcessNode(Node):
def process(self, shared_storage, prep_result):
def exec(self, shared_storage, prep_result):
key = self.params.get('key')
if key == 'error_key':
raise ValueError(f"Error processing key: {key}")
@ -29,7 +29,7 @@ class TestBatchFlow(unittest.TestCase):
def test_basic_batch_processing(self):
"""Test basic batch processing with multiple keys"""
class SimpleTestBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = {
@ -53,7 +53,7 @@ class TestBatchFlow(unittest.TestCase):
def test_empty_input(self):
"""Test batch processing with empty input dictionary"""
class EmptyTestBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = {
@ -68,7 +68,7 @@ class TestBatchFlow(unittest.TestCase):
def test_single_item(self):
"""Test batch processing with single item"""
class SingleItemBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = {
@ -88,7 +88,7 @@ class TestBatchFlow(unittest.TestCase):
def test_error_handling(self):
"""Test error handling during batch processing"""
class ErrorTestBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
shared_storage = {
@ -107,21 +107,21 @@ class TestBatchFlow(unittest.TestCase):
def test_nested_flow(self):
"""Test batch processing with nested flows"""
class InnerNode(Node):
def process(self, shared_storage, prep_result):
def exec(self, shared_storage, prep_result):
key = self.params.get('key')
if 'intermediate_results' not in shared_storage:
shared_storage['intermediate_results'] = {}
shared_storage['intermediate_results'][key] = shared_storage['input_data'][key] + 1
class OuterNode(Node):
def process(self, shared_storage, prep_result):
def exec(self, shared_storage, prep_result):
key = self.params.get('key')
if 'results' not in shared_storage:
shared_storage['results'] = {}
shared_storage['results'][key] = shared_storage['intermediate_results'][key] * 2
class NestedBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
def prep(self, shared_storage):
return [{'key': k} for k in shared_storage['input_data'].keys()]
# Create inner flow
@ -148,7 +148,7 @@ class TestBatchFlow(unittest.TestCase):
def test_custom_parameters(self):
"""Test batch processing with additional custom parameters"""
class CustomParamNode(Node):
def process(self, shared_storage, prep_result):
def exec(self, shared_storage, prep_result):
key = self.params.get('key')
multiplier = self.params.get('multiplier', 1)
if 'results' not in shared_storage:
@ -156,7 +156,7 @@ class TestBatchFlow(unittest.TestCase):
shared_storage['results'][key] = shared_storage['input_data'][key] * multiplier
class CustomParamBatchFlow(BatchFlow):
def preprocess(self, shared_storage):
def prep(self, shared_storage):
return [{
'key': k,
'multiplier': i + 1

View File

@ -10,7 +10,7 @@ class ArrayChunkNode(BatchNode):
super().__init__()
self.chunk_size = chunk_size
def preprocess(self, shared_storage):
def prep(self, shared_storage):
# Get array from shared storage and split into chunks
array = shared_storage.get('input_array', [])
chunks = []
@ -19,20 +19,20 @@ class ArrayChunkNode(BatchNode):
chunks.append((i, end))
return chunks
def process(self, shared_storage, chunk_indices):
def exec(self, shared_storage, chunk_indices):
start, end = chunk_indices
array = shared_storage['input_array']
# Process the chunk and return its sum
chunk_sum = sum(array[start:end])
return chunk_sum
def postprocess(self, shared_storage, prep_result, proc_result):
def post(self, shared_storage, prep_result, proc_result):
# Store chunk results in shared storage
shared_storage['chunk_results'] = proc_result
return "default"
class SumReduceNode(Node):
def process(self, shared_storage, data):
def exec(self, shared_storage, data):
# Get chunk results from shared storage and sum them
chunk_results = shared_storage.get('chunk_results', [])
total = sum(chunk_results)
@ -48,7 +48,7 @@ class TestBatchNode(unittest.TestCase):
}
chunk_node = ArrayChunkNode(chunk_size=10)
chunks = chunk_node.preprocess(shared_storage)
chunks = chunk_node.prep(shared_storage)
self.assertEqual(chunks, [(0, 10), (10, 20), (20, 25)])

View File

@ -10,7 +10,7 @@ class NumberNode(Node):
super().__init__()
self.number = number
def process(self, shared_storage, data):
def exec(self, shared_storage, data):
shared_storage['current'] = self.number
class AddNode(Node):
@ -18,7 +18,7 @@ class AddNode(Node):
super().__init__()
self.number = number
def process(self, shared_storage, data):
def exec(self, shared_storage, data):
shared_storage['current'] += self.number
class MultiplyNode(Node):
@ -26,18 +26,18 @@ class MultiplyNode(Node):
super().__init__()
self.number = number
def process(self, shared_storage, data):
def exec(self, shared_storage, data):
shared_storage['current'] *= self.number
class CheckPositiveNode(Node):
def postprocess(self, shared_storage, prep_result, proc_result):
def post(self, shared_storage, prep_result, proc_result):
if shared_storage['current'] >= 0:
return 'positive'
else:
return 'negative'
class NoOpNode(Node):
def process(self, shared_storage, data):
def exec(self, shared_storage, data):
# Do nothing, just pass
pass

View File

@ -12,7 +12,7 @@ class NumberNode(Node):
super().__init__()
self.number = number
def process(self, shared_storage, prep_result):
def exec(self, shared_storage, prep_result):
shared_storage['current'] = self.number
class AddNode(Node):
@ -20,7 +20,7 @@ class AddNode(Node):
super().__init__()
self.number = number
def process(self, shared_storage, prep_result):
def exec(self, shared_storage, prep_result):
shared_storage['current'] += self.number
class MultiplyNode(Node):
@ -28,7 +28,7 @@ class MultiplyNode(Node):
super().__init__()
self.number = number
def process(self, shared_storage, prep_result):
def exec(self, shared_storage, prep_result):
shared_storage['current'] *= self.number
class TestFlowComposition(unittest.TestCase):