diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index eecd859..dc5a9d6 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -11,12 +11,10 @@ class BaseNode: def _exec(self, shared, prep_res): return self.exec(shared, prep_res) def post(self, shared, prep_res, exec_res): return "default" def _run(self, shared): - prep_res = self.prep(shared) - exec_res = self._exec(shared, prep_res) + prep_res = self.prep(shared); exec_res = self._exec(shared, prep_res) return self.post(shared, prep_res, exec_res) def run(self, shared): - if self.successors: - warnings.warn("This node has successors. Create a parent Flow to run them or run that Flow instead.") + if self.successors: warnings.warn("This node has successors. Create a parent Flow instead to run them.") return self._run(shared) def __rshift__(self, other): return self.add_successor(other) def __sub__(self, cond): @@ -45,17 +43,14 @@ class BaseFlow(BaseNode): def __init__(self, start_node): super().__init__(); self.start_node = start_node def get_next_node(self, curr, cond): nxt = curr.successors.get(cond) - if nxt is None and curr.successors: - warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors.keys())}") + if nxt is None and curr.successors: warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors.keys())}") return nxt class Flow(BaseFlow): def _exec(self, shared, params=None): - curr = self.start_node - p = params if params else self.params.copy() + curr, p = self.start_node, (params if params else self.params.copy()) while curr: - curr.set_params(p) - c = curr._run(shared) + curr.set_params(p); c = curr._run(shared) curr = self.get_next_node(curr, c) def exec(self, shared, prep_res): raise NotImplementedError @@ -66,8 +61,7 @@ class BatchFlow(BaseBatchFlow, Flow): def _run(self, shared): prep_res = self.prep(shared) for batch_params in prep_res: - flow_params = self.params.copy() - flow_params.update(batch_params) + flow_params = self.params.copy(); flow_params.update(batch_params) self._exec(shared, flow_params) return self.post(shared, prep_res, None) @@ -75,21 +69,19 @@ class AsyncNode(Node): def post(self, shared, prep_res, exec_res): raise NotImplementedError("Use post_async") async def post_async(self, shared, prep_res, exec_res): await asyncio.sleep(0); return "default" async def run_async(self, shared): - if self.successors: - warnings.warn("This node has successors. Create a parent AsyncFlow to run them or run that Flow instead.") + if self.successors: warnings.warn("This node has successors. Create a parent AsyncFlow instead to run them.") return await self._run_async(shared) async def _run_async(self, shared): - prep_res = self.prep(shared) - exec_res = self._exec(shared, prep_res) + prep_res = self.prep(shared); exec_res = self._exec(shared, prep_res) return await self.post_async(shared, prep_res, exec_res) - def _run(self, shared): raise RuntimeError("AsyncNode requires async execution") + def _run(self, shared): raise RuntimeError("AsyncNode requires async execution. Use run_async instead.") class AsyncFlow(BaseFlow, AsyncNode): async def _exec_async(self, shared, params=None): curr, p = self.start_node, (params if params else self.params.copy()) while curr: curr.set_params(p) - c = (await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared)) + c = await curr._run_async(shared) if hasattr(curr, "run_async") else curr._run(shared) curr = self.get_next_node(curr, c) async def _run_async(self, shared): prep_res = self.prep(shared) @@ -100,7 +92,6 @@ class BatchAsyncFlow(BaseBatchFlow, AsyncFlow): async def _run_async(self, shared): prep_res = self.prep(shared) for batch_params in prep_res: - flow_params = self.params.copy() - flow_params.update(batch_params) + flow_params = self.params.copy(); flow_params.update(batch_params) await self._exec_async(shared, flow_params) return await self.post_async(shared, prep_res, None) \ No newline at end of file