pocketflow/cookbook/pocketflow-text2sql/nodes.py

177 lines
6.5 KiB
Python

import sqlite3
import time
import yaml # Import yaml here as nodes use it
from pocketflow import Node
from utils.call_llm import call_llm
class GetSchema(Node):
def prep(self, shared):
return shared["db_path"]
def exec(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema = []
for table_name_tuple in tables:
table_name = table_name_tuple[0]
schema.append(f"Table: {table_name}")
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
for col in columns:
schema.append(f" - {col[1]} ({col[2]})")
schema.append("")
conn.close()
return "\n".join(schema).strip()
def post(self, shared, prep_res, exec_res):
shared["schema"] = exec_res
print("\n===== DB SCHEMA =====\n")
print(exec_res)
print("\n=====================\n")
# return "default"
class GenerateSQL(Node):
def prep(self, shared):
return shared["natural_query"], shared["schema"]
def exec(self, prep_res):
natural_query, schema = prep_res
prompt = f"""
Given SQLite schema:
{schema}
Question: "{natural_query}"
Respond ONLY with a YAML block containing the SQL query under the key 'sql':
```yaml
sql: |
SELECT ...
```"""
llm_response = call_llm(prompt)
yaml_str = llm_response.split("```yaml")[1].split("```")[0].strip()
structured_result = yaml.safe_load(yaml_str)
sql_query = structured_result["sql"].strip().rstrip(';')
return sql_query
def post(self, shared, prep_res, exec_res):
# exec_res is now the parsed SQL query string
shared["generated_sql"] = exec_res
# Reset debug attempts when *successfully* generating new SQL
shared["debug_attempts"] = 0
print(f"\n===== GENERATED SQL (Attempt {shared.get('debug_attempts', 0) + 1}) =====\n")
print(exec_res)
print("\n====================================\n")
# return "default"
class ExecuteSQL(Node):
def prep(self, shared):
return shared["db_path"], shared["generated_sql"]
def exec(self, prep_res):
db_path, sql_query = prep_res
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
start_time = time.time()
cursor.execute(sql_query)
is_select = sql_query.strip().upper().startswith(("SELECT", "WITH"))
if is_select:
results = cursor.fetchall()
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
else:
conn.commit()
results = f"Query OK. Rows affected: {cursor.rowcount}"
column_names = []
conn.close()
duration = time.time() - start_time
print(f"SQL executed in {duration:.3f} seconds.")
return (True, results, column_names)
except sqlite3.Error as e:
print(f"SQLite Error during execution: {e}")
if 'conn' in locals() and conn:
try:
conn.close()
except Exception:
pass
return (False, str(e), [])
def post(self, shared, prep_res, exec_res):
success, result_or_error, column_names = exec_res
if success:
shared["final_result"] = result_or_error
shared["result_columns"] = column_names
print("\n===== SQL EXECUTION SUCCESS =====\n")
# (Same result printing logic as before)
if isinstance(result_or_error, list):
if column_names: print(" | ".join(column_names)); print("-" * (sum(len(str(c)) for c in column_names) + 3 * (len(column_names) -1)))
if not result_or_error: print("(No results found)")
else:
for row in result_or_error: print(" | ".join(map(str, row)))
else: print(result_or_error)
print("\n=================================\n")
return
else:
# Execution failed (SQLite error caught in exec)
shared["execution_error"] = result_or_error # Store the error message
shared["debug_attempts"] = shared.get("debug_attempts", 0) + 1
max_attempts = shared.get("max_debug_attempts", 3) # Get max attempts from shared
print(f"\n===== SQL EXECUTION FAILED (Attempt {shared['debug_attempts']}) =====\n")
print(f"Error: {shared['execution_error']}")
print("=========================================\n")
if shared["debug_attempts"] >= max_attempts:
print(f"Max debug attempts ({max_attempts}) reached. Stopping.")
shared["final_error"] = f"Failed to execute SQL after {max_attempts} attempts. Last error: {shared['execution_error']}"
return
else:
print("Attempting to debug the SQL...")
return "error_retry" # Signal to go to DebugSQL
class DebugSQL(Node):
def prep(self, shared):
return (
shared.get("natural_query"),
shared.get("schema"),
shared.get("generated_sql"),
shared.get("execution_error")
)
def exec(self, prep_res):
natural_query, schema, failed_sql, error_message = prep_res
prompt = f"""
The following SQLite SQL query failed:
```sql
{failed_sql}
```
It was generated for: "{natural_query}"
Schema:
{schema}
Error: "{error_message}"
Provide a corrected SQLite query.
Respond ONLY with a YAML block containing the corrected SQL under the key 'sql':
```yaml
sql: |
SELECT ... -- corrected query
```"""
llm_response = call_llm(prompt)
yaml_str = llm_response.split("```yaml")[1].split("```")[0].strip()
structured_result = yaml.safe_load(yaml_str)
corrected_sql = structured_result["sql"].strip().rstrip(';')
return corrected_sql
def post(self, shared, prep_res, exec_res):
# exec_res is the corrected SQL string
shared["generated_sql"] = exec_res # Overwrite with the new attempt
shared.pop("execution_error", None) # Clear the previous error for the next ExecuteSQL attempt
print(f"\n===== REVISED SQL (Attempt {shared.get('debug_attempts', 0) + 1}) =====\n")
print(exec_res)
print("\n====================================\n")