Last active
December 18, 2023 09:11
-
-
Save usbuild/b2d0fccf11afcbe8f6878007626865c2 to your computer and use it in GitHub Desktop.
A naive calculator with interpreter and jit. Just for teaching and demo.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <ctype.h> | |
#include <stdio.h> | |
#include <stdlib.h> | |
#include <string.h> | |
#include <sys/mman.h> | |
#include <assert.h> | |
#define STACK_MAX 100 | |
#define CODE_SIZE (2 * 1024 * 1024) | |
typedef unsigned char uchar; | |
int precedent(char op) { | |
switch (op) { | |
case '(': | |
return 0; | |
case '-': | |
case '+': | |
return 1; | |
case '*': | |
case '/': | |
return 2; | |
} | |
} | |
int isoperator(char op) { | |
if (op == '-' || op == '+' || op == '*' || op == '/') return 1; | |
return 0; | |
} | |
char *postfixify(const char *buf) { | |
char stack[STACK_MAX] = {0}; | |
char *sp = stack; | |
char *output = calloc(STACK_MAX, 1); | |
char *op = output; | |
while (*buf) { | |
char c = *buf++; | |
int digit = 0; | |
if (c == ' ') continue; | |
while (isalnum(c)) { | |
digit = 1; | |
*(op++) = c; | |
c = *buf++; | |
} | |
if (digit) { | |
*(op++) = ' '; | |
} | |
if (!c) break; | |
if (c == '(') { | |
*(sp++) = c; | |
continue; | |
} | |
if (c == ')') { | |
int i = -1; | |
for (; sp[i] != '('; --i) { | |
*(op++) = sp[i]; | |
*(op++) = ' '; | |
} | |
sp += i; | |
continue; | |
} | |
if (isoperator(c)) { | |
int p = precedent(c); | |
while (sp > stack) { | |
if (precedent(sp[-1]) >= p) { | |
*(op++) = sp[-1]; | |
*(op++) = ' '; | |
sp--; | |
} else { | |
break; | |
} | |
} | |
*(sp++) = c; | |
continue; | |
} | |
} | |
while (sp > stack) { | |
*(op++) = *--sp; | |
*(op++) = ' '; | |
} | |
return output; | |
} | |
enum OP { | |
NOP = 0, | |
RET, | |
PUSHV, | |
PUSHC, | |
ADD, | |
SUB, | |
MUL, | |
DIV, | |
}; | |
int *bccompile(const char *buf) { | |
int *bc = calloc(STACK_MAX, sizeof(int)); | |
int *bcp = bc; | |
char *postfix = postfixify(buf); | |
char *p = postfix; | |
while (*p) { | |
char c = *p++; | |
if (isalnum(c)) { | |
if (islower(c)) { | |
*(bcp++) = PUSHV; | |
*(bcp++) = c - 'a'; | |
} else { | |
int val = 0; | |
while (isdigit(c)) { | |
val = val * 10 + c - '0'; | |
c = *p++; | |
} | |
*(bcp++) = PUSHC; | |
*(bcp++) = val; | |
} | |
} | |
if (c == ' ') continue; | |
switch (c) { | |
case '+': | |
*(bcp++) = ADD; | |
break; | |
case '-': | |
*(bcp++) = SUB; | |
break; | |
case '*': | |
*(bcp++) = MUL; | |
break; | |
case '/': | |
*(bcp++) = DIV; | |
break; | |
} | |
} | |
*(bcp++) = RET; | |
free(postfix); | |
return bc; | |
} | |
int interpret(const int *bcp, const int *args) { | |
int stack[STACK_MAX] = {0}; | |
int *sp = stack; | |
bcp--; | |
while (*++bcp) { | |
switch (*bcp) { | |
case NOP: | |
break; | |
case RET: | |
return sp[-1]; | |
case PUSHV: | |
*(sp++) = args[*++bcp]; | |
break; | |
case PUSHC: | |
*(sp++) = *(++bcp); | |
break; | |
case ADD: | |
sp[-2] = sp[-2] + sp[-1]; | |
sp--; | |
break; | |
case SUB: | |
sp[-2] = sp[-2] - sp[-1]; | |
sp--; | |
break; | |
case MUL: | |
sp[-2] = sp[-2] * sp[-1]; | |
sp--; | |
break; | |
case DIV: | |
sp[-2] = sp[-2] / sp[-1]; | |
sp--; | |
break; | |
} | |
} | |
} | |
#define GPRDEF(_) \ | |
_(RAX) \ | |
_(RCX) \ | |
_(RDX) \ | |
_(RBX) \ | |
_(RSP) \ | |
_(RBP) _(RSI) _(RDI) _(R8) _(R9) _(R10) _(R11) _(R12) _(R13) _(R14) _(R15) | |
#define RIDENUM(name) RID_##name, | |
enum { GPRDEF(RIDENUM) RID_MAX }; | |
#define REX_64 0b01001000 | |
#define MOV_RR 0x89 | |
#define MODRM(name1, name2) ((0b11 << 6) | (name1 << 3) | (name2)) | |
uchar *bin_rr(uchar *code, uchar op, uchar src, uchar target, int ext) { | |
uchar rex = REX_64; | |
if (src & 0x8) { | |
rex |= 1 << 2; | |
src &= ~(uchar)(0x8); | |
} | |
if (target & 0x8) { | |
rex |= 1; | |
target &= ~(uchar)(0x8); | |
} | |
*(code++) = rex; | |
if (ext) { | |
*(code++) = 0x0F; | |
} | |
*(code++) = op; | |
*(code++) = ((0b11 << 6) | (src << 3) | (target)); | |
return code; | |
} | |
// maybe buggy | |
uchar *bin_mr(uchar *code, uchar op, uchar src, uchar target, char offset, | |
int ext) { | |
uchar rex = REX_64; | |
if (src & 0x8) { | |
rex |= 1 << 2; | |
src &= ~(uchar)(0x8); | |
} | |
if (target & 0x8) { | |
rex |= 1; | |
target &= ~(uchar)(0x8); | |
} | |
*(code++) = rex; | |
if (ext) { | |
*(code++) = 0x0F; | |
} | |
*(code++) = op; | |
*(code++) = ((0b01 << 6) | (src << 3) | (target)); | |
*(code++) = offset; | |
*(code++) = 0xff; | |
} | |
uchar *movq_ir(uchar *code, long val, uchar target) { | |
uchar *v = (uchar *)&val; | |
uchar rex = REX_64; | |
if (target & 0x8) { | |
rex |= 1; | |
target &= ~(uchar)(0x8); | |
} | |
*(code++) = rex; | |
*(code++) = 0xB8 + target; | |
*(code++) = v[0]; | |
*(code++) = v[1]; | |
*(code++) = v[2]; | |
*(code++) = v[3]; | |
*(code++) = v[4]; | |
*(code++) = v[5]; | |
*(code++) = v[6]; | |
*(code++) = v[7]; | |
return code; | |
} | |
void jitcompile(const int *bcp, uchar *code) { | |
uchar *cp = code; | |
uchar *cu; | |
const char *reg_stack[] = {"rbx", "rcx", "rdi", "rsi", "r8", "r9", | |
"r10", "r11", "r12", "r13", "r14", "r15"}; | |
uchar reg_stacki[] = {RID_RBX, RID_RCX, RID_RDI, RID_RSI, RID_R8, | |
RID_R9, RID_R10, RID_R11, RID_R11, RID_R12, | |
RID_R13, RID_R14, RID_R15}; | |
bcp--; | |
int sp = 0; | |
FILE *f = fopen("/tmp/test_jit.s", "w"); | |
fprintf(f, ".globl myfunc\n"); | |
fprintf(f, ".type myfunc, @function\n"); | |
fprintf(f, "myfunc:\n"); | |
fprintf(f, "push %rbp\n"); | |
*(cp++) = 0x50 + RID_RBP; | |
fprintf(f, "movq %rsp, %rbp\n"); | |
cp = bin_rr(cp, MOV_RR, RID_RSP, RID_RBP, 0); | |
fprintf(f, "movq %rdi, -8(%rbp)\n"); | |
cp = bin_mr(cp, MOV_RR, RID_RDI, RID_RBP, -8, 0); | |
fprintf(f, "movq %rdi, -16(%rbp)\n"); | |
cp = bin_mr(cp, MOV_RR, RID_RSI, RID_RBP, -16, 0); | |
fprintf(f, "movq %rdi, -24(%rbp)\n"); | |
cp = bin_mr(cp, MOV_RR, RID_RDX, RID_RBP, -24, 0); | |
fprintf(f, "movq %rdi, -32(%rbp)\n"); | |
cp = bin_mr(cp, MOV_RR, RID_R10, RID_RBP, -32, 0); | |
fprintf(f, "movq %rdi, -40(%rbp)\n"); | |
cp = bin_mr(cp, MOV_RR, RID_R8, RID_RBP, -40, 0); | |
fprintf(f, "movq %rdi, -48(%rbp)\n"); | |
cp = bin_mr(cp, MOV_RR, RID_R9, RID_RBP, -48, 0); | |
while (*++bcp) { | |
switch (*bcp) { | |
case NOP: | |
break; | |
case RET: | |
fprintf(f, "movq %rbx, %rax\n"); | |
cp = bin_rr(cp, MOV_RR, RID_RBX, RID_RAX, 0); | |
fprintf(f, "popq %rbp\n"); | |
*(cp++) = 0x58 + RID_RBP; | |
fprintf(f, "retq\n"); | |
*(cp++) = 0xc3; | |
for (cu = code; cu < cp; ++cu) { | |
fprintf(f, "%02x ", *cu); | |
} | |
fclose(f); | |
return; | |
case PUSHV: | |
++bcp; | |
int offset = (*bcp + 1) * -8; | |
fprintf(f, "movq %d(%rbp), %%%s\n", offset, reg_stack[sp]); | |
cp = bin_mr(cp, 0x8B, reg_stacki[sp], RID_RBP, offset, 0); | |
sp++; | |
break; | |
case PUSHC: | |
++bcp; | |
fprintf(f, "movq $%d, %%%s\n", *bcp, reg_stack[sp]); | |
cp = movq_ir(cp, *bcp, reg_stacki[sp]); | |
sp++; | |
break; | |
case ADD: | |
fprintf(f, "addq %%%s, %%%s\n", reg_stack[sp - 1], | |
reg_stack[sp - 2]); | |
cp = | |
bin_rr(cp, 0x01, reg_stacki[sp - 1], reg_stacki[sp - 2], 0); | |
sp--; | |
break; | |
case SUB: | |
fprintf(f, "subq %%%s, %%%s\n", reg_stack[sp - 1], | |
reg_stack[sp - 2]); | |
cp = | |
bin_rr(cp, 0x29, reg_stacki[sp - 1], reg_stacki[sp - 2], 0); | |
sp--; | |
break; | |
case MUL: | |
fprintf(f, "imulq %%%s, %%%s\n", reg_stack[sp - 1], | |
reg_stack[sp - 2]); | |
cp = | |
bin_rr(cp, 0xAF, reg_stacki[sp - 2], reg_stacki[sp - 1], 1); | |
sp--; | |
break; | |
case DIV: | |
fprintf(f, "movq %rax, %%%s\n", reg_stack[sp]); | |
fprintf(f, "movq %rdx, %%%s\n", reg_stack[sp + 1]); | |
fprintf(f, "xor %rdx, %rdx\n"); | |
fprintf(f, "movq %%%s, %rax\n", reg_stack[sp - 2]); | |
fprintf(f, "movq %%%s, -56(%rbp)\n", reg_stack[sp - 1]); | |
fprintf(f, "cqto\n"); | |
fprintf(f, "idivq -56(%rbp)\n"); | |
fprintf(f, "movq %%rax, %%%s\n", reg_stack[sp - 2]); | |
fprintf(f, "movq %%%s, %rax\n", reg_stack[sp]); | |
fprintf(f, "movq %%%s, %rdx\n", reg_stack[sp + 1]); | |
cp = bin_rr(cp, MOV_RR, RID_RAX, reg_stacki[sp], 0); | |
cp = bin_rr(cp, MOV_RR, RID_RDX, reg_stacki[sp + 1], 0); | |
cp = bin_rr(cp, 0x31, RID_RDX, RID_RDX, 0); | |
cp = bin_rr(cp, MOV_RR, reg_stacki[sp - 2], RID_RAX, 0); | |
// addr | |
cp = bin_mr(cp, MOV_RR, reg_stacki[sp - 1], RID_RBP, -56, 0); | |
*(cp++) = REX_64; | |
*(cp++) = 0x99; // cqto | |
*(cp++) = REX_64; | |
*(cp++) = 0xf7; | |
*(cp++) = ((0b01 << 6) | (7 << 3) | (RID_RBP)); | |
*(cp++) = -56; | |
cp = bin_rr(cp, MOV_RR, RID_RAX, reg_stacki[sp - 2], 0); | |
cp = bin_rr(cp, MOV_RR, reg_stacki[sp], RID_RAX, 0); | |
cp = bin_rr(cp, MOV_RR, reg_stacki[sp + 1], RID_RDX, 0); | |
sp--; | |
break; | |
} | |
} | |
} | |
typedef long (*MyFunc)(); | |
int calc(const char *val) { | |
int *bc = bccompile(val); | |
void *ptr = mmap(0, CODE_SIZE, PROT_READ | PROT_WRITE | PROT_EXEC, | |
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); | |
jitcompile(bc, ptr); | |
MyFunc f = ptr; | |
int ret = f(); | |
free(bc); | |
munmap(ptr, CODE_SIZE); | |
return ret; | |
} | |
#define SIMPLE_TEST_ASSERT(x) assert((x) == calc(#x)); | |
void test_all() { | |
SIMPLE_TEST_ASSERT(1+2+3); | |
SIMPLE_TEST_ASSERT(2/2); | |
SIMPLE_TEST_ASSERT((2 + 2)/2); | |
SIMPLE_TEST_ASSERT((2 + 2)*2); | |
SIMPLE_TEST_ASSERT(((2 + 2)*2)/4+1); | |
SIMPLE_TEST_ASSERT(1+2+3*4-5+6); | |
} | |
int main(int argc, const char *argv[]) { | |
test_all(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment