Created
August 25, 2025 18:41
-
-
Save drbh/b5ef9b3466b1ddf8388053a00cfe59ce to your computer and use it in GitHub Desktop.
generate template for tests for all torch bindings exposed in a so
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
def extract_strings(binary_path, min_length=4): | |
with open(binary_path, "rb") as f: | |
data = f.read() | |
strings = [] | |
current = [] | |
for byte in data: | |
if 32 <= byte <= 126: | |
current.append(chr(byte)) | |
else: | |
if len(current) >= min_length: | |
strings.append("".join(current)) | |
current = [] | |
if current and len(current) >= min_length: | |
strings.append("".join(current)) | |
return strings | |
def find_torch_signatures(strings): | |
signatures = [] | |
pattern = r"(\w+)\s*\(([^)]*)\)(?:\s*->\s*([^,\n\r]+))?" | |
for string in strings: | |
if not any(k in string for k in ["Tensor", "int", "bool", "->", "!"]): | |
continue | |
if len(string) > 1000: | |
continue | |
for match in re.finditer(pattern, string): | |
func_name = match.group(1) | |
params_str = match.group(2) or "" | |
if ( | |
len(func_name) < 3 | |
or func_name | |
in ["if", "for", "while", "return", "sizeof", "class", "struct"] | |
or func_name.startswith("_") | |
or "::" in params_str | |
): | |
continue | |
params = [] | |
if params_str.strip(): | |
parts = [] | |
depth = 0 | |
current = [] | |
for char in params_str: | |
if char == "(": | |
depth += 1 | |
elif char == ")": | |
depth -= 1 | |
elif char == "," and depth == 0: | |
parts.append("".join(current).strip()) | |
current = [] | |
continue | |
current.append(char) | |
if current: | |
parts.append("".join(current).strip()) | |
for part in parts: | |
param_match = re.match( | |
r"(Tensor|int|bool)\s*(?:\([^)]*\))?\s*(\w+)", part | |
) | |
if param_match: | |
params.append((param_match.group(2), param_match.group(1))) | |
if params: | |
signatures.append((func_name, params_str, params)) | |
seen = {} | |
for name, sig, params in signatures: | |
key = (name, tuple(p[0] for p in params)) | |
if key not in seen: | |
seen[key] = (name, sig, params) | |
return [v for v in seen.values() if len(v[0]) > 2 and v[0][0].islower()] | |
def generate_tests(signatures, module_name): | |
func_names = sorted(set(name for name, _, _ in signatures)) | |
imports = f"from {module_name} import {', '.join(func_names)}" | |
lines = ["import pytest", "import torch", "", imports, ""] | |
test_num = 0 | |
for func_name, sig, params in signatures: | |
test_num += 1 | |
lines.append(f"def test_{func_name}():") | |
args = [] | |
for pname, ptype in params: | |
if ptype == "Tensor": | |
lines.append(f" {pname} = torch.randn([4])") | |
args.append(pname) | |
elif ptype == "int": | |
args.append("1") | |
elif ptype == "bool": | |
args.append("True") | |
else: | |
args.append("None") | |
lines.extend( | |
[ | |
f" result = {func_name}({', '.join(args)})", | |
" assert result is not None", | |
"", | |
] | |
) | |
return "\n".join(lines) | |
def main(): | |
import sys | |
if len(sys.argv) != 2: | |
print("Usage: python test_gen.py <path_to_binary>") | |
sys.exit(1) | |
binary_path = sys.argv[1] | |
module_name = binary_path.split(".")[0].lstrip("_") | |
strings = extract_strings(binary_path) | |
signatures = find_torch_signatures(strings) | |
print(f"Found {len(signatures)} signatures") | |
with open("test_torch_functions.py", "w") as f: | |
f.write(generate_tests(signatures, module_name)) | |
print("Generated test_torch_functions.py") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment