This commit is contained in:
zachary62 2024-12-26 06:18:43 +00:00
parent 3e42cb2a9c
commit 11d85d78c1
1 changed files with 78 additions and 75 deletions

View File

@ -11,10 +11,11 @@ class BaseNode:
def _exec(self,shared,prep_res): return self.exec(shared,prep_res)
def post(self,shared,prep_res,exec_res): return "default"
def _run(self,shared):
prep_res = self.prep(shared); exec_res = self._exec(shared, prep_res)
prep_res=self.prep(shared)
exec_res=self._exec(shared,prep_res)
return self.post(shared,prep_res,exec_res)
def run(self,shared):
if self.successors: warnings.warn("This node has successors. Create a parent Flow instead to run them.")
if self.successors: warnings.warn("This node has successors. Create a parent Flow instead.")
return self._run(shared)
def __rshift__(self,other): return self.add_successor(other)
def __sub__(self,cond):
@ -40,17 +41,20 @@ class BatchNode(Node):
def _exec(self,shared,items): return [super(Node,self)._exec(shared,i) for i in items]
class BaseFlow(BaseNode):
def __init__(self, start_node): super().__init__(); self.start_node = start_node
def __init__(self,start_node):
super().__init__()
self.start_node=start_node
def get_next_node(self,curr,cond):
nxt=curr.successors.get(cond)
if nxt is None and curr.successors: warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors.keys())}")
if not nxt and curr.successors: warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors)}")
return nxt
class Flow(BaseFlow):
def _exec(self,shared,params=None):
curr, p = self.start_node, (params if params else self.params.copy())
curr,p=self.start_node,(params if params else {**self.params})
while curr:
curr.set_params(p); c = curr._run(shared)
curr.set_params(p)
c=curr._run(shared)
curr=self.get_next_node(curr,c)
def exec(self,shared,prep_res): raise NotImplementedError
@ -60,25 +64,25 @@ class BaseBatchFlow(BaseFlow):
class BatchFlow(BaseBatchFlow,Flow):
def _run(self,shared):
prep_res=self.prep(shared)
for batch_params in prep_res:
flow_params = self.params.copy(); flow_params.update(batch_params)
self._exec(shared, flow_params)
for batch_params in prep_res:self._exec(shared,{**self.params,**batch_params})
return self.post(shared,prep_res,None)
class AsyncNode(Node):
def post(self,shared,prep_res,exec_res): raise NotImplementedError("Use post_async")
async def post_async(self, shared, prep_res, exec_res): await asyncio.sleep(0); return "default"
async def post_async(self,shared,prep_res,exec_res):
await asyncio.sleep(0);return "default"
async def run_async(self,shared):
if self.successors: warnings.warn("This node has successors. Create a parent AsyncFlow instead to run them.")
if self.successors: warnings.warn("This node has successors. Create a parent AsyncFlow.")
return await self._run_async(shared)
async def _run_async(self,shared):
prep_res = self.prep(shared); exec_res = self._exec(shared, prep_res)
prep_res=self.prep(shared)
exec_res=self._exec(shared,prep_res)
return await self.post_async(shared,prep_res,exec_res)
def _run(self, shared): raise RuntimeError("AsyncNode requires async execution. Use run_async instead.")
def _run(self,shared): raise RuntimeError("AsyncNode requires async execution")
class AsyncFlow(BaseFlow,AsyncNode):
async def _exec_async(self,shared,params=None):
curr, p = self.start_node, (params if params else self.params.copy())
curr,p=self.start_node,(params if params else {**self.params})
while curr:
curr.set_params(p)
c=await curr._run_async(shared) if hasattr(curr,"run_async") else curr._run(shared)
@ -91,7 +95,6 @@ class AsyncFlow(BaseFlow, AsyncNode):
class BatchAsyncFlow(BaseBatchFlow,AsyncFlow):
async def _run_async(self,shared):
prep_res=self.prep(shared)
for batch_params in prep_res:
flow_params = self.params.copy(); flow_params.update(batch_params)
await self._exec_async(shared, flow_params)
for batch_params in prep_res:await self._exec_async(shared,{**self.params,**batch_params})
return await self.post_async(shared,prep_res,None)
def exec(self,shared,prep_res): raise NotImplementedError("BatchAsyncFlow does not support exec")