import yaml from pocketflow import Node, BatchNode from utils.call_llm import call_llm from utils.code_executor import execute_python class GenerateTestCases(Node): def prep(self, shared): return shared["problem"] def exec(self, problem): prompt = f"""Generate 5-7 test cases for this coding problem: {problem} Output in this YAML format with reasoning: ```yaml reasoning: | The input parameters should be: param1 as a string, and param2 as a number. To test the function, I will consider basic cases, edge cases, and corner cases. For this problem, I need to test... test_cases: - name: "Basic case" input: {{param1: value1, param2: value2}} expected: result1 - name: "Edge case - empty" input: {{param1: value3, param2: value4}} expected: result2 ```""" response = call_llm(prompt) yaml_str = response.split("```yaml")[1].split("```")[0].strip() result = yaml.safe_load(yaml_str) # Validation asserts assert "test_cases" in result, "Result must have 'test_cases' field" assert isinstance(result["test_cases"], list), "test_cases must be a list" for i, test_case in enumerate(result["test_cases"]): assert "name" in test_case, f"Test case {i} missing 'name' field" assert isinstance(test_case["name"], str), f"Test case {i} 'name' must be string" assert "input" in test_case, f"Test case {i} missing 'input' field" assert isinstance(test_case["input"], dict), f"Test case {i} 'input' must be dict" assert "expected" in test_case, f"Test case {i} missing 'expected' field" return result def post(self, shared, prep_res, exec_res): shared["test_cases"] = exec_res["test_cases"] # Print all generated test cases print(f"\n=== Generated {len(exec_res['test_cases'])} Test Cases ===") for i, test_case in enumerate(exec_res["test_cases"], 1): print(f"{i}. {test_case['name']}") print(f" input: {test_case['input']}") print(f" expected: {test_case['expected']}") class ImplementFunction(Node): def prep(self, shared): return shared["problem"], shared["test_cases"] def exec(self, inputs): problem, test_cases = inputs # Format test cases nicely for the prompt formatted_tests = "" for i, test in enumerate(test_cases, 1): formatted_tests += f"{i}. {test['name']}\n" formatted_tests += f" input: {test['input']}\n" formatted_tests += f" expected: {test['expected']}\n\n" prompt = f"""Implement a solution for this problem: {problem} Test cases to consider: {formatted_tests} IMPORTANT: The function name must be exactly "run_code" Output in this YAML format: ```yaml reasoning: | To implement this function, I will... My approach is... function_code: | def run_code(...): # your implementation return result ```""" response = call_llm(prompt) yaml_str = response.split("```yaml")[1].split("```")[0].strip() result = yaml.safe_load(yaml_str) # Validation asserts assert "function_code" in result, "Result must have 'function_code' field" assert isinstance(result["function_code"], str), "function_code must be string" assert "def run_code" in result["function_code"], "Function must be named 'run_code'" return result["function_code"] def post(self, shared, prep_res, exec_res): shared["function_code"] = exec_res # Print the implemented function print(f"\n=== Implemented Function ===") print(exec_res) class RunTests(BatchNode): def prep(self, shared): function_code = shared["function_code"] test_cases = shared["test_cases"] # Return list of tuples (function_code, test_case) return [(function_code, test_case) for test_case in test_cases] def exec(self, test_data): function_code, test_case = test_data output, error = execute_python(function_code, test_case["input"]) if error: return { "test_case": test_case, "passed": False, "actual": None, "expected": test_case["expected"], "error": error } passed = output == test_case["expected"] return { "test_case": test_case, "passed": passed, "actual": output, "expected": test_case["expected"], "error": None if passed else f"Expected {test_case['expected']}, got {output}" } def post(self, shared, prep_res, exec_res_list): shared["test_results"] = exec_res_list all_passed = all(result["passed"] for result in exec_res_list) shared["iteration_count"] = shared.get("iteration_count", 0) + 1 # Print test results passed_count = len([r for r in exec_res_list if r["passed"]]) total_count = len(exec_res_list) print(f"\n=== Test Results: {passed_count}/{total_count} Passed ===") failed_tests = [r for r in exec_res_list if not r["passed"]] if failed_tests: print("Failed tests:") for i, result in enumerate(failed_tests, 1): test_case = result['test_case'] print(f"{i}. {test_case['name']}:") if result['error']: print(f" error: {result['error']}") else: print(f" output: {result['actual']}") print(f" expected: {result['expected']}") if all_passed: return "success" elif shared["iteration_count"] >= shared.get("max_iterations", 5): return "max_iterations" else: return "failure" class Revise(Node): def prep(self, shared): failed_tests = [r for r in shared["test_results"] if not r["passed"]] return { "problem": shared["problem"], "test_cases": shared["test_cases"], "function_code": shared["function_code"], "failed_tests": failed_tests } def exec(self, inputs): # Format current test cases nicely formatted_tests = "" for i, test in enumerate(inputs['test_cases'], 1): formatted_tests += f"{i}. {test['name']}\n" formatted_tests += f" input: {test['input']}\n" formatted_tests += f" expected: {test['expected']}\n\n" # Format failed tests nicely formatted_failures = "" for i, result in enumerate(inputs['failed_tests'], 1): test_case = result['test_case'] formatted_failures += f"{i}. {test_case['name']}:\n" if result['error']: formatted_failures += f" error: {result['error']}\n" else: formatted_failures += f" output: {result['actual']}\n" formatted_failures += f" expected: {result['expected']}\n\n" prompt = f"""Problem: {inputs['problem']} Current test cases: {formatted_tests} Current function: ```python {inputs['function_code']} ``` Failed tests: {formatted_failures} Analyze the failures and output revisions in YAML. You can revise test cases, function code, or both: ```yaml reasoning: | Looking at the failures, I see that... The issue appears to be... I will revise... test_cases: # Dictionary mapping test case index (1-based) to revised test case 1: name: "Revised test name" input: {{...}} expected: ... function_code: | # Include this if revising function def run_code(...): return ... ```""" response = call_llm(prompt) yaml_str = response.split("```yaml")[1].split("```")[0].strip() result = yaml.safe_load(yaml_str) # Validation asserts if "test_cases" in result: assert isinstance(result["test_cases"], dict), "test_cases must be a dictionary" for index_str, test_case in result["test_cases"].items(): assert isinstance(index_str, (str, int)), "test_cases keys must be strings or ints" assert "name" in test_case, f"Revised test case {index_str} missing 'name' field" assert "input" in test_case, f"Revised test case {index_str} missing 'input' field" assert "expected" in test_case, f"Revised test case {index_str} missing 'expected' field" if "function_code" in result: assert isinstance(result["function_code"], str), "function_code must be string" assert "def run_code" in result["function_code"], "Function must be named 'run_code'" return result def post(self, shared, prep_res, exec_res): # Print what is being revised print(f"\n=== Revisions (Iteration {shared['iteration_count']}) ===") # Handle test case revisions - map indices to actual test cases if "test_cases" in exec_res: current_tests = shared["test_cases"].copy() print("Revising test cases:") for index_str, revised_test in exec_res["test_cases"].items(): index = int(index_str) - 1 # Convert to 0-based if 0 <= index < len(current_tests): old_test = current_tests[index] print(f" Test {index_str}: '{old_test['name']}' -> '{revised_test['name']}'") print(f" old input: {old_test['input']}") print(f" new input: {revised_test['input']}") print(f" old expected: {old_test['expected']}") print(f" new expected: {revised_test['expected']}") current_tests[index] = revised_test shared["test_cases"] = current_tests if "function_code" in exec_res: print("Revising function code:") print("New function:") print(exec_res["function_code"]) shared["function_code"] = exec_res["function_code"]