pocketflow/cookbook/pocketflow-fastapi-background/main.py

111 lines
3.5 KiB
Python

import asyncio
import json
import uuid
from fastapi import FastAPI, BackgroundTasks, Form
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from flow import create_article_flow
app = FastAPI()
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
# Store active jobs and their SSE queues
active_jobs = {}
def run_article_workflow(job_id: str, topic: str):
"""Run the article workflow in background"""
try:
# Get the pre-created queue from active_jobs
sse_queue = active_jobs[job_id]
shared = {
"topic": topic,
"sse_queue": sse_queue,
"sections": [],
"draft": "",
"final_article": ""
}
# Run the workflow
flow = create_article_flow()
flow.run(shared)
except Exception as e:
# Send error message
error_msg = {"step": "error", "progress": 0, "data": {"error": str(e)}}
if job_id in active_jobs:
active_jobs[job_id].put_nowait(error_msg)
@app.post("/start-job")
async def start_job(background_tasks: BackgroundTasks, topic: str = Form(...)):
"""Start a new article generation job"""
job_id = str(uuid.uuid4())
# Create SSE queue and register job immediately
sse_queue = asyncio.Queue()
active_jobs[job_id] = sse_queue
# Start background task
background_tasks.add_task(run_article_workflow, job_id, topic)
return {"job_id": job_id, "topic": topic, "status": "started"}
@app.get("/progress/{job_id}")
async def get_progress(job_id: str):
"""Stream progress updates via SSE"""
async def event_stream():
if job_id not in active_jobs:
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
return
sse_queue = active_jobs[job_id]
# Send initial connection confirmation
yield f"data: {json.dumps({'step': 'connected', 'progress': 0, 'data': {'message': 'Connected to job progress'}})}\n\n"
try:
while True:
# Wait for next progress update
try:
# Use asyncio.wait_for to avoid blocking forever
progress_msg = await asyncio.wait_for(sse_queue.get(), timeout=1.0)
yield f"data: {json.dumps(progress_msg)}\n\n"
# If job is complete, clean up and exit
if progress_msg.get("step") == "complete":
del active_jobs[job_id]
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield f"data: {json.dumps({'heartbeat': True})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
event_stream(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream"
}
)
@app.get("/")
async def get_index():
"""Serve the main page"""
return FileResponse("static/index.html")
@app.get("/progress.html")
async def get_progress_page():
"""Serve the progress page"""
return FileResponse("static/progress.html")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)