debug empty shared memory

This commit is contained in:
zachary62 2024-12-25 21:54:34 +00:00
parent 83dbc13054
commit c4c78b1939
2 changed files with 151 additions and 106 deletions

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import warnings
class BaseNode: class BaseNode:
""" """
@ -62,10 +63,19 @@ class BaseNode:
""" """
return _ConditionalTransition(self, condition) return _ConditionalTransition(self, condition)
def __sub__(self, condition):
"""
For chaining with - operator, e.g. node - "some_condition" >> next_node
"""
if isinstance(condition, str):
return _ConditionalTransition(self, condition)
raise TypeError("Condition must be a string")
class _ConditionalTransition: class _ConditionalTransition:
""" """
Helper for Node > 'condition' >> AnotherNode style Helper for Node > 'condition' >> AnotherNode style
(and also Node - 'condition' >> AnotherNode now).
""" """
def __init__(self, source_node, condition): def __init__(self, source_node, condition):
self.source_node = source_node self.source_node = source_node
@ -74,7 +84,6 @@ class _ConditionalTransition:
def __rshift__(self, target_node): def __rshift__(self, target_node):
return self.source_node.add_successor(target_node, self.condition) return self.source_node.add_successor(target_node, self.condition)
# robust running process
class Node(BaseNode): class Node(BaseNode):
def __init__(self, max_retries=1): def __init__(self, max_retries=1):
super().__init__() super().__init__()
@ -82,7 +91,6 @@ class Node(BaseNode):
def process_after_fail(self, shared_storage, data, exc): def process_after_fail(self, shared_storage, data, exc):
raise exc raise exc
# return "fail"
def _process(self, shared_storage, data): def _process(self, shared_storage, data):
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
@ -92,26 +100,14 @@ class Node(BaseNode):
if attempt == self.max_retries - 1: if attempt == self.max_retries - 1:
return self.process_after_fail(shared_storage, data, e) return self.process_after_fail(shared_storage, data, e)
class Flow(BaseNode):
def __init__(self, start_node=None):
self.start_node = start_node
def _process(self, shared_storage, _):
current_node = self.start_node
while current_node:
condition = current_node.run(shared_storage)
current_node = current_node.successors.get(condition, None)
def postprocess(self, shared_storage, prep_result, proc_result):
return None
class AsyncNode(Node): class AsyncNode(Node):
""" """
A Node whose postprocess step is async. A Node whose postprocess step is async.
You can also override process() to be async if needed. You can also override process() to be async if needed.
""" """
def postprocess(self, shared_storage, prep_result, proc_result):
# Not used in async workflow; define postprocess_async() instead.
raise NotImplementedError("AsyncNode requires postprocess_async, and should be run in an AsyncFlow")
async def postprocess_async(self, shared_storage, prep_result, proc_result): async def postprocess_async(self, shared_storage, prep_result, proc_result):
""" """
@ -122,106 +118,155 @@ class AsyncNode(Node):
return "default" return "default"
async def run_async(self, shared_storage=None): async def run_async(self, shared_storage=None):
"""
Async version of run.
If your process method is also async, you'll need to adapt accordingly.
"""
# We can keep preprocess synchronous or make it async as well,
# depending on your usage. Here it's left as sync for simplicity.
prep = self.preprocess(shared_storage) prep = self.preprocess(shared_storage)
# process can remain sync if you prefer, or you can define an async process.
proc = self._process(shared_storage, prep) proc = self._process(shared_storage, prep)
# postprocess is async
return await self.postprocess_async(shared_storage, prep, proc) return await self.postprocess_async(shared_storage, prep, proc)
class AsyncFlow(Flow): class BaseFlow(BaseNode):
""" """
A Flow that can handle a mixture of sync and async nodes. Abstract base flow that provides the main logic of:
If the node is an AsyncNode, calls `run_async`. - Starting from self.start_node
Otherwise, calls `run`. - Looping until no more successors
Subclasses must define how they *call* each node (sync or async).
""" """
async def _process(self, shared_storage, _):
current_node = self.start_node
while current_node:
if hasattr(current_node, "run_async") and callable(current_node.run_async):
# If it's an async node, await its run_async
condition = await current_node.run_async(shared_storage)
else:
# Otherwise, assume it's a sync node
condition = current_node.run(shared_storage)
current_node = current_node.successors.get(condition, None)
async def run_async(self, shared_storage=None):
"""
Kicks off the async flow. Similar to Flow.run,
but uses our async _process method.
"""
prep = self.preprocess(shared_storage)
# Note: flows typically don't need a meaningful process step
# because the "process" is the iteration through the nodes.
await self._process(shared_storage, prep)
return self.postprocess(shared_storage, prep, None)
class BatchNode(BaseNode):
def __init__(self, max_retries=5, delay_s=0.1):
super().__init__()
self.max_retries = max_retries
self.delay_s = delay_s
def preprocess(self, shared_storage):
return []
def process_one(self, shared_storage, item):
return None
def process_one_after_fail(self, shared_storage, item, exc):
print(f"[FAIL_ITEM] item={item}, error={exc}")
# By default, just return a "fail" marker. Could be anything you want.
return "fail"
async def _process_one(self, shared_storage, item):
for attempt in range(self.max_retries):
try:
return await self.process_one(shared_storage, item)
except Exception as e:
if attempt == self.max_retries - 1:
# If out of retries, let a subclass handle what to do next
return await self.process_one_after_fail(shared_storage, item, e)
await asyncio.sleep(self.delay_s)
async def _process(self, shared_storage, items):
results = []
for item in items:
r = await self._process_one(shared_storage, item)
results.append(r)
return results
class BatchFlow(BaseNode):
def __init__(self, start_node=None): def __init__(self, start_node=None):
super().__init__() super().__init__()
self.start_node = start_node self.start_node = start_node
def preprocess(self, shared_storage): def get_next_node(self, current_node, condition):
return [] next_node = current_node.successors.get(condition, None)
async def _process_one(self, shared_storage, param_dict): if next_node is None and current_node.successors:
node_parameters = self.parameters.copy() warnings.warn(f"Flow will end. Condition '{condition}' not found among possible conditions: {list(current_node.successors.keys())}")
node_parameters.update(param_dict)
if self.start_node: return next_node
def run(self, shared_storage=None):
"""
By default, do nothing (or raise).
Subclasses (Flow, AsyncFlow) will implement.
"""
raise NotImplementedError("BaseFlow.run must be implemented by subclasses")
async def run_async(self, shared_storage=None):
"""
By default, do nothing (or raise).
Subclasses (Flow, AsyncFlow) will implement.
"""
raise NotImplementedError("BaseFlow.run_async must be implemented by subclasses")
class Flow(BaseFlow):
"""
Synchronous flow: each node is called with .run(shared_storage).
"""
def _process_flow(self, shared_storage):
current_node = self.start_node current_node = self.start_node
while current_node: while current_node:
# set the combined parameters # Pass down the Flow's parameters to the current node
current_node.set_parameters(node_parameters) current_node.set_parameters(self.parameters)
current_node = await current_node._run_one(shared_storage or {}) # Synchronous run
condition = current_node.run(shared_storage)
# Decide next node
current_node = self.get_next_node(current_node, condition)
async def _process(self, shared_storage, items): def run(self, shared_storage=None):
results = [] prep_result = self.preprocess(shared_storage)
for param_dict in items: self._process_flow(shared_storage)
await self._process_one(shared_storage, param_dict) return self.postprocess(shared_storage, prep_result, None)
results.append(f"Ran sub-flow for param_dict={param_dict}")
return results class AsyncFlow(BaseFlow):
"""
Asynchronous flow: if a node has .run_async, we await it.
Otherwise, we fallback to .run.
"""
async def _process_flow_async(self, shared_storage):
current_node = self.start_node
while current_node:
current_node.set_parameters(self.parameters)
# If node is async-capable, call run_async; otherwise run sync
if hasattr(current_node, "run_async") and callable(current_node.run_async):
condition = await current_node.run_async(shared_storage)
else:
condition = current_node.run(shared_storage)
current_node = self.get_next_node(current_node, condition)
async def run_async(self, shared_storage=None):
prep_result = self.preprocess(shared_storage)
await self._process_flow_async(shared_storage)
return self.postprocess(shared_storage, prep_result, None)
def run(self, shared_storage=None):
return asyncio.run(self.run_async(shared_storage))
class BaseBatchFlow(BaseFlow):
"""
Abstract base for a flow that runs multiple times (a batch),
once for each set of parameters or items from preprocess().
"""
def preprocess(self, shared_storage):
"""
By default, returns an iterable of parameter-dicts or items
for the flow to process in a batch.
"""
return []
def post_batch_run(self, all_results):
"""
Hook for after the entire batch is done, to combine results, etc.
"""
return all_results
class BatchFlow(BaseBatchFlow, Flow):
"""
Synchronous batch flow: calls the flow repeatedly
for each set of parameters/items in preprocess().
"""
def run(self, shared_storage=None):
prep_result = self.preprocess(shared_storage)
all_results = []
# For each set of parameters (or items) we got from preprocess
for param_dict in prep_result:
# Merge param_dict into the Flow's parameters
original_params = self.parameters.copy()
self.parameters.update(param_dict)
# Run from the start node to end
self._process_flow(shared_storage)
# Optionally collect results from shared_storage or a custom method
all_results.append(f"Finished run with parameters: {param_dict}")
# Reset the parameters if needed
self.parameters = original_params
# Postprocess the entire batch
result = self.post_batch_run(all_results)
return self.postprocess(shared_storage, prep_result, result)
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
"""
Asynchronous batch flow: calls the flow repeatedly in an async manner
for each set of parameters/items in preprocess().
"""
async def run_async(self, shared_storage=None):
prep_result = self.preprocess(shared_storage)
all_results = []
for param_dict in prep_result:
original_params = self.parameters.copy()
self.parameters.update(param_dict)
await self._process_flow_async(shared_storage)
all_results.append(f"Finished async run with parameters: {param_dict}")
# Reset back to original parameters if needed
self.parameters = original_params
# Combine or process results at the end
result = self.post_batch_run(all_results)
return self.postprocess(shared_storage, prep_result, result)