adding type hints

This commit is contained in:
raceychan 2025-07-06 14:06:25 +08:00
parent d360ba8d08
commit 58bf50c87e
1 changed files with 40 additions and 39 deletions

View File

@ -1,32 +1,33 @@
import asyncio, warnings, copy, time import asyncio, warnings, copy, time
from typing import Dict, Any, Optional, List, Union
class BaseNode: class BaseNode:
def __init__(self): self.params,self.successors={},{} def __init__(self) -> None: self.params,self.successors={},{}
def set_params(self,params): self.params=params def set_params(self,params: Dict[str, Any]) -> None: self.params=params
def next(self,node,action="default"): def next(self,node: 'BaseNode',action: str="default") -> 'BaseNode':
if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'") if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'")
self.successors[action]=node; return node self.successors[action]=node; return node
def prep(self,shared): pass def prep(self,shared: Dict[str, Any]) -> Any: pass
def exec(self,prep_res): pass def exec(self,prep_res: Any) -> Any: pass
def post(self,shared,prep_res,exec_res): pass def post(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: pass
def _exec(self,prep_res): return self.exec(prep_res) def _exec(self,prep_res: Any) -> Any: return self.exec(prep_res)
def _run(self,shared): p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e) def _run(self,shared: Dict[str, Any]) -> Any: p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e)
def run(self,shared): def run(self,shared: Dict[str, Any]) -> Any:
if self.successors: warnings.warn("Node won't run successors. Use Flow.") if self.successors: warnings.warn("Node won't run successors. Use Flow.")
return self._run(shared) return self._run(shared)
def __rshift__(self,other): return self.next(other) def __rshift__(self,other: 'BaseNode') -> 'BaseNode': return self.next(other)
def __sub__(self,action): def __sub__(self,action: str) -> '_ConditionalTransition':
if isinstance(action,str): return _ConditionalTransition(self,action) if isinstance(action,str): return _ConditionalTransition(self,action)
raise TypeError("Action must be a string") raise TypeError("Action must be a string")
class _ConditionalTransition: class _ConditionalTransition:
def __init__(self,src,action): self.src,self.action=src,action def __init__(self,src: BaseNode,action: str) -> None: self.src,self.action=src,action
def __rshift__(self,tgt): return self.src.next(tgt,self.action) def __rshift__(self,tgt: BaseNode) -> BaseNode: return self.src.next(tgt,self.action)
class Node(BaseNode): class Node(BaseNode):
def __init__(self,max_retries=1,wait=0): super().__init__(); self.max_retries,self.wait=max_retries,wait def __init__(self,max_retries: int=1,wait: Union[int, float]=0) -> None: super().__init__(); self.max_retries,self.wait=max_retries,wait
def exec_fallback(self,prep_res,exc): raise exc def exec_fallback(self,prep_res: Any,exc: Exception) -> Any: raise exc
def _exec(self,prep_res): def _exec(self,prep_res: Any) -> Any:
for self.cur_retry in range(self.max_retries): for self.cur_retry in range(self.max_retries):
try: return self.exec(prep_res) try: return self.exec(prep_res)
except Exception as e: except Exception as e:
@ -34,67 +35,67 @@ class Node(BaseNode):
if self.wait>0: time.sleep(self.wait) if self.wait>0: time.sleep(self.wait)
class BatchNode(Node): class BatchNode(Node):
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])] def _exec(self,items: Optional[List[Any]]) -> List[Any]: return [super(BatchNode,self)._exec(i) for i in (items or [])]
class Flow(BaseNode): class Flow(BaseNode):
def __init__(self,start=None): super().__init__(); self.start_node=start def __init__(self,start: Optional[BaseNode]=None) -> None: super().__init__(); self.start_node=start
def start(self,start): self.start_node=start; return start def start(self,start: BaseNode) -> BaseNode: self.start_node=start; return start
def get_next_node(self,curr,action): def get_next_node(self,curr: BaseNode,action: Optional[str]) -> Optional[BaseNode]:
nxt=curr.successors.get(action or "default") nxt=curr.successors.get(action or "default")
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
return nxt return nxt
def _orch(self,shared,params=None): def _orch(self,shared: Dict[str, Any],params: Optional[Dict[str, Any]]=None) -> Any:
curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
return last_action return last_action
def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o) def _run(self,shared: Dict[str, Any]) -> Any: p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o)
def post(self,shared,prep_res,exec_res): return exec_res def post(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: return exec_res
class BatchFlow(Flow): class BatchFlow(Flow):
def _run(self,shared): def _run(self,shared: Dict[str, Any]) -> Any:
pr=self.prep(shared) or [] pr=self.prep(shared) or []
for bp in pr: self._orch(shared,{**self.params,**bp}) for bp in pr: self._orch(shared,{**self.params,**bp})
return self.post(shared,pr,None) return self.post(shared,pr,None)
class AsyncNode(Node): class AsyncNode(Node):
async def prep_async(self,shared): pass async def prep_async(self,shared: Dict[str, Any]) -> Any: pass
async def exec_async(self,prep_res): pass async def exec_async(self,prep_res: Any) -> Any: pass
async def exec_fallback_async(self,prep_res,exc): raise exc async def exec_fallback_async(self,prep_res: Any,exc: Exception) -> Any: raise exc
async def post_async(self,shared,prep_res,exec_res): pass async def post_async(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: pass
async def _exec(self,prep_res): async def _exec(self,prep_res: Any) -> Any:
for i in range(self.max_retries): for i in range(self.max_retries):
try: return await self.exec_async(prep_res) try: return await self.exec_async(prep_res)
except Exception as e: except Exception as e:
if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e) if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e)
if self.wait>0: await asyncio.sleep(self.wait) if self.wait>0: await asyncio.sleep(self.wait)
async def run_async(self,shared): async def run_async(self,shared: Dict[str, Any]) -> Any:
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.") if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
return await self._run_async(shared) return await self._run_async(shared)
async def _run_async(self,shared): p=await self.prep_async(shared); e=await self._exec(p); return await self.post_async(shared,p,e) async def _run_async(self,shared: Dict[str, Any]) -> Any: p=await self.prep_async(shared); e=await self._exec(p); return await self.post_async(shared,p,e)
def _run(self,shared): raise RuntimeError("Use run_async.") def _run(self,shared: Dict[str, Any]) -> Any: raise RuntimeError("Use run_async.")
class AsyncBatchNode(AsyncNode,BatchNode): class AsyncBatchNode(AsyncNode,BatchNode):
async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items] async def _exec(self,items: Optional[List[Any]]) -> List[Any]: return [await super(AsyncBatchNode,self)._exec(i) for i in items]
class AsyncParallelBatchNode(AsyncNode,BatchNode): class AsyncParallelBatchNode(AsyncNode,BatchNode):
async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items)) async def _exec(self,items: Optional[List[Any]]) -> List[Any]: return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))
class AsyncFlow(Flow,AsyncNode): class AsyncFlow(Flow,AsyncNode):
async def _orch_async(self,shared,params=None): async def _orch_async(self,shared: Dict[str, Any],params: Optional[Dict[str, Any]]=None) -> Any:
curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
return last_action return last_action
async def _run_async(self,shared): p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o) async def _run_async(self,shared: Dict[str, Any]) -> Any: p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o)
async def post_async(self,shared,prep_res,exec_res): return exec_res async def post_async(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: return exec_res
class AsyncBatchFlow(AsyncFlow,BatchFlow): class AsyncBatchFlow(AsyncFlow,BatchFlow):
async def _run_async(self,shared): async def _run_async(self,shared: Dict[str, Any]) -> Any:
pr=await self.prep_async(shared) or [] pr=await self.prep_async(shared) or []
for bp in pr: await self._orch_async(shared,{**self.params,**bp}) for bp in pr: await self._orch_async(shared,{**self.params,**bp})
return await self.post_async(shared,pr,None) return await self.post_async(shared,pr,None)
class AsyncParallelBatchFlow(AsyncFlow,BatchFlow): class AsyncParallelBatchFlow(AsyncFlow,BatchFlow):
async def _run_async(self,shared): async def _run_async(self,shared: Dict[str, Any]) -> Any:
pr=await self.prep_async(shared) or [] pr=await self.prep_async(shared) or []
await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr)) await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr))
return await self.post_async(shared,pr,None) return await self.post_async(shared,pr,None)