rename params

This commit is contained in:
zachary62 2024-12-26 03:32:30 +00:00
parent 4b9b357608
commit 2b024d0ec2
1 changed files with 10 additions and 8 deletions

View File

@ -57,7 +57,7 @@ class Flow(BaseFlow):
curr.set_params(p)
c = curr._run(shared)
curr = self.get_next_node(curr, c)
def exec(self, shared, prep_res): raise NotImplementedError("Flow exec not used directly")
def exec(self, shared, prep_res): raise NotImplementedError
class BaseBatchFlow(BaseFlow):
def prep(self, shared): return []
@ -65,9 +65,10 @@ class BaseBatchFlow(BaseFlow):
class BatchFlow(BaseBatchFlow, Flow):
def _run(self, shared):
prep_res = self.prep(shared)
for d in prep_res:
mp = self.params.copy(); mp.update(d)
self._exec(shared, mp)
for batch_params in prep_res:
flow_params = self.params.copy()
flow_params.update(batch_params)
self._exec(shared, flow_params)
return self.post(shared, prep_res, None)
class AsyncNode(Node):
@ -88,7 +89,7 @@ class AsyncFlow(BaseFlow, AsyncNode):
curr, p = self.start_node, (params if params else self.params.copy())
while curr:
curr.set_params(p)
c = await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared)
c = (await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared))
curr = self.get_next_node(curr, c)
async def _run_async(self, shared):
prep_res = self.prep(shared)
@ -98,7 +99,8 @@ class AsyncFlow(BaseFlow, AsyncNode):
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
async def _run_async(self, shared):
prep_res = self.prep(shared)
for d in prep_res:
mp = self.params.copy(); mp.update(d)
await self._exec_async(shared, mp)
for batch_params in prep_res:
flow_params = self.params.copy()
flow_params.update(batch_params)
await self._exec_async(shared, flow_params)
return await self.post_async(shared, prep_res, None)