pocketflow/minillmflow/__init__.py

246 lines
9.1 KiB
Python

import asyncio
import warnings
class BaseNode:
# preprocess(): this is for compute intensive preparation tasks, before the LLM call
# process(): this is for the LLM call, and should be idempotent for retries
# postprocess(): this is to summarize the result and retrun the condition for the successor node
def __init__(self):
self.parameters, self.successors = {}, {}
def set_parameters(self, params): # make sure params is immutable
self.parameters = params # must be immutable during pre/post/process
def add_successor(self, node, condition="default"):
if condition in self.successors:
warnings.warn(f"Overwriting existing successor for condition '{condition}'")
self.successors[condition] = node # maps condition -> successor node
return node
def preprocess(self, shared_storage):
return None # will be passed to process() and postprocess()
def process(self, shared_storage, prep_result):
return None # will be passed to postprocess()
def _process(self, shared_storage, prep_result):
# Could have retry logic or other wrap logic
return self.process(shared_storage, prep_result)
def postprocess(self, shared_storage, prep_result, proc_result):
return "default" # condition for next node
def run(self, shared_storage=None):
prep = self.preprocess(shared_storage)
proc = self._process(shared_storage, prep)
return self.postprocess(shared_storage, prep, proc)
def __rshift__(self, other):
# chaining: node1 >> node2
return self.add_successor(other)
def __sub__(self, condition):
# condition-based chaining: node - "some_condition" >> next_node
if isinstance(condition, str):
return _ConditionalTransition(self, condition)
raise TypeError("Condition must be a string")
class _ConditionalTransition:
def __init__(self, source_node, condition):
self.source_node = source_node
self.condition = condition
def __rshift__(self, target_node):
return self.source_node.add_successor(target_node, self.condition)
class Node(BaseNode):
def __init__(self, max_retries=1):
super().__init__()
self.max_retries = max_retries
def process_after_fail(self, shared_storage, data, exc):
raise exc
def _process(self, shared_storage, data):
for attempt in range(self.max_retries):
try:
return super()._process(shared_storage, data)
except Exception as e:
if attempt == self.max_retries - 1:
return self.process_after_fail(shared_storage, data, e)
class BatchNode(Node):
def preprocess(self, shared_storage):
# return an iterable of items, one for each run
return []
def process(self, shared_storage, item): # process() is called for each item
return None
def _process(self, shared_storage, items):
results = []
for item in items:
# Here, 'item' is passed in place of 'prep_result' from the BaseNode's perspective.
r = super()._process(shared_storage, item)
results.append(r)
return results
class AsyncNode(Node):
"""
A Node whose postprocess step is async.
You can also override process() to be async if needed.
"""
def postprocess(self, shared_storage, prep_result, proc_result):
# Not used in async workflow; define postprocess_async() instead.
raise NotImplementedError("AsyncNode requires postprocess_async, and should be run in an AsyncFlow")
async def postprocess_async(self, shared_storage, prep_result, proc_result):
"""
Async version of postprocess. By default, returns "default".
Override as needed.
"""
await asyncio.sleep(0) # trivial async pause (no-op)
return "default"
async def run_async(self, shared_storage=None):
prep = self.preprocess(shared_storage)
proc = self._process(shared_storage, prep)
return await self.postprocess_async(shared_storage, prep, proc)
class BaseFlow(BaseNode):
"""
Abstract base flow that provides the main logic of:
- Starting from self.start_node
- Looping until no more successors
Subclasses must define how they *call* each node (sync or async).
"""
def __init__(self, start_node=None):
super().__init__()
self.start_node = start_node
def get_next_node(self, current_node, condition):
next_node = current_node.successors.get(condition, None)
if next_node is None and current_node.successors:
warnings.warn(f"Flow will end. Condition '{condition}' not found among possible conditions: {list(current_node.successors.keys())}")
return next_node
def run(self, shared_storage=None):
"""
By default, do nothing (or raise).
Subclasses (Flow, AsyncFlow) will implement.
"""
raise NotImplementedError("BaseFlow.run must be implemented by subclasses")
async def run_async(self, shared_storage=None):
"""
By default, do nothing (or raise).
Subclasses (Flow, AsyncFlow) will implement.
"""
raise NotImplementedError("BaseFlow.run_async must be implemented by subclasses")
class Flow(BaseFlow):
"""
Synchronous flow: each node is called with .run(shared_storage).
"""
def _process_flow(self, shared_storage):
current_node = self.start_node
while current_node:
# Pass down the Flow's parameters to the current node
current_node.set_parameters(self.parameters)
# Synchronous run
condition = current_node.run(shared_storage)
# Decide next node
current_node = self.get_next_node(current_node, condition)
def run(self, shared_storage=None):
prep_result = self.preprocess(shared_storage)
self._process_flow(shared_storage)
return self.postprocess(shared_storage, prep_result, None)
class AsyncFlow(BaseFlow):
"""
Asynchronous flow: if a node has .run_async, we await it.
Otherwise, we fallback to .run.
"""
async def _process_flow_async(self, shared_storage):
current_node = self.start_node
while current_node:
current_node.set_parameters(self.parameters)
# If node is async-capable, call run_async; otherwise run sync
if hasattr(current_node, "run_async") and callable(current_node.run_async):
condition = await current_node.run_async(shared_storage)
else:
condition = current_node.run(shared_storage)
current_node = self.get_next_node(current_node, condition)
async def run_async(self, shared_storage=None):
prep_result = self.preprocess(shared_storage)
await self._process_flow_async(shared_storage)
return self.postprocess(shared_storage, prep_result, None)
def run(self, shared_storage=None):
try:
return asyncio.run(self.run_async(shared_storage))
except RuntimeError as e:
raise RuntimeError("If you are running in Jupyter, please use `await run_async()` instead of `run()`.") from e
class BaseBatchFlow(BaseFlow):
"""
Abstract base for a flow that runs multiple times (a batch),
once for each set of parameters or items from preprocess().
"""
def preprocess(self, shared_storage):
"""
By default, returns an iterable of parameter-dicts or items
for the flow to process in a batch.
"""
return []
class BatchFlow(BaseBatchFlow, Flow):
"""
Synchronous batch flow: calls the flow repeatedly
for each set of parameters/items in preprocess().
"""
def run(self, shared_storage=None):
prep_result = self.preprocess(shared_storage)
all_results = []
# For each set of parameters (or items) we got from preprocess
for param_dict in prep_result:
# Merge param_dict into the Flow's parameters
original_params = self.parameters.copy()
self.parameters.update(param_dict)
# Run from the start node to end
self._process_flow(shared_storage)
# Optionally collect results from shared_storage or a custom method
all_results.append(f"Finished run with parameters: {param_dict}")
# Reset the parameters if needed
self.parameters = original_params
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
"""
Asynchronous batch flow: calls the flow repeatedly in an async manner
for each set of parameters/items in preprocess().
"""
async def run_async(self, shared_storage=None):
prep_result = self.preprocess(shared_storage)
all_results = []
for param_dict in prep_result:
original_params = self.parameters.copy()
self.parameters.update(param_dict)
await self._process_flow_async(shared_storage)
all_results.append(f"Finished async run with parameters: {param_dict}")
# Reset back to original parameters if needed
self.parameters = original_params