pocketflow/minillmflow/__init__.py

100 lines
4.3 KiB
Python

import asyncio, warnings
class BaseNode:
def __init__(self): self.params,self.successors={},{}
def set_params(self,params): self.params=params
def add_successor(self,node,action="default"):
if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'")
self.successors[action]=node;return node
def prep(self,shared): return None
def exec(self,prep_res): return None
def _exec(self,prep_res): return self.exec(prep_res)
def post(self,shared,prep_res,exec_res): return "default"
def _run(self,shared):
prep_res=self.prep(shared)
exec_res=self._exec(prep_res)
return self.post(shared,prep_res,exec_res)
def run(self,shared):
if self.successors: warnings.warn("Node won't run successors. Use a parent Flow instead.")
return self._run(shared)
def __rshift__(self,other): return self.add_successor(other)
def __sub__(self,action):
if isinstance(action,str): return _ConditionalTransition(self,action)
raise TypeError("Action must be a string")
class _ConditionalTransition:
def __init__(self,src,action): self.src,self.action=src,action
def __rshift__(self,tgt): return self.src.add_successor(tgt,self.action)
class Node(BaseNode):
def __init__(self,max_retries=1):
super().__init__()
self.max_retries=max_retries
def process_after_fail(self,prep_res,exc): raise exc
def _exec(self,prep_res):
for i in range(self.max_retries):
try:return super()._exec(prep_res)
except Exception as e:
if i==self.max_retries-1:return self.process_after_fail(prep_res,e)
class BatchNode(Node):
def prep(self,shared): return []
def _exec(self,items): return [super(Node,self)._exec(i) for i in items]
class Flow(BaseNode):
def __init__(self,start):
super().__init__()
self.start=start
def get_next_node(self,curr,action):
nxt=curr.successors.get(action if action is not None else "default")
if not nxt and curr.successors: warnings.warn(f"Flow ends: action '{action}' not found in {list(curr.successors)}")
return nxt
def _orchestrate(self,shared,params=None):
curr,p=self.start,(params if params else {**self.params})
while curr:
curr.set_params(p)
curr=self.get_next_node(curr,curr._run(shared))
def _run(self,shared):
self._orchestrate(shared)
return self.post(shared,self.prep(shared),None)
def exec(self,prep_res):
raise RuntimeError("Flow should not exec directly. Create a child Node instead.")
class BatchFlow(Flow):
def prep(self,shared): return []
def _run(self,shared):
prep_res=self.prep(shared)
for batch_params in prep_res:self._orchestrate(shared,{**self.params,**batch_params})
return self.post(shared,prep_res,None)
class AsyncNode(Node):
def post(self,shared,prep_res,exec_res):
raise RuntimeError("AsyncNode should post using post_async instead.")
async def post_async(self,shared,prep_res,exec_res):
await asyncio.sleep(0);return "default"
async def run_async(self,shared):
if self.successors:
warnings.warn("Node won't run successors. Use a parent AsyncFlow instead.")
return await self._run_async(shared)
async def _run_async(self,shared):
prep_res=self.prep(shared)
exec_res=self._exec(prep_res)
return await self.post_async(shared,prep_res,exec_res)
def _run(self,shared): raise RuntimeError("AsyncNode should run using run_async instead.")
class AsyncFlow(Flow,AsyncNode):
async def _orchestrate_async(self,shared,params=None):
curr,p=self.start,(params if params else {**self.params})
while curr:
curr.set_params(p)
c=await curr._run_async(shared) if hasattr(curr,"run_async") else curr._run(shared)
curr=self.get_next_node(curr,c)
async def _run_async(self,shared):
await self._orchestrate_async(shared)
return await self.post_async(shared,self.prep(shared),None)
class BatchAsyncFlow(BatchFlow,AsyncFlow):
async def _run_async(self,shared):
prep_res=self.prep(shared)
for batch_params in prep_res:await self._orchestrate_async(shared,{**self.params,**batch_params})
return await self.post_async(shared,prep_res,None)