87 lines
3.3 KiB
Python
87 lines
3.3 KiB
Python
import httpx
|
|
from httpx_sse import connect_sse
|
|
from typing import Any, AsyncIterable
|
|
from common.types import (
|
|
AgentCard,
|
|
GetTaskRequest,
|
|
SendTaskRequest,
|
|
SendTaskResponse,
|
|
JSONRPCRequest,
|
|
GetTaskResponse,
|
|
CancelTaskResponse,
|
|
CancelTaskRequest,
|
|
SetTaskPushNotificationRequest,
|
|
SetTaskPushNotificationResponse,
|
|
GetTaskPushNotificationRequest,
|
|
GetTaskPushNotificationResponse,
|
|
A2AClientHTTPError,
|
|
A2AClientJSONError,
|
|
SendTaskStreamingRequest,
|
|
SendTaskStreamingResponse,
|
|
)
|
|
import json
|
|
|
|
|
|
class A2AClient:
|
|
def __init__(self, agent_card: AgentCard = None, url: str = None):
|
|
if agent_card:
|
|
self.url = agent_card.url
|
|
elif url:
|
|
self.url = url
|
|
else:
|
|
raise ValueError("Must provide either agent_card or url")
|
|
|
|
async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse:
|
|
request = SendTaskRequest(params=payload)
|
|
return SendTaskResponse(**await self._send_request(request))
|
|
|
|
async def send_task_streaming(
|
|
self, payload: dict[str, Any]
|
|
) -> AsyncIterable[SendTaskStreamingResponse]:
|
|
request = SendTaskStreamingRequest(params=payload)
|
|
with httpx.Client(timeout=None) as client:
|
|
with connect_sse(
|
|
client, "POST", self.url, json=request.model_dump()
|
|
) as event_source:
|
|
try:
|
|
for sse in event_source.iter_sse():
|
|
yield SendTaskStreamingResponse(**json.loads(sse.data))
|
|
except json.JSONDecodeError as e:
|
|
raise A2AClientJSONError(str(e)) from e
|
|
except httpx.RequestError as e:
|
|
raise A2AClientHTTPError(400, str(e)) from e
|
|
|
|
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
# Image generation could take time, adding timeout
|
|
response = await client.post(
|
|
self.url, json=request.model_dump(), timeout=30
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except httpx.HTTPStatusError as e:
|
|
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
|
|
except json.JSONDecodeError as e:
|
|
raise A2AClientJSONError(str(e)) from e
|
|
|
|
async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse:
|
|
request = GetTaskRequest(params=payload)
|
|
return GetTaskResponse(**await self._send_request(request))
|
|
|
|
async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse:
|
|
request = CancelTaskRequest(params=payload)
|
|
return CancelTaskResponse(**await self._send_request(request))
|
|
|
|
async def set_task_callback(
|
|
self, payload: dict[str, Any]
|
|
) -> SetTaskPushNotificationResponse:
|
|
request = SetTaskPushNotificationRequest(params=payload)
|
|
return SetTaskPushNotificationResponse(**await self._send_request(request))
|
|
|
|
async def get_task_callback(
|
|
self, payload: dict[str, Any]
|
|
) -> GetTaskPushNotificationResponse:
|
|
request = GetTaskPushNotificationRequest(params=payload)
|
|
return GetTaskPushNotificationResponse(**await self._send_request(request))
|