finish a2a
This commit is contained in:
parent
91dc7bbc99
commit
5610fbe9d2
|
|
@ -0,0 +1,228 @@
|
|||
# PocketFlow Agent with A2A Protocol
|
||||
|
||||
This project demonstrates how to take an existing agent built with the PocketFlow library and make it accessible to other agents using the **Agent-to-Agent (A2A) communication protocol**.
|
||||
|
||||
## How it Works: A2A Integration
|
||||
|
||||
This project combines two main parts:
|
||||
|
||||
1. **PocketFlow Agent Logic:** The original agent code ([`nodes.py`](nodes.py), [`utils.py`](utils.py), [`flow.py`](flow.py)) defines the internal workflow (Decide -> Search -> Answer). This code is taken directly from the [PocketFlow Agent Tutorial](https://github.com/The-Pocket/PocketFlow/tree/main/cookbook/pocketflow-agent).
|
||||
2. **A2A Server Wrapper:** Code from the [google/A2A samples repository](https://github.com/google/A2A/tree/main/samples/python) (`common/` directory) provides the necessary infrastructure to host the agent as an A2A-compliant server. *Note: Minor modifications were made to the common server/client code to add detailed logging for educational purposes.*
|
||||
3. **The Bridge ([`task_manager.py`](task_manager.py)):** A custom `PocketFlowTaskManager` class acts as the bridge. It receives A2A requests (like `tasks/send`), extracts the user query, runs the PocketFlow `agent_flow`, takes the final result from the flow's shared state, and packages it back into an A2A `Task` object with the answer as an `Artifact`.
|
||||
|
||||
This demonstrates how a non-A2A agent framework can be exposed over the A2A protocol by implementing a specific `TaskManager`.
|
||||
|
||||
## Simplified Interaction Sequence
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client as "Client ([minimal_a2a_client.py](a2a_client.py))"
|
||||
participant Server as "Server (localhost:10003)"
|
||||
|
||||
Note over Client: User enters question
|
||||
Client->>+Server: POST / (JSON-RPC Request: tasks/send)
|
||||
Note over Server: Processes request internally (runs PocketFlow)
|
||||
Server-->>-Client: HTTP 200 OK (JSON-RPC Response: result=Task)
|
||||
Note over Client: Displays final answer
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
* Python 3.10+ (due to type hinting used in the A2A `common` code)
|
||||
* An OpenAI API Key
|
||||
|
||||
### Installation
|
||||
|
||||
|
||||
1. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Set your OpenAI API key as an environment variable:
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
Let's do a quick check to make sure your API key is working properly:
|
||||
|
||||
```bash
|
||||
python utils.py
|
||||
```
|
||||
3. Run the server from this directory:
|
||||
|
||||
```bash
|
||||
python a2a_server.py --port 10003
|
||||
```
|
||||
|
||||
You should see logs indicating the server has started on `http://localhost:10003`.
|
||||
|
||||
|
||||
4. Run the Client in a *separate terminal*
|
||||
|
||||
```bash
|
||||
python a2a_client.py --agent-url http://localhost:10003
|
||||
```
|
||||
|
||||
5. Follow the instructions in the client terminal to ask questions. Type `:q` or `quit` to exit the client.
|
||||
|
||||
## Example Interaction Logs
|
||||
|
||||
**(Server Log - showing internal PocketFlow steps)**
|
||||
|
||||
```
|
||||
2025-04-12 17:20:40,893 - __main__ - INFO - Starting PocketFlow A2A server on http://localhost:10003
|
||||
INFO: Started server process [677223]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://localhost:10003 (Press CTRL+C to quit)
|
||||
2025-04-12 17:20:57,647 - A2AServer - INFO - <- Received Request (ID: d3f3fb93350d47d9a94ca12bb62b656b):
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "d3f3fb93350d47d9a94ca12bb62b656b",
|
||||
"method": "tasks/send",
|
||||
"params": {
|
||||
"id": "46c3ce7b941a4fff9b8e3b644d6db5f4",
|
||||
"sessionId": "f3e12b8424c44241be881cd4bb8a269f",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Who won the Nobel Prize in Physics 2024?"
|
||||
}
|
||||
]
|
||||
},
|
||||
"acceptedOutputModes": [
|
||||
"text",
|
||||
"text/plain"
|
||||
]
|
||||
}
|
||||
}
|
||||
2025-04-12 17:20:57,647 - task_manager - INFO - Received task send request: 46c3ce7b941a4fff9b8e3b644d6db5f4
|
||||
2025-04-12 17:20:57,647 - common.server.task_manager - INFO - Upserting task 46c3ce7b941a4fff9b8e3b644d6db5f4
|
||||
2025-04-12 17:20:57,647 - task_manager - INFO - Running PocketFlow for task 46c3ce7b941a4fff9b8e3b644d6db5f4...
|
||||
🤔 Agent deciding what to do next...
|
||||
2025-04-12 17:20:59,213 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
|
||||
🔍 Agent decided to search for: 2024 Nobel Prize in Physics winner
|
||||
🌐 Searching the web for: 2024 Nobel Prize in Physics winner
|
||||
2025-04-12 17:20:59,974 - primp - INFO - response: https://lite.duckduckgo.com/lite/ 200
|
||||
📚 Found information, analyzing results...
|
||||
🤔 Agent deciding what to do next...
|
||||
2025-04-12 17:21:01,619 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
|
||||
💡 Agent decided to answer the question
|
||||
✍️ Crafting final answer...
|
||||
2025-04-12 17:21:03,833 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
|
||||
✅ Answer generated successfully
|
||||
2025-04-12 17:21:03,834 - task_manager - INFO - PocketFlow completed for task 46c3ce7b941a4fff9b8e3b644d6db5f4
|
||||
2025-04-12 17:21:03,834 - A2AServer - INFO - -> Response (ID: d3f3fb93350d47d9a94ca12bb62b656b):
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "d3f3fb93350d47d9a94ca12bb62b656b",
|
||||
"result": {
|
||||
"id": "46c3ce7b941a4fff9b8e3b644d6db5f4",
|
||||
"sessionId": "f3e12b8424c44241be881cd4bb8a269f",
|
||||
"status": {
|
||||
"state": "completed",
|
||||
"timestamp": "2025-04-12T17:21:03.834542"
|
||||
},
|
||||
"artifacts": [
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The 2024 Nobel Prize in Physics was awarded to John J. Hopfield and Geoffrey Hinton for their foundational discoveries and inventions that have significantly advanced the field of machine learning through the use of artificial neural networks. Their pioneering work has been crucial in the development and implementation of algorithms that enable machines to learn and process information in a manner that mimics human cognitive functions. This advancement in artificial intelligence technology has had a profound impact on numerous industries, facilitating innovations across various applications, from image and speech recognition to self-driving cars."
|
||||
}
|
||||
],
|
||||
"index": 0
|
||||
}
|
||||
],
|
||||
"history": []
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**(Client Log - showing request/response)**
|
||||
|
||||
```
|
||||
Connecting to agent at: http://localhost:10003
|
||||
Using Session ID: f3e12b8424c44241be881cd4bb8a269f
|
||||
|
||||
Enter your question (:q or quit to exit) > Who won the Nobel Prize in Physics 2024?
|
||||
Sending task 46c3ce7b941a4fff9b8e3b644d6db5f4...
|
||||
2025-04-12 17:20:57,643 - A2AClient - INFO - -> Sending Request (ID: d3f3fb93350d47d9a94ca12bb62b656b, Method: tasks/send):
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "d3f3fb93350d47d9a94ca12bb62b656b",
|
||||
"method": "tasks/send",
|
||||
"params": {
|
||||
"id": "46c3ce7b941a4fff9b8e3b644d6db5f4",
|
||||
"sessionId": "f3e12b8424c44241be881cd4bb8a269f",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Who won the Nobel Prize in Physics 2024?"
|
||||
}
|
||||
]
|
||||
},
|
||||
"acceptedOutputModes": [
|
||||
"text",
|
||||
"text/plain"
|
||||
]
|
||||
}
|
||||
}
|
||||
2025-04-12 17:21:03,835 - httpx - INFO - HTTP Request: POST http://localhost:10003 "HTTP/1.1 200 OK"
|
||||
2025-04-12 17:21:03,836 - A2AClient - INFO - <- Received HTTP Status 200 for Request (ID: d3f3fb93350d47d9a94ca12bb62b656b)
|
||||
2025-04-12 17:21:03,836 - A2AClient - INFO - <- Received Success Response (ID: d3f3fb93350d47d9a94ca12bb62b656b):
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "d3f3fb93350d47d9a94ca12bb62b656b",
|
||||
"result": {
|
||||
"id": "46c3ce7b941a4fff9b8e3b644d6db5f4",
|
||||
"sessionId": "f3e12b8424c44241be881cd4bb8a269f",
|
||||
"status": {
|
||||
"state": "completed",
|
||||
"timestamp": "2025-04-12T17:21:03.834542"
|
||||
},
|
||||
"artifacts": [
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The 2024 Nobel Prize in Physics was awarded to John J. Hopfield and Geoffrey Hinton for their foundational discoveries and inventions that have significantly advanced the field of machine learning through the use of artificial neural networks. Their pioneering work has been crucial in the development and implementation of algorithms that enable machines to learn and process information in a manner that mimics human cognitive functions. This advancement in artificial intelligence technology has had a profound impact on numerous industries, facilitating innovations across various applications, from image and speech recognition to self-driving cars."
|
||||
}
|
||||
],
|
||||
"index": 0
|
||||
}
|
||||
],
|
||||
"history": []
|
||||
}
|
||||
}
|
||||
Task 46c3ce7b941a4fff9b8e3b644d6db5f4 finished with state: completed
|
||||
|
||||
Agent Response:
|
||||
The 2024 Nobel Prize in Physics was awarded to John J. Hopfield and Geoffrey Hinton for their foundational discoveries and inventions that have significantly advanced the field of machine learning through the use of artificial neural networks. Their pioneering work has been crucial in the development and implementation of algorithms that enable machines to learn and process information in a manner that mimics human cognitive functions. This advancement in artificial intelligence technology has had a profound impact on numerous industries, facilitating innovations across various applications, from image and speech recognition to self-driving cars.
|
||||
```
|
||||
|
||||
## Key A2A Integration Points
|
||||
|
||||
To make the PocketFlow agent A2A-compatible, the following were essential:
|
||||
|
||||
1. **A2A Server ([`common/server/server.py`](common/server/server.py)):** An ASGI application (using Starlette/Uvicorn) that listens for HTTP POST requests, parses JSON-RPC, and routes requests based on the `method` field.
|
||||
2. **A2A Data Types ([`common/types.py`](common/types.py)):** Pydantic models defining the structure of A2A messages, tasks, artifacts, errors, and the agent card, ensuring compliance with the `a2a.json` specification.
|
||||
3. **Task Manager ([`task_manager.py`](task_manager.py)):** A custom class (`PocketFlowTaskManager`) inheriting from the common `InMemoryTaskManager`. Its primary role is implementing the `on_send_task` method (and potentially others like `on_send_task_subscribe` if streaming were supported). This method:
|
||||
* Receives the validated A2A `SendTaskRequest`.
|
||||
* Extracts the user's query (`TextPart`) from the request's `message`.
|
||||
* Initializes the PocketFlow `shared_data` dictionary.
|
||||
* Creates and runs the PocketFlow `agent_flow`.
|
||||
* Retrieves the final answer from the `shared_data` dictionary *after* the flow completes.
|
||||
* Updates the task's state (e.g., to `COMPLETED` or `FAILED`) in the `InMemoryTaskManager`'s store.
|
||||
* Packages the final answer into an A2A `Artifact` containing a `TextPart`.
|
||||
* Constructs the final A2A `Task` object for the response.
|
||||
4. **Agent Card ([`a2a_server.py`](a2a_server.py)):** A Pydantic model (`AgentCard`) defining the agent's metadata (name, description, URL, capabilities, skills) served at `/.well-known/agent.json`.
|
||||
5. **Server Entry Point ([`a2a_server.py`](a2a_server.py)):** A script that initializes the `AgentCard`, the `PocketFlowTaskManager`, and the `A2AServer`, then starts the Uvicorn server process.
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
# FILE: minimal_a2a_client.py
|
||||
import asyncio
|
||||
import asyncclick as click # Using asyncclick for async main
|
||||
from uuid import uuid4
|
||||
import json # For potentially inspecting raw errors
|
||||
import anyio
|
||||
import functools
|
||||
import logging
|
||||
|
||||
# Import from the common directory placed alongside this script
|
||||
from common.client import A2AClient
|
||||
|
|
@ -15,6 +15,17 @@ from common.types import (
|
|||
JSONRPCResponse # Potentially useful for error checking
|
||||
)
|
||||
|
||||
# --- Configure logging ---
|
||||
# Set level to INFO to see client requests and responses
|
||||
# Set level to DEBUG to see raw response bodies and SSE data lines
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
# Optionally silence overly verbose libraries
|
||||
# logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
# logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
# --- ANSI Colors (Optional but helpful) ---
|
||||
C_RED = "\x1b[31m"
|
||||
C_GREEN = "\x1b[32m"
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
# FILE: pocketflow_a2a_agent/__main__.py
|
||||
import click
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -10,7 +9,18 @@ from common.types import AgentCard, AgentCapabilities, AgentSkill, MissingAPIKey
|
|||
# Import your custom TaskManager (which now imports from your original files)
|
||||
from task_manager import PocketFlowTaskManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# --- Configure logging ---
|
||||
# Set level to INFO to see server start, requests, responses
|
||||
# Set level to DEBUG to see raw response bodies from client
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
# Optionally silence overly verbose libraries
|
||||
# logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
# logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
# logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@click.command()
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
import httpx
|
||||
from httpx_sse import connect_sse
|
||||
from typing import Any, AsyncIterable
|
||||
from common.types import (
|
||||
AgentCard,
|
||||
GetTaskRequest,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
JSONRPCError,
|
||||
GetTaskResponse,
|
||||
CancelTaskResponse,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
A2AClientHTTPError,
|
||||
A2AClientJSONError,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
Task,
|
||||
TaskPushNotificationConfig,
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
)
|
||||
import json
|
||||
import logging
|
||||
|
||||
# Configure a logger specific to the client
|
||||
logger = logging.getLogger("A2AClient")
|
||||
|
||||
class A2AClientError(Exception):
|
||||
"""Base class for A2A client errors"""
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
class RpcError(Exception):
|
||||
code: int
|
||||
data: Any = None
|
||||
def __init__(self, code: int, message: str, data: Any = None):
|
||||
super().__init__(message)
|
||||
self.name = "RpcError"
|
||||
self.code = code
|
||||
self.data = data
|
||||
|
||||
class A2AClient:
|
||||
def __init__(self, agent_card: AgentCard = None, url: str = None):
|
||||
if agent_card:
|
||||
self.url = agent_card.url.rstrip("/")
|
||||
elif url:
|
||||
self.url = url.rstrip("/")
|
||||
else:
|
||||
raise ValueError("Must provide either agent_card or url")
|
||||
self.fetchImpl = httpx.AsyncClient(timeout=None)
|
||||
|
||||
def _generateRequestId(self):
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
|
||||
req_id = request.id
|
||||
req_method = request.method
|
||||
req_dump = request.model_dump(exclude_none=True)
|
||||
|
||||
logger.info(f"-> Sending Request (ID: {req_id}, Method: {req_method}):\n{json.dumps(req_dump, indent=2)}")
|
||||
|
||||
try:
|
||||
response = await self.fetchImpl.post(
|
||||
self.url, json=req_dump, timeout=60.0
|
||||
)
|
||||
logger.info(f"<- Received HTTP Status {response.status_code} for Request (ID: {req_id})")
|
||||
response_text = await response.aread()
|
||||
logger.debug(f"Raw Response Body (ID: {req_id}):\n{response_text.decode('utf-8', errors='replace')}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
json_response = json.loads(response_text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode JSON response (ID: {req_id}): {e}")
|
||||
raise A2AClientJSONError(f"Failed to decode JSON: {e}") from e
|
||||
|
||||
if "error" in json_response and json_response["error"] is not None:
|
||||
rpc_error = json_response["error"]
|
||||
logger.warning(f"<- Received JSON-RPC Error (ID: {req_id}): Code={rpc_error.get('code')}, Msg='{rpc_error.get('message')}'")
|
||||
raise RpcError(rpc_error.get("code", -32000), rpc_error.get("message", "Unknown RPC Error"), rpc_error.get("data"))
|
||||
|
||||
logger.info(f"<- Received Success Response (ID: {req_id}):\n{json.dumps(json_response, indent=2)}")
|
||||
return json_response
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP Error for Request (ID: {req_id}): {e.response.status_code} - {e.request.url}")
|
||||
error_body = await e.response.aread()
|
||||
raise A2AClientHTTPError(e.response.status_code, f"{e}. Body: {error_body.decode('utf-8', errors='replace')}") from e
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request Error for (ID: {req_id}): {e}")
|
||||
raise A2AClientError(f"Network or request error: {e}") from e
|
||||
except RpcError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during request (ID: {req_id}): {e}", exc_info=True)
|
||||
raise A2AClientError(f"Unexpected error: {e}") from e
|
||||
|
||||
async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse:
|
||||
request = SendTaskRequest(params=payload)
|
||||
response_dict = await self._send_request(request)
|
||||
return SendTaskResponse(**response_dict)
|
||||
|
||||
async def send_task_streaming(
|
||||
self, payload: dict[str, Any]
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
request = SendTaskStreamingRequest(params=payload)
|
||||
req_id = request.id
|
||||
req_dump = request.model_dump(exclude_none=True)
|
||||
|
||||
logger.info(f"-> Sending Streaming Request (ID: {req_id}, Method: {request.method}):\n{json.dumps(req_dump, indent=2)}")
|
||||
|
||||
try:
|
||||
async with self.fetchImpl.stream("POST", self.url, json=req_dump, timeout=None) as response:
|
||||
logger.info(f"<- Received HTTP Status {response.status_code} for Streaming Request (ID: {req_id})")
|
||||
response.raise_for_status()
|
||||
|
||||
buffer = ""
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
if buffer.startswith("data:"):
|
||||
data_str = buffer[len("data:"):].strip()
|
||||
logger.debug(f"Received SSE Data Line (ID: {req_id}): {data_str}")
|
||||
try:
|
||||
sse_data_dict = json.loads(data_str)
|
||||
yield SendTaskStreamingResponse(**sse_data_dict)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode SSE JSON (ID: {req_id}): {e}. Data: '{data_str}'")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing SSE data (ID: {req_id}): {e}. Data: '{data_str}'", exc_info=True)
|
||||
elif buffer:
|
||||
logger.debug(f"Received non-data SSE line (ID: {req_id}): {buffer}")
|
||||
buffer = ""
|
||||
else:
|
||||
buffer += line + "\n"
|
||||
|
||||
if buffer:
|
||||
logger.warning(f"SSE stream ended with partial data in buffer (ID: {req_id}): {buffer}")
|
||||
|
||||
logger.info(f"SSE Stream ended for request ID: {req_id}")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP Error during streaming connection (ID: {req_id}): {e.response.status_code} - {e.request.url}")
|
||||
error_body = await e.response.aread()
|
||||
raise A2AClientHTTPError(e.response.status_code, f"{e}. Body: {error_body.decode('utf-8', errors='replace')}") from e
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request Error during streaming (ID: {req_id}): {e}")
|
||||
raise A2AClientError(f"Network or request error during streaming: {e}") from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during streaming (ID: {req_id}): {e}", exc_info=True)
|
||||
raise A2AClientError(f"Unexpected streaming error: {e}") from e
|
||||
|
||||
async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse:
|
||||
request = GetTaskRequest(params=payload)
|
||||
response_dict = await self._send_request(request)
|
||||
return GetTaskResponse(**response_dict)
|
||||
|
||||
async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse:
|
||||
request = CancelTaskRequest(params=payload)
|
||||
response_dict = await self._send_request(request)
|
||||
return CancelTaskResponse(**response_dict)
|
||||
|
||||
async def set_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> SetTaskPushNotificationResponse:
|
||||
request = SetTaskPushNotificationRequest(params=payload)
|
||||
response_dict = await self._send_request(request)
|
||||
return SetTaskPushNotificationResponse(**response_dict)
|
||||
|
||||
async def get_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> GetTaskPushNotificationResponse:
|
||||
request = GetTaskPushNotificationRequest(params=payload)
|
||||
response_dict = await self._send_request(request)
|
||||
return GetTaskPushNotificationResponse(**response_dict)
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from starlette.requests import Request
|
||||
from common.types import (
|
||||
A2ARequest,
|
||||
JSONRPCResponse,
|
||||
InvalidRequestError,
|
||||
JSONParseError,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SendTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
InternalError,
|
||||
AgentCard,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
Message,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
import json
|
||||
from typing import AsyncIterable, Any
|
||||
from common.server.task_manager import TaskManager
|
||||
|
||||
import logging
|
||||
|
||||
# Configure a logger specific to the server
|
||||
logger = logging.getLogger("A2AServer")
|
||||
|
||||
|
||||
class A2AServer:
|
||||
def __init__(
|
||||
self,
|
||||
host="0.0.0.0",
|
||||
port=5000,
|
||||
endpoint="/",
|
||||
agent_card: AgentCard = None,
|
||||
task_manager: TaskManager = None,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.endpoint = endpoint
|
||||
self.task_manager = task_manager
|
||||
self.agent_card = agent_card
|
||||
self.app = Starlette()
|
||||
self.app.add_route(self.endpoint, self._process_request, methods=["POST"])
|
||||
self.app.add_route(
|
||||
"/.well-known/agent.json", self._get_agent_card, methods=["GET"]
|
||||
)
|
||||
|
||||
def start(self):
|
||||
if self.agent_card is None:
|
||||
raise ValueError("agent_card is not defined")
|
||||
|
||||
if self.task_manager is None:
|
||||
raise ValueError("request_handler is not defined")
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Basic logging config moved to __main__.py for application-level control
|
||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||
|
||||
def _get_agent_card(self, request: Request) -> JSONResponse:
|
||||
logger.info("Serving Agent Card request")
|
||||
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
|
||||
|
||||
async def _process_request(self, request: Request):
|
||||
request_id_for_log = "N/A" # Default if parsing fails early
|
||||
raw_body = b""
|
||||
try:
|
||||
# Log raw body first
|
||||
raw_body = await request.body()
|
||||
body = json.loads(raw_body) # Attempt parsing
|
||||
request_id_for_log = body.get("id", "N/A") # Get ID if possible
|
||||
logger.info(f"<- Received Request (ID: {request_id_for_log}):\n{json.dumps(body, indent=2)}")
|
||||
|
||||
json_rpc_request = A2ARequest.validate_python(body)
|
||||
|
||||
# Route based on method (same as before)
|
||||
if isinstance(json_rpc_request, GetTaskRequest):
|
||||
result = await self.task_manager.on_get_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SendTaskRequest):
|
||||
result = await self.task_manager.on_send_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
|
||||
result = await self.task_manager.on_send_task_subscribe(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, CancelTaskRequest):
|
||||
result = await self.task_manager.on_cancel_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
|
||||
result = await self.task_manager.on_set_task_push_notification(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
|
||||
result = await self.task_manager.on_get_task_push_notification(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
|
||||
result = await self.task_manager.on_resubscribe_to_task(
|
||||
json_rpc_request
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unexpected request type: {type(json_rpc_request)}")
|
||||
raise ValueError(f"Unexpected request type: {type(request)}")
|
||||
|
||||
return self._create_response(result) # Pass result to response creation
|
||||
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
logger.error(f"JSON Parse Error for Request body: <<<{raw_body.decode('utf-8', errors='replace')}>>>\nError: {e}")
|
||||
return self._handle_exception(e, request_id_for_log) # Pass ID if known
|
||||
except ValidationError as e:
|
||||
logger.error(f"Request Validation Error (ID: {request_id_for_log}): {e.json()}")
|
||||
return self._handle_exception(e, request_id_for_log)
|
||||
except Exception as e:
|
||||
logger.error(f"Unhandled Exception processing request (ID: {request_id_for_log}): {e}", exc_info=True)
|
||||
return self._handle_exception(e, request_id_for_log) # Pass ID if known
|
||||
|
||||
def _handle_exception(self, e: Exception, req_id=None) -> JSONResponse: # Accept req_id
|
||||
if isinstance(e, json.decoder.JSONDecodeError):
|
||||
json_rpc_error = JSONParseError()
|
||||
elif isinstance(e, ValidationError):
|
||||
json_rpc_error = InvalidRequestError(data=json.loads(e.json()))
|
||||
else:
|
||||
# Log the full exception details
|
||||
logger.error(f"Internal Server Error (ReqID: {req_id}): {e}", exc_info=True)
|
||||
json_rpc_error = InternalError(message=f"Internal Server Error: {type(e).__name__}")
|
||||
|
||||
response = JSONRPCResponse(id=req_id, error=json_rpc_error)
|
||||
response_dump = response.model_dump(exclude_none=True)
|
||||
logger.info(f"-> Sending Error Response (ReqID: {req_id}):\n{json.dumps(response_dump, indent=2)}")
|
||||
# A2A errors are still sent with HTTP 200
|
||||
return JSONResponse(response_dump, status_code=200)
|
||||
|
||||
def _create_response(self, result: Any) -> JSONResponse | EventSourceResponse:
|
||||
if isinstance(result, AsyncIterable):
|
||||
# Streaming response
|
||||
async def event_generator(result_stream) -> AsyncIterable[dict[str, str]]:
|
||||
stream_request_id = None # Capture ID from the first event if possible
|
||||
try:
|
||||
async for item in result_stream:
|
||||
# Log each streamed item
|
||||
response_json = item.model_dump_json(exclude_none=True)
|
||||
stream_request_id = item.id # Update ID
|
||||
logger.info(f"-> Sending SSE Event (ID: {stream_request_id}):\n{json.dumps(json.loads(response_json), indent=2)}")
|
||||
yield {"data": response_json}
|
||||
logger.info(f"SSE Stream ended for request ID: {stream_request_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during SSE generation (ReqID: {stream_request_id}): {e}", exc_info=True)
|
||||
# Optionally yield an error event if the protocol allows/requires it
|
||||
# error_payload = JSONRPCResponse(id=stream_request_id, error=InternalError(message=f"SSE Error: {e}"))
|
||||
# yield {"data": error_payload.model_dump_json(exclude_none=True)}
|
||||
|
||||
logger.info("Starting SSE stream...") # Log stream start
|
||||
return EventSourceResponse(event_generator(result))
|
||||
elif isinstance(result, JSONRPCResponse):
|
||||
# Standard JSON response
|
||||
response_dump = result.model_dump(exclude_none=True)
|
||||
log_id = result.id if result.id is not None else "N/A (Notification?)"
|
||||
log_prefix = "->"
|
||||
log_type = "Response"
|
||||
if result.error:
|
||||
log_prefix = "-> Sending Error"
|
||||
log_type = "Error Response"
|
||||
|
||||
logger.info(f"{log_prefix} {log_type} (ID: {log_id}):\n{json.dumps(response_dump, indent=2)}")
|
||||
return JSONResponse(response_dump)
|
||||
else:
|
||||
# This should ideally not happen if task manager returns correctly
|
||||
logger.error(f"Task manager returned unexpected type: {type(result)}")
|
||||
err_resp = JSONRPCResponse(id=None, error=InternalError(message="Invalid internal response type"))
|
||||
return JSONResponse(err_resp.model_dump(exclude_none=True), status_code=500)
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
# Research Agent
|
||||
|
||||
This project demonstrates a simple yet powerful LLM-powered research agent. This implementation is based directly on the tutorial: [LLM Agents are simply Graph — Tutorial For Dummies](https://zacharyhuang.substack.com/p/llm-agent-internal-as-a-graph-tutorial).
|
||||
|
||||
👉 Run the tutorial in your browser: [Try Google Colab Notebook](
|
||||
https://colab.research.google.com/github/The-Pocket/PocketFlow/blob/main/cookbook/pocketflow-agent/demo.ipynb)
|
||||
|
||||
## Features
|
||||
|
||||
- Performs web searches to gather information
|
||||
- Makes decisions about when to search vs. when to answer
|
||||
- Generates comprehensive answers based on research findings
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install the packages you need with this simple command:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Let's get your OpenAI API key ready:
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
3. Let's do a quick check to make sure your API key is working properly:
|
||||
|
||||
```bash
|
||||
python utils.py
|
||||
```
|
||||
|
||||
This will test both the LLM call and web search features. If you see responses, you're good to go!
|
||||
|
||||
4. Try out the agent with the default question (about Nobel Prize winners):
|
||||
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
5. Got a burning question? Ask anything you want by using the `--` prefix:
|
||||
|
||||
```bash
|
||||
python main.py --"What is quantum computing?"
|
||||
```
|
||||
|
||||
## How It Works?
|
||||
|
||||
The magic happens through a simple but powerful graph structure with three main parts:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[DecideAction] -->|"search"| B[SearchWeb]
|
||||
A -->|"answer"| C[AnswerQuestion]
|
||||
B -->|"decide"| A
|
||||
```
|
||||
|
||||
Here's what each part does:
|
||||
1. **DecideAction**: The brain that figures out whether to search or answer
|
||||
2. **SearchWeb**: The researcher that goes out and finds information
|
||||
3. **AnswerQuestion**: The writer that crafts the final answer
|
||||
|
||||
Here's what's in each file:
|
||||
- [`main.py`](./main.py): The starting point - runs the whole show!
|
||||
- [`flow.py`](./flow.py): Connects everything together into a smart agent
|
||||
- [`nodes.py`](./nodes.py): The building blocks that make decisions and take actions
|
||||
- [`utils.py`](./utils.py): Helper functions for talking to the LLM and searching the web
|
||||
|
|
@ -1,86 +0,0 @@
|
|||
import httpx
|
||||
from httpx_sse import connect_sse
|
||||
from typing import Any, AsyncIterable
|
||||
from common.types import (
|
||||
AgentCard,
|
||||
GetTaskRequest,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
JSONRPCRequest,
|
||||
GetTaskResponse,
|
||||
CancelTaskResponse,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
A2AClientHTTPError,
|
||||
A2AClientJSONError,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
)
|
||||
import json
|
||||
|
||||
|
||||
class A2AClient:
|
||||
def __init__(self, agent_card: AgentCard = None, url: str = None):
|
||||
if agent_card:
|
||||
self.url = agent_card.url
|
||||
elif url:
|
||||
self.url = url
|
||||
else:
|
||||
raise ValueError("Must provide either agent_card or url")
|
||||
|
||||
async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse:
|
||||
request = SendTaskRequest(params=payload)
|
||||
return SendTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def send_task_streaming(
|
||||
self, payload: dict[str, Any]
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
request = SendTaskStreamingRequest(params=payload)
|
||||
with httpx.Client(timeout=None) as client:
|
||||
with connect_sse(
|
||||
client, "POST", self.url, json=request.model_dump()
|
||||
) as event_source:
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
yield SendTaskStreamingResponse(**json.loads(sse.data))
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
except httpx.RequestError as e:
|
||||
raise A2AClientHTTPError(400, str(e)) from e
|
||||
|
||||
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
# Image generation could take time, adding timeout
|
||||
response = await client.post(
|
||||
self.url, json=request.model_dump(), timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
|
||||
async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse:
|
||||
request = GetTaskRequest(params=payload)
|
||||
return GetTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse:
|
||||
request = CancelTaskRequest(params=payload)
|
||||
return CancelTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def set_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> SetTaskPushNotificationResponse:
|
||||
request = SetTaskPushNotificationRequest(params=payload)
|
||||
return SetTaskPushNotificationResponse(**await self._send_request(request))
|
||||
|
||||
async def get_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> GetTaskPushNotificationResponse:
|
||||
request = GetTaskPushNotificationRequest(params=payload)
|
||||
return GetTaskPushNotificationResponse(**await self._send_request(request))
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from starlette.requests import Request
|
||||
from common.types import (
|
||||
A2ARequest,
|
||||
JSONRPCResponse,
|
||||
InvalidRequestError,
|
||||
JSONParseError,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SendTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
InternalError,
|
||||
AgentCard,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
import json
|
||||
from typing import AsyncIterable, Any
|
||||
from common.server.task_manager import TaskManager
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class A2AServer:
|
||||
def __init__(
|
||||
self,
|
||||
host="0.0.0.0",
|
||||
port=5000,
|
||||
endpoint="/",
|
||||
agent_card: AgentCard = None,
|
||||
task_manager: TaskManager = None,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.endpoint = endpoint
|
||||
self.task_manager = task_manager
|
||||
self.agent_card = agent_card
|
||||
self.app = Starlette()
|
||||
self.app.add_route(self.endpoint, self._process_request, methods=["POST"])
|
||||
self.app.add_route(
|
||||
"/.well-known/agent.json", self._get_agent_card, methods=["GET"]
|
||||
)
|
||||
|
||||
def start(self):
|
||||
if self.agent_card is None:
|
||||
raise ValueError("agent_card is not defined")
|
||||
|
||||
if self.task_manager is None:
|
||||
raise ValueError("request_handler is not defined")
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||
|
||||
def _get_agent_card(self, request: Request) -> JSONResponse:
|
||||
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
|
||||
|
||||
async def _process_request(self, request: Request):
|
||||
try:
|
||||
body = await request.json()
|
||||
json_rpc_request = A2ARequest.validate_python(body)
|
||||
|
||||
if isinstance(json_rpc_request, GetTaskRequest):
|
||||
result = await self.task_manager.on_get_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SendTaskRequest):
|
||||
result = await self.task_manager.on_send_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
|
||||
result = await self.task_manager.on_send_task_subscribe(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, CancelTaskRequest):
|
||||
result = await self.task_manager.on_cancel_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
|
||||
result = await self.task_manager.on_set_task_push_notification(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
|
||||
result = await self.task_manager.on_get_task_push_notification(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
|
||||
result = await self.task_manager.on_resubscribe_to_task(
|
||||
json_rpc_request
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unexpected request type: {type(json_rpc_request)}")
|
||||
raise ValueError(f"Unexpected request type: {type(request)}")
|
||||
|
||||
return self._create_response(result)
|
||||
|
||||
except Exception as e:
|
||||
return self._handle_exception(e)
|
||||
|
||||
def _handle_exception(self, e: Exception) -> JSONResponse:
|
||||
if isinstance(e, json.decoder.JSONDecodeError):
|
||||
json_rpc_error = JSONParseError()
|
||||
elif isinstance(e, ValidationError):
|
||||
json_rpc_error = InvalidRequestError(data=json.loads(e.json()))
|
||||
else:
|
||||
logger.error(f"Unhandled exception: {e}")
|
||||
json_rpc_error = InternalError()
|
||||
|
||||
response = JSONRPCResponse(id=None, error=json_rpc_error)
|
||||
return JSONResponse(response.model_dump(exclude_none=True), status_code=400)
|
||||
|
||||
def _create_response(self, result: Any) -> JSONResponse | EventSourceResponse:
|
||||
if isinstance(result, AsyncIterable):
|
||||
|
||||
async def event_generator(result) -> AsyncIterable[dict[str, str]]:
|
||||
async for item in result:
|
||||
yield {"data": item.model_dump_json(exclude_none=True)}
|
||||
|
||||
return EventSourceResponse(event_generator(result))
|
||||
elif isinstance(result, JSONRPCResponse):
|
||||
return JSONResponse(result.model_dump(exclude_none=True))
|
||||
else:
|
||||
logger.error(f"Unexpected result type: {type(result)}")
|
||||
raise ValueError(f"Unexpected result type: {type(result)}")
|
||||
Loading…
Reference in New Issue