177 lines
6.5 KiB
Python
177 lines
6.5 KiB
Python
import sqlite3
|
|
import time
|
|
import yaml # Import yaml here as nodes use it
|
|
from pocketflow import Node
|
|
from utils import call_llm
|
|
|
|
class GetSchema(Node):
|
|
def prep(self, shared):
|
|
return shared["db_path"]
|
|
|
|
def exec(self, db_path):
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
|
tables = cursor.fetchall()
|
|
schema = []
|
|
for table_name_tuple in tables:
|
|
table_name = table_name_tuple[0]
|
|
schema.append(f"Table: {table_name}")
|
|
cursor.execute(f"PRAGMA table_info({table_name});")
|
|
columns = cursor.fetchall()
|
|
for col in columns:
|
|
schema.append(f" - {col[1]} ({col[2]})")
|
|
schema.append("")
|
|
conn.close()
|
|
return "\n".join(schema).strip()
|
|
|
|
def post(self, shared, prep_res, exec_res):
|
|
shared["schema"] = exec_res
|
|
print("\n===== DB SCHEMA =====\n")
|
|
print(exec_res)
|
|
print("\n=====================\n")
|
|
# return "default"
|
|
|
|
class GenerateSQL(Node):
|
|
def prep(self, shared):
|
|
return shared["natural_query"], shared["schema"]
|
|
|
|
def exec(self, prep_res):
|
|
natural_query, schema = prep_res
|
|
prompt = f"""
|
|
Given SQLite schema:
|
|
{schema}
|
|
|
|
Question: "{natural_query}"
|
|
|
|
Respond ONLY with a YAML block containing the SQL query under the key 'sql':
|
|
```yaml
|
|
sql: |
|
|
SELECT ...
|
|
```"""
|
|
llm_response = call_llm(prompt)
|
|
yaml_str = llm_response.split("```yaml")[1].split("```")[0].strip()
|
|
structured_result = yaml.safe_load(yaml_str)
|
|
sql_query = structured_result["sql"].strip().rstrip(';')
|
|
return sql_query
|
|
|
|
def post(self, shared, prep_res, exec_res):
|
|
# exec_res is now the parsed SQL query string
|
|
shared["generated_sql"] = exec_res
|
|
# Reset debug attempts when *successfully* generating new SQL
|
|
shared["debug_attempts"] = 0
|
|
print(f"\n===== GENERATED SQL (Attempt {shared.get('debug_attempts', 0) + 1}) =====\n")
|
|
print(exec_res)
|
|
print("\n====================================\n")
|
|
# return "default"
|
|
|
|
class ExecuteSQL(Node):
|
|
def prep(self, shared):
|
|
return shared["db_path"], shared["generated_sql"]
|
|
|
|
def exec(self, prep_res):
|
|
db_path, sql_query = prep_res
|
|
try:
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
start_time = time.time()
|
|
cursor.execute(sql_query)
|
|
|
|
is_select = sql_query.strip().upper().startswith(("SELECT", "WITH"))
|
|
if is_select:
|
|
results = cursor.fetchall()
|
|
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
|
|
else:
|
|
conn.commit()
|
|
results = f"Query OK. Rows affected: {cursor.rowcount}"
|
|
column_names = []
|
|
conn.close()
|
|
duration = time.time() - start_time
|
|
print(f"SQL executed in {duration:.3f} seconds.")
|
|
return (True, results, column_names)
|
|
except sqlite3.Error as e:
|
|
print(f"SQLite Error during execution: {e}")
|
|
if 'conn' in locals() and conn:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
return (False, str(e), [])
|
|
|
|
def post(self, shared, prep_res, exec_res):
|
|
success, result_or_error, column_names = exec_res
|
|
|
|
if success:
|
|
shared["final_result"] = result_or_error
|
|
shared["result_columns"] = column_names
|
|
print("\n===== SQL EXECUTION SUCCESS =====\n")
|
|
# (Same result printing logic as before)
|
|
if isinstance(result_or_error, list):
|
|
if column_names: print(" | ".join(column_names)); print("-" * (sum(len(str(c)) for c in column_names) + 3 * (len(column_names) -1)))
|
|
if not result_or_error: print("(No results found)")
|
|
else:
|
|
for row in result_or_error: print(" | ".join(map(str, row)))
|
|
else: print(result_or_error)
|
|
print("\n=================================\n")
|
|
return
|
|
else:
|
|
# Execution failed (SQLite error caught in exec)
|
|
shared["execution_error"] = result_or_error # Store the error message
|
|
shared["debug_attempts"] = shared.get("debug_attempts", 0) + 1
|
|
max_attempts = shared.get("max_debug_attempts", 3) # Get max attempts from shared
|
|
|
|
print(f"\n===== SQL EXECUTION FAILED (Attempt {shared['debug_attempts']}) =====\n")
|
|
print(f"Error: {shared['execution_error']}")
|
|
print("=========================================\n")
|
|
|
|
if shared["debug_attempts"] >= max_attempts:
|
|
print(f"Max debug attempts ({max_attempts}) reached. Stopping.")
|
|
shared["final_error"] = f"Failed to execute SQL after {max_attempts} attempts. Last error: {shared['execution_error']}"
|
|
return
|
|
else:
|
|
print("Attempting to debug the SQL...")
|
|
return "error_retry" # Signal to go to DebugSQL
|
|
|
|
class DebugSQL(Node):
|
|
def prep(self, shared):
|
|
return (
|
|
shared.get("natural_query"),
|
|
shared.get("schema"),
|
|
shared.get("generated_sql"),
|
|
shared.get("execution_error")
|
|
)
|
|
|
|
def exec(self, prep_res):
|
|
natural_query, schema, failed_sql, error_message = prep_res
|
|
prompt = f"""
|
|
The following SQLite SQL query failed:
|
|
```sql
|
|
{failed_sql}
|
|
```
|
|
It was generated for: "{natural_query}"
|
|
Schema:
|
|
{schema}
|
|
Error: "{error_message}"
|
|
|
|
Provide a corrected SQLite query.
|
|
|
|
Respond ONLY with a YAML block containing the corrected SQL under the key 'sql':
|
|
```yaml
|
|
sql: |
|
|
SELECT ... -- corrected query
|
|
```"""
|
|
llm_response = call_llm(prompt)
|
|
|
|
yaml_str = llm_response.split("```yaml")[1].split("```")[0].strip()
|
|
structured_result = yaml.safe_load(yaml_str)
|
|
corrected_sql = structured_result["sql"].strip().rstrip(';')
|
|
return corrected_sql
|
|
|
|
def post(self, shared, prep_res, exec_res):
|
|
# exec_res is the corrected SQL string
|
|
shared["generated_sql"] = exec_res # Overwrite with the new attempt
|
|
shared.pop("execution_error", None) # Clear the previous error for the next ExecuteSQL attempt
|
|
|
|
print(f"\n===== REVISED SQL (Attempt {shared.get('debug_attempts', 0) + 1}) =====\n")
|
|
print(exec_res)
|
|
print("\n====================================\n") |