pocketflow/minillmflow/__init__.py

184 lines
7.4 KiB
Python

import asyncio
import warnings
class BaseNode:
# preprocess(): this is for compute intensive preparation tasks, before the LLM call
# process(): this is for the LLM call, and should be idempotent for retries
# postprocess(): this is to summarize the result and retrun the condition for the successor node
def __init__(self):
self.params, self.successors = {}, {}
def set_params(self, params): # make sure params is immutable
self.params = params # must be immutable during pre/post/process
def add_successor(self, node, condition="default"):
if condition in self.successors:
warnings.warn(f"Overwriting existing successor for condition '{condition}'")
self.successors[condition] = node # maps condition -> successor node
return node
def preprocess(self, shared_storage):
return None # will be passed to process() and postprocess()
def process(self, shared_storage, prep_result):
return None # will be passed to postprocess()
def _process(self, shared_storage, prep_result):
# Could have retry logic or other wrap logic
return self.process(shared_storage, prep_result)
def postprocess(self, shared_storage, prep_result, proc_result):
return "default" # condition for next node
def _run(self, shared_storage):
prep_result = self.preprocess(shared_storage)
proc_result = self._process(shared_storage, prep_result)
return self.postprocess(shared_storage, prep_result, proc_result)
def run(self, shared_storage):
if self.successors:
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent Flow and use that Flow.run() instead.")
return self._run(shared_storage)
def __rshift__(self, other):
# chaining: node1 >> node2
return self.add_successor(other)
def __sub__(self, condition):
# condition-based chaining: node - "some_condition" >> next_node
if isinstance(condition, str):
return _ConditionalTransition(self, condition)
raise TypeError("Condition must be a string")
class _ConditionalTransition:
def __init__(self, source_node, condition):
self.source_node = source_node
self.condition = condition
def __rshift__(self, target_node):
return self.source_node.add_successor(target_node, self.condition)
class Node(BaseNode):
def __init__(self, max_retries=1):
super().__init__()
self.max_retries = max_retries
def process_after_fail(self, shared_storage, data, exc):
raise exc
def _process(self, shared_storage, data):
for attempt in range(self.max_retries):
try:
return super()._process(shared_storage, data)
except Exception as e:
if attempt == self.max_retries - 1:
return self.process_after_fail(shared_storage, data, e)
class BatchNode(Node):
def preprocess(self, shared_storage):
# return an iterable of items, one for each run
return []
def process(self, shared_storage, item): # process() is called for each item
return None
def _process(self, shared_storage, items):
results = []
for item in items:
# Here, 'item' is passed in place of 'prep_result' from the BaseNode's perspective.
r = super()._process(shared_storage, item)
results.append(r)
return results
class AsyncNode(Node):
def postprocess(self, shared_storage, prep_result, proc_result):
raise NotImplementedError("AsyncNode requires postprocess_async, and should be run in an AsyncFlow")
async def postprocess_async(self, shared_storage, prep_result, proc_result):
await asyncio.sleep(0) # trivial async pause (no-op)
return "default"
async def run_async(self, shared_storage):
if self.successors:
warnings.warn("This node has successor nodes. To run its successors, wrap this node in a parent AsyncFlow and use that AsyncFlow.run_async() instead.")
return await self._run_async(shared_storage)
async def _run_async(self, shared_storage):
prep_result = self.preprocess(shared_storage)
proc_result = self._process(shared_storage, prep_result)
return await self.postprocess_async(shared_storage, prep_result, proc_result)
def _run(self, shared_storage):
raise RuntimeError("AsyncNode requires asynchronous execution. Use 'await node.run_async()' if inside an async function, or 'asyncio.run(node.run_async())' if in synchronous code.")
class BaseFlow(BaseNode):
def __init__(self, start_node):
super().__init__()
self.start_node = start_node
def get_next_node(self, current_node, condition):
next_node = current_node.successors.get(condition, None)
if next_node is None and current_node.successors:
warnings.warn(f"Flow will end. Condition '{condition}' not found among possible conditions: {list(current_node.successors.keys())}")
return next_node
class Flow(BaseFlow):
def _process(self, shared_storage, params=None):
current_node = self.start_node
params = params if params is not None else self.params.copy()
while current_node:
current_node.set_params(params)
condition = current_node._run(shared_storage)
current_node = self.get_next_node(current_node, condition)
def process(self, shared_storage, prep_result):
raise NotImplementedError("Flow should not process directly")
class AsyncFlow(BaseFlow, AsyncNode):
async def _process_async(self, shared_storage, params=None):
current_node = self.start_node
params = params if params is not None else self.params.copy()
while current_node:
current_node.set_params(params)
if hasattr(current_node, "run_async") and callable(current_node.run_async):
condition = await current_node._run_async(shared_storage)
else:
condition = current_node._run(shared_storage)
current_node = self.get_next_node(current_node, condition)
async def _run_async(self, shared_storage):
prep_result = self.preprocess(shared_storage)
await self._process_async(shared_storage)
return await self.postprocess_async(shared_storage, prep_result, None)
class BaseBatchFlow(BaseFlow):
def preprocess(self, shared_storage):
return [] # return an iterable of parameter dictionaries
class BatchFlow(BaseBatchFlow, Flow):
def _run(self, shared_storage):
prep_result = self.preprocess(shared_storage)
for param_dict in prep_result:
merged_params = self.params.copy()
merged_params.update(param_dict)
self._process(shared_storage, params=merged_params)
return self.postprocess(shared_storage, prep_result, None)
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
async def _run_async(self, shared_storage):
prep_result = self.preprocess(shared_storage)
for param_dict in prep_result:
merged_params = self.params.copy()
merged_params.update(param_dict)
await self._process_async(shared_storage, params=merged_params)
return await self.postprocess_async(shared_storage, prep_result, None)