diff --git a/pocketflow/__init__.pyi b/pocketflow/__init__.pyi new file mode 100644 index 0000000..220d2b4 --- /dev/null +++ b/pocketflow/__init__.pyi @@ -0,0 +1,97 @@ +import asyncio +from typing import Any, Dict, List, Optional, Union, TypeVar, Generic + +# Type variables for better type relationships +_PrepResult = TypeVar('_PrepResult') +_ExecResult = TypeVar('_ExecResult') +_PostResult = TypeVar('_PostResult') + +# More specific parameter types +ParamValue = Union[str, int, float, bool, None, List[Any], Dict[str, Any]] +SharedData = Dict[str, Any] +Params = Dict[str, ParamValue] + +class BaseNode(Generic[_PrepResult, _ExecResult, _PostResult]): + params: Params + successors: Dict[str, BaseNode[Any, Any, Any]] + + def __init__(self) -> None: ... + def set_params(self, params: Params) -> None: ... + def next(self, node: BaseNode[Any, Any, Any], action: str = "default") -> BaseNode[Any, Any, Any]: ... + def prep(self, shared: SharedData) -> _PrepResult: ... + def exec(self, prep_res: _PrepResult) -> _ExecResult: ... + def post(self, shared: SharedData, prep_res: _PrepResult, exec_res: _ExecResult) -> _PostResult: ... + def _exec(self, prep_res: _PrepResult) -> _ExecResult: ... + def _run(self, shared: SharedData) -> _PostResult: ... + def run(self, shared: SharedData) -> _PostResult: ... + def __rshift__(self, other: BaseNode[Any, Any, Any]) -> BaseNode[Any, Any, Any]: ... + def __sub__(self, action: str) -> _ConditionalTransition: ... + +class _ConditionalTransition: + src: BaseNode[Any, Any, Any] + action: str + + def __init__(self, src: BaseNode[Any, Any, Any], action: str) -> None: ... + def __rshift__(self, tgt: BaseNode[Any, Any, Any]) -> BaseNode[Any, Any, Any]: ... + +class Node(BaseNode[_PrepResult, _ExecResult, _PostResult]): + max_retries: int + wait: Union[int, float] + cur_retry: int + + def __init__(self, max_retries: int = 1, wait: Union[int, float] = 0) -> None: ... + def exec_fallback(self, prep_res: _PrepResult, exc: Exception) -> _ExecResult: ... + def _exec(self, prep_res: _PrepResult) -> _ExecResult: ... + +class BatchNode(Node[Optional[List[_PrepResult]], List[_ExecResult], _PostResult]): + def _exec(self, items: Optional[List[_PrepResult]]) -> List[_ExecResult]: ... + +class Flow(BaseNode[_PrepResult, Any, _PostResult]): + start_node: Optional[BaseNode[Any, Any, Any]] + + def __init__(self, start: Optional[BaseNode[Any, Any, Any]] = None) -> None: ... + def start(self, start: BaseNode[Any, Any, Any]) -> BaseNode[Any, Any, Any]: ... + def get_next_node( + self, curr: BaseNode[Any, Any, Any], action: Optional[str] + ) -> Optional[BaseNode[Any, Any, Any]]: ... + def _orch( + self, shared: SharedData, params: Optional[Params] = None + ) -> Any: ... + def _run(self, shared: SharedData) -> _PostResult: ... + def post(self, shared: SharedData, prep_res: _PrepResult, exec_res: Any) -> _PostResult: ... + +class BatchFlow(Flow[Optional[List[Params]], Any, _PostResult]): + def _run(self, shared: SharedData) -> _PostResult: ... + +class AsyncNode(Node[_PrepResult, _ExecResult, _PostResult]): + async def prep_async(self, shared: SharedData) -> _PrepResult: ... + async def exec_async(self, prep_res: _PrepResult) -> _ExecResult: ... + async def exec_fallback_async(self, prep_res: _PrepResult, exc: Exception) -> _ExecResult: ... + async def post_async( + self, shared: SharedData, prep_res: _PrepResult, exec_res: _ExecResult + ) -> _PostResult: ... + async def _exec(self, prep_res: _PrepResult) -> _ExecResult: ... + async def run_async(self, shared: SharedData) -> _PostResult: ... + async def _run_async(self, shared: SharedData) -> _PostResult: ... + def _run(self, shared: SharedData) -> _PostResult: ... + +class AsyncBatchNode(AsyncNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult], BatchNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult]): + async def _exec(self, items: Optional[List[_PrepResult]]) -> List[_ExecResult]: ... + +class AsyncParallelBatchNode(AsyncNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult], BatchNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult]): + async def _exec(self, items: Optional[List[_PrepResult]]) -> List[_ExecResult]: ... + +class AsyncFlow(Flow[_PrepResult, Any, _PostResult], AsyncNode[_PrepResult, Any, _PostResult]): + async def _orch_async( + self, shared: SharedData, params: Optional[Params] = None + ) -> Any: ... + async def _run_async(self, shared: SharedData) -> _PostResult: ... + async def post_async( + self, shared: SharedData, prep_res: _PrepResult, exec_res: Any + ) -> _PostResult: ... + +class AsyncBatchFlow(AsyncFlow[Optional[List[Params]], Any, _PostResult], BatchFlow[Optional[List[Params]], Any, _PostResult]): + async def _run_async(self, shared: SharedData) -> _PostResult: ... + +class AsyncParallelBatchFlow(AsyncFlow[Optional[List[Params]], Any, _PostResult], BatchFlow[Optional[List[Params]], Any, _PostResult]): + async def _run_async(self, shared: SharedData) -> _PostResult: ... \ No newline at end of file