Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Created April 22, 2026 19:06
Show Gist options
  • Select an option

  • Save llandsmeer/94f9f542d0b876a71e0f7734b0741c6c to your computer and use it in GitHub Desktop.

Select an option

Save llandsmeer/94f9f542d0b876a71e0f7734b0741c6c to your computer and use it in GitHub Desktop.
MiniJAX
# pip install keystone-engine
import inspect
import struct
from keystone import Ks, KS_ARCH_X86, KS_MODE_64
import ctypes
import mmap
def map_type(x):
if isinstance(x, int): return 'i64'
if isinstance(x, float): return 'f64'
assert False
def peephole_optimize(instructions):
instructions = list(instructions)
replace_list = [
(1, ['movq rax, xmm8', 'movq xmm0, rax'], ['movq xmm0, xmm8']),
(1, ['movq rax, xmm0', 'movq xmm8, rax'], ['movq xmm8, xmm0']),
(0, ['pop rax', 'push rax'], []),
(1, ['push rax', 'pop rax'], []),
(1, ['movq rax, xmm8', 'movq xmm8, rax'], []),
(1, ['movq rax, xmm9', 'movq xmm9, rax'], []),
(2, ['mov rax, rdi', 'cvtsi2sd xmm8, rax'], ['cvtsi2sd xmm8, rdi']),
(2, ['mov rax, rsi', 'cvtsi2sd xmm8, rax'], ['cvtsi2sd xmm8, rsi']),
(2, ['mov rax, rdx', 'cvtsi2sd xmm8, rax'], ['cvtsi2sd xmm8, rdx']),
(2, ['mov rax, rcx', 'cvtsi2sd xmm8, rax'], ['cvtsi2sd xmm8, rcx']),
(2, ['push rax', 'pop rdx'], ['mov rdx, rax']),
(4, ['pop rdx', 'movq xmm9, rdx'], ['movq xmm9, [rsp]', 'add rsp, 8']),
(3, ['movq rax, xmm8', 'mov rdx, rax'], ['movq rdx, xmm8']),
(3, ['addsd xmm8, xmm9', 'movq xmm0, xmm8'], ['addsd xmm0, xmm9']),
(3, ['mulsd xmm8, xmm9', 'movq xmm0, xmm8'], ['mulsd xmm0, xmm8']),
]
for npass in range(5):
for i in range(len(instructions)-1, -1, -1):
for order, a, b in replace_list:
if order > npass:
continue
if instructions[i:i+len(a)] == a:
instructions[i:i+len(a)] = b
return instructions
class CompiledFunction:
def __init__(self, graph, trace_args):
def map_type(t):
if t == 'i64': return ctypes.c_int64
if t == 'f64': return ctypes.c_double
assert False
instructions = list(graph.compile())
if graph.type == 'i64':
instructions.append('pop rax')
if graph.type == 'f64':
instructions.append('pop rax')
instructions.append('movq xmm0, rax')
instructions.append('ret')
instructions = peephole_optimize(instructions)
self.instructions = instructions
instructions = ';\n'.join(instructions)
ks = Ks(KS_ARCH_X86, KS_MODE_64)
encoding, _ = ks.asm(instructions)
code = bytes(encoding)
size = len(code)
mem = mmap.mmap(-1, size, prot=mmap.PROT_READ | mmap.PROT_WRITE | mmap.PROT_EXEC)
mem.write(code)
FUNC_TYPE = ctypes.CFUNCTYPE(map_type(graph.type), *(map_type(t.type) for t in trace_args))
func = FUNC_TYPE(ctypes.addressof(ctypes.c_int.from_buffer(mem)))
self.func = func
self.mem = mem # else it gets released
self.text = instructions
def __call__(self, *args):
return self.func(*args)
class ExprBase:
type: str
def __add__(self, other):
return Add(self, self.parse(other))
def __radd__(self, other):
return Add(self.parse(other), self)
def __mul__(self, other):
return Mul(self, self.parse(other))
def __rmul__(self, other):
return Mul(self.parse(other), self)
@staticmethod
def type_coerce(a, b):
if a == b: return a
if a == 'f64' or b == 'f64': return 'f64'
assert False
@staticmethod
def type_ensure(e, t):
if e.type == t:
return e
else:
return Convert(e, e.type, t)
@staticmethod
def parse(x):
if isinstance(x, ExprBase): return x
if isinstance(x, int): return Constant(x, 'i64')
if isinstance(x, float): return Constant(x, 'f64')
assert False
class Constant(ExprBase):
def __init__(self, value, type):
self.value = value
self.type = type
def __repr__(self):
return f'{self.value}'
def compile(self):
if self.type == 'i64':
yield f'mov rax, {self.value}'
yield 'push rax'
if self.type == 'f64':
bits = struct.unpack('Q', struct.pack('d', self.value))[0]
yield f'mov rax, {bits}'
yield 'push rax'
class Convert(ExprBase):
def __init__(self, x, tfrom, tto):
self.value = x
self.tfrom = tfrom
self.type = tto
def __repr__(self):
return f'Convert({self.value} -> {self.type})'
def compile(self):
assert self.tfrom == 'i64' and self.type == 'f64'
yield from self.value.compile()
yield 'pop rax'
yield 'cvtsi2sd xmm8, rax'
yield 'movq rax, xmm8'
yield 'push rax'
class Add(ExprBase):
def __init__(self, left: ExprBase, right: ExprBase):
self.type = self.type_coerce(left.type, right.type)
self.lhs = self.type_ensure(left, self.type)
self.rhs = self.type_ensure(right, self.type)
def __repr__(self):
return f'Add({self.lhs}, {self.rhs})'
def compile(self):
yield from self.lhs.compile()
yield from self.rhs.compile()
if self.type == 'i64':
yield 'pop rax'
yield 'pop rdx'
yield 'add rax, rdx'
yield 'push rax'
if self.type == 'f64':
yield 'pop rax'
yield 'movq xmm8, rax'
yield 'pop rdx'
yield 'movq xmm9, rdx'
yield 'addsd xmm8, xmm9'
yield 'movq rax, xmm8'
yield 'push rax'
class Mul(ExprBase):
def __init__(self, left: ExprBase, right: ExprBase):
self.type = self.type_coerce(left.type, right.type)
self.lhs = self.type_ensure(left, self.type)
self.rhs = self.type_ensure(right, self.type)
def __repr__(self):
return f'Mul({self.lhs}, {self.rhs})'
def compile(self):
yield from self.lhs.compile()
yield from self.rhs.compile()
if self.type == 'i64':
yield 'pop rax'
yield 'pop rdx'
yield 'mul rdx'
yield 'push rax'
if self.type == 'f64':
yield 'pop rax'
yield 'movq xmm8, rax'
yield 'pop rdx'
yield 'movq xmm9, rdx'
yield 'mulsd xmm8, xmm9'
yield 'movq rax, xmm8'
yield 'push rax'
class Argument(ExprBase):
def __init__(self, name: str, loc: int, type: str):
self.type = type
self.name = name
self.loc = loc
self.typeloc = None # float/int specific location
def __repr__(self):
return f'<{self.name}@({self.type}.{self.loc})>'
def compile(self):
if self.type == 'i64':
regs = 'rdi', 'rsi', 'rdx', 'rcx', 'r8', 'r9'
reg = regs[self.loc]
yield f'mov rax, {reg}'
yield f'push rax'
if self.type == 'f64':
regs = 'xmm0', 'xmm1', 'xmm2', 'xmm3', 'xmm4', 'xmm5', 'xmm6', 'xmm7'
reg = regs[self.loc]
yield f'movq rax, {reg}'
yield f'push rax'
def trace_function(f, example_args):
def build_trace_args(arg_names, example_args):
int_loc, float_loc = 0, 0
trace_args = []
for name, arg in zip(arg_names, example_args):
t = map_type(arg)
loc = int_loc if t == 'i64' else float_loc
trace_args.append(Argument(name, loc, t))
if t == 'i64': int_loc += 1
else: float_loc += 1
return trace_args
arg_names = inspect.getfullargspec(f).args
assert len(example_args) == len(arg_names)
trace_args = build_trace_args(arg_names, example_args)
graph = f(*trace_args)
return graph, trace_args
class jit:
def __init__(self, f):
self.f_python = f
self.compile_cache = {}
def __call__(self, *args):
func = self.get_compiled(args)
return func(*args)
def get_compiled(self, args):
types = tuple(map_type(x) for x in args) # type x shape
if types not in self.compile_cache:
print("WARNING, RECOMPILING")
graph, trace_args = trace_function(self.f_python, args)
func = CompiledFunction(graph, trace_args)
self.compile_cache[types] = func
return self.compile_cache[types]
def get_asm(self, args):
return self.get_compiled(args).text
def f(a, b, c):
return a + b * c
graph, args = trace_function(f, (1, 3., 1)) # Add(<a>, Mul(<b>, <c>))
f_aot = CompiledFunction(graph, args)
print(f_aot(1, 3., 1))
@jit
def fma(a, b, c):
return a + b * c
print(fma(1, 1, 1))
print(fma(10, 1, 1))
print(fma(1, 10, 1))
print(fma(1, 1., 1))
print()
print(fma.get_asm((1, 1, 1.)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment