From 3c329082127df6bf56c6d7aa225af95b22b13aed Mon Sep 17 00:00:00 2001 From: zachary62 Date: Fri, 11 Apr 2025 14:12:55 -0400 Subject: [PATCH] update design --- pocketflow/__init__.py | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/pocketflow/__init__.py b/pocketflow/__init__.py index 27a76e1..6f4f533 100644 --- a/pocketflow/__init__.py +++ b/pocketflow/__init__.py @@ -3,28 +3,28 @@ import asyncio, warnings, copy, time class BaseNode: def __init__(self): self.params,self.successors={},{} def set_params(self,params): self.params=params - def add_successor(self,node,action="default"): + def next(self,node,action="default"): if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'") - self.successors[action]=node;return node + self.successors[action]=node; return node def prep(self,shared): pass def exec(self,prep_res): pass def post(self,shared,prep_res,exec_res): pass def _exec(self,prep_res): return self.exec(prep_res) - def _run(self,shared): p=self.prep(shared);e=self._exec(p);return self.post(shared,p,e) + def _run(self,shared): p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e) def run(self,shared): if self.successors: warnings.warn("Node won't run successors. Use Flow.") return self._run(shared) - def __rshift__(self,other): return self.add_successor(other) + def __rshift__(self,other): return self.next(other) def __sub__(self,action): if isinstance(action,str): return _ConditionalTransition(self,action) raise TypeError("Action must be a string") class _ConditionalTransition: def __init__(self,src,action): self.src,self.action=src,action - def __rshift__(self,tgt): return self.src.add_successor(tgt,self.action) + def __rshift__(self,tgt): return self.src.next(tgt,self.action) class Node(BaseNode): - def __init__(self,max_retries=1,wait=0): super().__init__();self.max_retries,self.wait=max_retries,wait + def __init__(self,max_retries=1,wait=0): super().__init__(); self.max_retries,self.wait=max_retries,wait def exec_fallback(self,prep_res,exc): raise exc def _exec(self,prep_res): for self.cur_retry in range(self.max_retries): @@ -37,16 +37,18 @@ class BatchNode(Node): def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])] class Flow(BaseNode): - def __init__(self,start): super().__init__();self.start=start + def __init__(self,start): super().__init__(); self.start=start + def start(self,start): self.start=start; return start def get_next_node(self,curr,action): nxt=curr.successors.get(action or "default") if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") return nxt def _orch(self,shared,params=None): - curr,p=copy.copy(self.start),(params or {**self.params}) - while curr: curr.set_params(p);c=curr._run(shared);curr=copy.copy(self.get_next_node(curr,c)) - def _run(self,shared): pr=self.prep(shared);self._orch(shared);return self.post(shared,pr,None) - def exec(self,prep_res): raise RuntimeError("Flow can't exec.") + curr,p,last_action =copy.copy(self.start),(params or {**self.params}),None + while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) + return last_action + def _run(self,shared): pr=self.prep(shared); self._orch(shared); return self.post(shared,pr,None) + def post(self,shared,prep_res,exec_res): return exec_res class BatchFlow(Flow): def _run(self,shared): @@ -55,11 +57,6 @@ class BatchFlow(Flow): return self.post(shared,pr,None) class AsyncNode(Node): - def prep(self,shared): raise RuntimeError("Use prep_async.") - def exec(self,prep_res): raise RuntimeError("Use exec_async.") - def post(self,shared,prep_res,exec_res): raise RuntimeError("Use post_async.") - def exec_fallback(self,prep_res,exc): raise RuntimeError("Use exec_fallback_async.") - def _run(self,shared): raise RuntimeError("Use run_async.") async def prep_async(self,shared): pass async def exec_async(self,prep_res): pass async def exec_fallback_async(self,prep_res,exc): raise exc @@ -73,7 +70,8 @@ class AsyncNode(Node): async def run_async(self,shared): if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.") return await self._run_async(shared) - async def _run_async(self,shared): p=await self.prep_async(shared);e=await self._exec(p);return await self.post_async(shared,p,e) + async def _run_async(self,shared): p=await self.prep_async(shared); e=await self._exec(p); return await self.post_async(shared,p,e) + def _run(self,shared): raise RuntimeError("Use run_async.") class AsyncBatchNode(AsyncNode,BatchNode): async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items] @@ -83,9 +81,11 @@ class AsyncParallelBatchNode(AsyncNode,BatchNode): class AsyncFlow(Flow,AsyncNode): async def _orch_async(self,shared,params=None): - curr,p=copy.copy(self.start),(params or {**self.params}) - while curr:curr.set_params(p);c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared);curr=copy.copy(self.get_next_node(curr,c)) - async def _run_async(self,shared): p=await self.prep_async(shared);await self._orch_async(shared);return await self.post_async(shared,p,None) + curr,p,last_action =copy.copy(self.start),(params or {**self.params}),None + while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) + return last_action + async def _run_async(self,shared): p=await self.prep_async(shared); await self._orch_async(shared); return await self.post_async(shared,p,None) + async def post_async(self,shared,prep_res,exec_res): return exec_res class AsyncBatchFlow(AsyncFlow,BatchFlow): async def _run_async(self,shared): @@ -94,7 +94,7 @@ class AsyncBatchFlow(AsyncFlow,BatchFlow): return await self.post_async(shared,pr,None) class AsyncParallelBatchFlow(AsyncFlow,BatchFlow): - async def _run_async(self,shared): + async def _run_async(self,shared): pr=await self.prep_async(shared) or [] await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr)) return await self.post_async(shared,pr,None) \ No newline at end of file