pocketflow/cookbook/pocketflow-fastapi-hitl/server.py

253 lines
11 KiB
Python

import asyncio
import uuid
import json
import os
from fastapi import FastAPI, Request, HTTPException, status, BackgroundTasks # Import BackgroundTasks
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, Field # Import Pydantic for request/response models
from typing import Dict, Any, Literal # For type hinting
from flow import create_feedback_flow # PocketFlow imports
# --- Configuration ---
app = FastAPI(title="Minimal Feedback Loop API")
static_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'static'))
if os.path.isdir(static_dir):
app.mount("/static", StaticFiles(directory=static_dir), name="static")
else:
print(f"Warning: Static directory '{static_dir}' not found.")
template_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'templates'))
if os.path.isdir(template_dir):
templates = Jinja2Templates(directory=template_dir)
else:
print(f"Warning: Template directory '{template_dir}' not found.")
templates = None
# --- State Management (In-Memory - NOT FOR PRODUCTION) ---
# Global dictionary to store task state. In production, use Redis, DB, etc.
tasks: Dict[str, Dict[str, Any]] = {}
# Structure: task_id -> {"shared": dict, "status": str, "task_obj": asyncio.Task | None}
# --- Background Flow Runner ---
# This function remains mostly the same, as it defines the work to be done.
# It will be scheduled by FastAPI's BackgroundTasks now.
async def run_flow_background(task_id: str, flow, shared: Dict[str, Any]):
"""Runs the flow in background, uses queue in shared for SSE."""
# Check if task exists (might have been cancelled/deleted)
if task_id not in tasks:
print(f"Background task {task_id}: Task not found, aborting.")
return
queue = shared.get("sse_queue")
if not queue:
print(f"ERROR: Task {task_id} missing sse_queue in shared store!")
tasks[task_id]["status"] = "failed"
# Cannot report failure via SSE if queue is missing
return
tasks[task_id]["status"] = "running"
await queue.put({"status": "running"})
print(f"Task {task_id}: Background flow starting.")
final_status = "unknown"
error_message = None
try:
# Execute the potentially long-running PocketFlow
await flow.run_async(shared)
# Determine final status based on shared state after flow completion
if shared.get("final_result") is not None:
final_status = "completed"
else:
# If flow ends without setting final_result
final_status = "finished_incomplete"
print(f"Task {task_id}: Flow finished with status: {final_status}")
except Exception as e:
final_status = "failed"
error_message = str(e)
print(f"Task {task_id}: Flow execution failed: {e}")
# Consider logging traceback here in production
finally:
# Ensure task still exists before updating state
if task_id in tasks:
tasks[task_id]["status"] = final_status
final_update = {"status": final_status}
if final_status == "completed":
final_update["final_result"] = shared.get("final_result")
elif error_message:
final_update["error"] = error_message
# Put final status update onto the queue
await queue.put(final_update)
# Signal the end of the SSE stream by putting None
# Must happen regardless of whether task was deleted mid-run
if queue:
await queue.put(None)
print(f"Task {task_id}: Background task ended. Final update sentinel put on queue.")
# Remove the reference to the completed/failed asyncio Task object
if task_id in tasks:
tasks[task_id]["task_obj"] = None
# --- Pydantic Models for Request/Response Validation ---
class SubmitRequest(BaseModel):
data: str = Field(..., min_length=1, description="Input data for the task")
class SubmitResponse(BaseModel):
message: str = "Task submitted"
task_id: str
class FeedbackRequest(BaseModel):
feedback: Literal["approved", "rejected"] # Use Literal for specific choices
class FeedbackResponse(BaseModel):
message: str
# --- FastAPI Routes ---
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
async def get_index(request: Request):
"""Serves the main HTML frontend."""
if templates is None:
raise HTTPException(status_code=500, detail="Templates directory not configured.")
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/submit", response_model=SubmitResponse, status_code=status.HTTP_202_ACCEPTED)
async def submit_task(
submit_request: SubmitRequest, # Use Pydantic model for validation
background_tasks: BackgroundTasks # Inject BackgroundTasks instance
):
"""
Submits a new task. The actual processing runs in the background.
Returns immediately with the task ID.
"""
task_id = str(uuid.uuid4())
feedback_event = asyncio.Event()
status_queue = asyncio.Queue()
shared = {
"task_input": submit_request.data,
"processed_output": None,
"feedback": None,
"review_event": feedback_event,
"sse_queue": status_queue,
"final_result": None,
"task_id": task_id
}
flow = create_feedback_flow()
# Store task state BEFORE scheduling background task
tasks[task_id] = {
"shared": shared,
"status": "pending",
"task_obj": None # Placeholder for the asyncio Task created by BackgroundTasks
}
await status_queue.put({"status": "pending", "task_id": task_id})
# Schedule the flow execution using FastAPI's BackgroundTasks
# This runs AFTER the response has been sent
background_tasks.add_task(run_flow_background, task_id, flow, shared)
# Note: We don't get a direct reference to the asyncio Task object this way,
# which is fine for this minimal example. If cancellation were needed,
# managing asyncio.create_task manually would be necessary.
print(f"Task {task_id}: Submitted, scheduled for background execution.")
return SubmitResponse(task_id=task_id)
@app.post("/feedback/{task_id}", response_model=FeedbackResponse)
async def provide_feedback(task_id: str, feedback_request: FeedbackRequest):
"""Provides feedback (approved/rejected) to potentially unblock a waiting task."""
if task_id not in tasks:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Task not found")
task_info = tasks[task_id]
shared = task_info["shared"]
queue = shared.get("sse_queue")
review_event = shared.get("review_event")
async def report_error(message, status_code=status.HTTP_400_BAD_REQUEST):
# Helper to log, put status on queue, and raise HTTP exception
print(f"Task {task_id}: Feedback error - {message}")
if queue: await queue.put({"status": "feedback_error", "error": message})
raise HTTPException(status_code=status_code, detail=message)
if not review_event:
# This indicates an internal setup error if the task exists but has no event
await report_error("Task not configured for feedback", status.HTTP_500_INTERNAL_SERVER_ERROR)
if review_event.is_set():
# Prevent processing feedback multiple times or if the task isn't waiting
await report_error("Task not awaiting feedback or feedback already sent", status.HTTP_409_CONFLICT)
feedback = feedback_request.feedback # Already validated by Pydantic
print(f"Task {task_id}: Received feedback via POST: {feedback}")
# Update status *before* setting the event, so client sees 'processing' first
if queue: await queue.put({"status": "processing_feedback", "feedback_value": feedback})
tasks[task_id]["status"] = "processing_feedback" # Update central status tracker
# Store feedback and signal the waiting ReviewNode
shared["feedback"] = feedback
review_event.set()
return FeedbackResponse(message=f"Feedback '{feedback}' received")
# --- SSE Endpoint ---
@app.get("/stream/{task_id}")
async def stream_status(task_id: str):
"""Streams status updates for a given task using Server-Sent Events."""
if task_id not in tasks or "sse_queue" not in tasks[task_id]["shared"]:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Task or queue not found")
queue = tasks[task_id]["shared"]["sse_queue"]
async def event_generator():
"""Yields SSE messages from the task's queue."""
print(f"SSE Stream: Client connected for {task_id}")
try:
while True:
# Wait for the next status update from the queue
update = await queue.get()
if update is None: # Sentinel value indicates end of stream
print(f"SSE Stream: Sentinel received for {task_id}, closing stream.")
yield f"data: {json.dumps({'status': 'stream_closed'})}\n\n"
break
sse_data = json.dumps(update)
print(f"SSE Stream: Sending for {task_id}: {sse_data}")
yield f"data: {sse_data}\n\n" # SSE format: "data: <json>\n\n"
queue.task_done() # Acknowledge processing the queue item
except asyncio.CancelledError:
# This happens if the client disconnects
print(f"SSE Stream: Client disconnected for {task_id}.")
except Exception as e:
# Log unexpected errors during streaming
print(f"SSE Stream: Error in generator for {task_id}: {e}")
# Optionally send an error message to the client if possible
try:
yield f"data: {json.dumps({'status': 'stream_error', 'error': str(e)})}\n\n"
except Exception: # Catch errors if yield fails (e.g., connection already closed)
pass
finally:
print(f"SSE Stream: Generator finished for {task_id}.")
# Consider cleanup here (e.g., removing task if no longer needed)
# if task_id in tasks: del tasks[task_id]
# Use FastAPI/Starlette's StreamingResponse for SSE
headers = {'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no'}
return StreamingResponse(event_generator(), media_type="text/event-stream", headers=headers)
# --- Main Execution Guard (for running with uvicorn) ---
if __name__ == "__main__":
print("Starting FastAPI server using Uvicorn is recommended:")
print("uvicorn server:app --reload --host 0.0.0.0 --port 8000")
# Example using uvicorn programmatically (less common than CLI)
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)