278 lines
10 KiB
Python
278 lines
10 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Union, AsyncIterable, List
|
|
from common.types import Task
|
|
from common.types import (
|
|
JSONRPCResponse,
|
|
TaskIdParams,
|
|
TaskQueryParams,
|
|
GetTaskRequest,
|
|
TaskNotFoundError,
|
|
SendTaskRequest,
|
|
CancelTaskRequest,
|
|
TaskNotCancelableError,
|
|
SetTaskPushNotificationRequest,
|
|
GetTaskPushNotificationRequest,
|
|
GetTaskResponse,
|
|
CancelTaskResponse,
|
|
SendTaskResponse,
|
|
SetTaskPushNotificationResponse,
|
|
GetTaskPushNotificationResponse,
|
|
PushNotificationNotSupportedError,
|
|
TaskSendParams,
|
|
TaskStatus,
|
|
TaskState,
|
|
TaskResubscriptionRequest,
|
|
SendTaskStreamingRequest,
|
|
SendTaskStreamingResponse,
|
|
Artifact,
|
|
PushNotificationConfig,
|
|
TaskStatusUpdateEvent,
|
|
JSONRPCError,
|
|
TaskPushNotificationConfig,
|
|
InternalError,
|
|
)
|
|
from common.server.utils import new_not_implemented_error
|
|
import asyncio
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TaskManager(ABC):
|
|
@abstractmethod
|
|
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def on_send_task_subscribe(
|
|
self, request: SendTaskStreamingRequest
|
|
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def on_set_task_push_notification(
|
|
self, request: SetTaskPushNotificationRequest
|
|
) -> SetTaskPushNotificationResponse:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def on_get_task_push_notification(
|
|
self, request: GetTaskPushNotificationRequest
|
|
) -> GetTaskPushNotificationResponse:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def on_resubscribe_to_task(
|
|
self, request: TaskResubscriptionRequest
|
|
) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]:
|
|
pass
|
|
|
|
|
|
class InMemoryTaskManager(TaskManager):
|
|
def __init__(self):
|
|
self.tasks: dict[str, Task] = {}
|
|
self.push_notification_infos: dict[str, PushNotificationConfig] = {}
|
|
self.lock = asyncio.Lock()
|
|
self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {}
|
|
self.subscriber_lock = asyncio.Lock()
|
|
|
|
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
|
logger.info(f"Getting task {request.params.id}")
|
|
task_query_params: TaskQueryParams = request.params
|
|
|
|
async with self.lock:
|
|
task = self.tasks.get(task_query_params.id)
|
|
if task is None:
|
|
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
|
|
|
|
task_result = self.append_task_history(
|
|
task, task_query_params.historyLength
|
|
)
|
|
|
|
return GetTaskResponse(id=request.id, result=task_result)
|
|
|
|
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
|
logger.info(f"Cancelling task {request.params.id}")
|
|
task_id_params: TaskIdParams = request.params
|
|
|
|
async with self.lock:
|
|
task = self.tasks.get(task_id_params.id)
|
|
if task is None:
|
|
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
|
|
|
|
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
|
|
|
|
@abstractmethod
|
|
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def on_send_task_subscribe(
|
|
self, request: SendTaskStreamingRequest
|
|
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
|
pass
|
|
|
|
async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig):
|
|
async with self.lock:
|
|
task = self.tasks.get(task_id)
|
|
if task is None:
|
|
raise ValueError(f"Task not found for {task_id}")
|
|
|
|
self.push_notification_infos[task_id] = notification_config
|
|
|
|
return
|
|
|
|
async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig:
|
|
async with self.lock:
|
|
task = self.tasks.get(task_id)
|
|
if task is None:
|
|
raise ValueError(f"Task not found for {task_id}")
|
|
|
|
return self.push_notification_infos[task_id]
|
|
|
|
return
|
|
|
|
async def has_push_notification_info(self, task_id: str) -> bool:
|
|
async with self.lock:
|
|
return task_id in self.push_notification_infos
|
|
|
|
|
|
async def on_set_task_push_notification(
|
|
self, request: SetTaskPushNotificationRequest
|
|
) -> SetTaskPushNotificationResponse:
|
|
logger.info(f"Setting task push notification {request.params.id}")
|
|
task_notification_params: TaskPushNotificationConfig = request.params
|
|
|
|
try:
|
|
await self.set_push_notification_info(task_notification_params.id, task_notification_params.pushNotificationConfig)
|
|
except Exception as e:
|
|
logger.error(f"Error while setting push notification info: {e}")
|
|
return JSONRPCResponse(
|
|
id=request.id,
|
|
error=InternalError(
|
|
message="An error occurred while setting push notification info"
|
|
),
|
|
)
|
|
|
|
return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params)
|
|
|
|
async def on_get_task_push_notification(
|
|
self, request: GetTaskPushNotificationRequest
|
|
) -> GetTaskPushNotificationResponse:
|
|
logger.info(f"Getting task push notification {request.params.id}")
|
|
task_params: TaskIdParams = request.params
|
|
|
|
try:
|
|
notification_info = await self.get_push_notification_info(task_params.id)
|
|
except Exception as e:
|
|
logger.error(f"Error while getting push notification info: {e}")
|
|
return GetTaskPushNotificationResponse(
|
|
id=request.id,
|
|
error=InternalError(
|
|
message="An error occurred while getting push notification info"
|
|
),
|
|
)
|
|
|
|
return GetTaskPushNotificationResponse(id=request.id, result=TaskPushNotificationConfig(id=task_params.id, pushNotificationConfig=notification_info))
|
|
|
|
async def upsert_task(self, task_send_params: TaskSendParams) -> Task:
|
|
logger.info(f"Upserting task {task_send_params.id}")
|
|
async with self.lock:
|
|
task = self.tasks.get(task_send_params.id)
|
|
if task is None:
|
|
task = Task(
|
|
id=task_send_params.id,
|
|
sessionId = task_send_params.sessionId,
|
|
messages=[task_send_params.message],
|
|
status=TaskStatus(state=TaskState.SUBMITTED),
|
|
history=[task_send_params.message],
|
|
)
|
|
self.tasks[task_send_params.id] = task
|
|
else:
|
|
task.history.append(task_send_params.message)
|
|
|
|
return task
|
|
|
|
async def on_resubscribe_to_task(
|
|
self, request: TaskResubscriptionRequest
|
|
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
|
return new_not_implemented_error(request.id)
|
|
|
|
async def update_store(
|
|
self, task_id: str, status: TaskStatus, artifacts: list[Artifact]
|
|
) -> Task:
|
|
async with self.lock:
|
|
try:
|
|
task = self.tasks[task_id]
|
|
except KeyError:
|
|
logger.error(f"Task {task_id} not found for updating the task")
|
|
raise ValueError(f"Task {task_id} not found")
|
|
|
|
task.status = status
|
|
|
|
if status.message is not None:
|
|
task.history.append(status.message)
|
|
|
|
if artifacts is not None:
|
|
if task.artifacts is None:
|
|
task.artifacts = []
|
|
task.artifacts.extend(artifacts)
|
|
|
|
return task
|
|
|
|
def append_task_history(self, task: Task, historyLength: int | None):
|
|
new_task = task.model_copy()
|
|
if historyLength is not None and historyLength > 0:
|
|
new_task.history = new_task.history[-historyLength:]
|
|
else:
|
|
new_task.history = []
|
|
|
|
return new_task
|
|
|
|
async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False):
|
|
async with self.subscriber_lock:
|
|
if task_id not in self.task_sse_subscribers:
|
|
if is_resubscribe:
|
|
raise ValueError("Task not found for resubscription")
|
|
else:
|
|
self.task_sse_subscribers[task_id] = []
|
|
|
|
sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited
|
|
self.task_sse_subscribers[task_id].append(sse_event_queue)
|
|
return sse_event_queue
|
|
|
|
async def enqueue_events_for_sse(self, task_id, task_update_event):
|
|
async with self.subscriber_lock:
|
|
if task_id not in self.task_sse_subscribers:
|
|
return
|
|
|
|
current_subscribers = self.task_sse_subscribers[task_id]
|
|
for subscriber in current_subscribers:
|
|
await subscriber.put(task_update_event)
|
|
|
|
async def dequeue_events_for_sse(
|
|
self, request_id, task_id, sse_event_queue: asyncio.Queue
|
|
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
|
|
try:
|
|
while True:
|
|
event = await sse_event_queue.get()
|
|
if isinstance(event, JSONRPCError):
|
|
yield SendTaskStreamingResponse(id=request_id, error=event)
|
|
break
|
|
|
|
yield SendTaskStreamingResponse(id=request_id, result=event)
|
|
if isinstance(event, TaskStatusUpdateEvent) and event.final:
|
|
break
|
|
finally:
|
|
async with self.subscriber_lock:
|
|
if task_id in self.task_sse_subscribers:
|
|
self.task_sse_subscribers[task_id].remove(sse_event_queue)
|
|
|