Add type hints via .pyi stub file
- Create pocketflow/__init__.pyi with improved type definitions - Remove type hints from runtime code to keep it clean - Use TypeVars and Generics for better type relationships - Replace many 'Any' types with more specific types like ParamValue, SharedData, Params - Maintain original compact code style in __init__.py
This commit is contained in:
parent
58bf50c87e
commit
e58fbab70f
|
|
@ -1,33 +1,32 @@
|
|||
import asyncio, warnings, copy, time
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
|
||||
class BaseNode:
|
||||
def __init__(self) -> None: self.params,self.successors={},{}
|
||||
def set_params(self,params: Dict[str, Any]) -> None: self.params=params
|
||||
def next(self,node: 'BaseNode',action: str="default") -> 'BaseNode':
|
||||
def __init__(self): self.params,self.successors={},{}
|
||||
def set_params(self,params): self.params=params
|
||||
def next(self,node,action="default"):
|
||||
if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'")
|
||||
self.successors[action]=node; return node
|
||||
def prep(self,shared: Dict[str, Any]) -> Any: pass
|
||||
def exec(self,prep_res: Any) -> Any: pass
|
||||
def post(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: pass
|
||||
def _exec(self,prep_res: Any) -> Any: return self.exec(prep_res)
|
||||
def _run(self,shared: Dict[str, Any]) -> Any: p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e)
|
||||
def run(self,shared: Dict[str, Any]) -> Any:
|
||||
def prep(self,shared): pass
|
||||
def exec(self,prep_res): pass
|
||||
def post(self,shared,prep_res,exec_res): pass
|
||||
def _exec(self,prep_res): return self.exec(prep_res)
|
||||
def _run(self,shared): p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e)
|
||||
def run(self,shared):
|
||||
if self.successors: warnings.warn("Node won't run successors. Use Flow.")
|
||||
return self._run(shared)
|
||||
def __rshift__(self,other: 'BaseNode') -> 'BaseNode': return self.next(other)
|
||||
def __sub__(self,action: str) -> '_ConditionalTransition':
|
||||
def __rshift__(self,other): return self.next(other)
|
||||
def __sub__(self,action):
|
||||
if isinstance(action,str): return _ConditionalTransition(self,action)
|
||||
raise TypeError("Action must be a string")
|
||||
|
||||
class _ConditionalTransition:
|
||||
def __init__(self,src: BaseNode,action: str) -> None: self.src,self.action=src,action
|
||||
def __rshift__(self,tgt: BaseNode) -> BaseNode: return self.src.next(tgt,self.action)
|
||||
def __init__(self,src,action): self.src,self.action=src,action
|
||||
def __rshift__(self,tgt): return self.src.next(tgt,self.action)
|
||||
|
||||
class Node(BaseNode):
|
||||
def __init__(self,max_retries: int=1,wait: Union[int, float]=0) -> None: super().__init__(); self.max_retries,self.wait=max_retries,wait
|
||||
def exec_fallback(self,prep_res: Any,exc: Exception) -> Any: raise exc
|
||||
def _exec(self,prep_res: Any) -> Any:
|
||||
def __init__(self,max_retries=1,wait=0): super().__init__(); self.max_retries,self.wait=max_retries,wait
|
||||
def exec_fallback(self,prep_res,exc): raise exc
|
||||
def _exec(self,prep_res):
|
||||
for self.cur_retry in range(self.max_retries):
|
||||
try: return self.exec(prep_res)
|
||||
except Exception as e:
|
||||
|
|
@ -35,67 +34,67 @@ class Node(BaseNode):
|
|||
if self.wait>0: time.sleep(self.wait)
|
||||
|
||||
class BatchNode(Node):
|
||||
def _exec(self,items: Optional[List[Any]]) -> List[Any]: return [super(BatchNode,self)._exec(i) for i in (items or [])]
|
||||
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])]
|
||||
|
||||
class Flow(BaseNode):
|
||||
def __init__(self,start: Optional[BaseNode]=None) -> None: super().__init__(); self.start_node=start
|
||||
def start(self,start: BaseNode) -> BaseNode: self.start_node=start; return start
|
||||
def get_next_node(self,curr: BaseNode,action: Optional[str]) -> Optional[BaseNode]:
|
||||
def __init__(self,start=None): super().__init__(); self.start_node=start
|
||||
def start(self,start): self.start_node=start; return start
|
||||
def get_next_node(self,curr,action):
|
||||
nxt=curr.successors.get(action or "default")
|
||||
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
|
||||
return nxt
|
||||
def _orch(self,shared: Dict[str, Any],params: Optional[Dict[str, Any]]=None) -> Any:
|
||||
def _orch(self,shared,params=None):
|
||||
curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
|
||||
while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
|
||||
return last_action
|
||||
def _run(self,shared: Dict[str, Any]) -> Any: p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o)
|
||||
def post(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: return exec_res
|
||||
def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o)
|
||||
def post(self,shared,prep_res,exec_res): return exec_res
|
||||
|
||||
class BatchFlow(Flow):
|
||||
def _run(self,shared: Dict[str, Any]) -> Any:
|
||||
def _run(self,shared):
|
||||
pr=self.prep(shared) or []
|
||||
for bp in pr: self._orch(shared,{**self.params,**bp})
|
||||
return self.post(shared,pr,None)
|
||||
|
||||
class AsyncNode(Node):
|
||||
async def prep_async(self,shared: Dict[str, Any]) -> Any: pass
|
||||
async def exec_async(self,prep_res: Any) -> Any: pass
|
||||
async def exec_fallback_async(self,prep_res: Any,exc: Exception) -> Any: raise exc
|
||||
async def post_async(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: pass
|
||||
async def _exec(self,prep_res: Any) -> Any:
|
||||
async def prep_async(self,shared): pass
|
||||
async def exec_async(self,prep_res): pass
|
||||
async def exec_fallback_async(self,prep_res,exc): raise exc
|
||||
async def post_async(self,shared,prep_res,exec_res): pass
|
||||
async def _exec(self,prep_res):
|
||||
for i in range(self.max_retries):
|
||||
try: return await self.exec_async(prep_res)
|
||||
except Exception as e:
|
||||
if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e)
|
||||
if self.wait>0: await asyncio.sleep(self.wait)
|
||||
async def run_async(self,shared: Dict[str, Any]) -> Any:
|
||||
async def run_async(self,shared):
|
||||
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
|
||||
return await self._run_async(shared)
|
||||
async def _run_async(self,shared: Dict[str, Any]) -> Any: p=await self.prep_async(shared); e=await self._exec(p); return await self.post_async(shared,p,e)
|
||||
def _run(self,shared: Dict[str, Any]) -> Any: raise RuntimeError("Use run_async.")
|
||||
async def _run_async(self,shared): p=await self.prep_async(shared); e=await self._exec(p); return await self.post_async(shared,p,e)
|
||||
def _run(self,shared): raise RuntimeError("Use run_async.")
|
||||
|
||||
class AsyncBatchNode(AsyncNode,BatchNode):
|
||||
async def _exec(self,items: Optional[List[Any]]) -> List[Any]: return [await super(AsyncBatchNode,self)._exec(i) for i in items]
|
||||
async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items]
|
||||
|
||||
class AsyncParallelBatchNode(AsyncNode,BatchNode):
|
||||
async def _exec(self,items: Optional[List[Any]]) -> List[Any]: return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))
|
||||
async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))
|
||||
|
||||
class AsyncFlow(Flow,AsyncNode):
|
||||
async def _orch_async(self,shared: Dict[str, Any],params: Optional[Dict[str, Any]]=None) -> Any:
|
||||
async def _orch_async(self,shared,params=None):
|
||||
curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
|
||||
while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
|
||||
return last_action
|
||||
async def _run_async(self,shared: Dict[str, Any]) -> Any: p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o)
|
||||
async def post_async(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: return exec_res
|
||||
async def _run_async(self,shared): p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o)
|
||||
async def post_async(self,shared,prep_res,exec_res): return exec_res
|
||||
|
||||
class AsyncBatchFlow(AsyncFlow,BatchFlow):
|
||||
async def _run_async(self,shared: Dict[str, Any]) -> Any:
|
||||
async def _run_async(self,shared):
|
||||
pr=await self.prep_async(shared) or []
|
||||
for bp in pr: await self._orch_async(shared,{**self.params,**bp})
|
||||
return await self.post_async(shared,pr,None)
|
||||
|
||||
class AsyncParallelBatchFlow(AsyncFlow,BatchFlow):
|
||||
async def _run_async(self,shared: Dict[str, Any]) -> Any:
|
||||
async def _run_async(self,shared):
|
||||
pr=await self.prep_async(shared) or []
|
||||
await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr))
|
||||
return await self.post_async(shared,pr,None)
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
import asyncio
|
||||
from typing import Any, Dict, List, Optional, Union, TypeVar, Generic
|
||||
|
||||
# Type variables for better type relationships
|
||||
_PrepResult = TypeVar('_PrepResult')
|
||||
_ExecResult = TypeVar('_ExecResult')
|
||||
_PostResult = TypeVar('_PostResult')
|
||||
|
||||
# More specific parameter types
|
||||
ParamValue = Union[str, int, float, bool, None, List[Any], Dict[str, Any]]
|
||||
SharedData = Dict[str, Any]
|
||||
Params = Dict[str, ParamValue]
|
||||
|
||||
class BaseNode(Generic[_PrepResult, _ExecResult, _PostResult]):
|
||||
params: Params
|
||||
successors: Dict[str, BaseNode[Any, Any, Any]]
|
||||
|
||||
def __init__(self) -> None: ...
|
||||
def set_params(self, params: Params) -> None: ...
|
||||
def next(self, node: BaseNode[Any, Any, Any], action: str = "default") -> BaseNode[Any, Any, Any]: ...
|
||||
def prep(self, shared: SharedData) -> _PrepResult: ...
|
||||
def exec(self, prep_res: _PrepResult) -> _ExecResult: ...
|
||||
def post(self, shared: SharedData, prep_res: _PrepResult, exec_res: _ExecResult) -> _PostResult: ...
|
||||
def _exec(self, prep_res: _PrepResult) -> _ExecResult: ...
|
||||
def _run(self, shared: SharedData) -> _PostResult: ...
|
||||
def run(self, shared: SharedData) -> _PostResult: ...
|
||||
def __rshift__(self, other: BaseNode[Any, Any, Any]) -> BaseNode[Any, Any, Any]: ...
|
||||
def __sub__(self, action: str) -> _ConditionalTransition: ...
|
||||
|
||||
class _ConditionalTransition:
|
||||
src: BaseNode[Any, Any, Any]
|
||||
action: str
|
||||
|
||||
def __init__(self, src: BaseNode[Any, Any, Any], action: str) -> None: ...
|
||||
def __rshift__(self, tgt: BaseNode[Any, Any, Any]) -> BaseNode[Any, Any, Any]: ...
|
||||
|
||||
class Node(BaseNode[_PrepResult, _ExecResult, _PostResult]):
|
||||
max_retries: int
|
||||
wait: Union[int, float]
|
||||
cur_retry: int
|
||||
|
||||
def __init__(self, max_retries: int = 1, wait: Union[int, float] = 0) -> None: ...
|
||||
def exec_fallback(self, prep_res: _PrepResult, exc: Exception) -> _ExecResult: ...
|
||||
def _exec(self, prep_res: _PrepResult) -> _ExecResult: ...
|
||||
|
||||
class BatchNode(Node[Optional[List[_PrepResult]], List[_ExecResult], _PostResult]):
|
||||
def _exec(self, items: Optional[List[_PrepResult]]) -> List[_ExecResult]: ...
|
||||
|
||||
class Flow(BaseNode[_PrepResult, Any, _PostResult]):
|
||||
start_node: Optional[BaseNode[Any, Any, Any]]
|
||||
|
||||
def __init__(self, start: Optional[BaseNode[Any, Any, Any]] = None) -> None: ...
|
||||
def start(self, start: BaseNode[Any, Any, Any]) -> BaseNode[Any, Any, Any]: ...
|
||||
def get_next_node(
|
||||
self, curr: BaseNode[Any, Any, Any], action: Optional[str]
|
||||
) -> Optional[BaseNode[Any, Any, Any]]: ...
|
||||
def _orch(
|
||||
self, shared: SharedData, params: Optional[Params] = None
|
||||
) -> Any: ...
|
||||
def _run(self, shared: SharedData) -> _PostResult: ...
|
||||
def post(self, shared: SharedData, prep_res: _PrepResult, exec_res: Any) -> _PostResult: ...
|
||||
|
||||
class BatchFlow(Flow[Optional[List[Params]], Any, _PostResult]):
|
||||
def _run(self, shared: SharedData) -> _PostResult: ...
|
||||
|
||||
class AsyncNode(Node[_PrepResult, _ExecResult, _PostResult]):
|
||||
async def prep_async(self, shared: SharedData) -> _PrepResult: ...
|
||||
async def exec_async(self, prep_res: _PrepResult) -> _ExecResult: ...
|
||||
async def exec_fallback_async(self, prep_res: _PrepResult, exc: Exception) -> _ExecResult: ...
|
||||
async def post_async(
|
||||
self, shared: SharedData, prep_res: _PrepResult, exec_res: _ExecResult
|
||||
) -> _PostResult: ...
|
||||
async def _exec(self, prep_res: _PrepResult) -> _ExecResult: ...
|
||||
async def run_async(self, shared: SharedData) -> _PostResult: ...
|
||||
async def _run_async(self, shared: SharedData) -> _PostResult: ...
|
||||
def _run(self, shared: SharedData) -> _PostResult: ...
|
||||
|
||||
class AsyncBatchNode(AsyncNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult], BatchNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult]):
|
||||
async def _exec(self, items: Optional[List[_PrepResult]]) -> List[_ExecResult]: ...
|
||||
|
||||
class AsyncParallelBatchNode(AsyncNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult], BatchNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult]):
|
||||
async def _exec(self, items: Optional[List[_PrepResult]]) -> List[_ExecResult]: ...
|
||||
|
||||
class AsyncFlow(Flow[_PrepResult, Any, _PostResult], AsyncNode[_PrepResult, Any, _PostResult]):
|
||||
async def _orch_async(
|
||||
self, shared: SharedData, params: Optional[Params] = None
|
||||
) -> Any: ...
|
||||
async def _run_async(self, shared: SharedData) -> _PostResult: ...
|
||||
async def post_async(
|
||||
self, shared: SharedData, prep_res: _PrepResult, exec_res: Any
|
||||
) -> _PostResult: ...
|
||||
|
||||
class AsyncBatchFlow(AsyncFlow[Optional[List[Params]], Any, _PostResult], BatchFlow[Optional[List[Params]], Any, _PostResult]):
|
||||
async def _run_async(self, shared: SharedData) -> _PostResult: ...
|
||||
|
||||
class AsyncParallelBatchFlow(AsyncFlow[Optional[List[Params]], Any, _PostResult], BatchFlow[Optional[List[Params]], Any, _PostResult]):
|
||||
async def _run_async(self, shared: SharedData) -> _PostResult: ...
|
||||
Loading…
Reference in New Issue