code generator implementation
This commit is contained in:
parent
ac4381cae9
commit
a8dadfa946
|
|
@ -0,0 +1,161 @@
|
|||
# PocketFlow Code Generator
|
||||
|
||||
An intelligent AI system that takes LeetCode-style coding problems and automatically generates comprehensive test cases, implements solutions, and iteratively improves them until all tests pass.
|
||||
|
||||
## Features
|
||||
|
||||
- **Automatic Test Case Generation**: Creates diverse test cases including edge cases
|
||||
- **Intelligent Code Implementation**: Generates `run_code` functions with proper algorithms
|
||||
- **Iterative Improvement**: Analyzes failures and decides whether to revise tests or code
|
||||
- **Rich Debugging Output**: Detailed progress tracking and validation
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install required dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Set up your Anthropic API key:
|
||||
```bash
|
||||
export ANTHROPIC_API_KEY="your-api-key-here"
|
||||
```
|
||||
Test your API key is working:
|
||||
```bash
|
||||
python utils/call_llm.py
|
||||
```
|
||||
|
||||
3. Run the code generator with the default Two Sum problem:
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
4. Or provide your own problem:
|
||||
```bash
|
||||
python main.py "Reverse a linked list. Given the head of a singly linked list, reverse the list and return the reversed list."
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
The system follows an intelligent workflow combining **Agent** and **Workflow** design patterns:
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
start[Problem Input] --> generateTests[Generate Test Cases]
|
||||
generateTests --> implement[Implement Function]
|
||||
implement --> runTests[Run Tests - Batch]
|
||||
runTests --> decision{All Tests Pass?}
|
||||
decision -->|Yes| success[Success!]
|
||||
decision -->|No| revise[Revise - Agent Decision]
|
||||
revise --> runTests
|
||||
decision -->|Max Iterations| maxIter[Max Iterations Reached]
|
||||
```
|
||||
|
||||
### The Process
|
||||
|
||||
1. **GenerateTestCases**: Creates 5-7 comprehensive test cases from problem description
|
||||
2. **ImplementFunction**: Writes a `run_code` function based on problem and test cases
|
||||
3. **RunTests**: Executes function against all test cases using batch processing
|
||||
4. **Revise**: Analyzes failures and makes intelligent decisions to revise test cases and/or function code
|
||||
5. **Loop**: Continues until all tests pass or max iterations reached
|
||||
|
||||
## Sample Output
|
||||
|
||||
Here's what you'll see when running the Two Sum example:
|
||||
|
||||
```
|
||||
Starting PocketFlow Code Generator...
|
||||
|
||||
=== Generated 7 Test Cases ===
|
||||
1. Basic case - solution at beginning
|
||||
input: {'nums': [2, 7, 11, 15], 'target': 9}
|
||||
expected: [0, 1]
|
||||
2. Basic case - solution in middle
|
||||
input: {'nums': [3, 2, 4], 'target': 6}
|
||||
expected: [1, 2]
|
||||
3. Edge case - minimum array size with duplicates
|
||||
input: {'nums': [3, 3], 'target': 6}
|
||||
expected: [0, 1]
|
||||
4. Case with negative numbers
|
||||
input: {'nums': [-1, -2, -3, -4, -5], 'target': -8}
|
||||
expected: [2, 4]
|
||||
5. Case with zero and negative target
|
||||
input: {'nums': [0, 4, 3, 0], 'target': 0}
|
||||
expected: [0, 3]
|
||||
6. Case with solution at the end
|
||||
input: {'nums': [1, 2, 3, 4, 5, 6], 'target': 11}
|
||||
expected: [4, 5]
|
||||
7. Larger array case
|
||||
input: {'nums': [5, 75, 25, 45, 42, 2, 11, 9, 55, 12], 'target': 14}
|
||||
expected: [2, 6]
|
||||
|
||||
=== Implemented Function ===
|
||||
def run_code(nums, target):
|
||||
# Dictionary to store number -> index mapping
|
||||
num_to_index = {}
|
||||
|
||||
# Iterate through the array
|
||||
for i, num in enumerate(nums):
|
||||
# Calculate what number we need to reach the target
|
||||
complement = target - num
|
||||
|
||||
# Check if the complement exists in our map
|
||||
if complement in num_to_index:
|
||||
# Found the pair! Return indices
|
||||
return [num_to_index[complement], i]
|
||||
|
||||
# Store current number and its index
|
||||
num_to_index[num] = i
|
||||
|
||||
# Should never reach here given problem constraints
|
||||
return []
|
||||
|
||||
=== Test Results: 6/7 Passed ===
|
||||
Failed tests:
|
||||
1. Larger array case:
|
||||
error: Expected [2, 6], got [0, 7]
|
||||
expected: [2, 6]
|
||||
|
||||
=== Revisions (Iteration 1) ===
|
||||
Revising test cases:
|
||||
Test 7: 'Larger array case' -> 'Larger array case'
|
||||
old input: {'nums': [5, 75, 25, 45, 42, 2, 11, 9, 55, 12], 'target': 14}
|
||||
new input: {'nums': [5, 75, 25, 45, 42, 2, 11, 9, 55, 12], 'target': 14}
|
||||
old expected: [2, 6]
|
||||
new expected: [0, 7]
|
||||
|
||||
=== Test Results: 7/7 Passed ===
|
||||
```
|
||||
|
||||
## Key Features
|
||||
|
||||
### Intelligent Decision Making
|
||||
The **Revise** node acts as an agent that analyzes test failures and decides whether to:
|
||||
- Fix test cases (if they have incorrect expected outputs)
|
||||
- Fix the function implementation (if the logic is wrong)
|
||||
- Or both
|
||||
|
||||
### Structured Output with Validation
|
||||
All LLM interactions use YAML format with:
|
||||
- **Reasoning fields**: Transparent decision-making process
|
||||
- **Validation asserts**: Ensures outputs match expected structure
|
||||
- **Rich debugging**: Comprehensive logging of all steps
|
||||
|
||||
### Batch Processing
|
||||
The **RunTests** node uses PocketFlow's BatchNode to efficiently test the function against all test cases in parallel.
|
||||
|
||||
## Files
|
||||
|
||||
- [`main.py`](./main.py): Entry point with sample Two Sum problem
|
||||
- [`flow.py`](./flow.py): Connects all nodes into the complete workflow
|
||||
- [`nodes.py`](./nodes.py): Core logic nodes with validation and debugging
|
||||
- [`utils/call_llm.py`](./utils/call_llm.py): Anthropic Claude API wrapper
|
||||
- [`utils/code_executor.py`](./utils/code_executor.py): Safe Python code execution utility
|
||||
- [`doc/design.md`](./doc/design.md): Detailed system design documentation
|
||||
|
||||
## Design Patterns Used
|
||||
|
||||
- **[Workflow](https://the-pocket.github.io/PocketFlow/design_pattern/workflow.html)**: Sequential steps of test generation → coding → testing
|
||||
- **[Agent](https://the-pocket.github.io/PocketFlow/design_pattern/agent.html)**: Intelligent decision-making when tests fail
|
||||
- **[Batch](https://the-pocket.github.io/PocketFlow/core_abstraction/batch.html)**: Efficient parallel test execution
|
||||
- **[Structured Output](https://the-pocket.github.io/PocketFlow/design_pattern/structure.html)**: YAML validation for reliable LLM outputs
|
||||
|
|
@ -81,7 +81,8 @@ The shared memory structure is organized as follows:
|
|||
shared = {
|
||||
"problem": "Given an array of integers nums and an integer target, return indices of the two numbers such that they add up to target.",
|
||||
"test_cases": [
|
||||
{"input": {"nums": [2,7,11,15], "target": 9}, "expected": [0,1]},
|
||||
{"name": "Basic case", "input": {"nums": [2,7,11,15], "target": 9}, "expected": [0,1]},
|
||||
{"name": "Different order", "input": {"nums": [3,2,4], "target": 6}, "expected": [1,2]},
|
||||
# ... more test cases
|
||||
],
|
||||
"function_code": "def run_code(nums, target): ...",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
from pocketflow import Flow
|
||||
from nodes import GenerateTestCases, ImplementFunction, RunTests, Revise
|
||||
|
||||
def create_code_generator_flow():
|
||||
"""Creates and returns the code generator flow."""
|
||||
# Create nodes
|
||||
generate_tests = GenerateTestCases()
|
||||
implement_function = ImplementFunction()
|
||||
run_tests = RunTests()
|
||||
revise = Revise()
|
||||
|
||||
# Define transitions
|
||||
generate_tests >> implement_function
|
||||
implement_function >> run_tests
|
||||
run_tests - "failure" >> revise
|
||||
revise >> run_tests
|
||||
|
||||
# Create flow starting with test generation
|
||||
flow = Flow(start=generate_tests)
|
||||
return flow
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
import sys
|
||||
from flow import create_code_generator_flow
|
||||
|
||||
def main():
|
||||
"""Runs the PocketFlow Code Generator application."""
|
||||
print("Starting PocketFlow Code Generator...")
|
||||
|
||||
# Check if problem is provided as argument
|
||||
if len(sys.argv) > 1:
|
||||
problem = " ".join(sys.argv[1:])
|
||||
else:
|
||||
# Default Two Sum problem
|
||||
problem = """Two Sum
|
||||
|
||||
Given an array of integers nums and an integer target, return indices of the two numbers such that they add up to target.
|
||||
|
||||
You may assume that each input would have exactly one solution, and you may not use the same element twice.
|
||||
|
||||
Example 1:
|
||||
Input: nums = [2,7,11,15], target = 9
|
||||
Output: [0,1]
|
||||
|
||||
Example 2:
|
||||
Input: nums = [3,2,4], target = 6
|
||||
Output: [1,2]
|
||||
|
||||
Example 3:
|
||||
Input: nums = [3,3], target = 6
|
||||
Output: [0,1]"""
|
||||
|
||||
shared = {
|
||||
"problem": problem,
|
||||
"test_cases": [], # Will be populated with [{name, input, expected}, ...]
|
||||
"function_code": "",
|
||||
"test_results": [],
|
||||
"iteration_count": 0,
|
||||
"max_iterations": 5
|
||||
}
|
||||
|
||||
# Create and run the flow
|
||||
flow = create_code_generator_flow()
|
||||
flow.run(shared)
|
||||
|
||||
print("\n=== Final Results ===")
|
||||
print(f"Problem: {shared['problem'][:50]}...")
|
||||
print(f"Iterations: {shared['iteration_count']}")
|
||||
print(f"Function:\n{shared['function_code']}")
|
||||
print(f"Test Results: {len([r for r in shared['test_results'] if r['passed']])}/{len(shared['test_results'])} passed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
import yaml
|
||||
from pocketflow import Node, BatchNode
|
||||
from utils.call_llm import call_llm
|
||||
from utils.code_executor import execute_python
|
||||
|
||||
class GenerateTestCases(Node):
|
||||
def prep(self, shared):
|
||||
return shared["problem"]
|
||||
|
||||
def exec(self, problem):
|
||||
prompt = f"""Generate 5-7 test cases for this coding problem:
|
||||
|
||||
{problem}
|
||||
|
||||
Output in this YAML format with reasoning:
|
||||
```yaml
|
||||
reasoning: |
|
||||
The input parameters should be: param1 as a string, and param2 as a number.
|
||||
To test the function, I will consider basic cases, edge cases, and corner cases.
|
||||
For this problem, I need to test...
|
||||
test_cases:
|
||||
- name: "Basic case"
|
||||
input: {{param1: value1, param2: value2}}
|
||||
expected: result1
|
||||
- name: "Edge case - empty"
|
||||
input: {{param1: value3, param2: value4}}
|
||||
expected: result2
|
||||
```"""
|
||||
response = call_llm(prompt)
|
||||
yaml_str = response.split("```yaml")[1].split("```")[0].strip()
|
||||
result = yaml.safe_load(yaml_str)
|
||||
|
||||
# Validation asserts
|
||||
assert "test_cases" in result, "Result must have 'test_cases' field"
|
||||
assert isinstance(result["test_cases"], list), "test_cases must be a list"
|
||||
|
||||
for i, test_case in enumerate(result["test_cases"]):
|
||||
assert "name" in test_case, f"Test case {i} missing 'name' field"
|
||||
assert isinstance(test_case["name"], str), f"Test case {i} 'name' must be string"
|
||||
assert "input" in test_case, f"Test case {i} missing 'input' field"
|
||||
assert isinstance(test_case["input"], dict), f"Test case {i} 'input' must be dict"
|
||||
assert "expected" in test_case, f"Test case {i} missing 'expected' field"
|
||||
|
||||
return result
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
shared["test_cases"] = exec_res["test_cases"]
|
||||
|
||||
# Print all generated test cases
|
||||
print(f"\n=== Generated {len(exec_res['test_cases'])} Test Cases ===")
|
||||
for i, test_case in enumerate(exec_res["test_cases"], 1):
|
||||
print(f"{i}. {test_case['name']}")
|
||||
print(f" input: {test_case['input']}")
|
||||
print(f" expected: {test_case['expected']}")
|
||||
|
||||
class ImplementFunction(Node):
|
||||
def prep(self, shared):
|
||||
return shared["problem"], shared["test_cases"]
|
||||
|
||||
def exec(self, inputs):
|
||||
problem, test_cases = inputs
|
||||
|
||||
# Format test cases nicely for the prompt
|
||||
formatted_tests = ""
|
||||
for i, test in enumerate(test_cases, 1):
|
||||
formatted_tests += f"{i}. {test['name']}\n"
|
||||
formatted_tests += f" input: {test['input']}\n"
|
||||
formatted_tests += f" expected: {test['expected']}\n\n"
|
||||
|
||||
prompt = f"""Implement a solution for this problem:
|
||||
|
||||
{problem}
|
||||
|
||||
Test cases to consider:
|
||||
{formatted_tests}
|
||||
|
||||
IMPORTANT: The function name must be exactly "run_code"
|
||||
|
||||
Output in this YAML format:
|
||||
```yaml
|
||||
reasoning: |
|
||||
To implement this function, I will...
|
||||
My approach is...
|
||||
function_code: |
|
||||
def run_code(...):
|
||||
# your implementation
|
||||
return result
|
||||
```"""
|
||||
response = call_llm(prompt)
|
||||
yaml_str = response.split("```yaml")[1].split("```")[0].strip()
|
||||
result = yaml.safe_load(yaml_str)
|
||||
|
||||
# Validation asserts
|
||||
assert "function_code" in result, "Result must have 'function_code' field"
|
||||
assert isinstance(result["function_code"], str), "function_code must be string"
|
||||
assert "def run_code" in result["function_code"], "Function must be named 'run_code'"
|
||||
|
||||
return result["function_code"]
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
shared["function_code"] = exec_res
|
||||
|
||||
# Print the implemented function
|
||||
print(f"\n=== Implemented Function ===")
|
||||
print(exec_res)
|
||||
|
||||
class RunTests(BatchNode):
|
||||
def prep(self, shared):
|
||||
function_code = shared["function_code"]
|
||||
test_cases = shared["test_cases"]
|
||||
# Return list of tuples (function_code, test_case)
|
||||
return [(function_code, test_case) for test_case in test_cases]
|
||||
|
||||
def exec(self, test_data):
|
||||
function_code, test_case = test_data
|
||||
output, error = execute_python(function_code, test_case["input"])
|
||||
|
||||
if error:
|
||||
return {
|
||||
"test_case": test_case,
|
||||
"passed": False,
|
||||
"actual": None,
|
||||
"expected": test_case["expected"],
|
||||
"error": error
|
||||
}
|
||||
|
||||
passed = output == test_case["expected"]
|
||||
return {
|
||||
"test_case": test_case,
|
||||
"passed": passed,
|
||||
"actual": output,
|
||||
"expected": test_case["expected"],
|
||||
"error": None if passed else f"Expected {test_case['expected']}, got {output}"
|
||||
}
|
||||
|
||||
def post(self, shared, prep_res, exec_res_list):
|
||||
shared["test_results"] = exec_res_list
|
||||
all_passed = all(result["passed"] for result in exec_res_list)
|
||||
shared["iteration_count"] = shared.get("iteration_count", 0) + 1
|
||||
|
||||
# Print test results
|
||||
passed_count = len([r for r in exec_res_list if r["passed"]])
|
||||
total_count = len(exec_res_list)
|
||||
print(f"\n=== Test Results: {passed_count}/{total_count} Passed ===")
|
||||
|
||||
failed_tests = [r for r in exec_res_list if not r["passed"]]
|
||||
if failed_tests:
|
||||
print("Failed tests:")
|
||||
for i, result in enumerate(failed_tests, 1):
|
||||
test_case = result['test_case']
|
||||
print(f"{i}. {test_case['name']}:")
|
||||
if result['error']:
|
||||
print(f" error: {result['error']}")
|
||||
else:
|
||||
print(f" output: {result['actual']}")
|
||||
print(f" expected: {result['expected']}")
|
||||
|
||||
if all_passed:
|
||||
return "success"
|
||||
elif shared["iteration_count"] >= shared.get("max_iterations", 5):
|
||||
return "max_iterations"
|
||||
else:
|
||||
return "failure"
|
||||
|
||||
class Revise(Node):
|
||||
def prep(self, shared):
|
||||
failed_tests = [r for r in shared["test_results"] if not r["passed"]]
|
||||
return {
|
||||
"problem": shared["problem"],
|
||||
"test_cases": shared["test_cases"],
|
||||
"function_code": shared["function_code"],
|
||||
"failed_tests": failed_tests
|
||||
}
|
||||
|
||||
def exec(self, inputs):
|
||||
# Format current test cases nicely
|
||||
formatted_tests = ""
|
||||
for i, test in enumerate(inputs['test_cases'], 1):
|
||||
formatted_tests += f"{i}. {test['name']}\n"
|
||||
formatted_tests += f" input: {test['input']}\n"
|
||||
formatted_tests += f" expected: {test['expected']}\n\n"
|
||||
|
||||
# Format failed tests nicely
|
||||
formatted_failures = ""
|
||||
for i, result in enumerate(inputs['failed_tests'], 1):
|
||||
test_case = result['test_case']
|
||||
formatted_failures += f"{i}. {test_case['name']}:\n"
|
||||
if result['error']:
|
||||
formatted_failures += f" error: {result['error']}\n"
|
||||
else:
|
||||
formatted_failures += f" output: {result['actual']}\n"
|
||||
formatted_failures += f" expected: {result['expected']}\n\n"
|
||||
|
||||
prompt = f"""Problem: {inputs['problem']}
|
||||
|
||||
Current test cases:
|
||||
{formatted_tests}
|
||||
|
||||
Current function:
|
||||
```python
|
||||
{inputs['function_code']}
|
||||
```
|
||||
|
||||
Failed tests:
|
||||
{formatted_failures}
|
||||
|
||||
Analyze the failures and output revisions in YAML. You can revise test cases, function code, or both:
|
||||
|
||||
```yaml
|
||||
reasoning: |
|
||||
Looking at the failures, I see that...
|
||||
The issue appears to be...
|
||||
I will revise...
|
||||
test_cases: # Dictionary mapping test case index (1-based) to revised test case
|
||||
1:
|
||||
name: "Revised test name"
|
||||
input: {{...}}
|
||||
expected: ...
|
||||
function_code: | # Include this if revising function
|
||||
def run_code(...):
|
||||
return ...
|
||||
```"""
|
||||
response = call_llm(prompt)
|
||||
yaml_str = response.split("```yaml")[1].split("```")[0].strip()
|
||||
result = yaml.safe_load(yaml_str)
|
||||
|
||||
# Validation asserts
|
||||
if "test_cases" in result:
|
||||
assert isinstance(result["test_cases"], dict), "test_cases must be a dictionary"
|
||||
for index_str, test_case in result["test_cases"].items():
|
||||
assert isinstance(index_str, (str, int)), "test_cases keys must be strings or ints"
|
||||
assert "name" in test_case, f"Revised test case {index_str} missing 'name' field"
|
||||
assert "input" in test_case, f"Revised test case {index_str} missing 'input' field"
|
||||
assert "expected" in test_case, f"Revised test case {index_str} missing 'expected' field"
|
||||
|
||||
if "function_code" in result:
|
||||
assert isinstance(result["function_code"], str), "function_code must be string"
|
||||
assert "def run_code" in result["function_code"], "Function must be named 'run_code'"
|
||||
|
||||
return result
|
||||
|
||||
def post(self, shared, prep_res, exec_res):
|
||||
# Print what is being revised
|
||||
print(f"\n=== Revisions (Iteration {shared['iteration_count']}) ===")
|
||||
|
||||
# Handle test case revisions - map indices to actual test cases
|
||||
if "test_cases" in exec_res:
|
||||
current_tests = shared["test_cases"].copy()
|
||||
print("Revising test cases:")
|
||||
for index_str, revised_test in exec_res["test_cases"].items():
|
||||
index = int(index_str) - 1 # Convert to 0-based
|
||||
if 0 <= index < len(current_tests):
|
||||
old_test = current_tests[index]
|
||||
print(f" Test {index_str}: '{old_test['name']}' -> '{revised_test['name']}'")
|
||||
print(f" old input: {old_test['input']}")
|
||||
print(f" new input: {revised_test['input']}")
|
||||
print(f" old expected: {old_test['expected']}")
|
||||
print(f" new expected: {revised_test['expected']}")
|
||||
current_tests[index] = revised_test
|
||||
shared["test_cases"] = current_tests
|
||||
|
||||
if "function_code" in exec_res:
|
||||
print("Revising function code:")
|
||||
print("New function:")
|
||||
print(exec_res["function_code"])
|
||||
shared["function_code"] = exec_res["function_code"]
|
||||
Loading…
Reference in New Issue