Skip to content

Instantly share code, notes, and snippets.

@drbh
Created August 25, 2025 18:41
Show Gist options
  • Save drbh/b5ef9b3466b1ddf8388053a00cfe59ce to your computer and use it in GitHub Desktop.
Save drbh/b5ef9b3466b1ddf8388053a00cfe59ce to your computer and use it in GitHub Desktop.
generate template for tests for all torch bindings exposed in a so
import re
def extract_strings(binary_path, min_length=4):
with open(binary_path, "rb") as f:
data = f.read()
strings = []
current = []
for byte in data:
if 32 <= byte <= 126:
current.append(chr(byte))
else:
if len(current) >= min_length:
strings.append("".join(current))
current = []
if current and len(current) >= min_length:
strings.append("".join(current))
return strings
def find_torch_signatures(strings):
signatures = []
pattern = r"(\w+)\s*\(([^)]*)\)(?:\s*->\s*([^,\n\r]+))?"
for string in strings:
if not any(k in string for k in ["Tensor", "int", "bool", "->", "!"]):
continue
if len(string) > 1000:
continue
for match in re.finditer(pattern, string):
func_name = match.group(1)
params_str = match.group(2) or ""
if (
len(func_name) < 3
or func_name
in ["if", "for", "while", "return", "sizeof", "class", "struct"]
or func_name.startswith("_")
or "::" in params_str
):
continue
params = []
if params_str.strip():
parts = []
depth = 0
current = []
for char in params_str:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
elif char == "," and depth == 0:
parts.append("".join(current).strip())
current = []
continue
current.append(char)
if current:
parts.append("".join(current).strip())
for part in parts:
param_match = re.match(
r"(Tensor|int|bool)\s*(?:\([^)]*\))?\s*(\w+)", part
)
if param_match:
params.append((param_match.group(2), param_match.group(1)))
if params:
signatures.append((func_name, params_str, params))
seen = {}
for name, sig, params in signatures:
key = (name, tuple(p[0] for p in params))
if key not in seen:
seen[key] = (name, sig, params)
return [v for v in seen.values() if len(v[0]) > 2 and v[0][0].islower()]
def generate_tests(signatures, module_name):
func_names = sorted(set(name for name, _, _ in signatures))
imports = f"from {module_name} import {', '.join(func_names)}"
lines = ["import pytest", "import torch", "", imports, ""]
test_num = 0
for func_name, sig, params in signatures:
test_num += 1
lines.append(f"def test_{func_name}():")
args = []
for pname, ptype in params:
if ptype == "Tensor":
lines.append(f" {pname} = torch.randn([4])")
args.append(pname)
elif ptype == "int":
args.append("1")
elif ptype == "bool":
args.append("True")
else:
args.append("None")
lines.extend(
[
f" result = {func_name}({', '.join(args)})",
" assert result is not None",
"",
]
)
return "\n".join(lines)
def main():
import sys
if len(sys.argv) != 2:
print("Usage: python test_gen.py <path_to_binary>")
sys.exit(1)
binary_path = sys.argv[1]
module_name = binary_path.split(".")[0].lstrip("_")
strings = extract_strings(binary_path)
signatures = find_torch_signatures(strings)
print(f"Found {len(signatures)} signatures")
with open("test_torch_functions.py", "w") as f:
f.write(generate_tests(signatures, module_name))
print("Generated test_torch_functions.py")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment