refactor
This commit is contained in:
parent
11d85d78c1
commit
9af98f00f3
|
|
@ -4,7 +4,7 @@ class BaseNode:
|
||||||
def __init__(self): self.params,self.successors={},{}
|
def __init__(self): self.params,self.successors={},{}
|
||||||
def set_params(self,params): self.params=params
|
def set_params(self,params): self.params=params
|
||||||
def add_successor(self,node,cond="default"):
|
def add_successor(self,node,cond="default"):
|
||||||
if cond in self.successors: warnings.warn(f"Overwriting existing successor for '{cond}'")
|
if cond in self.successors: warnings.warn(f"Overwriting successor for condition '{cond}'")
|
||||||
self.successors[cond]=node;return node
|
self.successors[cond]=node;return node
|
||||||
def prep(self,shared): return None
|
def prep(self,shared): return None
|
||||||
def exec(self,shared,prep_res): return None
|
def exec(self,shared,prep_res): return None
|
||||||
|
|
@ -15,7 +15,7 @@ class BaseNode:
|
||||||
exec_res=self._exec(shared,prep_res)
|
exec_res=self._exec(shared,prep_res)
|
||||||
return self.post(shared,prep_res,exec_res)
|
return self.post(shared,prep_res,exec_res)
|
||||||
def run(self,shared):
|
def run(self,shared):
|
||||||
if self.successors: warnings.warn("This node has successors. Create a parent Flow instead.")
|
if self.successors: warnings.warn("Node won't run successors. Use a parent Flow instead.")
|
||||||
return self._run(shared)
|
return self._run(shared)
|
||||||
def __rshift__(self,other): return self.add_successor(other)
|
def __rshift__(self,other): return self.add_successor(other)
|
||||||
def __sub__(self,cond):
|
def __sub__(self,cond):
|
||||||
|
|
@ -27,7 +27,9 @@ class _ConditionalTransition:
|
||||||
def __rshift__(self,tgt): return self.src.add_successor(tgt,self.cond)
|
def __rshift__(self,tgt): return self.src.add_successor(tgt,self.cond)
|
||||||
|
|
||||||
class Node(BaseNode):
|
class Node(BaseNode):
|
||||||
def __init__(self,max_retries=1): super().__init__();self.max_retries=max_retries
|
def __init__(self,max_retries=1):
|
||||||
|
super().__init__()
|
||||||
|
self.max_retries=max_retries
|
||||||
def process_after_fail(self,shared,prep_res,exc): raise exc
|
def process_after_fail(self,shared,prep_res,exc): raise exc
|
||||||
def _exec(self,shared,prep_res):
|
def _exec(self,shared,prep_res):
|
||||||
for i in range(self.max_retries):
|
for i in range(self.max_retries):
|
||||||
|
|
@ -37,50 +39,49 @@ class Node(BaseNode):
|
||||||
|
|
||||||
class BatchNode(Node):
|
class BatchNode(Node):
|
||||||
def prep(self,shared): return []
|
def prep(self,shared): return []
|
||||||
def exec(self,shared,item): return None
|
|
||||||
def _exec(self,shared,items): return [super(Node,self)._exec(shared,i) for i in items]
|
def _exec(self,shared,items): return [super(Node,self)._exec(shared,i) for i in items]
|
||||||
|
|
||||||
class BaseFlow(BaseNode):
|
class Flow(BaseNode):
|
||||||
def __init__(self,start_node):
|
def __init__(self,start_node):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.start_node=start_node
|
self.start_node=start_node
|
||||||
def get_next_node(self,curr,cond):
|
def get_next_node(self,curr,cond):
|
||||||
nxt=curr.successors.get(cond)
|
nxt=curr.successors.get(cond if cond is not None else "default")
|
||||||
if not nxt and curr.successors: warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors)}")
|
if not nxt and curr.successors:
|
||||||
|
warnings.warn(f"Flow ends: condition '{cond}' not found in {list(curr.successors)}")
|
||||||
return nxt
|
return nxt
|
||||||
|
|
||||||
class Flow(BaseFlow):
|
|
||||||
def _exec(self,shared,params=None):
|
def _exec(self,shared,params=None):
|
||||||
curr,p=self.start_node,(params if params else {**self.params})
|
curr,p=self.start_node,(params if params else {**self.params})
|
||||||
while curr:
|
while curr:
|
||||||
curr.set_params(p)
|
curr.set_params(p)
|
||||||
c=curr._run(shared)
|
c=curr._run(shared)
|
||||||
curr=self.get_next_node(curr,c)
|
curr=self.get_next_node(curr,c)
|
||||||
def exec(self,shared,prep_res): raise NotImplementedError
|
def exec(self,shared,prep_res):
|
||||||
|
raise RuntimeError("Flow should not exec directly. Create a child Node instead.")
|
||||||
|
|
||||||
class BaseBatchFlow(BaseFlow):
|
class BatchFlow(Flow):
|
||||||
def prep(self,shared): return []
|
def prep(self,shared): return []
|
||||||
|
|
||||||
class BatchFlow(BaseBatchFlow,Flow):
|
|
||||||
def _run(self,shared):
|
def _run(self,shared):
|
||||||
prep_res=self.prep(shared)
|
prep_res=self.prep(shared)
|
||||||
for batch_params in prep_res:self._exec(shared,{**self.params,**batch_params})
|
for batch_params in prep_res:self._exec(shared,{**self.params,**batch_params})
|
||||||
return self.post(shared,prep_res,None)
|
return self.post(shared,prep_res,None)
|
||||||
|
|
||||||
class AsyncNode(Node):
|
class AsyncNode(Node):
|
||||||
def post(self,shared,prep_res,exec_res): raise NotImplementedError("Use post_async")
|
def post(self,shared,prep_res,exec_res):
|
||||||
|
raise RuntimeError("AsyncNode should post using post_async instead.")
|
||||||
async def post_async(self,shared,prep_res,exec_res):
|
async def post_async(self,shared,prep_res,exec_res):
|
||||||
await asyncio.sleep(0);return "default"
|
await asyncio.sleep(0);return "default"
|
||||||
async def run_async(self,shared):
|
async def run_async(self,shared):
|
||||||
if self.successors: warnings.warn("This node has successors. Create a parent AsyncFlow.")
|
if self.successors:
|
||||||
|
warnings.warn("Node won't run successors. Use a parent AsyncFlow instead.")
|
||||||
return await self._run_async(shared)
|
return await self._run_async(shared)
|
||||||
async def _run_async(self,shared):
|
async def _run_async(self,shared):
|
||||||
prep_res=self.prep(shared)
|
prep_res=self.prep(shared)
|
||||||
exec_res=self._exec(shared,prep_res)
|
exec_res=self._exec(shared,prep_res)
|
||||||
return await self.post_async(shared,prep_res,exec_res)
|
return await self.post_async(shared,prep_res,exec_res)
|
||||||
def _run(self,shared): raise RuntimeError("AsyncNode requires async execution")
|
def _run(self,shared): raise RuntimeError("AsyncNode should run using run_async instead.")
|
||||||
|
|
||||||
class AsyncFlow(BaseFlow,AsyncNode):
|
class AsyncFlow(Flow,AsyncNode):
|
||||||
async def _exec_async(self,shared,params=None):
|
async def _exec_async(self,shared,params=None):
|
||||||
curr,p=self.start_node,(params if params else {**self.params})
|
curr,p=self.start_node,(params if params else {**self.params})
|
||||||
while curr:
|
while curr:
|
||||||
|
|
@ -92,9 +93,8 @@ class AsyncFlow(BaseFlow,AsyncNode):
|
||||||
await self._exec_async(shared)
|
await self._exec_async(shared)
|
||||||
return await self.post_async(shared,prep_res,None)
|
return await self.post_async(shared,prep_res,None)
|
||||||
|
|
||||||
class BatchAsyncFlow(BaseBatchFlow,AsyncFlow):
|
class BatchAsyncFlow(BatchFlow,AsyncFlow):
|
||||||
async def _run_async(self,shared):
|
async def _run_async(self,shared):
|
||||||
prep_res=self.prep(shared)
|
prep_res=self.prep(shared)
|
||||||
for batch_params in prep_res:await self._exec_async(shared,{**self.params,**batch_params})
|
for batch_params in prep_res:await self._exec_async(shared,{**self.params,**batch_params})
|
||||||
return await self.post_async(shared,prep_res,None)
|
return await self.post_async(shared,prep_res,None)
|
||||||
def exec(self,shared,prep_res): raise NotImplementedError("BatchAsyncFlow does not support exec")
|
|
||||||
Loading…
Reference in New Issue