pocketflow/minillmflow/__init__.py

227 lines
7.5 KiB
Python

import asyncio
class BaseNode:
"""
A base node that provides:
- preprocess()
- process()
- postprocess()
- run() -- just runs itself (no chaining)
"""
def __init__(self):
self.parameters = {}
self.successors = {}
def set_parameters(self, params):
self.parameters.update(params)
def add_successor(self, node, condition="default"):
if condition in self.successors:
print(f"Warning: overwriting existing successor for condition '{condition}'")
self.successors[condition] = node
return node
def preprocess(self, shared_storage):
return None
def process(self, shared_storage, prep_result):
return None
def _process(self, shared_storage, prep_result):
# Could have retry logic or other wrap logic
return self.process(shared_storage, prep_result)
def postprocess(self, shared_storage, prep_result, proc_result):
return "default"
def run(self, shared_storage=None):
prep = self.preprocess(shared_storage)
proc = self._process(shared_storage, prep)
return self.postprocess(shared_storage, prep, proc)
def __rshift__(self, other):
"""
For chaining with >> operator, e.g. node1 >> node2
"""
return self.add_successor(other)
def __gt__(self, other):
"""
For chaining with > operator, e.g. node1 > "some_condition"
then >> node2
"""
if isinstance(other, str):
return _ConditionalTransition(self, other)
elif isinstance(other, BaseNode):
return self.add_successor(other)
raise TypeError("Unsupported operand type")
def __call__(self, condition):
"""
For node("condition") >> next_node syntax
"""
return _ConditionalTransition(self, condition)
class _ConditionalTransition:
"""
Helper for Node > 'condition' >> AnotherNode style
"""
def __init__(self, source_node, condition):
self.source_node = source_node
self.condition = condition
def __rshift__(self, target_node):
return self.source_node.add_successor(target_node, self.condition)
# robust running process
class Node(BaseNode):
def __init__(self, max_retries=1):
super().__init__()
self.max_retries = max_retries
def process_after_fail(self, shared_storage, data, exc):
raise exc
# return "fail"
def _process(self, shared_storage, data):
for attempt in range(self.max_retries):
try:
return super()._process(shared_storage, data)
except Exception as e:
if attempt == self.max_retries - 1:
return self.process_after_fail(shared_storage, data, e)
class Flow(BaseNode):
def __init__(self, start_node=None):
self.start_node = start_node
def _process(self, shared_storage, _):
current_node = self.start_node
while current_node:
condition = current_node.run(shared_storage)
current_node = current_node.successors.get(condition, None)
def postprocess(self, shared_storage, prep_result, proc_result):
return None
class AsyncNode(Node):
"""
A Node whose postprocess step is async.
You can also override process() to be async if needed.
"""
async def postprocess_async(self, shared_storage, prep_result, proc_result):
"""
Async version of postprocess. By default, returns "default".
Override as needed.
"""
await asyncio.sleep(0) # trivial async pause (no-op)
return "default"
async def run_async(self, shared_storage=None):
"""
Async version of run.
If your process method is also async, you'll need to adapt accordingly.
"""
# We can keep preprocess synchronous or make it async as well,
# depending on your usage. Here it's left as sync for simplicity.
prep = self.preprocess(shared_storage)
# process can remain sync if you prefer, or you can define an async process.
proc = self._process(shared_storage, prep)
# postprocess is async
return await self.postprocess_async(shared_storage, prep, proc)
class AsyncFlow(Flow):
"""
A Flow that can handle a mixture of sync and async nodes.
If the node is an AsyncNode, calls `run_async`.
Otherwise, calls `run`.
"""
async def _process(self, shared_storage, _):
current_node = self.start_node
while current_node:
if hasattr(current_node, "run_async") and callable(current_node.run_async):
# If it's an async node, await its run_async
condition = await current_node.run_async(shared_storage)
else:
# Otherwise, assume it's a sync node
condition = current_node.run(shared_storage)
current_node = current_node.successors.get(condition, None)
async def run_async(self, shared_storage=None):
"""
Kicks off the async flow. Similar to Flow.run,
but uses our async _process method.
"""
prep = self.preprocess(shared_storage)
# Note: flows typically don't need a meaningful process step
# because the "process" is the iteration through the nodes.
await self._process(shared_storage, prep)
return self.postprocess(shared_storage, prep, None)
class BatchNode(BaseNode):
def __init__(self, max_retries=5, delay_s=0.1):
super().__init__()
self.max_retries = max_retries
self.delay_s = delay_s
def preprocess(self, shared_storage):
return []
def process_one(self, shared_storage, item):
return None
def process_one_after_fail(self, shared_storage, item, exc):
print(f"[FAIL_ITEM] item={item}, error={exc}")
# By default, just return a "fail" marker. Could be anything you want.
return "fail"
async def _process_one(self, shared_storage, item):
for attempt in range(self.max_retries):
try:
return await self.process_one(shared_storage, item)
except Exception as e:
if attempt == self.max_retries - 1:
# If out of retries, let a subclass handle what to do next
return await self.process_one_after_fail(shared_storage, item, e)
await asyncio.sleep(self.delay_s)
async def _process(self, shared_storage, items):
results = []
for item in items:
r = await self._process_one(shared_storage, item)
results.append(r)
return results
class BatchFlow(BaseNode):
def __init__(self, start_node=None):
super().__init__()
self.start_node = start_node
def preprocess(self, shared_storage):
return []
async def _process_one(self, shared_storage, param_dict):
node_parameters = self.parameters.copy()
node_parameters.update(param_dict)
if self.start_node:
current_node = self.start_node
while current_node:
# set the combined parameters
current_node.set_parameters(node_parameters)
current_node = await current_node._run_one(shared_storage or {})
async def _process(self, shared_storage, items):
results = []
for param_dict in items:
await self._process_one(shared_storage, param_dict)
results.append(f"Ran sub-flow for param_dict={param_dict}")
return results