Last active
January 15, 2025 06:09
-
-
Save wolph/5198729733a3d9175466589060220f2c to your computer and use it in GitHub Desktop.
Automatically improve code quality by having an LLM fix ruff/pyright/mypy errors and by having it refactor code. GPT-o1 or stronger models recommended.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import ast | |
import asyncio | |
import difflib | |
import logging | |
import os | |
import pathlib | |
import re | |
import sys | |
from datetime import datetime | |
import aiofiles | |
from openai import AsyncAzureOpenAI | |
from openai.types.chat import ChatCompletion | |
from rich.console import Console | |
from rich.layout import Layout | |
from rich.logging import RichHandler | |
from rich.panel import Panel | |
from rich.syntax import Syntax | |
PROMPT = """Please fix the following code to address all issues, ensuring that all | |
public functions and classes from the original code are preserved. | |
You are tasked with generating high-quality, modern, Pythonic Python code with type hints and documentation. Follow these instructions carefully: | |
### Code Preservation | |
- **Preserve all public or global functions, classes, and definitions.** | |
- Do not remove or rename any public/global functions or classes, as they might be externally referenced. | |
- Retain all existing comments and docstrings. | |
### Code Style | |
- **Adhere to PEP8 guidelines:** | |
- Adhere to the Zen of Python. | |
- Limit lines to **79 characters**. | |
- Use **single quotes** (`'`) for strings. | |
- Ensure the overall style reflects modern Python practices. | |
### Type Hinting | |
- Add type hints to every function and every variable where possible. | |
- Add **complete type annotations** compatible with **Python 3.13 and above**: | |
- Use **native type annotations**: `list[T]` instead of `typing.List[T]`. | |
- For `Optional` annotations, use `Union` shorthand with the `|` operator (e.g., `int | None` instead of `typing.Optional[int]`). | |
### Code Optimization | |
- Refactor code as needed for **performance and clarity**, while ensuring: | |
- Functions are split when useful to reduce complexity. | |
- Code and variables are reused where possible. | |
- Avoided using `global` variables, `eval()`, or `exec()`. | |
- Maintain or improve runtime efficiency over readability when trade-offs arise. | |
### Imports | |
- Be concise and explicit when importing: | |
- Avoid `from module import *`. | |
- Import only what is necessary. | |
- Prefer `import module` over `from module import function` when practical to avoid namespace collisions. | |
- Prefer `from spam import eggs` over `from eggs.spam import foo` for readability. | |
- Prefer `pathlib` over `os.path`. | |
### Documentation | |
- Ensure the code is well-documented like a public-facing open-source library: | |
- Retain all **existing documentation** (comments, docstrings, and doctests). | |
- Add docstrings for any public function, class, or module without one. | |
- Include new **doctests** where appropriate, without modifying existing doctests. | |
- Be descriptive but concise. | |
# Output Format | |
- Provide the updated Python code as plain text. | |
- Ensure all changes adhere to the guidelines. | |
# Notes | |
- Avoid altering the purpose or behavior of existing code. | |
- Ensure the final product maintains compliance with all stated rules. | |
- If you encounter areas where optimizations would be speculative or where changes could disrupt backward compatibility, leave the code as-is and explain briefly within comments. | |
# Example | |
### Input Code | |
```python | |
def add(x, y): | |
return x + y | |
def perform_operation(op, a, b): | |
if op == 'add': | |
return add(a, b) | |
elif op == 'subtract': | |
return a - b | |
elif op == 'multiply': | |
return a * b | |
return None | |
``` | |
### Optimized Output | |
```python | |
def add(x: int | float, y: int | float) -> int | float: | |
'''Add two numbers. | |
Args: | |
x: The first number. | |
y: The second number. | |
Returns: | |
The sum of x and y. | |
>>> add(3, 7) | |
10 | |
''' | |
return x + y | |
def perform_operation(op: str, a: int | float, b: int | float) -> int | float | None: | |
'''Perform a mathematical operation on two numbers. | |
Args: | |
op: The operation to perform ('add', 'subtract', 'multiply'). | |
a: The first number. | |
b: The second number. | |
Returns: | |
The result of the operation, or None if the operation is invalid. | |
>>> perform_operation('add', 3, 7) | |
10 | |
>>> perform_operation('subtract', 10, 3) | |
7 | |
>>> perform_operation('multiply', 4, 5) | |
20 | |
>>> perform_operation('divide', 4, 5) is None | |
True | |
''' | |
match op: | |
case 'add': | |
return add(a, b) | |
case 'subtract': | |
return a - b | |
case 'multiply': | |
return a * b | |
case _: | |
return None | |
``` | |
The output demonstrates refactored code with Python 3.13 type hints, better performance (using `match`), and added documentation. | |
### Output Format | |
Provide only the updated code for each file, enclosed within special markers as | |
follows: | |
For each file: | |
<<FILE: file_path>> | |
<code> | |
<<ENDFILE>> | |
Do not include any additional text or comments in the output. | |
""" | |
SYSTEM_PROMPT = """You are tasked with generating high-quality, modern, Pythonic | |
Python code with type hints and documentation. Follow these instructions | |
carefully: | |
### Code Preservation | |
- Preserve all public or global functions, classes, and definitions. | |
- Do not remove or rename any public/global functions or classes, as they might | |
be externally referenced. | |
- Retain all existing comments and docstrings. | |
### Code Style | |
- Adhere to PEP8 guidelines: | |
- Limit lines to 79 characters. | |
- Use single quotes (`'`) for strings. | |
- Ensure the overall style reflects modern Python practices. | |
### Type Hinting | |
- Add complete type annotations compatible with Python 3.13 and above: | |
- Use native type annotations: `list[T]` instead of `typing.List[T]`. | |
- For Optional annotations, use Union shorthand with the `|` operator (e.g., | |
`int | None`). | |
### Code Optimization | |
- Refactor code as needed for performance and clarity, while ensuring: | |
- Functions are split when useful to reduce complexity. | |
- Code and variables are reused where possible. | |
- Avoided using `global` variables, `eval()`, or `exec()`. | |
- Maintain or improve runtime efficiency over readability when trade-offs | |
arise. | |
### Imports | |
- Be concise and explicit when importing: | |
- Avoid `from module import *`. | |
- Import only what is necessary. | |
- Prefer `import module` over `from module import function` when practical to | |
avoid namespace collisions. | |
- Prefer `pathlib` over `os.path`. | |
### Documentation | |
- Ensure the code is well-documented like a public-facing open-source library: | |
- Retain all existing documentation (comments, docstrings, and doctests). | |
- Add docstrings for any public function, class, or module without one. | |
- Include new doctests where appropriate, without modifying existing doctests. | |
- Be descriptive but concise. | |
# Output Format | |
- Provide the updated Python code as plain text. | |
- Ensure all changes adhere to the guidelines. | |
# Notes | |
- Avoid altering the purpose or behavior of existing code. | |
- Ensure the final product maintains compliance with all stated rules. | |
- If you encounter areas where optimizations would be speculative or where | |
changes could disrupt backward compatibility, leave the code as-is and explain | |
briefly within comments. | |
Encase code between "```" and "```". | |
""" | |
# Read .env file if it exists | |
DOT_ENV_FILE = pathlib.Path('.env') | |
if DOT_ENV_FILE.exists(): | |
with DOT_ENV_FILE.open() as f: | |
for line in f: | |
line = line.strip() | |
if line.startswith('#'): | |
continue | |
key, value = line.strip().split('=', 1) | |
os.environ.setdefault(key, value) | |
async def main() -> None: | |
""" | |
Main function to parse arguments and run the code improvement process. | |
""" | |
console = Console() | |
parser = argparse.ArgumentParser( | |
description='Automatically improve Python files using Azure OpenAI API.' | |
) | |
parser.add_argument( | |
'files', metavar='FILE', nargs='+', help='Python files to process' | |
) | |
parser.add_argument( | |
'--max-iterations', | |
type=int, | |
default=1, | |
help='Maximum number of iterations per file (default: 1)', | |
) | |
parser.add_argument( | |
'--tools', | |
nargs='+', | |
choices=['ruff', 'pyright', 'mypy'], | |
default=['ruff', 'pyright'], | |
help='Specify code quality tools to run (default: ruff/pyright).', | |
) | |
parser.add_argument( | |
'--verbose', | |
'-v', | |
action='store_true', | |
help='Enable verbose logging', | |
) | |
parser.add_argument( | |
'--force', | |
'-f', | |
action='store_true', | |
help='Force write of files even with no improvements', | |
) | |
parser.add_argument( | |
'--test', | |
'-t', | |
action='store_true', | |
help='Run pytest after processing', | |
) | |
parser.add_argument( | |
'--no-backup', | |
'-B', | |
action='store_true', | |
help='Prevent the creation of backups before each step', | |
) | |
parser.add_argument( | |
'--parallel-files', | |
action='store_true', | |
help='Process files individually and in parallel', | |
) | |
args = parser.parse_args() | |
# Set up logging with color support | |
logging.basicConfig( | |
level=logging.DEBUG if args.verbose else logging.INFO, | |
format='%(asctime)s %(name)s [%(levelname)s] %(message)s', | |
datefmt='%H:%M:%S', | |
handlers=[RichHandler(rich_tracebacks=True)], | |
) | |
# Suppress logging from azure libraries | |
logging.getLogger('azure').setLevel(logging.ERROR) | |
logging.getLogger('openai').setLevel(logging.ERROR) | |
logging.getLogger('httpcore').setLevel(logging.ERROR) | |
# Environment variables | |
azure_endpoint = os.getenv('AZURE_OPENAI_ENDPOINT') | |
api_key = os.getenv('AZURE_OPENAI_API_KEY') | |
openai_deployment = os.getenv('OPENAI_DEPLOYMENT') | |
api_version = '2024-12-01-preview' # Update to your API version if needed | |
if not azure_endpoint or not api_key or not openai_deployment: | |
logging.error( | |
'Azure OpenAI environment variables are not set properly.' | |
) | |
sys.exit(1) | |
# Set up Azure OpenAI API client | |
client = AsyncAzureOpenAI( | |
azure_endpoint=azure_endpoint, | |
api_key=api_key, | |
api_version=api_version, | |
) | |
# Read contents of all files | |
original_code: dict[str, str] = {} | |
for file_path in args.files: | |
try: | |
async with aiofiles.open(file_path) as f: | |
original_code[file_path] = await f.read() | |
except FileNotFoundError: | |
logging.exception(f'File not found: {file_path}') | |
sys.exit(1) | |
except OSError as e: | |
logging.exception(f'Error reading file {file_path}: {e}') | |
sys.exit(1) | |
# Run ruff format and fix on all files before processing | |
for file_path in args.files: | |
logging.info(f'Formatting {file_path} with ruff...') | |
await run_ruff_format(file_path) | |
await run_ruff_fix(file_path) | |
async with client: | |
# Process files based on the parallel_files flag | |
await process_files( | |
file_paths=args.files, | |
max_iterations=args.max_iterations, | |
model=openai_deployment, | |
tools=args.tools, | |
console=console, | |
original_code=original_code, | |
force_write=args.force, | |
no_backup=args.no_backup, | |
client=client, | |
parallel_files=args.parallel_files, | |
) | |
# Run ruff format and fix on all files after processing | |
for file_path in args.files: | |
logging.info(f'Formatting {file_path} with ruff...') | |
await run_ruff_format(file_path) | |
await run_ruff_fix(file_path) | |
# Run pytest to verify that no code was broken if --test is specified | |
if args.test: | |
await run_pytest(console) | |
async def run_ruff_format(file_path: str) -> None: | |
""" | |
Format the given file using ruff format. | |
Args: | |
file_path (str): Path to the file to format. | |
""" | |
try: | |
process = await asyncio.create_subprocess_exec( | |
'ruff', | |
'format', | |
file_path, | |
stdout=asyncio.subprocess.DEVNULL, | |
stderr=asyncio.subprocess.DEVNULL, | |
) | |
await process.communicate() | |
except OSError as e: | |
logging.exception(f'Error running ruff format on {file_path}: {e}') | |
async def run_ruff_fix(file_path: str) -> None: | |
""" | |
Fix the given file using ruff fix. | |
Args: | |
file_path (str): Path to the file to fix. | |
""" | |
try: | |
process = await asyncio.create_subprocess_exec( | |
'ruff', | |
'check', | |
'--fix', | |
file_path, | |
stdout=asyncio.subprocess.DEVNULL, | |
stderr=asyncio.subprocess.DEVNULL, | |
) | |
await process.communicate() | |
except OSError as e: | |
logging.exception(f'Error running ruff fix on {file_path}: {e}') | |
async def process_files( | |
file_paths: list[str], | |
max_iterations: int, | |
model: str, | |
tools: list[str], | |
console: Console, | |
original_code: dict[str, str], | |
force_write: bool, | |
no_backup: bool, | |
client: AsyncAzureOpenAI, | |
parallel_files: bool, | |
) -> None: | |
""" | |
Process the files, improving their code with the help of OpenAI. | |
Args: | |
file_paths (list[str]): List of file paths to process. | |
max_iterations (int): Maximum number of iterations to attempt improvements. | |
model (str): Deployment ID for the Azure OpenAI service. | |
tools (list[str]): List of code quality tools to use. | |
console (Console): Rich console instance for output. | |
original_code (dict[str, str]): Original code for each file. | |
force_write (bool): Whether to write files even if there are no improvements. | |
no_backup (bool): Whether to skip creating backups before writing files. | |
client (AsyncAzureOpenAI): The async Azure OpenAI client. | |
parallel_files (bool): Whether to process files individually and in parallel. | |
""" | |
if parallel_files: | |
# Process files individually and in parallel | |
tasks = [ | |
process_single_file( | |
file_path, | |
max_iterations, | |
model, | |
tools, | |
console, | |
original_code[file_path], | |
force_write, | |
no_backup, | |
client, | |
) | |
for file_path in file_paths | |
] | |
await asyncio.gather(*tasks) | |
else: | |
# Process files together | |
code_to_improve = original_code.copy() | |
improved_code = code_to_improve.copy() | |
for iteration in range(1, max_iterations + 1): | |
logging.info(f'Iteration {iteration} of {max_iterations}') | |
logging.info('Improving code using %s', model) | |
# Run code quality tools on current code | |
logging.info('Running code quality tools before improvement...') | |
quality_errors = await run_quality_tools(file_paths, tools) | |
if quality_errors: | |
logging.info('Current code quality issues:') | |
logging.info(quality_errors) | |
else: | |
logging.info('No code quality issues in current code.') | |
if not quality_errors.strip(): | |
logging.info('No code quality issues found.') | |
unchanged = all( | |
improved_code[fp] == original_code[fp] for fp in file_paths | |
) | |
if unchanged and not force_write: | |
logging.info('No changes made to the code.') | |
break | |
for file_path in file_paths: | |
if ( | |
improved_code[file_path] != original_code[file_path] | |
or force_write | |
): | |
missing_globals_error = check_global_definitions( | |
original_code[file_path], | |
improved_code[file_path], | |
) | |
if missing_globals_error: | |
logging.error( | |
f'Globals missing in the improved code ' | |
f'for {file_path}:' | |
) | |
logging.error(missing_globals_error) | |
logging.error( | |
f'Aborting write operation for {file_path} ' | |
'due to missing globals.' | |
) | |
if not force_write: | |
continue | |
if not no_backup: | |
await create_backup_file( | |
file_path, original_code[file_path] | |
) | |
show_diff_with_syntax_highlighting( | |
console, | |
original_code[file_path], | |
improved_code[file_path], | |
file_path, | |
) | |
logging.info( | |
f'Writing improved code to original file {file_path}...' | |
) | |
async with aiofiles.open(file_path, 'w') as f: | |
await f.write(improved_code[file_path]) | |
logging.info( | |
f'No errors found in {file_path} after {iteration} ' | |
'iteration(s).' | |
) | |
else: | |
logging.info( | |
f'No changes made to the code in {file_path}.' | |
) | |
break | |
improved_code_text = await process_code_with_openai( | |
code_to_improve, | |
model, | |
client, | |
quality_errors=quality_errors, | |
mode='improve', | |
) | |
improved_code = parse_improved_code( | |
improved_code_text, file_paths, code_to_improve | |
) | |
unchanged = all( | |
improved_code[fp] == code_to_improve[fp] for fp in file_paths | |
) | |
if unchanged: | |
logging.error('Failed to improve code. Exiting iteration.') | |
break | |
logging.info('Checking for missing global definitions...') | |
globals_errors = '' | |
for file_path in file_paths: | |
missing_globals_error = check_global_definitions( | |
original_code[file_path], improved_code[file_path] | |
) | |
if missing_globals_error: | |
logging.error( | |
f'Globals missing in the improved code for {file_path}:' | |
) | |
logging.error(missing_globals_error) | |
globals_errors += missing_globals_error + '\n' | |
else: | |
logging.info( | |
f'All global definitions are present in {file_path}.' | |
) | |
logging.info('Checking for missing public definitions...') | |
errors = '' | |
for file_path in file_paths: | |
missing_defs_error = check_public_definitions( | |
original_code[file_path], improved_code[file_path] | |
) | |
if missing_defs_error: | |
logging.warning( | |
f'Public definitions missing in {file_path}:' | |
) | |
logging.warning(missing_defs_error) | |
errors += missing_defs_error + '\n' | |
else: | |
logging.info( | |
f'All public definitions are present in {file_path}.' | |
) | |
errors += globals_errors | |
temp_files = [] | |
for file_path in file_paths: | |
temp_file_path = file_path + '.temp' | |
temp_files.append(temp_file_path) | |
async with aiofiles.open(temp_file_path, 'w') as f: | |
await f.write(improved_code[file_path]) | |
logging.info('Running code quality tools after improvement...') | |
quality_errors = await run_quality_tools(temp_files, tools) | |
if quality_errors: | |
logging.warning('Code quality issues found:') | |
logging.warning(quality_errors) | |
errors += quality_errors | |
else: | |
logging.info('No code quality issues found after improvement.') | |
all_globals_present = True | |
for file_path in file_paths: | |
missing_globals_error = check_global_definitions( | |
original_code[file_path], improved_code[file_path] | |
) | |
if missing_globals_error: | |
logging.error( | |
f'Globals missing in the improved code ' | |
f'for {file_path}:' | |
) | |
logging.error(missing_globals_error) | |
logging.error( | |
f'Aborting write operation for {file_path} ' | |
'due to missing globals.' | |
) | |
all_globals_present = False | |
break | |
if not all_globals_present and not force_write: | |
code_to_improve = improved_code.copy() | |
for temp_file in temp_files: | |
if os.path.exists(temp_file): | |
os.remove(temp_file) | |
continue | |
for file_path in file_paths: | |
if ( | |
improved_code[file_path] != original_code[file_path] | |
or force_write | |
): | |
if not no_backup: | |
await create_backup_file( | |
file_path, original_code[file_path] | |
) | |
show_diff_with_syntax_highlighting( | |
console, | |
original_code[file_path], | |
improved_code[file_path], | |
file_path, | |
) | |
logging.info( | |
'Writing improved code to original file ' | |
f'{file_path}...' | |
) | |
async with aiofiles.open(file_path, 'w') as f: | |
await f.write(improved_code[file_path]) | |
logging.info( | |
f'No errors found in {file_path} after ' | |
f'{iteration} iteration(s).' | |
) | |
for temp_file in temp_files: | |
if os.path.exists(temp_file): | |
os.remove(temp_file) | |
break | |
logging.info('Fixing code based on errors...') | |
code_to_improve_text = await process_code_with_openai( | |
improved_code, | |
model, | |
client, | |
errors=errors, | |
mode='fix', | |
) | |
code_to_improve = parse_improved_code( | |
code_to_improve_text, file_paths, improved_code | |
) | |
unchanged = all( | |
code_to_improve[fp] == improved_code[fp] for fp in file_paths | |
) | |
if unchanged: | |
logging.error('Failed to fix code. Exiting iteration.') | |
break | |
for temp_file in temp_files: | |
if os.path.exists(temp_file): | |
os.remove(temp_file) | |
else: | |
logging.warning( | |
f'Maximum iterations ({max_iterations}) reached. Errors may remain.' | |
) | |
for file_path in file_paths: | |
if ( | |
improved_code[file_path] != original_code[file_path] | |
or force_write | |
): | |
if not no_backup: | |
await create_backup_file( | |
file_path, original_code[file_path] | |
) | |
show_diff_with_syntax_highlighting( | |
console, | |
original_code[file_path], | |
improved_code[file_path], | |
file_path, | |
) | |
logging.info( | |
f'Writing last improved code to original file {file_path}.' | |
) | |
missing_globals_error = check_global_definitions( | |
original_code[file_path], improved_code[file_path] | |
) | |
if missing_globals_error and not force_write: | |
logging.error( | |
f'Globals missing in the improved code for {file_path}:' | |
) | |
logging.error(missing_globals_error) | |
logging.error( | |
f'Aborting write operation for {file_path} ' | |
'due to missing globals.' | |
) | |
continue | |
async with aiofiles.open(file_path, 'w') as f: | |
await f.write(improved_code[file_path]) | |
for temp_file in temp_files: | |
if os.path.exists(temp_file): | |
os.remove(temp_file) | |
async def process_single_file( | |
file_path: str, | |
max_iterations: int, | |
model: str, | |
tools: list[str], | |
console: Console, | |
original_code: str, | |
force_write: bool, | |
no_backup: bool, | |
client: AsyncAzureOpenAI, | |
) -> None: | |
""" | |
Process a single file, improving its code with the help of OpenAI. | |
Args: | |
file_path (str): Path to the file to process. | |
max_iterations (int): Maximum number of iterations to attempt improvements. | |
model (str): Deployment ID for the Azure OpenAI service. | |
tools (list[str]): List of code quality tools to use. | |
console (Console): Rich console instance for output. | |
original_code (str): Original code for the file. | |
force_write (bool): Whether to write files even if there are no improvements. | |
no_backup (bool): Whether to skip creating backups before writing files. | |
client (AsyncAzureOpenAI): The async Azure OpenAI client. | |
""" | |
code_to_improve = original_code | |
improved_code = code_to_improve | |
for iteration in range(1, max_iterations + 1): | |
logging.info( | |
f'Processing {file_path} - Iteration {iteration} of {max_iterations}' | |
) | |
logging.info('Improving code using %s', model) | |
logging.info( | |
f'Running code quality tools on {file_path} before improvement...' | |
) | |
temp_file_path = file_path + '.temp.before' | |
async with aiofiles.open(temp_file_path, 'w') as f: | |
await f.write(code_to_improve) | |
quality_errors = await run_quality_tools([temp_file_path], tools) | |
if os.path.exists(temp_file_path): | |
os.remove(temp_file_path) | |
if quality_errors: | |
logging.info('Current code quality issues:') | |
logging.info(quality_errors) | |
else: | |
logging.info('No code quality issues in current code.') | |
if not quality_errors.strip(): | |
logging.info('No code quality issues found.') | |
if improved_code == original_code and not force_write: | |
logging.info('No changes made to the code.') | |
break | |
missing_globals_error = check_global_definitions( | |
original_code, improved_code | |
) | |
if missing_globals_error: | |
logging.error( | |
f'Globals missing in the improved code for {file_path}:' | |
) | |
logging.error(missing_globals_error) | |
logging.error( | |
f'Aborting write operation for {file_path} ' | |
'due to missing globals.' | |
) | |
if not force_write: | |
continue | |
if not no_backup: | |
await create_backup_file(file_path, original_code) | |
show_diff_with_syntax_highlighting( | |
console, original_code, improved_code, file_path | |
) | |
logging.info( | |
f'Writing improved code to original file {file_path}...' | |
) | |
async with aiofiles.open(file_path, 'w') as f: | |
await f.write(improved_code) | |
logging.info( | |
f'No errors found in {file_path} after {iteration} iteration(s).' | |
) | |
break | |
improved_code_text = await process_code_with_openai( | |
{file_path: code_to_improve}, | |
model, | |
client, | |
quality_errors=quality_errors, | |
mode='improve', | |
) | |
improved_code_dict = parse_improved_code( | |
improved_code_text, [file_path], {file_path: code_to_improve} | |
) | |
improved_code = improved_code_dict[file_path] | |
if improved_code == code_to_improve: | |
logging.error('Failed to improve code. Exiting iteration.') | |
break | |
logging.info('Checking for missing global definitions...') | |
missing_globals_error = check_global_definitions( | |
original_code, improved_code | |
) | |
if missing_globals_error: | |
logging.error( | |
f'Globals missing in the improved code for {file_path}:' | |
) | |
logging.error(missing_globals_error) | |
globals_errors = missing_globals_error + '\n' | |
else: | |
logging.info(f'All global definitions are present in {file_path}.') | |
globals_errors = '' | |
logging.info('Checking for missing public definitions...') | |
missing_defs_error = check_public_definitions( | |
original_code, improved_code | |
) | |
if missing_defs_error: | |
logging.warning(f'Public definitions missing in {file_path}:') | |
logging.warning(missing_defs_error) | |
errors = missing_defs_error + '\n' + globals_errors | |
else: | |
logging.info(f'All public definitions are present in {file_path}.') | |
errors = globals_errors | |
temp_file_path = file_path + '.temp' | |
async with aiofiles.open(temp_file_path, 'w') as f: | |
await f.write(improved_code) | |
logging.info( | |
f'Running code quality tools on {file_path} after improvement...' | |
) | |
quality_errors = await run_quality_tools([temp_file_path], tools) | |
if quality_errors: | |
logging.warning('Code quality issues found:') | |
logging.warning(quality_errors) | |
errors += quality_errors | |
if os.path.exists(temp_file_path): | |
os.remove(temp_file_path) | |
if not errors.strip(): | |
missing_globals_error = check_global_definitions( | |
original_code, improved_code | |
) | |
if missing_globals_error and not force_write: | |
logging.error( | |
f'Globals missing in the improved code for {file_path}:' | |
) | |
logging.error(missing_globals_error) | |
logging.error( | |
f'Aborting write operation for {file_path} ' | |
'due to missing globals.' | |
) | |
continue | |
if not no_backup: | |
await create_backup_file(file_path, original_code) | |
show_diff_with_syntax_highlighting( | |
console, original_code, improved_code, file_path | |
) | |
logging.info( | |
'Writing improved code to original file ' f'{file_path}...' | |
) | |
async with aiofiles.open(file_path, 'w') as f: | |
await f.write(improved_code) | |
logging.info( | |
f'No errors found in {file_path} after {iteration} iteration(s).' | |
) | |
break | |
logging.info(f'Fixing code for {file_path} based on errors...') | |
code_to_improve_text = await process_code_with_openai( | |
{file_path: improved_code}, | |
model, | |
client, | |
errors=errors, | |
mode='fix', | |
) | |
code_to_improve_dict = parse_improved_code( | |
code_to_improve_text, | |
[file_path], | |
{file_path: improved_code}, | |
) | |
code_to_improve = code_to_improve_dict[file_path] | |
if code_to_improve == improved_code: | |
logging.error('Failed to fix code. Exiting iteration.') | |
break | |
improved_code = code_to_improve | |
else: | |
logging.warning( | |
f'Maximum iterations ({max_iterations}) reached for {file_path}. ' | |
'Errors may remain.' | |
) | |
if improved_code != original_code or force_write: | |
if not no_backup: | |
await create_backup_file(file_path, original_code) | |
show_diff_with_syntax_highlighting( | |
console, original_code, improved_code, file_path | |
) | |
logging.info( | |
f'Writing last improved code to original file {file_path}.' | |
) | |
missing_globals_error = check_global_definitions( | |
original_code, improved_code | |
) | |
if missing_globals_error and not force_write: | |
logging.error( | |
f'Globals missing in the improved code for {file_path}:' | |
) | |
logging.error(missing_globals_error) | |
logging.error( | |
f'Aborting write operation for {file_path} due to ' | |
'missing globals.' | |
) | |
else: | |
async with aiofiles.open(file_path, 'w') as f: | |
await f.write(improved_code) | |
async def create_backup_file(file_path: str, code: str) -> None: | |
""" | |
Create a backup of the specified file, naming the backup file with a | |
timestamp. | |
Args: | |
file_path (str): Path of the file being backed up. | |
code (str): The content to write to the backup file. | |
""" | |
backup_file_path = os.path.join( | |
os.path.dirname(file_path), | |
f'.{datetime.now()}.{os.path.basename(file_path)}', | |
) | |
logging.info(f'Creating backup of {file_path} at {backup_file_path}') | |
async with aiofiles.open(backup_file_path, 'w') as f: | |
await f.write(code) | |
def parse_improved_code( | |
response_text: str, file_paths: list[str], fallback_code: dict[str, str] | |
) -> dict[str, str]: | |
""" | |
Parses the assistant's response to extract improved code for each file. | |
Args: | |
response_text (str): The response text from the assistant. | |
file_paths (list[str]): List of file paths. | |
fallback_code (dict[str, str]): Original code to fallback on | |
if parsing fails. | |
Returns: | |
dict[str, str]: Dictionary mapping file paths to improved code. | |
""" | |
improved_code: dict[str, str] = {} | |
pattern = r'<<FILE:(.*?)>>(.*?)<<ENDFILE>>' | |
matches = re.findall(pattern, response_text, re.DOTALL) | |
for file_name, code in matches: | |
file_name = file_name.strip() | |
code = code.strip() | |
if file_name in file_paths: | |
improved_code[file_name] = code | |
else: | |
logging.warning( | |
f'Received code for unexpected file {file_name}. Ignoring.' | |
) | |
for file_path in file_paths: | |
if file_path not in improved_code: | |
logging.warning( | |
f"No improved code for {file_path} found in the assistant's " | |
'response.' | |
) | |
improved_code[file_path] = fallback_code[file_path] | |
return improved_code | |
def show_diff_with_syntax_highlighting( | |
console: Console, original_code: str, improved_code: str, file_path: str | |
) -> None: | |
""" | |
Shows a syntax-highlighted diff between the original and improved code. | |
Args: | |
console (Console): Rich console instance for output. | |
original_code (str): The original code. | |
improved_code (str): The improved code. | |
file_path (str): The path to the file. | |
""" | |
logging.info( | |
f'Showing diff between original and improved code for {file_path}:' | |
) | |
diff = list( | |
difflib.unified_diff( | |
original_code.splitlines(), | |
improved_code.splitlines(), | |
fromfile=f'Original {file_path}', | |
tofile=f'Improved {file_path}', | |
lineterm='', | |
) | |
) | |
if diff: | |
original_syntax = Syntax( | |
original_code, 'python', line_numbers=True, word_wrap=True | |
) | |
improved_syntax = Syntax( | |
improved_code, 'python', line_numbers=True, word_wrap=True | |
) | |
layout = Layout() | |
layout.split_row( | |
Layout(Panel(original_syntax, title=f'Original {file_path}')), | |
Layout(Panel(improved_syntax, title=f'Improved {file_path}')), | |
) | |
console.print(layout) | |
else: | |
logging.info( | |
f'No differences found between original and improved code for ' | |
f'{file_path}.' | |
) | |
async def process_code_with_openai( | |
code_dict: dict[str, str], | |
model: str, | |
client: AsyncAzureOpenAI, | |
errors: str = '', | |
quality_errors: str = '', | |
mode: str = 'improve', # 'improve' or 'fix' | |
) -> str: | |
""" | |
Use OpenAI API to improve or fix code for multiple files based on errors. | |
Args: | |
code_dict (dict[str, str]): Dictionary mapping file paths to code. | |
model (str): Deployment ID for the Azure OpenAI service. | |
client (AsyncAzureOpenAI): The async Azure OpenAI client. | |
errors (str): String containing code errors to address. | |
quality_errors (str): String containing code quality errors. | |
mode (str): Mode of operation, 'improve' or 'fix'. | |
Returns: | |
str: Improved or fixed code as a string with markers indicating files. | |
""" | |
PROMPT.strip() | |
prompt = '' | |
if mode == 'improve': | |
prompt += ( | |
'\n\nHere are the code quality issues found by ruff and pyright ' | |
) | |
prompt += '(if any):\n' | |
prompt += quality_errors or 'No issues found.' | |
prompt += '\n\nOriginal code for all files:\n' | |
elif mode == 'fix': | |
prompt += '\n\nHere are the errors:\n' | |
prompt += errors | |
prompt += '\n\nHere is the code that needs fixing:\n' | |
for file_path, code in code_dict.items(): | |
prompt += f'\n<<FILE: {file_path}>>\n{code}\n<<ENDFILE>>' | |
messages = [ | |
{'role': 'developer', 'content': SYSTEM_PROMPT}, | |
{'role': 'user', 'content': prompt}, | |
] | |
logging.debug(f'Sending {mode} request to OpenAI API...') | |
try: | |
response: ChatCompletion = await client.chat.completions.create( | |
messages=messages, | |
model=model, | |
max_completion_tokens=65536, | |
stop=None, | |
) | |
except Exception: | |
logging.exception('Azure OpenAI API error') | |
return '\n'.join(code_dict.values()) | |
code_text = response.choices[0].message.content | |
logging.info('Usage statistics: %r', response.usage) | |
logging.debug(f'Received {mode} response from OpenAI API') | |
logging.debug('Code received:') | |
logging.debug(code_text) | |
return code_text | |
async def run_quality_tools(file_paths: list[str], tools: list[str]) -> str: | |
""" | |
Runs code quality tools on the given files and collects errors. | |
Args: | |
file_paths (list[str]): List of file paths to check. | |
tools (list[str]): List of tools to use ('ruff', 'pyright', 'mypy'). | |
Returns: | |
str: Combined error messages from all tools. | |
""" | |
errors = '' | |
tasks = [] | |
for file_path in file_paths: | |
if 'ruff' in tools: | |
logging.debug(f'Running ruff on {file_path}...') | |
tasks.append( | |
run_tool( | |
'ruff', | |
['check', '--fix', file_path], | |
f'Ruff errors in {file_path}:\n', | |
) | |
) | |
if 'pyright' in tools: | |
logging.debug(f'Running pyright on {file_path}...') | |
tasks.append( | |
run_tool( | |
'pyright', [file_path], f'Pyright errors in {file_path}:\n' | |
) | |
) | |
if 'mypy' in tools: | |
logging.debug(f'Running mypy on {file_path}...') | |
tasks.append( | |
run_tool('mypy', [file_path], f'Mypy errors in {file_path}:\n') | |
) | |
results = await asyncio.gather(*tasks) | |
errors = '\n'.join(r for r in results if r) | |
return errors.strip() | |
async def run_tool(tool_name: str, args: list[str], error_prefix: str) -> str: | |
""" | |
Run a code quality tool and capture its output. | |
Args: | |
tool_name (str): Name of the tool to run (e.g., 'ruff', 'pyright', 'mypy'). | |
args (list[str]): List of arguments to pass to the tool. | |
error_prefix (str): Prefix to include in the error messages. | |
Returns: | |
str: Captured errors from the tool's output. | |
""" | |
try: | |
process = await asyncio.create_subprocess_exec( | |
tool_name, | |
*args, | |
stdout=asyncio.subprocess.PIPE, | |
stderr=asyncio.subprocess.PIPE, | |
) | |
stdout_bytes, stderr_bytes = await process.communicate() | |
returncode = process.returncode | |
except OSError as e: | |
logging.exception(f'Error running {tool_name}') | |
return f'{error_prefix}Error running {tool_name}: {e}\n' | |
stdout = ( | |
stdout_bytes.decode('utf-8', errors='replace') if stdout_bytes else '' | |
) | |
stderr = ( | |
stderr_bytes.decode('utf-8', errors='replace') if stderr_bytes else '' | |
) | |
errors = '' | |
if tool_name == 'ruff': | |
if returncode != 0: | |
errors = error_prefix + stdout + stderr + '\n' | |
elif tool_name == 'pyright': | |
if stdout and not stdout.strip().endswith( | |
'0 errors, 0 warnings, 0 infos' | |
): | |
errors = error_prefix + stdout + stderr + '\n' | |
elif tool_name == 'mypy' and (returncode != 0 or stdout or stderr): | |
errors = error_prefix + stdout + stderr + '\n' | |
return errors | |
def extract_global_definitions(code: str) -> set[str]: | |
""" | |
Extracts the names of all global functions and classes from code. | |
Args: | |
code (str): The code to analyze. | |
Returns: | |
set[str]: Set of global function and class names. | |
""" | |
global_defs: set[str] = set() | |
try: | |
tree = ast.parse(code) | |
except SyntaxError: | |
logging.exception('Syntax error during AST parsing') | |
return global_defs | |
for node in tree.body: | |
if isinstance( | |
node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef) | |
): | |
global_defs.add(node.name) | |
return global_defs | |
def check_global_definitions(original_code: str, improved_code: str) -> str: | |
""" | |
Checks that all global functions and classes from the original code | |
exist in the improved code. Returns an error message if any are missing. | |
Args: | |
original_code (str): The original code. | |
improved_code (str): The improved code. | |
Returns: | |
str: Error message if globals are missing; empty string otherwise. | |
""" | |
original_globals = extract_global_definitions(original_code) | |
improved_globals = extract_global_definitions(improved_code) | |
missing_globals = original_globals - improved_globals | |
if missing_globals: | |
return ( | |
'The following global functions or classes are missing in the ' | |
'improved code:\n' | |
+ '\n'.join(sorted(missing_globals)) | |
+ '\nPlease ensure that all global functions and classes from ' | |
'the original code are preserved.' | |
) | |
return '' | |
def extract_public_definitions_with_parents(code: str) -> set[str]: | |
""" | |
Extracts the names of public functions and classes from the code, | |
considering their parent scopes. | |
Args: | |
code (str): The code to analyze. | |
Returns: | |
set[str]: Set of names of public functions and classes with parent scopes. | |
""" | |
public_defs: set[str] = set() | |
try: | |
tree = ast.parse(code) | |
add_parent_pointers(tree) | |
except SyntaxError: | |
logging.exception('Syntax error during AST parsing') | |
return public_defs | |
for node in ast.walk(tree): | |
if isinstance(node, (ast.FunctionDef, ast.ClassDef)): | |
names = [node.name] | |
parent = getattr(node, 'parent', None) | |
while parent and not isinstance(parent, ast.Module): | |
if isinstance(parent, (ast.ClassDef, ast.FunctionDef)): | |
names.append(parent.name) | |
parent = getattr(parent, 'parent', None) | |
full_name = '.'.join(reversed(names)) | |
if not node.name.startswith('_'): | |
public_defs.add(full_name) | |
return public_defs | |
def check_public_definitions(original_code: str, improved_code: str) -> str: | |
""" | |
Checks that all public functions and classes from the original code | |
exist in the improved code. Returns an error message if any are missing. | |
Args: | |
original_code (str): The original code. | |
improved_code (str): The improved code. | |
Returns: | |
str: Error message if public definitions are missing; empty string otherwise. | |
""" | |
original_defs = extract_public_definitions_with_parents(original_code) | |
improved_defs = extract_public_definitions_with_parents(improved_code) | |
missing_defs = original_defs - improved_defs | |
if missing_defs: | |
return ( | |
'The following public functions or classes are missing in the ' | |
'improved code:\n' | |
+ '\n'.join(sorted(missing_defs)) | |
+ '\nPlease ensure that all public functions and classes from the ' | |
'original code are preserved.' | |
) | |
return '' | |
def add_parent_pointers(tree: ast.AST) -> None: | |
""" | |
Modify AST nodes to keep track of parent nodes. | |
Args: | |
tree (ast.AST): The AST tree to modify. | |
""" | |
for node in ast.walk(tree): | |
for child in ast.iter_child_nodes(node): | |
child.parent = node | |
async def run_pytest(console: Console) -> None: | |
""" | |
Runs pytest to verify that the code changes did not break any tests. | |
Args: | |
console (Console): Rich console instance for output. | |
""" | |
logging.info('Running pytest to verify that no code was broken...') | |
try: | |
process = await asyncio.create_subprocess_exec( | |
'pytest', | |
stdout=asyncio.subprocess.PIPE, | |
stderr=asyncio.subprocess.STDOUT, | |
) | |
stdout_bytes, _ = await process.communicate() | |
stdout = stdout_bytes.decode('utf-8', errors='replace') | |
if process.returncode == 0: | |
logging.info('All tests passed successfully.') | |
else: | |
logging.error('Tests failed. Please check the output below:') | |
console.print(stdout) | |
except OSError: | |
logging.exception('Error running pytest') | |
except Exception: | |
logging.exception('Unexpected error running pytest') | |
if __name__ == '__main__': | |
try: | |
asyncio.run(main()) | |
except KeyboardInterrupt: | |
logging.info('Process interrupted by user.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment