169 lines
7.8 KiB
Python
169 lines
7.8 KiB
Python
from starlette.applications import Starlette
|
|
from starlette.responses import JSONResponse
|
|
from sse_starlette.sse import EventSourceResponse
|
|
from starlette.requests import Request
|
|
from common.types import (
|
|
A2ARequest,
|
|
JSONRPCResponse,
|
|
InvalidRequestError,
|
|
JSONParseError,
|
|
GetTaskRequest,
|
|
CancelTaskRequest,
|
|
SendTaskRequest,
|
|
SetTaskPushNotificationRequest,
|
|
GetTaskPushNotificationRequest,
|
|
InternalError,
|
|
AgentCard,
|
|
TaskResubscriptionRequest,
|
|
SendTaskStreamingRequest,
|
|
Message,
|
|
)
|
|
from pydantic import ValidationError
|
|
import json
|
|
from typing import AsyncIterable, Any
|
|
from common.server.task_manager import TaskManager
|
|
|
|
import logging
|
|
|
|
# Configure a logger specific to the server
|
|
logger = logging.getLogger("A2AServer")
|
|
|
|
|
|
class A2AServer:
|
|
def __init__(
|
|
self,
|
|
host="0.0.0.0",
|
|
port=5000,
|
|
endpoint="/",
|
|
agent_card: AgentCard = None,
|
|
task_manager: TaskManager = None,
|
|
):
|
|
self.host = host
|
|
self.port = port
|
|
self.endpoint = endpoint
|
|
self.task_manager = task_manager
|
|
self.agent_card = agent_card
|
|
self.app = Starlette()
|
|
self.app.add_route(self.endpoint, self._process_request, methods=["POST"])
|
|
self.app.add_route(
|
|
"/.well-known/agent.json", self._get_agent_card, methods=["GET"]
|
|
)
|
|
|
|
def start(self):
|
|
if self.agent_card is None:
|
|
raise ValueError("agent_card is not defined")
|
|
|
|
if self.task_manager is None:
|
|
raise ValueError("request_handler is not defined")
|
|
|
|
import uvicorn
|
|
|
|
# Basic logging config moved to __main__.py for application-level control
|
|
uvicorn.run(self.app, host=self.host, port=self.port)
|
|
|
|
def _get_agent_card(self, request: Request) -> JSONResponse:
|
|
logger.info("Serving Agent Card request")
|
|
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
|
|
|
|
async def _process_request(self, request: Request):
|
|
request_id_for_log = "N/A" # Default if parsing fails early
|
|
raw_body = b""
|
|
try:
|
|
# Log raw body first
|
|
raw_body = await request.body()
|
|
body = json.loads(raw_body) # Attempt parsing
|
|
request_id_for_log = body.get("id", "N/A") # Get ID if possible
|
|
logger.info(f"<- Received Request (ID: {request_id_for_log}):\n{json.dumps(body, indent=2)}")
|
|
|
|
json_rpc_request = A2ARequest.validate_python(body)
|
|
|
|
# Route based on method (same as before)
|
|
if isinstance(json_rpc_request, GetTaskRequest):
|
|
result = await self.task_manager.on_get_task(json_rpc_request)
|
|
elif isinstance(json_rpc_request, SendTaskRequest):
|
|
result = await self.task_manager.on_send_task(json_rpc_request)
|
|
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
|
|
result = await self.task_manager.on_send_task_subscribe(
|
|
json_rpc_request
|
|
)
|
|
elif isinstance(json_rpc_request, CancelTaskRequest):
|
|
result = await self.task_manager.on_cancel_task(json_rpc_request)
|
|
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
|
|
result = await self.task_manager.on_set_task_push_notification(json_rpc_request)
|
|
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
|
|
result = await self.task_manager.on_get_task_push_notification(json_rpc_request)
|
|
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
|
|
result = await self.task_manager.on_resubscribe_to_task(
|
|
json_rpc_request
|
|
)
|
|
else:
|
|
logger.warning(f"Unexpected request type: {type(json_rpc_request)}")
|
|
raise ValueError(f"Unexpected request type: {type(request)}")
|
|
|
|
return self._create_response(result) # Pass result to response creation
|
|
|
|
except json.decoder.JSONDecodeError as e:
|
|
logger.error(f"JSON Parse Error for Request body: <<<{raw_body.decode('utf-8', errors='replace')}>>>\nError: {e}")
|
|
return self._handle_exception(e, request_id_for_log) # Pass ID if known
|
|
except ValidationError as e:
|
|
logger.error(f"Request Validation Error (ID: {request_id_for_log}): {e.json()}")
|
|
return self._handle_exception(e, request_id_for_log)
|
|
except Exception as e:
|
|
logger.error(f"Unhandled Exception processing request (ID: {request_id_for_log}): {e}", exc_info=True)
|
|
return self._handle_exception(e, request_id_for_log) # Pass ID if known
|
|
|
|
def _handle_exception(self, e: Exception, req_id=None) -> JSONResponse: # Accept req_id
|
|
if isinstance(e, json.decoder.JSONDecodeError):
|
|
json_rpc_error = JSONParseError()
|
|
elif isinstance(e, ValidationError):
|
|
json_rpc_error = InvalidRequestError(data=json.loads(e.json()))
|
|
else:
|
|
# Log the full exception details
|
|
logger.error(f"Internal Server Error (ReqID: {req_id}): {e}", exc_info=True)
|
|
json_rpc_error = InternalError(message=f"Internal Server Error: {type(e).__name__}")
|
|
|
|
response = JSONRPCResponse(id=req_id, error=json_rpc_error)
|
|
response_dump = response.model_dump(exclude_none=True)
|
|
logger.info(f"-> Sending Error Response (ReqID: {req_id}):\n{json.dumps(response_dump, indent=2)}")
|
|
# A2A errors are still sent with HTTP 200
|
|
return JSONResponse(response_dump, status_code=200)
|
|
|
|
def _create_response(self, result: Any) -> JSONResponse | EventSourceResponse:
|
|
if isinstance(result, AsyncIterable):
|
|
# Streaming response
|
|
async def event_generator(result_stream) -> AsyncIterable[dict[str, str]]:
|
|
stream_request_id = None # Capture ID from the first event if possible
|
|
try:
|
|
async for item in result_stream:
|
|
# Log each streamed item
|
|
response_json = item.model_dump_json(exclude_none=True)
|
|
stream_request_id = item.id # Update ID
|
|
logger.info(f"-> Sending SSE Event (ID: {stream_request_id}):\n{json.dumps(json.loads(response_json), indent=2)}")
|
|
yield {"data": response_json}
|
|
logger.info(f"SSE Stream ended for request ID: {stream_request_id}")
|
|
except Exception as e:
|
|
logger.error(f"Error during SSE generation (ReqID: {stream_request_id}): {e}", exc_info=True)
|
|
# Optionally yield an error event if the protocol allows/requires it
|
|
# error_payload = JSONRPCResponse(id=stream_request_id, error=InternalError(message=f"SSE Error: {e}"))
|
|
# yield {"data": error_payload.model_dump_json(exclude_none=True)}
|
|
|
|
logger.info("Starting SSE stream...") # Log stream start
|
|
return EventSourceResponse(event_generator(result))
|
|
elif isinstance(result, JSONRPCResponse):
|
|
# Standard JSON response
|
|
response_dump = result.model_dump(exclude_none=True)
|
|
log_id = result.id if result.id is not None else "N/A (Notification?)"
|
|
log_prefix = "->"
|
|
log_type = "Response"
|
|
if result.error:
|
|
log_prefix = "-> Sending Error"
|
|
log_type = "Error Response"
|
|
|
|
logger.info(f"{log_prefix} {log_type} (ID: {log_id}):\n{json.dumps(response_dump, indent=2)}")
|
|
return JSONResponse(response_dump)
|
|
else:
|
|
# This should ideally not happen if task manager returns correctly
|
|
logger.error(f"Task manager returned unexpected type: {type(result)}")
|
|
err_resp = JSONRPCResponse(id=None, error=InternalError(message="Invalid internal response type"))
|
|
return JSONResponse(err_resp.model_dump(exclude_none=True), status_code=500)
|