diff --git a/cookbook/pocketflow_a2a/README.md b/cookbook/pocketflow_a2a/README.md new file mode 100644 index 0000000..36aaeca --- /dev/null +++ b/cookbook/pocketflow_a2a/README.md @@ -0,0 +1,67 @@ +# Research Agent + +This project demonstrates a simple yet powerful LLM-powered research agent. This implementation is based directly on the tutorial: [LLM Agents are simply Graph — Tutorial For Dummies](https://zacharyhuang.substack.com/p/llm-agent-internal-as-a-graph-tutorial). + +šŸ‘‰ Run the tutorial in your browser: [Try Google Colab Notebook]( +https://colab.research.google.com/github/The-Pocket/PocketFlow/blob/main/cookbook/pocketflow-agent/demo.ipynb) + +## Features + +- Performs web searches to gather information +- Makes decisions about when to search vs. when to answer +- Generates comprehensive answers based on research findings + +## Getting Started + +1. Install the packages you need with this simple command: +```bash +pip install -r requirements.txt +``` + +2. Let's get your OpenAI API key ready: + +```bash +export OPENAI_API_KEY="your-api-key-here" +``` + +3. Let's do a quick check to make sure your API key is working properly: + +```bash +python utils.py +``` + +This will test both the LLM call and web search features. If you see responses, you're good to go! + +4. Try out the agent with the default question (about Nobel Prize winners): + +```bash +python main.py +``` + +5. Got a burning question? Ask anything you want by using the `--` prefix: + +```bash +python main.py --"What is quantum computing?" +``` + +## How It Works? + +The magic happens through a simple but powerful graph structure with three main parts: + +```mermaid +graph TD + A[DecideAction] -->|"search"| B[SearchWeb] + A -->|"answer"| C[AnswerQuestion] + B -->|"decide"| A +``` + +Here's what each part does: +1. **DecideAction**: The brain that figures out whether to search or answer +2. **SearchWeb**: The researcher that goes out and finds information +3. **AnswerQuestion**: The writer that crafts the final answer + +Here's what's in each file: +- [`main.py`](./main.py): The starting point - runs the whole show! +- [`flow.py`](./flow.py): Connects everything together into a smart agent +- [`nodes.py`](./nodes.py): The building blocks that make decisions and take actions +- [`utils.py`](./utils.py): Helper functions for talking to the LLM and searching the web diff --git a/cookbook/pocketflow_a2a/a2a_client.py b/cookbook/pocketflow_a2a/a2a_client.py new file mode 100644 index 0000000..04c295e --- /dev/null +++ b/cookbook/pocketflow_a2a/a2a_client.py @@ -0,0 +1,143 @@ +# FILE: minimal_a2a_client.py +import asyncio +import asyncclick as click # Using asyncclick for async main +from uuid import uuid4 +import json # For potentially inspecting raw errors + +# Import from the common directory placed alongside this script +from common.client import A2AClient +from common.types import ( + TaskState, + A2AClientError, + TextPart, # Used to construct the message + JSONRPCResponse # Potentially useful for error checking +) + +# --- ANSI Colors (Optional but helpful) --- +C_RED = "\x1b[31m" +C_GREEN = "\x1b[32m" +C_YELLOW = "\x1b[33m" +C_BLUE = "\x1b[34m" +C_MAGENTA = "\x1b[35m" +C_CYAN = "\x1b[36m" +C_WHITE = "\x1b[37m" +C_GRAY = "\x1b[90m" +C_BRIGHT_MAGENTA = "\x1b[95m" +C_RESET = "\x1b[0m" +C_BOLD = "\x1b[1m" + +def colorize(color, text): + return f"{color}{text}{C_RESET}" + +@click.command() +@click.option( + "--agent-url", + default="http://localhost:10003", # Default to the port used in server __main__ + help="URL of the PocketFlow A2A agent server.", +) +async def cli(agent_url: str): + """Minimal CLI client to interact with an A2A agent.""" + + print(colorize(C_BRIGHT_MAGENTA, f"Connecting to agent at: {agent_url}")) + + # Instantiate the client - only URL is needed if not fetching card first + # Note: The PocketFlow wrapper doesn't expose much via the AgentCard, + # so we skip fetching it for this minimal client. + client = A2AClient(url=agent_url) + + sessionId = uuid4().hex # Generate a new session ID for this run + print(colorize(C_GRAY, f"Using Session ID: {sessionId}")) + + while True: + taskId = uuid4().hex # Generate a new task ID for each interaction + try: + prompt = await click.prompt( + colorize(C_CYAN, "\nEnter your question (:q or quit to exit)"), + prompt_suffix=" > ", + type=str # Ensure prompt returns string + ) + except RuntimeError: + # This can happen if stdin is closed, e.g., in some test runners + print(colorize(C_RED, "Failed to read input. Exiting.")) + break + + + if prompt.lower() in [":q", "quit"]: + print(colorize(C_YELLOW, "Exiting client.")) + break + + # --- Construct A2A Request Payload --- + payload = { + "id": taskId, + "sessionId": sessionId, + "message": { + "role": "user", + "parts": [ + { + "type": "text", # Explicitly match TextPart structure + "text": prompt, + } + ], + }, + "acceptedOutputModes": ["text", "text/plain"], # What the client wants back + # historyLength could be added if needed + } + + print(colorize(C_GRAY, f"Sending task {taskId}...")) + + try: + # --- Send Task (Non-Streaming) --- + response = await client.send_task(payload) + + # --- Process Response --- + if response.error: + print(colorize(C_RED, f"Error from agent (Code: {response.error.code}): {response.error.message}")) + if response.error.data: + print(colorize(C_GRAY, f"Error Data: {response.error.data}")) + elif response.result: + task_result = response.result + print(colorize(C_GREEN, f"Task {task_result.id} finished with state: {task_result.status.state}")) + + final_answer = "Agent did not provide a final artifact." + # Extract answer from artifacts (as implemented in PocketFlowTaskManager) + if task_result.artifacts: + try: + # Find the first text part in the first artifact + first_artifact = task_result.artifacts[0] + first_text_part = next( + (p for p in first_artifact.parts if isinstance(p, TextPart)), + None + ) + if first_text_part: + final_answer = first_text_part.text + else: + final_answer = f"(Non-text artifact received: {first_artifact.parts})" + except (IndexError, AttributeError, TypeError) as e: + final_answer = f"(Error parsing artifact: {e})" + elif task_result.status.message and task_result.status.message.parts: + # Fallback to status message if no artifact + try: + first_text_part = next( + (p for p in task_result.status.message.parts if isinstance(p, TextPart)), + None + ) + if first_text_part: + final_answer = f"(Final status message: {first_text_part.text})" + + except (AttributeError, TypeError) as e: + final_answer = f"(Error parsing status message: {e})" + + + print(colorize(C_BOLD + C_WHITE, f"\nAgent Response:\n{final_answer}")) + + else: + # Should not happen if error is None + print(colorize(C_YELLOW, "Received response with no result and no error.")) + + except A2AClientError as e: + print(colorize(C_RED, f"\nClient Error: {e}")) + except Exception as e: + print(colorize(C_RED, f"\nAn unexpected error occurred: {e}")) + +if __name__ == "__main__": + asyncio.run(cli()) \ No newline at end of file diff --git a/cookbook/pocketflow_a2a/a2a_server.py b/cookbook/pocketflow_a2a/a2a_server.py new file mode 100644 index 0000000..9ef6127 --- /dev/null +++ b/cookbook/pocketflow_a2a/a2a_server.py @@ -0,0 +1,81 @@ +# FILE: pocketflow_a2a_agent/__main__.py +import click +import logging +import os + +# Import from the common code you copied +from common.server import A2AServer +from common.types import AgentCard, AgentCapabilities, AgentSkill, MissingAPIKeyError + +# Import your custom TaskManager (which now imports from your original files) +from .task_manager import PocketFlowTaskManager + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +@click.command() +@click.option("--host", "host", default="localhost") +@click.option("--port", "port", default=10003) # Use a different port from other agents +def main(host, port): + """Starts the PocketFlow A2A Agent server.""" + try: + # Check for necessary API keys (add others if needed) + if not os.getenv("OPENAI_API_KEY"): + raise MissingAPIKeyError("OPENAI_API_KEY environment variable not set.") + + # --- Define the Agent Card --- + capabilities = AgentCapabilities( + streaming=False, # This simple implementation is synchronous + pushNotifications=False, + stateTransitionHistory=False # PocketFlow state isn't exposed via A2A history + ) + skill = AgentSkill( + id="web_research_qa", + name="Web Research and Answering", + description="Answers questions using web search results when necessary.", + tags=["research", "qa", "web search"], + examples=[ + "Who won the Nobel Prize in Physics 2024?", + "What is quantum computing?", + "Summarize the latest news about AI.", + ], + # Input/Output modes defined in the TaskManager + inputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES, + outputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES, + ) + agent_card = AgentCard( + name="PocketFlow Research Agent (A2A Wrapped)", + description="A simple research agent based on PocketFlow, made accessible via A2A.", + url=f"http://{host}:{port}/", # The endpoint A2A clients will use + version="0.1.0-a2a", + capabilities=capabilities, + skills=[skill], + # Assuming no specific provider or auth for this example + provider=None, + authentication=None, + defaultInputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES, + defaultOutputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES, + ) + + # --- Initialize and Start Server --- + task_manager = PocketFlowTaskManager() # Instantiate your custom manager + server = A2AServer( + agent_card=agent_card, + task_manager=task_manager, + host=host, + port=port, + ) + + logger.info(f"Starting PocketFlow A2A server on http://{host}:{port}") + server.start() + + except MissingAPIKeyError as e: + logger.error(f"Configuration Error: {e}") + exit(1) + except Exception as e: + logger.error(f"An error occurred during server startup: {e}", exc_info=True) + exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/cookbook/pocketflow_a2a/common/__init__.py b/cookbook/pocketflow_a2a/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cookbook/pocketflow_a2a/common/client/__init__.py b/cookbook/pocketflow_a2a/common/client/__init__.py new file mode 100644 index 0000000..5aa0e0d --- /dev/null +++ b/cookbook/pocketflow_a2a/common/client/__init__.py @@ -0,0 +1,4 @@ +from .client import A2AClient +from .card_resolver import A2ACardResolver + +__all__ = ["A2AClient", "A2ACardResolver"] diff --git a/cookbook/pocketflow_a2a/common/client/card_resolver.py b/cookbook/pocketflow_a2a/common/client/card_resolver.py new file mode 100644 index 0000000..dd5873c --- /dev/null +++ b/cookbook/pocketflow_a2a/common/client/card_resolver.py @@ -0,0 +1,21 @@ +import httpx +from common.types import ( + AgentCard, + A2AClientJSONError, +) +import json + + +class A2ACardResolver: + def __init__(self, base_url, agent_card_path="/.well-known/agent.json"): + self.base_url = base_url.rstrip("/") + self.agent_card_path = agent_card_path.lstrip("/") + + def get_agent_card(self) -> AgentCard: + with httpx.Client() as client: + response = client.get(self.base_url + "/" + self.agent_card_path) + response.raise_for_status() + try: + return AgentCard(**response.json()) + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e diff --git a/cookbook/pocketflow_a2a/common/client/client.py b/cookbook/pocketflow_a2a/common/client/client.py new file mode 100644 index 0000000..5e969d1 --- /dev/null +++ b/cookbook/pocketflow_a2a/common/client/client.py @@ -0,0 +1,86 @@ +import httpx +from httpx_sse import connect_sse +from typing import Any, AsyncIterable +from common.types import ( + AgentCard, + GetTaskRequest, + SendTaskRequest, + SendTaskResponse, + JSONRPCRequest, + GetTaskResponse, + CancelTaskResponse, + CancelTaskRequest, + SetTaskPushNotificationRequest, + SetTaskPushNotificationResponse, + GetTaskPushNotificationRequest, + GetTaskPushNotificationResponse, + A2AClientHTTPError, + A2AClientJSONError, + SendTaskStreamingRequest, + SendTaskStreamingResponse, +) +import json + + +class A2AClient: + def __init__(self, agent_card: AgentCard = None, url: str = None): + if agent_card: + self.url = agent_card.url + elif url: + self.url = url + else: + raise ValueError("Must provide either agent_card or url") + + async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse: + request = SendTaskRequest(params=payload) + return SendTaskResponse(**await self._send_request(request)) + + async def send_task_streaming( + self, payload: dict[str, Any] + ) -> AsyncIterable[SendTaskStreamingResponse]: + request = SendTaskStreamingRequest(params=payload) + with httpx.Client(timeout=None) as client: + with connect_sse( + client, "POST", self.url, json=request.model_dump() + ) as event_source: + try: + for sse in event_source.iter_sse(): + yield SendTaskStreamingResponse(**json.loads(sse.data)) + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError(400, str(e)) from e + + async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]: + async with httpx.AsyncClient() as client: + try: + # Image generation could take time, adding timeout + response = await client.post( + self.url, json=request.model_dump(), timeout=30 + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + + async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse: + request = GetTaskRequest(params=payload) + return GetTaskResponse(**await self._send_request(request)) + + async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse: + request = CancelTaskRequest(params=payload) + return CancelTaskResponse(**await self._send_request(request)) + + async def set_task_callback( + self, payload: dict[str, Any] + ) -> SetTaskPushNotificationResponse: + request = SetTaskPushNotificationRequest(params=payload) + return SetTaskPushNotificationResponse(**await self._send_request(request)) + + async def get_task_callback( + self, payload: dict[str, Any] + ) -> GetTaskPushNotificationResponse: + request = GetTaskPushNotificationRequest(params=payload) + return GetTaskPushNotificationResponse(**await self._send_request(request)) diff --git a/cookbook/pocketflow_a2a/common/server/__init__.py b/cookbook/pocketflow_a2a/common/server/__init__.py new file mode 100644 index 0000000..10f5fa4 --- /dev/null +++ b/cookbook/pocketflow_a2a/common/server/__init__.py @@ -0,0 +1,4 @@ +from .server import A2AServer +from .task_manager import TaskManager, InMemoryTaskManager + +__all__ = ["A2AServer", "TaskManager", "InMemoryTaskManager"] diff --git a/cookbook/pocketflow_a2a/common/server/server.py b/cookbook/pocketflow_a2a/common/server/server.py new file mode 100644 index 0000000..62740ef --- /dev/null +++ b/cookbook/pocketflow_a2a/common/server/server.py @@ -0,0 +1,120 @@ +from starlette.applications import Starlette +from starlette.responses import JSONResponse +from sse_starlette.sse import EventSourceResponse +from starlette.requests import Request +from common.types import ( + A2ARequest, + JSONRPCResponse, + InvalidRequestError, + JSONParseError, + GetTaskRequest, + CancelTaskRequest, + SendTaskRequest, + SetTaskPushNotificationRequest, + GetTaskPushNotificationRequest, + InternalError, + AgentCard, + TaskResubscriptionRequest, + SendTaskStreamingRequest, +) +from pydantic import ValidationError +import json +from typing import AsyncIterable, Any +from common.server.task_manager import TaskManager + +import logging + +logger = logging.getLogger(__name__) + + +class A2AServer: + def __init__( + self, + host="0.0.0.0", + port=5000, + endpoint="/", + agent_card: AgentCard = None, + task_manager: TaskManager = None, + ): + self.host = host + self.port = port + self.endpoint = endpoint + self.task_manager = task_manager + self.agent_card = agent_card + self.app = Starlette() + self.app.add_route(self.endpoint, self._process_request, methods=["POST"]) + self.app.add_route( + "/.well-known/agent.json", self._get_agent_card, methods=["GET"] + ) + + def start(self): + if self.agent_card is None: + raise ValueError("agent_card is not defined") + + if self.task_manager is None: + raise ValueError("request_handler is not defined") + + import uvicorn + + uvicorn.run(self.app, host=self.host, port=self.port) + + def _get_agent_card(self, request: Request) -> JSONResponse: + return JSONResponse(self.agent_card.model_dump(exclude_none=True)) + + async def _process_request(self, request: Request): + try: + body = await request.json() + json_rpc_request = A2ARequest.validate_python(body) + + if isinstance(json_rpc_request, GetTaskRequest): + result = await self.task_manager.on_get_task(json_rpc_request) + elif isinstance(json_rpc_request, SendTaskRequest): + result = await self.task_manager.on_send_task(json_rpc_request) + elif isinstance(json_rpc_request, SendTaskStreamingRequest): + result = await self.task_manager.on_send_task_subscribe( + json_rpc_request + ) + elif isinstance(json_rpc_request, CancelTaskRequest): + result = await self.task_manager.on_cancel_task(json_rpc_request) + elif isinstance(json_rpc_request, SetTaskPushNotificationRequest): + result = await self.task_manager.on_set_task_push_notification(json_rpc_request) + elif isinstance(json_rpc_request, GetTaskPushNotificationRequest): + result = await self.task_manager.on_get_task_push_notification(json_rpc_request) + elif isinstance(json_rpc_request, TaskResubscriptionRequest): + result = await self.task_manager.on_resubscribe_to_task( + json_rpc_request + ) + else: + logger.warning(f"Unexpected request type: {type(json_rpc_request)}") + raise ValueError(f"Unexpected request type: {type(request)}") + + return self._create_response(result) + + except Exception as e: + return self._handle_exception(e) + + def _handle_exception(self, e: Exception) -> JSONResponse: + if isinstance(e, json.decoder.JSONDecodeError): + json_rpc_error = JSONParseError() + elif isinstance(e, ValidationError): + json_rpc_error = InvalidRequestError(data=json.loads(e.json())) + else: + logger.error(f"Unhandled exception: {e}") + json_rpc_error = InternalError() + + response = JSONRPCResponse(id=None, error=json_rpc_error) + return JSONResponse(response.model_dump(exclude_none=True), status_code=400) + + def _create_response(self, result: Any) -> JSONResponse | EventSourceResponse: + if isinstance(result, AsyncIterable): + + async def event_generator(result) -> AsyncIterable[dict[str, str]]: + async for item in result: + yield {"data": item.model_dump_json(exclude_none=True)} + + return EventSourceResponse(event_generator(result)) + elif isinstance(result, JSONRPCResponse): + return JSONResponse(result.model_dump(exclude_none=True)) + else: + logger.error(f"Unexpected result type: {type(result)}") + raise ValueError(f"Unexpected result type: {type(result)}") diff --git a/cookbook/pocketflow_a2a/common/server/task_manager.py b/cookbook/pocketflow_a2a/common/server/task_manager.py new file mode 100644 index 0000000..6c4c91b --- /dev/null +++ b/cookbook/pocketflow_a2a/common/server/task_manager.py @@ -0,0 +1,277 @@ +from abc import ABC, abstractmethod +from typing import Union, AsyncIterable, List +from common.types import Task +from common.types import ( + JSONRPCResponse, + TaskIdParams, + TaskQueryParams, + GetTaskRequest, + TaskNotFoundError, + SendTaskRequest, + CancelTaskRequest, + TaskNotCancelableError, + SetTaskPushNotificationRequest, + GetTaskPushNotificationRequest, + GetTaskResponse, + CancelTaskResponse, + SendTaskResponse, + SetTaskPushNotificationResponse, + GetTaskPushNotificationResponse, + PushNotificationNotSupportedError, + TaskSendParams, + TaskStatus, + TaskState, + TaskResubscriptionRequest, + SendTaskStreamingRequest, + SendTaskStreamingResponse, + Artifact, + PushNotificationConfig, + TaskStatusUpdateEvent, + JSONRPCError, + TaskPushNotificationConfig, + InternalError, +) +from common.server.utils import new_not_implemented_error +import asyncio +import logging + +logger = logging.getLogger(__name__) + +class TaskManager(ABC): + @abstractmethod + async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: + pass + + @abstractmethod + async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: + pass + + @abstractmethod + async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: + pass + + @abstractmethod + async def on_send_task_subscribe( + self, request: SendTaskStreamingRequest + ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: + pass + + @abstractmethod + async def on_set_task_push_notification( + self, request: SetTaskPushNotificationRequest + ) -> SetTaskPushNotificationResponse: + pass + + @abstractmethod + async def on_get_task_push_notification( + self, request: GetTaskPushNotificationRequest + ) -> GetTaskPushNotificationResponse: + pass + + @abstractmethod + async def on_resubscribe_to_task( + self, request: TaskResubscriptionRequest + ) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]: + pass + + +class InMemoryTaskManager(TaskManager): + def __init__(self): + self.tasks: dict[str, Task] = {} + self.push_notification_infos: dict[str, PushNotificationConfig] = {} + self.lock = asyncio.Lock() + self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {} + self.subscriber_lock = asyncio.Lock() + + async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: + logger.info(f"Getting task {request.params.id}") + task_query_params: TaskQueryParams = request.params + + async with self.lock: + task = self.tasks.get(task_query_params.id) + if task is None: + return GetTaskResponse(id=request.id, error=TaskNotFoundError()) + + task_result = self.append_task_history( + task, task_query_params.historyLength + ) + + return GetTaskResponse(id=request.id, result=task_result) + + async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: + logger.info(f"Cancelling task {request.params.id}") + task_id_params: TaskIdParams = request.params + + async with self.lock: + task = self.tasks.get(task_id_params.id) + if task is None: + return CancelTaskResponse(id=request.id, error=TaskNotFoundError()) + + return CancelTaskResponse(id=request.id, error=TaskNotCancelableError()) + + @abstractmethod + async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: + pass + + @abstractmethod + async def on_send_task_subscribe( + self, request: SendTaskStreamingRequest + ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: + pass + + async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig): + async with self.lock: + task = self.tasks.get(task_id) + if task is None: + raise ValueError(f"Task not found for {task_id}") + + self.push_notification_infos[task_id] = notification_config + + return + + async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig: + async with self.lock: + task = self.tasks.get(task_id) + if task is None: + raise ValueError(f"Task not found for {task_id}") + + return self.push_notification_infos[task_id] + + return + + async def has_push_notification_info(self, task_id: str) -> bool: + async with self.lock: + return task_id in self.push_notification_infos + + + async def on_set_task_push_notification( + self, request: SetTaskPushNotificationRequest + ) -> SetTaskPushNotificationResponse: + logger.info(f"Setting task push notification {request.params.id}") + task_notification_params: TaskPushNotificationConfig = request.params + + try: + await self.set_push_notification_info(task_notification_params.id, task_notification_params.pushNotificationConfig) + except Exception as e: + logger.error(f"Error while setting push notification info: {e}") + return JSONRPCResponse( + id=request.id, + error=InternalError( + message="An error occurred while setting push notification info" + ), + ) + + return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params) + + async def on_get_task_push_notification( + self, request: GetTaskPushNotificationRequest + ) -> GetTaskPushNotificationResponse: + logger.info(f"Getting task push notification {request.params.id}") + task_params: TaskIdParams = request.params + + try: + notification_info = await self.get_push_notification_info(task_params.id) + except Exception as e: + logger.error(f"Error while getting push notification info: {e}") + return GetTaskPushNotificationResponse( + id=request.id, + error=InternalError( + message="An error occurred while getting push notification info" + ), + ) + + return GetTaskPushNotificationResponse(id=request.id, result=TaskPushNotificationConfig(id=task_params.id, pushNotificationConfig=notification_info)) + + async def upsert_task(self, task_send_params: TaskSendParams) -> Task: + logger.info(f"Upserting task {task_send_params.id}") + async with self.lock: + task = self.tasks.get(task_send_params.id) + if task is None: + task = Task( + id=task_send_params.id, + sessionId = task_send_params.sessionId, + messages=[task_send_params.message], + status=TaskStatus(state=TaskState.SUBMITTED), + history=[task_send_params.message], + ) + self.tasks[task_send_params.id] = task + else: + task.history.append(task_send_params.message) + + return task + + async def on_resubscribe_to_task( + self, request: TaskResubscriptionRequest + ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: + return new_not_implemented_error(request.id) + + async def update_store( + self, task_id: str, status: TaskStatus, artifacts: list[Artifact] + ) -> Task: + async with self.lock: + try: + task = self.tasks[task_id] + except KeyError: + logger.error(f"Task {task_id} not found for updating the task") + raise ValueError(f"Task {task_id} not found") + + task.status = status + + if status.message is not None: + task.history.append(status.message) + + if artifacts is not None: + if task.artifacts is None: + task.artifacts = [] + task.artifacts.extend(artifacts) + + return task + + def append_task_history(self, task: Task, historyLength: int | None): + new_task = task.model_copy() + if historyLength is not None and historyLength > 0: + new_task.history = new_task.history[-historyLength:] + else: + new_task.history = [] + + return new_task + + async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False): + async with self.subscriber_lock: + if task_id not in self.task_sse_subscribers: + if is_resubscribe: + raise ValueError("Task not found for resubscription") + else: + self.task_sse_subscribers[task_id] = [] + + sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited + self.task_sse_subscribers[task_id].append(sse_event_queue) + return sse_event_queue + + async def enqueue_events_for_sse(self, task_id, task_update_event): + async with self.subscriber_lock: + if task_id not in self.task_sse_subscribers: + return + + current_subscribers = self.task_sse_subscribers[task_id] + for subscriber in current_subscribers: + await subscriber.put(task_update_event) + + async def dequeue_events_for_sse( + self, request_id, task_id, sse_event_queue: asyncio.Queue + ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: + try: + while True: + event = await sse_event_queue.get() + if isinstance(event, JSONRPCError): + yield SendTaskStreamingResponse(id=request_id, error=event) + break + + yield SendTaskStreamingResponse(id=request_id, result=event) + if isinstance(event, TaskStatusUpdateEvent) and event.final: + break + finally: + async with self.subscriber_lock: + if task_id in self.task_sse_subscribers: + self.task_sse_subscribers[task_id].remove(sse_event_queue) + diff --git a/cookbook/pocketflow_a2a/common/server/utils.py b/cookbook/pocketflow_a2a/common/server/utils.py new file mode 100644 index 0000000..50e3985 --- /dev/null +++ b/cookbook/pocketflow_a2a/common/server/utils.py @@ -0,0 +1,28 @@ +from common.types import ( + JSONRPCResponse, + ContentTypeNotSupportedError, + UnsupportedOperationError, +) +from typing import List + + +def are_modalities_compatible( + server_output_modes: List[str], client_output_modes: List[str] +): + """Modalities are compatible if they are both non-empty + and there is at least one common element.""" + if client_output_modes is None or len(client_output_modes) == 0: + return True + + if server_output_modes is None or len(server_output_modes) == 0: + return True + + return any(x in server_output_modes for x in client_output_modes) + + +def new_incompatible_types_error(request_id): + return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError()) + + +def new_not_implemented_error(request_id): + return JSONRPCResponse(id=request_id, error=UnsupportedOperationError()) diff --git a/cookbook/pocketflow_a2a/common/types.py b/cookbook/pocketflow_a2a/common/types.py new file mode 100644 index 0000000..585fec3 --- /dev/null +++ b/cookbook/pocketflow_a2a/common/types.py @@ -0,0 +1,365 @@ +from typing import Union, Any +from pydantic import BaseModel, Field, TypeAdapter +from typing import Literal, List, Annotated, Optional +from datetime import datetime +from pydantic import model_validator, ConfigDict, field_serializer +from uuid import uuid4 +from enum import Enum +from typing_extensions import Self + + +class TaskState(str, Enum): + SUBMITTED = "submitted" + WORKING = "working" + INPUT_REQUIRED = "input-required" + COMPLETED = "completed" + CANCELED = "canceled" + FAILED = "failed" + UNKNOWN = "unknown" + + +class TextPart(BaseModel): + type: Literal["text"] = "text" + text: str + metadata: dict[str, Any] | None = None + + +class FileContent(BaseModel): + name: str | None = None + mimeType: str | None = None + bytes: str | None = None + uri: str | None = None + + @model_validator(mode="after") + def check_content(self) -> Self: + if not (self.bytes or self.uri): + raise ValueError("Either 'bytes' or 'uri' must be present in the file data") + if self.bytes and self.uri: + raise ValueError( + "Only one of 'bytes' or 'uri' can be present in the file data" + ) + return self + + +class FilePart(BaseModel): + type: Literal["file"] = "file" + file: FileContent + metadata: dict[str, Any] | None = None + + +class DataPart(BaseModel): + type: Literal["data"] = "data" + data: dict[str, Any] + metadata: dict[str, Any] | None = None + + +Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")] + + +class Message(BaseModel): + role: Literal["user", "agent"] + parts: List[Part] + metadata: dict[str, Any] | None = None + + +class TaskStatus(BaseModel): + state: TaskState + message: Message | None = None + timestamp: datetime = Field(default_factory=datetime.now) + + @field_serializer("timestamp") + def serialize_dt(self, dt: datetime, _info): + return dt.isoformat() + + +class Artifact(BaseModel): + name: str | None = None + description: str | None = None + parts: List[Part] + metadata: dict[str, Any] | None = None + index: int = 0 + append: bool | None = None + lastChunk: bool | None = None + + +class Task(BaseModel): + id: str + sessionId: str | None = None + status: TaskStatus + artifacts: List[Artifact] | None = None + history: List[Message] | None = None + metadata: dict[str, Any] | None = None + + +class TaskStatusUpdateEvent(BaseModel): + id: str + status: TaskStatus + final: bool = False + metadata: dict[str, Any] | None = None + + +class TaskArtifactUpdateEvent(BaseModel): + id: str + artifact: Artifact + metadata: dict[str, Any] | None = None + + +class AuthenticationInfo(BaseModel): + model_config = ConfigDict(extra="allow") + + schemes: List[str] + credentials: str | None = None + + +class PushNotificationConfig(BaseModel): + url: str + token: str | None = None + authentication: AuthenticationInfo | None = None + + +class TaskIdParams(BaseModel): + id: str + metadata: dict[str, Any] | None = None + + +class TaskQueryParams(TaskIdParams): + historyLength: int | None = None + + +class TaskSendParams(BaseModel): + id: str + sessionId: str = Field(default_factory=lambda: uuid4().hex) + message: Message + acceptedOutputModes: Optional[List[str]] = None + pushNotification: PushNotificationConfig | None = None + historyLength: int | None = None + metadata: dict[str, Any] | None = None + + +class TaskPushNotificationConfig(BaseModel): + id: str + pushNotificationConfig: PushNotificationConfig + + +## RPC Messages + + +class JSONRPCMessage(BaseModel): + jsonrpc: Literal["2.0"] = "2.0" + id: int | str | None = Field(default_factory=lambda: uuid4().hex) + + +class JSONRPCRequest(JSONRPCMessage): + method: str + params: dict[str, Any] | None = None + + +class JSONRPCError(BaseModel): + code: int + message: str + data: Any | None = None + + +class JSONRPCResponse(JSONRPCMessage): + result: Any | None = None + error: JSONRPCError | None = None + + +class SendTaskRequest(JSONRPCRequest): + method: Literal["tasks/send"] = "tasks/send" + params: TaskSendParams + + +class SendTaskResponse(JSONRPCResponse): + result: Task | None = None + + +class SendTaskStreamingRequest(JSONRPCRequest): + method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe" + params: TaskSendParams + + +class SendTaskStreamingResponse(JSONRPCResponse): + result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None + + +class GetTaskRequest(JSONRPCRequest): + method: Literal["tasks/get"] = "tasks/get" + params: TaskQueryParams + + +class GetTaskResponse(JSONRPCResponse): + result: Task | None = None + + +class CancelTaskRequest(JSONRPCRequest): + method: Literal["tasks/cancel",] = "tasks/cancel" + params: TaskIdParams + + +class CancelTaskResponse(JSONRPCResponse): + result: Task | None = None + + +class SetTaskPushNotificationRequest(JSONRPCRequest): + method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set" + params: TaskPushNotificationConfig + + +class SetTaskPushNotificationResponse(JSONRPCResponse): + result: TaskPushNotificationConfig | None = None + + +class GetTaskPushNotificationRequest(JSONRPCRequest): + method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get" + params: TaskIdParams + + +class GetTaskPushNotificationResponse(JSONRPCResponse): + result: TaskPushNotificationConfig | None = None + + +class TaskResubscriptionRequest(JSONRPCRequest): + method: Literal["tasks/resubscribe",] = "tasks/resubscribe" + params: TaskIdParams + + +A2ARequest = TypeAdapter( + Annotated[ + Union[ + SendTaskRequest, + GetTaskRequest, + CancelTaskRequest, + SetTaskPushNotificationRequest, + GetTaskPushNotificationRequest, + TaskResubscriptionRequest, + SendTaskStreamingRequest, + ], + Field(discriminator="method"), + ] +) + +## Error types + + +class JSONParseError(JSONRPCError): + code: int = -32700 + message: str = "Invalid JSON payload" + data: Any | None = None + + +class InvalidRequestError(JSONRPCError): + code: int = -32600 + message: str = "Request payload validation error" + data: Any | None = None + + +class MethodNotFoundError(JSONRPCError): + code: int = -32601 + message: str = "Method not found" + data: None = None + + +class InvalidParamsError(JSONRPCError): + code: int = -32602 + message: str = "Invalid parameters" + data: Any | None = None + + +class InternalError(JSONRPCError): + code: int = -32603 + message: str = "Internal error" + data: Any | None = None + + +class TaskNotFoundError(JSONRPCError): + code: int = -32001 + message: str = "Task not found" + data: None = None + + +class TaskNotCancelableError(JSONRPCError): + code: int = -32002 + message: str = "Task cannot be canceled" + data: None = None + + +class PushNotificationNotSupportedError(JSONRPCError): + code: int = -32003 + message: str = "Push Notification is not supported" + data: None = None + + +class UnsupportedOperationError(JSONRPCError): + code: int = -32004 + message: str = "This operation is not supported" + data: None = None + + +class ContentTypeNotSupportedError(JSONRPCError): + code: int = -32005 + message: str = "Incompatible content types" + data: None = None + + +class AgentProvider(BaseModel): + organization: str + url: str | None = None + + +class AgentCapabilities(BaseModel): + streaming: bool = False + pushNotifications: bool = False + stateTransitionHistory: bool = False + + +class AgentAuthentication(BaseModel): + schemes: List[str] + credentials: str | None = None + + +class AgentSkill(BaseModel): + id: str + name: str + description: str | None = None + tags: List[str] | None = None + examples: List[str] | None = None + inputModes: List[str] | None = None + outputModes: List[str] | None = None + + +class AgentCard(BaseModel): + name: str + description: str | None = None + url: str + provider: AgentProvider | None = None + version: str + documentationUrl: str | None = None + capabilities: AgentCapabilities + authentication: AgentAuthentication | None = None + defaultInputModes: List[str] = ["text"] + defaultOutputModes: List[str] = ["text"] + skills: List[AgentSkill] + + +class A2AClientError(Exception): + pass + + +class A2AClientHTTPError(A2AClientError): + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + super().__init__(f"HTTP Error {status_code}: {message}") + + +class A2AClientJSONError(A2AClientError): + def __init__(self, message: str): + self.message = message + super().__init__(f"JSON Error: {message}") + + +class MissingAPIKeyError(Exception): + """Exception for missing API key.""" + + pass \ No newline at end of file diff --git a/cookbook/pocketflow_a2a/common/utils/in_memory_cache.py b/cookbook/pocketflow_a2a/common/utils/in_memory_cache.py new file mode 100644 index 0000000..06a29fe --- /dev/null +++ b/cookbook/pocketflow_a2a/common/utils/in_memory_cache.py @@ -0,0 +1,109 @@ +"""In Memory Cache utility.""" + +import threading +import time +from typing import Any, Dict, Optional + + +class InMemoryCache: + """A thread-safe Singleton class to manage cache data. + + Ensures only one instance of the cache exists across the application. + """ + + _instance: Optional["InMemoryCache"] = None + _lock: threading.Lock = threading.Lock() + _initialized: bool = False + + def __new__(cls): + """Override __new__ to control instance creation (Singleton pattern). + + Uses a lock to ensure thread safety during the first instantiation. + + Returns: + The singleton instance of InMemoryCache. + """ + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + """Initialize the cache storage. + + Uses a flag (_initialized) to ensure this logic runs only on the very first + creation of the singleton instance. + """ + if not self._initialized: + with self._lock: + if not self._initialized: + # print("Initializing SessionCache storage") + self._cache_data: Dict[str, Dict[str, Any]] = {} + self._ttl: Dict[str, float] = {} + self._data_lock: threading.Lock = threading.Lock() + self._initialized = True + + def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: + """Set a key-value pair. + + Args: + key: The key for the data. + value: The data to store. + ttl: Time to live in seconds. If None, data will not expire. + """ + with self._data_lock: + self._cache_data[key] = value + + if ttl is not None: + self._ttl[key] = time.time() + ttl + else: + if key in self._ttl: + del self._ttl[key] + + def get(self, key: str, default: Any = None) -> Any: + """Get the value associated with a key. + + Args: + key: The key for the data within the session. + default: The value to return if the session or key is not found. + + Returns: + The cached value, or the default value if not found. + """ + with self._data_lock: + if key in self._ttl and time.time() > self._ttl[key]: + del self._cache_data[key] + del self._ttl[key] + return default + return self._cache_data.get(key, default) + + def delete(self, key: str) -> None: + """Delete a specific key-value pair from a cache. + + Args: + key: The key to delete. + + Returns: + True if the key was found and deleted, False otherwise. + """ + + with self._data_lock: + if key in self._cache_data: + del self._cache_data[key] + if key in self._ttl: + del self._ttl[key] + return True + return False + + def clear(self) -> bool: + """Remove all data. + + Returns: + True if the data was cleared, False otherwise. + """ + with self._data_lock: + self._cache_data.clear() + self._ttl.clear() + return True + return False diff --git a/cookbook/pocketflow_a2a/common/utils/push_notification_auth.py b/cookbook/pocketflow_a2a/common/utils/push_notification_auth.py new file mode 100644 index 0000000..cc74ade --- /dev/null +++ b/cookbook/pocketflow_a2a/common/utils/push_notification_auth.py @@ -0,0 +1,135 @@ +from jwcrypto import jwk +import uuid +from starlette.responses import JSONResponse +from starlette.requests import Request +from typing import Any + +import jwt +import time +import json +import hashlib +import httpx +import logging + +from jwt import PyJWK, PyJWKClient + +logger = logging.getLogger(__name__) +AUTH_HEADER_PREFIX = 'Bearer ' + +class PushNotificationAuth: + def _calculate_request_body_sha256(self, data: dict[str, Any]): + """Calculates the SHA256 hash of a request body. + + This logic needs to be same for both the agent who signs the payload and the client verifier. + """ + body_str = json.dumps( + data, + ensure_ascii=False, + allow_nan=False, + indent=None, + separators=(",", ":"), + ) + return hashlib.sha256(body_str.encode()).hexdigest() + +class PushNotificationSenderAuth(PushNotificationAuth): + def __init__(self): + self.public_keys = [] + self.private_key_jwk: PyJWK = None + + @staticmethod + async def verify_push_notification_url(url: str) -> bool: + async with httpx.AsyncClient(timeout=10) as client: + try: + validation_token = str(uuid.uuid4()) + response = await client.get( + url, + params={"validationToken": validation_token} + ) + response.raise_for_status() + is_verified = response.text == validation_token + + logger.info(f"Verified push-notification URL: {url} => {is_verified}") + return is_verified + except Exception as e: + logger.warning(f"Error during sending push-notification for URL {url}: {e}") + + return False + + def generate_jwk(self): + key = jwk.JWK.generate(kty='RSA', size=2048, kid=str(uuid.uuid4()), use="sig") + self.public_keys.append(key.export_public(as_dict=True)) + self.private_key_jwk = PyJWK.from_json(key.export_private()) + + def handle_jwks_endpoint(self, _request: Request): + """Allow clients to fetch public keys. + """ + return JSONResponse({ + "keys": self.public_keys + }) + + def _generate_jwt(self, data: dict[str, Any]): + """JWT is generated by signing both the request payload SHA digest and time of token generation. + + Payload is signed with private key and it ensures the integrity of payload for client. + Including iat prevents from replay attack. + """ + + iat = int(time.time()) + + return jwt.encode( + {"iat": iat, "request_body_sha256": self._calculate_request_body_sha256(data)}, + key=self.private_key_jwk, + headers={"kid": self.private_key_jwk.key_id}, + algorithm="RS256" + ) + + async def send_push_notification(self, url: str, data: dict[str, Any]): + jwt_token = self._generate_jwt(data) + headers = {'Authorization': f"Bearer {jwt_token}"} + async with httpx.AsyncClient(timeout=10) as client: + try: + response = await client.post( + url, + json=data, + headers=headers + ) + response.raise_for_status() + logger.info(f"Push-notification sent for URL: {url}") + except Exception as e: + logger.warning(f"Error during sending push-notification for URL {url}: {e}") + +class PushNotificationReceiverAuth(PushNotificationAuth): + def __init__(self): + self.public_keys_jwks = [] + self.jwks_client = None + + async def load_jwks(self, jwks_url: str): + self.jwks_client = PyJWKClient(jwks_url) + + async def verify_push_notification(self, request: Request) -> bool: + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX): + print("Invalid authorization header") + return False + + token = auth_header[len(AUTH_HEADER_PREFIX):] + signing_key = self.jwks_client.get_signing_key_from_jwt(token) + + decode_token = jwt.decode( + token, + signing_key, + options={"require": ["iat", "request_body_sha256"]}, + algorithms=["RS256"], + ) + + actual_body_sha256 = self._calculate_request_body_sha256(await request.json()) + if actual_body_sha256 != decode_token["request_body_sha256"]: + # Payload signature does not match the digest in signed token. + raise ValueError("Invalid request body") + + if time.time() - decode_token["iat"] > 60 * 5: + # Do not allow push-notifications older than 5 minutes. + # This is to prevent replay attack. + raise ValueError("Token is expired") + + return True diff --git a/cookbook/pocketflow_a2a/flow.py b/cookbook/pocketflow_a2a/flow.py new file mode 100644 index 0000000..bc0cb80 --- /dev/null +++ b/cookbook/pocketflow_a2a/flow.py @@ -0,0 +1,33 @@ +from pocketflow import Flow +from nodes import DecideAction, SearchWeb, AnswerQuestion + +def create_agent_flow(): + """ + Create and connect the nodes to form a complete agent flow. + + The flow works like this: + 1. DecideAction node decides whether to search or answer + 2. If search, go to SearchWeb node + 3. If answer, go to AnswerQuestion node + 4. After SearchWeb completes, go back to DecideAction + + Returns: + Flow: A complete research agent flow + """ + # Create instances of each node + decide = DecideAction() + search = SearchWeb() + answer = AnswerQuestion() + + # Connect the nodes + # If DecideAction returns "search", go to SearchWeb + decide - "search" >> search + + # If DecideAction returns "answer", go to AnswerQuestion + decide - "answer" >> answer + + # After SearchWeb completes and returns "decide", go back to DecideAction + search - "decide" >> decide + + # Create and return the flow, starting with the DecideAction node + return Flow(start=decide) \ No newline at end of file diff --git a/cookbook/pocketflow_a2a/main.py b/cookbook/pocketflow_a2a/main.py new file mode 100644 index 0000000..4d2cbc4 --- /dev/null +++ b/cookbook/pocketflow_a2a/main.py @@ -0,0 +1,27 @@ +import sys +from flow import create_agent_flow + +def main(): + """Simple function to process a question.""" + # Default question + default_question = "Who won the Nobel Prize in Physics 2024?" + + # Get question from command line if provided with -- + question = default_question + for arg in sys.argv[1:]: + if arg.startswith("--"): + question = arg[2:] + break + + # Create the agent flow + agent_flow = create_agent_flow() + + # Process the question + shared = {"question": question} + print(f"šŸ¤” Processing question: {question}") + agent_flow.run(shared) + print("\nšŸŽÆ Final Answer:") + print(shared.get("answer", "No answer found")) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/cookbook/pocketflow_a2a/nodes.py b/cookbook/pocketflow_a2a/nodes.py new file mode 100644 index 0000000..d777bde --- /dev/null +++ b/cookbook/pocketflow_a2a/nodes.py @@ -0,0 +1,135 @@ +from pocketflow import Node +from utils import call_llm, search_web +import yaml + +class DecideAction(Node): + def prep(self, shared): + """Prepare the context and question for the decision-making process.""" + # Get the current context (default to "No previous search" if none exists) + context = shared.get("context", "No previous search") + # Get the question from the shared store + question = shared["question"] + # Return both for the exec step + return question, context + + def exec(self, inputs): + """Call the LLM to decide whether to search or answer.""" + question, context = inputs + + print(f"šŸ¤” Agent deciding what to do next...") + + # Create a prompt to help the LLM decide what to do next with proper yaml formatting + prompt = f""" +### CONTEXT +You are a research assistant that can search the web. +Question: {question} +Previous Research: {context} + +### ACTION SPACE +[1] search + Description: Look up more information on the web + Parameters: + - query (str): What to search for + +[2] answer + Description: Answer the question with current knowledge + Parameters: + - answer (str): Final answer to the question + +## NEXT ACTION +Decide the next action based on the context and available actions. +Return your response in this format: + +```yaml +thinking: | + +action: search OR answer +reason: +answer: +search_query: +``` +IMPORTANT: Make sure to: +1. Use proper indentation (4 spaces) for all multi-line fields +2. Use the | character for multi-line text fields +3. Keep single-line fields without the | character +""" + + # Call the LLM to make a decision + response = call_llm(prompt) + + # Parse the response to get the decision + yaml_str = response.split("```yaml")[1].split("```")[0].strip() + decision = yaml.safe_load(yaml_str) + + return decision + + def post(self, shared, prep_res, exec_res): + """Save the decision and determine the next step in the flow.""" + # If LLM decided to search, save the search query + if exec_res["action"] == "search": + shared["search_query"] = exec_res["search_query"] + print(f"šŸ” Agent decided to search for: {exec_res['search_query']}") + else: + shared["context"] = exec_res["answer"] #save the context if LLM gives the answer without searching. + print(f"šŸ’” Agent decided to answer the question") + + # Return the action to determine the next node in the flow + return exec_res["action"] + +class SearchWeb(Node): + def prep(self, shared): + """Get the search query from the shared store.""" + return shared["search_query"] + + def exec(self, search_query): + """Search the web for the given query.""" + # Call the search utility function + print(f"🌐 Searching the web for: {search_query}") + results = search_web(search_query) + return results + + def post(self, shared, prep_res, exec_res): + """Save the search results and go back to the decision node.""" + # Add the search results to the context in the shared store + previous = shared.get("context", "") + shared["context"] = previous + "\n\nSEARCH: " + shared["search_query"] + "\nRESULTS: " + exec_res + + print(f"šŸ“š Found information, analyzing results...") + + # Always go back to the decision node after searching + return "decide" + +class AnswerQuestion(Node): + def prep(self, shared): + """Get the question and context for answering.""" + return shared["question"], shared.get("context", "") + + def exec(self, inputs): + """Call the LLM to generate a final answer.""" + question, context = inputs + + print(f"āœļø Crafting final answer...") + + # Create a prompt for the LLM to answer the question + prompt = f""" +### CONTEXT +Based on the following information, answer the question. +Question: {question} +Research: {context} + +## YOUR ANSWER: +Provide a comprehensive answer using the research results. +""" + # Call the LLM to generate an answer + answer = call_llm(prompt) + return answer + + def post(self, shared, prep_res, exec_res): + """Save the final answer and complete the flow.""" + # Save the answer in the shared store + shared["answer"] = exec_res + + print(f"āœ… Answer generated successfully") + + # We're done - no need to continue the flow + return "done" diff --git a/cookbook/pocketflow_a2a/requirements.txt b/cookbook/pocketflow_a2a/requirements.txt new file mode 100644 index 0000000..5134cff --- /dev/null +++ b/cookbook/pocketflow_a2a/requirements.txt @@ -0,0 +1,21 @@ +# For PocketFlow Agent Logic +pocketflow>=0.0.1 +openai>=1.0.0 +duckduckgo-search>=7.5.2 +pyyaml>=5.1 + +# For A2A Server Infrastructure (from common) +starlette>=0.37.2,<0.38.0 +uvicorn[standard]>=0.29.0,<0.30.0 +sse-starlette>=1.8.2,<2.0.0 +pydantic>=2.0.0,<3.0.0 +httpx>=0.27.0,<0.28.0 +anyio>=3.0.0,<5.0.0 # Dependency of starlette/httpx + +# For running __main__.py +click>=8.0.0,<9.0.0 + +# For A2A Client +httpx>=0.27.0,<0.28.0 +asyncclick>=8.1.8 # Or just 'click' if you prefer asyncio.run +pydantic>=2.0.0,<3.0.0 # For common.types \ No newline at end of file diff --git a/cookbook/pocketflow_a2a/task_manager.py b/cookbook/pocketflow_a2a/task_manager.py new file mode 100644 index 0000000..039f0f8 --- /dev/null +++ b/cookbook/pocketflow_a2a/task_manager.py @@ -0,0 +1,112 @@ +# FILE: pocketflow_a2a_agent/task_manager.py +import logging +from typing import AsyncIterable, Union +import asyncio + +# Import from the common code you copied +from common.server.task_manager import InMemoryTaskManager +from common.types import ( + JSONRPCResponse, SendTaskRequest, SendTaskResponse, + SendTaskStreamingRequest, SendTaskStreamingResponse, Task, TaskSendParams, + TaskState, TaskStatus, TextPart, Artifact, UnsupportedOperationError, + InternalError, InvalidParamsError +) +import common.server.utils as server_utils + +# Import directly from your original PocketFlow files +from .flow import create_agent_flow # Assumes flow.py is in the same directory +from .utils import call_llm, search_web # Make utils functions available if needed elsewhere + +logger = logging.getLogger(__name__) + +class PocketFlowTaskManager(InMemoryTaskManager): + """ TaskManager implementation that runs the PocketFlow agent. """ + + SUPPORTED_CONTENT_TYPES = ["text", "text/plain"] # Define what the agent accepts/outputs + + async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: + """Handles non-streaming task requests.""" + logger.info(f"Received task send request: {request.params.id}") + + # Validate output modes + if not server_utils.are_modalities_compatible( + request.params.acceptedOutputModes, self.SUPPORTED_CONTENT_TYPES + ): + logger.warning( + "Unsupported output mode. Received %s, Support %s", + request.params.acceptedOutputModes, self.SUPPORTED_CONTENT_TYPES + ) + return SendTaskResponse(id=request.id, error=server_utils.new_incompatible_types_error(request.id).error) + + # Upsert the task in the store (initial state: submitted) + # We create the task first so its state can be tracked, even if the sync execution fails + await self.upsert_task(request.params) + # Update state to working before running + await self.update_store(request.params.id, TaskStatus(state=TaskState.WORKING), []) + + + # --- Run the PocketFlow logic --- + task_params: TaskSendParams = request.params + query = self._get_user_query(task_params) + if query is None: + fail_status = TaskStatus(state=TaskState.FAILED, message=Message(role="agent", parts=[TextPart(text="No text query found")])) + await self.update_store(task_params.id, fail_status, []) + return SendTaskResponse(id=request.id, error=InvalidParamsError(message="No text query found in message parts")) + + shared_data = {"question": query} + agent_flow = create_agent_flow() # Create the flow instance + + try: + # Run the synchronous PocketFlow + # In a real async server, you might run this in a separate thread/process + # executor to avoid blocking the event loop. For simplicity here, we run it directly. + # Consider adding a timeout if flows can hang. + logger.info(f"Running PocketFlow for task {task_params.id}...") + final_state_dict = agent_flow.run(shared_data) + logger.info(f"PocketFlow completed for task {task_params.id}") + answer_text = final_state_dict.get("answer", "Agent did not produce a final answer text.") + + # --- Package result into A2A Task --- + final_task_status = TaskStatus(state=TaskState.COMPLETED) + # Package the answer as an artifact + final_artifact = Artifact(parts=[TextPart(text=answer_text)]) + + # Update the task in the store with final status and artifact + final_task = await self.update_store( + task_params.id, final_task_status, [final_artifact] + ) + + # Prepare and return the A2A response + task_result = self.append_task_history(final_task, task_params.historyLength) + return SendTaskResponse(id=request.id, result=task_result) + + except Exception as e: + logger.error(f"Error executing PocketFlow for task {task_params.id}: {e}", exc_info=True) + # Update task state to FAILED + fail_status = TaskStatus( + state=TaskState.FAILED, + message=Message(role="agent", parts=[TextPart(text=f"Agent execution failed: {e}")]) + ) + await self.update_store(task_params.id, fail_status, []) + return SendTaskResponse(id=request.id, error=InternalError(message=f"Agent error: {e}")) + + async def on_send_task_subscribe( + self, request: SendTaskStreamingRequest + ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: + """Handles streaming requests - Not implemented for this synchronous agent.""" + logger.warning(f"Streaming requested for task {request.params.id}, but not supported by this PocketFlow agent implementation.") + # Return an error indicating streaming is not supported + return JSONRPCResponse(id=request.id, error=UnsupportedOperationError(message="Streaming not supported by this agent")) + + def _get_user_query(self, task_send_params: TaskSendParams) -> str | None: + """Extracts the first text part from the user message.""" + if not task_send_params.message or not task_send_params.message.parts: + logger.warning(f"No message parts found for task {task_send_params.id}") + return None + for part in task_send_params.message.parts: + # Ensure part is treated as a dictionary if it came from JSON + part_dict = part if isinstance(part, dict) else part.model_dump() + if part_dict.get("type") == "text" and "text" in part_dict: + return part_dict["text"] + logger.warning(f"No text part found in message for task {task_send_params.id}") + return None # No text part found \ No newline at end of file diff --git a/cookbook/pocketflow_a2a/utils.py b/cookbook/pocketflow_a2a/utils.py new file mode 100644 index 0000000..c56a175 --- /dev/null +++ b/cookbook/pocketflow_a2a/utils.py @@ -0,0 +1,30 @@ +from openai import OpenAI +import os +from duckduckgo_search import DDGS + +def call_llm(prompt): + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key")) + r = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": prompt}] + ) + return r.choices[0].message.content + +def search_web(query): + results = DDGS().text(query, max_results=5) + # Convert results to a string + results_str = "\n\n".join([f"Title: {r['title']}\nURL: {r['href']}\nSnippet: {r['body']}" for r in results]) + return results_str + +if __name__ == "__main__": + print("## Testing call_llm") + prompt = "In a few words, what is the meaning of life?" + print(f"## Prompt: {prompt}") + response = call_llm(prompt) + print(f"## Response: {response}") + + print("## Testing search_web") + query = "Who won the Nobel Prize in Physics 2024?" + print(f"## Query: {query}") + results = search_web(query) + print(f"## Results: {results}") \ No newline at end of file