async wrapper

This commit is contained in:
zachary62 2024-12-25 04:52:35 +00:00
parent 6206eb3650
commit c8cae054bc
1 changed files with 69 additions and 105 deletions

View File

@ -1,50 +1,59 @@
import asyncio import asyncio
# --------------------------------------------------------------------- def _wrap_async(fn):
# BaseNode (no duplication of run_one) """
# --------------------------------------------------------------------- Given a synchronous function fn, return a coroutine (async function) that
class BaseNode: simply awaits the (synchronous) call.
"""
async def _async_wrapper(self, *args, **kwargs):
return fn(self, *args, **kwargs)
return _async_wrapper
class NodeMeta(type):
"""
Metaclass that converts certain methods into async if they are not already.
"""
def __new__(mcs, name, bases, attrs):
# Add ANY method names you want to auto-wrap here:
methods_to_wrap = (
"preprocess",
"process",
"postprocess",
"process_after_fail",
"process_one",
"process_one_after_fail",
)
for attr_name in methods_to_wrap:
if attr_name in attrs:
# If it's not already a coroutine function, wrap it
if not asyncio.iscoroutinefunction(attrs[attr_name]):
old_fn = attrs[attr_name]
attrs[attr_name] = _wrap_async(old_fn)
return super().__new__(mcs, name, bases, attrs)
class BaseNode(metaclass=NodeMeta):
def __init__(self): def __init__(self):
self.parameters = {} self.parameters = {}
self.successors = {} self.successors = {}
def add_successor(self, node, condition="default"): # By default these are already async. If a subclass overrides them
self.successors[condition] = node # with non-async definitions, they'll get wrapped automatically.
return node def preprocess(self, shared_storage):
async def preprocess(self, shared_storage):
"""
Override if needed to load or prepare data.
"""
return None return None
async def process(self, shared_storage, data): def process(self, shared_storage, data):
"""
Public method for user logic. Default does nothing.
"""
return None return None
async def _process(self, shared_storage, data): def postprocess(self, shared_storage, preprocess_result, process_result):
"""
Internal hook that calls `process(...)`.
Subclasses override this to add extra logic (e.g. retries).
"""
return await self.process(shared_storage, data)
async def postprocess(self, shared_storage, preprocess_result, process_result):
"""
By default, returns "default" to pick the default successor.
"""
return "default" return "default"
async def run_one(self, shared_storage): async def _process(self, shared_storage, data):
""" return await self.process(shared_storage, data)
One cycle of the node:
1) preprocess async def _run_one(self, shared_storage):
2) _process
3) postprocess
4) pick successor
"""
preprocess_result = await self.preprocess(shared_storage) preprocess_result = await self.preprocess(shared_storage)
process_result = await self._process(shared_storage, preprocess_result) process_result = await self._process(shared_storage, preprocess_result)
condition = await self.postprocess(shared_storage, preprocess_result, process_result) condition = await self.postprocess(shared_storage, preprocess_result, process_result)
@ -62,16 +71,17 @@ class BaseNode:
shared_storage = shared_storage or {} shared_storage = shared_storage or {}
current_node = self current_node = self
while current_node: while current_node:
current_node = await current_node.run_one(shared_storage) current_node = await current_node._run_one(shared_storage)
# Syntactic sugar for chaining # Syntactic sugar for chaining
def add_successor(self, node, condition="default"):
self.successors[condition] = node
return node
def __rshift__(self, other): def __rshift__(self, other):
return self.add_successor(other) return self.add_successor(other)
def __gt__(self, other): def __gt__(self, other):
"""
For branching: node > "condition" > another_node
"""
if isinstance(other, str): if isinstance(other, str):
return _ConditionalTransition(self, other) return _ConditionalTransition(self, other)
elif isinstance(other, BaseNode): elif isinstance(other, BaseNode):
@ -94,90 +104,59 @@ class _ConditionalTransition:
class Node(BaseNode): class Node(BaseNode):
"""
Single-item node with robust logic.
End devs override `process(...)`.
`_process(...)` adds the retry logic.
"""
def __init__(self, max_retries=5, delay_s=0.1): def __init__(self, max_retries=5, delay_s=0.1):
super().__init__() super().__init__()
self.max_retries = max_retries self.max_retries = max_retries
self.delay_s = delay_s self.delay_s = delay_s
async def fail_item(self, shared_storage, data, exc): def process_after_fail(self, shared_storage, data, exc):
"""
Called if we exhaust all retries.
"""
print(f"[FAIL_ITEM] data={data}, error={exc}") print(f"[FAIL_ITEM] data={data}, error={exc}")
return "fail" return "fail"
async def _process(self, shared_storage, data): async def _process(self, shared_storage, data):
"""
Wraps the users `process(...)` with retry logic.
"""
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
try: try:
return await super()._process(shared_storage, data) return await super()._process(shared_storage, data)
except Exception as e: except Exception as e:
if attempt == self.max_retries - 1: if attempt == self.max_retries - 1:
return await self.fail_item(shared_storage, data, e) return await self.process_after_fail(shared_storage, data, e)
await asyncio.sleep(self.delay_s) await asyncio.sleep(self.delay_s)
# ---------------------------------------------------------------------
# BatchNode: processes multiple items
# ---------------------------------------------------------------------
class BatchNode(BaseNode): class BatchNode(BaseNode):
"""
Processes a list of items in `process(...)`.
The user overrides `process_one(item)`.
`_process_one(...)` handles robust retries for each item.
"""
def __init__(self, max_retries=5, delay_s=0.1): def __init__(self, max_retries=5, delay_s=0.1):
super().__init__() super().__init__()
self.max_retries = max_retries self.max_retries = max_retries
self.delay_s = delay_s self.delay_s = delay_s
async def preprocess(self, shared_storage): def preprocess(self, shared_storage):
"""
Typically return a list of items to process.
"""
return [] return []
async def process_one(self, shared_storage, item): def process_one(self, shared_storage, item):
"""
End developers override single-item logic here.
"""
return None return None
def process_one_after_fail(self, shared_storage, item, exc):
print(f"[FAIL_ITEM] item={item}, error={exc}")
# By default, just return a "fail" marker. Could be anything you want.
return "fail"
async def _process_one(self, shared_storage, item): async def _process_one(self, shared_storage, item):
"""
Retry logic around process_one(item).
"""
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
try: try:
return await self.process_one(shared_storage, item) return await self.process_one(shared_storage, item)
except Exception as e: except Exception as e:
if attempt == self.max_retries - 1: if attempt == self.max_retries - 1:
print(f"[FAIL_ITEM] item={item}, error={e}") # If out of retries, let a subclass handle what to do next
return "fail" return await self.process_one_after_fail(shared_storage, item, e)
await asyncio.sleep(self.delay_s) await asyncio.sleep(self.delay_s)
async def _process(self, shared_storage, items): async def _process(self, shared_storage, items):
"""
Loops over items, calling _process_one per item.
"""
results = [] results = []
for item in items: for item in items:
r = await self._process_one(shared_storage, item) r = await self._process_one(shared_storage, item)
results.append(r) results.append(r)
return results return results
class Flow(BaseNode): class Flow(BaseNode):
"""
Runs a sub-flow from `start_node` once per call.
"""
def __init__(self, start_node=None): def __init__(self, start_node=None):
super().__init__() super().__init__()
self.start_node = start_node self.start_node = start_node
@ -186,30 +165,18 @@ class Flow(BaseNode):
if self.start_node: if self.start_node:
current_node = self.start_node current_node = self.start_node
while current_node: while current_node:
current_node = await current_node.run_one(shared_storage or {}) current_node = await current_node._run_one(shared_storage or {})
return "Flow done" return "Flow done"
class BatchFlow(BaseNode): class BatchFlow(BaseNode):
"""
For each param_dict in the batch, merges it into self.parameters,
then runs the sub-flow from `start_node`.
"""
def __init__(self, start_node=None): def __init__(self, start_node=None):
super().__init__() super().__init__()
self.start_node = start_node self.start_node = start_node
async def preprocess(self, shared_storage): def preprocess(self, shared_storage):
"""
Return a list of param_dict objects.
"""
return [] return []
async def process_one(self, shared_storage, param_dict): async def _process_one(self, shared_storage, param_dict):
"""
Merge param_dict into the node's parameters,
then run the sub-flow.
"""
node_parameters = self.parameters.copy() node_parameters = self.parameters.copy()
node_parameters.update(param_dict) node_parameters.update(param_dict)
@ -218,14 +185,11 @@ class BatchFlow(BaseNode):
while current_node: while current_node:
# set the combined parameters # set the combined parameters
current_node.set_parameters(node_parameters) current_node.set_parameters(node_parameters)
current_node = await current_node.run_one(shared_storage or {}) current_node = await current_node._run_one(shared_storage or {})
async def _process(self, shared_storage, items): async def _process(self, shared_storage, items):
"""
For each param_dict in items, run the sub-flow once.
"""
results = [] results = []
for param_dict in items: for param_dict in items:
await self.process_one(shared_storage, param_dict) await self._process_one(shared_storage, param_dict)
results.append(f"Ran sub-flow for param_dict={param_dict}") results.append(f"Ran sub-flow for param_dict={param_dict}")
return results return results