Created
April 22, 2026 19:06
-
-
Save llandsmeer/94f9f542d0b876a71e0f7734b0741c6c to your computer and use it in GitHub Desktop.
MiniJAX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # pip install keystone-engine | |
| import inspect | |
| import struct | |
| from keystone import Ks, KS_ARCH_X86, KS_MODE_64 | |
| import ctypes | |
| import mmap | |
| def map_type(x): | |
| if isinstance(x, int): return 'i64' | |
| if isinstance(x, float): return 'f64' | |
| assert False | |
| def peephole_optimize(instructions): | |
| instructions = list(instructions) | |
| replace_list = [ | |
| (1, ['movq rax, xmm8', 'movq xmm0, rax'], ['movq xmm0, xmm8']), | |
| (1, ['movq rax, xmm0', 'movq xmm8, rax'], ['movq xmm8, xmm0']), | |
| (0, ['pop rax', 'push rax'], []), | |
| (1, ['push rax', 'pop rax'], []), | |
| (1, ['movq rax, xmm8', 'movq xmm8, rax'], []), | |
| (1, ['movq rax, xmm9', 'movq xmm9, rax'], []), | |
| (2, ['mov rax, rdi', 'cvtsi2sd xmm8, rax'], ['cvtsi2sd xmm8, rdi']), | |
| (2, ['mov rax, rsi', 'cvtsi2sd xmm8, rax'], ['cvtsi2sd xmm8, rsi']), | |
| (2, ['mov rax, rdx', 'cvtsi2sd xmm8, rax'], ['cvtsi2sd xmm8, rdx']), | |
| (2, ['mov rax, rcx', 'cvtsi2sd xmm8, rax'], ['cvtsi2sd xmm8, rcx']), | |
| (2, ['push rax', 'pop rdx'], ['mov rdx, rax']), | |
| (4, ['pop rdx', 'movq xmm9, rdx'], ['movq xmm9, [rsp]', 'add rsp, 8']), | |
| (3, ['movq rax, xmm8', 'mov rdx, rax'], ['movq rdx, xmm8']), | |
| (3, ['addsd xmm8, xmm9', 'movq xmm0, xmm8'], ['addsd xmm0, xmm9']), | |
| (3, ['mulsd xmm8, xmm9', 'movq xmm0, xmm8'], ['mulsd xmm0, xmm8']), | |
| ] | |
| for npass in range(5): | |
| for i in range(len(instructions)-1, -1, -1): | |
| for order, a, b in replace_list: | |
| if order > npass: | |
| continue | |
| if instructions[i:i+len(a)] == a: | |
| instructions[i:i+len(a)] = b | |
| return instructions | |
| class CompiledFunction: | |
| def __init__(self, graph, trace_args): | |
| def map_type(t): | |
| if t == 'i64': return ctypes.c_int64 | |
| if t == 'f64': return ctypes.c_double | |
| assert False | |
| instructions = list(graph.compile()) | |
| if graph.type == 'i64': | |
| instructions.append('pop rax') | |
| if graph.type == 'f64': | |
| instructions.append('pop rax') | |
| instructions.append('movq xmm0, rax') | |
| instructions.append('ret') | |
| instructions = peephole_optimize(instructions) | |
| self.instructions = instructions | |
| instructions = ';\n'.join(instructions) | |
| ks = Ks(KS_ARCH_X86, KS_MODE_64) | |
| encoding, _ = ks.asm(instructions) | |
| code = bytes(encoding) | |
| size = len(code) | |
| mem = mmap.mmap(-1, size, prot=mmap.PROT_READ | mmap.PROT_WRITE | mmap.PROT_EXEC) | |
| mem.write(code) | |
| FUNC_TYPE = ctypes.CFUNCTYPE(map_type(graph.type), *(map_type(t.type) for t in trace_args)) | |
| func = FUNC_TYPE(ctypes.addressof(ctypes.c_int.from_buffer(mem))) | |
| self.func = func | |
| self.mem = mem # else it gets released | |
| self.text = instructions | |
| def __call__(self, *args): | |
| return self.func(*args) | |
| class ExprBase: | |
| type: str | |
| def __add__(self, other): | |
| return Add(self, self.parse(other)) | |
| def __radd__(self, other): | |
| return Add(self.parse(other), self) | |
| def __mul__(self, other): | |
| return Mul(self, self.parse(other)) | |
| def __rmul__(self, other): | |
| return Mul(self.parse(other), self) | |
| @staticmethod | |
| def type_coerce(a, b): | |
| if a == b: return a | |
| if a == 'f64' or b == 'f64': return 'f64' | |
| assert False | |
| @staticmethod | |
| def type_ensure(e, t): | |
| if e.type == t: | |
| return e | |
| else: | |
| return Convert(e, e.type, t) | |
| @staticmethod | |
| def parse(x): | |
| if isinstance(x, ExprBase): return x | |
| if isinstance(x, int): return Constant(x, 'i64') | |
| if isinstance(x, float): return Constant(x, 'f64') | |
| assert False | |
| class Constant(ExprBase): | |
| def __init__(self, value, type): | |
| self.value = value | |
| self.type = type | |
| def __repr__(self): | |
| return f'{self.value}' | |
| def compile(self): | |
| if self.type == 'i64': | |
| yield f'mov rax, {self.value}' | |
| yield 'push rax' | |
| if self.type == 'f64': | |
| bits = struct.unpack('Q', struct.pack('d', self.value))[0] | |
| yield f'mov rax, {bits}' | |
| yield 'push rax' | |
| class Convert(ExprBase): | |
| def __init__(self, x, tfrom, tto): | |
| self.value = x | |
| self.tfrom = tfrom | |
| self.type = tto | |
| def __repr__(self): | |
| return f'Convert({self.value} -> {self.type})' | |
| def compile(self): | |
| assert self.tfrom == 'i64' and self.type == 'f64' | |
| yield from self.value.compile() | |
| yield 'pop rax' | |
| yield 'cvtsi2sd xmm8, rax' | |
| yield 'movq rax, xmm8' | |
| yield 'push rax' | |
| class Add(ExprBase): | |
| def __init__(self, left: ExprBase, right: ExprBase): | |
| self.type = self.type_coerce(left.type, right.type) | |
| self.lhs = self.type_ensure(left, self.type) | |
| self.rhs = self.type_ensure(right, self.type) | |
| def __repr__(self): | |
| return f'Add({self.lhs}, {self.rhs})' | |
| def compile(self): | |
| yield from self.lhs.compile() | |
| yield from self.rhs.compile() | |
| if self.type == 'i64': | |
| yield 'pop rax' | |
| yield 'pop rdx' | |
| yield 'add rax, rdx' | |
| yield 'push rax' | |
| if self.type == 'f64': | |
| yield 'pop rax' | |
| yield 'movq xmm8, rax' | |
| yield 'pop rdx' | |
| yield 'movq xmm9, rdx' | |
| yield 'addsd xmm8, xmm9' | |
| yield 'movq rax, xmm8' | |
| yield 'push rax' | |
| class Mul(ExprBase): | |
| def __init__(self, left: ExprBase, right: ExprBase): | |
| self.type = self.type_coerce(left.type, right.type) | |
| self.lhs = self.type_ensure(left, self.type) | |
| self.rhs = self.type_ensure(right, self.type) | |
| def __repr__(self): | |
| return f'Mul({self.lhs}, {self.rhs})' | |
| def compile(self): | |
| yield from self.lhs.compile() | |
| yield from self.rhs.compile() | |
| if self.type == 'i64': | |
| yield 'pop rax' | |
| yield 'pop rdx' | |
| yield 'mul rdx' | |
| yield 'push rax' | |
| if self.type == 'f64': | |
| yield 'pop rax' | |
| yield 'movq xmm8, rax' | |
| yield 'pop rdx' | |
| yield 'movq xmm9, rdx' | |
| yield 'mulsd xmm8, xmm9' | |
| yield 'movq rax, xmm8' | |
| yield 'push rax' | |
| class Argument(ExprBase): | |
| def __init__(self, name: str, loc: int, type: str): | |
| self.type = type | |
| self.name = name | |
| self.loc = loc | |
| self.typeloc = None # float/int specific location | |
| def __repr__(self): | |
| return f'<{self.name}@({self.type}.{self.loc})>' | |
| def compile(self): | |
| if self.type == 'i64': | |
| regs = 'rdi', 'rsi', 'rdx', 'rcx', 'r8', 'r9' | |
| reg = regs[self.loc] | |
| yield f'mov rax, {reg}' | |
| yield f'push rax' | |
| if self.type == 'f64': | |
| regs = 'xmm0', 'xmm1', 'xmm2', 'xmm3', 'xmm4', 'xmm5', 'xmm6', 'xmm7' | |
| reg = regs[self.loc] | |
| yield f'movq rax, {reg}' | |
| yield f'push rax' | |
| def trace_function(f, example_args): | |
| def build_trace_args(arg_names, example_args): | |
| int_loc, float_loc = 0, 0 | |
| trace_args = [] | |
| for name, arg in zip(arg_names, example_args): | |
| t = map_type(arg) | |
| loc = int_loc if t == 'i64' else float_loc | |
| trace_args.append(Argument(name, loc, t)) | |
| if t == 'i64': int_loc += 1 | |
| else: float_loc += 1 | |
| return trace_args | |
| arg_names = inspect.getfullargspec(f).args | |
| assert len(example_args) == len(arg_names) | |
| trace_args = build_trace_args(arg_names, example_args) | |
| graph = f(*trace_args) | |
| return graph, trace_args | |
| class jit: | |
| def __init__(self, f): | |
| self.f_python = f | |
| self.compile_cache = {} | |
| def __call__(self, *args): | |
| func = self.get_compiled(args) | |
| return func(*args) | |
| def get_compiled(self, args): | |
| types = tuple(map_type(x) for x in args) # type x shape | |
| if types not in self.compile_cache: | |
| print("WARNING, RECOMPILING") | |
| graph, trace_args = trace_function(self.f_python, args) | |
| func = CompiledFunction(graph, trace_args) | |
| self.compile_cache[types] = func | |
| return self.compile_cache[types] | |
| def get_asm(self, args): | |
| return self.get_compiled(args).text | |
| def f(a, b, c): | |
| return a + b * c | |
| graph, args = trace_function(f, (1, 3., 1)) # Add(<a>, Mul(<b>, <c>)) | |
| f_aot = CompiledFunction(graph, args) | |
| print(f_aot(1, 3., 1)) | |
| @jit | |
| def fma(a, b, c): | |
| return a + b * c | |
| print(fma(1, 1, 1)) | |
| print(fma(10, 1, 1)) | |
| print(fma(1, 10, 1)) | |
| print(fma(1, 1., 1)) | |
| print() | |
| print(fma.get_asm((1, 1, 1.))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment