rename params

This commit is contained in:
zachary62 2024-12-26 03:32:30 +00:00
parent 4b9b357608
commit 2b024d0ec2
1 changed files with 10 additions and 8 deletions

View File

@ -57,7 +57,7 @@ class Flow(BaseFlow):
curr.set_params(p) curr.set_params(p)
c = curr._run(shared) c = curr._run(shared)
curr = self.get_next_node(curr, c) curr = self.get_next_node(curr, c)
def exec(self, shared, prep_res): raise NotImplementedError("Flow exec not used directly") def exec(self, shared, prep_res): raise NotImplementedError
class BaseBatchFlow(BaseFlow): class BaseBatchFlow(BaseFlow):
def prep(self, shared): return [] def prep(self, shared): return []
@ -65,9 +65,10 @@ class BaseBatchFlow(BaseFlow):
class BatchFlow(BaseBatchFlow, Flow): class BatchFlow(BaseBatchFlow, Flow):
def _run(self, shared): def _run(self, shared):
prep_res = self.prep(shared) prep_res = self.prep(shared)
for d in prep_res: for batch_params in prep_res:
mp = self.params.copy(); mp.update(d) flow_params = self.params.copy()
self._exec(shared, mp) flow_params.update(batch_params)
self._exec(shared, flow_params)
return self.post(shared, prep_res, None) return self.post(shared, prep_res, None)
class AsyncNode(Node): class AsyncNode(Node):
@ -88,7 +89,7 @@ class AsyncFlow(BaseFlow, AsyncNode):
curr, p = self.start_node, (params if params else self.params.copy()) curr, p = self.start_node, (params if params else self.params.copy())
while curr: while curr:
curr.set_params(p) curr.set_params(p)
c = await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared) c = (await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared))
curr = self.get_next_node(curr, c) curr = self.get_next_node(curr, c)
async def _run_async(self, shared): async def _run_async(self, shared):
prep_res = self.prep(shared) prep_res = self.prep(shared)
@ -98,7 +99,8 @@ class AsyncFlow(BaseFlow, AsyncNode):
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow): class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
async def _run_async(self, shared): async def _run_async(self, shared):
prep_res = self.prep(shared) prep_res = self.prep(shared)
for d in prep_res: for batch_params in prep_res:
mp = self.params.copy(); mp.update(d) flow_params = self.params.copy()
await self._exec_async(shared, mp) flow_params.update(batch_params)
await self._exec_async(shared, flow_params)
return await self.post_async(shared, prep_res, None) return await self.post_async(shared, prep_res, None)