pocketflow/cookbook/pocketflow_a2a/common/server/server.py

121 lines
4.5 KiB
Python

from starlette.applications import Starlette
from starlette.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from starlette.requests import Request
from common.types import (
A2ARequest,
JSONRPCResponse,
InvalidRequestError,
JSONParseError,
GetTaskRequest,
CancelTaskRequest,
SendTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
InternalError,
AgentCard,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
)
from pydantic import ValidationError
import json
from typing import AsyncIterable, Any
from common.server.task_manager import TaskManager
import logging
logger = logging.getLogger(__name__)
class A2AServer:
def __init__(
self,
host="0.0.0.0",
port=5000,
endpoint="/",
agent_card: AgentCard = None,
task_manager: TaskManager = None,
):
self.host = host
self.port = port
self.endpoint = endpoint
self.task_manager = task_manager
self.agent_card = agent_card
self.app = Starlette()
self.app.add_route(self.endpoint, self._process_request, methods=["POST"])
self.app.add_route(
"/.well-known/agent.json", self._get_agent_card, methods=["GET"]
)
def start(self):
if self.agent_card is None:
raise ValueError("agent_card is not defined")
if self.task_manager is None:
raise ValueError("request_handler is not defined")
import uvicorn
uvicorn.run(self.app, host=self.host, port=self.port)
def _get_agent_card(self, request: Request) -> JSONResponse:
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
async def _process_request(self, request: Request):
try:
body = await request.json()
json_rpc_request = A2ARequest.validate_python(body)
if isinstance(json_rpc_request, GetTaskRequest):
result = await self.task_manager.on_get_task(json_rpc_request)
elif isinstance(json_rpc_request, SendTaskRequest):
result = await self.task_manager.on_send_task(json_rpc_request)
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
result = await self.task_manager.on_send_task_subscribe(
json_rpc_request
)
elif isinstance(json_rpc_request, CancelTaskRequest):
result = await self.task_manager.on_cancel_task(json_rpc_request)
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
result = await self.task_manager.on_set_task_push_notification(json_rpc_request)
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
result = await self.task_manager.on_get_task_push_notification(json_rpc_request)
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
result = await self.task_manager.on_resubscribe_to_task(
json_rpc_request
)
else:
logger.warning(f"Unexpected request type: {type(json_rpc_request)}")
raise ValueError(f"Unexpected request type: {type(request)}")
return self._create_response(result)
except Exception as e:
return self._handle_exception(e)
def _handle_exception(self, e: Exception) -> JSONResponse:
if isinstance(e, json.decoder.JSONDecodeError):
json_rpc_error = JSONParseError()
elif isinstance(e, ValidationError):
json_rpc_error = InvalidRequestError(data=json.loads(e.json()))
else:
logger.error(f"Unhandled exception: {e}")
json_rpc_error = InternalError()
response = JSONRPCResponse(id=None, error=json_rpc_error)
return JSONResponse(response.model_dump(exclude_none=True), status_code=400)
def _create_response(self, result: Any) -> JSONResponse | EventSourceResponse:
if isinstance(result, AsyncIterable):
async def event_generator(result) -> AsyncIterable[dict[str, str]]:
async for item in result:
yield {"data": item.model_dump_json(exclude_none=True)}
return EventSourceResponse(event_generator(result))
elif isinstance(result, JSONRPCResponse):
return JSONResponse(result.model_dump(exclude_none=True))
else:
logger.error(f"Unexpected result type: {type(result)}")
raise ValueError(f"Unexpected result type: {type(result)}")