diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index 79250b5..eecd859 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -57,7 +57,7 @@ class Flow(BaseFlow): curr.set_params(p) c = curr._run(shared) curr = self.get_next_node(curr, c) - def exec(self, shared, prep_res): raise NotImplementedError("Flow exec not used directly") + def exec(self, shared, prep_res): raise NotImplementedError class BaseBatchFlow(BaseFlow): def prep(self, shared): return [] @@ -65,9 +65,10 @@ class BaseBatchFlow(BaseFlow): class BatchFlow(BaseBatchFlow, Flow): def _run(self, shared): prep_res = self.prep(shared) - for d in prep_res: - mp = self.params.copy(); mp.update(d) - self._exec(shared, mp) + for batch_params in prep_res: + flow_params = self.params.copy() + flow_params.update(batch_params) + self._exec(shared, flow_params) return self.post(shared, prep_res, None) class AsyncNode(Node): @@ -88,7 +89,7 @@ class AsyncFlow(BaseFlow, AsyncNode): curr, p = self.start_node, (params if params else self.params.copy()) while curr: curr.set_params(p) - c = await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared) + c = (await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared)) curr = self.get_next_node(curr, c) async def _run_async(self, shared): prep_res = self.prep(shared) @@ -98,7 +99,8 @@ class AsyncFlow(BaseFlow, AsyncNode): class BatchAsyncFlow(BaseBatchFlow, AsyncFlow): async def _run_async(self, shared): prep_res = self.prep(shared) - for d in prep_res: - mp = self.params.copy(); mp.update(d) - await self._exec_async(shared, mp) + for batch_params in prep_res: + flow_params = self.params.copy() + flow_params.update(batch_params) + await self._exec_async(shared, flow_params) return await self.post_async(shared, prep_res, None) \ No newline at end of file