refactor
This commit is contained in:
parent
89c003f657
commit
6206eb3650
|
|
@ -1,55 +1,57 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# ---------------------------------------------------------------------
|
# ---------------------------------------------------------------------
|
||||||
# Base Classes
|
# BaseNode (no duplication of run_one)
|
||||||
# ---------------------------------------------------------------------
|
# ---------------------------------------------------------------------
|
||||||
class BaseNode:
|
class BaseNode:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.set_parameters({})
|
self.parameters = {}
|
||||||
self.successors = {}
|
self.successors = {}
|
||||||
|
|
||||||
def set_parameters(self, parameters):
|
|
||||||
self.parameters = parameters.copy() if parameters else {}
|
|
||||||
|
|
||||||
def add_successor(self, node, condition="default"):
|
def add_successor(self, node, condition="default"):
|
||||||
self.successors[condition] = node
|
self.successors[condition] = node
|
||||||
return node
|
return node
|
||||||
|
|
||||||
async def preprocess(self, shared_storage):
|
async def preprocess(self, shared_storage):
|
||||||
return None
|
|
||||||
|
|
||||||
async def process_one(self, shared_storage, item):
|
|
||||||
"""
|
"""
|
||||||
The main single-item processing method that end developers override.
|
Override if needed to load or prepare data.
|
||||||
Default does nothing.
|
|
||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def robust_process_one(self, shared_storage, item):
|
async def process(self, shared_storage, data):
|
||||||
"""
|
"""
|
||||||
In BaseNode, this is just a pass-through to `process_one`.
|
Public method for user logic. Default does nothing.
|
||||||
Subclasses (like Node with retry) can override this to add extra logic.
|
|
||||||
"""
|
"""
|
||||||
return await self.process_one(shared_storage, item)
|
return None
|
||||||
|
|
||||||
async def process(self, shared_storage, preprocess_result):
|
async def _process(self, shared_storage, data):
|
||||||
"""
|
"""
|
||||||
Calls `robust_process_one` instead of `process_one` so that
|
Internal hook that calls `process(...)`.
|
||||||
any subclass overrides of robust_process_one will apply.
|
Subclasses override this to add extra logic (e.g. retries).
|
||||||
"""
|
"""
|
||||||
return await self.robust_process_one(shared_storage, preprocess_result)
|
return await self.process(shared_storage, data)
|
||||||
|
|
||||||
async def postprocess(self, shared_storage, preprocess_result, process_result):
|
async def postprocess(self, shared_storage, preprocess_result, process_result):
|
||||||
|
"""
|
||||||
|
By default, returns "default" to pick the default successor.
|
||||||
|
"""
|
||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
async def run_one(self, shared_storage):
|
async def run_one(self, shared_storage):
|
||||||
|
"""
|
||||||
|
One cycle of the node:
|
||||||
|
1) preprocess
|
||||||
|
2) _process
|
||||||
|
3) postprocess
|
||||||
|
4) pick successor
|
||||||
|
"""
|
||||||
preprocess_result = await self.preprocess(shared_storage)
|
preprocess_result = await self.preprocess(shared_storage)
|
||||||
process_result = await self.process(shared_storage, preprocess_result)
|
process_result = await self._process(shared_storage, preprocess_result)
|
||||||
condition = await self.postprocess(shared_storage, preprocess_result, process_result)
|
condition = await self.postprocess(shared_storage, preprocess_result, process_result)
|
||||||
|
|
||||||
if not self.successors:
|
if not self.successors:
|
||||||
return None
|
return None
|
||||||
elif len(self.successors) == 1:
|
if len(self.successors) == 1:
|
||||||
return next(iter(self.successors.values()))
|
return next(iter(self.successors.values()))
|
||||||
return self.successors.get(condition)
|
return self.successors.get(condition)
|
||||||
|
|
||||||
|
|
@ -62,10 +64,14 @@ class BaseNode:
|
||||||
while current_node:
|
while current_node:
|
||||||
current_node = await current_node.run_one(shared_storage)
|
current_node = await current_node.run_one(shared_storage)
|
||||||
|
|
||||||
|
# Syntactic sugar for chaining
|
||||||
def __rshift__(self, other):
|
def __rshift__(self, other):
|
||||||
return self.add_successor(other)
|
return self.add_successor(other)
|
||||||
|
|
||||||
def __gt__(self, other):
|
def __gt__(self, other):
|
||||||
|
"""
|
||||||
|
For branching: node > "condition" > another_node
|
||||||
|
"""
|
||||||
if isinstance(other, str):
|
if isinstance(other, str):
|
||||||
return _ConditionalTransition(self, other)
|
return _ConditionalTransition(self, other)
|
||||||
elif isinstance(other, BaseNode):
|
elif isinstance(other, BaseNode):
|
||||||
|
|
@ -87,139 +93,139 @@ class _ConditionalTransition:
|
||||||
return self.source_node.add_successor(target_node, self.condition)
|
return self.source_node.add_successor(target_node, self.condition)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
# Flow: allows you to define a "start_node" that is run in a sub-flow
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
class Flow(BaseNode):
|
|
||||||
def __init__(self, start_node=None):
|
|
||||||
super().__init__()
|
|
||||||
self.start_node = start_node
|
|
||||||
|
|
||||||
async def process_one(self, shared_storage, item):
|
|
||||||
# Instead of doing a single operation, we run a sub-flow
|
|
||||||
if self.start_node:
|
|
||||||
current_node = self.start_node
|
|
||||||
while current_node:
|
|
||||||
# Pass down the parameters
|
|
||||||
current_node.set_parameters(self.parameters)
|
|
||||||
current_node = await current_node.run_one(shared_storage or {})
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
# Node: adds robust retry logic on top of BaseNode
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
class Node(BaseNode):
|
class Node(BaseNode):
|
||||||
"""
|
"""
|
||||||
Retries its single-item operation up to `max_retries` times,
|
Single-item node with robust logic.
|
||||||
waiting `delay_s` seconds between attempts.
|
End devs override `process(...)`.
|
||||||
By default: max_retries=5, delay_s=0.1
|
`_process(...)` adds the retry logic.
|
||||||
End developers simply override `process_one` to define logic.
|
"""
|
||||||
|
def __init__(self, max_retries=5, delay_s=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.delay_s = delay_s
|
||||||
|
|
||||||
|
async def fail_item(self, shared_storage, data, exc):
|
||||||
|
"""
|
||||||
|
Called if we exhaust all retries.
|
||||||
|
"""
|
||||||
|
print(f"[FAIL_ITEM] data={data}, error={exc}")
|
||||||
|
return "fail"
|
||||||
|
|
||||||
|
async def _process(self, shared_storage, data):
|
||||||
|
"""
|
||||||
|
Wraps the user’s `process(...)` with retry logic.
|
||||||
|
"""
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
return await super()._process(shared_storage, data)
|
||||||
|
except Exception as e:
|
||||||
|
if attempt == self.max_retries - 1:
|
||||||
|
return await self.fail_item(shared_storage, data, e)
|
||||||
|
await asyncio.sleep(self.delay_s)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
# BatchNode: processes multiple items
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
class BatchNode(BaseNode):
|
||||||
|
"""
|
||||||
|
Processes a list of items in `process(...)`.
|
||||||
|
The user overrides `process_one(item)`.
|
||||||
|
`_process_one(...)` handles robust retries for each item.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_retries=5, delay_s=0.1):
|
def __init__(self, max_retries=5, delay_s=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.parameters.setdefault("max_retries", max_retries)
|
self.max_retries = max_retries
|
||||||
self.parameters.setdefault("delay_s", delay_s)
|
self.delay_s = delay_s
|
||||||
|
|
||||||
async def fail_one(self, shared_storage, item, exc):
|
|
||||||
"""
|
|
||||||
Called if the final retry also fails. By default,
|
|
||||||
just returns a special string or could log an error.
|
|
||||||
End developers can override this to do something else
|
|
||||||
(e.g., store the failure in a separate list or
|
|
||||||
trigger alternative logic).
|
|
||||||
"""
|
|
||||||
# Example: log and return a special status
|
|
||||||
print(f"[FAIL_ONE] item={item}, error={exc}")
|
|
||||||
return "fail"
|
|
||||||
|
|
||||||
async def robust_process_one(self, shared_storage, item):
|
|
||||||
max_retries = self.parameters.get("max_retries", 5)
|
|
||||||
delay_s = self.parameters.get("delay_s", 0.1)
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
# Defer to the user's process_one logic
|
|
||||||
return await super().robust_process_one(shared_storage, item)
|
|
||||||
except Exception as e:
|
|
||||||
if attempt == max_retries - 1:
|
|
||||||
# Final attempt failed; call fail_one
|
|
||||||
return await self.fail_one(shared_storage, item, e)
|
|
||||||
# Otherwise, wait a bit and try again
|
|
||||||
await asyncio.sleep(delay_s)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
# BatchMixin: processes a collection of items by calling robust_process_one for each
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
class BatchMixin:
|
|
||||||
async def process(self, shared_storage, items):
|
|
||||||
"""
|
|
||||||
Processes a *collection* of items in a loop, calling robust_process_one per item.
|
|
||||||
"""
|
|
||||||
partial_results = []
|
|
||||||
for item in items:
|
|
||||||
r = await self.robust_process_one(shared_storage, item)
|
|
||||||
partial_results.append(r)
|
|
||||||
return self.merge(shared_storage, partial_results)
|
|
||||||
|
|
||||||
def merge(self, shared_storage, partial_results):
|
|
||||||
"""
|
|
||||||
Combines partial results into a single output.
|
|
||||||
By default, returns the list of partial results.
|
|
||||||
"""
|
|
||||||
return partial_results
|
|
||||||
|
|
||||||
async def preprocess(self, shared_storage):
|
async def preprocess(self, shared_storage):
|
||||||
"""
|
"""
|
||||||
Typically, you'd return a list or collection of items to process here.
|
Typically return a list of items to process.
|
||||||
By default, returns an empty list.
|
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
# BatchNode: combines Node (robust logic) + BatchMixin (batch logic)
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
class BatchNode(BatchMixin, Node):
|
|
||||||
"""
|
|
||||||
A batch-processing node that:
|
|
||||||
- Inherits robust retry logic from Node
|
|
||||||
- Uses BatchMixin to process a list of items
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def preprocess(self, shared_storage):
|
|
||||||
# Gather or return the batch items. By default, no items.
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def process_one(self, shared_storage, item):
|
async def process_one(self, shared_storage, item):
|
||||||
"""
|
"""
|
||||||
The per-item logic that the end developer will override.
|
End developers override single-item logic here.
|
||||||
By default, does nothing.
|
|
||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _process_one(self, shared_storage, item):
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
# BatchFlow: combines Flow (sub-flow logic) + batch processing + robust logic
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
class BatchFlow(BatchMixin, Flow):
|
|
||||||
"""
|
"""
|
||||||
This class runs a sub-flow (start_node) for each item in a batch.
|
Retry logic around process_one(item).
|
||||||
If you also want robust retries, you can adapt or combine with `Node`.
|
|
||||||
"""
|
"""
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
return await self.process_one(shared_storage, item)
|
||||||
|
except Exception as e:
|
||||||
|
if attempt == self.max_retries - 1:
|
||||||
|
print(f"[FAIL_ITEM] item={item}, error={e}")
|
||||||
|
return "fail"
|
||||||
|
await asyncio.sleep(self.delay_s)
|
||||||
|
|
||||||
async def preprocess(self, shared_storage):
|
async def _process(self, shared_storage, items):
|
||||||
# Return your batch items here
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def process(self, shared_storage, items):
|
|
||||||
"""
|
"""
|
||||||
For each item, run the sub-flow (start_node).
|
Loops over items, calling _process_one per item.
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for item in items:
|
for item in items:
|
||||||
# Here we re-run the sub-flow, which happens inside process_one of Flow
|
r = await self._process_one(shared_storage, item)
|
||||||
await self.process_one(shared_storage, item)
|
results.append(r)
|
||||||
# Optionally collect results or do something after the sub-flow
|
return results
|
||||||
results.append(f"Finished sub-flow for item: {item}")
|
|
||||||
|
|
||||||
|
class Flow(BaseNode):
|
||||||
|
"""
|
||||||
|
Runs a sub-flow from `start_node` once per call.
|
||||||
|
"""
|
||||||
|
def __init__(self, start_node=None):
|
||||||
|
super().__init__()
|
||||||
|
self.start_node = start_node
|
||||||
|
|
||||||
|
async def _process(self, shared_storage, _):
|
||||||
|
if self.start_node:
|
||||||
|
current_node = self.start_node
|
||||||
|
while current_node:
|
||||||
|
current_node = await current_node.run_one(shared_storage or {})
|
||||||
|
return "Flow done"
|
||||||
|
|
||||||
|
class BatchFlow(BaseNode):
|
||||||
|
"""
|
||||||
|
For each param_dict in the batch, merges it into self.parameters,
|
||||||
|
then runs the sub-flow from `start_node`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, start_node=None):
|
||||||
|
super().__init__()
|
||||||
|
self.start_node = start_node
|
||||||
|
|
||||||
|
async def preprocess(self, shared_storage):
|
||||||
|
"""
|
||||||
|
Return a list of param_dict objects.
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def process_one(self, shared_storage, param_dict):
|
||||||
|
"""
|
||||||
|
Merge param_dict into the node's parameters,
|
||||||
|
then run the sub-flow.
|
||||||
|
"""
|
||||||
|
node_parameters = self.parameters.copy()
|
||||||
|
node_parameters.update(param_dict)
|
||||||
|
|
||||||
|
if self.start_node:
|
||||||
|
current_node = self.start_node
|
||||||
|
while current_node:
|
||||||
|
# set the combined parameters
|
||||||
|
current_node.set_parameters(node_parameters)
|
||||||
|
current_node = await current_node.run_one(shared_storage or {})
|
||||||
|
|
||||||
|
async def _process(self, shared_storage, items):
|
||||||
|
"""
|
||||||
|
For each param_dict in items, run the sub-flow once.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for param_dict in items:
|
||||||
|
await self.process_one(shared_storage, param_dict)
|
||||||
|
results.append(f"Ran sub-flow for param_dict={param_dict}")
|
||||||
return results
|
return results
|
||||||
Loading…
Reference in New Issue