a2a init
This commit is contained in:
parent
447a2d47d2
commit
cd02d65efe
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Research Agent
|
||||||
|
|
||||||
|
This project demonstrates a simple yet powerful LLM-powered research agent. This implementation is based directly on the tutorial: [LLM Agents are simply Graph — Tutorial For Dummies](https://zacharyhuang.substack.com/p/llm-agent-internal-as-a-graph-tutorial).
|
||||||
|
|
||||||
|
👉 Run the tutorial in your browser: [Try Google Colab Notebook](
|
||||||
|
https://colab.research.google.com/github/The-Pocket/PocketFlow/blob/main/cookbook/pocketflow-agent/demo.ipynb)
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Performs web searches to gather information
|
||||||
|
- Makes decisions about when to search vs. when to answer
|
||||||
|
- Generates comprehensive answers based on research findings
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
1. Install the packages you need with this simple command:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Let's get your OpenAI API key ready:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Let's do a quick check to make sure your API key is working properly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This will test both the LLM call and web search features. If you see responses, you're good to go!
|
||||||
|
|
||||||
|
4. Try out the agent with the default question (about Nobel Prize winners):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Got a burning question? Ask anything you want by using the `--` prefix:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py --"What is quantum computing?"
|
||||||
|
```
|
||||||
|
|
||||||
|
## How It Works?
|
||||||
|
|
||||||
|
The magic happens through a simple but powerful graph structure with three main parts:
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph TD
|
||||||
|
A[DecideAction] -->|"search"| B[SearchWeb]
|
||||||
|
A -->|"answer"| C[AnswerQuestion]
|
||||||
|
B -->|"decide"| A
|
||||||
|
```
|
||||||
|
|
||||||
|
Here's what each part does:
|
||||||
|
1. **DecideAction**: The brain that figures out whether to search or answer
|
||||||
|
2. **SearchWeb**: The researcher that goes out and finds information
|
||||||
|
3. **AnswerQuestion**: The writer that crafts the final answer
|
||||||
|
|
||||||
|
Here's what's in each file:
|
||||||
|
- [`main.py`](./main.py): The starting point - runs the whole show!
|
||||||
|
- [`flow.py`](./flow.py): Connects everything together into a smart agent
|
||||||
|
- [`nodes.py`](./nodes.py): The building blocks that make decisions and take actions
|
||||||
|
- [`utils.py`](./utils.py): Helper functions for talking to the LLM and searching the web
|
||||||
|
|
@ -0,0 +1,143 @@
|
||||||
|
# FILE: minimal_a2a_client.py
|
||||||
|
import asyncio
|
||||||
|
import asyncclick as click # Using asyncclick for async main
|
||||||
|
from uuid import uuid4
|
||||||
|
import json # For potentially inspecting raw errors
|
||||||
|
|
||||||
|
# Import from the common directory placed alongside this script
|
||||||
|
from common.client import A2AClient
|
||||||
|
from common.types import (
|
||||||
|
TaskState,
|
||||||
|
A2AClientError,
|
||||||
|
TextPart, # Used to construct the message
|
||||||
|
JSONRPCResponse # Potentially useful for error checking
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- ANSI Colors (Optional but helpful) ---
|
||||||
|
C_RED = "\x1b[31m"
|
||||||
|
C_GREEN = "\x1b[32m"
|
||||||
|
C_YELLOW = "\x1b[33m"
|
||||||
|
C_BLUE = "\x1b[34m"
|
||||||
|
C_MAGENTA = "\x1b[35m"
|
||||||
|
C_CYAN = "\x1b[36m"
|
||||||
|
C_WHITE = "\x1b[37m"
|
||||||
|
C_GRAY = "\x1b[90m"
|
||||||
|
C_BRIGHT_MAGENTA = "\x1b[95m"
|
||||||
|
C_RESET = "\x1b[0m"
|
||||||
|
C_BOLD = "\x1b[1m"
|
||||||
|
|
||||||
|
def colorize(color, text):
|
||||||
|
return f"{color}{text}{C_RESET}"
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option(
|
||||||
|
"--agent-url",
|
||||||
|
default="http://localhost:10003", # Default to the port used in server __main__
|
||||||
|
help="URL of the PocketFlow A2A agent server.",
|
||||||
|
)
|
||||||
|
async def cli(agent_url: str):
|
||||||
|
"""Minimal CLI client to interact with an A2A agent."""
|
||||||
|
|
||||||
|
print(colorize(C_BRIGHT_MAGENTA, f"Connecting to agent at: {agent_url}"))
|
||||||
|
|
||||||
|
# Instantiate the client - only URL is needed if not fetching card first
|
||||||
|
# Note: The PocketFlow wrapper doesn't expose much via the AgentCard,
|
||||||
|
# so we skip fetching it for this minimal client.
|
||||||
|
client = A2AClient(url=agent_url)
|
||||||
|
|
||||||
|
sessionId = uuid4().hex # Generate a new session ID for this run
|
||||||
|
print(colorize(C_GRAY, f"Using Session ID: {sessionId}"))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
taskId = uuid4().hex # Generate a new task ID for each interaction
|
||||||
|
try:
|
||||||
|
prompt = await click.prompt(
|
||||||
|
colorize(C_CYAN, "\nEnter your question (:q or quit to exit)"),
|
||||||
|
prompt_suffix=" > ",
|
||||||
|
type=str # Ensure prompt returns string
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
# This can happen if stdin is closed, e.g., in some test runners
|
||||||
|
print(colorize(C_RED, "Failed to read input. Exiting."))
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if prompt.lower() in [":q", "quit"]:
|
||||||
|
print(colorize(C_YELLOW, "Exiting client."))
|
||||||
|
break
|
||||||
|
|
||||||
|
# --- Construct A2A Request Payload ---
|
||||||
|
payload = {
|
||||||
|
"id": taskId,
|
||||||
|
"sessionId": sessionId,
|
||||||
|
"message": {
|
||||||
|
"role": "user",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"type": "text", # Explicitly match TextPart structure
|
||||||
|
"text": prompt,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"acceptedOutputModes": ["text", "text/plain"], # What the client wants back
|
||||||
|
# historyLength could be added if needed
|
||||||
|
}
|
||||||
|
|
||||||
|
print(colorize(C_GRAY, f"Sending task {taskId}..."))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# --- Send Task (Non-Streaming) ---
|
||||||
|
response = await client.send_task(payload)
|
||||||
|
|
||||||
|
# --- Process Response ---
|
||||||
|
if response.error:
|
||||||
|
print(colorize(C_RED, f"Error from agent (Code: {response.error.code}): {response.error.message}"))
|
||||||
|
if response.error.data:
|
||||||
|
print(colorize(C_GRAY, f"Error Data: {response.error.data}"))
|
||||||
|
elif response.result:
|
||||||
|
task_result = response.result
|
||||||
|
print(colorize(C_GREEN, f"Task {task_result.id} finished with state: {task_result.status.state}"))
|
||||||
|
|
||||||
|
final_answer = "Agent did not provide a final artifact."
|
||||||
|
# Extract answer from artifacts (as implemented in PocketFlowTaskManager)
|
||||||
|
if task_result.artifacts:
|
||||||
|
try:
|
||||||
|
# Find the first text part in the first artifact
|
||||||
|
first_artifact = task_result.artifacts[0]
|
||||||
|
first_text_part = next(
|
||||||
|
(p for p in first_artifact.parts if isinstance(p, TextPart)),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
if first_text_part:
|
||||||
|
final_answer = first_text_part.text
|
||||||
|
else:
|
||||||
|
final_answer = f"(Non-text artifact received: {first_artifact.parts})"
|
||||||
|
except (IndexError, AttributeError, TypeError) as e:
|
||||||
|
final_answer = f"(Error parsing artifact: {e})"
|
||||||
|
elif task_result.status.message and task_result.status.message.parts:
|
||||||
|
# Fallback to status message if no artifact
|
||||||
|
try:
|
||||||
|
first_text_part = next(
|
||||||
|
(p for p in task_result.status.message.parts if isinstance(p, TextPart)),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
if first_text_part:
|
||||||
|
final_answer = f"(Final status message: {first_text_part.text})"
|
||||||
|
|
||||||
|
except (AttributeError, TypeError) as e:
|
||||||
|
final_answer = f"(Error parsing status message: {e})"
|
||||||
|
|
||||||
|
|
||||||
|
print(colorize(C_BOLD + C_WHITE, f"\nAgent Response:\n{final_answer}"))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Should not happen if error is None
|
||||||
|
print(colorize(C_YELLOW, "Received response with no result and no error."))
|
||||||
|
|
||||||
|
except A2AClientError as e:
|
||||||
|
print(colorize(C_RED, f"\nClient Error: {e}"))
|
||||||
|
except Exception as e:
|
||||||
|
print(colorize(C_RED, f"\nAn unexpected error occurred: {e}"))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(cli())
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
# FILE: pocketflow_a2a_agent/__main__.py
|
||||||
|
import click
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Import from the common code you copied
|
||||||
|
from common.server import A2AServer
|
||||||
|
from common.types import AgentCard, AgentCapabilities, AgentSkill, MissingAPIKeyError
|
||||||
|
|
||||||
|
# Import your custom TaskManager (which now imports from your original files)
|
||||||
|
from .task_manager import PocketFlowTaskManager
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option("--host", "host", default="localhost")
|
||||||
|
@click.option("--port", "port", default=10003) # Use a different port from other agents
|
||||||
|
def main(host, port):
|
||||||
|
"""Starts the PocketFlow A2A Agent server."""
|
||||||
|
try:
|
||||||
|
# Check for necessary API keys (add others if needed)
|
||||||
|
if not os.getenv("OPENAI_API_KEY"):
|
||||||
|
raise MissingAPIKeyError("OPENAI_API_KEY environment variable not set.")
|
||||||
|
|
||||||
|
# --- Define the Agent Card ---
|
||||||
|
capabilities = AgentCapabilities(
|
||||||
|
streaming=False, # This simple implementation is synchronous
|
||||||
|
pushNotifications=False,
|
||||||
|
stateTransitionHistory=False # PocketFlow state isn't exposed via A2A history
|
||||||
|
)
|
||||||
|
skill = AgentSkill(
|
||||||
|
id="web_research_qa",
|
||||||
|
name="Web Research and Answering",
|
||||||
|
description="Answers questions using web search results when necessary.",
|
||||||
|
tags=["research", "qa", "web search"],
|
||||||
|
examples=[
|
||||||
|
"Who won the Nobel Prize in Physics 2024?",
|
||||||
|
"What is quantum computing?",
|
||||||
|
"Summarize the latest news about AI.",
|
||||||
|
],
|
||||||
|
# Input/Output modes defined in the TaskManager
|
||||||
|
inputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
|
||||||
|
outputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
|
||||||
|
)
|
||||||
|
agent_card = AgentCard(
|
||||||
|
name="PocketFlow Research Agent (A2A Wrapped)",
|
||||||
|
description="A simple research agent based on PocketFlow, made accessible via A2A.",
|
||||||
|
url=f"http://{host}:{port}/", # The endpoint A2A clients will use
|
||||||
|
version="0.1.0-a2a",
|
||||||
|
capabilities=capabilities,
|
||||||
|
skills=[skill],
|
||||||
|
# Assuming no specific provider or auth for this example
|
||||||
|
provider=None,
|
||||||
|
authentication=None,
|
||||||
|
defaultInputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
|
||||||
|
defaultOutputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Initialize and Start Server ---
|
||||||
|
task_manager = PocketFlowTaskManager() # Instantiate your custom manager
|
||||||
|
server = A2AServer(
|
||||||
|
agent_card=agent_card,
|
||||||
|
task_manager=task_manager,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Starting PocketFlow A2A server on http://{host}:{port}")
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
except MissingAPIKeyError as e:
|
||||||
|
logger.error(f"Configuration Error: {e}")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"An error occurred during server startup: {e}", exc_info=True)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .client import A2AClient
|
||||||
|
from .card_resolver import A2ACardResolver
|
||||||
|
|
||||||
|
__all__ = ["A2AClient", "A2ACardResolver"]
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
import httpx
|
||||||
|
from common.types import (
|
||||||
|
AgentCard,
|
||||||
|
A2AClientJSONError,
|
||||||
|
)
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class A2ACardResolver:
|
||||||
|
def __init__(self, base_url, agent_card_path="/.well-known/agent.json"):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.agent_card_path = agent_card_path.lstrip("/")
|
||||||
|
|
||||||
|
def get_agent_card(self) -> AgentCard:
|
||||||
|
with httpx.Client() as client:
|
||||||
|
response = client.get(self.base_url + "/" + self.agent_card_path)
|
||||||
|
response.raise_for_status()
|
||||||
|
try:
|
||||||
|
return AgentCard(**response.json())
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise A2AClientJSONError(str(e)) from e
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
import httpx
|
||||||
|
from httpx_sse import connect_sse
|
||||||
|
from typing import Any, AsyncIterable
|
||||||
|
from common.types import (
|
||||||
|
AgentCard,
|
||||||
|
GetTaskRequest,
|
||||||
|
SendTaskRequest,
|
||||||
|
SendTaskResponse,
|
||||||
|
JSONRPCRequest,
|
||||||
|
GetTaskResponse,
|
||||||
|
CancelTaskResponse,
|
||||||
|
CancelTaskRequest,
|
||||||
|
SetTaskPushNotificationRequest,
|
||||||
|
SetTaskPushNotificationResponse,
|
||||||
|
GetTaskPushNotificationRequest,
|
||||||
|
GetTaskPushNotificationResponse,
|
||||||
|
A2AClientHTTPError,
|
||||||
|
A2AClientJSONError,
|
||||||
|
SendTaskStreamingRequest,
|
||||||
|
SendTaskStreamingResponse,
|
||||||
|
)
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class A2AClient:
|
||||||
|
def __init__(self, agent_card: AgentCard = None, url: str = None):
|
||||||
|
if agent_card:
|
||||||
|
self.url = agent_card.url
|
||||||
|
elif url:
|
||||||
|
self.url = url
|
||||||
|
else:
|
||||||
|
raise ValueError("Must provide either agent_card or url")
|
||||||
|
|
||||||
|
async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse:
|
||||||
|
request = SendTaskRequest(params=payload)
|
||||||
|
return SendTaskResponse(**await self._send_request(request))
|
||||||
|
|
||||||
|
async def send_task_streaming(
|
||||||
|
self, payload: dict[str, Any]
|
||||||
|
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||||
|
request = SendTaskStreamingRequest(params=payload)
|
||||||
|
with httpx.Client(timeout=None) as client:
|
||||||
|
with connect_sse(
|
||||||
|
client, "POST", self.url, json=request.model_dump()
|
||||||
|
) as event_source:
|
||||||
|
try:
|
||||||
|
for sse in event_source.iter_sse():
|
||||||
|
yield SendTaskStreamingResponse(**json.loads(sse.data))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise A2AClientJSONError(str(e)) from e
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise A2AClientHTTPError(400, str(e)) from e
|
||||||
|
|
||||||
|
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
# Image generation could take time, adding timeout
|
||||||
|
response = await client.post(
|
||||||
|
self.url, json=request.model_dump(), timeout=30
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise A2AClientJSONError(str(e)) from e
|
||||||
|
|
||||||
|
async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse:
|
||||||
|
request = GetTaskRequest(params=payload)
|
||||||
|
return GetTaskResponse(**await self._send_request(request))
|
||||||
|
|
||||||
|
async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse:
|
||||||
|
request = CancelTaskRequest(params=payload)
|
||||||
|
return CancelTaskResponse(**await self._send_request(request))
|
||||||
|
|
||||||
|
async def set_task_callback(
|
||||||
|
self, payload: dict[str, Any]
|
||||||
|
) -> SetTaskPushNotificationResponse:
|
||||||
|
request = SetTaskPushNotificationRequest(params=payload)
|
||||||
|
return SetTaskPushNotificationResponse(**await self._send_request(request))
|
||||||
|
|
||||||
|
async def get_task_callback(
|
||||||
|
self, payload: dict[str, Any]
|
||||||
|
) -> GetTaskPushNotificationResponse:
|
||||||
|
request = GetTaskPushNotificationRequest(params=payload)
|
||||||
|
return GetTaskPushNotificationResponse(**await self._send_request(request))
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .server import A2AServer
|
||||||
|
from .task_manager import TaskManager, InMemoryTaskManager
|
||||||
|
|
||||||
|
__all__ = ["A2AServer", "TaskManager", "InMemoryTaskManager"]
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
from starlette.requests import Request
|
||||||
|
from common.types import (
|
||||||
|
A2ARequest,
|
||||||
|
JSONRPCResponse,
|
||||||
|
InvalidRequestError,
|
||||||
|
JSONParseError,
|
||||||
|
GetTaskRequest,
|
||||||
|
CancelTaskRequest,
|
||||||
|
SendTaskRequest,
|
||||||
|
SetTaskPushNotificationRequest,
|
||||||
|
GetTaskPushNotificationRequest,
|
||||||
|
InternalError,
|
||||||
|
AgentCard,
|
||||||
|
TaskResubscriptionRequest,
|
||||||
|
SendTaskStreamingRequest,
|
||||||
|
)
|
||||||
|
from pydantic import ValidationError
|
||||||
|
import json
|
||||||
|
from typing import AsyncIterable, Any
|
||||||
|
from common.server.task_manager import TaskManager
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class A2AServer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=5000,
|
||||||
|
endpoint="/",
|
||||||
|
agent_card: AgentCard = None,
|
||||||
|
task_manager: TaskManager = None,
|
||||||
|
):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.task_manager = task_manager
|
||||||
|
self.agent_card = agent_card
|
||||||
|
self.app = Starlette()
|
||||||
|
self.app.add_route(self.endpoint, self._process_request, methods=["POST"])
|
||||||
|
self.app.add_route(
|
||||||
|
"/.well-known/agent.json", self._get_agent_card, methods=["GET"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
if self.agent_card is None:
|
||||||
|
raise ValueError("agent_card is not defined")
|
||||||
|
|
||||||
|
if self.task_manager is None:
|
||||||
|
raise ValueError("request_handler is not defined")
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||||
|
|
||||||
|
def _get_agent_card(self, request: Request) -> JSONResponse:
|
||||||
|
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
|
||||||
|
|
||||||
|
async def _process_request(self, request: Request):
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
json_rpc_request = A2ARequest.validate_python(body)
|
||||||
|
|
||||||
|
if isinstance(json_rpc_request, GetTaskRequest):
|
||||||
|
result = await self.task_manager.on_get_task(json_rpc_request)
|
||||||
|
elif isinstance(json_rpc_request, SendTaskRequest):
|
||||||
|
result = await self.task_manager.on_send_task(json_rpc_request)
|
||||||
|
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
|
||||||
|
result = await self.task_manager.on_send_task_subscribe(
|
||||||
|
json_rpc_request
|
||||||
|
)
|
||||||
|
elif isinstance(json_rpc_request, CancelTaskRequest):
|
||||||
|
result = await self.task_manager.on_cancel_task(json_rpc_request)
|
||||||
|
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
|
||||||
|
result = await self.task_manager.on_set_task_push_notification(json_rpc_request)
|
||||||
|
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
|
||||||
|
result = await self.task_manager.on_get_task_push_notification(json_rpc_request)
|
||||||
|
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
|
||||||
|
result = await self.task_manager.on_resubscribe_to_task(
|
||||||
|
json_rpc_request
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected request type: {type(json_rpc_request)}")
|
||||||
|
raise ValueError(f"Unexpected request type: {type(request)}")
|
||||||
|
|
||||||
|
return self._create_response(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return self._handle_exception(e)
|
||||||
|
|
||||||
|
def _handle_exception(self, e: Exception) -> JSONResponse:
|
||||||
|
if isinstance(e, json.decoder.JSONDecodeError):
|
||||||
|
json_rpc_error = JSONParseError()
|
||||||
|
elif isinstance(e, ValidationError):
|
||||||
|
json_rpc_error = InvalidRequestError(data=json.loads(e.json()))
|
||||||
|
else:
|
||||||
|
logger.error(f"Unhandled exception: {e}")
|
||||||
|
json_rpc_error = InternalError()
|
||||||
|
|
||||||
|
response = JSONRPCResponse(id=None, error=json_rpc_error)
|
||||||
|
return JSONResponse(response.model_dump(exclude_none=True), status_code=400)
|
||||||
|
|
||||||
|
def _create_response(self, result: Any) -> JSONResponse | EventSourceResponse:
|
||||||
|
if isinstance(result, AsyncIterable):
|
||||||
|
|
||||||
|
async def event_generator(result) -> AsyncIterable[dict[str, str]]:
|
||||||
|
async for item in result:
|
||||||
|
yield {"data": item.model_dump_json(exclude_none=True)}
|
||||||
|
|
||||||
|
return EventSourceResponse(event_generator(result))
|
||||||
|
elif isinstance(result, JSONRPCResponse):
|
||||||
|
return JSONResponse(result.model_dump(exclude_none=True))
|
||||||
|
else:
|
||||||
|
logger.error(f"Unexpected result type: {type(result)}")
|
||||||
|
raise ValueError(f"Unexpected result type: {type(result)}")
|
||||||
|
|
@ -0,0 +1,277 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Union, AsyncIterable, List
|
||||||
|
from common.types import Task
|
||||||
|
from common.types import (
|
||||||
|
JSONRPCResponse,
|
||||||
|
TaskIdParams,
|
||||||
|
TaskQueryParams,
|
||||||
|
GetTaskRequest,
|
||||||
|
TaskNotFoundError,
|
||||||
|
SendTaskRequest,
|
||||||
|
CancelTaskRequest,
|
||||||
|
TaskNotCancelableError,
|
||||||
|
SetTaskPushNotificationRequest,
|
||||||
|
GetTaskPushNotificationRequest,
|
||||||
|
GetTaskResponse,
|
||||||
|
CancelTaskResponse,
|
||||||
|
SendTaskResponse,
|
||||||
|
SetTaskPushNotificationResponse,
|
||||||
|
GetTaskPushNotificationResponse,
|
||||||
|
PushNotificationNotSupportedError,
|
||||||
|
TaskSendParams,
|
||||||
|
TaskStatus,
|
||||||
|
TaskState,
|
||||||
|
TaskResubscriptionRequest,
|
||||||
|
SendTaskStreamingRequest,
|
||||||
|
SendTaskStreamingResponse,
|
||||||
|
Artifact,
|
||||||
|
PushNotificationConfig,
|
||||||
|
TaskStatusUpdateEvent,
|
||||||
|
JSONRPCError,
|
||||||
|
TaskPushNotificationConfig,
|
||||||
|
InternalError,
|
||||||
|
)
|
||||||
|
from common.server.utils import new_not_implemented_error
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class TaskManager(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_send_task_subscribe(
|
||||||
|
self, request: SendTaskStreamingRequest
|
||||||
|
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_set_task_push_notification(
|
||||||
|
self, request: SetTaskPushNotificationRequest
|
||||||
|
) -> SetTaskPushNotificationResponse:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_get_task_push_notification(
|
||||||
|
self, request: GetTaskPushNotificationRequest
|
||||||
|
) -> GetTaskPushNotificationResponse:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_resubscribe_to_task(
|
||||||
|
self, request: TaskResubscriptionRequest
|
||||||
|
) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryTaskManager(TaskManager):
|
||||||
|
def __init__(self):
|
||||||
|
self.tasks: dict[str, Task] = {}
|
||||||
|
self.push_notification_infos: dict[str, PushNotificationConfig] = {}
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {}
|
||||||
|
self.subscriber_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
||||||
|
logger.info(f"Getting task {request.params.id}")
|
||||||
|
task_query_params: TaskQueryParams = request.params
|
||||||
|
|
||||||
|
async with self.lock:
|
||||||
|
task = self.tasks.get(task_query_params.id)
|
||||||
|
if task is None:
|
||||||
|
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
|
||||||
|
|
||||||
|
task_result = self.append_task_history(
|
||||||
|
task, task_query_params.historyLength
|
||||||
|
)
|
||||||
|
|
||||||
|
return GetTaskResponse(id=request.id, result=task_result)
|
||||||
|
|
||||||
|
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
||||||
|
logger.info(f"Cancelling task {request.params.id}")
|
||||||
|
task_id_params: TaskIdParams = request.params
|
||||||
|
|
||||||
|
async with self.lock:
|
||||||
|
task = self.tasks.get(task_id_params.id)
|
||||||
|
if task is None:
|
||||||
|
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
|
||||||
|
|
||||||
|
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_send_task_subscribe(
|
||||||
|
self, request: SendTaskStreamingRequest
|
||||||
|
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig):
|
||||||
|
async with self.lock:
|
||||||
|
task = self.tasks.get(task_id)
|
||||||
|
if task is None:
|
||||||
|
raise ValueError(f"Task not found for {task_id}")
|
||||||
|
|
||||||
|
self.push_notification_infos[task_id] = notification_config
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig:
|
||||||
|
async with self.lock:
|
||||||
|
task = self.tasks.get(task_id)
|
||||||
|
if task is None:
|
||||||
|
raise ValueError(f"Task not found for {task_id}")
|
||||||
|
|
||||||
|
return self.push_notification_infos[task_id]
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
async def has_push_notification_info(self, task_id: str) -> bool:
|
||||||
|
async with self.lock:
|
||||||
|
return task_id in self.push_notification_infos
|
||||||
|
|
||||||
|
|
||||||
|
async def on_set_task_push_notification(
|
||||||
|
self, request: SetTaskPushNotificationRequest
|
||||||
|
) -> SetTaskPushNotificationResponse:
|
||||||
|
logger.info(f"Setting task push notification {request.params.id}")
|
||||||
|
task_notification_params: TaskPushNotificationConfig = request.params
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.set_push_notification_info(task_notification_params.id, task_notification_params.pushNotificationConfig)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error while setting push notification info: {e}")
|
||||||
|
return JSONRPCResponse(
|
||||||
|
id=request.id,
|
||||||
|
error=InternalError(
|
||||||
|
message="An error occurred while setting push notification info"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params)
|
||||||
|
|
||||||
|
async def on_get_task_push_notification(
|
||||||
|
self, request: GetTaskPushNotificationRequest
|
||||||
|
) -> GetTaskPushNotificationResponse:
|
||||||
|
logger.info(f"Getting task push notification {request.params.id}")
|
||||||
|
task_params: TaskIdParams = request.params
|
||||||
|
|
||||||
|
try:
|
||||||
|
notification_info = await self.get_push_notification_info(task_params.id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error while getting push notification info: {e}")
|
||||||
|
return GetTaskPushNotificationResponse(
|
||||||
|
id=request.id,
|
||||||
|
error=InternalError(
|
||||||
|
message="An error occurred while getting push notification info"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return GetTaskPushNotificationResponse(id=request.id, result=TaskPushNotificationConfig(id=task_params.id, pushNotificationConfig=notification_info))
|
||||||
|
|
||||||
|
async def upsert_task(self, task_send_params: TaskSendParams) -> Task:
|
||||||
|
logger.info(f"Upserting task {task_send_params.id}")
|
||||||
|
async with self.lock:
|
||||||
|
task = self.tasks.get(task_send_params.id)
|
||||||
|
if task is None:
|
||||||
|
task = Task(
|
||||||
|
id=task_send_params.id,
|
||||||
|
sessionId = task_send_params.sessionId,
|
||||||
|
messages=[task_send_params.message],
|
||||||
|
status=TaskStatus(state=TaskState.SUBMITTED),
|
||||||
|
history=[task_send_params.message],
|
||||||
|
)
|
||||||
|
self.tasks[task_send_params.id] = task
|
||||||
|
else:
|
||||||
|
task.history.append(task_send_params.message)
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def on_resubscribe_to_task(
|
||||||
|
self, request: TaskResubscriptionRequest
|
||||||
|
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||||
|
return new_not_implemented_error(request.id)
|
||||||
|
|
||||||
|
async def update_store(
|
||||||
|
self, task_id: str, status: TaskStatus, artifacts: list[Artifact]
|
||||||
|
) -> Task:
|
||||||
|
async with self.lock:
|
||||||
|
try:
|
||||||
|
task = self.tasks[task_id]
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"Task {task_id} not found for updating the task")
|
||||||
|
raise ValueError(f"Task {task_id} not found")
|
||||||
|
|
||||||
|
task.status = status
|
||||||
|
|
||||||
|
if status.message is not None:
|
||||||
|
task.history.append(status.message)
|
||||||
|
|
||||||
|
if artifacts is not None:
|
||||||
|
if task.artifacts is None:
|
||||||
|
task.artifacts = []
|
||||||
|
task.artifacts.extend(artifacts)
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
def append_task_history(self, task: Task, historyLength: int | None):
|
||||||
|
new_task = task.model_copy()
|
||||||
|
if historyLength is not None and historyLength > 0:
|
||||||
|
new_task.history = new_task.history[-historyLength:]
|
||||||
|
else:
|
||||||
|
new_task.history = []
|
||||||
|
|
||||||
|
return new_task
|
||||||
|
|
||||||
|
async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False):
|
||||||
|
async with self.subscriber_lock:
|
||||||
|
if task_id not in self.task_sse_subscribers:
|
||||||
|
if is_resubscribe:
|
||||||
|
raise ValueError("Task not found for resubscription")
|
||||||
|
else:
|
||||||
|
self.task_sse_subscribers[task_id] = []
|
||||||
|
|
||||||
|
sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited
|
||||||
|
self.task_sse_subscribers[task_id].append(sse_event_queue)
|
||||||
|
return sse_event_queue
|
||||||
|
|
||||||
|
async def enqueue_events_for_sse(self, task_id, task_update_event):
|
||||||
|
async with self.subscriber_lock:
|
||||||
|
if task_id not in self.task_sse_subscribers:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_subscribers = self.task_sse_subscribers[task_id]
|
||||||
|
for subscriber in current_subscribers:
|
||||||
|
await subscriber.put(task_update_event)
|
||||||
|
|
||||||
|
async def dequeue_events_for_sse(
|
||||||
|
self, request_id, task_id, sse_event_queue: asyncio.Queue
|
||||||
|
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
event = await sse_event_queue.get()
|
||||||
|
if isinstance(event, JSONRPCError):
|
||||||
|
yield SendTaskStreamingResponse(id=request_id, error=event)
|
||||||
|
break
|
||||||
|
|
||||||
|
yield SendTaskStreamingResponse(id=request_id, result=event)
|
||||||
|
if isinstance(event, TaskStatusUpdateEvent) and event.final:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
async with self.subscriber_lock:
|
||||||
|
if task_id in self.task_sse_subscribers:
|
||||||
|
self.task_sse_subscribers[task_id].remove(sse_event_queue)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
from common.types import (
|
||||||
|
JSONRPCResponse,
|
||||||
|
ContentTypeNotSupportedError,
|
||||||
|
UnsupportedOperationError,
|
||||||
|
)
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def are_modalities_compatible(
|
||||||
|
server_output_modes: List[str], client_output_modes: List[str]
|
||||||
|
):
|
||||||
|
"""Modalities are compatible if they are both non-empty
|
||||||
|
and there is at least one common element."""
|
||||||
|
if client_output_modes is None or len(client_output_modes) == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if server_output_modes is None or len(server_output_modes) == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return any(x in server_output_modes for x in client_output_modes)
|
||||||
|
|
||||||
|
|
||||||
|
def new_incompatible_types_error(request_id):
|
||||||
|
return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError())
|
||||||
|
|
||||||
|
|
||||||
|
def new_not_implemented_error(request_id):
|
||||||
|
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())
|
||||||
|
|
@ -0,0 +1,365 @@
|
||||||
|
from typing import Union, Any
|
||||||
|
from pydantic import BaseModel, Field, TypeAdapter
|
||||||
|
from typing import Literal, List, Annotated, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from pydantic import model_validator, ConfigDict, field_serializer
|
||||||
|
from uuid import uuid4
|
||||||
|
from enum import Enum
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
class TaskState(str, Enum):
|
||||||
|
SUBMITTED = "submitted"
|
||||||
|
WORKING = "working"
|
||||||
|
INPUT_REQUIRED = "input-required"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
CANCELED = "canceled"
|
||||||
|
FAILED = "failed"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class TextPart(BaseModel):
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileContent(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
mimeType: str | None = None
|
||||||
|
bytes: str | None = None
|
||||||
|
uri: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_content(self) -> Self:
|
||||||
|
if not (self.bytes or self.uri):
|
||||||
|
raise ValueError("Either 'bytes' or 'uri' must be present in the file data")
|
||||||
|
if self.bytes and self.uri:
|
||||||
|
raise ValueError(
|
||||||
|
"Only one of 'bytes' or 'uri' can be present in the file data"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class FilePart(BaseModel):
|
||||||
|
type: Literal["file"] = "file"
|
||||||
|
file: FileContent
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DataPart(BaseModel):
|
||||||
|
type: Literal["data"] = "data"
|
||||||
|
data: dict[str, Any]
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")]
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
role: Literal["user", "agent"]
|
||||||
|
parts: List[Part]
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(BaseModel):
|
||||||
|
state: TaskState
|
||||||
|
message: Message | None = None
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
@field_serializer("timestamp")
|
||||||
|
def serialize_dt(self, dt: datetime, _info):
|
||||||
|
return dt.isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
class Artifact(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
parts: List[Part]
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
index: int = 0
|
||||||
|
append: bool | None = None
|
||||||
|
lastChunk: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Task(BaseModel):
|
||||||
|
id: str
|
||||||
|
sessionId: str | None = None
|
||||||
|
status: TaskStatus
|
||||||
|
artifacts: List[Artifact] | None = None
|
||||||
|
history: List[Message] | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusUpdateEvent(BaseModel):
|
||||||
|
id: str
|
||||||
|
status: TaskStatus
|
||||||
|
final: bool = False
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskArtifactUpdateEvent(BaseModel):
|
||||||
|
id: str
|
||||||
|
artifact: Artifact
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationInfo(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
schemes: List[str]
|
||||||
|
credentials: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PushNotificationConfig(BaseModel):
|
||||||
|
url: str
|
||||||
|
token: str | None = None
|
||||||
|
authentication: AuthenticationInfo | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskIdParams(BaseModel):
|
||||||
|
id: str
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskQueryParams(TaskIdParams):
|
||||||
|
historyLength: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskSendParams(BaseModel):
|
||||||
|
id: str
|
||||||
|
sessionId: str = Field(default_factory=lambda: uuid4().hex)
|
||||||
|
message: Message
|
||||||
|
acceptedOutputModes: Optional[List[str]] = None
|
||||||
|
pushNotification: PushNotificationConfig | None = None
|
||||||
|
historyLength: int | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskPushNotificationConfig(BaseModel):
|
||||||
|
id: str
|
||||||
|
pushNotificationConfig: PushNotificationConfig
|
||||||
|
|
||||||
|
|
||||||
|
## RPC Messages
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCMessage(BaseModel):
|
||||||
|
jsonrpc: Literal["2.0"] = "2.0"
|
||||||
|
id: int | str | None = Field(default_factory=lambda: uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCRequest(JSONRPCMessage):
|
||||||
|
method: str
|
||||||
|
params: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCError(BaseModel):
|
||||||
|
code: int
|
||||||
|
message: str
|
||||||
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCResponse(JSONRPCMessage):
|
||||||
|
result: Any | None = None
|
||||||
|
error: JSONRPCError | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SendTaskRequest(JSONRPCRequest):
|
||||||
|
method: Literal["tasks/send"] = "tasks/send"
|
||||||
|
params: TaskSendParams
|
||||||
|
|
||||||
|
|
||||||
|
class SendTaskResponse(JSONRPCResponse):
|
||||||
|
result: Task | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SendTaskStreamingRequest(JSONRPCRequest):
|
||||||
|
method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe"
|
||||||
|
params: TaskSendParams
|
||||||
|
|
||||||
|
|
||||||
|
class SendTaskStreamingResponse(JSONRPCResponse):
|
||||||
|
result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GetTaskRequest(JSONRPCRequest):
|
||||||
|
method: Literal["tasks/get"] = "tasks/get"
|
||||||
|
params: TaskQueryParams
|
||||||
|
|
||||||
|
|
||||||
|
class GetTaskResponse(JSONRPCResponse):
|
||||||
|
result: Task | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CancelTaskRequest(JSONRPCRequest):
|
||||||
|
method: Literal["tasks/cancel",] = "tasks/cancel"
|
||||||
|
params: TaskIdParams
|
||||||
|
|
||||||
|
|
||||||
|
class CancelTaskResponse(JSONRPCResponse):
|
||||||
|
result: Task | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SetTaskPushNotificationRequest(JSONRPCRequest):
|
||||||
|
method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set"
|
||||||
|
params: TaskPushNotificationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SetTaskPushNotificationResponse(JSONRPCResponse):
|
||||||
|
result: TaskPushNotificationConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GetTaskPushNotificationRequest(JSONRPCRequest):
|
||||||
|
method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get"
|
||||||
|
params: TaskIdParams
|
||||||
|
|
||||||
|
|
||||||
|
class GetTaskPushNotificationResponse(JSONRPCResponse):
|
||||||
|
result: TaskPushNotificationConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskResubscriptionRequest(JSONRPCRequest):
|
||||||
|
method: Literal["tasks/resubscribe",] = "tasks/resubscribe"
|
||||||
|
params: TaskIdParams
|
||||||
|
|
||||||
|
|
||||||
|
A2ARequest = TypeAdapter(
|
||||||
|
Annotated[
|
||||||
|
Union[
|
||||||
|
SendTaskRequest,
|
||||||
|
GetTaskRequest,
|
||||||
|
CancelTaskRequest,
|
||||||
|
SetTaskPushNotificationRequest,
|
||||||
|
GetTaskPushNotificationRequest,
|
||||||
|
TaskResubscriptionRequest,
|
||||||
|
SendTaskStreamingRequest,
|
||||||
|
],
|
||||||
|
Field(discriminator="method"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
## Error types
|
||||||
|
|
||||||
|
|
||||||
|
class JSONParseError(JSONRPCError):
|
||||||
|
code: int = -32700
|
||||||
|
message: str = "Invalid JSON payload"
|
||||||
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRequestError(JSONRPCError):
|
||||||
|
code: int = -32600
|
||||||
|
message: str = "Request payload validation error"
|
||||||
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MethodNotFoundError(JSONRPCError):
|
||||||
|
code: int = -32601
|
||||||
|
message: str = "Method not found"
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidParamsError(JSONRPCError):
|
||||||
|
code: int = -32602
|
||||||
|
message: str = "Invalid parameters"
|
||||||
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InternalError(JSONRPCError):
|
||||||
|
code: int = -32603
|
||||||
|
message: str = "Internal error"
|
||||||
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskNotFoundError(JSONRPCError):
|
||||||
|
code: int = -32001
|
||||||
|
message: str = "Task not found"
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskNotCancelableError(JSONRPCError):
|
||||||
|
code: int = -32002
|
||||||
|
message: str = "Task cannot be canceled"
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PushNotificationNotSupportedError(JSONRPCError):
|
||||||
|
code: int = -32003
|
||||||
|
message: str = "Push Notification is not supported"
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedOperationError(JSONRPCError):
|
||||||
|
code: int = -32004
|
||||||
|
message: str = "This operation is not supported"
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ContentTypeNotSupportedError(JSONRPCError):
|
||||||
|
code: int = -32005
|
||||||
|
message: str = "Incompatible content types"
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentProvider(BaseModel):
|
||||||
|
organization: str
|
||||||
|
url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCapabilities(BaseModel):
|
||||||
|
streaming: bool = False
|
||||||
|
pushNotifications: bool = False
|
||||||
|
stateTransitionHistory: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AgentAuthentication(BaseModel):
|
||||||
|
schemes: List[str]
|
||||||
|
credentials: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSkill(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
tags: List[str] | None = None
|
||||||
|
examples: List[str] | None = None
|
||||||
|
inputModes: List[str] | None = None
|
||||||
|
outputModes: List[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCard(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
url: str
|
||||||
|
provider: AgentProvider | None = None
|
||||||
|
version: str
|
||||||
|
documentationUrl: str | None = None
|
||||||
|
capabilities: AgentCapabilities
|
||||||
|
authentication: AgentAuthentication | None = None
|
||||||
|
defaultInputModes: List[str] = ["text"]
|
||||||
|
defaultOutputModes: List[str] = ["text"]
|
||||||
|
skills: List[AgentSkill]
|
||||||
|
|
||||||
|
|
||||||
|
class A2AClientError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class A2AClientHTTPError(A2AClientError):
|
||||||
|
def __init__(self, status_code: int, message: str):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
super().__init__(f"HTTP Error {status_code}: {message}")
|
||||||
|
|
||||||
|
|
||||||
|
class A2AClientJSONError(A2AClientError):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(f"JSON Error: {message}")
|
||||||
|
|
||||||
|
|
||||||
|
class MissingAPIKeyError(Exception):
|
||||||
|
"""Exception for missing API key."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""In Memory Cache utility."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryCache:
|
||||||
|
"""A thread-safe Singleton class to manage cache data.
|
||||||
|
|
||||||
|
Ensures only one instance of the cache exists across the application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: Optional["InMemoryCache"] = None
|
||||||
|
_lock: threading.Lock = threading.Lock()
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
"""Override __new__ to control instance creation (Singleton pattern).
|
||||||
|
|
||||||
|
Uses a lock to ensure thread safety during the first instantiation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The singleton instance of InMemoryCache.
|
||||||
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the cache storage.
|
||||||
|
|
||||||
|
Uses a flag (_initialized) to ensure this logic runs only on the very first
|
||||||
|
creation of the singleton instance.
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
with self._lock:
|
||||||
|
if not self._initialized:
|
||||||
|
# print("Initializing SessionCache storage")
|
||||||
|
self._cache_data: Dict[str, Dict[str, Any]] = {}
|
||||||
|
self._ttl: Dict[str, float] = {}
|
||||||
|
self._data_lock: threading.Lock = threading.Lock()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
||||||
|
"""Set a key-value pair.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key for the data.
|
||||||
|
value: The data to store.
|
||||||
|
ttl: Time to live in seconds. If None, data will not expire.
|
||||||
|
"""
|
||||||
|
with self._data_lock:
|
||||||
|
self._cache_data[key] = value
|
||||||
|
|
||||||
|
if ttl is not None:
|
||||||
|
self._ttl[key] = time.time() + ttl
|
||||||
|
else:
|
||||||
|
if key in self._ttl:
|
||||||
|
del self._ttl[key]
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""Get the value associated with a key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key for the data within the session.
|
||||||
|
default: The value to return if the session or key is not found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The cached value, or the default value if not found.
|
||||||
|
"""
|
||||||
|
with self._data_lock:
|
||||||
|
if key in self._ttl and time.time() > self._ttl[key]:
|
||||||
|
del self._cache_data[key]
|
||||||
|
del self._ttl[key]
|
||||||
|
return default
|
||||||
|
return self._cache_data.get(key, default)
|
||||||
|
|
||||||
|
def delete(self, key: str) -> None:
|
||||||
|
"""Delete a specific key-value pair from a cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the key was found and deleted, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with self._data_lock:
|
||||||
|
if key in self._cache_data:
|
||||||
|
del self._cache_data[key]
|
||||||
|
if key in self._ttl:
|
||||||
|
del self._ttl[key]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear(self) -> bool:
|
||||||
|
"""Remove all data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the data was cleared, False otherwise.
|
||||||
|
"""
|
||||||
|
with self._data_lock:
|
||||||
|
self._cache_data.clear()
|
||||||
|
self._ttl.clear()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
@ -0,0 +1,135 @@
|
||||||
|
from jwcrypto import jwk
|
||||||
|
import uuid
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.requests import Request
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
import httpx
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from jwt import PyJWK, PyJWKClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
AUTH_HEADER_PREFIX = 'Bearer '
|
||||||
|
|
||||||
|
class PushNotificationAuth:
|
||||||
|
def _calculate_request_body_sha256(self, data: dict[str, Any]):
|
||||||
|
"""Calculates the SHA256 hash of a request body.
|
||||||
|
|
||||||
|
This logic needs to be same for both the agent who signs the payload and the client verifier.
|
||||||
|
"""
|
||||||
|
body_str = json.dumps(
|
||||||
|
data,
|
||||||
|
ensure_ascii=False,
|
||||||
|
allow_nan=False,
|
||||||
|
indent=None,
|
||||||
|
separators=(",", ":"),
|
||||||
|
)
|
||||||
|
return hashlib.sha256(body_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
class PushNotificationSenderAuth(PushNotificationAuth):
|
||||||
|
def __init__(self):
|
||||||
|
self.public_keys = []
|
||||||
|
self.private_key_jwk: PyJWK = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def verify_push_notification_url(url: str) -> bool:
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
try:
|
||||||
|
validation_token = str(uuid.uuid4())
|
||||||
|
response = await client.get(
|
||||||
|
url,
|
||||||
|
params={"validationToken": validation_token}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
is_verified = response.text == validation_token
|
||||||
|
|
||||||
|
logger.info(f"Verified push-notification URL: {url} => {is_verified}")
|
||||||
|
return is_verified
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during sending push-notification for URL {url}: {e}")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def generate_jwk(self):
|
||||||
|
key = jwk.JWK.generate(kty='RSA', size=2048, kid=str(uuid.uuid4()), use="sig")
|
||||||
|
self.public_keys.append(key.export_public(as_dict=True))
|
||||||
|
self.private_key_jwk = PyJWK.from_json(key.export_private())
|
||||||
|
|
||||||
|
def handle_jwks_endpoint(self, _request: Request):
|
||||||
|
"""Allow clients to fetch public keys.
|
||||||
|
"""
|
||||||
|
return JSONResponse({
|
||||||
|
"keys": self.public_keys
|
||||||
|
})
|
||||||
|
|
||||||
|
def _generate_jwt(self, data: dict[str, Any]):
|
||||||
|
"""JWT is generated by signing both the request payload SHA digest and time of token generation.
|
||||||
|
|
||||||
|
Payload is signed with private key and it ensures the integrity of payload for client.
|
||||||
|
Including iat prevents from replay attack.
|
||||||
|
"""
|
||||||
|
|
||||||
|
iat = int(time.time())
|
||||||
|
|
||||||
|
return jwt.encode(
|
||||||
|
{"iat": iat, "request_body_sha256": self._calculate_request_body_sha256(data)},
|
||||||
|
key=self.private_key_jwk,
|
||||||
|
headers={"kid": self.private_key_jwk.key_id},
|
||||||
|
algorithm="RS256"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_push_notification(self, url: str, data: dict[str, Any]):
|
||||||
|
jwt_token = self._generate_jwt(data)
|
||||||
|
headers = {'Authorization': f"Bearer {jwt_token}"}
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
json=data,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
logger.info(f"Push-notification sent for URL: {url}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during sending push-notification for URL {url}: {e}")
|
||||||
|
|
||||||
|
class PushNotificationReceiverAuth(PushNotificationAuth):
|
||||||
|
def __init__(self):
|
||||||
|
self.public_keys_jwks = []
|
||||||
|
self.jwks_client = None
|
||||||
|
|
||||||
|
async def load_jwks(self, jwks_url: str):
|
||||||
|
self.jwks_client = PyJWKClient(jwks_url)
|
||||||
|
|
||||||
|
async def verify_push_notification(self, request: Request) -> bool:
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX):
|
||||||
|
print("Invalid authorization header")
|
||||||
|
return False
|
||||||
|
|
||||||
|
token = auth_header[len(AUTH_HEADER_PREFIX):]
|
||||||
|
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
|
||||||
|
|
||||||
|
decode_token = jwt.decode(
|
||||||
|
token,
|
||||||
|
signing_key,
|
||||||
|
options={"require": ["iat", "request_body_sha256"]},
|
||||||
|
algorithms=["RS256"],
|
||||||
|
)
|
||||||
|
|
||||||
|
actual_body_sha256 = self._calculate_request_body_sha256(await request.json())
|
||||||
|
if actual_body_sha256 != decode_token["request_body_sha256"]:
|
||||||
|
# Payload signature does not match the digest in signed token.
|
||||||
|
raise ValueError("Invalid request body")
|
||||||
|
|
||||||
|
if time.time() - decode_token["iat"] > 60 * 5:
|
||||||
|
# Do not allow push-notifications older than 5 minutes.
|
||||||
|
# This is to prevent replay attack.
|
||||||
|
raise ValueError("Token is expired")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
from pocketflow import Flow
|
||||||
|
from nodes import DecideAction, SearchWeb, AnswerQuestion
|
||||||
|
|
||||||
|
def create_agent_flow():
|
||||||
|
"""
|
||||||
|
Create and connect the nodes to form a complete agent flow.
|
||||||
|
|
||||||
|
The flow works like this:
|
||||||
|
1. DecideAction node decides whether to search or answer
|
||||||
|
2. If search, go to SearchWeb node
|
||||||
|
3. If answer, go to AnswerQuestion node
|
||||||
|
4. After SearchWeb completes, go back to DecideAction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Flow: A complete research agent flow
|
||||||
|
"""
|
||||||
|
# Create instances of each node
|
||||||
|
decide = DecideAction()
|
||||||
|
search = SearchWeb()
|
||||||
|
answer = AnswerQuestion()
|
||||||
|
|
||||||
|
# Connect the nodes
|
||||||
|
# If DecideAction returns "search", go to SearchWeb
|
||||||
|
decide - "search" >> search
|
||||||
|
|
||||||
|
# If DecideAction returns "answer", go to AnswerQuestion
|
||||||
|
decide - "answer" >> answer
|
||||||
|
|
||||||
|
# After SearchWeb completes and returns "decide", go back to DecideAction
|
||||||
|
search - "decide" >> decide
|
||||||
|
|
||||||
|
# Create and return the flow, starting with the DecideAction node
|
||||||
|
return Flow(start=decide)
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
import sys
|
||||||
|
from flow import create_agent_flow
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Simple function to process a question."""
|
||||||
|
# Default question
|
||||||
|
default_question = "Who won the Nobel Prize in Physics 2024?"
|
||||||
|
|
||||||
|
# Get question from command line if provided with --
|
||||||
|
question = default_question
|
||||||
|
for arg in sys.argv[1:]:
|
||||||
|
if arg.startswith("--"):
|
||||||
|
question = arg[2:]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create the agent flow
|
||||||
|
agent_flow = create_agent_flow()
|
||||||
|
|
||||||
|
# Process the question
|
||||||
|
shared = {"question": question}
|
||||||
|
print(f"🤔 Processing question: {question}")
|
||||||
|
agent_flow.run(shared)
|
||||||
|
print("\n🎯 Final Answer:")
|
||||||
|
print(shared.get("answer", "No answer found"))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,135 @@
|
||||||
|
from pocketflow import Node
|
||||||
|
from utils import call_llm, search_web
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
class DecideAction(Node):
|
||||||
|
def prep(self, shared):
|
||||||
|
"""Prepare the context and question for the decision-making process."""
|
||||||
|
# Get the current context (default to "No previous search" if none exists)
|
||||||
|
context = shared.get("context", "No previous search")
|
||||||
|
# Get the question from the shared store
|
||||||
|
question = shared["question"]
|
||||||
|
# Return both for the exec step
|
||||||
|
return question, context
|
||||||
|
|
||||||
|
def exec(self, inputs):
|
||||||
|
"""Call the LLM to decide whether to search or answer."""
|
||||||
|
question, context = inputs
|
||||||
|
|
||||||
|
print(f"🤔 Agent deciding what to do next...")
|
||||||
|
|
||||||
|
# Create a prompt to help the LLM decide what to do next with proper yaml formatting
|
||||||
|
prompt = f"""
|
||||||
|
### CONTEXT
|
||||||
|
You are a research assistant that can search the web.
|
||||||
|
Question: {question}
|
||||||
|
Previous Research: {context}
|
||||||
|
|
||||||
|
### ACTION SPACE
|
||||||
|
[1] search
|
||||||
|
Description: Look up more information on the web
|
||||||
|
Parameters:
|
||||||
|
- query (str): What to search for
|
||||||
|
|
||||||
|
[2] answer
|
||||||
|
Description: Answer the question with current knowledge
|
||||||
|
Parameters:
|
||||||
|
- answer (str): Final answer to the question
|
||||||
|
|
||||||
|
## NEXT ACTION
|
||||||
|
Decide the next action based on the context and available actions.
|
||||||
|
Return your response in this format:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
thinking: |
|
||||||
|
<your step-by-step reasoning process>
|
||||||
|
action: search OR answer
|
||||||
|
reason: <why you chose this action>
|
||||||
|
answer: <if action is answer>
|
||||||
|
search_query: <specific search query if action is search>
|
||||||
|
```
|
||||||
|
IMPORTANT: Make sure to:
|
||||||
|
1. Use proper indentation (4 spaces) for all multi-line fields
|
||||||
|
2. Use the | character for multi-line text fields
|
||||||
|
3. Keep single-line fields without the | character
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Call the LLM to make a decision
|
||||||
|
response = call_llm(prompt)
|
||||||
|
|
||||||
|
# Parse the response to get the decision
|
||||||
|
yaml_str = response.split("```yaml")[1].split("```")[0].strip()
|
||||||
|
decision = yaml.safe_load(yaml_str)
|
||||||
|
|
||||||
|
return decision
|
||||||
|
|
||||||
|
def post(self, shared, prep_res, exec_res):
|
||||||
|
"""Save the decision and determine the next step in the flow."""
|
||||||
|
# If LLM decided to search, save the search query
|
||||||
|
if exec_res["action"] == "search":
|
||||||
|
shared["search_query"] = exec_res["search_query"]
|
||||||
|
print(f"🔍 Agent decided to search for: {exec_res['search_query']}")
|
||||||
|
else:
|
||||||
|
shared["context"] = exec_res["answer"] #save the context if LLM gives the answer without searching.
|
||||||
|
print(f"💡 Agent decided to answer the question")
|
||||||
|
|
||||||
|
# Return the action to determine the next node in the flow
|
||||||
|
return exec_res["action"]
|
||||||
|
|
||||||
|
class SearchWeb(Node):
|
||||||
|
def prep(self, shared):
|
||||||
|
"""Get the search query from the shared store."""
|
||||||
|
return shared["search_query"]
|
||||||
|
|
||||||
|
def exec(self, search_query):
|
||||||
|
"""Search the web for the given query."""
|
||||||
|
# Call the search utility function
|
||||||
|
print(f"🌐 Searching the web for: {search_query}")
|
||||||
|
results = search_web(search_query)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def post(self, shared, prep_res, exec_res):
|
||||||
|
"""Save the search results and go back to the decision node."""
|
||||||
|
# Add the search results to the context in the shared store
|
||||||
|
previous = shared.get("context", "")
|
||||||
|
shared["context"] = previous + "\n\nSEARCH: " + shared["search_query"] + "\nRESULTS: " + exec_res
|
||||||
|
|
||||||
|
print(f"📚 Found information, analyzing results...")
|
||||||
|
|
||||||
|
# Always go back to the decision node after searching
|
||||||
|
return "decide"
|
||||||
|
|
||||||
|
class AnswerQuestion(Node):
|
||||||
|
def prep(self, shared):
|
||||||
|
"""Get the question and context for answering."""
|
||||||
|
return shared["question"], shared.get("context", "")
|
||||||
|
|
||||||
|
def exec(self, inputs):
|
||||||
|
"""Call the LLM to generate a final answer."""
|
||||||
|
question, context = inputs
|
||||||
|
|
||||||
|
print(f"✍️ Crafting final answer...")
|
||||||
|
|
||||||
|
# Create a prompt for the LLM to answer the question
|
||||||
|
prompt = f"""
|
||||||
|
### CONTEXT
|
||||||
|
Based on the following information, answer the question.
|
||||||
|
Question: {question}
|
||||||
|
Research: {context}
|
||||||
|
|
||||||
|
## YOUR ANSWER:
|
||||||
|
Provide a comprehensive answer using the research results.
|
||||||
|
"""
|
||||||
|
# Call the LLM to generate an answer
|
||||||
|
answer = call_llm(prompt)
|
||||||
|
return answer
|
||||||
|
|
||||||
|
def post(self, shared, prep_res, exec_res):
|
||||||
|
"""Save the final answer and complete the flow."""
|
||||||
|
# Save the answer in the shared store
|
||||||
|
shared["answer"] = exec_res
|
||||||
|
|
||||||
|
print(f"✅ Answer generated successfully")
|
||||||
|
|
||||||
|
# We're done - no need to continue the flow
|
||||||
|
return "done"
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
# For PocketFlow Agent Logic
|
||||||
|
pocketflow>=0.0.1
|
||||||
|
openai>=1.0.0
|
||||||
|
duckduckgo-search>=7.5.2
|
||||||
|
pyyaml>=5.1
|
||||||
|
|
||||||
|
# For A2A Server Infrastructure (from common)
|
||||||
|
starlette>=0.37.2,<0.38.0
|
||||||
|
uvicorn[standard]>=0.29.0,<0.30.0
|
||||||
|
sse-starlette>=1.8.2,<2.0.0
|
||||||
|
pydantic>=2.0.0,<3.0.0
|
||||||
|
httpx>=0.27.0,<0.28.0
|
||||||
|
anyio>=3.0.0,<5.0.0 # Dependency of starlette/httpx
|
||||||
|
|
||||||
|
# For running __main__.py
|
||||||
|
click>=8.0.0,<9.0.0
|
||||||
|
|
||||||
|
# For A2A Client
|
||||||
|
httpx>=0.27.0,<0.28.0
|
||||||
|
asyncclick>=8.1.8 # Or just 'click' if you prefer asyncio.run
|
||||||
|
pydantic>=2.0.0,<3.0.0 # For common.types
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
# FILE: pocketflow_a2a_agent/task_manager.py
|
||||||
|
import logging
|
||||||
|
from typing import AsyncIterable, Union
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Import from the common code you copied
|
||||||
|
from common.server.task_manager import InMemoryTaskManager
|
||||||
|
from common.types import (
|
||||||
|
JSONRPCResponse, SendTaskRequest, SendTaskResponse,
|
||||||
|
SendTaskStreamingRequest, SendTaskStreamingResponse, Task, TaskSendParams,
|
||||||
|
TaskState, TaskStatus, TextPart, Artifact, UnsupportedOperationError,
|
||||||
|
InternalError, InvalidParamsError
|
||||||
|
)
|
||||||
|
import common.server.utils as server_utils
|
||||||
|
|
||||||
|
# Import directly from your original PocketFlow files
|
||||||
|
from .flow import create_agent_flow # Assumes flow.py is in the same directory
|
||||||
|
from .utils import call_llm, search_web # Make utils functions available if needed elsewhere
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class PocketFlowTaskManager(InMemoryTaskManager):
|
||||||
|
""" TaskManager implementation that runs the PocketFlow agent. """
|
||||||
|
|
||||||
|
SUPPORTED_CONTENT_TYPES = ["text", "text/plain"] # Define what the agent accepts/outputs
|
||||||
|
|
||||||
|
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
||||||
|
"""Handles non-streaming task requests."""
|
||||||
|
logger.info(f"Received task send request: {request.params.id}")
|
||||||
|
|
||||||
|
# Validate output modes
|
||||||
|
if not server_utils.are_modalities_compatible(
|
||||||
|
request.params.acceptedOutputModes, self.SUPPORTED_CONTENT_TYPES
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Unsupported output mode. Received %s, Support %s",
|
||||||
|
request.params.acceptedOutputModes, self.SUPPORTED_CONTENT_TYPES
|
||||||
|
)
|
||||||
|
return SendTaskResponse(id=request.id, error=server_utils.new_incompatible_types_error(request.id).error)
|
||||||
|
|
||||||
|
# Upsert the task in the store (initial state: submitted)
|
||||||
|
# We create the task first so its state can be tracked, even if the sync execution fails
|
||||||
|
await self.upsert_task(request.params)
|
||||||
|
# Update state to working before running
|
||||||
|
await self.update_store(request.params.id, TaskStatus(state=TaskState.WORKING), [])
|
||||||
|
|
||||||
|
|
||||||
|
# --- Run the PocketFlow logic ---
|
||||||
|
task_params: TaskSendParams = request.params
|
||||||
|
query = self._get_user_query(task_params)
|
||||||
|
if query is None:
|
||||||
|
fail_status = TaskStatus(state=TaskState.FAILED, message=Message(role="agent", parts=[TextPart(text="No text query found")]))
|
||||||
|
await self.update_store(task_params.id, fail_status, [])
|
||||||
|
return SendTaskResponse(id=request.id, error=InvalidParamsError(message="No text query found in message parts"))
|
||||||
|
|
||||||
|
shared_data = {"question": query}
|
||||||
|
agent_flow = create_agent_flow() # Create the flow instance
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run the synchronous PocketFlow
|
||||||
|
# In a real async server, you might run this in a separate thread/process
|
||||||
|
# executor to avoid blocking the event loop. For simplicity here, we run it directly.
|
||||||
|
# Consider adding a timeout if flows can hang.
|
||||||
|
logger.info(f"Running PocketFlow for task {task_params.id}...")
|
||||||
|
final_state_dict = agent_flow.run(shared_data)
|
||||||
|
logger.info(f"PocketFlow completed for task {task_params.id}")
|
||||||
|
answer_text = final_state_dict.get("answer", "Agent did not produce a final answer text.")
|
||||||
|
|
||||||
|
# --- Package result into A2A Task ---
|
||||||
|
final_task_status = TaskStatus(state=TaskState.COMPLETED)
|
||||||
|
# Package the answer as an artifact
|
||||||
|
final_artifact = Artifact(parts=[TextPart(text=answer_text)])
|
||||||
|
|
||||||
|
# Update the task in the store with final status and artifact
|
||||||
|
final_task = await self.update_store(
|
||||||
|
task_params.id, final_task_status, [final_artifact]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare and return the A2A response
|
||||||
|
task_result = self.append_task_history(final_task, task_params.historyLength)
|
||||||
|
return SendTaskResponse(id=request.id, result=task_result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing PocketFlow for task {task_params.id}: {e}", exc_info=True)
|
||||||
|
# Update task state to FAILED
|
||||||
|
fail_status = TaskStatus(
|
||||||
|
state=TaskState.FAILED,
|
||||||
|
message=Message(role="agent", parts=[TextPart(text=f"Agent execution failed: {e}")])
|
||||||
|
)
|
||||||
|
await self.update_store(task_params.id, fail_status, [])
|
||||||
|
return SendTaskResponse(id=request.id, error=InternalError(message=f"Agent error: {e}"))
|
||||||
|
|
||||||
|
async def on_send_task_subscribe(
|
||||||
|
self, request: SendTaskStreamingRequest
|
||||||
|
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||||
|
"""Handles streaming requests - Not implemented for this synchronous agent."""
|
||||||
|
logger.warning(f"Streaming requested for task {request.params.id}, but not supported by this PocketFlow agent implementation.")
|
||||||
|
# Return an error indicating streaming is not supported
|
||||||
|
return JSONRPCResponse(id=request.id, error=UnsupportedOperationError(message="Streaming not supported by this agent"))
|
||||||
|
|
||||||
|
def _get_user_query(self, task_send_params: TaskSendParams) -> str | None:
|
||||||
|
"""Extracts the first text part from the user message."""
|
||||||
|
if not task_send_params.message or not task_send_params.message.parts:
|
||||||
|
logger.warning(f"No message parts found for task {task_send_params.id}")
|
||||||
|
return None
|
||||||
|
for part in task_send_params.message.parts:
|
||||||
|
# Ensure part is treated as a dictionary if it came from JSON
|
||||||
|
part_dict = part if isinstance(part, dict) else part.model_dump()
|
||||||
|
if part_dict.get("type") == "text" and "text" in part_dict:
|
||||||
|
return part_dict["text"]
|
||||||
|
logger.warning(f"No text part found in message for task {task_send_params.id}")
|
||||||
|
return None # No text part found
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
from openai import OpenAI
|
||||||
|
import os
|
||||||
|
from duckduckgo_search import DDGS
|
||||||
|
|
||||||
|
def call_llm(prompt):
|
||||||
|
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
|
||||||
|
r = client.chat.completions.create(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=[{"role": "user", "content": prompt}]
|
||||||
|
)
|
||||||
|
return r.choices[0].message.content
|
||||||
|
|
||||||
|
def search_web(query):
|
||||||
|
results = DDGS().text(query, max_results=5)
|
||||||
|
# Convert results to a string
|
||||||
|
results_str = "\n\n".join([f"Title: {r['title']}\nURL: {r['href']}\nSnippet: {r['body']}" for r in results])
|
||||||
|
return results_str
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("## Testing call_llm")
|
||||||
|
prompt = "In a few words, what is the meaning of life?"
|
||||||
|
print(f"## Prompt: {prompt}")
|
||||||
|
response = call_llm(prompt)
|
||||||
|
print(f"## Response: {response}")
|
||||||
|
|
||||||
|
print("## Testing search_web")
|
||||||
|
query = "Who won the Nobel Prize in Physics 2024?"
|
||||||
|
print(f"## Query: {query}")
|
||||||
|
results = search_web(query)
|
||||||
|
print(f"## Results: {results}")
|
||||||
Loading…
Reference in New Issue