rename params
This commit is contained in:
parent
4b9b357608
commit
2b024d0ec2
|
|
@ -57,7 +57,7 @@ class Flow(BaseFlow):
|
||||||
curr.set_params(p)
|
curr.set_params(p)
|
||||||
c = curr._run(shared)
|
c = curr._run(shared)
|
||||||
curr = self.get_next_node(curr, c)
|
curr = self.get_next_node(curr, c)
|
||||||
def exec(self, shared, prep_res): raise NotImplementedError("Flow exec not used directly")
|
def exec(self, shared, prep_res): raise NotImplementedError
|
||||||
|
|
||||||
class BaseBatchFlow(BaseFlow):
|
class BaseBatchFlow(BaseFlow):
|
||||||
def prep(self, shared): return []
|
def prep(self, shared): return []
|
||||||
|
|
@ -65,9 +65,10 @@ class BaseBatchFlow(BaseFlow):
|
||||||
class BatchFlow(BaseBatchFlow, Flow):
|
class BatchFlow(BaseBatchFlow, Flow):
|
||||||
def _run(self, shared):
|
def _run(self, shared):
|
||||||
prep_res = self.prep(shared)
|
prep_res = self.prep(shared)
|
||||||
for d in prep_res:
|
for batch_params in prep_res:
|
||||||
mp = self.params.copy(); mp.update(d)
|
flow_params = self.params.copy()
|
||||||
self._exec(shared, mp)
|
flow_params.update(batch_params)
|
||||||
|
self._exec(shared, flow_params)
|
||||||
return self.post(shared, prep_res, None)
|
return self.post(shared, prep_res, None)
|
||||||
|
|
||||||
class AsyncNode(Node):
|
class AsyncNode(Node):
|
||||||
|
|
@ -88,7 +89,7 @@ class AsyncFlow(BaseFlow, AsyncNode):
|
||||||
curr, p = self.start_node, (params if params else self.params.copy())
|
curr, p = self.start_node, (params if params else self.params.copy())
|
||||||
while curr:
|
while curr:
|
||||||
curr.set_params(p)
|
curr.set_params(p)
|
||||||
c = await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared)
|
c = (await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared))
|
||||||
curr = self.get_next_node(curr, c)
|
curr = self.get_next_node(curr, c)
|
||||||
async def _run_async(self, shared):
|
async def _run_async(self, shared):
|
||||||
prep_res = self.prep(shared)
|
prep_res = self.prep(shared)
|
||||||
|
|
@ -98,7 +99,8 @@ class AsyncFlow(BaseFlow, AsyncNode):
|
||||||
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
class BatchAsyncFlow(BaseBatchFlow, AsyncFlow):
|
||||||
async def _run_async(self, shared):
|
async def _run_async(self, shared):
|
||||||
prep_res = self.prep(shared)
|
prep_res = self.prep(shared)
|
||||||
for d in prep_res:
|
for batch_params in prep_res:
|
||||||
mp = self.params.copy(); mp.update(d)
|
flow_params = self.params.copy()
|
||||||
await self._exec_async(shared, mp)
|
flow_params.update(batch_params)
|
||||||
|
await self._exec_async(shared, flow_params)
|
||||||
return await self.post_async(shared, prep_res, None)
|
return await self.post_async(shared, prep_res, None)
|
||||||
Loading…
Reference in New Issue