refactor
This commit is contained in:
parent
86297ea2a8
commit
1ca48ddfce
|
|
@ -39,8 +39,7 @@ class Flow(BaseNode):
|
|||
def __init__(self,start): super().__init__();self.start=start
|
||||
def get_next_node(self,curr,action):
|
||||
nxt=curr.successors.get(action or "default")
|
||||
if not nxt and curr.successors:
|
||||
warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
|
||||
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
|
||||
return nxt
|
||||
def _orch(self,shared,params=None):
|
||||
curr,p=copy.copy(self.start),(params or {**self.params})
|
||||
|
|
@ -72,9 +71,7 @@ class AsyncNode(Node):
|
|||
async def run_async(self,shared):
|
||||
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
|
||||
return await self._run_async(shared)
|
||||
async def _run_async(self,shared):
|
||||
p=await self.prep_async(shared);e=await self._exec(p)
|
||||
return await self.post_async(shared,p,e)
|
||||
async def _run_async(self,shared): p=await self.prep_async(shared);e=await self._exec(p);return await self.post_async(shared,p,e)
|
||||
|
||||
class AsyncBatchNode(AsyncNode):
|
||||
async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items]
|
||||
|
|
@ -86,12 +83,9 @@ class AsyncFlow(Flow,AsyncNode):
|
|||
async def _orch_async(self,shared,params=None):
|
||||
curr,p=copy.copy(self.start),(params or {**self.params})
|
||||
while curr:
|
||||
curr.set_params(p)
|
||||
c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared)
|
||||
curr.set_params(p);c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared)
|
||||
curr=copy.copy(self.get_next_node(curr,c))
|
||||
async def _run_async(self,shared):
|
||||
pr=await self.prep_async(shared);await self._orch_async(shared)
|
||||
return await self.post_async(shared,pr,None)
|
||||
async def _run_async(self,shared): p=await self.prep_async(shared);await self._orch_async(shared);return await self.post_async(shared,p,None)
|
||||
|
||||
class AsyncBatchFlow(AsyncFlow):
|
||||
async def _run_async(self,shared):
|
||||
|
|
|
|||
Loading…
Reference in New Issue