add guardrail
This commit is contained in:
parent
32e8902c8b
commit
9c093df100
|
|
@ -0,0 +1,71 @@
|
|||
# Travel Advisor Chat with Guardrails
|
||||
|
||||
A travel-focused chat application using PocketFlow with OpenAI's GPT-4o model, enhanced with input validation to ensure only travel-related queries are processed.
|
||||
|
||||
## Features
|
||||
|
||||
- Travel advisor chatbot that answers questions about destinations, planning, accommodations, etc.
|
||||
- **Topic-specific guardrails** to ensure only travel-related queries are accepted
|
||||
|
||||
## Run It
|
||||
|
||||
1. Make sure your OpenAI API key is set:
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
Alternatively, you can edit the `utils.py` file to include your API key directly.
|
||||
|
||||
2. Install requirements and run the application:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
python main.py
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
user[UserInputNode] -->|validate| guardrail[GuardrailNode]
|
||||
guardrail -->|retry| user
|
||||
guardrail -->|process| llm[LLMNode]
|
||||
llm -->|continue| user
|
||||
```
|
||||
|
||||
The chat application uses:
|
||||
- A `UserInputNode` that collects user input in its `exec` method
|
||||
- A `GuardrailNode` that validates if the query is travel-related using:
|
||||
- Basic validation checks (empty input, too short)
|
||||
- LLM-based validation to determine if the query relates to travel
|
||||
- An `LLMNode` that processes valid travel queries using GPT-4o with a travel advisor system prompt
|
||||
- Flow connections that route inputs through validation before processing and handle retries for non-travel related queries
|
||||
|
||||
## Files
|
||||
|
||||
- [`main.py`](./main.py): Implementation of the nodes and chat flow
|
||||
- [`utils.py`](./utils.py): Utilities for calling the OpenAI API
|
||||
|
||||
## Example Outputs
|
||||
|
||||
```
|
||||
Welcome to the Travel Advisor Chat! Type 'exit' to end the conversation.
|
||||
|
||||
You: Plan my trip to Thailand
|
||||
|
||||
Travel Advisor: Thailand offers a blend of vibrant cities, serene beaches, and rich culture. Begin in Bangkok to explore the Grand Palace and local markets. Head north to Chiang Mai for temples and elephant sanctuaries. Fly south to the islands; Phuket or Koh Samui for beaches, diving, and nightlife. Consider visiting during the cool season (November-February) for pleasant weather. Accommodation ranges from budget hostels to luxury resorts. Internal flights, trains, and buses connect major destinations. Don’t miss local cuisine, from street food to fine dining. Ensure your passport is valid for six months and consider travel insurance for peace of mind. Enjoy!
|
||||
|
||||
You: exit
|
||||
|
||||
Goodbye! Safe travels!
|
||||
```
|
||||
|
||||
```
|
||||
Welcome to the Travel Advisor Chat! Type 'exit' to end the conversation.
|
||||
|
||||
You: How to study large language models?
|
||||
|
||||
Travel Advisor: The query is not related to travel advice, destinations, planning, or other travel topics. It is about studying large language models, which is a topic related to artificial intelligence and machine learning.
|
||||
|
||||
You: exit
|
||||
|
||||
Goodbye! Safe travels!
|
||||
```
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
from pocketflow import Node, Flow
|
||||
from utils import call_llm
|
||||
|
||||
class UserInputNode(Node):
|
||||
def prep(self, shared):
|
||||
# Initialize messages if this is the first run
|
||||
if "messages" not in shared:
|
||||
shared["messages"] = []
|
||||
print("Welcome to the Travel Advisor Chat! Type 'exit' to end the conversation.")
|
||||
|
||||
return None
|
||||
|
||||
def exec(self, _):
|
||||
# Get user input
|
||||
user_input = input("\nYou: ")
|
||||
return user_input
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
user_input = exec_res
|
||||
|
||||
# Check if user wants to exit
|
||||
if user_input and user_input.lower() == 'exit':
|
||||
print("\nGoodbye! Safe travels!")
|
||||
return None # End the conversation
|
||||
|
||||
# Store user input in shared
|
||||
shared["user_input"] = user_input
|
||||
|
||||
# Move to guardrail validation
|
||||
return "validate"
|
||||
|
||||
class GuardrailNode(Node):
|
||||
def prep(self, shared):
|
||||
# Get the user input from shared data
|
||||
user_input = shared.get("user_input", "")
|
||||
return user_input
|
||||
|
||||
def exec(self, user_input):
|
||||
# Basic validation checks
|
||||
if not user_input or user_input.strip() == "":
|
||||
return False, "Your query is empty. Please provide a travel-related question."
|
||||
|
||||
if len(user_input.strip()) < 3:
|
||||
return False, "Your query is too short. Please provide more details about your travel question."
|
||||
|
||||
# LLM-based validation for travel topics
|
||||
prompt = f"""
|
||||
Evaluate if the following user query is related to travel advice, destinations, planning, or other travel topics.
|
||||
The chat should ONLY answer travel-related questions and reject any off-topic, harmful, or inappropriate queries.
|
||||
User query: {user_input}
|
||||
Return your evaluation in YAML format:
|
||||
```yaml
|
||||
valid: true/false
|
||||
reason: [Explain why the query is valid or invalid]
|
||||
```"""
|
||||
|
||||
# Call LLM with the validation prompt
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
response = call_llm(messages)
|
||||
|
||||
# Extract YAML content
|
||||
yaml_content = response.split("```yaml")[1].split("```")[0].strip() if "```yaml" in response else response
|
||||
|
||||
import yaml
|
||||
result = yaml.safe_load(yaml_content)
|
||||
assert result is not None, "Error: Invalid YAML format"
|
||||
assert "valid" in result and "reason" in result, "Error: Invalid YAML format"
|
||||
is_valid = result.get("valid", False)
|
||||
reason = result.get("reason", "Missing reason in YAML response")
|
||||
|
||||
return is_valid, reason
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
is_valid, message = exec_res
|
||||
|
||||
if not is_valid:
|
||||
# Display error message to user
|
||||
print(f"\nTravel Advisor: {message}")
|
||||
# Skip LLM call and go back to user input
|
||||
return "retry"
|
||||
|
||||
# Valid input, add to message history
|
||||
shared["messages"].append({"role": "user", "content": shared["user_input"]})
|
||||
# Proceed to LLM processing
|
||||
return "process"
|
||||
|
||||
class LLMNode(Node):
|
||||
def prep(self, shared):
|
||||
# Add system message if not present
|
||||
if not any(msg.get("role") == "system" for msg in shared["messages"]):
|
||||
shared["messages"].insert(0, {
|
||||
"role": "system",
|
||||
"content": "You are a helpful travel advisor that provides information about destinations, travel planning, accommodations, transportation, activities, and other travel-related topics. Only respond to travel-related queries and keep responses informative and friendly. Your response are concise in 100 words."
|
||||
})
|
||||
|
||||
# Return all messages for the LLM
|
||||
return shared["messages"]
|
||||
|
||||
def exec(self, messages):
|
||||
# Call LLM with the entire conversation history
|
||||
response = call_llm(messages)
|
||||
return response
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
# Print the assistant's response
|
||||
print(f"\nTravel Advisor: {exec_res}")
|
||||
|
||||
# Add assistant message to history
|
||||
shared["messages"].append({"role": "assistant", "content": exec_res})
|
||||
|
||||
# Loop back to continue the conversation
|
||||
return "continue"
|
||||
|
||||
# Create the flow with nodes and connections
|
||||
user_input_node = UserInputNode()
|
||||
guardrail_node = GuardrailNode()
|
||||
llm_node = LLMNode()
|
||||
|
||||
# Create flow connections
|
||||
user_input_node - "validate" >> guardrail_node
|
||||
guardrail_node - "retry" >> user_input_node # Loop back if input is invalid
|
||||
guardrail_node - "process" >> llm_node
|
||||
llm_node - "continue" >> user_input_node # Continue conversation
|
||||
|
||||
flow = Flow(start=user_input_node)
|
||||
|
||||
# Start the chat
|
||||
if __name__ == "__main__":
|
||||
shared = {}
|
||||
flow.run(shared)
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
pocketflow>=0.0.1
|
||||
openai>=1.0.0
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
from openai import OpenAI
|
||||
import os
|
||||
|
||||
def call_llm(messages):
|
||||
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the LLM call
|
||||
messages = [{"role": "user", "content": "In a few words, what's the meaning of life?"}]
|
||||
response = call_llm(messages)
|
||||
print(f"Prompt: {messages[0]['content']}")
|
||||
print(f"Response: {response}")
|
||||
|
|
@ -57,7 +57,6 @@ reasons:
|
|||
# Extract YAML content
|
||||
yaml_content = response.split("```yaml")[1].split("```")[0].strip() if "```yaml" in response else response
|
||||
result = yaml.safe_load(yaml_content)
|
||||
|
||||
|
||||
return (filename, result)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue