enrich warning

This commit is contained in:
zachary62 2024-12-26 03:38:46 +00:00
parent 2b024d0ec2
commit 63172fafa8
1 changed files with 11 additions and 20 deletions

View File

@ -11,12 +11,10 @@ class BaseNode:
def _exec(self, shared, prep_res): return self.exec(shared, prep_res) def _exec(self, shared, prep_res): return self.exec(shared, prep_res)
def post(self, shared, prep_res, exec_res): return "default" def post(self, shared, prep_res, exec_res): return "default"
def _run(self, shared): def _run(self, shared):
prep_res = self.prep(shared) prep_res = self.prep(shared); exec_res = self._exec(shared, prep_res)
exec_res = self._exec(shared, prep_res)
return self.post(shared, prep_res, exec_res) return self.post(shared, prep_res, exec_res)
def run(self, shared): def run(self, shared):
if self.successors: if self.successors: warnings.warn("This node has successors. Create a parent Flow instead to run them.")
warnings.warn("This node has successors. Create a parent Flow to run them or run that Flow instead.")
return self._run(shared) return self._run(shared)
def __rshift__(self, other): return self.add_successor(other) def __rshift__(self, other): return self.add_successor(other)
def __sub__(self, cond): def __sub__(self, cond):
@ -45,17 +43,14 @@ class BaseFlow(BaseNode):
def __init__(self, start_node): super().__init__(); self.start_node = start_node def __init__(self, start_node): super().__init__(); self.start_node = start_node
def get_next_node(self, curr, cond): def get_next_node(self, curr, cond):
nxt = curr.successors.get(cond) nxt = curr.successors.get(cond)
if nxt is None and curr.successors: if nxt is None and curr.successors: warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors.keys())}")
warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors.keys())}")
return nxt return nxt
class Flow(BaseFlow): class Flow(BaseFlow):
def _exec(self, shared, params=None): def _exec(self, shared, params=None):
curr = self.start_node curr, p = self.start_node, (params if params else self.params.copy())
p = params if params else self.params.copy()
while curr: while curr:
curr.set_params(p) curr.set_params(p); c = curr._run(shared)
c = curr._run(shared)
curr = self.get_next_node(curr, c) curr = self.get_next_node(curr, c)
def exec(self, shared, prep_res): raise NotImplementedError def exec(self, shared, prep_res): raise NotImplementedError
@ -66,8 +61,7 @@ class BatchFlow(BaseBatchFlow, Flow):
def _run(self, shared): def _run(self, shared):
prep_res = self.prep(shared) prep_res = self.prep(shared)
for batch_params in prep_res: for batch_params in prep_res:
flow_params = self.params.copy() flow_params = self.params.copy(); flow_params.update(batch_params)
flow_params.update(batch_params)
self._exec(shared, flow_params) self._exec(shared, flow_params)
return self.post(shared, prep_res, None) return self.post(shared, prep_res, None)
@ -75,21 +69,19 @@ class AsyncNode(Node):
def post(self, shared, prep_res, exec_res): raise NotImplementedError("Use post_async") def post(self, shared, prep_res, exec_res): raise NotImplementedError("Use post_async")
async def post_async(self, shared, prep_res, exec_res): await asyncio.sleep(0); return "default" async def post_async(self, shared, prep_res, exec_res): await asyncio.sleep(0); return "default"
async def run_async(self, shared): async def run_async(self, shared):
if self.successors: if self.successors: warnings.warn("This node has successors. Create a parent AsyncFlow instead to run them.")
warnings.warn("This node has successors. Create a parent AsyncFlow to run them or run that Flow instead.")
return await self._run_async(shared) return await self._run_async(shared)
async def _run_async(self, shared): async def _run_async(self, shared):
prep_res = self.prep(shared) prep_res = self.prep(shared); exec_res = self._exec(shared, prep_res)
exec_res = self._exec(shared, prep_res)
return await self.post_async(shared, prep_res, exec_res) return await self.post_async(shared, prep_res, exec_res)
def _run(self, shared): raise RuntimeError("AsyncNode requires async execution") def _run(self, shared): raise RuntimeError("AsyncNode requires async execution. Use run_async instead.")
class AsyncFlow(BaseFlow, AsyncNode): class AsyncFlow(BaseFlow, AsyncNode):
async def _exec_async(self, shared, params=None): async def _exec_async(self, shared, params=None):
curr, p = self.start_node, (params if params else self.params.copy()) curr, p = self.start_node, (params if params else self.params.copy())
while curr: while curr:
curr.set_params(p) curr.set_params(p)
c = (await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared)) c = await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared)
curr = self.get_next_node(curr, c) curr = self.get_next_node(curr, c)
async def _run_async(self, shared): async def _run_async(self, shared):
prep_res = self.prep(shared) prep_res = self.prep(shared)
@ -100,7 +92,6 @@ class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
async def _run_async(self, shared): async def _run_async(self, shared):
prep_res = self.prep(shared) prep_res = self.prep(shared)
for batch_params in prep_res: for batch_params in prep_res:
flow_params = self.params.copy() flow_params = self.params.copy(); flow_params.update(batch_params)
flow_params.update(batch_params)
await self._exec_async(shared, flow_params) await self._exec_async(shared, flow_params)
return await self.post_async(shared, prep_res, None) return await self.post_async(shared, prep_res, None)