From 58bf50c87e30a7a115cd5c2b4535b99e86804ea9 Mon Sep 17 00:00:00 2001 From: raceychan Date: Sun, 6 Jul 2025 14:06:25 +0800 Subject: [PATCH 1/2] adding type hints --- pocketflow/__init__.py | 79 +++++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/pocketflow/__init__.py b/pocketflow/__init__.py index a7203df..5534c23 100644 --- a/pocketflow/__init__.py +++ b/pocketflow/__init__.py @@ -1,32 +1,33 @@ import asyncio, warnings, copy, time +from typing import Dict, Any, Optional, List, Union class BaseNode: - def __init__(self): self.params,self.successors={},{} - def set_params(self,params): self.params=params - def next(self,node,action="default"): + def __init__(self) -> None: self.params,self.successors={},{} + def set_params(self,params: Dict[str, Any]) -> None: self.params=params + def next(self,node: 'BaseNode',action: str="default") -> 'BaseNode': if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'") self.successors[action]=node; return node - def prep(self,shared): pass - def exec(self,prep_res): pass - def post(self,shared,prep_res,exec_res): pass - def _exec(self,prep_res): return self.exec(prep_res) - def _run(self,shared): p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e) - def run(self,shared): + def prep(self,shared: Dict[str, Any]) -> Any: pass + def exec(self,prep_res: Any) -> Any: pass + def post(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: pass + def _exec(self,prep_res: Any) -> Any: return self.exec(prep_res) + def _run(self,shared: Dict[str, Any]) -> Any: p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e) + def run(self,shared: Dict[str, Any]) -> Any: if self.successors: warnings.warn("Node won't run successors. Use Flow.") return self._run(shared) - def __rshift__(self,other): return self.next(other) - def __sub__(self,action): + def __rshift__(self,other: 'BaseNode') -> 'BaseNode': return self.next(other) + def __sub__(self,action: str) -> '_ConditionalTransition': if isinstance(action,str): return _ConditionalTransition(self,action) raise TypeError("Action must be a string") class _ConditionalTransition: - def __init__(self,src,action): self.src,self.action=src,action - def __rshift__(self,tgt): return self.src.next(tgt,self.action) + def __init__(self,src: BaseNode,action: str) -> None: self.src,self.action=src,action + def __rshift__(self,tgt: BaseNode) -> BaseNode: return self.src.next(tgt,self.action) class Node(BaseNode): - def __init__(self,max_retries=1,wait=0): super().__init__(); self.max_retries,self.wait=max_retries,wait - def exec_fallback(self,prep_res,exc): raise exc - def _exec(self,prep_res): + def __init__(self,max_retries: int=1,wait: Union[int, float]=0) -> None: super().__init__(); self.max_retries,self.wait=max_retries,wait + def exec_fallback(self,prep_res: Any,exc: Exception) -> Any: raise exc + def _exec(self,prep_res: Any) -> Any: for self.cur_retry in range(self.max_retries): try: return self.exec(prep_res) except Exception as e: @@ -34,67 +35,67 @@ class Node(BaseNode): if self.wait>0: time.sleep(self.wait) class BatchNode(Node): - def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])] + def _exec(self,items: Optional[List[Any]]) -> List[Any]: return [super(BatchNode,self)._exec(i) for i in (items or [])] class Flow(BaseNode): - def __init__(self,start=None): super().__init__(); self.start_node=start - def start(self,start): self.start_node=start; return start - def get_next_node(self,curr,action): + def __init__(self,start: Optional[BaseNode]=None) -> None: super().__init__(); self.start_node=start + def start(self,start: BaseNode) -> BaseNode: self.start_node=start; return start + def get_next_node(self,curr: BaseNode,action: Optional[str]) -> Optional[BaseNode]: nxt=curr.successors.get(action or "default") if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") return nxt - def _orch(self,shared,params=None): + def _orch(self,shared: Dict[str, Any],params: Optional[Dict[str, Any]]=None) -> Any: curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) return last_action - def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o) - def post(self,shared,prep_res,exec_res): return exec_res + def _run(self,shared: Dict[str, Any]) -> Any: p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o) + def post(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: return exec_res class BatchFlow(Flow): - def _run(self,shared): + def _run(self,shared: Dict[str, Any]) -> Any: pr=self.prep(shared) or [] for bp in pr: self._orch(shared,{**self.params,**bp}) return self.post(shared,pr,None) class AsyncNode(Node): - async def prep_async(self,shared): pass - async def exec_async(self,prep_res): pass - async def exec_fallback_async(self,prep_res,exc): raise exc - async def post_async(self,shared,prep_res,exec_res): pass - async def _exec(self,prep_res): + async def prep_async(self,shared: Dict[str, Any]) -> Any: pass + async def exec_async(self,prep_res: Any) -> Any: pass + async def exec_fallback_async(self,prep_res: Any,exc: Exception) -> Any: raise exc + async def post_async(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: pass + async def _exec(self,prep_res: Any) -> Any: for i in range(self.max_retries): try: return await self.exec_async(prep_res) except Exception as e: if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e) if self.wait>0: await asyncio.sleep(self.wait) - async def run_async(self,shared): + async def run_async(self,shared: Dict[str, Any]) -> Any: if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.") return await self._run_async(shared) - async def _run_async(self,shared): p=await self.prep_async(shared); e=await self._exec(p); return await self.post_async(shared,p,e) - def _run(self,shared): raise RuntimeError("Use run_async.") + async def _run_async(self,shared: Dict[str, Any]) -> Any: p=await self.prep_async(shared); e=await self._exec(p); return await self.post_async(shared,p,e) + def _run(self,shared: Dict[str, Any]) -> Any: raise RuntimeError("Use run_async.") class AsyncBatchNode(AsyncNode,BatchNode): - async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items] + async def _exec(self,items: Optional[List[Any]]) -> List[Any]: return [await super(AsyncBatchNode,self)._exec(i) for i in items] class AsyncParallelBatchNode(AsyncNode,BatchNode): - async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items)) + async def _exec(self,items: Optional[List[Any]]) -> List[Any]: return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items)) class AsyncFlow(Flow,AsyncNode): - async def _orch_async(self,shared,params=None): + async def _orch_async(self,shared: Dict[str, Any],params: Optional[Dict[str, Any]]=None) -> Any: curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) return last_action - async def _run_async(self,shared): p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o) - async def post_async(self,shared,prep_res,exec_res): return exec_res + async def _run_async(self,shared: Dict[str, Any]) -> Any: p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o) + async def post_async(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: return exec_res class AsyncBatchFlow(AsyncFlow,BatchFlow): - async def _run_async(self,shared): + async def _run_async(self,shared: Dict[str, Any]) -> Any: pr=await self.prep_async(shared) or [] for bp in pr: await self._orch_async(shared,{**self.params,**bp}) return await self.post_async(shared,pr,None) class AsyncParallelBatchFlow(AsyncFlow,BatchFlow): - async def _run_async(self,shared): + async def _run_async(self,shared: Dict[str, Any]) -> Any: pr=await self.prep_async(shared) or [] await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr)) return await self.post_async(shared,pr,None) \ No newline at end of file From e58fbab70fb9d65c459ae5dbf469d8168bfa0f25 Mon Sep 17 00:00:00 2001 From: raceychan Date: Mon, 7 Jul 2025 01:24:50 +0800 Subject: [PATCH 2/2] Add type hints via .pyi stub file - Create pocketflow/__init__.pyi with improved type definitions - Remove type hints from runtime code to keep it clean - Use TypeVars and Generics for better type relationships - Replace many 'Any' types with more specific types like ParamValue, SharedData, Params - Maintain original compact code style in __init__.py --- pocketflow/__init__.py | 79 +++++++++++++++++---------------- pocketflow/__init__.pyi | 97 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 40 deletions(-) create mode 100644 pocketflow/__init__.pyi diff --git a/pocketflow/__init__.py b/pocketflow/__init__.py index 5534c23..a7203df 100644 --- a/pocketflow/__init__.py +++ b/pocketflow/__init__.py @@ -1,33 +1,32 @@ import asyncio, warnings, copy, time -from typing import Dict, Any, Optional, List, Union class BaseNode: - def __init__(self) -> None: self.params,self.successors={},{} - def set_params(self,params: Dict[str, Any]) -> None: self.params=params - def next(self,node: 'BaseNode',action: str="default") -> 'BaseNode': + def __init__(self): self.params,self.successors={},{} + def set_params(self,params): self.params=params + def next(self,node,action="default"): if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'") self.successors[action]=node; return node - def prep(self,shared: Dict[str, Any]) -> Any: pass - def exec(self,prep_res: Any) -> Any: pass - def post(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: pass - def _exec(self,prep_res: Any) -> Any: return self.exec(prep_res) - def _run(self,shared: Dict[str, Any]) -> Any: p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e) - def run(self,shared: Dict[str, Any]) -> Any: + def prep(self,shared): pass + def exec(self,prep_res): pass + def post(self,shared,prep_res,exec_res): pass + def _exec(self,prep_res): return self.exec(prep_res) + def _run(self,shared): p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e) + def run(self,shared): if self.successors: warnings.warn("Node won't run successors. Use Flow.") return self._run(shared) - def __rshift__(self,other: 'BaseNode') -> 'BaseNode': return self.next(other) - def __sub__(self,action: str) -> '_ConditionalTransition': + def __rshift__(self,other): return self.next(other) + def __sub__(self,action): if isinstance(action,str): return _ConditionalTransition(self,action) raise TypeError("Action must be a string") class _ConditionalTransition: - def __init__(self,src: BaseNode,action: str) -> None: self.src,self.action=src,action - def __rshift__(self,tgt: BaseNode) -> BaseNode: return self.src.next(tgt,self.action) + def __init__(self,src,action): self.src,self.action=src,action + def __rshift__(self,tgt): return self.src.next(tgt,self.action) class Node(BaseNode): - def __init__(self,max_retries: int=1,wait: Union[int, float]=0) -> None: super().__init__(); self.max_retries,self.wait=max_retries,wait - def exec_fallback(self,prep_res: Any,exc: Exception) -> Any: raise exc - def _exec(self,prep_res: Any) -> Any: + def __init__(self,max_retries=1,wait=0): super().__init__(); self.max_retries,self.wait=max_retries,wait + def exec_fallback(self,prep_res,exc): raise exc + def _exec(self,prep_res): for self.cur_retry in range(self.max_retries): try: return self.exec(prep_res) except Exception as e: @@ -35,67 +34,67 @@ class Node(BaseNode): if self.wait>0: time.sleep(self.wait) class BatchNode(Node): - def _exec(self,items: Optional[List[Any]]) -> List[Any]: return [super(BatchNode,self)._exec(i) for i in (items or [])] + def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])] class Flow(BaseNode): - def __init__(self,start: Optional[BaseNode]=None) -> None: super().__init__(); self.start_node=start - def start(self,start: BaseNode) -> BaseNode: self.start_node=start; return start - def get_next_node(self,curr: BaseNode,action: Optional[str]) -> Optional[BaseNode]: + def __init__(self,start=None): super().__init__(); self.start_node=start + def start(self,start): self.start_node=start; return start + def get_next_node(self,curr,action): nxt=curr.successors.get(action or "default") if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") return nxt - def _orch(self,shared: Dict[str, Any],params: Optional[Dict[str, Any]]=None) -> Any: + def _orch(self,shared,params=None): curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) return last_action - def _run(self,shared: Dict[str, Any]) -> Any: p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o) - def post(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: return exec_res + def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o) + def post(self,shared,prep_res,exec_res): return exec_res class BatchFlow(Flow): - def _run(self,shared: Dict[str, Any]) -> Any: + def _run(self,shared): pr=self.prep(shared) or [] for bp in pr: self._orch(shared,{**self.params,**bp}) return self.post(shared,pr,None) class AsyncNode(Node): - async def prep_async(self,shared: Dict[str, Any]) -> Any: pass - async def exec_async(self,prep_res: Any) -> Any: pass - async def exec_fallback_async(self,prep_res: Any,exc: Exception) -> Any: raise exc - async def post_async(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: pass - async def _exec(self,prep_res: Any) -> Any: + async def prep_async(self,shared): pass + async def exec_async(self,prep_res): pass + async def exec_fallback_async(self,prep_res,exc): raise exc + async def post_async(self,shared,prep_res,exec_res): pass + async def _exec(self,prep_res): for i in range(self.max_retries): try: return await self.exec_async(prep_res) except Exception as e: if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e) if self.wait>0: await asyncio.sleep(self.wait) - async def run_async(self,shared: Dict[str, Any]) -> Any: + async def run_async(self,shared): if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.") return await self._run_async(shared) - async def _run_async(self,shared: Dict[str, Any]) -> Any: p=await self.prep_async(shared); e=await self._exec(p); return await self.post_async(shared,p,e) - def _run(self,shared: Dict[str, Any]) -> Any: raise RuntimeError("Use run_async.") + async def _run_async(self,shared): p=await self.prep_async(shared); e=await self._exec(p); return await self.post_async(shared,p,e) + def _run(self,shared): raise RuntimeError("Use run_async.") class AsyncBatchNode(AsyncNode,BatchNode): - async def _exec(self,items: Optional[List[Any]]) -> List[Any]: return [await super(AsyncBatchNode,self)._exec(i) for i in items] + async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items] class AsyncParallelBatchNode(AsyncNode,BatchNode): - async def _exec(self,items: Optional[List[Any]]) -> List[Any]: return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items)) + async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items)) class AsyncFlow(Flow,AsyncNode): - async def _orch_async(self,shared: Dict[str, Any],params: Optional[Dict[str, Any]]=None) -> Any: + async def _orch_async(self,shared,params=None): curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action)) return last_action - async def _run_async(self,shared: Dict[str, Any]) -> Any: p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o) - async def post_async(self,shared: Dict[str, Any],prep_res: Any,exec_res: Any) -> Any: return exec_res + async def _run_async(self,shared): p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o) + async def post_async(self,shared,prep_res,exec_res): return exec_res class AsyncBatchFlow(AsyncFlow,BatchFlow): - async def _run_async(self,shared: Dict[str, Any]) -> Any: + async def _run_async(self,shared): pr=await self.prep_async(shared) or [] for bp in pr: await self._orch_async(shared,{**self.params,**bp}) return await self.post_async(shared,pr,None) class AsyncParallelBatchFlow(AsyncFlow,BatchFlow): - async def _run_async(self,shared: Dict[str, Any]) -> Any: + async def _run_async(self,shared): pr=await self.prep_async(shared) or [] await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr)) return await self.post_async(shared,pr,None) \ No newline at end of file diff --git a/pocketflow/__init__.pyi b/pocketflow/__init__.pyi new file mode 100644 index 0000000..220d2b4 --- /dev/null +++ b/pocketflow/__init__.pyi @@ -0,0 +1,97 @@ +import asyncio +from typing import Any, Dict, List, Optional, Union, TypeVar, Generic + +# Type variables for better type relationships +_PrepResult = TypeVar('_PrepResult') +_ExecResult = TypeVar('_ExecResult') +_PostResult = TypeVar('_PostResult') + +# More specific parameter types +ParamValue = Union[str, int, float, bool, None, List[Any], Dict[str, Any]] +SharedData = Dict[str, Any] +Params = Dict[str, ParamValue] + +class BaseNode(Generic[_PrepResult, _ExecResult, _PostResult]): + params: Params + successors: Dict[str, BaseNode[Any, Any, Any]] + + def __init__(self) -> None: ... + def set_params(self, params: Params) -> None: ... + def next(self, node: BaseNode[Any, Any, Any], action: str = "default") -> BaseNode[Any, Any, Any]: ... + def prep(self, shared: SharedData) -> _PrepResult: ... + def exec(self, prep_res: _PrepResult) -> _ExecResult: ... + def post(self, shared: SharedData, prep_res: _PrepResult, exec_res: _ExecResult) -> _PostResult: ... + def _exec(self, prep_res: _PrepResult) -> _ExecResult: ... + def _run(self, shared: SharedData) -> _PostResult: ... + def run(self, shared: SharedData) -> _PostResult: ... + def __rshift__(self, other: BaseNode[Any, Any, Any]) -> BaseNode[Any, Any, Any]: ... + def __sub__(self, action: str) -> _ConditionalTransition: ... + +class _ConditionalTransition: + src: BaseNode[Any, Any, Any] + action: str + + def __init__(self, src: BaseNode[Any, Any, Any], action: str) -> None: ... + def __rshift__(self, tgt: BaseNode[Any, Any, Any]) -> BaseNode[Any, Any, Any]: ... + +class Node(BaseNode[_PrepResult, _ExecResult, _PostResult]): + max_retries: int + wait: Union[int, float] + cur_retry: int + + def __init__(self, max_retries: int = 1, wait: Union[int, float] = 0) -> None: ... + def exec_fallback(self, prep_res: _PrepResult, exc: Exception) -> _ExecResult: ... + def _exec(self, prep_res: _PrepResult) -> _ExecResult: ... + +class BatchNode(Node[Optional[List[_PrepResult]], List[_ExecResult], _PostResult]): + def _exec(self, items: Optional[List[_PrepResult]]) -> List[_ExecResult]: ... + +class Flow(BaseNode[_PrepResult, Any, _PostResult]): + start_node: Optional[BaseNode[Any, Any, Any]] + + def __init__(self, start: Optional[BaseNode[Any, Any, Any]] = None) -> None: ... + def start(self, start: BaseNode[Any, Any, Any]) -> BaseNode[Any, Any, Any]: ... + def get_next_node( + self, curr: BaseNode[Any, Any, Any], action: Optional[str] + ) -> Optional[BaseNode[Any, Any, Any]]: ... + def _orch( + self, shared: SharedData, params: Optional[Params] = None + ) -> Any: ... + def _run(self, shared: SharedData) -> _PostResult: ... + def post(self, shared: SharedData, prep_res: _PrepResult, exec_res: Any) -> _PostResult: ... + +class BatchFlow(Flow[Optional[List[Params]], Any, _PostResult]): + def _run(self, shared: SharedData) -> _PostResult: ... + +class AsyncNode(Node[_PrepResult, _ExecResult, _PostResult]): + async def prep_async(self, shared: SharedData) -> _PrepResult: ... + async def exec_async(self, prep_res: _PrepResult) -> _ExecResult: ... + async def exec_fallback_async(self, prep_res: _PrepResult, exc: Exception) -> _ExecResult: ... + async def post_async( + self, shared: SharedData, prep_res: _PrepResult, exec_res: _ExecResult + ) -> _PostResult: ... + async def _exec(self, prep_res: _PrepResult) -> _ExecResult: ... + async def run_async(self, shared: SharedData) -> _PostResult: ... + async def _run_async(self, shared: SharedData) -> _PostResult: ... + def _run(self, shared: SharedData) -> _PostResult: ... + +class AsyncBatchNode(AsyncNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult], BatchNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult]): + async def _exec(self, items: Optional[List[_PrepResult]]) -> List[_ExecResult]: ... + +class AsyncParallelBatchNode(AsyncNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult], BatchNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult]): + async def _exec(self, items: Optional[List[_PrepResult]]) -> List[_ExecResult]: ... + +class AsyncFlow(Flow[_PrepResult, Any, _PostResult], AsyncNode[_PrepResult, Any, _PostResult]): + async def _orch_async( + self, shared: SharedData, params: Optional[Params] = None + ) -> Any: ... + async def _run_async(self, shared: SharedData) -> _PostResult: ... + async def post_async( + self, shared: SharedData, prep_res: _PrepResult, exec_res: Any + ) -> _PostResult: ... + +class AsyncBatchFlow(AsyncFlow[Optional[List[Params]], Any, _PostResult], BatchFlow[Optional[List[Params]], Any, _PostResult]): + async def _run_async(self, shared: SharedData) -> _PostResult: ... + +class AsyncParallelBatchFlow(AsyncFlow[Optional[List[Params]], Any, _PostResult], BatchFlow[Optional[List[Params]], Any, _PostResult]): + async def _run_async(self, shared: SharedData) -> _PostResult: ... \ No newline at end of file