add text-to-sql

This commit is contained in:
zachary62 2025-04-23 13:29:53 -04:00
parent 2b45e97514
commit 1f779bfdf7
8 changed files with 575 additions and 0 deletions

View File

@ -0,0 +1,165 @@
# Text-to-SQL Workflow
A PocketFlow example demonstrating a text-to-SQL workflow that converts natural language questions into executable SQL queries for an SQLite database, including an LLM-powered debugging loop for failed queries.
## Features
- **Schema Awareness**: Automatically retrieves the database schema to provide context to the LLM.
- **LLM-Powered SQL Generation**: Uses an LLM (GPT-4o) to translate natural language questions into SQLite queries (using YAML structured output).
- **Automated Debugging Loop**: If SQL execution fails, an LLM attempts to correct the query based on the error message. This process repeats up to a configurable number of times.
## Getting Started
1. **Install Packages:**
```bash
pip install -r requirements.txt
```
2. **Set API Key:**
Set the environment variable for your OpenAI API key.
```bash
export OPENAI_API_KEY="your-api-key-here"
```
*(Replace `"your-api-key-here"` with your actual key)*
3. **Verify API Key (Optional):**
Run a quick check using the utility script. If successful, it will print a short joke.
```bash
python utils.py
```
*(Note: This requires a valid API key to be set.)*
4. **Run Default Example:**
Execute the main script. This will create the sample `ecommerce.db` if it doesn't exist and run the workflow with a default query.
```bash
python main.py
```
The default query is:
> Show me the names and email addresses of customers from New York
5. **Run Custom Query:**
Provide your own natural language query as command-line arguments after the script name.
```bash
python main.py What is the total stock quantity for products in the 'Accessories' category?
```
Or, for queries with spaces, ensure they are treated as a single argument by the shell if necessary (quotes might help depending on your shell):
```bash
python main.py "List orders placed in the last 30 days with status 'shipped'"
```
## How It Works
The workflow uses several nodes connected in a sequence, with a loop for debugging failed SQL queries.
```mermaid
graph LR
A[Get Schema] --> B[Generate SQL]
B --> C[Execute SQL]
C -- Success --> E[End]
C -- SQLite Error --> D{Debug SQL Attempt}
D -- Corrected SQL --> C
C -- Max Retries Reached --> F[End with Error]
style E fill:#dff,stroke:#333,stroke-width:2px
style F fill:#fdd,stroke:#333,stroke-width:2px
```
**Node Descriptions:**
1. **`GetSchema`**: Connects to the SQLite database (`ecommerce.db` by default) and extracts the schema (table names and columns).
2. **`GenerateSQL`**: Takes the natural language query and the database schema, prompts the LLM to generate an SQLite query (expecting YAML output with the SQL), and parses the result.
3. **`ExecuteSQL`**: Attempts to run the generated SQL against the database.
* If successful, the results are stored, and the flow ends successfully.
* If an `sqlite3.Error` occurs (e.g., syntax error), it captures the error message and triggers the debug loop.
4. **`DebugSQL`**: If `ExecuteSQL` failed, this node takes the original query, schema, failed SQL, and error message, prompts the LLM to generate a *corrected* SQL query (again, expecting YAML).
5. **(Loop)**: The corrected SQL from `DebugSQL` is passed back to `ExecuteSQL` for another attempt.
6. **(End Conditions)**: The loop continues until `ExecuteSQL` succeeds or the maximum number of debug attempts (default: 3) is reached.
## Files
- [`main.py`](./main.py): Main entry point to run the workflow. Handles command-line arguments for the query.
- [`flow.py`](./flow.py): Defines the PocketFlow `Flow` connecting the different nodes, including the debug loop logic.
- [`nodes.py`](./nodes.py): Contains the `Node` classes for each step (`GetSchema`, `GenerateSQL`, `ExecuteSQL`, `DebugSQL`).
- [`utils.py`](./utils.py): Contains the minimal `call_llm` utility function.
- [`populate_db.py`](./populate_db.py): Script to create and populate the sample `ecommerce.db` SQLite database.
- [`requirements.txt`](./requirements.txt): Lists Python package dependencies.
- [`README.md`](./README.md): This file.
## Example Output (Successful Run)
```
=== Starting Text-to-SQL Workflow ===
Query: 'total products per category'
Database: ecommerce.db
Max Debug Retries on SQL Error: 3
=============================================
===== DB SCHEMA =====
Table: customers
- customer_id (INTEGER)
- first_name (TEXT)
- last_name (TEXT)
- email (TEXT)
- registration_date (DATE)
- city (TEXT)
- country (TEXT)
Table: sqlite_sequence
- name ()
- seq ()
Table: products
- product_id (INTEGER)
- name (TEXT)
- description (TEXT)
- category (TEXT)
- price (REAL)
- stock_quantity (INTEGER)
Table: orders
- order_id (INTEGER)
- customer_id (INTEGER)
- order_date (TIMESTAMP)
- status (TEXT)
- total_amount (REAL)
- shipping_address (TEXT)
Table: order_items
- order_item_id (INTEGER)
- order_id (INTEGER)
- product_id (INTEGER)
- quantity (INTEGER)
- price_per_unit (REAL)
=====================
===== GENERATED SQL (Attempt 1) =====
SELECT category, COUNT(*) AS total_products
FROM products
GROUP BY category
====================================
SQL executed in 0.000 seconds.
===== SQL EXECUTION SUCCESS =====
category | total_products
-------------------------
Accessories | 3
Apparel | 1
Electronics | 3
Home Goods | 2
Sports | 1
=================================
/home/zh2408/.venv/lib/python3.9/site-packages/pocketflow/__init__.py:43: UserWarning: Flow ends: 'None' not found in ['error_retry']
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
=== Workflow Completed Successfully ===
====================================
```

Binary file not shown.

View File

@ -0,0 +1,25 @@
from pocketflow import Flow, Node
from nodes import GetSchema, GenerateSQL, ExecuteSQL, DebugSQL
def create_text_to_sql_flow():
"""Creates the text-to-SQL workflow with a debug loop."""
get_schema_node = GetSchema()
generate_sql_node = GenerateSQL()
execute_sql_node = ExecuteSQL()
debug_sql_node = DebugSQL()
# Define the main flow sequence using the default transition operator
get_schema_node >> generate_sql_node >> execute_sql_node
# --- Define the debug loop connections using the correct operator ---
# If ExecuteSQL returns "error_retry", go to DebugSQL
execute_sql_node - "error_retry" >> debug_sql_node
# If DebugSQL returns "default", go back to ExecuteSQL
# debug_sql_node - "default" >> execute_sql_node # Explicitly for "default"
# OR using the shorthand for default:
debug_sql_node >> execute_sql_node
# Create the flow
text_to_sql_flow = Flow(start=get_schema_node)
return text_to_sql_flow

View File

@ -0,0 +1,49 @@
import sys
import os
from flow import create_text_to_sql_flow
from populate_db import populate_database, DB_FILE
def run_text_to_sql(natural_query, db_path=DB_FILE, max_debug_retries=3):
if not os.path.exists(db_path) or os.path.getsize(db_path) == 0:
print(f"Database at {db_path} missing or empty. Populating...")
populate_database(db_path)
shared = {
"db_path": db_path,
"natural_query": natural_query,
"max_debug_attempts": max_debug_retries,
"debug_attempts": 0,
"final_result": None,
"final_error": None
}
print(f"\n=== Starting Text-to-SQL Workflow ===")
print(f"Query: '{natural_query}'")
print(f"Database: {db_path}")
print(f"Max Debug Retries on SQL Error: {max_debug_retries}")
print("=" * 45)
flow = create_text_to_sql_flow()
flow.run(shared) # Let errors inside the loop be handled by the flow logic
# Check final state based on shared data
if shared.get("final_error"):
print("\n=== Workflow Completed with Error ===")
print(f"Error: {shared['final_error']}")
elif shared.get("final_result") is not None:
print("\n=== Workflow Completed Successfully ===")
# Result already printed by ExecuteSQL node
else:
# Should not happen if flow logic is correct and covers all end states
print("\n=== Workflow Completed (Unknown State) ===")
print("=" * 36)
return shared
if __name__ == "__main__":
if len(sys.argv) > 1:
query = " ".join(sys.argv[1:])
else:
query = "total products per category"
run_text_to_sql(query)

View File

@ -0,0 +1,177 @@
import sqlite3
import time
import yaml # Import yaml here as nodes use it
from pocketflow import Node
from utils import call_llm
class GetSchema(Node):
def prep(self, shared):
return shared["db_path"]
def exec(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema = []
for table_name_tuple in tables:
table_name = table_name_tuple[0]
schema.append(f"Table: {table_name}")
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
for col in columns:
schema.append(f" - {col[1]} ({col[2]})")
schema.append("")
conn.close()
return "\n".join(schema).strip()
def post(self, shared, prep_res, exec_res):
shared["schema"] = exec_res
print("\n===== DB SCHEMA =====\n")
print(exec_res)
print("\n=====================\n")
# return "default"
class GenerateSQL(Node):
def prep(self, shared):
return shared["natural_query"], shared["schema"]
def exec(self, prep_res):
natural_query, schema = prep_res
prompt = f"""
Given SQLite schema:
{schema}
Question: "{natural_query}"
Respond ONLY with a YAML block containing the SQL query under the key 'sql':
```yaml
sql: |
SELECT ...
```"""
llm_response = call_llm(prompt)
yaml_str = llm_response.split("```yaml")[1].split("```")[0].strip()
structured_result = yaml.safe_load(yaml_str)
sql_query = structured_result["sql"].strip().rstrip(';')
return sql_query
def post(self, shared, prep_res, exec_res):
# exec_res is now the parsed SQL query string
shared["generated_sql"] = exec_res
# Reset debug attempts when *successfully* generating new SQL
shared["debug_attempts"] = 0
print(f"\n===== GENERATED SQL (Attempt {shared.get('debug_attempts', 0) + 1}) =====\n")
print(exec_res)
print("\n====================================\n")
# return "default"
class ExecuteSQL(Node):
def prep(self, shared):
return shared["db_path"], shared["generated_sql"]
def exec(self, prep_res):
db_path, sql_query = prep_res
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
start_time = time.time()
cursor.execute(sql_query)
is_select = sql_query.strip().upper().startswith(("SELECT", "WITH"))
if is_select:
results = cursor.fetchall()
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
else:
conn.commit()
results = f"Query OK. Rows affected: {cursor.rowcount}"
column_names = []
conn.close()
duration = time.time() - start_time
print(f"SQL executed in {duration:.3f} seconds.")
return (True, results, column_names)
except sqlite3.Error as e:
print(f"SQLite Error during execution: {e}")
if 'conn' in locals() and conn:
try:
conn.close()
except Exception:
pass
return (False, str(e), [])
def post(self, shared, prep_res, exec_res):
success, result_or_error, column_names = exec_res
if success:
shared["final_result"] = result_or_error
shared["result_columns"] = column_names
print("\n===== SQL EXECUTION SUCCESS =====\n")
# (Same result printing logic as before)
if isinstance(result_or_error, list):
if column_names: print(" | ".join(column_names)); print("-" * (sum(len(str(c)) for c in column_names) + 3 * (len(column_names) -1)))
if not result_or_error: print("(No results found)")
else:
for row in result_or_error: print(" | ".join(map(str, row)))
else: print(result_or_error)
print("\n=================================\n")
return
else:
# Execution failed (SQLite error caught in exec)
shared["execution_error"] = result_or_error # Store the error message
shared["debug_attempts"] = shared.get("debug_attempts", 0) + 1
max_attempts = shared.get("max_debug_attempts", 3) # Get max attempts from shared
print(f"\n===== SQL EXECUTION FAILED (Attempt {shared['debug_attempts']}) =====\n")
print(f"Error: {shared['execution_error']}")
print("=========================================\n")
if shared["debug_attempts"] >= max_attempts:
print(f"Max debug attempts ({max_attempts}) reached. Stopping.")
shared["final_error"] = f"Failed to execute SQL after {max_attempts} attempts. Last error: {shared['execution_error']}"
return
else:
print("Attempting to debug the SQL...")
return "error_retry" # Signal to go to DebugSQL
class DebugSQL(Node):
def prep(self, shared):
return (
shared.get("natural_query"),
shared.get("schema"),
shared.get("generated_sql"),
shared.get("execution_error")
)
def exec(self, prep_res):
natural_query, schema, failed_sql, error_message = prep_res
prompt = f"""
The following SQLite SQL query failed:
```sql
{failed_sql}
```
It was generated for: "{natural_query}"
Schema:
{schema}
Error: "{error_message}"
Provide a corrected SQLite query.
Respond ONLY with a YAML block containing the corrected SQL under the key 'sql':
```yaml
sql: |
SELECT ... -- corrected query
```"""
llm_response = call_llm(prompt)
yaml_str = llm_response.split("```yaml")[1].split("```")[0].strip()
structured_result = yaml.safe_load(yaml_str)
corrected_sql = structured_result["sql"].strip().rstrip(';')
return corrected_sql
def post(self, shared, prep_res, exec_res):
# exec_res is the corrected SQL string
shared["generated_sql"] = exec_res # Overwrite with the new attempt
shared.pop("execution_error", None) # Clear the previous error for the next ExecuteSQL attempt
print(f"\n===== REVISED SQL (Attempt {shared.get('debug_attempts', 0) + 1}) =====\n")
print(exec_res)
print("\n====================================\n")

View File

@ -0,0 +1,141 @@
import sqlite3
import os
import random
from datetime import datetime, timedelta
DB_FILE = "ecommerce.db"
def populate_database(db_file=DB_FILE):
"""Creates and populates the SQLite database."""
if os.path.exists(db_file):
os.remove(db_file)
print(f"Removed existing database: {db_file}")
conn = sqlite3.connect(db_file)
cursor = conn.cursor()
# Create Tables
cursor.execute("""
CREATE TABLE customers (
customer_id INTEGER PRIMARY KEY AUTOINCREMENT,
first_name TEXT NOT NULL,
last_name TEXT NOT NULL,
email TEXT UNIQUE NOT NULL,
registration_date DATE NOT NULL,
city TEXT,
country TEXT DEFAULT 'USA'
);
""")
print("Created 'customers' table.")
cursor.execute("""
CREATE TABLE products (
product_id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
description TEXT,
category TEXT NOT NULL,
price REAL NOT NULL CHECK (price > 0),
stock_quantity INTEGER NOT NULL DEFAULT 0 CHECK (stock_quantity >= 0)
);
""")
print("Created 'products' table.")
cursor.execute("""
CREATE TABLE orders (
order_id INTEGER PRIMARY KEY AUTOINCREMENT,
customer_id INTEGER NOT NULL,
order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
status TEXT NOT NULL CHECK (status IN ('pending', 'processing', 'shipped', 'delivered', 'cancelled')),
total_amount REAL,
shipping_address TEXT,
FOREIGN KEY (customer_id) REFERENCES customers (customer_id)
);
""")
print("Created 'orders' table.")
cursor.execute("""
CREATE TABLE order_items (
order_item_id INTEGER PRIMARY KEY AUTOINCREMENT,
order_id INTEGER NOT NULL,
product_id INTEGER NOT NULL,
quantity INTEGER NOT NULL CHECK (quantity > 0),
price_per_unit REAL NOT NULL,
FOREIGN KEY (order_id) REFERENCES orders (order_id),
FOREIGN KEY (product_id) REFERENCES products (product_id)
);
""")
print("Created 'order_items' table.")
# Insert Sample Data
customers_data = [
('Alice', 'Smith', 'alice.s@email.com', '2023-01-15', 'New York', 'USA'),
('Bob', 'Johnson', 'b.johnson@email.com', '2023-02-20', 'Los Angeles', 'USA'),
('Charlie', 'Williams', 'charlie.w@email.com', '2023-03-10', 'Chicago', 'USA'),
('Diana', 'Brown', 'diana.b@email.com', '2023-04-05', 'Houston', 'USA'),
('Ethan', 'Davis', 'ethan.d@email.com', '2023-05-12', 'Phoenix', 'USA'),
('Fiona', 'Miller', 'fiona.m@email.com', '2023-06-18', 'Philadelphia', 'USA'),
('George', 'Wilson', 'george.w@email.com', '2023-07-22', 'San Antonio', 'USA'),
('Hannah', 'Moore', 'hannah.m@email.com', '2023-08-30', 'San Diego', 'USA'),
('Ian', 'Taylor', 'ian.t@email.com', '2023-09-05', 'Dallas', 'USA'),
('Julia', 'Anderson', 'julia.a@email.com', '2023-10-11', 'San Jose', 'USA')
]
cursor.executemany("INSERT INTO customers (first_name, last_name, email, registration_date, city, country) VALUES (?, ?, ?, ?, ?, ?)", customers_data)
print(f"Inserted {len(customers_data)} customers.")
products_data = [
('Laptop Pro', 'High-end laptop for professionals', 'Electronics', 1200.00, 50),
('Wireless Mouse', 'Ergonomic wireless mouse', 'Accessories', 25.50, 200),
('Mechanical Keyboard', 'RGB backlit mechanical keyboard', 'Accessories', 75.00, 150),
('4K Monitor', '27-inch 4K UHD Monitor', 'Electronics', 350.00, 80),
('Smartphone X', 'Latest generation smartphone', 'Electronics', 999.00, 120),
('Coffee Maker', 'Drip coffee maker', 'Home Goods', 50.00, 300),
('Running Shoes', 'Comfortable running shoes', 'Apparel', 90.00, 250),
('Yoga Mat', 'Eco-friendly yoga mat', 'Sports', 30.00, 400),
('Desk Lamp', 'Adjustable LED desk lamp', 'Home Goods', 45.00, 180),
('Backpack', 'Durable backpack for travel', 'Accessories', 60.00, 220)
]
cursor.executemany("INSERT INTO products (name, description, category, price, stock_quantity) VALUES (?, ?, ?, ?, ?)", products_data)
print(f"Inserted {len(products_data)} products.")
orders_data = []
start_date = datetime.now() - timedelta(days=60)
order_statuses = ['pending', 'processing', 'shipped', 'delivered', 'cancelled']
for i in range(1, 21): # Create 20 orders
customer_id = random.randint(1, 10)
order_date = start_date + timedelta(days=random.randint(0, 59), hours=random.randint(0, 23))
status = random.choice(order_statuses)
shipping_address = f"{random.randint(100, 999)} Main St, Anytown"
orders_data.append((customer_id, order_date.strftime('%Y-%m-%d %H:%M:%S'), status, None, shipping_address)) # Total amount calculated later
cursor.executemany("INSERT INTO orders (customer_id, order_date, status, total_amount, shipping_address) VALUES (?, ?, ?, ?, ?)", orders_data)
print(f"Inserted {len(orders_data)} orders.")
order_items_data = []
order_totals = {} # Keep track of totals per order
for order_id in range(1, 21):
num_items = random.randint(1, 4)
order_total = 0
for _ in range(num_items):
product_id = random.randint(1, 10)
quantity = random.randint(1, 5)
# Get product price
cursor.execute("SELECT price FROM products WHERE product_id = ?", (product_id,))
price_per_unit = cursor.fetchone()[0]
order_items_data.append((order_id, product_id, quantity, price_per_unit))
order_total += quantity * price_per_unit
order_totals[order_id] = round(order_total, 2)
cursor.executemany("INSERT INTO order_items (order_id, product_id, quantity, price_per_unit) VALUES (?, ?, ?, ?)", order_items_data)
print(f"Inserted {len(order_items_data)} order items.")
# Update order totals
for order_id, total_amount in order_totals.items():
cursor.execute("UPDATE orders SET total_amount = ? WHERE order_id = ?", (total_amount, order_id))
print("Updated order totals.")
conn.commit()
conn.close()
print(f"Database '{db_file}' created and populated successfully.")
if __name__ == "__main__":
populate_database()

View File

@ -0,0 +1,4 @@
pocketflow>=0.0.1
openai>=1.0.0
pyyaml>=6.0
sqlite3>=3.0

View File

@ -0,0 +1,14 @@
import os
from openai import OpenAI
def call_llm(prompt):
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
r = client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": prompt}]
)
return r.choices[0].message.content
# Example usage
if __name__ == "__main__":
print(call_llm("Tell me a short joke"))