diff --git a/cookbook/pocketflow-text2sql/README.md b/cookbook/pocketflow-text2sql/README.md new file mode 100644 index 0000000..9224fe5 --- /dev/null +++ b/cookbook/pocketflow-text2sql/README.md @@ -0,0 +1,165 @@ +# Text-to-SQL Workflow + +A PocketFlow example demonstrating a text-to-SQL workflow that converts natural language questions into executable SQL queries for an SQLite database, including an LLM-powered debugging loop for failed queries. + +## Features + +- **Schema Awareness**: Automatically retrieves the database schema to provide context to the LLM. +- **LLM-Powered SQL Generation**: Uses an LLM (GPT-4o) to translate natural language questions into SQLite queries (using YAML structured output). +- **Automated Debugging Loop**: If SQL execution fails, an LLM attempts to correct the query based on the error message. This process repeats up to a configurable number of times. +## Getting Started + +1. **Install Packages:** + ```bash + pip install -r requirements.txt + ``` + +2. **Set API Key:** + Set the environment variable for your OpenAI API key. + ```bash + export OPENAI_API_KEY="your-api-key-here" + ``` + *(Replace `"your-api-key-here"` with your actual key)* + +3. **Verify API Key (Optional):** + Run a quick check using the utility script. If successful, it will print a short joke. + ```bash + python utils.py + ``` + *(Note: This requires a valid API key to be set.)* + +4. **Run Default Example:** + Execute the main script. This will create the sample `ecommerce.db` if it doesn't exist and run the workflow with a default query. + ```bash + python main.py + ``` + The default query is: + > Show me the names and email addresses of customers from New York + +5. **Run Custom Query:** + Provide your own natural language query as command-line arguments after the script name. + ```bash + python main.py What is the total stock quantity for products in the 'Accessories' category? + ``` + Or, for queries with spaces, ensure they are treated as a single argument by the shell if necessary (quotes might help depending on your shell): + ```bash + python main.py "List orders placed in the last 30 days with status 'shipped'" + ``` + +## How It Works + +The workflow uses several nodes connected in a sequence, with a loop for debugging failed SQL queries. + +```mermaid +graph LR + A[Get Schema] --> B[Generate SQL] + B --> C[Execute SQL] + C -- Success --> E[End] + C -- SQLite Error --> D{Debug SQL Attempt} + D -- Corrected SQL --> C + C -- Max Retries Reached --> F[End with Error] + + style E fill:#dff,stroke:#333,stroke-width:2px + style F fill:#fdd,stroke:#333,stroke-width:2px + +``` + +**Node Descriptions:** + +1. **`GetSchema`**: Connects to the SQLite database (`ecommerce.db` by default) and extracts the schema (table names and columns). +2. **`GenerateSQL`**: Takes the natural language query and the database schema, prompts the LLM to generate an SQLite query (expecting YAML output with the SQL), and parses the result. +3. **`ExecuteSQL`**: Attempts to run the generated SQL against the database. + * If successful, the results are stored, and the flow ends successfully. + * If an `sqlite3.Error` occurs (e.g., syntax error), it captures the error message and triggers the debug loop. +4. **`DebugSQL`**: If `ExecuteSQL` failed, this node takes the original query, schema, failed SQL, and error message, prompts the LLM to generate a *corrected* SQL query (again, expecting YAML). +5. **(Loop)**: The corrected SQL from `DebugSQL` is passed back to `ExecuteSQL` for another attempt. +6. **(End Conditions)**: The loop continues until `ExecuteSQL` succeeds or the maximum number of debug attempts (default: 3) is reached. + +## Files + +- [`main.py`](./main.py): Main entry point to run the workflow. Handles command-line arguments for the query. +- [`flow.py`](./flow.py): Defines the PocketFlow `Flow` connecting the different nodes, including the debug loop logic. +- [`nodes.py`](./nodes.py): Contains the `Node` classes for each step (`GetSchema`, `GenerateSQL`, `ExecuteSQL`, `DebugSQL`). +- [`utils.py`](./utils.py): Contains the minimal `call_llm` utility function. +- [`populate_db.py`](./populate_db.py): Script to create and populate the sample `ecommerce.db` SQLite database. +- [`requirements.txt`](./requirements.txt): Lists Python package dependencies. +- [`README.md`](./README.md): This file. + +## Example Output (Successful Run) + +``` +=== Starting Text-to-SQL Workflow === +Query: 'total products per category' +Database: ecommerce.db +Max Debug Retries on SQL Error: 3 +============================================= + +===== DB SCHEMA ===== + +Table: customers + - customer_id (INTEGER) + - first_name (TEXT) + - last_name (TEXT) + - email (TEXT) + - registration_date (DATE) + - city (TEXT) + - country (TEXT) + +Table: sqlite_sequence + - name () + - seq () + +Table: products + - product_id (INTEGER) + - name (TEXT) + - description (TEXT) + - category (TEXT) + - price (REAL) + - stock_quantity (INTEGER) + +Table: orders + - order_id (INTEGER) + - customer_id (INTEGER) + - order_date (TIMESTAMP) + - status (TEXT) + - total_amount (REAL) + - shipping_address (TEXT) + +Table: order_items + - order_item_id (INTEGER) + - order_id (INTEGER) + - product_id (INTEGER) + - quantity (INTEGER) + - price_per_unit (REAL) + +===================== + + +===== GENERATED SQL (Attempt 1) ===== + +SELECT category, COUNT(*) AS total_products +FROM products +GROUP BY category + +==================================== + +SQL executed in 0.000 seconds. + +===== SQL EXECUTION SUCCESS ===== + +category | total_products +------------------------- +Accessories | 3 +Apparel | 1 +Electronics | 3 +Home Goods | 2 +Sports | 1 + +================================= + +/home/zh2408/.venv/lib/python3.9/site-packages/pocketflow/__init__.py:43: UserWarning: Flow ends: 'None' not found in ['error_retry'] + if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") + +=== Workflow Completed Successfully === +==================================== +``` \ No newline at end of file diff --git a/cookbook/pocketflow-text2sql/ecommerce.db b/cookbook/pocketflow-text2sql/ecommerce.db new file mode 100644 index 0000000..2003797 Binary files /dev/null and b/cookbook/pocketflow-text2sql/ecommerce.db differ diff --git a/cookbook/pocketflow-text2sql/flow.py b/cookbook/pocketflow-text2sql/flow.py new file mode 100644 index 0000000..7843571 --- /dev/null +++ b/cookbook/pocketflow-text2sql/flow.py @@ -0,0 +1,25 @@ +from pocketflow import Flow, Node +from nodes import GetSchema, GenerateSQL, ExecuteSQL, DebugSQL + +def create_text_to_sql_flow(): + """Creates the text-to-SQL workflow with a debug loop.""" + get_schema_node = GetSchema() + generate_sql_node = GenerateSQL() + execute_sql_node = ExecuteSQL() + debug_sql_node = DebugSQL() + + # Define the main flow sequence using the default transition operator + get_schema_node >> generate_sql_node >> execute_sql_node + + # --- Define the debug loop connections using the correct operator --- + # If ExecuteSQL returns "error_retry", go to DebugSQL + execute_sql_node - "error_retry" >> debug_sql_node + + # If DebugSQL returns "default", go back to ExecuteSQL + # debug_sql_node - "default" >> execute_sql_node # Explicitly for "default" + # OR using the shorthand for default: + debug_sql_node >> execute_sql_node + + # Create the flow + text_to_sql_flow = Flow(start=get_schema_node) + return text_to_sql_flow \ No newline at end of file diff --git a/cookbook/pocketflow-text2sql/main.py b/cookbook/pocketflow-text2sql/main.py new file mode 100644 index 0000000..6d5503f --- /dev/null +++ b/cookbook/pocketflow-text2sql/main.py @@ -0,0 +1,49 @@ +import sys +import os +from flow import create_text_to_sql_flow +from populate_db import populate_database, DB_FILE + +def run_text_to_sql(natural_query, db_path=DB_FILE, max_debug_retries=3): + if not os.path.exists(db_path) or os.path.getsize(db_path) == 0: + print(f"Database at {db_path} missing or empty. Populating...") + populate_database(db_path) + + shared = { + "db_path": db_path, + "natural_query": natural_query, + "max_debug_attempts": max_debug_retries, + "debug_attempts": 0, + "final_result": None, + "final_error": None + } + + print(f"\n=== Starting Text-to-SQL Workflow ===") + print(f"Query: '{natural_query}'") + print(f"Database: {db_path}") + print(f"Max Debug Retries on SQL Error: {max_debug_retries}") + print("=" * 45) + + flow = create_text_to_sql_flow() + flow.run(shared) # Let errors inside the loop be handled by the flow logic + + # Check final state based on shared data + if shared.get("final_error"): + print("\n=== Workflow Completed with Error ===") + print(f"Error: {shared['final_error']}") + elif shared.get("final_result") is not None: + print("\n=== Workflow Completed Successfully ===") + # Result already printed by ExecuteSQL node + else: + # Should not happen if flow logic is correct and covers all end states + print("\n=== Workflow Completed (Unknown State) ===") + + print("=" * 36) + return shared + +if __name__ == "__main__": + if len(sys.argv) > 1: + query = " ".join(sys.argv[1:]) + else: + query = "total products per category" + + run_text_to_sql(query) \ No newline at end of file diff --git a/cookbook/pocketflow-text2sql/nodes.py b/cookbook/pocketflow-text2sql/nodes.py new file mode 100644 index 0000000..723c262 --- /dev/null +++ b/cookbook/pocketflow-text2sql/nodes.py @@ -0,0 +1,177 @@ +import sqlite3 +import time +import yaml # Import yaml here as nodes use it +from pocketflow import Node +from utils import call_llm + +class GetSchema(Node): + def prep(self, shared): + return shared["db_path"] + + def exec(self, db_path): + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = cursor.fetchall() + schema = [] + for table_name_tuple in tables: + table_name = table_name_tuple[0] + schema.append(f"Table: {table_name}") + cursor.execute(f"PRAGMA table_info({table_name});") + columns = cursor.fetchall() + for col in columns: + schema.append(f" - {col[1]} ({col[2]})") + schema.append("") + conn.close() + return "\n".join(schema).strip() + + def post(self, shared, prep_res, exec_res): + shared["schema"] = exec_res + print("\n===== DB SCHEMA =====\n") + print(exec_res) + print("\n=====================\n") + # return "default" + +class GenerateSQL(Node): + def prep(self, shared): + return shared["natural_query"], shared["schema"] + + def exec(self, prep_res): + natural_query, schema = prep_res + prompt = f""" +Given SQLite schema: +{schema} + +Question: "{natural_query}" + +Respond ONLY with a YAML block containing the SQL query under the key 'sql': +```yaml +sql: | + SELECT ... +```""" + llm_response = call_llm(prompt) + yaml_str = llm_response.split("```yaml")[1].split("```")[0].strip() + structured_result = yaml.safe_load(yaml_str) + sql_query = structured_result["sql"].strip().rstrip(';') + return sql_query + + def post(self, shared, prep_res, exec_res): + # exec_res is now the parsed SQL query string + shared["generated_sql"] = exec_res + # Reset debug attempts when *successfully* generating new SQL + shared["debug_attempts"] = 0 + print(f"\n===== GENERATED SQL (Attempt {shared.get('debug_attempts', 0) + 1}) =====\n") + print(exec_res) + print("\n====================================\n") + # return "default" + +class ExecuteSQL(Node): + def prep(self, shared): + return shared["db_path"], shared["generated_sql"] + + def exec(self, prep_res): + db_path, sql_query = prep_res + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + start_time = time.time() + cursor.execute(sql_query) + + is_select = sql_query.strip().upper().startswith(("SELECT", "WITH")) + if is_select: + results = cursor.fetchall() + column_names = [desc[0] for desc in cursor.description] if cursor.description else [] + else: + conn.commit() + results = f"Query OK. Rows affected: {cursor.rowcount}" + column_names = [] + conn.close() + duration = time.time() - start_time + print(f"SQL executed in {duration:.3f} seconds.") + return (True, results, column_names) + except sqlite3.Error as e: + print(f"SQLite Error during execution: {e}") + if 'conn' in locals() and conn: + try: + conn.close() + except Exception: + pass + return (False, str(e), []) + + def post(self, shared, prep_res, exec_res): + success, result_or_error, column_names = exec_res + + if success: + shared["final_result"] = result_or_error + shared["result_columns"] = column_names + print("\n===== SQL EXECUTION SUCCESS =====\n") + # (Same result printing logic as before) + if isinstance(result_or_error, list): + if column_names: print(" | ".join(column_names)); print("-" * (sum(len(str(c)) for c in column_names) + 3 * (len(column_names) -1))) + if not result_or_error: print("(No results found)") + else: + for row in result_or_error: print(" | ".join(map(str, row))) + else: print(result_or_error) + print("\n=================================\n") + return + else: + # Execution failed (SQLite error caught in exec) + shared["execution_error"] = result_or_error # Store the error message + shared["debug_attempts"] = shared.get("debug_attempts", 0) + 1 + max_attempts = shared.get("max_debug_attempts", 3) # Get max attempts from shared + + print(f"\n===== SQL EXECUTION FAILED (Attempt {shared['debug_attempts']}) =====\n") + print(f"Error: {shared['execution_error']}") + print("=========================================\n") + + if shared["debug_attempts"] >= max_attempts: + print(f"Max debug attempts ({max_attempts}) reached. Stopping.") + shared["final_error"] = f"Failed to execute SQL after {max_attempts} attempts. Last error: {shared['execution_error']}" + return + else: + print("Attempting to debug the SQL...") + return "error_retry" # Signal to go to DebugSQL + +class DebugSQL(Node): + def prep(self, shared): + return ( + shared.get("natural_query"), + shared.get("schema"), + shared.get("generated_sql"), + shared.get("execution_error") + ) + + def exec(self, prep_res): + natural_query, schema, failed_sql, error_message = prep_res + prompt = f""" +The following SQLite SQL query failed: +```sql +{failed_sql} +``` +It was generated for: "{natural_query}" +Schema: +{schema} +Error: "{error_message}" + +Provide a corrected SQLite query. + +Respond ONLY with a YAML block containing the corrected SQL under the key 'sql': +```yaml +sql: | + SELECT ... -- corrected query +```""" + llm_response = call_llm(prompt) + + yaml_str = llm_response.split("```yaml")[1].split("```")[0].strip() + structured_result = yaml.safe_load(yaml_str) + corrected_sql = structured_result["sql"].strip().rstrip(';') + return corrected_sql + + def post(self, shared, prep_res, exec_res): + # exec_res is the corrected SQL string + shared["generated_sql"] = exec_res # Overwrite with the new attempt + shared.pop("execution_error", None) # Clear the previous error for the next ExecuteSQL attempt + + print(f"\n===== REVISED SQL (Attempt {shared.get('debug_attempts', 0) + 1}) =====\n") + print(exec_res) + print("\n====================================\n") \ No newline at end of file diff --git a/cookbook/pocketflow-text2sql/populate_db.py b/cookbook/pocketflow-text2sql/populate_db.py new file mode 100644 index 0000000..b1088d5 --- /dev/null +++ b/cookbook/pocketflow-text2sql/populate_db.py @@ -0,0 +1,141 @@ +import sqlite3 +import os +import random +from datetime import datetime, timedelta + +DB_FILE = "ecommerce.db" + +def populate_database(db_file=DB_FILE): + """Creates and populates the SQLite database.""" + if os.path.exists(db_file): + os.remove(db_file) + print(f"Removed existing database: {db_file}") + + conn = sqlite3.connect(db_file) + cursor = conn.cursor() + + # Create Tables + cursor.execute(""" + CREATE TABLE customers ( + customer_id INTEGER PRIMARY KEY AUTOINCREMENT, + first_name TEXT NOT NULL, + last_name TEXT NOT NULL, + email TEXT UNIQUE NOT NULL, + registration_date DATE NOT NULL, + city TEXT, + country TEXT DEFAULT 'USA' + ); + """) + print("Created 'customers' table.") + + cursor.execute(""" + CREATE TABLE products ( + product_id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + description TEXT, + category TEXT NOT NULL, + price REAL NOT NULL CHECK (price > 0), + stock_quantity INTEGER NOT NULL DEFAULT 0 CHECK (stock_quantity >= 0) + ); + """) + print("Created 'products' table.") + + cursor.execute(""" + CREATE TABLE orders ( + order_id INTEGER PRIMARY KEY AUTOINCREMENT, + customer_id INTEGER NOT NULL, + order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + status TEXT NOT NULL CHECK (status IN ('pending', 'processing', 'shipped', 'delivered', 'cancelled')), + total_amount REAL, + shipping_address TEXT, + FOREIGN KEY (customer_id) REFERENCES customers (customer_id) + ); + """) + print("Created 'orders' table.") + + cursor.execute(""" + CREATE TABLE order_items ( + order_item_id INTEGER PRIMARY KEY AUTOINCREMENT, + order_id INTEGER NOT NULL, + product_id INTEGER NOT NULL, + quantity INTEGER NOT NULL CHECK (quantity > 0), + price_per_unit REAL NOT NULL, + FOREIGN KEY (order_id) REFERENCES orders (order_id), + FOREIGN KEY (product_id) REFERENCES products (product_id) + ); + """) + print("Created 'order_items' table.") + + # Insert Sample Data + customers_data = [ + ('Alice', 'Smith', 'alice.s@email.com', '2023-01-15', 'New York', 'USA'), + ('Bob', 'Johnson', 'b.johnson@email.com', '2023-02-20', 'Los Angeles', 'USA'), + ('Charlie', 'Williams', 'charlie.w@email.com', '2023-03-10', 'Chicago', 'USA'), + ('Diana', 'Brown', 'diana.b@email.com', '2023-04-05', 'Houston', 'USA'), + ('Ethan', 'Davis', 'ethan.d@email.com', '2023-05-12', 'Phoenix', 'USA'), + ('Fiona', 'Miller', 'fiona.m@email.com', '2023-06-18', 'Philadelphia', 'USA'), + ('George', 'Wilson', 'george.w@email.com', '2023-07-22', 'San Antonio', 'USA'), + ('Hannah', 'Moore', 'hannah.m@email.com', '2023-08-30', 'San Diego', 'USA'), + ('Ian', 'Taylor', 'ian.t@email.com', '2023-09-05', 'Dallas', 'USA'), + ('Julia', 'Anderson', 'julia.a@email.com', '2023-10-11', 'San Jose', 'USA') + ] + cursor.executemany("INSERT INTO customers (first_name, last_name, email, registration_date, city, country) VALUES (?, ?, ?, ?, ?, ?)", customers_data) + print(f"Inserted {len(customers_data)} customers.") + + products_data = [ + ('Laptop Pro', 'High-end laptop for professionals', 'Electronics', 1200.00, 50), + ('Wireless Mouse', 'Ergonomic wireless mouse', 'Accessories', 25.50, 200), + ('Mechanical Keyboard', 'RGB backlit mechanical keyboard', 'Accessories', 75.00, 150), + ('4K Monitor', '27-inch 4K UHD Monitor', 'Electronics', 350.00, 80), + ('Smartphone X', 'Latest generation smartphone', 'Electronics', 999.00, 120), + ('Coffee Maker', 'Drip coffee maker', 'Home Goods', 50.00, 300), + ('Running Shoes', 'Comfortable running shoes', 'Apparel', 90.00, 250), + ('Yoga Mat', 'Eco-friendly yoga mat', 'Sports', 30.00, 400), + ('Desk Lamp', 'Adjustable LED desk lamp', 'Home Goods', 45.00, 180), + ('Backpack', 'Durable backpack for travel', 'Accessories', 60.00, 220) + ] + cursor.executemany("INSERT INTO products (name, description, category, price, stock_quantity) VALUES (?, ?, ?, ?, ?)", products_data) + print(f"Inserted {len(products_data)} products.") + + orders_data = [] + start_date = datetime.now() - timedelta(days=60) + order_statuses = ['pending', 'processing', 'shipped', 'delivered', 'cancelled'] + for i in range(1, 21): # Create 20 orders + customer_id = random.randint(1, 10) + order_date = start_date + timedelta(days=random.randint(0, 59), hours=random.randint(0, 23)) + status = random.choice(order_statuses) + shipping_address = f"{random.randint(100, 999)} Main St, Anytown" + orders_data.append((customer_id, order_date.strftime('%Y-%m-%d %H:%M:%S'), status, None, shipping_address)) # Total amount calculated later + + cursor.executemany("INSERT INTO orders (customer_id, order_date, status, total_amount, shipping_address) VALUES (?, ?, ?, ?, ?)", orders_data) + print(f"Inserted {len(orders_data)} orders.") + + order_items_data = [] + order_totals = {} # Keep track of totals per order + for order_id in range(1, 21): + num_items = random.randint(1, 4) + order_total = 0 + for _ in range(num_items): + product_id = random.randint(1, 10) + quantity = random.randint(1, 5) + # Get product price + cursor.execute("SELECT price FROM products WHERE product_id = ?", (product_id,)) + price_per_unit = cursor.fetchone()[0] + order_items_data.append((order_id, product_id, quantity, price_per_unit)) + order_total += quantity * price_per_unit + order_totals[order_id] = round(order_total, 2) + + cursor.executemany("INSERT INTO order_items (order_id, product_id, quantity, price_per_unit) VALUES (?, ?, ?, ?)", order_items_data) + print(f"Inserted {len(order_items_data)} order items.") + + # Update order totals + for order_id, total_amount in order_totals.items(): + cursor.execute("UPDATE orders SET total_amount = ? WHERE order_id = ?", (total_amount, order_id)) + print("Updated order totals.") + + conn.commit() + conn.close() + print(f"Database '{db_file}' created and populated successfully.") + +if __name__ == "__main__": + populate_database() \ No newline at end of file diff --git a/cookbook/pocketflow-text2sql/requirements.txt b/cookbook/pocketflow-text2sql/requirements.txt new file mode 100644 index 0000000..aa149b6 --- /dev/null +++ b/cookbook/pocketflow-text2sql/requirements.txt @@ -0,0 +1,4 @@ +pocketflow>=0.0.1 +openai>=1.0.0 +pyyaml>=6.0 +sqlite3>=3.0 diff --git a/cookbook/pocketflow-text2sql/utils.py b/cookbook/pocketflow-text2sql/utils.py new file mode 100644 index 0000000..3427cbb --- /dev/null +++ b/cookbook/pocketflow-text2sql/utils.py @@ -0,0 +1,14 @@ +import os +from openai import OpenAI + +def call_llm(prompt): + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key")) + r = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": prompt}] + ) + return r.choices[0].message.content + +# Example usage +if __name__ == "__main__": + print(call_llm("Tell me a short joke")) \ No newline at end of file