a2a init
This commit is contained in:
parent
447a2d47d2
commit
cd02d65efe
|
|
@ -0,0 +1,67 @@
|
|||
# Research Agent
|
||||
|
||||
This project demonstrates a simple yet powerful LLM-powered research agent. This implementation is based directly on the tutorial: [LLM Agents are simply Graph — Tutorial For Dummies](https://zacharyhuang.substack.com/p/llm-agent-internal-as-a-graph-tutorial).
|
||||
|
||||
👉 Run the tutorial in your browser: [Try Google Colab Notebook](
|
||||
https://colab.research.google.com/github/The-Pocket/PocketFlow/blob/main/cookbook/pocketflow-agent/demo.ipynb)
|
||||
|
||||
## Features
|
||||
|
||||
- Performs web searches to gather information
|
||||
- Makes decisions about when to search vs. when to answer
|
||||
- Generates comprehensive answers based on research findings
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install the packages you need with this simple command:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Let's get your OpenAI API key ready:
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
3. Let's do a quick check to make sure your API key is working properly:
|
||||
|
||||
```bash
|
||||
python utils.py
|
||||
```
|
||||
|
||||
This will test both the LLM call and web search features. If you see responses, you're good to go!
|
||||
|
||||
4. Try out the agent with the default question (about Nobel Prize winners):
|
||||
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
5. Got a burning question? Ask anything you want by using the `--` prefix:
|
||||
|
||||
```bash
|
||||
python main.py --"What is quantum computing?"
|
||||
```
|
||||
|
||||
## How It Works?
|
||||
|
||||
The magic happens through a simple but powerful graph structure with three main parts:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[DecideAction] -->|"search"| B[SearchWeb]
|
||||
A -->|"answer"| C[AnswerQuestion]
|
||||
B -->|"decide"| A
|
||||
```
|
||||
|
||||
Here's what each part does:
|
||||
1. **DecideAction**: The brain that figures out whether to search or answer
|
||||
2. **SearchWeb**: The researcher that goes out and finds information
|
||||
3. **AnswerQuestion**: The writer that crafts the final answer
|
||||
|
||||
Here's what's in each file:
|
||||
- [`main.py`](./main.py): The starting point - runs the whole show!
|
||||
- [`flow.py`](./flow.py): Connects everything together into a smart agent
|
||||
- [`nodes.py`](./nodes.py): The building blocks that make decisions and take actions
|
||||
- [`utils.py`](./utils.py): Helper functions for talking to the LLM and searching the web
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
# FILE: minimal_a2a_client.py
|
||||
import asyncio
|
||||
import asyncclick as click # Using asyncclick for async main
|
||||
from uuid import uuid4
|
||||
import json # For potentially inspecting raw errors
|
||||
|
||||
# Import from the common directory placed alongside this script
|
||||
from common.client import A2AClient
|
||||
from common.types import (
|
||||
TaskState,
|
||||
A2AClientError,
|
||||
TextPart, # Used to construct the message
|
||||
JSONRPCResponse # Potentially useful for error checking
|
||||
)
|
||||
|
||||
# --- ANSI Colors (Optional but helpful) ---
|
||||
C_RED = "\x1b[31m"
|
||||
C_GREEN = "\x1b[32m"
|
||||
C_YELLOW = "\x1b[33m"
|
||||
C_BLUE = "\x1b[34m"
|
||||
C_MAGENTA = "\x1b[35m"
|
||||
C_CYAN = "\x1b[36m"
|
||||
C_WHITE = "\x1b[37m"
|
||||
C_GRAY = "\x1b[90m"
|
||||
C_BRIGHT_MAGENTA = "\x1b[95m"
|
||||
C_RESET = "\x1b[0m"
|
||||
C_BOLD = "\x1b[1m"
|
||||
|
||||
def colorize(color, text):
|
||||
return f"{color}{text}{C_RESET}"
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--agent-url",
|
||||
default="http://localhost:10003", # Default to the port used in server __main__
|
||||
help="URL of the PocketFlow A2A agent server.",
|
||||
)
|
||||
async def cli(agent_url: str):
|
||||
"""Minimal CLI client to interact with an A2A agent."""
|
||||
|
||||
print(colorize(C_BRIGHT_MAGENTA, f"Connecting to agent at: {agent_url}"))
|
||||
|
||||
# Instantiate the client - only URL is needed if not fetching card first
|
||||
# Note: The PocketFlow wrapper doesn't expose much via the AgentCard,
|
||||
# so we skip fetching it for this minimal client.
|
||||
client = A2AClient(url=agent_url)
|
||||
|
||||
sessionId = uuid4().hex # Generate a new session ID for this run
|
||||
print(colorize(C_GRAY, f"Using Session ID: {sessionId}"))
|
||||
|
||||
while True:
|
||||
taskId = uuid4().hex # Generate a new task ID for each interaction
|
||||
try:
|
||||
prompt = await click.prompt(
|
||||
colorize(C_CYAN, "\nEnter your question (:q or quit to exit)"),
|
||||
prompt_suffix=" > ",
|
||||
type=str # Ensure prompt returns string
|
||||
)
|
||||
except RuntimeError:
|
||||
# This can happen if stdin is closed, e.g., in some test runners
|
||||
print(colorize(C_RED, "Failed to read input. Exiting."))
|
||||
break
|
||||
|
||||
|
||||
if prompt.lower() in [":q", "quit"]:
|
||||
print(colorize(C_YELLOW, "Exiting client."))
|
||||
break
|
||||
|
||||
# --- Construct A2A Request Payload ---
|
||||
payload = {
|
||||
"id": taskId,
|
||||
"sessionId": sessionId,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"type": "text", # Explicitly match TextPart structure
|
||||
"text": prompt,
|
||||
}
|
||||
],
|
||||
},
|
||||
"acceptedOutputModes": ["text", "text/plain"], # What the client wants back
|
||||
# historyLength could be added if needed
|
||||
}
|
||||
|
||||
print(colorize(C_GRAY, f"Sending task {taskId}..."))
|
||||
|
||||
try:
|
||||
# --- Send Task (Non-Streaming) ---
|
||||
response = await client.send_task(payload)
|
||||
|
||||
# --- Process Response ---
|
||||
if response.error:
|
||||
print(colorize(C_RED, f"Error from agent (Code: {response.error.code}): {response.error.message}"))
|
||||
if response.error.data:
|
||||
print(colorize(C_GRAY, f"Error Data: {response.error.data}"))
|
||||
elif response.result:
|
||||
task_result = response.result
|
||||
print(colorize(C_GREEN, f"Task {task_result.id} finished with state: {task_result.status.state}"))
|
||||
|
||||
final_answer = "Agent did not provide a final artifact."
|
||||
# Extract answer from artifacts (as implemented in PocketFlowTaskManager)
|
||||
if task_result.artifacts:
|
||||
try:
|
||||
# Find the first text part in the first artifact
|
||||
first_artifact = task_result.artifacts[0]
|
||||
first_text_part = next(
|
||||
(p for p in first_artifact.parts if isinstance(p, TextPart)),
|
||||
None
|
||||
)
|
||||
if first_text_part:
|
||||
final_answer = first_text_part.text
|
||||
else:
|
||||
final_answer = f"(Non-text artifact received: {first_artifact.parts})"
|
||||
except (IndexError, AttributeError, TypeError) as e:
|
||||
final_answer = f"(Error parsing artifact: {e})"
|
||||
elif task_result.status.message and task_result.status.message.parts:
|
||||
# Fallback to status message if no artifact
|
||||
try:
|
||||
first_text_part = next(
|
||||
(p for p in task_result.status.message.parts if isinstance(p, TextPart)),
|
||||
None
|
||||
)
|
||||
if first_text_part:
|
||||
final_answer = f"(Final status message: {first_text_part.text})"
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
final_answer = f"(Error parsing status message: {e})"
|
||||
|
||||
|
||||
print(colorize(C_BOLD + C_WHITE, f"\nAgent Response:\n{final_answer}"))
|
||||
|
||||
else:
|
||||
# Should not happen if error is None
|
||||
print(colorize(C_YELLOW, "Received response with no result and no error."))
|
||||
|
||||
except A2AClientError as e:
|
||||
print(colorize(C_RED, f"\nClient Error: {e}"))
|
||||
except Exception as e:
|
||||
print(colorize(C_RED, f"\nAn unexpected error occurred: {e}"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(cli())
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
# FILE: pocketflow_a2a_agent/__main__.py
|
||||
import click
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Import from the common code you copied
|
||||
from common.server import A2AServer
|
||||
from common.types import AgentCard, AgentCapabilities, AgentSkill, MissingAPIKeyError
|
||||
|
||||
# Import your custom TaskManager (which now imports from your original files)
|
||||
from .task_manager import PocketFlowTaskManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@click.command()
|
||||
@click.option("--host", "host", default="localhost")
|
||||
@click.option("--port", "port", default=10003) # Use a different port from other agents
|
||||
def main(host, port):
|
||||
"""Starts the PocketFlow A2A Agent server."""
|
||||
try:
|
||||
# Check for necessary API keys (add others if needed)
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
raise MissingAPIKeyError("OPENAI_API_KEY environment variable not set.")
|
||||
|
||||
# --- Define the Agent Card ---
|
||||
capabilities = AgentCapabilities(
|
||||
streaming=False, # This simple implementation is synchronous
|
||||
pushNotifications=False,
|
||||
stateTransitionHistory=False # PocketFlow state isn't exposed via A2A history
|
||||
)
|
||||
skill = AgentSkill(
|
||||
id="web_research_qa",
|
||||
name="Web Research and Answering",
|
||||
description="Answers questions using web search results when necessary.",
|
||||
tags=["research", "qa", "web search"],
|
||||
examples=[
|
||||
"Who won the Nobel Prize in Physics 2024?",
|
||||
"What is quantum computing?",
|
||||
"Summarize the latest news about AI.",
|
||||
],
|
||||
# Input/Output modes defined in the TaskManager
|
||||
inputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
|
||||
outputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
|
||||
)
|
||||
agent_card = AgentCard(
|
||||
name="PocketFlow Research Agent (A2A Wrapped)",
|
||||
description="A simple research agent based on PocketFlow, made accessible via A2A.",
|
||||
url=f"http://{host}:{port}/", # The endpoint A2A clients will use
|
||||
version="0.1.0-a2a",
|
||||
capabilities=capabilities,
|
||||
skills=[skill],
|
||||
# Assuming no specific provider or auth for this example
|
||||
provider=None,
|
||||
authentication=None,
|
||||
defaultInputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
|
||||
defaultOutputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
|
||||
)
|
||||
|
||||
# --- Initialize and Start Server ---
|
||||
task_manager = PocketFlowTaskManager() # Instantiate your custom manager
|
||||
server = A2AServer(
|
||||
agent_card=agent_card,
|
||||
task_manager=task_manager,
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
logger.info(f"Starting PocketFlow A2A server on http://{host}:{port}")
|
||||
server.start()
|
||||
|
||||
except MissingAPIKeyError as e:
|
||||
logger.error(f"Configuration Error: {e}")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during server startup: {e}", exc_info=True)
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from .client import A2AClient
|
||||
from .card_resolver import A2ACardResolver
|
||||
|
||||
__all__ = ["A2AClient", "A2ACardResolver"]
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
import httpx
|
||||
from common.types import (
|
||||
AgentCard,
|
||||
A2AClientJSONError,
|
||||
)
|
||||
import json
|
||||
|
||||
|
||||
class A2ACardResolver:
|
||||
def __init__(self, base_url, agent_card_path="/.well-known/agent.json"):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.agent_card_path = agent_card_path.lstrip("/")
|
||||
|
||||
def get_agent_card(self) -> AgentCard:
|
||||
with httpx.Client() as client:
|
||||
response = client.get(self.base_url + "/" + self.agent_card_path)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
return AgentCard(**response.json())
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
import httpx
|
||||
from httpx_sse import connect_sse
|
||||
from typing import Any, AsyncIterable
|
||||
from common.types import (
|
||||
AgentCard,
|
||||
GetTaskRequest,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
JSONRPCRequest,
|
||||
GetTaskResponse,
|
||||
CancelTaskResponse,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
A2AClientHTTPError,
|
||||
A2AClientJSONError,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
)
|
||||
import json
|
||||
|
||||
|
||||
class A2AClient:
|
||||
def __init__(self, agent_card: AgentCard = None, url: str = None):
|
||||
if agent_card:
|
||||
self.url = agent_card.url
|
||||
elif url:
|
||||
self.url = url
|
||||
else:
|
||||
raise ValueError("Must provide either agent_card or url")
|
||||
|
||||
async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse:
|
||||
request = SendTaskRequest(params=payload)
|
||||
return SendTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def send_task_streaming(
|
||||
self, payload: dict[str, Any]
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
request = SendTaskStreamingRequest(params=payload)
|
||||
with httpx.Client(timeout=None) as client:
|
||||
with connect_sse(
|
||||
client, "POST", self.url, json=request.model_dump()
|
||||
) as event_source:
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
yield SendTaskStreamingResponse(**json.loads(sse.data))
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
except httpx.RequestError as e:
|
||||
raise A2AClientHTTPError(400, str(e)) from e
|
||||
|
||||
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
# Image generation could take time, adding timeout
|
||||
response = await client.post(
|
||||
self.url, json=request.model_dump(), timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
|
||||
async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse:
|
||||
request = GetTaskRequest(params=payload)
|
||||
return GetTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse:
|
||||
request = CancelTaskRequest(params=payload)
|
||||
return CancelTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def set_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> SetTaskPushNotificationResponse:
|
||||
request = SetTaskPushNotificationRequest(params=payload)
|
||||
return SetTaskPushNotificationResponse(**await self._send_request(request))
|
||||
|
||||
async def get_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> GetTaskPushNotificationResponse:
|
||||
request = GetTaskPushNotificationRequest(params=payload)
|
||||
return GetTaskPushNotificationResponse(**await self._send_request(request))
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from .server import A2AServer
|
||||
from .task_manager import TaskManager, InMemoryTaskManager
|
||||
|
||||
__all__ = ["A2AServer", "TaskManager", "InMemoryTaskManager"]
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from starlette.requests import Request
|
||||
from common.types import (
|
||||
A2ARequest,
|
||||
JSONRPCResponse,
|
||||
InvalidRequestError,
|
||||
JSONParseError,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SendTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
InternalError,
|
||||
AgentCard,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
import json
|
||||
from typing import AsyncIterable, Any
|
||||
from common.server.task_manager import TaskManager
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class A2AServer:
|
||||
def __init__(
|
||||
self,
|
||||
host="0.0.0.0",
|
||||
port=5000,
|
||||
endpoint="/",
|
||||
agent_card: AgentCard = None,
|
||||
task_manager: TaskManager = None,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.endpoint = endpoint
|
||||
self.task_manager = task_manager
|
||||
self.agent_card = agent_card
|
||||
self.app = Starlette()
|
||||
self.app.add_route(self.endpoint, self._process_request, methods=["POST"])
|
||||
self.app.add_route(
|
||||
"/.well-known/agent.json", self._get_agent_card, methods=["GET"]
|
||||
)
|
||||
|
||||
def start(self):
|
||||
if self.agent_card is None:
|
||||
raise ValueError("agent_card is not defined")
|
||||
|
||||
if self.task_manager is None:
|
||||
raise ValueError("request_handler is not defined")
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||
|
||||
def _get_agent_card(self, request: Request) -> JSONResponse:
|
||||
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
|
||||
|
||||
async def _process_request(self, request: Request):
|
||||
try:
|
||||
body = await request.json()
|
||||
json_rpc_request = A2ARequest.validate_python(body)
|
||||
|
||||
if isinstance(json_rpc_request, GetTaskRequest):
|
||||
result = await self.task_manager.on_get_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SendTaskRequest):
|
||||
result = await self.task_manager.on_send_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
|
||||
result = await self.task_manager.on_send_task_subscribe(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, CancelTaskRequest):
|
||||
result = await self.task_manager.on_cancel_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
|
||||
result = await self.task_manager.on_set_task_push_notification(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
|
||||
result = await self.task_manager.on_get_task_push_notification(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
|
||||
result = await self.task_manager.on_resubscribe_to_task(
|
||||
json_rpc_request
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unexpected request type: {type(json_rpc_request)}")
|
||||
raise ValueError(f"Unexpected request type: {type(request)}")
|
||||
|
||||
return self._create_response(result)
|
||||
|
||||
except Exception as e:
|
||||
return self._handle_exception(e)
|
||||
|
||||
def _handle_exception(self, e: Exception) -> JSONResponse:
|
||||
if isinstance(e, json.decoder.JSONDecodeError):
|
||||
json_rpc_error = JSONParseError()
|
||||
elif isinstance(e, ValidationError):
|
||||
json_rpc_error = InvalidRequestError(data=json.loads(e.json()))
|
||||
else:
|
||||
logger.error(f"Unhandled exception: {e}")
|
||||
json_rpc_error = InternalError()
|
||||
|
||||
response = JSONRPCResponse(id=None, error=json_rpc_error)
|
||||
return JSONResponse(response.model_dump(exclude_none=True), status_code=400)
|
||||
|
||||
def _create_response(self, result: Any) -> JSONResponse | EventSourceResponse:
|
||||
if isinstance(result, AsyncIterable):
|
||||
|
||||
async def event_generator(result) -> AsyncIterable[dict[str, str]]:
|
||||
async for item in result:
|
||||
yield {"data": item.model_dump_json(exclude_none=True)}
|
||||
|
||||
return EventSourceResponse(event_generator(result))
|
||||
elif isinstance(result, JSONRPCResponse):
|
||||
return JSONResponse(result.model_dump(exclude_none=True))
|
||||
else:
|
||||
logger.error(f"Unexpected result type: {type(result)}")
|
||||
raise ValueError(f"Unexpected result type: {type(result)}")
|
||||
|
|
@ -0,0 +1,277 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Union, AsyncIterable, List
|
||||
from common.types import Task
|
||||
from common.types import (
|
||||
JSONRPCResponse,
|
||||
TaskIdParams,
|
||||
TaskQueryParams,
|
||||
GetTaskRequest,
|
||||
TaskNotFoundError,
|
||||
SendTaskRequest,
|
||||
CancelTaskRequest,
|
||||
TaskNotCancelableError,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskResponse,
|
||||
CancelTaskResponse,
|
||||
SendTaskResponse,
|
||||
SetTaskPushNotificationResponse,
|
||||
GetTaskPushNotificationResponse,
|
||||
PushNotificationNotSupportedError,
|
||||
TaskSendParams,
|
||||
TaskStatus,
|
||||
TaskState,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
Artifact,
|
||||
PushNotificationConfig,
|
||||
TaskStatusUpdateEvent,
|
||||
JSONRPCError,
|
||||
TaskPushNotificationConfig,
|
||||
InternalError,
|
||||
)
|
||||
from common.server.utils import new_not_implemented_error
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TaskManager(ABC):
|
||||
@abstractmethod
|
||||
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_send_task_subscribe(
|
||||
self, request: SendTaskStreamingRequest
|
||||
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_set_task_push_notification(
|
||||
self, request: SetTaskPushNotificationRequest
|
||||
) -> SetTaskPushNotificationResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_get_task_push_notification(
|
||||
self, request: GetTaskPushNotificationRequest
|
||||
) -> GetTaskPushNotificationResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_resubscribe_to_task(
|
||||
self, request: TaskResubscriptionRequest
|
||||
) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]:
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryTaskManager(TaskManager):
|
||||
def __init__(self):
|
||||
self.tasks: dict[str, Task] = {}
|
||||
self.push_notification_infos: dict[str, PushNotificationConfig] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {}
|
||||
self.subscriber_lock = asyncio.Lock()
|
||||
|
||||
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
||||
logger.info(f"Getting task {request.params.id}")
|
||||
task_query_params: TaskQueryParams = request.params
|
||||
|
||||
async with self.lock:
|
||||
task = self.tasks.get(task_query_params.id)
|
||||
if task is None:
|
||||
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
|
||||
|
||||
task_result = self.append_task_history(
|
||||
task, task_query_params.historyLength
|
||||
)
|
||||
|
||||
return GetTaskResponse(id=request.id, result=task_result)
|
||||
|
||||
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
||||
logger.info(f"Cancelling task {request.params.id}")
|
||||
task_id_params: TaskIdParams = request.params
|
||||
|
||||
async with self.lock:
|
||||
task = self.tasks.get(task_id_params.id)
|
||||
if task is None:
|
||||
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
|
||||
|
||||
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
|
||||
|
||||
@abstractmethod
|
||||
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_send_task_subscribe(
|
||||
self, request: SendTaskStreamingRequest
|
||||
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||
pass
|
||||
|
||||
async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig):
|
||||
async with self.lock:
|
||||
task = self.tasks.get(task_id)
|
||||
if task is None:
|
||||
raise ValueError(f"Task not found for {task_id}")
|
||||
|
||||
self.push_notification_infos[task_id] = notification_config
|
||||
|
||||
return
|
||||
|
||||
async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig:
|
||||
async with self.lock:
|
||||
task = self.tasks.get(task_id)
|
||||
if task is None:
|
||||
raise ValueError(f"Task not found for {task_id}")
|
||||
|
||||
return self.push_notification_infos[task_id]
|
||||
|
||||
return
|
||||
|
||||
async def has_push_notification_info(self, task_id: str) -> bool:
|
||||
async with self.lock:
|
||||
return task_id in self.push_notification_infos
|
||||
|
||||
|
||||
async def on_set_task_push_notification(
|
||||
self, request: SetTaskPushNotificationRequest
|
||||
) -> SetTaskPushNotificationResponse:
|
||||
logger.info(f"Setting task push notification {request.params.id}")
|
||||
task_notification_params: TaskPushNotificationConfig = request.params
|
||||
|
||||
try:
|
||||
await self.set_push_notification_info(task_notification_params.id, task_notification_params.pushNotificationConfig)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while setting push notification info: {e}")
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message="An error occurred while setting push notification info"
|
||||
),
|
||||
)
|
||||
|
||||
return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params)
|
||||
|
||||
async def on_get_task_push_notification(
|
||||
self, request: GetTaskPushNotificationRequest
|
||||
) -> GetTaskPushNotificationResponse:
|
||||
logger.info(f"Getting task push notification {request.params.id}")
|
||||
task_params: TaskIdParams = request.params
|
||||
|
||||
try:
|
||||
notification_info = await self.get_push_notification_info(task_params.id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while getting push notification info: {e}")
|
||||
return GetTaskPushNotificationResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message="An error occurred while getting push notification info"
|
||||
),
|
||||
)
|
||||
|
||||
return GetTaskPushNotificationResponse(id=request.id, result=TaskPushNotificationConfig(id=task_params.id, pushNotificationConfig=notification_info))
|
||||
|
||||
async def upsert_task(self, task_send_params: TaskSendParams) -> Task:
|
||||
logger.info(f"Upserting task {task_send_params.id}")
|
||||
async with self.lock:
|
||||
task = self.tasks.get(task_send_params.id)
|
||||
if task is None:
|
||||
task = Task(
|
||||
id=task_send_params.id,
|
||||
sessionId = task_send_params.sessionId,
|
||||
messages=[task_send_params.message],
|
||||
status=TaskStatus(state=TaskState.SUBMITTED),
|
||||
history=[task_send_params.message],
|
||||
)
|
||||
self.tasks[task_send_params.id] = task
|
||||
else:
|
||||
task.history.append(task_send_params.message)
|
||||
|
||||
return task
|
||||
|
||||
async def on_resubscribe_to_task(
|
||||
self, request: TaskResubscriptionRequest
|
||||
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||
return new_not_implemented_error(request.id)
|
||||
|
||||
async def update_store(
|
||||
self, task_id: str, status: TaskStatus, artifacts: list[Artifact]
|
||||
) -> Task:
|
||||
async with self.lock:
|
||||
try:
|
||||
task = self.tasks[task_id]
|
||||
except KeyError:
|
||||
logger.error(f"Task {task_id} not found for updating the task")
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
task.status = status
|
||||
|
||||
if status.message is not None:
|
||||
task.history.append(status.message)
|
||||
|
||||
if artifacts is not None:
|
||||
if task.artifacts is None:
|
||||
task.artifacts = []
|
||||
task.artifacts.extend(artifacts)
|
||||
|
||||
return task
|
||||
|
||||
def append_task_history(self, task: Task, historyLength: int | None):
|
||||
new_task = task.model_copy()
|
||||
if historyLength is not None and historyLength > 0:
|
||||
new_task.history = new_task.history[-historyLength:]
|
||||
else:
|
||||
new_task.history = []
|
||||
|
||||
return new_task
|
||||
|
||||
async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False):
|
||||
async with self.subscriber_lock:
|
||||
if task_id not in self.task_sse_subscribers:
|
||||
if is_resubscribe:
|
||||
raise ValueError("Task not found for resubscription")
|
||||
else:
|
||||
self.task_sse_subscribers[task_id] = []
|
||||
|
||||
sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited
|
||||
self.task_sse_subscribers[task_id].append(sse_event_queue)
|
||||
return sse_event_queue
|
||||
|
||||
async def enqueue_events_for_sse(self, task_id, task_update_event):
|
||||
async with self.subscriber_lock:
|
||||
if task_id not in self.task_sse_subscribers:
|
||||
return
|
||||
|
||||
current_subscribers = self.task_sse_subscribers[task_id]
|
||||
for subscriber in current_subscribers:
|
||||
await subscriber.put(task_update_event)
|
||||
|
||||
async def dequeue_events_for_sse(
|
||||
self, request_id, task_id, sse_event_queue: asyncio.Queue
|
||||
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
|
||||
try:
|
||||
while True:
|
||||
event = await sse_event_queue.get()
|
||||
if isinstance(event, JSONRPCError):
|
||||
yield SendTaskStreamingResponse(id=request_id, error=event)
|
||||
break
|
||||
|
||||
yield SendTaskStreamingResponse(id=request_id, result=event)
|
||||
if isinstance(event, TaskStatusUpdateEvent) and event.final:
|
||||
break
|
||||
finally:
|
||||
async with self.subscriber_lock:
|
||||
if task_id in self.task_sse_subscribers:
|
||||
self.task_sse_subscribers[task_id].remove(sse_event_queue)
|
||||
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from common.types import (
|
||||
JSONRPCResponse,
|
||||
ContentTypeNotSupportedError,
|
||||
UnsupportedOperationError,
|
||||
)
|
||||
from typing import List
|
||||
|
||||
|
||||
def are_modalities_compatible(
|
||||
server_output_modes: List[str], client_output_modes: List[str]
|
||||
):
|
||||
"""Modalities are compatible if they are both non-empty
|
||||
and there is at least one common element."""
|
||||
if client_output_modes is None or len(client_output_modes) == 0:
|
||||
return True
|
||||
|
||||
if server_output_modes is None or len(server_output_modes) == 0:
|
||||
return True
|
||||
|
||||
return any(x in server_output_modes for x in client_output_modes)
|
||||
|
||||
|
||||
def new_incompatible_types_error(request_id):
|
||||
return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError())
|
||||
|
||||
|
||||
def new_not_implemented_error(request_id):
|
||||
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())
|
||||
|
|
@ -0,0 +1,365 @@
|
|||
from typing import Union, Any
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from typing import Literal, List, Annotated, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import model_validator, ConfigDict, field_serializer
|
||||
from uuid import uuid4
|
||||
from enum import Enum
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
SUBMITTED = "submitted"
|
||||
WORKING = "working"
|
||||
INPUT_REQUIRED = "input-required"
|
||||
COMPLETED = "completed"
|
||||
CANCELED = "canceled"
|
||||
FAILED = "failed"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class TextPart(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class FileContent(BaseModel):
|
||||
name: str | None = None
|
||||
mimeType: str | None = None
|
||||
bytes: str | None = None
|
||||
uri: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_content(self) -> Self:
|
||||
if not (self.bytes or self.uri):
|
||||
raise ValueError("Either 'bytes' or 'uri' must be present in the file data")
|
||||
if self.bytes and self.uri:
|
||||
raise ValueError(
|
||||
"Only one of 'bytes' or 'uri' can be present in the file data"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class FilePart(BaseModel):
|
||||
type: Literal["file"] = "file"
|
||||
file: FileContent
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DataPart(BaseModel):
|
||||
type: Literal["data"] = "data"
|
||||
data: dict[str, Any]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")]
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Literal["user", "agent"]
|
||||
parts: List[Part]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskStatus(BaseModel):
|
||||
state: TaskState
|
||||
message: Message | None = None
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
@field_serializer("timestamp")
|
||||
def serialize_dt(self, dt: datetime, _info):
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
parts: List[Part]
|
||||
metadata: dict[str, Any] | None = None
|
||||
index: int = 0
|
||||
append: bool | None = None
|
||||
lastChunk: bool | None = None
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
id: str
|
||||
sessionId: str | None = None
|
||||
status: TaskStatus
|
||||
artifacts: List[Artifact] | None = None
|
||||
history: List[Message] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskStatusUpdateEvent(BaseModel):
|
||||
id: str
|
||||
status: TaskStatus
|
||||
final: bool = False
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskArtifactUpdateEvent(BaseModel):
|
||||
id: str
|
||||
artifact: Artifact
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AuthenticationInfo(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
schemes: List[str]
|
||||
credentials: str | None = None
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
url: str
|
||||
token: str | None = None
|
||||
authentication: AuthenticationInfo | None = None
|
||||
|
||||
|
||||
class TaskIdParams(BaseModel):
|
||||
id: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskQueryParams(TaskIdParams):
|
||||
historyLength: int | None = None
|
||||
|
||||
|
||||
class TaskSendParams(BaseModel):
|
||||
id: str
|
||||
sessionId: str = Field(default_factory=lambda: uuid4().hex)
|
||||
message: Message
|
||||
acceptedOutputModes: Optional[List[str]] = None
|
||||
pushNotification: PushNotificationConfig | None = None
|
||||
historyLength: int | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskPushNotificationConfig(BaseModel):
|
||||
id: str
|
||||
pushNotificationConfig: PushNotificationConfig
|
||||
|
||||
|
||||
## RPC Messages
|
||||
|
||||
|
||||
class JSONRPCMessage(BaseModel):
|
||||
jsonrpc: Literal["2.0"] = "2.0"
|
||||
id: int | str | None = Field(default_factory=lambda: uuid4().hex)
|
||||
|
||||
|
||||
class JSONRPCRequest(JSONRPCMessage):
|
||||
method: str
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class JSONRPCError(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class JSONRPCResponse(JSONRPCMessage):
|
||||
result: Any | None = None
|
||||
error: JSONRPCError | None = None
|
||||
|
||||
|
||||
class SendTaskRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/send"] = "tasks/send"
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskResponse(JSONRPCResponse):
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class SendTaskStreamingRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe"
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskStreamingResponse(JSONRPCResponse):
|
||||
result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None
|
||||
|
||||
|
||||
class GetTaskRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/get"] = "tasks/get"
|
||||
params: TaskQueryParams
|
||||
|
||||
|
||||
class GetTaskResponse(JSONRPCResponse):
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class CancelTaskRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/cancel",] = "tasks/cancel"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class CancelTaskResponse(JSONRPCResponse):
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class SetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set"
|
||||
params: TaskPushNotificationConfig
|
||||
|
||||
|
||||
class SetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
result: TaskPushNotificationConfig | None = None
|
||||
|
||||
|
||||
class GetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class GetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
result: TaskPushNotificationConfig | None = None
|
||||
|
||||
|
||||
class TaskResubscriptionRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/resubscribe",] = "tasks/resubscribe"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
A2ARequest = TypeAdapter(
|
||||
Annotated[
|
||||
Union[
|
||||
SendTaskRequest,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
],
|
||||
Field(discriminator="method"),
|
||||
]
|
||||
)
|
||||
|
||||
## Error types
|
||||
|
||||
|
||||
class JSONParseError(JSONRPCError):
|
||||
code: int = -32700
|
||||
message: str = "Invalid JSON payload"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class InvalidRequestError(JSONRPCError):
|
||||
code: int = -32600
|
||||
message: str = "Request payload validation error"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class MethodNotFoundError(JSONRPCError):
|
||||
code: int = -32601
|
||||
message: str = "Method not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class InvalidParamsError(JSONRPCError):
|
||||
code: int = -32602
|
||||
message: str = "Invalid parameters"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class InternalError(JSONRPCError):
|
||||
code: int = -32603
|
||||
message: str = "Internal error"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class TaskNotFoundError(JSONRPCError):
|
||||
code: int = -32001
|
||||
message: str = "Task not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class TaskNotCancelableError(JSONRPCError):
|
||||
code: int = -32002
|
||||
message: str = "Task cannot be canceled"
|
||||
data: None = None
|
||||
|
||||
|
||||
class PushNotificationNotSupportedError(JSONRPCError):
|
||||
code: int = -32003
|
||||
message: str = "Push Notification is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class UnsupportedOperationError(JSONRPCError):
|
||||
code: int = -32004
|
||||
message: str = "This operation is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class ContentTypeNotSupportedError(JSONRPCError):
|
||||
code: int = -32005
|
||||
message: str = "Incompatible content types"
|
||||
data: None = None
|
||||
|
||||
|
||||
class AgentProvider(BaseModel):
|
||||
organization: str
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class AgentCapabilities(BaseModel):
|
||||
streaming: bool = False
|
||||
pushNotifications: bool = False
|
||||
stateTransitionHistory: bool = False
|
||||
|
||||
|
||||
class AgentAuthentication(BaseModel):
|
||||
schemes: List[str]
|
||||
credentials: str | None = None
|
||||
|
||||
|
||||
class AgentSkill(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
tags: List[str] | None = None
|
||||
examples: List[str] | None = None
|
||||
inputModes: List[str] | None = None
|
||||
outputModes: List[str] | None = None
|
||||
|
||||
|
||||
class AgentCard(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
url: str
|
||||
provider: AgentProvider | None = None
|
||||
version: str
|
||||
documentationUrl: str | None = None
|
||||
capabilities: AgentCapabilities
|
||||
authentication: AgentAuthentication | None = None
|
||||
defaultInputModes: List[str] = ["text"]
|
||||
defaultOutputModes: List[str] = ["text"]
|
||||
skills: List[AgentSkill]
|
||||
|
||||
|
||||
class A2AClientError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class A2AClientHTTPError(A2AClientError):
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(f"HTTP Error {status_code}: {message}")
|
||||
|
||||
|
||||
class A2AClientJSONError(A2AClientError):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(f"JSON Error: {message}")
|
||||
|
||||
|
||||
class MissingAPIKeyError(Exception):
|
||||
"""Exception for missing API key."""
|
||||
|
||||
pass
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
"""In Memory Cache utility."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class InMemoryCache:
|
||||
"""A thread-safe Singleton class to manage cache data.
|
||||
|
||||
Ensures only one instance of the cache exists across the application.
|
||||
"""
|
||||
|
||||
_instance: Optional["InMemoryCache"] = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
"""Override __new__ to control instance creation (Singleton pattern).
|
||||
|
||||
Uses a lock to ensure thread safety during the first instantiation.
|
||||
|
||||
Returns:
|
||||
The singleton instance of InMemoryCache.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the cache storage.
|
||||
|
||||
Uses a flag (_initialized) to ensure this logic runs only on the very first
|
||||
creation of the singleton instance.
|
||||
"""
|
||||
if not self._initialized:
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
# print("Initializing SessionCache storage")
|
||||
self._cache_data: Dict[str, Dict[str, Any]] = {}
|
||||
self._ttl: Dict[str, float] = {}
|
||||
self._data_lock: threading.Lock = threading.Lock()
|
||||
self._initialized = True
|
||||
|
||||
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
||||
"""Set a key-value pair.
|
||||
|
||||
Args:
|
||||
key: The key for the data.
|
||||
value: The data to store.
|
||||
ttl: Time to live in seconds. If None, data will not expire.
|
||||
"""
|
||||
with self._data_lock:
|
||||
self._cache_data[key] = value
|
||||
|
||||
if ttl is not None:
|
||||
self._ttl[key] = time.time() + ttl
|
||||
else:
|
||||
if key in self._ttl:
|
||||
del self._ttl[key]
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get the value associated with a key.
|
||||
|
||||
Args:
|
||||
key: The key for the data within the session.
|
||||
default: The value to return if the session or key is not found.
|
||||
|
||||
Returns:
|
||||
The cached value, or the default value if not found.
|
||||
"""
|
||||
with self._data_lock:
|
||||
if key in self._ttl and time.time() > self._ttl[key]:
|
||||
del self._cache_data[key]
|
||||
del self._ttl[key]
|
||||
return default
|
||||
return self._cache_data.get(key, default)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete a specific key-value pair from a cache.
|
||||
|
||||
Args:
|
||||
key: The key to delete.
|
||||
|
||||
Returns:
|
||||
True if the key was found and deleted, False otherwise.
|
||||
"""
|
||||
|
||||
with self._data_lock:
|
||||
if key in self._cache_data:
|
||||
del self._cache_data[key]
|
||||
if key in self._ttl:
|
||||
del self._ttl[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self) -> bool:
|
||||
"""Remove all data.
|
||||
|
||||
Returns:
|
||||
True if the data was cleared, False otherwise.
|
||||
"""
|
||||
with self._data_lock:
|
||||
self._cache_data.clear()
|
||||
self._ttl.clear()
|
||||
return True
|
||||
return False
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
from jwcrypto import jwk
|
||||
import uuid
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.requests import Request
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
import time
|
||||
import json
|
||||
import hashlib
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from jwt import PyJWK, PyJWKClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
AUTH_HEADER_PREFIX = 'Bearer '
|
||||
|
||||
class PushNotificationAuth:
|
||||
def _calculate_request_body_sha256(self, data: dict[str, Any]):
|
||||
"""Calculates the SHA256 hash of a request body.
|
||||
|
||||
This logic needs to be same for both the agent who signs the payload and the client verifier.
|
||||
"""
|
||||
body_str = json.dumps(
|
||||
data,
|
||||
ensure_ascii=False,
|
||||
allow_nan=False,
|
||||
indent=None,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return hashlib.sha256(body_str.encode()).hexdigest()
|
||||
|
||||
class PushNotificationSenderAuth(PushNotificationAuth):
|
||||
def __init__(self):
|
||||
self.public_keys = []
|
||||
self.private_key_jwk: PyJWK = None
|
||||
|
||||
@staticmethod
|
||||
async def verify_push_notification_url(url: str) -> bool:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
validation_token = str(uuid.uuid4())
|
||||
response = await client.get(
|
||||
url,
|
||||
params={"validationToken": validation_token}
|
||||
)
|
||||
response.raise_for_status()
|
||||
is_verified = response.text == validation_token
|
||||
|
||||
logger.info(f"Verified push-notification URL: {url} => {is_verified}")
|
||||
return is_verified
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during sending push-notification for URL {url}: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def generate_jwk(self):
|
||||
key = jwk.JWK.generate(kty='RSA', size=2048, kid=str(uuid.uuid4()), use="sig")
|
||||
self.public_keys.append(key.export_public(as_dict=True))
|
||||
self.private_key_jwk = PyJWK.from_json(key.export_private())
|
||||
|
||||
def handle_jwks_endpoint(self, _request: Request):
|
||||
"""Allow clients to fetch public keys.
|
||||
"""
|
||||
return JSONResponse({
|
||||
"keys": self.public_keys
|
||||
})
|
||||
|
||||
def _generate_jwt(self, data: dict[str, Any]):
|
||||
"""JWT is generated by signing both the request payload SHA digest and time of token generation.
|
||||
|
||||
Payload is signed with private key and it ensures the integrity of payload for client.
|
||||
Including iat prevents from replay attack.
|
||||
"""
|
||||
|
||||
iat = int(time.time())
|
||||
|
||||
return jwt.encode(
|
||||
{"iat": iat, "request_body_sha256": self._calculate_request_body_sha256(data)},
|
||||
key=self.private_key_jwk,
|
||||
headers={"kid": self.private_key_jwk.key_id},
|
||||
algorithm="RS256"
|
||||
)
|
||||
|
||||
async def send_push_notification(self, url: str, data: dict[str, Any]):
|
||||
jwt_token = self._generate_jwt(data)
|
||||
headers = {'Authorization': f"Bearer {jwt_token}"}
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=data,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(f"Push-notification sent for URL: {url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during sending push-notification for URL {url}: {e}")
|
||||
|
||||
class PushNotificationReceiverAuth(PushNotificationAuth):
|
||||
def __init__(self):
|
||||
self.public_keys_jwks = []
|
||||
self.jwks_client = None
|
||||
|
||||
async def load_jwks(self, jwks_url: str):
|
||||
self.jwks_client = PyJWKClient(jwks_url)
|
||||
|
||||
async def verify_push_notification(self, request: Request) -> bool:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX):
|
||||
print("Invalid authorization header")
|
||||
return False
|
||||
|
||||
token = auth_header[len(AUTH_HEADER_PREFIX):]
|
||||
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
|
||||
|
||||
decode_token = jwt.decode(
|
||||
token,
|
||||
signing_key,
|
||||
options={"require": ["iat", "request_body_sha256"]},
|
||||
algorithms=["RS256"],
|
||||
)
|
||||
|
||||
actual_body_sha256 = self._calculate_request_body_sha256(await request.json())
|
||||
if actual_body_sha256 != decode_token["request_body_sha256"]:
|
||||
# Payload signature does not match the digest in signed token.
|
||||
raise ValueError("Invalid request body")
|
||||
|
||||
if time.time() - decode_token["iat"] > 60 * 5:
|
||||
# Do not allow push-notifications older than 5 minutes.
|
||||
# This is to prevent replay attack.
|
||||
raise ValueError("Token is expired")
|
||||
|
||||
return True
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
from pocketflow import Flow
|
||||
from nodes import DecideAction, SearchWeb, AnswerQuestion
|
||||
|
||||
def create_agent_flow():
|
||||
"""
|
||||
Create and connect the nodes to form a complete agent flow.
|
||||
|
||||
The flow works like this:
|
||||
1. DecideAction node decides whether to search or answer
|
||||
2. If search, go to SearchWeb node
|
||||
3. If answer, go to AnswerQuestion node
|
||||
4. After SearchWeb completes, go back to DecideAction
|
||||
|
||||
Returns:
|
||||
Flow: A complete research agent flow
|
||||
"""
|
||||
# Create instances of each node
|
||||
decide = DecideAction()
|
||||
search = SearchWeb()
|
||||
answer = AnswerQuestion()
|
||||
|
||||
# Connect the nodes
|
||||
# If DecideAction returns "search", go to SearchWeb
|
||||
decide - "search" >> search
|
||||
|
||||
# If DecideAction returns "answer", go to AnswerQuestion
|
||||
decide - "answer" >> answer
|
||||
|
||||
# After SearchWeb completes and returns "decide", go back to DecideAction
|
||||
search - "decide" >> decide
|
||||
|
||||
# Create and return the flow, starting with the DecideAction node
|
||||
return Flow(start=decide)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
import sys
|
||||
from flow import create_agent_flow
|
||||
|
||||
def main():
|
||||
"""Simple function to process a question."""
|
||||
# Default question
|
||||
default_question = "Who won the Nobel Prize in Physics 2024?"
|
||||
|
||||
# Get question from command line if provided with --
|
||||
question = default_question
|
||||
for arg in sys.argv[1:]:
|
||||
if arg.startswith("--"):
|
||||
question = arg[2:]
|
||||
break
|
||||
|
||||
# Create the agent flow
|
||||
agent_flow = create_agent_flow()
|
||||
|
||||
# Process the question
|
||||
shared = {"question": question}
|
||||
print(f"🤔 Processing question: {question}")
|
||||
agent_flow.run(shared)
|
||||
print("\n🎯 Final Answer:")
|
||||
print(shared.get("answer", "No answer found"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
from pocketflow import Node
|
||||
from utils import call_llm, search_web
|
||||
import yaml
|
||||
|
||||
class DecideAction(Node):
|
||||
def prep(self, shared):
|
||||
"""Prepare the context and question for the decision-making process."""
|
||||
# Get the current context (default to "No previous search" if none exists)
|
||||
context = shared.get("context", "No previous search")
|
||||
# Get the question from the shared store
|
||||
question = shared["question"]
|
||||
# Return both for the exec step
|
||||
return question, context
|
||||
|
||||
def exec(self, inputs):
|
||||
"""Call the LLM to decide whether to search or answer."""
|
||||
question, context = inputs
|
||||
|
||||
print(f"🤔 Agent deciding what to do next...")
|
||||
|
||||
# Create a prompt to help the LLM decide what to do next with proper yaml formatting
|
||||
prompt = f"""
|
||||
### CONTEXT
|
||||
You are a research assistant that can search the web.
|
||||
Question: {question}
|
||||
Previous Research: {context}
|
||||
|
||||
### ACTION SPACE
|
||||
[1] search
|
||||
Description: Look up more information on the web
|
||||
Parameters:
|
||||
- query (str): What to search for
|
||||
|
||||
[2] answer
|
||||
Description: Answer the question with current knowledge
|
||||
Parameters:
|
||||
- answer (str): Final answer to the question
|
||||
|
||||
## NEXT ACTION
|
||||
Decide the next action based on the context and available actions.
|
||||
Return your response in this format:
|
||||
|
||||
```yaml
|
||||
thinking: |
|
||||
<your step-by-step reasoning process>
|
||||
action: search OR answer
|
||||
reason: <why you chose this action>
|
||||
answer: <if action is answer>
|
||||
search_query: <specific search query if action is search>
|
||||
```
|
||||
IMPORTANT: Make sure to:
|
||||
1. Use proper indentation (4 spaces) for all multi-line fields
|
||||
2. Use the | character for multi-line text fields
|
||||
3. Keep single-line fields without the | character
|
||||
"""
|
||||
|
||||
# Call the LLM to make a decision
|
||||
response = call_llm(prompt)
|
||||
|
||||
# Parse the response to get the decision
|
||||
yaml_str = response.split("```yaml")[1].split("```")[0].strip()
|
||||
decision = yaml.safe_load(yaml_str)
|
||||
|
||||
return decision
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
"""Save the decision and determine the next step in the flow."""
|
||||
# If LLM decided to search, save the search query
|
||||
if exec_res["action"] == "search":
|
||||
shared["search_query"] = exec_res["search_query"]
|
||||
print(f"🔍 Agent decided to search for: {exec_res['search_query']}")
|
||||
else:
|
||||
shared["context"] = exec_res["answer"] #save the context if LLM gives the answer without searching.
|
||||
print(f"💡 Agent decided to answer the question")
|
||||
|
||||
# Return the action to determine the next node in the flow
|
||||
return exec_res["action"]
|
||||
|
||||
class SearchWeb(Node):
|
||||
def prep(self, shared):
|
||||
"""Get the search query from the shared store."""
|
||||
return shared["search_query"]
|
||||
|
||||
def exec(self, search_query):
|
||||
"""Search the web for the given query."""
|
||||
# Call the search utility function
|
||||
print(f"🌐 Searching the web for: {search_query}")
|
||||
results = search_web(search_query)
|
||||
return results
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
"""Save the search results and go back to the decision node."""
|
||||
# Add the search results to the context in the shared store
|
||||
previous = shared.get("context", "")
|
||||
shared["context"] = previous + "\n\nSEARCH: " + shared["search_query"] + "\nRESULTS: " + exec_res
|
||||
|
||||
print(f"📚 Found information, analyzing results...")
|
||||
|
||||
# Always go back to the decision node after searching
|
||||
return "decide"
|
||||
|
||||
class AnswerQuestion(Node):
|
||||
def prep(self, shared):
|
||||
"""Get the question and context for answering."""
|
||||
return shared["question"], shared.get("context", "")
|
||||
|
||||
def exec(self, inputs):
|
||||
"""Call the LLM to generate a final answer."""
|
||||
question, context = inputs
|
||||
|
||||
print(f"✍️ Crafting final answer...")
|
||||
|
||||
# Create a prompt for the LLM to answer the question
|
||||
prompt = f"""
|
||||
### CONTEXT
|
||||
Based on the following information, answer the question.
|
||||
Question: {question}
|
||||
Research: {context}
|
||||
|
||||
## YOUR ANSWER:
|
||||
Provide a comprehensive answer using the research results.
|
||||
"""
|
||||
# Call the LLM to generate an answer
|
||||
answer = call_llm(prompt)
|
||||
return answer
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
"""Save the final answer and complete the flow."""
|
||||
# Save the answer in the shared store
|
||||
shared["answer"] = exec_res
|
||||
|
||||
print(f"✅ Answer generated successfully")
|
||||
|
||||
# We're done - no need to continue the flow
|
||||
return "done"
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# For PocketFlow Agent Logic
|
||||
pocketflow>=0.0.1
|
||||
openai>=1.0.0
|
||||
duckduckgo-search>=7.5.2
|
||||
pyyaml>=5.1
|
||||
|
||||
# For A2A Server Infrastructure (from common)
|
||||
starlette>=0.37.2,<0.38.0
|
||||
uvicorn[standard]>=0.29.0,<0.30.0
|
||||
sse-starlette>=1.8.2,<2.0.0
|
||||
pydantic>=2.0.0,<3.0.0
|
||||
httpx>=0.27.0,<0.28.0
|
||||
anyio>=3.0.0,<5.0.0 # Dependency of starlette/httpx
|
||||
|
||||
# For running __main__.py
|
||||
click>=8.0.0,<9.0.0
|
||||
|
||||
# For A2A Client
|
||||
httpx>=0.27.0,<0.28.0
|
||||
asyncclick>=8.1.8 # Or just 'click' if you prefer asyncio.run
|
||||
pydantic>=2.0.0,<3.0.0 # For common.types
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
# FILE: pocketflow_a2a_agent/task_manager.py
|
||||
import logging
|
||||
from typing import AsyncIterable, Union
|
||||
import asyncio
|
||||
|
||||
# Import from the common code you copied
|
||||
from common.server.task_manager import InMemoryTaskManager
|
||||
from common.types import (
|
||||
JSONRPCResponse, SendTaskRequest, SendTaskResponse,
|
||||
SendTaskStreamingRequest, SendTaskStreamingResponse, Task, TaskSendParams,
|
||||
TaskState, TaskStatus, TextPart, Artifact, UnsupportedOperationError,
|
||||
InternalError, InvalidParamsError
|
||||
)
|
||||
import common.server.utils as server_utils
|
||||
|
||||
# Import directly from your original PocketFlow files
|
||||
from .flow import create_agent_flow # Assumes flow.py is in the same directory
|
||||
from .utils import call_llm, search_web # Make utils functions available if needed elsewhere
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PocketFlowTaskManager(InMemoryTaskManager):
|
||||
""" TaskManager implementation that runs the PocketFlow agent. """
|
||||
|
||||
SUPPORTED_CONTENT_TYPES = ["text", "text/plain"] # Define what the agent accepts/outputs
|
||||
|
||||
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
||||
"""Handles non-streaming task requests."""
|
||||
logger.info(f"Received task send request: {request.params.id}")
|
||||
|
||||
# Validate output modes
|
||||
if not server_utils.are_modalities_compatible(
|
||||
request.params.acceptedOutputModes, self.SUPPORTED_CONTENT_TYPES
|
||||
):
|
||||
logger.warning(
|
||||
"Unsupported output mode. Received %s, Support %s",
|
||||
request.params.acceptedOutputModes, self.SUPPORTED_CONTENT_TYPES
|
||||
)
|
||||
return SendTaskResponse(id=request.id, error=server_utils.new_incompatible_types_error(request.id).error)
|
||||
|
||||
# Upsert the task in the store (initial state: submitted)
|
||||
# We create the task first so its state can be tracked, even if the sync execution fails
|
||||
await self.upsert_task(request.params)
|
||||
# Update state to working before running
|
||||
await self.update_store(request.params.id, TaskStatus(state=TaskState.WORKING), [])
|
||||
|
||||
|
||||
# --- Run the PocketFlow logic ---
|
||||
task_params: TaskSendParams = request.params
|
||||
query = self._get_user_query(task_params)
|
||||
if query is None:
|
||||
fail_status = TaskStatus(state=TaskState.FAILED, message=Message(role="agent", parts=[TextPart(text="No text query found")]))
|
||||
await self.update_store(task_params.id, fail_status, [])
|
||||
return SendTaskResponse(id=request.id, error=InvalidParamsError(message="No text query found in message parts"))
|
||||
|
||||
shared_data = {"question": query}
|
||||
agent_flow = create_agent_flow() # Create the flow instance
|
||||
|
||||
try:
|
||||
# Run the synchronous PocketFlow
|
||||
# In a real async server, you might run this in a separate thread/process
|
||||
# executor to avoid blocking the event loop. For simplicity here, we run it directly.
|
||||
# Consider adding a timeout if flows can hang.
|
||||
logger.info(f"Running PocketFlow for task {task_params.id}...")
|
||||
final_state_dict = agent_flow.run(shared_data)
|
||||
logger.info(f"PocketFlow completed for task {task_params.id}")
|
||||
answer_text = final_state_dict.get("answer", "Agent did not produce a final answer text.")
|
||||
|
||||
# --- Package result into A2A Task ---
|
||||
final_task_status = TaskStatus(state=TaskState.COMPLETED)
|
||||
# Package the answer as an artifact
|
||||
final_artifact = Artifact(parts=[TextPart(text=answer_text)])
|
||||
|
||||
# Update the task in the store with final status and artifact
|
||||
final_task = await self.update_store(
|
||||
task_params.id, final_task_status, [final_artifact]
|
||||
)
|
||||
|
||||
# Prepare and return the A2A response
|
||||
task_result = self.append_task_history(final_task, task_params.historyLength)
|
||||
return SendTaskResponse(id=request.id, result=task_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing PocketFlow for task {task_params.id}: {e}", exc_info=True)
|
||||
# Update task state to FAILED
|
||||
fail_status = TaskStatus(
|
||||
state=TaskState.FAILED,
|
||||
message=Message(role="agent", parts=[TextPart(text=f"Agent execution failed: {e}")])
|
||||
)
|
||||
await self.update_store(task_params.id, fail_status, [])
|
||||
return SendTaskResponse(id=request.id, error=InternalError(message=f"Agent error: {e}"))
|
||||
|
||||
async def on_send_task_subscribe(
|
||||
self, request: SendTaskStreamingRequest
|
||||
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||
"""Handles streaming requests - Not implemented for this synchronous agent."""
|
||||
logger.warning(f"Streaming requested for task {request.params.id}, but not supported by this PocketFlow agent implementation.")
|
||||
# Return an error indicating streaming is not supported
|
||||
return JSONRPCResponse(id=request.id, error=UnsupportedOperationError(message="Streaming not supported by this agent"))
|
||||
|
||||
def _get_user_query(self, task_send_params: TaskSendParams) -> str | None:
|
||||
"""Extracts the first text part from the user message."""
|
||||
if not task_send_params.message or not task_send_params.message.parts:
|
||||
logger.warning(f"No message parts found for task {task_send_params.id}")
|
||||
return None
|
||||
for part in task_send_params.message.parts:
|
||||
# Ensure part is treated as a dictionary if it came from JSON
|
||||
part_dict = part if isinstance(part, dict) else part.model_dump()
|
||||
if part_dict.get("type") == "text" and "text" in part_dict:
|
||||
return part_dict["text"]
|
||||
logger.warning(f"No text part found in message for task {task_send_params.id}")
|
||||
return None # No text part found
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
from openai import OpenAI
|
||||
import os
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
def call_llm(prompt):
|
||||
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
|
||||
r = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
return r.choices[0].message.content
|
||||
|
||||
def search_web(query):
|
||||
results = DDGS().text(query, max_results=5)
|
||||
# Convert results to a string
|
||||
results_str = "\n\n".join([f"Title: {r['title']}\nURL: {r['href']}\nSnippet: {r['body']}" for r in results])
|
||||
return results_str
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("## Testing call_llm")
|
||||
prompt = "In a few words, what is the meaning of life?"
|
||||
print(f"## Prompt: {prompt}")
|
||||
response = call_llm(prompt)
|
||||
print(f"## Response: {response}")
|
||||
|
||||
print("## Testing search_web")
|
||||
query = "Who won the Nobel Prize in Physics 2024?"
|
||||
print(f"## Query: {query}")
|
||||
results = search_web(query)
|
||||
print(f"## Results: {results}")
|
||||
Loading…
Reference in New Issue