diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index 58e0358..a2b5fb3 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -39,8 +39,7 @@ class Flow(BaseNode): def __init__(self,start): super().__init__();self.start=start def get_next_node(self,curr,action): nxt=curr.successors.get(action or "default") - if not nxt and curr.successors: - warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") + if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") return nxt def _orch(self,shared,params=None): curr,p=copy.copy(self.start),(params or {**self.params}) @@ -72,9 +71,7 @@ class AsyncNode(Node): async def run_async(self,shared): if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.") return await self._run_async(shared) - async def _run_async(self,shared): - p=await self.prep_async(shared);e=await self._exec(p) - return await self.post_async(shared,p,e) + async def _run_async(self,shared): p=await self.prep_async(shared);e=await self._exec(p);return await self.post_async(shared,p,e) class AsyncBatchNode(AsyncNode): async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items] @@ -86,12 +83,9 @@ class AsyncFlow(Flow,AsyncNode): async def _orch_async(self,shared,params=None): curr,p=copy.copy(self.start),(params or {**self.params}) while curr: - curr.set_params(p) - c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared) + curr.set_params(p);c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared) curr=copy.copy(self.get_next_node(curr,c)) - async def _run_async(self,shared): - pr=await self.prep_async(shared);await self._orch_async(shared) - return await self.post_async(shared,pr,None) + async def _run_async(self,shared): p=await self.prep_async(shared);await self._orch_async(shared);return await self.post_async(shared,p,None) class AsyncBatchFlow(AsyncFlow): async def _run_async(self,shared):