add text-to-sql
This commit is contained in:
parent
2b45e97514
commit
1f779bfdf7
|
|
@ -0,0 +1,165 @@
|
|||
# Text-to-SQL Workflow
|
||||
|
||||
A PocketFlow example demonstrating a text-to-SQL workflow that converts natural language questions into executable SQL queries for an SQLite database, including an LLM-powered debugging loop for failed queries.
|
||||
|
||||
## Features
|
||||
|
||||
- **Schema Awareness**: Automatically retrieves the database schema to provide context to the LLM.
|
||||
- **LLM-Powered SQL Generation**: Uses an LLM (GPT-4o) to translate natural language questions into SQLite queries (using YAML structured output).
|
||||
- **Automated Debugging Loop**: If SQL execution fails, an LLM attempts to correct the query based on the error message. This process repeats up to a configurable number of times.
|
||||
## Getting Started
|
||||
|
||||
1. **Install Packages:**
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. **Set API Key:**
|
||||
Set the environment variable for your OpenAI API key.
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
*(Replace `"your-api-key-here"` with your actual key)*
|
||||
|
||||
3. **Verify API Key (Optional):**
|
||||
Run a quick check using the utility script. If successful, it will print a short joke.
|
||||
```bash
|
||||
python utils.py
|
||||
```
|
||||
*(Note: This requires a valid API key to be set.)*
|
||||
|
||||
4. **Run Default Example:**
|
||||
Execute the main script. This will create the sample `ecommerce.db` if it doesn't exist and run the workflow with a default query.
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
The default query is:
|
||||
> Show me the names and email addresses of customers from New York
|
||||
|
||||
5. **Run Custom Query:**
|
||||
Provide your own natural language query as command-line arguments after the script name.
|
||||
```bash
|
||||
python main.py What is the total stock quantity for products in the 'Accessories' category?
|
||||
```
|
||||
Or, for queries with spaces, ensure they are treated as a single argument by the shell if necessary (quotes might help depending on your shell):
|
||||
```bash
|
||||
python main.py "List orders placed in the last 30 days with status 'shipped'"
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
The workflow uses several nodes connected in a sequence, with a loop for debugging failed SQL queries.
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
A[Get Schema] --> B[Generate SQL]
|
||||
B --> C[Execute SQL]
|
||||
C -- Success --> E[End]
|
||||
C -- SQLite Error --> D{Debug SQL Attempt}
|
||||
D -- Corrected SQL --> C
|
||||
C -- Max Retries Reached --> F[End with Error]
|
||||
|
||||
style E fill:#dff,stroke:#333,stroke-width:2px
|
||||
style F fill:#fdd,stroke:#333,stroke-width:2px
|
||||
|
||||
```
|
||||
|
||||
**Node Descriptions:**
|
||||
|
||||
1. **`GetSchema`**: Connects to the SQLite database (`ecommerce.db` by default) and extracts the schema (table names and columns).
|
||||
2. **`GenerateSQL`**: Takes the natural language query and the database schema, prompts the LLM to generate an SQLite query (expecting YAML output with the SQL), and parses the result.
|
||||
3. **`ExecuteSQL`**: Attempts to run the generated SQL against the database.
|
||||
* If successful, the results are stored, and the flow ends successfully.
|
||||
* If an `sqlite3.Error` occurs (e.g., syntax error), it captures the error message and triggers the debug loop.
|
||||
4. **`DebugSQL`**: If `ExecuteSQL` failed, this node takes the original query, schema, failed SQL, and error message, prompts the LLM to generate a *corrected* SQL query (again, expecting YAML).
|
||||
5. **(Loop)**: The corrected SQL from `DebugSQL` is passed back to `ExecuteSQL` for another attempt.
|
||||
6. **(End Conditions)**: The loop continues until `ExecuteSQL` succeeds or the maximum number of debug attempts (default: 3) is reached.
|
||||
|
||||
## Files
|
||||
|
||||
- [`main.py`](./main.py): Main entry point to run the workflow. Handles command-line arguments for the query.
|
||||
- [`flow.py`](./flow.py): Defines the PocketFlow `Flow` connecting the different nodes, including the debug loop logic.
|
||||
- [`nodes.py`](./nodes.py): Contains the `Node` classes for each step (`GetSchema`, `GenerateSQL`, `ExecuteSQL`, `DebugSQL`).
|
||||
- [`utils.py`](./utils.py): Contains the minimal `call_llm` utility function.
|
||||
- [`populate_db.py`](./populate_db.py): Script to create and populate the sample `ecommerce.db` SQLite database.
|
||||
- [`requirements.txt`](./requirements.txt): Lists Python package dependencies.
|
||||
- [`README.md`](./README.md): This file.
|
||||
|
||||
## Example Output (Successful Run)
|
||||
|
||||
```
|
||||
=== Starting Text-to-SQL Workflow ===
|
||||
Query: 'total products per category'
|
||||
Database: ecommerce.db
|
||||
Max Debug Retries on SQL Error: 3
|
||||
=============================================
|
||||
|
||||
===== DB SCHEMA =====
|
||||
|
||||
Table: customers
|
||||
- customer_id (INTEGER)
|
||||
- first_name (TEXT)
|
||||
- last_name (TEXT)
|
||||
- email (TEXT)
|
||||
- registration_date (DATE)
|
||||
- city (TEXT)
|
||||
- country (TEXT)
|
||||
|
||||
Table: sqlite_sequence
|
||||
- name ()
|
||||
- seq ()
|
||||
|
||||
Table: products
|
||||
- product_id (INTEGER)
|
||||
- name (TEXT)
|
||||
- description (TEXT)
|
||||
- category (TEXT)
|
||||
- price (REAL)
|
||||
- stock_quantity (INTEGER)
|
||||
|
||||
Table: orders
|
||||
- order_id (INTEGER)
|
||||
- customer_id (INTEGER)
|
||||
- order_date (TIMESTAMP)
|
||||
- status (TEXT)
|
||||
- total_amount (REAL)
|
||||
- shipping_address (TEXT)
|
||||
|
||||
Table: order_items
|
||||
- order_item_id (INTEGER)
|
||||
- order_id (INTEGER)
|
||||
- product_id (INTEGER)
|
||||
- quantity (INTEGER)
|
||||
- price_per_unit (REAL)
|
||||
|
||||
=====================
|
||||
|
||||
|
||||
===== GENERATED SQL (Attempt 1) =====
|
||||
|
||||
SELECT category, COUNT(*) AS total_products
|
||||
FROM products
|
||||
GROUP BY category
|
||||
|
||||
====================================
|
||||
|
||||
SQL executed in 0.000 seconds.
|
||||
|
||||
===== SQL EXECUTION SUCCESS =====
|
||||
|
||||
category | total_products
|
||||
-------------------------
|
||||
Accessories | 3
|
||||
Apparel | 1
|
||||
Electronics | 3
|
||||
Home Goods | 2
|
||||
Sports | 1
|
||||
|
||||
=================================
|
||||
|
||||
/home/zh2408/.venv/lib/python3.9/site-packages/pocketflow/__init__.py:43: UserWarning: Flow ends: 'None' not found in ['error_retry']
|
||||
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
|
||||
|
||||
=== Workflow Completed Successfully ===
|
||||
====================================
|
||||
```
|
||||
Binary file not shown.
|
|
@ -0,0 +1,25 @@
|
|||
from pocketflow import Flow, Node
|
||||
from nodes import GetSchema, GenerateSQL, ExecuteSQL, DebugSQL
|
||||
|
||||
def create_text_to_sql_flow():
|
||||
"""Creates the text-to-SQL workflow with a debug loop."""
|
||||
get_schema_node = GetSchema()
|
||||
generate_sql_node = GenerateSQL()
|
||||
execute_sql_node = ExecuteSQL()
|
||||
debug_sql_node = DebugSQL()
|
||||
|
||||
# Define the main flow sequence using the default transition operator
|
||||
get_schema_node >> generate_sql_node >> execute_sql_node
|
||||
|
||||
# --- Define the debug loop connections using the correct operator ---
|
||||
# If ExecuteSQL returns "error_retry", go to DebugSQL
|
||||
execute_sql_node - "error_retry" >> debug_sql_node
|
||||
|
||||
# If DebugSQL returns "default", go back to ExecuteSQL
|
||||
# debug_sql_node - "default" >> execute_sql_node # Explicitly for "default"
|
||||
# OR using the shorthand for default:
|
||||
debug_sql_node >> execute_sql_node
|
||||
|
||||
# Create the flow
|
||||
text_to_sql_flow = Flow(start=get_schema_node)
|
||||
return text_to_sql_flow
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
import sys
|
||||
import os
|
||||
from flow import create_text_to_sql_flow
|
||||
from populate_db import populate_database, DB_FILE
|
||||
|
||||
def run_text_to_sql(natural_query, db_path=DB_FILE, max_debug_retries=3):
|
||||
if not os.path.exists(db_path) or os.path.getsize(db_path) == 0:
|
||||
print(f"Database at {db_path} missing or empty. Populating...")
|
||||
populate_database(db_path)
|
||||
|
||||
shared = {
|
||||
"db_path": db_path,
|
||||
"natural_query": natural_query,
|
||||
"max_debug_attempts": max_debug_retries,
|
||||
"debug_attempts": 0,
|
||||
"final_result": None,
|
||||
"final_error": None
|
||||
}
|
||||
|
||||
print(f"\n=== Starting Text-to-SQL Workflow ===")
|
||||
print(f"Query: '{natural_query}'")
|
||||
print(f"Database: {db_path}")
|
||||
print(f"Max Debug Retries on SQL Error: {max_debug_retries}")
|
||||
print("=" * 45)
|
||||
|
||||
flow = create_text_to_sql_flow()
|
||||
flow.run(shared) # Let errors inside the loop be handled by the flow logic
|
||||
|
||||
# Check final state based on shared data
|
||||
if shared.get("final_error"):
|
||||
print("\n=== Workflow Completed with Error ===")
|
||||
print(f"Error: {shared['final_error']}")
|
||||
elif shared.get("final_result") is not None:
|
||||
print("\n=== Workflow Completed Successfully ===")
|
||||
# Result already printed by ExecuteSQL node
|
||||
else:
|
||||
# Should not happen if flow logic is correct and covers all end states
|
||||
print("\n=== Workflow Completed (Unknown State) ===")
|
||||
|
||||
print("=" * 36)
|
||||
return shared
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
query = " ".join(sys.argv[1:])
|
||||
else:
|
||||
query = "total products per category"
|
||||
|
||||
run_text_to_sql(query)
|
||||
|
|
@ -0,0 +1,177 @@
|
|||
import sqlite3
|
||||
import time
|
||||
import yaml # Import yaml here as nodes use it
|
||||
from pocketflow import Node
|
||||
from utils import call_llm
|
||||
|
||||
class GetSchema(Node):
|
||||
def prep(self, shared):
|
||||
return shared["db_path"]
|
||||
|
||||
def exec(self, db_path):
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
tables = cursor.fetchall()
|
||||
schema = []
|
||||
for table_name_tuple in tables:
|
||||
table_name = table_name_tuple[0]
|
||||
schema.append(f"Table: {table_name}")
|
||||
cursor.execute(f"PRAGMA table_info({table_name});")
|
||||
columns = cursor.fetchall()
|
||||
for col in columns:
|
||||
schema.append(f" - {col[1]} ({col[2]})")
|
||||
schema.append("")
|
||||
conn.close()
|
||||
return "\n".join(schema).strip()
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
shared["schema"] = exec_res
|
||||
print("\n===== DB SCHEMA =====\n")
|
||||
print(exec_res)
|
||||
print("\n=====================\n")
|
||||
# return "default"
|
||||
|
||||
class GenerateSQL(Node):
|
||||
def prep(self, shared):
|
||||
return shared["natural_query"], shared["schema"]
|
||||
|
||||
def exec(self, prep_res):
|
||||
natural_query, schema = prep_res
|
||||
prompt = f"""
|
||||
Given SQLite schema:
|
||||
{schema}
|
||||
|
||||
Question: "{natural_query}"
|
||||
|
||||
Respond ONLY with a YAML block containing the SQL query under the key 'sql':
|
||||
```yaml
|
||||
sql: |
|
||||
SELECT ...
|
||||
```"""
|
||||
llm_response = call_llm(prompt)
|
||||
yaml_str = llm_response.split("```yaml")[1].split("```")[0].strip()
|
||||
structured_result = yaml.safe_load(yaml_str)
|
||||
sql_query = structured_result["sql"].strip().rstrip(';')
|
||||
return sql_query
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
# exec_res is now the parsed SQL query string
|
||||
shared["generated_sql"] = exec_res
|
||||
# Reset debug attempts when *successfully* generating new SQL
|
||||
shared["debug_attempts"] = 0
|
||||
print(f"\n===== GENERATED SQL (Attempt {shared.get('debug_attempts', 0) + 1}) =====\n")
|
||||
print(exec_res)
|
||||
print("\n====================================\n")
|
||||
# return "default"
|
||||
|
||||
class ExecuteSQL(Node):
|
||||
def prep(self, shared):
|
||||
return shared["db_path"], shared["generated_sql"]
|
||||
|
||||
def exec(self, prep_res):
|
||||
db_path, sql_query = prep_res
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
start_time = time.time()
|
||||
cursor.execute(sql_query)
|
||||
|
||||
is_select = sql_query.strip().upper().startswith(("SELECT", "WITH"))
|
||||
if is_select:
|
||||
results = cursor.fetchall()
|
||||
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
else:
|
||||
conn.commit()
|
||||
results = f"Query OK. Rows affected: {cursor.rowcount}"
|
||||
column_names = []
|
||||
conn.close()
|
||||
duration = time.time() - start_time
|
||||
print(f"SQL executed in {duration:.3f} seconds.")
|
||||
return (True, results, column_names)
|
||||
except sqlite3.Error as e:
|
||||
print(f"SQLite Error during execution: {e}")
|
||||
if 'conn' in locals() and conn:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
return (False, str(e), [])
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
success, result_or_error, column_names = exec_res
|
||||
|
||||
if success:
|
||||
shared["final_result"] = result_or_error
|
||||
shared["result_columns"] = column_names
|
||||
print("\n===== SQL EXECUTION SUCCESS =====\n")
|
||||
# (Same result printing logic as before)
|
||||
if isinstance(result_or_error, list):
|
||||
if column_names: print(" | ".join(column_names)); print("-" * (sum(len(str(c)) for c in column_names) + 3 * (len(column_names) -1)))
|
||||
if not result_or_error: print("(No results found)")
|
||||
else:
|
||||
for row in result_or_error: print(" | ".join(map(str, row)))
|
||||
else: print(result_or_error)
|
||||
print("\n=================================\n")
|
||||
return
|
||||
else:
|
||||
# Execution failed (SQLite error caught in exec)
|
||||
shared["execution_error"] = result_or_error # Store the error message
|
||||
shared["debug_attempts"] = shared.get("debug_attempts", 0) + 1
|
||||
max_attempts = shared.get("max_debug_attempts", 3) # Get max attempts from shared
|
||||
|
||||
print(f"\n===== SQL EXECUTION FAILED (Attempt {shared['debug_attempts']}) =====\n")
|
||||
print(f"Error: {shared['execution_error']}")
|
||||
print("=========================================\n")
|
||||
|
||||
if shared["debug_attempts"] >= max_attempts:
|
||||
print(f"Max debug attempts ({max_attempts}) reached. Stopping.")
|
||||
shared["final_error"] = f"Failed to execute SQL after {max_attempts} attempts. Last error: {shared['execution_error']}"
|
||||
return
|
||||
else:
|
||||
print("Attempting to debug the SQL...")
|
||||
return "error_retry" # Signal to go to DebugSQL
|
||||
|
||||
class DebugSQL(Node):
|
||||
def prep(self, shared):
|
||||
return (
|
||||
shared.get("natural_query"),
|
||||
shared.get("schema"),
|
||||
shared.get("generated_sql"),
|
||||
shared.get("execution_error")
|
||||
)
|
||||
|
||||
def exec(self, prep_res):
|
||||
natural_query, schema, failed_sql, error_message = prep_res
|
||||
prompt = f"""
|
||||
The following SQLite SQL query failed:
|
||||
```sql
|
||||
{failed_sql}
|
||||
```
|
||||
It was generated for: "{natural_query}"
|
||||
Schema:
|
||||
{schema}
|
||||
Error: "{error_message}"
|
||||
|
||||
Provide a corrected SQLite query.
|
||||
|
||||
Respond ONLY with a YAML block containing the corrected SQL under the key 'sql':
|
||||
```yaml
|
||||
sql: |
|
||||
SELECT ... -- corrected query
|
||||
```"""
|
||||
llm_response = call_llm(prompt)
|
||||
|
||||
yaml_str = llm_response.split("```yaml")[1].split("```")[0].strip()
|
||||
structured_result = yaml.safe_load(yaml_str)
|
||||
corrected_sql = structured_result["sql"].strip().rstrip(';')
|
||||
return corrected_sql
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
# exec_res is the corrected SQL string
|
||||
shared["generated_sql"] = exec_res # Overwrite with the new attempt
|
||||
shared.pop("execution_error", None) # Clear the previous error for the next ExecuteSQL attempt
|
||||
|
||||
print(f"\n===== REVISED SQL (Attempt {shared.get('debug_attempts', 0) + 1}) =====\n")
|
||||
print(exec_res)
|
||||
print("\n====================================\n")
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
import sqlite3
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
DB_FILE = "ecommerce.db"
|
||||
|
||||
def populate_database(db_file=DB_FILE):
|
||||
"""Creates and populates the SQLite database."""
|
||||
if os.path.exists(db_file):
|
||||
os.remove(db_file)
|
||||
print(f"Removed existing database: {db_file}")
|
||||
|
||||
conn = sqlite3.connect(db_file)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create Tables
|
||||
cursor.execute("""
|
||||
CREATE TABLE customers (
|
||||
customer_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
first_name TEXT NOT NULL,
|
||||
last_name TEXT NOT NULL,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
registration_date DATE NOT NULL,
|
||||
city TEXT,
|
||||
country TEXT DEFAULT 'USA'
|
||||
);
|
||||
""")
|
||||
print("Created 'customers' table.")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE products (
|
||||
product_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
category TEXT NOT NULL,
|
||||
price REAL NOT NULL CHECK (price > 0),
|
||||
stock_quantity INTEGER NOT NULL DEFAULT 0 CHECK (stock_quantity >= 0)
|
||||
);
|
||||
""")
|
||||
print("Created 'products' table.")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE orders (
|
||||
order_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
customer_id INTEGER NOT NULL,
|
||||
order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
status TEXT NOT NULL CHECK (status IN ('pending', 'processing', 'shipped', 'delivered', 'cancelled')),
|
||||
total_amount REAL,
|
||||
shipping_address TEXT,
|
||||
FOREIGN KEY (customer_id) REFERENCES customers (customer_id)
|
||||
);
|
||||
""")
|
||||
print("Created 'orders' table.")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE order_items (
|
||||
order_item_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
order_id INTEGER NOT NULL,
|
||||
product_id INTEGER NOT NULL,
|
||||
quantity INTEGER NOT NULL CHECK (quantity > 0),
|
||||
price_per_unit REAL NOT NULL,
|
||||
FOREIGN KEY (order_id) REFERENCES orders (order_id),
|
||||
FOREIGN KEY (product_id) REFERENCES products (product_id)
|
||||
);
|
||||
""")
|
||||
print("Created 'order_items' table.")
|
||||
|
||||
# Insert Sample Data
|
||||
customers_data = [
|
||||
('Alice', 'Smith', 'alice.s@email.com', '2023-01-15', 'New York', 'USA'),
|
||||
('Bob', 'Johnson', 'b.johnson@email.com', '2023-02-20', 'Los Angeles', 'USA'),
|
||||
('Charlie', 'Williams', 'charlie.w@email.com', '2023-03-10', 'Chicago', 'USA'),
|
||||
('Diana', 'Brown', 'diana.b@email.com', '2023-04-05', 'Houston', 'USA'),
|
||||
('Ethan', 'Davis', 'ethan.d@email.com', '2023-05-12', 'Phoenix', 'USA'),
|
||||
('Fiona', 'Miller', 'fiona.m@email.com', '2023-06-18', 'Philadelphia', 'USA'),
|
||||
('George', 'Wilson', 'george.w@email.com', '2023-07-22', 'San Antonio', 'USA'),
|
||||
('Hannah', 'Moore', 'hannah.m@email.com', '2023-08-30', 'San Diego', 'USA'),
|
||||
('Ian', 'Taylor', 'ian.t@email.com', '2023-09-05', 'Dallas', 'USA'),
|
||||
('Julia', 'Anderson', 'julia.a@email.com', '2023-10-11', 'San Jose', 'USA')
|
||||
]
|
||||
cursor.executemany("INSERT INTO customers (first_name, last_name, email, registration_date, city, country) VALUES (?, ?, ?, ?, ?, ?)", customers_data)
|
||||
print(f"Inserted {len(customers_data)} customers.")
|
||||
|
||||
products_data = [
|
||||
('Laptop Pro', 'High-end laptop for professionals', 'Electronics', 1200.00, 50),
|
||||
('Wireless Mouse', 'Ergonomic wireless mouse', 'Accessories', 25.50, 200),
|
||||
('Mechanical Keyboard', 'RGB backlit mechanical keyboard', 'Accessories', 75.00, 150),
|
||||
('4K Monitor', '27-inch 4K UHD Monitor', 'Electronics', 350.00, 80),
|
||||
('Smartphone X', 'Latest generation smartphone', 'Electronics', 999.00, 120),
|
||||
('Coffee Maker', 'Drip coffee maker', 'Home Goods', 50.00, 300),
|
||||
('Running Shoes', 'Comfortable running shoes', 'Apparel', 90.00, 250),
|
||||
('Yoga Mat', 'Eco-friendly yoga mat', 'Sports', 30.00, 400),
|
||||
('Desk Lamp', 'Adjustable LED desk lamp', 'Home Goods', 45.00, 180),
|
||||
('Backpack', 'Durable backpack for travel', 'Accessories', 60.00, 220)
|
||||
]
|
||||
cursor.executemany("INSERT INTO products (name, description, category, price, stock_quantity) VALUES (?, ?, ?, ?, ?)", products_data)
|
||||
print(f"Inserted {len(products_data)} products.")
|
||||
|
||||
orders_data = []
|
||||
start_date = datetime.now() - timedelta(days=60)
|
||||
order_statuses = ['pending', 'processing', 'shipped', 'delivered', 'cancelled']
|
||||
for i in range(1, 21): # Create 20 orders
|
||||
customer_id = random.randint(1, 10)
|
||||
order_date = start_date + timedelta(days=random.randint(0, 59), hours=random.randint(0, 23))
|
||||
status = random.choice(order_statuses)
|
||||
shipping_address = f"{random.randint(100, 999)} Main St, Anytown"
|
||||
orders_data.append((customer_id, order_date.strftime('%Y-%m-%d %H:%M:%S'), status, None, shipping_address)) # Total amount calculated later
|
||||
|
||||
cursor.executemany("INSERT INTO orders (customer_id, order_date, status, total_amount, shipping_address) VALUES (?, ?, ?, ?, ?)", orders_data)
|
||||
print(f"Inserted {len(orders_data)} orders.")
|
||||
|
||||
order_items_data = []
|
||||
order_totals = {} # Keep track of totals per order
|
||||
for order_id in range(1, 21):
|
||||
num_items = random.randint(1, 4)
|
||||
order_total = 0
|
||||
for _ in range(num_items):
|
||||
product_id = random.randint(1, 10)
|
||||
quantity = random.randint(1, 5)
|
||||
# Get product price
|
||||
cursor.execute("SELECT price FROM products WHERE product_id = ?", (product_id,))
|
||||
price_per_unit = cursor.fetchone()[0]
|
||||
order_items_data.append((order_id, product_id, quantity, price_per_unit))
|
||||
order_total += quantity * price_per_unit
|
||||
order_totals[order_id] = round(order_total, 2)
|
||||
|
||||
cursor.executemany("INSERT INTO order_items (order_id, product_id, quantity, price_per_unit) VALUES (?, ?, ?, ?)", order_items_data)
|
||||
print(f"Inserted {len(order_items_data)} order items.")
|
||||
|
||||
# Update order totals
|
||||
for order_id, total_amount in order_totals.items():
|
||||
cursor.execute("UPDATE orders SET total_amount = ? WHERE order_id = ?", (total_amount, order_id))
|
||||
print("Updated order totals.")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print(f"Database '{db_file}' created and populated successfully.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
populate_database()
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
pocketflow>=0.0.1
|
||||
openai>=1.0.0
|
||||
pyyaml>=6.0
|
||||
sqlite3>=3.0
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
import os
|
||||
from openai import OpenAI
|
||||
|
||||
def call_llm(prompt):
|
||||
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
|
||||
r = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
return r.choices[0].message.content
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
print(call_llm("Tell me a short joke"))
|
||||
Loading…
Reference in New Issue