pocketflow/minillmflow/__init__.py

231 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
# ---------------------------------------------------------------------
# BaseNode (no duplication of run_one)
# ---------------------------------------------------------------------
class BaseNode:
def __init__(self):
self.parameters = {}
self.successors = {}
def add_successor(self, node, condition="default"):
self.successors[condition] = node
return node
async def preprocess(self, shared_storage):
"""
Override if needed to load or prepare data.
"""
return None
async def process(self, shared_storage, data):
"""
Public method for user logic. Default does nothing.
"""
return None
async def _process(self, shared_storage, data):
"""
Internal hook that calls `process(...)`.
Subclasses override this to add extra logic (e.g. retries).
"""
return await self.process(shared_storage, data)
async def postprocess(self, shared_storage, preprocess_result, process_result):
"""
By default, returns "default" to pick the default successor.
"""
return "default"
async def run_one(self, shared_storage):
"""
One cycle of the node:
1) preprocess
2) _process
3) postprocess
4) pick successor
"""
preprocess_result = await self.preprocess(shared_storage)
process_result = await self._process(shared_storage, preprocess_result)
condition = await self.postprocess(shared_storage, preprocess_result, process_result)
if not self.successors:
return None
if len(self.successors) == 1:
return next(iter(self.successors.values()))
return self.successors.get(condition)
def run(self, shared_storage=None):
return asyncio.run(self.run_async(shared_storage))
async def run_async(self, shared_storage=None):
shared_storage = shared_storage or {}
current_node = self
while current_node:
current_node = await current_node.run_one(shared_storage)
# Syntactic sugar for chaining
def __rshift__(self, other):
return self.add_successor(other)
def __gt__(self, other):
"""
For branching: node > "condition" > another_node
"""
if isinstance(other, str):
return _ConditionalTransition(self, other)
elif isinstance(other, BaseNode):
return self.add_successor(other)
raise TypeError("Unsupported operand type")
def __call__(self, condition):
return _ConditionalTransition(self, condition)
class _ConditionalTransition:
def __init__(self, source_node, condition):
self.source_node = source_node
self.condition = condition
def __gt__(self, target_node):
if not isinstance(target_node, BaseNode):
raise TypeError("Target must be a BaseNode")
return self.source_node.add_successor(target_node, self.condition)
class Node(BaseNode):
"""
Single-item node with robust logic.
End devs override `process(...)`.
`_process(...)` adds the retry logic.
"""
def __init__(self, max_retries=5, delay_s=0.1):
super().__init__()
self.max_retries = max_retries
self.delay_s = delay_s
async def fail_item(self, shared_storage, data, exc):
"""
Called if we exhaust all retries.
"""
print(f"[FAIL_ITEM] data={data}, error={exc}")
return "fail"
async def _process(self, shared_storage, data):
"""
Wraps the users `process(...)` with retry logic.
"""
for attempt in range(self.max_retries):
try:
return await super()._process(shared_storage, data)
except Exception as e:
if attempt == self.max_retries - 1:
return await self.fail_item(shared_storage, data, e)
await asyncio.sleep(self.delay_s)
# ---------------------------------------------------------------------
# BatchNode: processes multiple items
# ---------------------------------------------------------------------
class BatchNode(BaseNode):
"""
Processes a list of items in `process(...)`.
The user overrides `process_one(item)`.
`_process_one(...)` handles robust retries for each item.
"""
def __init__(self, max_retries=5, delay_s=0.1):
super().__init__()
self.max_retries = max_retries
self.delay_s = delay_s
async def preprocess(self, shared_storage):
"""
Typically return a list of items to process.
"""
return []
async def process_one(self, shared_storage, item):
"""
End developers override single-item logic here.
"""
return None
async def _process_one(self, shared_storage, item):
"""
Retry logic around process_one(item).
"""
for attempt in range(self.max_retries):
try:
return await self.process_one(shared_storage, item)
except Exception as e:
if attempt == self.max_retries - 1:
print(f"[FAIL_ITEM] item={item}, error={e}")
return "fail"
await asyncio.sleep(self.delay_s)
async def _process(self, shared_storage, items):
"""
Loops over items, calling _process_one per item.
"""
results = []
for item in items:
r = await self._process_one(shared_storage, item)
results.append(r)
return results
class Flow(BaseNode):
"""
Runs a sub-flow from `start_node` once per call.
"""
def __init__(self, start_node=None):
super().__init__()
self.start_node = start_node
async def _process(self, shared_storage, _):
if self.start_node:
current_node = self.start_node
while current_node:
current_node = await current_node.run_one(shared_storage or {})
return "Flow done"
class BatchFlow(BaseNode):
"""
For each param_dict in the batch, merges it into self.parameters,
then runs the sub-flow from `start_node`.
"""
def __init__(self, start_node=None):
super().__init__()
self.start_node = start_node
async def preprocess(self, shared_storage):
"""
Return a list of param_dict objects.
"""
return []
async def process_one(self, shared_storage, param_dict):
"""
Merge param_dict into the node's parameters,
then run the sub-flow.
"""
node_parameters = self.parameters.copy()
node_parameters.update(param_dict)
if self.start_node:
current_node = self.start_node
while current_node:
# set the combined parameters
current_node.set_parameters(node_parameters)
current_node = await current_node.run_one(shared_storage or {})
async def _process(self, shared_storage, items):
"""
For each param_dict in items, run the sub-flow once.
"""
results = []
for param_dict in items:
await self.process_one(shared_storage, param_dict)
results.append(f"Ran sub-flow for param_dict={param_dict}")
return results