debug empty shared memory
This commit is contained in:
parent
83dbc13054
commit
c4c78b1939
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
|
||||
import warnings
|
||||
|
||||
class BaseNode:
|
||||
"""
|
||||
A base node that provides:
|
||||
|
|
@ -62,10 +63,19 @@ class BaseNode:
|
|||
"""
|
||||
return _ConditionalTransition(self, condition)
|
||||
|
||||
def __sub__(self, condition):
|
||||
"""
|
||||
For chaining with - operator, e.g. node - "some_condition" >> next_node
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return _ConditionalTransition(self, condition)
|
||||
raise TypeError("Condition must be a string")
|
||||
|
||||
|
||||
class _ConditionalTransition:
|
||||
"""
|
||||
Helper for Node > 'condition' >> AnotherNode style
|
||||
(and also Node - 'condition' >> AnotherNode now).
|
||||
"""
|
||||
def __init__(self, source_node, condition):
|
||||
self.source_node = source_node
|
||||
|
|
@ -74,7 +84,6 @@ class _ConditionalTransition:
|
|||
def __rshift__(self, target_node):
|
||||
return self.source_node.add_successor(target_node, self.condition)
|
||||
|
||||
# robust running process
|
||||
class Node(BaseNode):
|
||||
def __init__(self, max_retries=1):
|
||||
super().__init__()
|
||||
|
|
@ -82,7 +91,6 @@ class Node(BaseNode):
|
|||
|
||||
def process_after_fail(self, shared_storage, data, exc):
|
||||
raise exc
|
||||
# return "fail"
|
||||
|
||||
def _process(self, shared_storage, data):
|
||||
for attempt in range(self.max_retries):
|
||||
|
|
@ -91,28 +99,16 @@ class Node(BaseNode):
|
|||
except Exception as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
return self.process_after_fail(shared_storage, data, e)
|
||||
|
||||
class Flow(BaseNode):
|
||||
def __init__(self, start_node=None):
|
||||
self.start_node = start_node
|
||||
|
||||
def _process(self, shared_storage, _):
|
||||
current_node = self.start_node
|
||||
while current_node:
|
||||
condition = current_node.run(shared_storage)
|
||||
current_node = current_node.successors.get(condition, None)
|
||||
|
||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class AsyncNode(Node):
|
||||
"""
|
||||
A Node whose postprocess step is async.
|
||||
You can also override process() to be async if needed.
|
||||
"""
|
||||
|
||||
def postprocess(self, shared_storage, prep_result, proc_result):
|
||||
# Not used in async workflow; define postprocess_async() instead.
|
||||
raise NotImplementedError("AsyncNode requires postprocess_async, and should be run in an AsyncFlow")
|
||||
|
||||
async def postprocess_async(self, shared_storage, prep_result, proc_result):
|
||||
"""
|
||||
Async version of postprocess. By default, returns "default".
|
||||
|
|
@ -122,106 +118,155 @@ class AsyncNode(Node):
|
|||
return "default"
|
||||
|
||||
async def run_async(self, shared_storage=None):
|
||||
"""
|
||||
Async version of run.
|
||||
If your process method is also async, you'll need to adapt accordingly.
|
||||
"""
|
||||
# We can keep preprocess synchronous or make it async as well,
|
||||
# depending on your usage. Here it's left as sync for simplicity.
|
||||
prep = self.preprocess(shared_storage)
|
||||
|
||||
# process can remain sync if you prefer, or you can define an async process.
|
||||
proc = self._process(shared_storage, prep)
|
||||
|
||||
# postprocess is async
|
||||
return await self.postprocess_async(shared_storage, prep, proc)
|
||||
|
||||
|
||||
class AsyncFlow(Flow):
|
||||
class BaseFlow(BaseNode):
|
||||
"""
|
||||
A Flow that can handle a mixture of sync and async nodes.
|
||||
If the node is an AsyncNode, calls `run_async`.
|
||||
Otherwise, calls `run`.
|
||||
Abstract base flow that provides the main logic of:
|
||||
- Starting from self.start_node
|
||||
- Looping until no more successors
|
||||
Subclasses must define how they *call* each node (sync or async).
|
||||
"""
|
||||
async def _process(self, shared_storage, _):
|
||||
current_node = self.start_node
|
||||
while current_node:
|
||||
if hasattr(current_node, "run_async") and callable(current_node.run_async):
|
||||
# If it's an async node, await its run_async
|
||||
condition = await current_node.run_async(shared_storage)
|
||||
else:
|
||||
# Otherwise, assume it's a sync node
|
||||
condition = current_node.run(shared_storage)
|
||||
|
||||
current_node = current_node.successors.get(condition, None)
|
||||
|
||||
async def run_async(self, shared_storage=None):
|
||||
"""
|
||||
Kicks off the async flow. Similar to Flow.run,
|
||||
but uses our async _process method.
|
||||
"""
|
||||
prep = self.preprocess(shared_storage)
|
||||
# Note: flows typically don't need a meaningful process step
|
||||
# because the "process" is the iteration through the nodes.
|
||||
await self._process(shared_storage, prep)
|
||||
return self.postprocess(shared_storage, prep, None)
|
||||
|
||||
class BatchNode(BaseNode):
|
||||
def __init__(self, max_retries=5, delay_s=0.1):
|
||||
super().__init__()
|
||||
self.max_retries = max_retries
|
||||
self.delay_s = delay_s
|
||||
|
||||
def preprocess(self, shared_storage):
|
||||
return []
|
||||
|
||||
def process_one(self, shared_storage, item):
|
||||
return None
|
||||
|
||||
def process_one_after_fail(self, shared_storage, item, exc):
|
||||
print(f"[FAIL_ITEM] item={item}, error={exc}")
|
||||
# By default, just return a "fail" marker. Could be anything you want.
|
||||
return "fail"
|
||||
|
||||
async def _process_one(self, shared_storage, item):
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return await self.process_one(shared_storage, item)
|
||||
except Exception as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
# If out of retries, let a subclass handle what to do next
|
||||
return await self.process_one_after_fail(shared_storage, item, e)
|
||||
await asyncio.sleep(self.delay_s)
|
||||
|
||||
async def _process(self, shared_storage, items):
|
||||
results = []
|
||||
for item in items:
|
||||
r = await self._process_one(shared_storage, item)
|
||||
results.append(r)
|
||||
return results
|
||||
|
||||
class BatchFlow(BaseNode):
|
||||
def __init__(self, start_node=None):
|
||||
super().__init__()
|
||||
self.start_node = start_node
|
||||
|
||||
def get_next_node(self, current_node, condition):
|
||||
next_node = current_node.successors.get(condition, None)
|
||||
|
||||
if next_node is None and current_node.successors:
|
||||
warnings.warn(f"Flow will end. Condition '{condition}' not found among possible conditions: {list(current_node.successors.keys())}")
|
||||
|
||||
return next_node
|
||||
|
||||
def run(self, shared_storage=None):
|
||||
"""
|
||||
By default, do nothing (or raise).
|
||||
Subclasses (Flow, AsyncFlow) will implement.
|
||||
"""
|
||||
raise NotImplementedError("BaseFlow.run must be implemented by subclasses")
|
||||
|
||||
async def run_async(self, shared_storage=None):
|
||||
"""
|
||||
By default, do nothing (or raise).
|
||||
Subclasses (Flow, AsyncFlow) will implement.
|
||||
"""
|
||||
raise NotImplementedError("BaseFlow.run_async must be implemented by subclasses")
|
||||
|
||||
class Flow(BaseFlow):
|
||||
"""
|
||||
Synchronous flow: each node is called with .run(shared_storage).
|
||||
"""
|
||||
def _process_flow(self, shared_storage):
|
||||
current_node = self.start_node
|
||||
while current_node:
|
||||
# Pass down the Flow's parameters to the current node
|
||||
current_node.set_parameters(self.parameters)
|
||||
# Synchronous run
|
||||
condition = current_node.run(shared_storage)
|
||||
# Decide next node
|
||||
current_node = self.get_next_node(current_node, condition)
|
||||
|
||||
def run(self, shared_storage=None):
|
||||
prep_result = self.preprocess(shared_storage)
|
||||
self._process_flow(shared_storage)
|
||||
return self.postprocess(shared_storage, prep_result, None)
|
||||
|
||||
class AsyncFlow(BaseFlow):
|
||||
"""
|
||||
Asynchronous flow: if a node has .run_async, we await it.
|
||||
Otherwise, we fallback to .run.
|
||||
"""
|
||||
async def _process_flow_async(self, shared_storage):
|
||||
current_node = self.start_node
|
||||
while current_node:
|
||||
current_node.set_parameters(self.parameters)
|
||||
|
||||
# If node is async-capable, call run_async; otherwise run sync
|
||||
if hasattr(current_node, "run_async") and callable(current_node.run_async):
|
||||
condition = await current_node.run_async(shared_storage)
|
||||
else:
|
||||
condition = current_node.run(shared_storage)
|
||||
|
||||
current_node = self.get_next_node(current_node, condition)
|
||||
|
||||
async def run_async(self, shared_storage=None):
|
||||
prep_result = self.preprocess(shared_storage)
|
||||
await self._process_flow_async(shared_storage)
|
||||
return self.postprocess(shared_storage, prep_result, None)
|
||||
|
||||
def run(self, shared_storage=None):
|
||||
return asyncio.run(self.run_async(shared_storage))
|
||||
|
||||
class BaseBatchFlow(BaseFlow):
|
||||
"""
|
||||
Abstract base for a flow that runs multiple times (a batch),
|
||||
once for each set of parameters or items from preprocess().
|
||||
"""
|
||||
def preprocess(self, shared_storage):
|
||||
"""
|
||||
By default, returns an iterable of parameter-dicts or items
|
||||
for the flow to process in a batch.
|
||||
"""
|
||||
return []
|
||||
|
||||
async def _process_one(self, shared_storage, param_dict):
|
||||
node_parameters = self.parameters.copy()
|
||||
node_parameters.update(param_dict)
|
||||
def post_batch_run(self, all_results):
|
||||
"""
|
||||
Hook for after the entire batch is done, to combine results, etc.
|
||||
"""
|
||||
return all_results
|
||||
|
||||
if self.start_node:
|
||||
current_node = self.start_node
|
||||
while current_node:
|
||||
# set the combined parameters
|
||||
current_node.set_parameters(node_parameters)
|
||||
current_node = await current_node._run_one(shared_storage or {})
|
||||
class BatchFlow(BaseBatchFlow, Flow):
|
||||
"""
|
||||
Synchronous batch flow: calls the flow repeatedly
|
||||
for each set of parameters/items in preprocess().
|
||||
"""
|
||||
def run(self, shared_storage=None):
|
||||
prep_result = self.preprocess(shared_storage)
|
||||
all_results = []
|
||||
|
||||
async def _process(self, shared_storage, items):
|
||||
results = []
|
||||
for param_dict in items:
|
||||
await self._process_one(shared_storage, param_dict)
|
||||
results.append(f"Ran sub-flow for param_dict={param_dict}")
|
||||
return results
|
||||
# For each set of parameters (or items) we got from preprocess
|
||||
for param_dict in prep_result:
|
||||
# Merge param_dict into the Flow's parameters
|
||||
original_params = self.parameters.copy()
|
||||
self.parameters.update(param_dict)
|
||||
|
||||
# Run from the start node to end
|
||||
self._process_flow(shared_storage)
|
||||
|
||||
# Optionally collect results from shared_storage or a custom method
|
||||
all_results.append(f"Finished run with parameters: {param_dict}")
|
||||
|
||||
# Reset the parameters if needed
|
||||
self.parameters = original_params
|
||||
|
||||
# Postprocess the entire batch
|
||||
result = self.post_batch_run(all_results)
|
||||
return self.postprocess(shared_storage, prep_result, result)
|
||||
|
||||
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
||||
"""
|
||||
Asynchronous batch flow: calls the flow repeatedly in an async manner
|
||||
for each set of parameters/items in preprocess().
|
||||
"""
|
||||
async def run_async(self, shared_storage=None):
|
||||
prep_result = self.preprocess(shared_storage)
|
||||
all_results = []
|
||||
|
||||
for param_dict in prep_result:
|
||||
original_params = self.parameters.copy()
|
||||
self.parameters.update(param_dict)
|
||||
|
||||
await self._process_flow_async(shared_storage)
|
||||
|
||||
all_results.append(f"Finished async run with parameters: {param_dict}")
|
||||
|
||||
# Reset back to original parameters if needed
|
||||
self.parameters = original_params
|
||||
|
||||
# Combine or process results at the end
|
||||
result = self.post_batch_run(all_results)
|
||||
return self.postprocess(shared_storage, prep_result, result)
|
||||
Binary file not shown.
Loading…
Reference in New Issue