pocketflow/cookbook/pocketflow-a2a/common/client/client.py

183 lines
7.9 KiB
Python

import httpx
from httpx_sse import connect_sse
from typing import Any, AsyncIterable
from common.types import (
AgentCard,
GetTaskRequest,
SendTaskRequest,
SendTaskResponse,
JSONRPCRequest,
JSONRPCResponse,
JSONRPCError,
GetTaskResponse,
CancelTaskResponse,
CancelTaskRequest,
SetTaskPushNotificationRequest,
SetTaskPushNotificationResponse,
GetTaskPushNotificationRequest,
GetTaskPushNotificationResponse,
A2AClientHTTPError,
A2AClientJSONError,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
Task,
TaskPushNotificationConfig,
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
)
import json
import logging
# Configure a logger specific to the client
logger = logging.getLogger("A2AClient")
class A2AClientError(Exception):
"""Base class for A2A client errors"""
def __init__(self, message):
super().__init__(message)
class RpcError(Exception):
code: int
data: Any = None
def __init__(self, code: int, message: str, data: Any = None):
super().__init__(message)
self.name = "RpcError"
self.code = code
self.data = data
class A2AClient:
def __init__(self, agent_card: AgentCard = None, url: str = None):
if agent_card:
self.url = agent_card.url.rstrip("/")
elif url:
self.url = url.rstrip("/")
else:
raise ValueError("Must provide either agent_card or url")
self.fetchImpl = httpx.AsyncClient(timeout=None)
def _generateRequestId(self):
import time
return int(time.time() * 1000)
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
req_id = request.id
req_method = request.method
req_dump = request.model_dump(exclude_none=True)
logger.info(f"-> Sending Request (ID: {req_id}, Method: {req_method}):\n{json.dumps(req_dump, indent=2)}")
try:
response = await self.fetchImpl.post(
self.url, json=req_dump, timeout=60.0
)
logger.info(f"<- Received HTTP Status {response.status_code} for Request (ID: {req_id})")
response_text = await response.aread()
logger.debug(f"Raw Response Body (ID: {req_id}):\n{response_text.decode('utf-8', errors='replace')}")
response.raise_for_status()
try:
json_response = json.loads(response_text)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode JSON response (ID: {req_id}): {e}")
raise A2AClientJSONError(f"Failed to decode JSON: {e}") from e
if "error" in json_response and json_response["error"] is not None:
rpc_error = json_response["error"]
logger.warning(f"<- Received JSON-RPC Error (ID: {req_id}): Code={rpc_error.get('code')}, Msg='{rpc_error.get('message')}'")
raise RpcError(rpc_error.get("code", -32000), rpc_error.get("message", "Unknown RPC Error"), rpc_error.get("data"))
logger.info(f"<- Received Success Response (ID: {req_id}):\n{json.dumps(json_response, indent=2)}")
return json_response
except httpx.HTTPStatusError as e:
logger.error(f"HTTP Error for Request (ID: {req_id}): {e.response.status_code} - {e.request.url}")
error_body = await e.response.aread()
raise A2AClientHTTPError(e.response.status_code, f"{e}. Body: {error_body.decode('utf-8', errors='replace')}") from e
except httpx.RequestError as e:
logger.error(f"Request Error for (ID: {req_id}): {e}")
raise A2AClientError(f"Network or request error: {e}") from e
except RpcError:
raise
except Exception as e:
logger.error(f"Unexpected error during request (ID: {req_id}): {e}", exc_info=True)
raise A2AClientError(f"Unexpected error: {e}") from e
async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse:
request = SendTaskRequest(params=payload)
response_dict = await self._send_request(request)
return SendTaskResponse(**response_dict)
async def send_task_streaming(
self, payload: dict[str, Any]
) -> AsyncIterable[SendTaskStreamingResponse]:
request = SendTaskStreamingRequest(params=payload)
req_id = request.id
req_dump = request.model_dump(exclude_none=True)
logger.info(f"-> Sending Streaming Request (ID: {req_id}, Method: {request.method}):\n{json.dumps(req_dump, indent=2)}")
try:
async with self.fetchImpl.stream("POST", self.url, json=req_dump, timeout=None) as response:
logger.info(f"<- Received HTTP Status {response.status_code} for Streaming Request (ID: {req_id})")
response.raise_for_status()
buffer = ""
async for line in response.aiter_lines():
if not line:
if buffer.startswith("data:"):
data_str = buffer[len("data:"):].strip()
logger.debug(f"Received SSE Data Line (ID: {req_id}): {data_str}")
try:
sse_data_dict = json.loads(data_str)
yield SendTaskStreamingResponse(**sse_data_dict)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode SSE JSON (ID: {req_id}): {e}. Data: '{data_str}'")
except Exception as e:
logger.error(f"Error processing SSE data (ID: {req_id}): {e}. Data: '{data_str}'", exc_info=True)
elif buffer:
logger.debug(f"Received non-data SSE line (ID: {req_id}): {buffer}")
buffer = ""
else:
buffer += line + "\n"
if buffer:
logger.warning(f"SSE stream ended with partial data in buffer (ID: {req_id}): {buffer}")
logger.info(f"SSE Stream ended for request ID: {req_id}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP Error during streaming connection (ID: {req_id}): {e.response.status_code} - {e.request.url}")
error_body = await e.response.aread()
raise A2AClientHTTPError(e.response.status_code, f"{e}. Body: {error_body.decode('utf-8', errors='replace')}") from e
except httpx.RequestError as e:
logger.error(f"Request Error during streaming (ID: {req_id}): {e}")
raise A2AClientError(f"Network or request error during streaming: {e}") from e
except Exception as e:
logger.error(f"Unexpected error during streaming (ID: {req_id}): {e}", exc_info=True)
raise A2AClientError(f"Unexpected streaming error: {e}") from e
async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse:
request = GetTaskRequest(params=payload)
response_dict = await self._send_request(request)
return GetTaskResponse(**response_dict)
async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse:
request = CancelTaskRequest(params=payload)
response_dict = await self._send_request(request)
return CancelTaskResponse(**response_dict)
async def set_task_callback(
self, payload: dict[str, Any]
) -> SetTaskPushNotificationResponse:
request = SetTaskPushNotificationRequest(params=payload)
response_dict = await self._send_request(request)
return SetTaskPushNotificationResponse(**response_dict)
async def get_task_callback(
self, payload: dict[str, Any]
) -> GetTaskPushNotificationResponse:
request = GetTaskPushNotificationRequest(params=payload)
response_dict = await self._send_request(request)
return GetTaskPushNotificationResponse(**response_dict)