This commit is contained in:
zachary62 2025-04-12 15:35:16 -04:00
parent 447a2d47d2
commit cd02d65efe
20 changed files with 1798 additions and 0 deletions

View File

@ -0,0 +1,67 @@
# Research Agent
This project demonstrates a simple yet powerful LLM-powered research agent. This implementation is based directly on the tutorial: [LLM Agents are simply Graph — Tutorial For Dummies](https://zacharyhuang.substack.com/p/llm-agent-internal-as-a-graph-tutorial).
👉 Run the tutorial in your browser: [Try Google Colab Notebook](
https://colab.research.google.com/github/The-Pocket/PocketFlow/blob/main/cookbook/pocketflow-agent/demo.ipynb)
## Features
- Performs web searches to gather information
- Makes decisions about when to search vs. when to answer
- Generates comprehensive answers based on research findings
## Getting Started
1. Install the packages you need with this simple command:
```bash
pip install -r requirements.txt
```
2. Let's get your OpenAI API key ready:
```bash
export OPENAI_API_KEY="your-api-key-here"
```
3. Let's do a quick check to make sure your API key is working properly:
```bash
python utils.py
```
This will test both the LLM call and web search features. If you see responses, you're good to go!
4. Try out the agent with the default question (about Nobel Prize winners):
```bash
python main.py
```
5. Got a burning question? Ask anything you want by using the `--` prefix:
```bash
python main.py --"What is quantum computing?"
```
## How It Works?
The magic happens through a simple but powerful graph structure with three main parts:
```mermaid
graph TD
A[DecideAction] -->|"search"| B[SearchWeb]
A -->|"answer"| C[AnswerQuestion]
B -->|"decide"| A
```
Here's what each part does:
1. **DecideAction**: The brain that figures out whether to search or answer
2. **SearchWeb**: The researcher that goes out and finds information
3. **AnswerQuestion**: The writer that crafts the final answer
Here's what's in each file:
- [`main.py`](./main.py): The starting point - runs the whole show!
- [`flow.py`](./flow.py): Connects everything together into a smart agent
- [`nodes.py`](./nodes.py): The building blocks that make decisions and take actions
- [`utils.py`](./utils.py): Helper functions for talking to the LLM and searching the web

View File

@ -0,0 +1,143 @@
# FILE: minimal_a2a_client.py
import asyncio
import asyncclick as click # Using asyncclick for async main
from uuid import uuid4
import json # For potentially inspecting raw errors
# Import from the common directory placed alongside this script
from common.client import A2AClient
from common.types import (
TaskState,
A2AClientError,
TextPart, # Used to construct the message
JSONRPCResponse # Potentially useful for error checking
)
# --- ANSI Colors (Optional but helpful) ---
C_RED = "\x1b[31m"
C_GREEN = "\x1b[32m"
C_YELLOW = "\x1b[33m"
C_BLUE = "\x1b[34m"
C_MAGENTA = "\x1b[35m"
C_CYAN = "\x1b[36m"
C_WHITE = "\x1b[37m"
C_GRAY = "\x1b[90m"
C_BRIGHT_MAGENTA = "\x1b[95m"
C_RESET = "\x1b[0m"
C_BOLD = "\x1b[1m"
def colorize(color, text):
return f"{color}{text}{C_RESET}"
@click.command()
@click.option(
"--agent-url",
default="http://localhost:10003", # Default to the port used in server __main__
help="URL of the PocketFlow A2A agent server.",
)
async def cli(agent_url: str):
"""Minimal CLI client to interact with an A2A agent."""
print(colorize(C_BRIGHT_MAGENTA, f"Connecting to agent at: {agent_url}"))
# Instantiate the client - only URL is needed if not fetching card first
# Note: The PocketFlow wrapper doesn't expose much via the AgentCard,
# so we skip fetching it for this minimal client.
client = A2AClient(url=agent_url)
sessionId = uuid4().hex # Generate a new session ID for this run
print(colorize(C_GRAY, f"Using Session ID: {sessionId}"))
while True:
taskId = uuid4().hex # Generate a new task ID for each interaction
try:
prompt = await click.prompt(
colorize(C_CYAN, "\nEnter your question (:q or quit to exit)"),
prompt_suffix=" > ",
type=str # Ensure prompt returns string
)
except RuntimeError:
# This can happen if stdin is closed, e.g., in some test runners
print(colorize(C_RED, "Failed to read input. Exiting."))
break
if prompt.lower() in [":q", "quit"]:
print(colorize(C_YELLOW, "Exiting client."))
break
# --- Construct A2A Request Payload ---
payload = {
"id": taskId,
"sessionId": sessionId,
"message": {
"role": "user",
"parts": [
{
"type": "text", # Explicitly match TextPart structure
"text": prompt,
}
],
},
"acceptedOutputModes": ["text", "text/plain"], # What the client wants back
# historyLength could be added if needed
}
print(colorize(C_GRAY, f"Sending task {taskId}..."))
try:
# --- Send Task (Non-Streaming) ---
response = await client.send_task(payload)
# --- Process Response ---
if response.error:
print(colorize(C_RED, f"Error from agent (Code: {response.error.code}): {response.error.message}"))
if response.error.data:
print(colorize(C_GRAY, f"Error Data: {response.error.data}"))
elif response.result:
task_result = response.result
print(colorize(C_GREEN, f"Task {task_result.id} finished with state: {task_result.status.state}"))
final_answer = "Agent did not provide a final artifact."
# Extract answer from artifacts (as implemented in PocketFlowTaskManager)
if task_result.artifacts:
try:
# Find the first text part in the first artifact
first_artifact = task_result.artifacts[0]
first_text_part = next(
(p for p in first_artifact.parts if isinstance(p, TextPart)),
None
)
if first_text_part:
final_answer = first_text_part.text
else:
final_answer = f"(Non-text artifact received: {first_artifact.parts})"
except (IndexError, AttributeError, TypeError) as e:
final_answer = f"(Error parsing artifact: {e})"
elif task_result.status.message and task_result.status.message.parts:
# Fallback to status message if no artifact
try:
first_text_part = next(
(p for p in task_result.status.message.parts if isinstance(p, TextPart)),
None
)
if first_text_part:
final_answer = f"(Final status message: {first_text_part.text})"
except (AttributeError, TypeError) as e:
final_answer = f"(Error parsing status message: {e})"
print(colorize(C_BOLD + C_WHITE, f"\nAgent Response:\n{final_answer}"))
else:
# Should not happen if error is None
print(colorize(C_YELLOW, "Received response with no result and no error."))
except A2AClientError as e:
print(colorize(C_RED, f"\nClient Error: {e}"))
except Exception as e:
print(colorize(C_RED, f"\nAn unexpected error occurred: {e}"))
if __name__ == "__main__":
asyncio.run(cli())

View File

@ -0,0 +1,81 @@
# FILE: pocketflow_a2a_agent/__main__.py
import click
import logging
import os
# Import from the common code you copied
from common.server import A2AServer
from common.types import AgentCard, AgentCapabilities, AgentSkill, MissingAPIKeyError
# Import your custom TaskManager (which now imports from your original files)
from .task_manager import PocketFlowTaskManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@click.command()
@click.option("--host", "host", default="localhost")
@click.option("--port", "port", default=10003) # Use a different port from other agents
def main(host, port):
"""Starts the PocketFlow A2A Agent server."""
try:
# Check for necessary API keys (add others if needed)
if not os.getenv("OPENAI_API_KEY"):
raise MissingAPIKeyError("OPENAI_API_KEY environment variable not set.")
# --- Define the Agent Card ---
capabilities = AgentCapabilities(
streaming=False, # This simple implementation is synchronous
pushNotifications=False,
stateTransitionHistory=False # PocketFlow state isn't exposed via A2A history
)
skill = AgentSkill(
id="web_research_qa",
name="Web Research and Answering",
description="Answers questions using web search results when necessary.",
tags=["research", "qa", "web search"],
examples=[
"Who won the Nobel Prize in Physics 2024?",
"What is quantum computing?",
"Summarize the latest news about AI.",
],
# Input/Output modes defined in the TaskManager
inputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
outputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
)
agent_card = AgentCard(
name="PocketFlow Research Agent (A2A Wrapped)",
description="A simple research agent based on PocketFlow, made accessible via A2A.",
url=f"http://{host}:{port}/", # The endpoint A2A clients will use
version="0.1.0-a2a",
capabilities=capabilities,
skills=[skill],
# Assuming no specific provider or auth for this example
provider=None,
authentication=None,
defaultInputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
defaultOutputModes=PocketFlowTaskManager.SUPPORTED_CONTENT_TYPES,
)
# --- Initialize and Start Server ---
task_manager = PocketFlowTaskManager() # Instantiate your custom manager
server = A2AServer(
agent_card=agent_card,
task_manager=task_manager,
host=host,
port=port,
)
logger.info(f"Starting PocketFlow A2A server on http://{host}:{port}")
server.start()
except MissingAPIKeyError as e:
logger.error(f"Configuration Error: {e}")
exit(1)
except Exception as e:
logger.error(f"An error occurred during server startup: {e}", exc_info=True)
exit(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,4 @@
from .client import A2AClient
from .card_resolver import A2ACardResolver
__all__ = ["A2AClient", "A2ACardResolver"]

View File

@ -0,0 +1,21 @@
import httpx
from common.types import (
AgentCard,
A2AClientJSONError,
)
import json
class A2ACardResolver:
def __init__(self, base_url, agent_card_path="/.well-known/agent.json"):
self.base_url = base_url.rstrip("/")
self.agent_card_path = agent_card_path.lstrip("/")
def get_agent_card(self) -> AgentCard:
with httpx.Client() as client:
response = client.get(self.base_url + "/" + self.agent_card_path)
response.raise_for_status()
try:
return AgentCard(**response.json())
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e

View File

@ -0,0 +1,86 @@
import httpx
from httpx_sse import connect_sse
from typing import Any, AsyncIterable
from common.types import (
AgentCard,
GetTaskRequest,
SendTaskRequest,
SendTaskResponse,
JSONRPCRequest,
GetTaskResponse,
CancelTaskResponse,
CancelTaskRequest,
SetTaskPushNotificationRequest,
SetTaskPushNotificationResponse,
GetTaskPushNotificationRequest,
GetTaskPushNotificationResponse,
A2AClientHTTPError,
A2AClientJSONError,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
)
import json
class A2AClient:
def __init__(self, agent_card: AgentCard = None, url: str = None):
if agent_card:
self.url = agent_card.url
elif url:
self.url = url
else:
raise ValueError("Must provide either agent_card or url")
async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse:
request = SendTaskRequest(params=payload)
return SendTaskResponse(**await self._send_request(request))
async def send_task_streaming(
self, payload: dict[str, Any]
) -> AsyncIterable[SendTaskStreamingResponse]:
request = SendTaskStreamingRequest(params=payload)
with httpx.Client(timeout=None) as client:
with connect_sse(
client, "POST", self.url, json=request.model_dump()
) as event_source:
try:
for sse in event_source.iter_sse():
yield SendTaskStreamingResponse(**json.loads(sse.data))
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(400, str(e)) from e
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
async with httpx.AsyncClient() as client:
try:
# Image generation could take time, adding timeout
response = await client.post(
self.url, json=request.model_dump(), timeout=30
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse:
request = GetTaskRequest(params=payload)
return GetTaskResponse(**await self._send_request(request))
async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse:
request = CancelTaskRequest(params=payload)
return CancelTaskResponse(**await self._send_request(request))
async def set_task_callback(
self, payload: dict[str, Any]
) -> SetTaskPushNotificationResponse:
request = SetTaskPushNotificationRequest(params=payload)
return SetTaskPushNotificationResponse(**await self._send_request(request))
async def get_task_callback(
self, payload: dict[str, Any]
) -> GetTaskPushNotificationResponse:
request = GetTaskPushNotificationRequest(params=payload)
return GetTaskPushNotificationResponse(**await self._send_request(request))

View File

@ -0,0 +1,4 @@
from .server import A2AServer
from .task_manager import TaskManager, InMemoryTaskManager
__all__ = ["A2AServer", "TaskManager", "InMemoryTaskManager"]

View File

@ -0,0 +1,120 @@
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from starlette.requests import Request
from common.types import (
A2ARequest,
JSONRPCResponse,
InvalidRequestError,
JSONParseError,
GetTaskRequest,
CancelTaskRequest,
SendTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
InternalError,
AgentCard,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
)
from pydantic import ValidationError
import json
from typing import AsyncIterable, Any
from common.server.task_manager import TaskManager
import logging
logger = logging.getLogger(__name__)
class A2AServer:
def __init__(
self,
host="0.0.0.0",
port=5000,
endpoint="/",
agent_card: AgentCard = None,
task_manager: TaskManager = None,
):
self.host = host
self.port = port
self.endpoint = endpoint
self.task_manager = task_manager
self.agent_card = agent_card
self.app = Starlette()
self.app.add_route(self.endpoint, self._process_request, methods=["POST"])
self.app.add_route(
"/.well-known/agent.json", self._get_agent_card, methods=["GET"]
)
def start(self):
if self.agent_card is None:
raise ValueError("agent_card is not defined")
if self.task_manager is None:
raise ValueError("request_handler is not defined")
import uvicorn
uvicorn.run(self.app, host=self.host, port=self.port)
def _get_agent_card(self, request: Request) -> JSONResponse:
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
async def _process_request(self, request: Request):
try:
body = await request.json()
json_rpc_request = A2ARequest.validate_python(body)
if isinstance(json_rpc_request, GetTaskRequest):
result = await self.task_manager.on_get_task(json_rpc_request)
elif isinstance(json_rpc_request, SendTaskRequest):
result = await self.task_manager.on_send_task(json_rpc_request)
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
result = await self.task_manager.on_send_task_subscribe(
json_rpc_request
)
elif isinstance(json_rpc_request, CancelTaskRequest):
result = await self.task_manager.on_cancel_task(json_rpc_request)
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
result = await self.task_manager.on_set_task_push_notification(json_rpc_request)
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
result = await self.task_manager.on_get_task_push_notification(json_rpc_request)
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
result = await self.task_manager.on_resubscribe_to_task(
json_rpc_request
)
else:
logger.warning(f"Unexpected request type: {type(json_rpc_request)}")
raise ValueError(f"Unexpected request type: {type(request)}")
return self._create_response(result)
except Exception as e:
return self._handle_exception(e)
def _handle_exception(self, e: Exception) -> JSONResponse:
if isinstance(e, json.decoder.JSONDecodeError):
json_rpc_error = JSONParseError()
elif isinstance(e, ValidationError):
json_rpc_error = InvalidRequestError(data=json.loads(e.json()))
else:
logger.error(f"Unhandled exception: {e}")
json_rpc_error = InternalError()
response = JSONRPCResponse(id=None, error=json_rpc_error)
return JSONResponse(response.model_dump(exclude_none=True), status_code=400)
def _create_response(self, result: Any) -> JSONResponse | EventSourceResponse:
if isinstance(result, AsyncIterable):
async def event_generator(result) -> AsyncIterable[dict[str, str]]:
async for item in result:
yield {"data": item.model_dump_json(exclude_none=True)}
return EventSourceResponse(event_generator(result))
elif isinstance(result, JSONRPCResponse):
return JSONResponse(result.model_dump(exclude_none=True))
else:
logger.error(f"Unexpected result type: {type(result)}")
raise ValueError(f"Unexpected result type: {type(result)}")

View File

@ -0,0 +1,277 @@
from abc import ABC, abstractmethod
from typing import Union, AsyncIterable, List
from common.types import Task
from common.types import (
JSONRPCResponse,
TaskIdParams,
TaskQueryParams,
GetTaskRequest,
TaskNotFoundError,
SendTaskRequest,
CancelTaskRequest,
TaskNotCancelableError,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
GetTaskResponse,
CancelTaskResponse,
SendTaskResponse,
SetTaskPushNotificationResponse,
GetTaskPushNotificationResponse,
PushNotificationNotSupportedError,
TaskSendParams,
TaskStatus,
TaskState,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
Artifact,
PushNotificationConfig,
TaskStatusUpdateEvent,
JSONRPCError,
TaskPushNotificationConfig,
InternalError,
)
from common.server.utils import new_not_implemented_error
import asyncio
import logging
logger = logging.getLogger(__name__)
class TaskManager(ABC):
@abstractmethod
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
pass
@abstractmethod
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
pass
@abstractmethod
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
pass
@abstractmethod
async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
pass
@abstractmethod
async def on_set_task_push_notification(
self, request: SetTaskPushNotificationRequest
) -> SetTaskPushNotificationResponse:
pass
@abstractmethod
async def on_get_task_push_notification(
self, request: GetTaskPushNotificationRequest
) -> GetTaskPushNotificationResponse:
pass
@abstractmethod
async def on_resubscribe_to_task(
self, request: TaskResubscriptionRequest
) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]:
pass
class InMemoryTaskManager(TaskManager):
def __init__(self):
self.tasks: dict[str, Task] = {}
self.push_notification_infos: dict[str, PushNotificationConfig] = {}
self.lock = asyncio.Lock()
self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {}
self.subscriber_lock = asyncio.Lock()
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
logger.info(f"Getting task {request.params.id}")
task_query_params: TaskQueryParams = request.params
async with self.lock:
task = self.tasks.get(task_query_params.id)
if task is None:
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
task_result = self.append_task_history(
task, task_query_params.historyLength
)
return GetTaskResponse(id=request.id, result=task_result)
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
logger.info(f"Cancelling task {request.params.id}")
task_id_params: TaskIdParams = request.params
async with self.lock:
task = self.tasks.get(task_id_params.id)
if task is None:
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
@abstractmethod
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
pass
@abstractmethod
async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
pass
async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig):
async with self.lock:
task = self.tasks.get(task_id)
if task is None:
raise ValueError(f"Task not found for {task_id}")
self.push_notification_infos[task_id] = notification_config
return
async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig:
async with self.lock:
task = self.tasks.get(task_id)
if task is None:
raise ValueError(f"Task not found for {task_id}")
return self.push_notification_infos[task_id]
return
async def has_push_notification_info(self, task_id: str) -> bool:
async with self.lock:
return task_id in self.push_notification_infos
async def on_set_task_push_notification(
self, request: SetTaskPushNotificationRequest
) -> SetTaskPushNotificationResponse:
logger.info(f"Setting task push notification {request.params.id}")
task_notification_params: TaskPushNotificationConfig = request.params
try:
await self.set_push_notification_info(task_notification_params.id, task_notification_params.pushNotificationConfig)
except Exception as e:
logger.error(f"Error while setting push notification info: {e}")
return JSONRPCResponse(
id=request.id,
error=InternalError(
message="An error occurred while setting push notification info"
),
)
return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params)
async def on_get_task_push_notification(
self, request: GetTaskPushNotificationRequest
) -> GetTaskPushNotificationResponse:
logger.info(f"Getting task push notification {request.params.id}")
task_params: TaskIdParams = request.params
try:
notification_info = await self.get_push_notification_info(task_params.id)
except Exception as e:
logger.error(f"Error while getting push notification info: {e}")
return GetTaskPushNotificationResponse(
id=request.id,
error=InternalError(
message="An error occurred while getting push notification info"
),
)
return GetTaskPushNotificationResponse(id=request.id, result=TaskPushNotificationConfig(id=task_params.id, pushNotificationConfig=notification_info))
async def upsert_task(self, task_send_params: TaskSendParams) -> Task:
logger.info(f"Upserting task {task_send_params.id}")
async with self.lock:
task = self.tasks.get(task_send_params.id)
if task is None:
task = Task(
id=task_send_params.id,
sessionId = task_send_params.sessionId,
messages=[task_send_params.message],
status=TaskStatus(state=TaskState.SUBMITTED),
history=[task_send_params.message],
)
self.tasks[task_send_params.id] = task
else:
task.history.append(task_send_params.message)
return task
async def on_resubscribe_to_task(
self, request: TaskResubscriptionRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
return new_not_implemented_error(request.id)
async def update_store(
self, task_id: str, status: TaskStatus, artifacts: list[Artifact]
) -> Task:
async with self.lock:
try:
task = self.tasks[task_id]
except KeyError:
logger.error(f"Task {task_id} not found for updating the task")
raise ValueError(f"Task {task_id} not found")
task.status = status
if status.message is not None:
task.history.append(status.message)
if artifacts is not None:
if task.artifacts is None:
task.artifacts = []
task.artifacts.extend(artifacts)
return task
def append_task_history(self, task: Task, historyLength: int | None):
new_task = task.model_copy()
if historyLength is not None and historyLength > 0:
new_task.history = new_task.history[-historyLength:]
else:
new_task.history = []
return new_task
async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False):
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
if is_resubscribe:
raise ValueError("Task not found for resubscription")
else:
self.task_sse_subscribers[task_id] = []
sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited
self.task_sse_subscribers[task_id].append(sse_event_queue)
return sse_event_queue
async def enqueue_events_for_sse(self, task_id, task_update_event):
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
return
current_subscribers = self.task_sse_subscribers[task_id]
for subscriber in current_subscribers:
await subscriber.put(task_update_event)
async def dequeue_events_for_sse(
self, request_id, task_id, sse_event_queue: asyncio.Queue
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
try:
while True:
event = await sse_event_queue.get()
if isinstance(event, JSONRPCError):
yield SendTaskStreamingResponse(id=request_id, error=event)
break
yield SendTaskStreamingResponse(id=request_id, result=event)
if isinstance(event, TaskStatusUpdateEvent) and event.final:
break
finally:
async with self.subscriber_lock:
if task_id in self.task_sse_subscribers:
self.task_sse_subscribers[task_id].remove(sse_event_queue)

View File

@ -0,0 +1,28 @@
from common.types import (
JSONRPCResponse,
ContentTypeNotSupportedError,
UnsupportedOperationError,
)
from typing import List
def are_modalities_compatible(
server_output_modes: List[str], client_output_modes: List[str]
):
"""Modalities are compatible if they are both non-empty
and there is at least one common element."""
if client_output_modes is None or len(client_output_modes) == 0:
return True
if server_output_modes is None or len(server_output_modes) == 0:
return True
return any(x in server_output_modes for x in client_output_modes)
def new_incompatible_types_error(request_id):
return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError())
def new_not_implemented_error(request_id):
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())

View File

@ -0,0 +1,365 @@
from typing import Union, Any
from pydantic import BaseModel, Field, TypeAdapter
from typing import Literal, List, Annotated, Optional
from datetime import datetime
from pydantic import model_validator, ConfigDict, field_serializer
from uuid import uuid4
from enum import Enum
from typing_extensions import Self
class TaskState(str, Enum):
SUBMITTED = "submitted"
WORKING = "working"
INPUT_REQUIRED = "input-required"
COMPLETED = "completed"
CANCELED = "canceled"
FAILED = "failed"
UNKNOWN = "unknown"
class TextPart(BaseModel):
type: Literal["text"] = "text"
text: str
metadata: dict[str, Any] | None = None
class FileContent(BaseModel):
name: str | None = None
mimeType: str | None = None
bytes: str | None = None
uri: str | None = None
@model_validator(mode="after")
def check_content(self) -> Self:
if not (self.bytes or self.uri):
raise ValueError("Either 'bytes' or 'uri' must be present in the file data")
if self.bytes and self.uri:
raise ValueError(
"Only one of 'bytes' or 'uri' can be present in the file data"
)
return self
class FilePart(BaseModel):
type: Literal["file"] = "file"
file: FileContent
metadata: dict[str, Any] | None = None
class DataPart(BaseModel):
type: Literal["data"] = "data"
data: dict[str, Any]
metadata: dict[str, Any] | None = None
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")]
class Message(BaseModel):
role: Literal["user", "agent"]
parts: List[Part]
metadata: dict[str, Any] | None = None
class TaskStatus(BaseModel):
state: TaskState
message: Message | None = None
timestamp: datetime = Field(default_factory=datetime.now)
@field_serializer("timestamp")
def serialize_dt(self, dt: datetime, _info):
return dt.isoformat()
class Artifact(BaseModel):
name: str | None = None
description: str | None = None
parts: List[Part]
metadata: dict[str, Any] | None = None
index: int = 0
append: bool | None = None
lastChunk: bool | None = None
class Task(BaseModel):
id: str
sessionId: str | None = None
status: TaskStatus
artifacts: List[Artifact] | None = None
history: List[Message] | None = None
metadata: dict[str, Any] | None = None
class TaskStatusUpdateEvent(BaseModel):
id: str
status: TaskStatus
final: bool = False
metadata: dict[str, Any] | None = None
class TaskArtifactUpdateEvent(BaseModel):
id: str
artifact: Artifact
metadata: dict[str, Any] | None = None
class AuthenticationInfo(BaseModel):
model_config = ConfigDict(extra="allow")
schemes: List[str]
credentials: str | None = None
class PushNotificationConfig(BaseModel):
url: str
token: str | None = None
authentication: AuthenticationInfo | None = None
class TaskIdParams(BaseModel):
id: str
metadata: dict[str, Any] | None = None
class TaskQueryParams(TaskIdParams):
historyLength: int | None = None
class TaskSendParams(BaseModel):
id: str
sessionId: str = Field(default_factory=lambda: uuid4().hex)
message: Message
acceptedOutputModes: Optional[List[str]] = None
pushNotification: PushNotificationConfig | None = None
historyLength: int | None = None
metadata: dict[str, Any] | None = None
class TaskPushNotificationConfig(BaseModel):
id: str
pushNotificationConfig: PushNotificationConfig
## RPC Messages
class JSONRPCMessage(BaseModel):
jsonrpc: Literal["2.0"] = "2.0"
id: int | str | None = Field(default_factory=lambda: uuid4().hex)
class JSONRPCRequest(JSONRPCMessage):
method: str
params: dict[str, Any] | None = None
class JSONRPCError(BaseModel):
code: int
message: str
data: Any | None = None
class JSONRPCResponse(JSONRPCMessage):
result: Any | None = None
error: JSONRPCError | None = None
class SendTaskRequest(JSONRPCRequest):
method: Literal["tasks/send"] = "tasks/send"
params: TaskSendParams
class SendTaskResponse(JSONRPCResponse):
result: Task | None = None
class SendTaskStreamingRequest(JSONRPCRequest):
method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe"
params: TaskSendParams
class SendTaskStreamingResponse(JSONRPCResponse):
result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None
class GetTaskRequest(JSONRPCRequest):
method: Literal["tasks/get"] = "tasks/get"
params: TaskQueryParams
class GetTaskResponse(JSONRPCResponse):
result: Task | None = None
class CancelTaskRequest(JSONRPCRequest):
method: Literal["tasks/cancel",] = "tasks/cancel"
params: TaskIdParams
class CancelTaskResponse(JSONRPCResponse):
result: Task | None = None
class SetTaskPushNotificationRequest(JSONRPCRequest):
method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set"
params: TaskPushNotificationConfig
class SetTaskPushNotificationResponse(JSONRPCResponse):
result: TaskPushNotificationConfig | None = None
class GetTaskPushNotificationRequest(JSONRPCRequest):
method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get"
params: TaskIdParams
class GetTaskPushNotificationResponse(JSONRPCResponse):
result: TaskPushNotificationConfig | None = None
class TaskResubscriptionRequest(JSONRPCRequest):
method: Literal["tasks/resubscribe",] = "tasks/resubscribe"
params: TaskIdParams
A2ARequest = TypeAdapter(
Annotated[
Union[
SendTaskRequest,
GetTaskRequest,
CancelTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
],
Field(discriminator="method"),
]
)
## Error types
class JSONParseError(JSONRPCError):
code: int = -32700
message: str = "Invalid JSON payload"
data: Any | None = None
class InvalidRequestError(JSONRPCError):
code: int = -32600
message: str = "Request payload validation error"
data: Any | None = None
class MethodNotFoundError(JSONRPCError):
code: int = -32601
message: str = "Method not found"
data: None = None
class InvalidParamsError(JSONRPCError):
code: int = -32602
message: str = "Invalid parameters"
data: Any | None = None
class InternalError(JSONRPCError):
code: int = -32603
message: str = "Internal error"
data: Any | None = None
class TaskNotFoundError(JSONRPCError):
code: int = -32001
message: str = "Task not found"
data: None = None
class TaskNotCancelableError(JSONRPCError):
code: int = -32002
message: str = "Task cannot be canceled"
data: None = None
class PushNotificationNotSupportedError(JSONRPCError):
code: int = -32003
message: str = "Push Notification is not supported"
data: None = None
class UnsupportedOperationError(JSONRPCError):
code: int = -32004
message: str = "This operation is not supported"
data: None = None
class ContentTypeNotSupportedError(JSONRPCError):
code: int = -32005
message: str = "Incompatible content types"
data: None = None
class AgentProvider(BaseModel):
organization: str
url: str | None = None
class AgentCapabilities(BaseModel):
streaming: bool = False
pushNotifications: bool = False
stateTransitionHistory: bool = False
class AgentAuthentication(BaseModel):
schemes: List[str]
credentials: str | None = None
class AgentSkill(BaseModel):
id: str
name: str
description: str | None = None
tags: List[str] | None = None
examples: List[str] | None = None
inputModes: List[str] | None = None
outputModes: List[str] | None = None
class AgentCard(BaseModel):
name: str
description: str | None = None
url: str
provider: AgentProvider | None = None
version: str
documentationUrl: str | None = None
capabilities: AgentCapabilities
authentication: AgentAuthentication | None = None
defaultInputModes: List[str] = ["text"]
defaultOutputModes: List[str] = ["text"]
skills: List[AgentSkill]
class A2AClientError(Exception):
pass
class A2AClientHTTPError(A2AClientError):
def __init__(self, status_code: int, message: str):
self.status_code = status_code
self.message = message
super().__init__(f"HTTP Error {status_code}: {message}")
class A2AClientJSONError(A2AClientError):
def __init__(self, message: str):
self.message = message
super().__init__(f"JSON Error: {message}")
class MissingAPIKeyError(Exception):
"""Exception for missing API key."""
pass

View File

@ -0,0 +1,109 @@
"""In Memory Cache utility."""
import threading
import time
from typing import Any, Dict, Optional
class InMemoryCache:
"""A thread-safe Singleton class to manage cache data.
Ensures only one instance of the cache exists across the application.
"""
_instance: Optional["InMemoryCache"] = None
_lock: threading.Lock = threading.Lock()
_initialized: bool = False
def __new__(cls):
"""Override __new__ to control instance creation (Singleton pattern).
Uses a lock to ensure thread safety during the first instantiation.
Returns:
The singleton instance of InMemoryCache.
"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""Initialize the cache storage.
Uses a flag (_initialized) to ensure this logic runs only on the very first
creation of the singleton instance.
"""
if not self._initialized:
with self._lock:
if not self._initialized:
# print("Initializing SessionCache storage")
self._cache_data: Dict[str, Dict[str, Any]] = {}
self._ttl: Dict[str, float] = {}
self._data_lock: threading.Lock = threading.Lock()
self._initialized = True
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
"""Set a key-value pair.
Args:
key: The key for the data.
value: The data to store.
ttl: Time to live in seconds. If None, data will not expire.
"""
with self._data_lock:
self._cache_data[key] = value
if ttl is not None:
self._ttl[key] = time.time() + ttl
else:
if key in self._ttl:
del self._ttl[key]
def get(self, key: str, default: Any = None) -> Any:
"""Get the value associated with a key.
Args:
key: The key for the data within the session.
default: The value to return if the session or key is not found.
Returns:
The cached value, or the default value if not found.
"""
with self._data_lock:
if key in self._ttl and time.time() > self._ttl[key]:
del self._cache_data[key]
del self._ttl[key]
return default
return self._cache_data.get(key, default)
def delete(self, key: str) -> None:
"""Delete a specific key-value pair from a cache.
Args:
key: The key to delete.
Returns:
True if the key was found and deleted, False otherwise.
"""
with self._data_lock:
if key in self._cache_data:
del self._cache_data[key]
if key in self._ttl:
del self._ttl[key]
return True
return False
def clear(self) -> bool:
"""Remove all data.
Returns:
True if the data was cleared, False otherwise.
"""
with self._data_lock:
self._cache_data.clear()
self._ttl.clear()
return True
return False

View File

@ -0,0 +1,135 @@
from jwcrypto import jwk
import uuid
from starlette.responses import JSONResponse
from starlette.requests import Request
from typing import Any
import jwt
import time
import json
import hashlib
import httpx
import logging
from jwt import PyJWK, PyJWKClient
logger = logging.getLogger(__name__)
AUTH_HEADER_PREFIX = 'Bearer '
class PushNotificationAuth:
def _calculate_request_body_sha256(self, data: dict[str, Any]):
"""Calculates the SHA256 hash of a request body.
This logic needs to be same for both the agent who signs the payload and the client verifier.
"""
body_str = json.dumps(
data,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
)
return hashlib.sha256(body_str.encode()).hexdigest()
class PushNotificationSenderAuth(PushNotificationAuth):
def __init__(self):
self.public_keys = []
self.private_key_jwk: PyJWK = None
@staticmethod
async def verify_push_notification_url(url: str) -> bool:
async with httpx.AsyncClient(timeout=10) as client:
try:
validation_token = str(uuid.uuid4())
response = await client.get(
url,
params={"validationToken": validation_token}
)
response.raise_for_status()
is_verified = response.text == validation_token
logger.info(f"Verified push-notification URL: {url} => {is_verified}")
return is_verified
except Exception as e:
logger.warning(f"Error during sending push-notification for URL {url}: {e}")
return False
def generate_jwk(self):
key = jwk.JWK.generate(kty='RSA', size=2048, kid=str(uuid.uuid4()), use="sig")
self.public_keys.append(key.export_public(as_dict=True))
self.private_key_jwk = PyJWK.from_json(key.export_private())
def handle_jwks_endpoint(self, _request: Request):
"""Allow clients to fetch public keys.
"""
return JSONResponse({
"keys": self.public_keys
})
def _generate_jwt(self, data: dict[str, Any]):
"""JWT is generated by signing both the request payload SHA digest and time of token generation.
Payload is signed with private key and it ensures the integrity of payload for client.
Including iat prevents from replay attack.
"""
iat = int(time.time())
return jwt.encode(
{"iat": iat, "request_body_sha256": self._calculate_request_body_sha256(data)},
key=self.private_key_jwk,
headers={"kid": self.private_key_jwk.key_id},
algorithm="RS256"
)
async def send_push_notification(self, url: str, data: dict[str, Any]):
jwt_token = self._generate_jwt(data)
headers = {'Authorization': f"Bearer {jwt_token}"}
async with httpx.AsyncClient(timeout=10) as client:
try:
response = await client.post(
url,
json=data,
headers=headers
)
response.raise_for_status()
logger.info(f"Push-notification sent for URL: {url}")
except Exception as e:
logger.warning(f"Error during sending push-notification for URL {url}: {e}")
class PushNotificationReceiverAuth(PushNotificationAuth):
def __init__(self):
self.public_keys_jwks = []
self.jwks_client = None
async def load_jwks(self, jwks_url: str):
self.jwks_client = PyJWKClient(jwks_url)
async def verify_push_notification(self, request: Request) -> bool:
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX):
print("Invalid authorization header")
return False
token = auth_header[len(AUTH_HEADER_PREFIX):]
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
decode_token = jwt.decode(
token,
signing_key,
options={"require": ["iat", "request_body_sha256"]},
algorithms=["RS256"],
)
actual_body_sha256 = self._calculate_request_body_sha256(await request.json())
if actual_body_sha256 != decode_token["request_body_sha256"]:
# Payload signature does not match the digest in signed token.
raise ValueError("Invalid request body")
if time.time() - decode_token["iat"] > 60 * 5:
# Do not allow push-notifications older than 5 minutes.
# This is to prevent replay attack.
raise ValueError("Token is expired")
return True

View File

@ -0,0 +1,33 @@
from pocketflow import Flow
from nodes import DecideAction, SearchWeb, AnswerQuestion
def create_agent_flow():
"""
Create and connect the nodes to form a complete agent flow.
The flow works like this:
1. DecideAction node decides whether to search or answer
2. If search, go to SearchWeb node
3. If answer, go to AnswerQuestion node
4. After SearchWeb completes, go back to DecideAction
Returns:
Flow: A complete research agent flow
"""
# Create instances of each node
decide = DecideAction()
search = SearchWeb()
answer = AnswerQuestion()
# Connect the nodes
# If DecideAction returns "search", go to SearchWeb
decide - "search" >> search
# If DecideAction returns "answer", go to AnswerQuestion
decide - "answer" >> answer
# After SearchWeb completes and returns "decide", go back to DecideAction
search - "decide" >> decide
# Create and return the flow, starting with the DecideAction node
return Flow(start=decide)

View File

@ -0,0 +1,27 @@
import sys
from flow import create_agent_flow
def main():
"""Simple function to process a question."""
# Default question
default_question = "Who won the Nobel Prize in Physics 2024?"
# Get question from command line if provided with --
question = default_question
for arg in sys.argv[1:]:
if arg.startswith("--"):
question = arg[2:]
break
# Create the agent flow
agent_flow = create_agent_flow()
# Process the question
shared = {"question": question}
print(f"🤔 Processing question: {question}")
agent_flow.run(shared)
print("\n🎯 Final Answer:")
print(shared.get("answer", "No answer found"))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,135 @@
from pocketflow import Node
from utils import call_llm, search_web
import yaml
class DecideAction(Node):
def prep(self, shared):
"""Prepare the context and question for the decision-making process."""
# Get the current context (default to "No previous search" if none exists)
context = shared.get("context", "No previous search")
# Get the question from the shared store
question = shared["question"]
# Return both for the exec step
return question, context
def exec(self, inputs):
"""Call the LLM to decide whether to search or answer."""
question, context = inputs
print(f"🤔 Agent deciding what to do next...")
# Create a prompt to help the LLM decide what to do next with proper yaml formatting
prompt = f"""
### CONTEXT
You are a research assistant that can search the web.
Question: {question}
Previous Research: {context}
### ACTION SPACE
[1] search
Description: Look up more information on the web
Parameters:
- query (str): What to search for
[2] answer
Description: Answer the question with current knowledge
Parameters:
- answer (str): Final answer to the question
## NEXT ACTION
Decide the next action based on the context and available actions.
Return your response in this format:
```yaml
thinking: |
<your step-by-step reasoning process>
action: search OR answer
reason: <why you chose this action>
answer: <if action is answer>
search_query: <specific search query if action is search>
```
IMPORTANT: Make sure to:
1. Use proper indentation (4 spaces) for all multi-line fields
2. Use the | character for multi-line text fields
3. Keep single-line fields without the | character
"""
# Call the LLM to make a decision
response = call_llm(prompt)
# Parse the response to get the decision
yaml_str = response.split("```yaml")[1].split("```")[0].strip()
decision = yaml.safe_load(yaml_str)
return decision
def post(self, shared, prep_res, exec_res):
"""Save the decision and determine the next step in the flow."""
# If LLM decided to search, save the search query
if exec_res["action"] == "search":
shared["search_query"] = exec_res["search_query"]
print(f"🔍 Agent decided to search for: {exec_res['search_query']}")
else:
shared["context"] = exec_res["answer"] #save the context if LLM gives the answer without searching.
print(f"💡 Agent decided to answer the question")
# Return the action to determine the next node in the flow
return exec_res["action"]
class SearchWeb(Node):
def prep(self, shared):
"""Get the search query from the shared store."""
return shared["search_query"]
def exec(self, search_query):
"""Search the web for the given query."""
# Call the search utility function
print(f"🌐 Searching the web for: {search_query}")
results = search_web(search_query)
return results
def post(self, shared, prep_res, exec_res):
"""Save the search results and go back to the decision node."""
# Add the search results to the context in the shared store
previous = shared.get("context", "")
shared["context"] = previous + "\n\nSEARCH: " + shared["search_query"] + "\nRESULTS: " + exec_res
print(f"📚 Found information, analyzing results...")
# Always go back to the decision node after searching
return "decide"
class AnswerQuestion(Node):
def prep(self, shared):
"""Get the question and context for answering."""
return shared["question"], shared.get("context", "")
def exec(self, inputs):
"""Call the LLM to generate a final answer."""
question, context = inputs
print(f"✍️ Crafting final answer...")
# Create a prompt for the LLM to answer the question
prompt = f"""
### CONTEXT
Based on the following information, answer the question.
Question: {question}
Research: {context}
## YOUR ANSWER:
Provide a comprehensive answer using the research results.
"""
# Call the LLM to generate an answer
answer = call_llm(prompt)
return answer
def post(self, shared, prep_res, exec_res):
"""Save the final answer and complete the flow."""
# Save the answer in the shared store
shared["answer"] = exec_res
print(f"✅ Answer generated successfully")
# We're done - no need to continue the flow
return "done"

View File

@ -0,0 +1,21 @@
# For PocketFlow Agent Logic
pocketflow>=0.0.1
openai>=1.0.0
duckduckgo-search>=7.5.2
pyyaml>=5.1
# For A2A Server Infrastructure (from common)
starlette>=0.37.2,<0.38.0
uvicorn[standard]>=0.29.0,<0.30.0
sse-starlette>=1.8.2,<2.0.0
pydantic>=2.0.0,<3.0.0
httpx>=0.27.0,<0.28.0
anyio>=3.0.0,<5.0.0 # Dependency of starlette/httpx
# For running __main__.py
click>=8.0.0,<9.0.0
# For A2A Client
httpx>=0.27.0,<0.28.0
asyncclick>=8.1.8 # Or just 'click' if you prefer asyncio.run
pydantic>=2.0.0,<3.0.0 # For common.types

View File

@ -0,0 +1,112 @@
# FILE: pocketflow_a2a_agent/task_manager.py
import logging
from typing import AsyncIterable, Union
import asyncio
# Import from the common code you copied
from common.server.task_manager import InMemoryTaskManager
from common.types import (
JSONRPCResponse, SendTaskRequest, SendTaskResponse,
SendTaskStreamingRequest, SendTaskStreamingResponse, Task, TaskSendParams,
TaskState, TaskStatus, TextPart, Artifact, UnsupportedOperationError,
InternalError, InvalidParamsError
)
import common.server.utils as server_utils
# Import directly from your original PocketFlow files
from .flow import create_agent_flow # Assumes flow.py is in the same directory
from .utils import call_llm, search_web # Make utils functions available if needed elsewhere
logger = logging.getLogger(__name__)
class PocketFlowTaskManager(InMemoryTaskManager):
""" TaskManager implementation that runs the PocketFlow agent. """
SUPPORTED_CONTENT_TYPES = ["text", "text/plain"] # Define what the agent accepts/outputs
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
"""Handles non-streaming task requests."""
logger.info(f"Received task send request: {request.params.id}")
# Validate output modes
if not server_utils.are_modalities_compatible(
request.params.acceptedOutputModes, self.SUPPORTED_CONTENT_TYPES
):
logger.warning(
"Unsupported output mode. Received %s, Support %s",
request.params.acceptedOutputModes, self.SUPPORTED_CONTENT_TYPES
)
return SendTaskResponse(id=request.id, error=server_utils.new_incompatible_types_error(request.id).error)
# Upsert the task in the store (initial state: submitted)
# We create the task first so its state can be tracked, even if the sync execution fails
await self.upsert_task(request.params)
# Update state to working before running
await self.update_store(request.params.id, TaskStatus(state=TaskState.WORKING), [])
# --- Run the PocketFlow logic ---
task_params: TaskSendParams = request.params
query = self._get_user_query(task_params)
if query is None:
fail_status = TaskStatus(state=TaskState.FAILED, message=Message(role="agent", parts=[TextPart(text="No text query found")]))
await self.update_store(task_params.id, fail_status, [])
return SendTaskResponse(id=request.id, error=InvalidParamsError(message="No text query found in message parts"))
shared_data = {"question": query}
agent_flow = create_agent_flow() # Create the flow instance
try:
# Run the synchronous PocketFlow
# In a real async server, you might run this in a separate thread/process
# executor to avoid blocking the event loop. For simplicity here, we run it directly.
# Consider adding a timeout if flows can hang.
logger.info(f"Running PocketFlow for task {task_params.id}...")
final_state_dict = agent_flow.run(shared_data)
logger.info(f"PocketFlow completed for task {task_params.id}")
answer_text = final_state_dict.get("answer", "Agent did not produce a final answer text.")
# --- Package result into A2A Task ---
final_task_status = TaskStatus(state=TaskState.COMPLETED)
# Package the answer as an artifact
final_artifact = Artifact(parts=[TextPart(text=answer_text)])
# Update the task in the store with final status and artifact
final_task = await self.update_store(
task_params.id, final_task_status, [final_artifact]
)
# Prepare and return the A2A response
task_result = self.append_task_history(final_task, task_params.historyLength)
return SendTaskResponse(id=request.id, result=task_result)
except Exception as e:
logger.error(f"Error executing PocketFlow for task {task_params.id}: {e}", exc_info=True)
# Update task state to FAILED
fail_status = TaskStatus(
state=TaskState.FAILED,
message=Message(role="agent", parts=[TextPart(text=f"Agent execution failed: {e}")])
)
await self.update_store(task_params.id, fail_status, [])
return SendTaskResponse(id=request.id, error=InternalError(message=f"Agent error: {e}"))
async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
"""Handles streaming requests - Not implemented for this synchronous agent."""
logger.warning(f"Streaming requested for task {request.params.id}, but not supported by this PocketFlow agent implementation.")
# Return an error indicating streaming is not supported
return JSONRPCResponse(id=request.id, error=UnsupportedOperationError(message="Streaming not supported by this agent"))
def _get_user_query(self, task_send_params: TaskSendParams) -> str | None:
"""Extracts the first text part from the user message."""
if not task_send_params.message or not task_send_params.message.parts:
logger.warning(f"No message parts found for task {task_send_params.id}")
return None
for part in task_send_params.message.parts:
# Ensure part is treated as a dictionary if it came from JSON
part_dict = part if isinstance(part, dict) else part.model_dump()
if part_dict.get("type") == "text" and "text" in part_dict:
return part_dict["text"]
logger.warning(f"No text part found in message for task {task_send_params.id}")
return None # No text part found

View File

@ -0,0 +1,30 @@
from openai import OpenAI
import os
from duckduckgo_search import DDGS
def call_llm(prompt):
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
r = client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": prompt}]
)
return r.choices[0].message.content
def search_web(query):
results = DDGS().text(query, max_results=5)
# Convert results to a string
results_str = "\n\n".join([f"Title: {r['title']}\nURL: {r['href']}\nSnippet: {r['body']}" for r in results])
return results_str
if __name__ == "__main__":
print("## Testing call_llm")
prompt = "In a few words, what is the meaning of life?"
print(f"## Prompt: {prompt}")
response = call_llm(prompt)
print(f"## Response: {response}")
print("## Testing search_web")
query = "Who won the Nobel Prize in Physics 2024?"
print(f"## Query: {query}")
results = search_web(query)
print(f"## Results: {results}")