Skip to content

Instantly share code, notes, and snippets.

@hirrolot
Created April 13, 2025 13:40
Show Gist options
  • Save hirrolot/c89baa9b83c7da9b87146be88e560351 to your computer and use it in GitHub Desktop.
Save hirrolot/c89baa9b83c7da9b87146be88e560351 to your computer and use it in GitHub Desktop.
Calculus of Constructions in 60 lines of OCaml
let discard _a b = b
let rec pp lvl = function
| `Lam f -> "" ^ pp (lvl + 1) (f (`Go lvl)) ^ ")"
| `Pi (a, f) -> "" ^ pp lvl a ^ "." ^ pp (lvl + 1) (f (`Go lvl)) ^ ")"
| `Appl (m, n) -> "(" ^ pp lvl m ^ " " ^ pp lvl n ^ ")"
| `Ann (m, a) -> "(" ^ pp lvl m ^ " : " ^ pp lvl a ^ ")"
| `Go x -> string_of_int x
| `Star -> "*"
| `Box -> ""
let rec eval = function
| `Lam f -> `Lam (fun n -> eval (f n))
| `Pi (a, f) -> `Pi (eval a, fun n -> eval (f n))
| `Appl (m, n) -> (
match (eval m, eval n) with `Lam f, n -> f n | m, n -> `Appl (m, n))
| `Ann (m, _a) -> eval m
| (`Go _ | `Star | `Box) as t -> t
let rec equate lvl = function
| `Lam f, `Lam g -> equate (lvl + 1) (f (`Go lvl), g (`Go lvl))
| `Pi (a, f), `Pi (b, g) ->
equate lvl (a, b) && equate (lvl + 1) (f (`Go lvl), g (`Go lvl))
| `Appl (m, n), `Appl (m', n') -> equate lvl (m, m') && equate lvl (n, n')
| `Ann (m, a), `Ann (m', b) -> equate lvl (m, m') && equate lvl (a, b)
| `Go x, `Go y -> x = y
| `Star, `Star | `Box, `Box -> true
| _, _ -> false
let panic lvl t fmt =
Printf.ksprintf (fun s -> failwith (s ^ ": " ^ pp lvl t)) fmt
let rec infer lvl ctx = function
| `Pi (a, f) ->
discard (infer_sort lvl ctx a)
(infer_sort (lvl + 1) (eval a :: ctx) (f (`Go lvl)))
| `Appl (m, n) -> (
match infer lvl ctx m with
| `Pi (a, f) -> discard (check lvl ctx (n, a)) (f n)
| m_ty -> panic lvl m "Want Π, got %s" (pp lvl m_ty))
| `Ann (m, a) -> discard (infer_sort lvl ctx a) (check lvl ctx (m, eval a))
| `Go x -> List.nth ctx (lvl - x - 1)
| `Star -> `Box
| t -> panic lvl t "Not inferrable"
and infer_sort lvl ctx a =
match infer lvl ctx a with
| (`Star | `Box) as s -> s
| ty -> panic lvl a "Want a sort, got %s" (pp lvl ty)
and check lvl ctx = function
| `Lam f, `Pi (a, g) ->
discard
(check (lvl + 1) (a :: ctx) (f (`Go lvl), g (`Go lvl)))
(`Pi (a, g))
| `Lam f, ty -> panic lvl (`Lam f) "Want Π, got %s" (pp lvl ty)
| t, ty ->
let got_ty = infer lvl ctx t in
if equate lvl (ty, got_ty) then ty
else panic lvl t "Want type %s, got %s" (pp lvl ty) (pp lvl got_ty)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment