finish websocket example

This commit is contained in:
zachary62 2025-05-26 17:34:21 -04:00
parent b8817f08d1
commit 953a506c05
5 changed files with 200 additions and 225 deletions

View File

@ -1,135 +1,49 @@
# PocketFlow FastAPI WebSocket Chat Interface # PocketFlow FastAPI WebSocket Chat
A minimal real-time chat interface built with FastAPI, WebSocket, and PocketFlow that supports streaming LLM responses. Real-time chat interface with streaming LLM responses using PocketFlow, FastAPI, and WebSocket.
## Features ## Features
- 🚀 **Real-time Communication**: WebSocket-based bidirectional communication - **Real-time Streaming**: See AI responses typed out in real-time as the LLM generates them
- 📡 **Streaming Responses**: See AI responses being typed out in real-time - **Conversation Memory**: Maintains chat history across messages
- 🔄 **Persistent Connection**: Stay connected throughout the conversation - **Modern UI**: Clean, responsive chat interface with gradient design
- 💬 **Conversation History**: Maintains context across messages - **WebSocket Connection**: Persistent connection for instant communication
- 🎨 **Modern UI**: Clean, responsive chat interface - **PocketFlow Integration**: Uses PocketFlow `AsyncNode` and `AsyncFlow` for streaming
- 🛠️ **Minimal Dependencies**: Built with minimal, production-ready dependencies
## Quick Start ## How to Run
### 1. Install Dependencies 1. **Set OpenAI API Key:**
```bash
export OPENAI_API_KEY="your-openai-api-key"
```
2. **Install Dependencies:**
```bash ```bash
pip install -r requirements.txt pip install -r requirements.txt
``` ```
### 2. Set Up OpenAI API Key (Optional) 3. **Run the Application:**
For real LLM responses, set your OpenAI API key:
```bash
export OPENAI_API_KEY="your-api-key-here"
```
### 3. Run the Application
```bash ```bash
python main.py python main.py
``` ```
### 4. Open in Browser 4. **Access the Web UI:**
Open `http://localhost:8000` in your browser.
Navigate to: `http://localhost:8000`
## Architecture
This application uses a **simplified single-node pattern** with PocketFlow:
```mermaid
flowchart TD
websocket[FastAPI WebSocket] --> stream[Streaming Chat Node]
stream --> websocket
```
### Components
- **FastAPI**: Web framework with WebSocket support
- **PocketFlow**: Single node handles message processing and LLM streaming
- **Streaming LLM**: Real-time response generation
### File Structure
```
cookbook/pocketflow-fastapi-websocket/
├── main.py # FastAPI application with WebSocket endpoint
├── nodes.py # Single PocketFlow node for chat processing
├── flow.py # Simple flow with one node
├── utils/
│ └── stream_llm.py # LLM streaming utilities
├── requirements.txt # Dependencies
├── README.md # This file
└── docs/
└── design.md # Detailed design documentation
```
## Usage ## Usage
1. **Start a Conversation**: Type a message and press Enter or click Send 1. **Type Message**: Enter your message in the input field
2. **Watch Streaming**: See the AI response appear in real-time 2. **Send**: Press Enter or click Send button
3. **Continue Chatting**: The conversation maintains context automatically 3. **Watch Streaming**: See the AI response appear in real-time
4. **Multiple Users**: Each WebSocket connection has its own conversation 4. **Continue Chat**: Conversation history is maintained automatically
## Development ## Files
### Using Real OpenAI API - [`main.py`](./main.py): FastAPI application with WebSocket endpoint
- [`nodes.py`](./nodes.py): PocketFlow `StreamingChatNode` definition
To use real OpenAI API instead of fake responses: - [`flow.py`](./flow.py): PocketFlow `AsyncFlow` for chat processing
- [`utils/stream_llm.py`](./utils/stream_llm.py): OpenAI streaming utility
1. Set your API key: `export OPENAI_API_KEY="your-key"` - [`static/index.html`](./static/index.html): Modern chat interface
2. In `nodes.py`, change line 35 from `fake_stream_llm(formatted_prompt)` to `stream_llm(formatted_prompt)` - [`requirements.txt`](./requirements.txt): Project dependencies
- [`docs/design.md`](./docs/design.md): System design documentation
### Testing - [`README.md`](./README.md): This file
Test the PocketFlow logic without WebSocket:
```bash
python test_flow.py
```
Test the streaming utility:
```bash
cd utils
python stream_llm.py
```
### Customization
- **Modify System Prompt**: Edit the system prompt in `nodes.py` StreamingChatNode
- **Change UI**: Update the HTML template in `main.py`
- **Add Features**: Extend the single node or add new nodes to the flow
## Why This Simple Design?
This implementation demonstrates PocketFlow's philosophy of **minimal complexity**:
- **Single Node**: One node handles message processing, LLM calls, and streaming
- **No Utility Bloat**: Direct JSON handling instead of wrapper functions
- **Clear Separation**: FastAPI handles WebSocket, PocketFlow handles LLM logic
- **Easy to Extend**: Simple to add features like RAG, agents, or multi-step workflows
## Production Considerations
- **Connection Management**: Use Redis or database for connection storage
- **Rate Limiting**: Add rate limiting for API calls
- **Error Handling**: Enhance error handling and user feedback
- **Authentication**: Add user authentication if needed
- **Scaling**: Use multiple workers with proper session management
## Technology Stack
- **Backend**: FastAPI + WebSocket
- **Frontend**: Pure HTML/CSS/JavaScript
- **AI Framework**: PocketFlow (single node)
- **LLM**: OpenAI GPT-4
- **Real-time**: WebSocket with streaming
## License
MIT License

View File

@ -15,15 +15,19 @@ async def get_chat_interface():
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
# Initialize conversation history for this connection
shared_store = {
"websocket": websocket,
"conversation_history": []
}
try: try:
while True: while True:
data = await websocket.receive_text() data = await websocket.receive_text()
message = json.loads(data) message = json.loads(data)
shared_store = { # Update only the current message, keep conversation history
"websocket": websocket, shared_store["user_message"] = message.get("content", "")
"user_message": message.get("content", "")
}
flow = create_streaming_chat_flow() flow = create_streaming_chat_flow()
await flow.run_async(shared_store) await flow.run_async(shared_store)

View File

@ -4,7 +4,7 @@ from pocketflow import AsyncNode
from utils.stream_llm import stream_llm from utils.stream_llm import stream_llm
class StreamingChatNode(AsyncNode): class StreamingChatNode(AsyncNode):
def prep(self, shared): async def prep_async(self, shared):
user_message = shared.get("user_message", "") user_message = shared.get("user_message", "")
websocket = shared.get("websocket") websocket = shared.get("websocket")
@ -19,7 +19,7 @@ class StreamingChatNode(AsyncNode):
await websocket.send_text(json.dumps({"type": "start", "content": ""})) await websocket.send_text(json.dumps({"type": "start", "content": ""}))
full_response = "" full_response = ""
for chunk_content in stream_llm(messages): async for chunk_content in stream_llm(messages):
full_response += chunk_content full_response += chunk_content
await websocket.send_text(json.dumps({ await websocket.send_text(json.dumps({
"type": "chunk", "type": "chunk",
@ -30,11 +30,9 @@ class StreamingChatNode(AsyncNode):
return full_response, websocket return full_response, websocket
def post(self, shared, prep_res, exec_res): async def post_async(self, shared, prep_res, exec_res):
full_response, websocket = exec_res full_response, websocket = exec_res
conversation_history = shared.get("conversation_history", []) conversation_history = shared.get("conversation_history", [])
conversation_history.append({"role": "assistant", "content": full_response}) conversation_history.append({"role": "assistant", "content": full_response})
shared["conversation_history"] = conversation_history shared["conversation_history"] = conversation_history
return "stream"

View File

@ -2,82 +2,155 @@
<html> <html>
<head> <head>
<title>PocketFlow Chat</title> <title>PocketFlow Chat</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
<style> <style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body { body {
font-family: Arial, sans-serif; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
max-width: 800px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
margin: 0 auto; min-height: 100vh;
padding: 20px;
background-color: #f5f5f5;
}
.chat-container {
background: white;
border-radius: 10px;
padding: 20px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.messages {
height: 400px;
overflow-y: auto;
border: 1px solid #ddd;
padding: 10px;
margin-bottom: 10px;
background-color: #fafafa;
border-radius: 5px;
}
.message {
margin-bottom: 10px;
padding: 8px 12px;
border-radius: 8px;
max-width: 80%;
}
.user-message {
background-color: #007bff;
color: white;
margin-left: auto;
text-align: right;
}
.ai-message {
background-color: #e9ecef;
color: #333;
}
.input-container {
display: flex; display: flex;
gap: 10px; align-items: center;
justify-content: center;
padding: 20px;
} }
.chat-container {
background: rgba(255, 255, 255, 0.95);
backdrop-filter: blur(10px);
border-radius: 20px;
width: 100%;
max-width: 600px;
height: 80vh;
display: flex;
flex-direction: column;
box-shadow: 0 20px 40px rgba(0,0,0,0.1);
overflow: hidden;
}
.header {
padding: 20px;
background: rgba(255, 255, 255, 0.1);
border-bottom: 1px solid rgba(255, 255, 255, 0.2);
text-align: center;
}
.header h1 {
font-size: 24px;
font-weight: 600;
color: #333;
margin-bottom: 5px;
}
.status {
font-size: 14px;
color: #666;
font-weight: 500;
}
.messages {
flex: 1;
overflow-y: auto;
padding: 20px;
display: flex;
flex-direction: column;
gap: 16px;
}
.message {
max-width: 80%;
padding: 12px 16px;
border-radius: 18px;
font-size: 15px;
line-height: 1.4;
word-wrap: break-word;
}
.user-message {
background: linear-gradient(135deg, #667eea, #764ba2);
color: white;
align-self: flex-end;
border-bottom-right-radius: 4px;
}
.ai-message {
background: #f1f3f4;
color: #333;
align-self: flex-start;
border-bottom-left-radius: 4px;
}
.input-container {
padding: 20px;
background: rgba(255, 255, 255, 0.1);
border-top: 1px solid rgba(255, 255, 255, 0.2);
display: flex;
gap: 12px;
}
#messageInput { #messageInput {
flex: 1; flex: 1;
padding: 10px; padding: 12px 16px;
border: 1px solid #ddd; border: none;
border-radius: 5px; border-radius: 25px;
font-size: 16px; background: white;
font-size: 15px;
outline: none;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
} }
#messageInput::placeholder {
color: #999;
}
#sendButton { #sendButton {
padding: 10px 20px; padding: 12px 24px;
background-color: #007bff; background: linear-gradient(135deg, #667eea, #764ba2);
color: white; color: white;
border: none; border: none;
border-radius: 5px; border-radius: 25px;
cursor: pointer; cursor: pointer;
font-size: 16px; font-size: 15px;
font-weight: 600;
transition: all 0.2s ease;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
} }
#sendButton:hover:not(:disabled) {
transform: translateY(-1px);
box-shadow: 0 4px 15px rgba(0,0,0,0.2);
}
#sendButton:disabled { #sendButton:disabled {
background-color: #ccc; background: #ccc;
cursor: not-allowed; cursor: not-allowed;
transform: none;
} }
.status {
text-align: center; .messages::-webkit-scrollbar {
color: #666; width: 6px;
font-style: italic; }
margin: 10px 0;
.messages::-webkit-scrollbar-track {
background: transparent;
}
.messages::-webkit-scrollbar-thumb {
background: rgba(0,0,0,0.2);
border-radius: 3px;
} }
</style> </style>
</head> </head>
<body> <body>
<div class="chat-container"> <div class="chat-container">
<h1>🤖 PocketFlow Chat Interface</h1> <div class="header">
<h1>PocketFlow Chat</h1>
<div class="status" id="status">Connecting...</div> <div class="status" id="status">Connecting...</div>
</div>
<div class="messages" id="messages"></div> <div class="messages" id="messages"></div>
<div class="input-container"> <div class="input-container">
<input type="text" id="messageInput" placeholder="Type your message..." disabled> <input type="text" id="messageInput" placeholder="Type your message..." disabled>
<button id="sendButton" disabled>Send</button> <button id="sendButton" disabled>Send</button>
@ -94,8 +167,8 @@
let isStreaming = false; let isStreaming = false;
let currentAiMessage = null; let currentAiMessage = null;
ws.onopen = function(event) { ws.onopen = function() {
statusDiv.textContent = 'Connected'; statusDiv.textContent = 'Connected';
messageInput.disabled = false; messageInput.disabled = false;
sendButton.disabled = false; sendButton.disabled = false;
messageInput.focus(); messageInput.focus();
@ -108,12 +181,10 @@
isStreaming = true; isStreaming = true;
currentAiMessage = document.createElement('div'); currentAiMessage = document.createElement('div');
currentAiMessage.className = 'message ai-message'; currentAiMessage.className = 'message ai-message';
currentAiMessage.textContent = '';
messagesDiv.appendChild(currentAiMessage); messagesDiv.appendChild(currentAiMessage);
messagesDiv.scrollTop = messagesDiv.scrollHeight; messagesDiv.scrollTop = messagesDiv.scrollHeight;
sendButton.disabled = true; sendButton.disabled = true;
statusDiv.textContent = '🤖 AI is typing...'; statusDiv.textContent = 'AI is typing...';
} else if (data.type === 'chunk') { } else if (data.type === 'chunk') {
if (currentAiMessage) { if (currentAiMessage) {
@ -125,34 +196,17 @@
isStreaming = false; isStreaming = false;
currentAiMessage = null; currentAiMessage = null;
sendButton.disabled = false; sendButton.disabled = false;
statusDiv.textContent = 'Connected'; statusDiv.textContent = 'Connected';
messageInput.focus(); messageInput.focus();
} else if (data.type === 'error') {
const errorMessage = document.createElement('div');
errorMessage.className = 'message ai-message';
errorMessage.textContent = `Error: ${data.content}`;
errorMessage.style.color = 'red';
messagesDiv.appendChild(errorMessage);
messagesDiv.scrollTop = messagesDiv.scrollHeight;
isStreaming = false;
sendButton.disabled = false;
statusDiv.textContent = '✅ Connected';
} }
}; };
ws.onclose = function(event) { ws.onclose = function() {
statusDiv.textContent = 'Disconnected'; statusDiv.textContent = 'Disconnected';
messageInput.disabled = true; messageInput.disabled = true;
sendButton.disabled = true; sendButton.disabled = true;
}; };
ws.onerror = function(error) {
statusDiv.textContent = '❌ Connection Error';
console.error('WebSocket error:', error);
};
function sendMessage() { function sendMessage() {
const message = messageInput.value.trim(); const message = messageInput.value.trim();
if (message && !isStreaming) { if (message && !isStreaming) {
@ -168,13 +222,13 @@
})); }));
messageInput.value = ''; messageInput.value = '';
statusDiv.textContent = '📤 Sending...'; statusDiv.textContent = 'Sending...';
} }
} }
sendButton.addEventListener('click', sendMessage); sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keypress', function(e) { messageInput.addEventListener('keypress', function(e) {
if (e.key === 'Enter' && !e.shiftKey) { if (e.key === 'Enter') {
e.preventDefault(); e.preventDefault();
sendMessage(); sendMessage();
} }

View File

@ -1,22 +1,27 @@
import os import os
from openai import OpenAI from openai import AsyncOpenAI
def stream_llm(messages): async def stream_llm(messages):
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key")) client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
stream = client.chat.completions.create( stream = await client.chat.completions.create(
model="gpt-4o-mini", model="gpt-4o-mini",
messages=messages, messages=messages,
stream=True, stream=True,
temperature=0.7 temperature=0.7
) )
for chunk in stream: async for chunk in stream:
if chunk.choices[0].delta.content is not None: if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content yield chunk.choices[0].delta.content
if __name__ == "__main__": if __name__ == "__main__":
import asyncio
async def test():
messages = [{"role": "user", "content": "Hello!"}] messages = [{"role": "user", "content": "Hello!"}]
for chunk in stream_llm(messages): async for chunk in stream_llm(messages):
print(chunk, end="", flush=True) print(chunk, end="", flush=True)
print() print()
asyncio.run(test())