diff --git a/minillmflow/__init__.py b/minillmflow/__init__.py index 513fb16..96305bc 100644 --- a/minillmflow/__init__.py +++ b/minillmflow/__init__.py @@ -4,7 +4,7 @@ class BaseNode: def __init__(self): self.params,self.successors={},{} def set_params(self,params): self.params=params def add_successor(self,node,cond="default"): - if cond in self.successors: warnings.warn(f"Overwriting existing successor for '{cond}'") + if cond in self.successors: warnings.warn(f"Overwriting successor for condition '{cond}'") self.successors[cond]=node;return node def prep(self,shared): return None def exec(self,shared,prep_res): return None @@ -15,7 +15,7 @@ class BaseNode: exec_res=self._exec(shared,prep_res) return self.post(shared,prep_res,exec_res) def run(self,shared): - if self.successors: warnings.warn("This node has successors. Create a parent Flow instead.") + if self.successors: warnings.warn("Node won't run successors. Use a parent Flow instead.") return self._run(shared) def __rshift__(self,other): return self.add_successor(other) def __sub__(self,cond): @@ -27,7 +27,9 @@ class _ConditionalTransition: def __rshift__(self,tgt): return self.src.add_successor(tgt,self.cond) class Node(BaseNode): - def __init__(self,max_retries=1): super().__init__();self.max_retries=max_retries + def __init__(self,max_retries=1): + super().__init__() + self.max_retries=max_retries def process_after_fail(self,shared,prep_res,exc): raise exc def _exec(self,shared,prep_res): for i in range(self.max_retries): @@ -37,50 +39,49 @@ class Node(BaseNode): class BatchNode(Node): def prep(self,shared): return [] - def exec(self,shared,item): return None def _exec(self,shared,items): return [super(Node,self)._exec(shared,i) for i in items] -class BaseFlow(BaseNode): +class Flow(BaseNode): def __init__(self,start_node): super().__init__() self.start_node=start_node def get_next_node(self,curr,cond): - nxt=curr.successors.get(cond) - if not nxt and curr.successors: warnings.warn(f"Flow ends. '{cond}' not among {list(curr.successors)}") + nxt=curr.successors.get(cond if cond is not None else "default") + if not nxt and curr.successors: + warnings.warn(f"Flow ends: condition '{cond}' not found in {list(curr.successors)}") return nxt - -class Flow(BaseFlow): def _exec(self,shared,params=None): curr,p=self.start_node,(params if params else {**self.params}) while curr: curr.set_params(p) c=curr._run(shared) curr=self.get_next_node(curr,c) - def exec(self,shared,prep_res): raise NotImplementedError + def exec(self,shared,prep_res): + raise RuntimeError("Flow should not exec directly. Create a child Node instead.") -class BaseBatchFlow(BaseFlow): +class BatchFlow(Flow): def prep(self,shared): return [] - -class BatchFlow(BaseBatchFlow,Flow): def _run(self,shared): prep_res=self.prep(shared) for batch_params in prep_res:self._exec(shared,{**self.params,**batch_params}) return self.post(shared,prep_res,None) class AsyncNode(Node): - def post(self,shared,prep_res,exec_res): raise NotImplementedError("Use post_async") + def post(self,shared,prep_res,exec_res): + raise RuntimeError("AsyncNode should post using post_async instead.") async def post_async(self,shared,prep_res,exec_res): await asyncio.sleep(0);return "default" async def run_async(self,shared): - if self.successors: warnings.warn("This node has successors. Create a parent AsyncFlow.") + if self.successors: + warnings.warn("Node won't run successors. Use a parent AsyncFlow instead.") return await self._run_async(shared) async def _run_async(self,shared): prep_res=self.prep(shared) exec_res=self._exec(shared,prep_res) return await self.post_async(shared,prep_res,exec_res) - def _run(self,shared): raise RuntimeError("AsyncNode requires async execution") + def _run(self,shared): raise RuntimeError("AsyncNode should run using run_async instead.") -class AsyncFlow(BaseFlow,AsyncNode): +class AsyncFlow(Flow,AsyncNode): async def _exec_async(self,shared,params=None): curr,p=self.start_node,(params if params else {**self.params}) while curr: @@ -92,9 +93,8 @@ class AsyncFlow(BaseFlow,AsyncNode): await self._exec_async(shared) return await self.post_async(shared,prep_res,None) -class BatchAsyncFlow(BaseBatchFlow,AsyncFlow): +class BatchAsyncFlow(BatchFlow,AsyncFlow): async def _run_async(self,shared): prep_res=self.prep(shared) for batch_params in prep_res:await self._exec_async(shared,{**self.params,**batch_params}) - return await self.post_async(shared,prep_res,None) - def exec(self,shared,prep_res): raise NotImplementedError("BatchAsyncFlow does not support exec") \ No newline at end of file + return await self.post_async(shared,prep_res,None) \ No newline at end of file