Created
June 20, 2025 18:02
-
-
Save bukzor/7915814febdde9e9c068d6711c9bcb5c to your computer and use it in GitHub Desktop.
libcst: prepend a decorator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
usage: add-decorator-cst DECORATOR_NAME FUNCTION_NAME FILE_PATH | |
Automatically adds a decorator to a Python function's source code, | |
preserving all comments and formatting. | |
Requires: pip install libcst | |
Examples: | |
add-decorator-cst "cache" "expensive_function" my_module.py | |
add-decorator-cst "dataclass" "MyClass" models.py | |
""" | |
from __future__ import annotations | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import TypeAlias | |
try: | |
import libcst as cst | |
except ImportError: | |
raise ImportError("libcst is required: pip install libcst") | |
ExitCode: TypeAlias = None | int | str | |
USAGE = __doc__ | |
@dataclass(frozen=True) | |
class DecoratorAdder(cst.CSTTransformer): | |
"""CST transformer that adds decorators while preserving formatting.""" | |
decorator_name: str | |
target_name: str | |
def leave_FunctionDef( | |
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef | |
) -> cst.FunctionDef: | |
if updated_node.name.value == self.target_name: | |
return self._add_decorator(updated_node) | |
return updated_node | |
def leave_ClassDef( | |
self, original_node: cst.ClassDef, updated_node: cst.ClassDef | |
) -> cst.ClassDef: | |
if updated_node.name.value == self.target_name: | |
return self._add_decorator(updated_node) | |
return updated_node | |
def _add_decorator(self, node): | |
"""Add decorator if not already present.""" | |
# Check if decorator already exists | |
for decorator in node.decorators: | |
decorator_name = self._get_decorator_name(decorator.decorator) | |
if decorator_name == self.decorator_name: | |
return node # Already has this decorator | |
# Create new decorator | |
new_decorator = cst.Decorator( | |
decorator=self._create_decorator_node(self.decorator_name) | |
) | |
# Add to beginning of decorator list | |
return node.with_changes(decorators=[new_decorator] + list(node.decorators)) | |
def _get_decorator_name(self, decorator: cst.BaseExpression) -> str: | |
"""Extract decorator name for comparison.""" | |
if isinstance(decorator, cst.Name): | |
return decorator.value | |
elif isinstance(decorator, cst.Attribute): | |
# Handle module.decorator syntax | |
parts = [] | |
current = decorator | |
while isinstance(current, cst.Attribute): | |
parts.append(current.attr.value) | |
current = current.value | |
if isinstance(current, cst.Name): | |
parts.append(current.value) | |
return ".".join(reversed(parts)) | |
return "" | |
def _create_decorator_node(self, decorator_name: str) -> cst.BaseExpression: | |
"""Create a decorator node from a string name.""" | |
if "." in decorator_name: | |
# Handle module.decorator syntax | |
parts = decorator_name.split(".") | |
result = cst.Name(parts[0]) | |
for part in parts[1:]: | |
result = cst.Attribute(value=result, attr=cst.Name(part)) | |
return result | |
else: | |
return cst.Name(decorator_name) | |
def add_decorator_to_file( | |
file_path: Path, decorator_name: str, target_name: str | |
) -> str: | |
"""Add decorator to target function/class in file, return modified source.""" | |
source = file_path.read_text() | |
tree = cst.parse_module(source) | |
transformer = DecoratorAdder(decorator_name, target_name) | |
modified_tree = tree.visit(transformer) | |
return modified_tree.code | |
def main() -> ExitCode: | |
from sys import argv | |
if len(argv) != 4: | |
print(USAGE) | |
return 1 | |
_, decorator_name, target_name, file_path_str = argv | |
file_path = Path(file_path_str) | |
if not file_path.exists(): | |
print(f"File not found: {file_path}") | |
return 1 | |
try: | |
modified_source = add_decorator_to_file(file_path, decorator_name, target_name) | |
file_path.write_text(modified_source) | |
print(f"Added @{decorator_name} to {target_name} in {file_path}") | |
except Exception as e: | |
print(f"Error: {e}") | |
return 1 | |
if __name__ == "__main__": | |
raise SystemExit(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Tests for add_decorator module - showing perfect comment preservation.""" | |
from __future__ import annotations | |
from pathlib import Path | |
import libcst as cst | |
from add_decorator import DecoratorAdder | |
from add_decorator import add_decorator_to_file | |
class DescribeCommentAndWhitespacePreservation: | |
"""Test that libcst preserves comments and formatting perfectly.""" | |
def it_preserves_all_comments_and_whitespace(self, tmp_path: Path): | |
"""Mega test for comprehensive comment and whitespace preservation.""" | |
content = '''# Top-level comment | |
"""Module docstring.""" | |
import os # Import comment | |
# Comment before function | |
def foo( # Parameter comment | |
x: int, # Type comment | |
y: str = "default" # Default comment | |
) -> bool: # Return comment | |
"""Function docstring with | |
multiple lines.""" | |
# Inside comment | |
if x > 0: # Inline condition comment | |
return True # Return comment | |
# Another inside comment | |
return False | |
# Comment between functions | |
def bar(): | |
pass # Simple pass comment | |
# Final comment''' | |
test_file = tmp_path / "test.py" | |
test_file.write_text(content) | |
result = add_decorator_to_file(test_file, "cache", "foo") | |
# Should preserve ALL comments and formatting exactly | |
expected = '''# Top-level comment | |
"""Module docstring.""" | |
import os # Import comment | |
# Comment before function | |
@cache | |
def foo( # Parameter comment | |
x: int, # Type comment | |
y: str = "default" # Default comment | |
) -> bool: # Return comment | |
"""Function docstring with | |
multiple lines.""" | |
# Inside comment | |
if x > 0: # Inline condition comment | |
return True # Return comment | |
# Another inside comment | |
return False | |
# Comment between functions | |
def bar(): | |
pass # Simple pass comment | |
# Final comment''' | |
assert result == expected |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment