195 lines
6.6 KiB
Python
195 lines
6.6 KiB
Python
import asyncio
|
|
|
|
def _wrap_async(fn):
|
|
"""
|
|
Given a synchronous function fn, return a coroutine (async function) that
|
|
simply awaits the (synchronous) call.
|
|
"""
|
|
async def _async_wrapper(self, *args, **kwargs):
|
|
return fn(self, *args, **kwargs)
|
|
return _async_wrapper
|
|
|
|
|
|
class NodeMeta(type):
|
|
"""
|
|
Metaclass that converts certain methods into async if they are not already.
|
|
"""
|
|
def __new__(mcs, name, bases, attrs):
|
|
# Add ANY method names you want to auto-wrap here:
|
|
methods_to_wrap = (
|
|
"preprocess",
|
|
"process",
|
|
"postprocess",
|
|
"process_after_fail",
|
|
"process_one",
|
|
"process_one_after_fail",
|
|
)
|
|
|
|
for attr_name in methods_to_wrap:
|
|
if attr_name in attrs:
|
|
# If it's not already a coroutine function, wrap it
|
|
if not asyncio.iscoroutinefunction(attrs[attr_name]):
|
|
old_fn = attrs[attr_name]
|
|
attrs[attr_name] = _wrap_async(old_fn)
|
|
|
|
return super().__new__(mcs, name, bases, attrs)
|
|
|
|
class BaseNode(metaclass=NodeMeta):
|
|
def __init__(self):
|
|
self.parameters = {}
|
|
self.successors = {}
|
|
|
|
# By default these are already async. If a subclass overrides them
|
|
# with non-async definitions, they'll get wrapped automatically.
|
|
def preprocess(self, shared_storage):
|
|
return None
|
|
|
|
def process(self, shared_storage, data):
|
|
return None
|
|
|
|
def postprocess(self, shared_storage, preprocess_result, process_result):
|
|
return "default"
|
|
|
|
async def _process(self, shared_storage, data):
|
|
return await self.process(shared_storage, data)
|
|
|
|
async def _run_one(self, shared_storage):
|
|
preprocess_result = await self.preprocess(shared_storage)
|
|
process_result = await self._process(shared_storage, preprocess_result)
|
|
condition = await self.postprocess(shared_storage, preprocess_result, process_result)
|
|
|
|
if not self.successors:
|
|
return None
|
|
if len(self.successors) == 1:
|
|
return next(iter(self.successors.values()))
|
|
return self.successors.get(condition)
|
|
|
|
def run(self, shared_storage=None):
|
|
return asyncio.run(self.run_async(shared_storage))
|
|
|
|
async def run_async(self, shared_storage=None):
|
|
shared_storage = shared_storage or {}
|
|
current_node = self
|
|
while current_node:
|
|
current_node = await current_node._run_one(shared_storage)
|
|
|
|
# Syntactic sugar for chaining
|
|
def add_successor(self, node, condition="default"):
|
|
self.successors[condition] = node
|
|
return node
|
|
|
|
def __rshift__(self, other):
|
|
return self.add_successor(other)
|
|
|
|
def __gt__(self, other):
|
|
if isinstance(other, str):
|
|
return _ConditionalTransition(self, other)
|
|
elif isinstance(other, BaseNode):
|
|
return self.add_successor(other)
|
|
raise TypeError("Unsupported operand type")
|
|
|
|
def __call__(self, condition):
|
|
return _ConditionalTransition(self, condition)
|
|
|
|
|
|
class _ConditionalTransition:
|
|
def __init__(self, source_node, condition):
|
|
self.source_node = source_node
|
|
self.condition = condition
|
|
|
|
def __gt__(self, target_node):
|
|
if not isinstance(target_node, BaseNode):
|
|
raise TypeError("Target must be a BaseNode")
|
|
return self.source_node.add_successor(target_node, self.condition)
|
|
|
|
|
|
class Node(BaseNode):
|
|
def __init__(self, max_retries=5, delay_s=0.1):
|
|
super().__init__()
|
|
self.max_retries = max_retries
|
|
self.delay_s = delay_s
|
|
|
|
def process_after_fail(self, shared_storage, data, exc):
|
|
print(f"[FAIL_ITEM] data={data}, error={exc}")
|
|
return "fail"
|
|
|
|
async def _process(self, shared_storage, data):
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
return await super()._process(shared_storage, data)
|
|
except Exception as e:
|
|
if attempt == self.max_retries - 1:
|
|
return await self.process_after_fail(shared_storage, data, e)
|
|
await asyncio.sleep(self.delay_s)
|
|
|
|
class BatchNode(BaseNode):
|
|
def __init__(self, max_retries=5, delay_s=0.1):
|
|
super().__init__()
|
|
self.max_retries = max_retries
|
|
self.delay_s = delay_s
|
|
|
|
def preprocess(self, shared_storage):
|
|
return []
|
|
|
|
def process_one(self, shared_storage, item):
|
|
return None
|
|
|
|
def process_one_after_fail(self, shared_storage, item, exc):
|
|
print(f"[FAIL_ITEM] item={item}, error={exc}")
|
|
# By default, just return a "fail" marker. Could be anything you want.
|
|
return "fail"
|
|
|
|
async def _process_one(self, shared_storage, item):
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
return await self.process_one(shared_storage, item)
|
|
except Exception as e:
|
|
if attempt == self.max_retries - 1:
|
|
# If out of retries, let a subclass handle what to do next
|
|
return await self.process_one_after_fail(shared_storage, item, e)
|
|
await asyncio.sleep(self.delay_s)
|
|
|
|
async def _process(self, shared_storage, items):
|
|
results = []
|
|
for item in items:
|
|
r = await self._process_one(shared_storage, item)
|
|
results.append(r)
|
|
return results
|
|
|
|
class Flow(BaseNode):
|
|
def __init__(self, start_node=None):
|
|
super().__init__()
|
|
self.start_node = start_node
|
|
|
|
async def _process(self, shared_storage, _):
|
|
if self.start_node:
|
|
current_node = self.start_node
|
|
while current_node:
|
|
current_node = await current_node._run_one(shared_storage or {})
|
|
return "Flow done"
|
|
|
|
class BatchFlow(BaseNode):
|
|
def __init__(self, start_node=None):
|
|
super().__init__()
|
|
self.start_node = start_node
|
|
|
|
def preprocess(self, shared_storage):
|
|
return []
|
|
|
|
async def _process_one(self, shared_storage, param_dict):
|
|
node_parameters = self.parameters.copy()
|
|
node_parameters.update(param_dict)
|
|
|
|
if self.start_node:
|
|
current_node = self.start_node
|
|
while current_node:
|
|
# set the combined parameters
|
|
current_node.set_parameters(node_parameters)
|
|
current_node = await current_node._run_one(shared_storage or {})
|
|
|
|
async def _process(self, shared_storage, items):
|
|
results = []
|
|
for param_dict in items:
|
|
await self._process_one(shared_storage, param_dict)
|
|
results.append(f"Ran sub-flow for param_dict={param_dict}")
|
|
return results |