refactor
This commit is contained in:
parent
86297ea2a8
commit
1ca48ddfce
|
|
@ -39,8 +39,7 @@ class Flow(BaseNode):
|
||||||
def __init__(self,start): super().__init__();self.start=start
|
def __init__(self,start): super().__init__();self.start=start
|
||||||
def get_next_node(self,curr,action):
|
def get_next_node(self,curr,action):
|
||||||
nxt=curr.successors.get(action or "default")
|
nxt=curr.successors.get(action or "default")
|
||||||
if not nxt and curr.successors:
|
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
|
||||||
warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
|
|
||||||
return nxt
|
return nxt
|
||||||
def _orch(self,shared,params=None):
|
def _orch(self,shared,params=None):
|
||||||
curr,p=copy.copy(self.start),(params or {**self.params})
|
curr,p=copy.copy(self.start),(params or {**self.params})
|
||||||
|
|
@ -72,9 +71,7 @@ class AsyncNode(Node):
|
||||||
async def run_async(self,shared):
|
async def run_async(self,shared):
|
||||||
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
|
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
|
||||||
return await self._run_async(shared)
|
return await self._run_async(shared)
|
||||||
async def _run_async(self,shared):
|
async def _run_async(self,shared): p=await self.prep_async(shared);e=await self._exec(p);return await self.post_async(shared,p,e)
|
||||||
p=await self.prep_async(shared);e=await self._exec(p)
|
|
||||||
return await self.post_async(shared,p,e)
|
|
||||||
|
|
||||||
class AsyncBatchNode(AsyncNode):
|
class AsyncBatchNode(AsyncNode):
|
||||||
async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items]
|
async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items]
|
||||||
|
|
@ -86,12 +83,9 @@ class AsyncFlow(Flow,AsyncNode):
|
||||||
async def _orch_async(self,shared,params=None):
|
async def _orch_async(self,shared,params=None):
|
||||||
curr,p=copy.copy(self.start),(params or {**self.params})
|
curr,p=copy.copy(self.start),(params or {**self.params})
|
||||||
while curr:
|
while curr:
|
||||||
curr.set_params(p)
|
curr.set_params(p);c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared)
|
||||||
c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared)
|
|
||||||
curr=copy.copy(self.get_next_node(curr,c))
|
curr=copy.copy(self.get_next_node(curr,c))
|
||||||
async def _run_async(self,shared):
|
async def _run_async(self,shared): p=await self.prep_async(shared);await self._orch_async(shared);return await self.post_async(shared,p,None)
|
||||||
pr=await self.prep_async(shared);await self._orch_async(shared)
|
|
||||||
return await self.post_async(shared,pr,None)
|
|
||||||
|
|
||||||
class AsyncBatchFlow(AsyncFlow):
|
class AsyncBatchFlow(AsyncFlow):
|
||||||
async def _run_async(self,shared):
|
async def _run_async(self,shared):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue