Skip to content

Instantly share code, notes, and snippets.

@geofft
Last active January 26, 2025 21:38
Show Gist options
  • Save geofft/947db6bbf98cba27ba383c95c7de54ac to your computer and use it in GitHub Desktop.
Save geofft/947db6bbf98cba27ba383c95c7de54ac to your computer and use it in GitHub Desktop.
I can't believe it's not numba
  • main.cc: A very abbreviated version of chapter 3 of the LLVM tutorial, hard-coding a single function
  • main.py: A translation into Python using llvmlite, also hard-coding a main function
  • pycomp.py: A compiler for an extremely small subset of Python.
$ ../venv/bin/python3 pycomp.py < sourcecode.py
$ cc -o output output.o
$ ./output
24.000000
#include <iostream>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
using namespace llvm;
static std::unique_ptr<LLVMContext> TheContext;
static std::unique_ptr<Module> TheModule;
static std::unique_ptr<IRBuilder<>> Builder;
/// ExprAST - Base class for all expression nodes.
class ExprAST {
public:
virtual ~ExprAST() = default;
virtual Value *codegen() = 0;
};
/// NumberExprAST - Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double Val;
public:
NumberExprAST(double Val) : Val(Val) {}
Value *codegen() override;
};
Value *NumberExprAST::codegen() {
return ConstantFP::get(*TheContext, APFloat(Val));
}
int main() {
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("my cool jit", *TheContext);
Builder = std::make_unique<IRBuilder<>>(*TheContext);
Type *doublety = Type::getDoubleTy(*TheContext);
FunctionType *FT = FunctionType::get(doublety, std::vector{doublety}, false);
Function *TheFunction = Function::Create(FT, Function::ExternalLinkage, "main", *TheModule);
BasicBlock *bb = BasicBlock::Create(*TheContext, "my basic block", TheFunction);
Builder->SetInsertPoint(bb);
NumberExprAST two(2);
Value *two_val = two.codegen();
Value *arg = TheFunction->args().begin();
arg->setName("x");
Value *b = Builder->CreateFAdd(two_val, arg);
Value *c = Builder->CreateFAdd(arg, two_val);
Value *d = Builder->CreateFAdd(b, c);
Builder->CreateRet(d);
// copied from https://llvm.org/docs/NewPassManager.html
// Create the analysis managers.
// These must be declared in this order so that they are destroyed in the
// correct order due to inter-analysis-manager references.
LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;
// Create the new pass manager builder.
// Take a look at the PassBuilder constructor parameters for more
// customization, e.g. specifying a TargetMachine or various debugging
// options.
PassBuilder PB;
// Register all the basic analyses with the managers.
PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
// Create the pass manager.
// This one corresponds to a typical -O2 optimization pipeline.
ModulePassManager MPM = PB.buildPerModuleDefaultPipeline(OptimizationLevel::O2);
// Optimize the IR!
MPM.run(*TheModule, MAM);
// end copy
if (verifyFunction(*TheFunction)) {
std::cout << "LLVM was unhappy\n";
} else {
TheFunction->print(outs());
}
}
import llvmlite
llvmlite.opaque_pointers_enabled = True
from llvmlite import ir
module = ir.Module("my cool not a jit")
builder = ir.IRBuilder()
doublety = ir.DoubleType()
ft = ir.FunctionType(doublety, [doublety], False)
func = ir.Function(module, ft, name="two_x_plus_four")
bb = func.append_basic_block(name="my basic block")
builder.position_at_start(bb)
two = ir.Constant(doublety, 2)
arg = func.args[0]
b = builder.fadd(two, arg)
c = builder.fadd(arg, two)
d = builder.fadd(b, c)
builder.ret(d)
intty = ir.IntType(32)
maintype = ir.FunctionType(intty, [], False)
main = ir.Function(module, maintype, name="main")
bb2 = main.append_basic_block(name="bb2")
builder.position_at_start(bb2)
# double result = func(10);
result = builder.call(func, [ir.Constant(doublety, 10)])
charty = ir.IntType(8)
#char_star = ir.PointerType(charty)
pointerty = ir.PointerType()
def c_string(x: bytes):
arrayty = ir.ArrayType(charty, len(x) + 1)
array = ir.Constant(arrayty, bytearray(x + b"\0"))
gv = ir.GlobalVariable(module, arrayty, "lol")
gv.initializer = array
#return gv.bitcast(char_star)
return gv
printf = ir.Function(module, ir.FunctionType(intty, [pointerty], True), name="printf")
# printf("%f\n", result);
builder.call(printf, [c_string(b"%f\n"), result])
# return 0;
builder.ret(ir.Constant(intty, 0))
print(module)
import ast
import pathlib
import sys
import llvmlite
llvmlite.opaque_pointers_enabled = True
from llvmlite import ir
import llvmlite.binding
tree = ast.parse(sys.stdin.read())
module = ir.Module("my cool not a jit")
builder = ir.IRBuilder()
functions = {}
def generate_code_for(value, argmap):
match value:
case ast.BinOp(left, op, right):
lv = generate_code_for(left, argmap)
rv = generate_code_for(right, argmap)
match op:
case ast.Add():
return builder.fadd(lv, rv)
case _:
raise RuntimeError(f"I don't know how to {op}")
case ast.Name(id):
if id in argmap:
return argmap[id]
if id in functions:
return functions[id]
raise NameError(f"No variable or function {id}")
case ast.Constant(value):
match value:
case int():
return ir.Constant(ir.IntType(32), value)
case float():
return ir.Constant(ir.DoubleType(), value)
case bytes():
arrayty = ir.ArrayType(ir.IntType(8), len(value) + 1)
array = ir.Constant(arrayty, bytearray(value + b"\0"))
gv = ir.GlobalVariable(module, arrayty, "lol")
gv.initializer = array
return gv
case _:
raise RuntimeError(f"Can't handle type of constant {value!r}")
case ast.Call(func, args):
return builder.call(
generate_code_for(func, argmap),
[generate_code_for(arg, argmap) for arg in args],
)
case _:
raise RuntimeError(f"I can't compile a {ast.dump(value)}")
def type_for(annotation):
match annotation:
case ast.Name(id):
match id:
case "int":
return ir.IntType(32)
case "float":
return ir.DoubleType()
case "bytearray":
return ir.PointerType()
case _:
raise RuntimeError(f"Don't know how to handle type {id}")
case None:
raise RuntimeError("write some type annotations!")
case _:
raise RuntimeError(f"Your type annotation {ast.dump(annotation)} is too complicated")
for thing in tree.body:
match thing:
case ast.FunctionDef(name, args, body, returns):
ft = ir.FunctionType(type_for(thing.returns), [type_for(arg.annotation) for arg in args.args], args.vararg is not None)
func = ir.Function(module, ft, name=name)
functions[name] = func
if len(body) == 1 and isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant) and body[0].value.value == ...:
# This is an FFI function
pass
else:
# args = [arg(arg='x')]
# func.args = [ir.Value(...)]
#argmap = {"x": ir.Value(...)}
argmap = dict(zip((arg.arg for arg in args.args), func.args))
bb = func.append_basic_block(name="my basic block")
builder.position_at_start(bb)
for statement in body:
match statement:
case ast.Return(value):
llvm_value = generate_code_for(value, argmap)
builder.ret(llvm_value)
case ast.Expr(value):
generate_code_for(value, argmap)
# discard the result
case _:
raise RuntimeError(f"I cannot execute {ast.dump(statement)}")
llvmlite.binding.initialize()
llvmlite.binding.initialize_native_target()
llvmlite.binding.initialize_native_asmprinter()
moduleref = llvmlite.binding.parse_assembly(str(module))
target = llvmlite.binding.Target.from_default_triple()
target_machine = target.create_target_machine()
tuning_options = llvmlite.binding.PipelineTuningOptions()
pb = llvmlite.binding.PassBuilder(target_machine, tuning_options)
pb.getModulePassManager().run(moduleref, pb)
obj = target_machine.emit_object(moduleref)
pathlib.Path("output.o").write_bytes(obj)
def two_x_plus_four(x: float) -> float:
return (x + 2.0) + (2.0 + x)
def printf(formatstring: bytearray, *args) -> int:
...
def main() -> int:
printf(b"%f\n", two_x_plus_four(10.0))
return 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment