This commit is contained in:
zachary62 2024-12-25 02:14:29 +00:00
parent 89c003f657
commit 6206eb3650
1 changed files with 142 additions and 136 deletions

View File

@ -1,55 +1,57 @@
import asyncio import asyncio
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
# Base Classes # BaseNode (no duplication of run_one)
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
class BaseNode: class BaseNode:
def __init__(self): def __init__(self):
self.set_parameters({}) self.parameters = {}
self.successors = {} self.successors = {}
def set_parameters(self, parameters):
self.parameters = parameters.copy() if parameters else {}
def add_successor(self, node, condition="default"): def add_successor(self, node, condition="default"):
self.successors[condition] = node self.successors[condition] = node
return node return node
async def preprocess(self, shared_storage): async def preprocess(self, shared_storage):
return None
async def process_one(self, shared_storage, item):
""" """
The main single-item processing method that end developers override. Override if needed to load or prepare data.
Default does nothing.
""" """
return None return None
async def robust_process_one(self, shared_storage, item): async def process(self, shared_storage, data):
""" """
In BaseNode, this is just a pass-through to `process_one`. Public method for user logic. Default does nothing.
Subclasses (like Node with retry) can override this to add extra logic.
""" """
return await self.process_one(shared_storage, item) return None
async def process(self, shared_storage, preprocess_result): async def _process(self, shared_storage, data):
""" """
Calls `robust_process_one` instead of `process_one` so that Internal hook that calls `process(...)`.
any subclass overrides of robust_process_one will apply. Subclasses override this to add extra logic (e.g. retries).
""" """
return await self.robust_process_one(shared_storage, preprocess_result) return await self.process(shared_storage, data)
async def postprocess(self, shared_storage, preprocess_result, process_result): async def postprocess(self, shared_storage, preprocess_result, process_result):
"""
By default, returns "default" to pick the default successor.
"""
return "default" return "default"
async def run_one(self, shared_storage): async def run_one(self, shared_storage):
"""
One cycle of the node:
1) preprocess
2) _process
3) postprocess
4) pick successor
"""
preprocess_result = await self.preprocess(shared_storage) preprocess_result = await self.preprocess(shared_storage)
process_result = await self.process(shared_storage, preprocess_result) process_result = await self._process(shared_storage, preprocess_result)
condition = await self.postprocess(shared_storage, preprocess_result, process_result) condition = await self.postprocess(shared_storage, preprocess_result, process_result)
if not self.successors: if not self.successors:
return None return None
elif len(self.successors) == 1: if len(self.successors) == 1:
return next(iter(self.successors.values())) return next(iter(self.successors.values()))
return self.successors.get(condition) return self.successors.get(condition)
@ -62,10 +64,14 @@ class BaseNode:
while current_node: while current_node:
current_node = await current_node.run_one(shared_storage) current_node = await current_node.run_one(shared_storage)
# Syntactic sugar for chaining
def __rshift__(self, other): def __rshift__(self, other):
return self.add_successor(other) return self.add_successor(other)
def __gt__(self, other): def __gt__(self, other):
"""
For branching: node > "condition" > another_node
"""
if isinstance(other, str): if isinstance(other, str):
return _ConditionalTransition(self, other) return _ConditionalTransition(self, other)
elif isinstance(other, BaseNode): elif isinstance(other, BaseNode):
@ -87,139 +93,139 @@ class _ConditionalTransition:
return self.source_node.add_successor(target_node, self.condition) return self.source_node.add_successor(target_node, self.condition)
# ---------------------------------------------------------------------
# Flow: allows you to define a "start_node" that is run in a sub-flow
# ---------------------------------------------------------------------
class Flow(BaseNode):
def __init__(self, start_node=None):
super().__init__()
self.start_node = start_node
async def process_one(self, shared_storage, item):
# Instead of doing a single operation, we run a sub-flow
if self.start_node:
current_node = self.start_node
while current_node:
# Pass down the parameters
current_node.set_parameters(self.parameters)
current_node = await current_node.run_one(shared_storage or {})
# ---------------------------------------------------------------------
# Node: adds robust retry logic on top of BaseNode
# ---------------------------------------------------------------------
class Node(BaseNode): class Node(BaseNode):
""" """
Retries its single-item operation up to `max_retries` times, Single-item node with robust logic.
waiting `delay_s` seconds between attempts. End devs override `process(...)`.
By default: max_retries=5, delay_s=0.1 `_process(...)` adds the retry logic.
End developers simply override `process_one` to define logic. """
def __init__(self, max_retries=5, delay_s=0.1):
super().__init__()
self.max_retries = max_retries
self.delay_s = delay_s
async def fail_item(self, shared_storage, data, exc):
"""
Called if we exhaust all retries.
"""
print(f"[FAIL_ITEM] data={data}, error={exc}")
return "fail"
async def _process(self, shared_storage, data):
"""
Wraps the users `process(...)` with retry logic.
"""
for attempt in range(self.max_retries):
try:
return await super()._process(shared_storage, data)
except Exception as e:
if attempt == self.max_retries - 1:
return await self.fail_item(shared_storage, data, e)
await asyncio.sleep(self.delay_s)
# ---------------------------------------------------------------------
# BatchNode: processes multiple items
# ---------------------------------------------------------------------
class BatchNode(BaseNode):
"""
Processes a list of items in `process(...)`.
The user overrides `process_one(item)`.
`_process_one(...)` handles robust retries for each item.
""" """
def __init__(self, max_retries=5, delay_s=0.1): def __init__(self, max_retries=5, delay_s=0.1):
super().__init__() super().__init__()
self.parameters.setdefault("max_retries", max_retries) self.max_retries = max_retries
self.parameters.setdefault("delay_s", delay_s) self.delay_s = delay_s
async def fail_one(self, shared_storage, item, exc):
"""
Called if the final retry also fails. By default,
just returns a special string or could log an error.
End developers can override this to do something else
(e.g., store the failure in a separate list or
trigger alternative logic).
"""
# Example: log and return a special status
print(f"[FAIL_ONE] item={item}, error={exc}")
return "fail"
async def robust_process_one(self, shared_storage, item):
max_retries = self.parameters.get("max_retries", 5)
delay_s = self.parameters.get("delay_s", 0.1)
for attempt in range(max_retries):
try:
# Defer to the user's process_one logic
return await super().robust_process_one(shared_storage, item)
except Exception as e:
if attempt == max_retries - 1:
# Final attempt failed; call fail_one
return await self.fail_one(shared_storage, item, e)
# Otherwise, wait a bit and try again
await asyncio.sleep(delay_s)
# ---------------------------------------------------------------------
# BatchMixin: processes a collection of items by calling robust_process_one for each
# ---------------------------------------------------------------------
class BatchMixin:
async def process(self, shared_storage, items):
"""
Processes a *collection* of items in a loop, calling robust_process_one per item.
"""
partial_results = []
for item in items:
r = await self.robust_process_one(shared_storage, item)
partial_results.append(r)
return self.merge(shared_storage, partial_results)
def merge(self, shared_storage, partial_results):
"""
Combines partial results into a single output.
By default, returns the list of partial results.
"""
return partial_results
async def preprocess(self, shared_storage): async def preprocess(self, shared_storage):
""" """
Typically, you'd return a list or collection of items to process here. Typically return a list of items to process.
By default, returns an empty list.
""" """
return [] return []
# ---------------------------------------------------------------------
# BatchNode: combines Node (robust logic) + BatchMixin (batch logic)
# ---------------------------------------------------------------------
class BatchNode(BatchMixin, Node):
"""
A batch-processing node that:
- Inherits robust retry logic from Node
- Uses BatchMixin to process a list of items
"""
async def preprocess(self, shared_storage):
# Gather or return the batch items. By default, no items.
return []
async def process_one(self, shared_storage, item): async def process_one(self, shared_storage, item):
""" """
The per-item logic that the end developer will override. End developers override single-item logic here.
By default, does nothing.
""" """
return None return None
async def _process_one(self, shared_storage, item):
# ---------------------------------------------------------------------
# BatchFlow: combines Flow (sub-flow logic) + batch processing + robust logic
# ---------------------------------------------------------------------
class BatchFlow(BatchMixin, Flow):
""" """
This class runs a sub-flow (start_node) for each item in a batch. Retry logic around process_one(item).
If you also want robust retries, you can adapt or combine with `Node`.
""" """
for attempt in range(self.max_retries):
try:
return await self.process_one(shared_storage, item)
except Exception as e:
if attempt == self.max_retries - 1:
print(f"[FAIL_ITEM] item={item}, error={e}")
return "fail"
await asyncio.sleep(self.delay_s)
async def preprocess(self, shared_storage): async def _process(self, shared_storage, items):
# Return your batch items here
return []
async def process(self, shared_storage, items):
""" """
For each item, run the sub-flow (start_node). Loops over items, calling _process_one per item.
""" """
results = [] results = []
for item in items: for item in items:
# Here we re-run the sub-flow, which happens inside process_one of Flow r = await self._process_one(shared_storage, item)
await self.process_one(shared_storage, item) results.append(r)
# Optionally collect results or do something after the sub-flow return results
results.append(f"Finished sub-flow for item: {item}")
class Flow(BaseNode):
"""
Runs a sub-flow from `start_node` once per call.
"""
def __init__(self, start_node=None):
super().__init__()
self.start_node = start_node
async def _process(self, shared_storage, _):
if self.start_node:
current_node = self.start_node
while current_node:
current_node = await current_node.run_one(shared_storage or {})
return "Flow done"
class BatchFlow(BaseNode):
"""
For each param_dict in the batch, merges it into self.parameters,
then runs the sub-flow from `start_node`.
"""
def __init__(self, start_node=None):
super().__init__()
self.start_node = start_node
async def preprocess(self, shared_storage):
"""
Return a list of param_dict objects.
"""
return []
async def process_one(self, shared_storage, param_dict):
"""
Merge param_dict into the node's parameters,
then run the sub-flow.
"""
node_parameters = self.parameters.copy()
node_parameters.update(param_dict)
if self.start_node:
current_node = self.start_node
while current_node:
# set the combined parameters
current_node.set_parameters(node_parameters)
current_node = await current_node.run_one(shared_storage or {})
async def _process(self, shared_storage, items):
"""
For each param_dict in items, run the sub-flow once.
"""
results = []
for param_dict in items:
await self.process_one(shared_storage, param_dict)
results.append(f"Ran sub-flow for param_dict={param_dict}")
return results return results