Skip to content

Instantly share code, notes, and snippets.

@bukzor
Created June 20, 2025 18:02
Show Gist options
  • Save bukzor/7915814febdde9e9c068d6711c9bcb5c to your computer and use it in GitHub Desktop.
Save bukzor/7915814febdde9e9c068d6711c9bcb5c to your computer and use it in GitHub Desktop.
libcst: prepend a decorator
#!/usr/bin/env python3
"""
usage: add-decorator-cst DECORATOR_NAME FUNCTION_NAME FILE_PATH
Automatically adds a decorator to a Python function's source code,
preserving all comments and formatting.
Requires: pip install libcst
Examples:
add-decorator-cst "cache" "expensive_function" my_module.py
add-decorator-cst "dataclass" "MyClass" models.py
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import TypeAlias
try:
import libcst as cst
except ImportError:
raise ImportError("libcst is required: pip install libcst")
ExitCode: TypeAlias = None | int | str
USAGE = __doc__
@dataclass(frozen=True)
class DecoratorAdder(cst.CSTTransformer):
"""CST transformer that adds decorators while preserving formatting."""
decorator_name: str
target_name: str
def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
if updated_node.name.value == self.target_name:
return self._add_decorator(updated_node)
return updated_node
def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef:
if updated_node.name.value == self.target_name:
return self._add_decorator(updated_node)
return updated_node
def _add_decorator(self, node):
"""Add decorator if not already present."""
# Check if decorator already exists
for decorator in node.decorators:
decorator_name = self._get_decorator_name(decorator.decorator)
if decorator_name == self.decorator_name:
return node # Already has this decorator
# Create new decorator
new_decorator = cst.Decorator(
decorator=self._create_decorator_node(self.decorator_name)
)
# Add to beginning of decorator list
return node.with_changes(decorators=[new_decorator] + list(node.decorators))
def _get_decorator_name(self, decorator: cst.BaseExpression) -> str:
"""Extract decorator name for comparison."""
if isinstance(decorator, cst.Name):
return decorator.value
elif isinstance(decorator, cst.Attribute):
# Handle module.decorator syntax
parts = []
current = decorator
while isinstance(current, cst.Attribute):
parts.append(current.attr.value)
current = current.value
if isinstance(current, cst.Name):
parts.append(current.value)
return ".".join(reversed(parts))
return ""
def _create_decorator_node(self, decorator_name: str) -> cst.BaseExpression:
"""Create a decorator node from a string name."""
if "." in decorator_name:
# Handle module.decorator syntax
parts = decorator_name.split(".")
result = cst.Name(parts[0])
for part in parts[1:]:
result = cst.Attribute(value=result, attr=cst.Name(part))
return result
else:
return cst.Name(decorator_name)
def add_decorator_to_file(
file_path: Path, decorator_name: str, target_name: str
) -> str:
"""Add decorator to target function/class in file, return modified source."""
source = file_path.read_text()
tree = cst.parse_module(source)
transformer = DecoratorAdder(decorator_name, target_name)
modified_tree = tree.visit(transformer)
return modified_tree.code
def main() -> ExitCode:
from sys import argv
if len(argv) != 4:
print(USAGE)
return 1
_, decorator_name, target_name, file_path_str = argv
file_path = Path(file_path_str)
if not file_path.exists():
print(f"File not found: {file_path}")
return 1
try:
modified_source = add_decorator_to_file(file_path, decorator_name, target_name)
file_path.write_text(modified_source)
print(f"Added @{decorator_name} to {target_name} in {file_path}")
except Exception as e:
print(f"Error: {e}")
return 1
if __name__ == "__main__":
raise SystemExit(main())
#!/usr/bin/env python3
"""Tests for add_decorator module - showing perfect comment preservation."""
from __future__ import annotations
from pathlib import Path
import libcst as cst
from add_decorator import DecoratorAdder
from add_decorator import add_decorator_to_file
class DescribeCommentAndWhitespacePreservation:
"""Test that libcst preserves comments and formatting perfectly."""
def it_preserves_all_comments_and_whitespace(self, tmp_path: Path):
"""Mega test for comprehensive comment and whitespace preservation."""
content = '''# Top-level comment
"""Module docstring."""
import os # Import comment
# Comment before function
def foo( # Parameter comment
x: int, # Type comment
y: str = "default" # Default comment
) -> bool: # Return comment
"""Function docstring with
multiple lines."""
# Inside comment
if x > 0: # Inline condition comment
return True # Return comment
# Another inside comment
return False
# Comment between functions
def bar():
pass # Simple pass comment
# Final comment'''
test_file = tmp_path / "test.py"
test_file.write_text(content)
result = add_decorator_to_file(test_file, "cache", "foo")
# Should preserve ALL comments and formatting exactly
expected = '''# Top-level comment
"""Module docstring."""
import os # Import comment
# Comment before function
@cache
def foo( # Parameter comment
x: int, # Type comment
y: str = "default" # Default comment
) -> bool: # Return comment
"""Function docstring with
multiple lines."""
# Inside comment
if x > 0: # Inline condition comment
return True # Return comment
# Another inside comment
return False
# Comment between functions
def bar():
pass # Simple pass comment
# Final comment'''
assert result == expected
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment