Last active
February 12, 2024 22:34
-
-
Save algebraic-dev/bf79be8eee70ea224b23ec5820e3624b to your computer and use it in GitHub Desktop.
Dependent type checker with substitution for lambda calculus.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{collections::HashSet, fmt::Display, rc::Rc}; | |
/// The AST. This thing describes | |
/// the syntatic tree of the program. | |
#[derive(Debug)] | |
pub enum Syntax { | |
Lambda { | |
param: String, | |
body: Rc<Syntax>, | |
}, | |
App { | |
fun: Rc<Syntax>, | |
arg: Rc<Syntax>, | |
}, | |
Var { | |
name: String, | |
}, | |
Pi { | |
param: String, | |
typ: Rc<Syntax>, | |
body: Rc<Syntax>, | |
}, | |
Ann { | |
expr: Rc<Syntax>, | |
typ: Rc<Syntax>, | |
}, | |
Typ, | |
} | |
/// Pretty printing of the code. | |
impl Display for Syntax { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
match self { | |
Syntax::Lambda { param, body } => write!(f, "(λ{param}. {body})"), | |
Syntax::App { fun, arg } => write!(f, "({fun} {arg})"), | |
Syntax::Var { name } => write!(f, "{name}"), | |
Syntax::Pi { param, typ, body } => write!(f, "(({param} : {typ}) -> {body})"), | |
Syntax::Ann { expr: body, typ } => write!(f, "({body} : {typ})"), | |
Syntax::Typ => write!(f, "Type"), | |
} | |
} | |
} | |
impl Syntax { | |
/// Collects "Free variables". Free variables (FV) are variables that | |
/// would give us a error saying that "this variable is not defined" is any language. | |
/// BUT here we need them because "free variables" may be used by the context | |
/// or outside of the expression we are analysing right now. | |
/// | |
/// E.g: If we are analysing the BODY of "λx. x" ("x") them "x" is a free variable locally | |
/// but if we analyse the entire expression then "x" is bound. The something similar happens | |
/// with the Pi type "(x : T) -> G x". "x" is free on "G x" but bound on the rest. | |
pub fn free_vars(&self) -> HashSet<String> { | |
let mut fv = HashSet::new(); | |
// Here we are using the "im" package that provides "immutable" HashSets for efficiency | |
// So clones are almost O(1). | |
fn collect(expr: &Syntax, ctx: im::HashSet<String>, fv: &mut HashSet<String>) { | |
match expr { | |
Syntax::Lambda { param, body } => { | |
let mut new_ctx = ctx.clone(); | |
new_ctx.insert(param.clone()); | |
collect(&body.clone(), new_ctx, fv) | |
} | |
Syntax::App { fun, arg } => { | |
collect(&fun.clone(), ctx.clone(), fv); | |
collect(&arg.clone(), ctx, fv); | |
} | |
Syntax::Var { name } => { | |
if !ctx.contains(name) { | |
fv.insert(name.clone()); | |
} | |
} | |
Syntax::Pi { param, typ, body } => { | |
collect(&typ.clone(), ctx.clone(), fv); | |
let mut new_ctx = ctx; | |
new_ctx.insert(param.clone()); | |
collect(&body.clone(), new_ctx, fv); | |
} | |
Syntax::Ann { expr: body, typ } => { | |
collect(&body.clone(), ctx.clone(), fv); | |
collect(&typ.clone(), ctx, fv); | |
} | |
Syntax::Typ => (), | |
} | |
} | |
collect(&self, Default::default(), &mut fv); | |
fv | |
} | |
} | |
/// Now the type checker runs inside a "type checker context" | |
/// in order to create new names. | |
pub struct TyCtx { | |
pub name_counter: u64, | |
} | |
impl TyCtx { | |
/// Generates a new name based on the counter. | |
pub fn new_name(&mut self) -> String { | |
let mut str = String::new(); | |
let mut count = self.name_counter; | |
loop { | |
let chr = count % 26; | |
count = count / 26; | |
str.push((chr + 96) as u8 as char); | |
if count <= 0 { | |
break; | |
} | |
} | |
self.name_counter += 1; | |
str.push('\''); | |
str.chars().rev().collect() | |
} | |
/// Substitutes (expr[from = to]) the variable "from" to the variable "to" in the expression | |
/// "expr". | |
pub fn subst(&mut self, expr: Rc<Syntax>, from: &String, to: Rc<Syntax>) -> Rc<Syntax> { | |
match &*expr { | |
// If the variable has the same name of "from" then we just returns "to". | |
Syntax::Var { name } if name == from => to.clone(), | |
// If the "param" is equal to "from" then we are facing a "Shadowing". | |
Syntax::Lambda { param, body } if param != from => { | |
// Here we have a special case that is uur really bad to treat btw, we have to treat it. | |
// It's the case when we have to substitute (\x. E)[y = x] and `x` happens on the `to` | |
// parameter. in this case we have to change all of the `x` in `E` to something new. | |
let (param, body) = if to.free_vars().contains(param) { | |
let new_var = self.new_name(); | |
( | |
new_var.clone(), | |
self.subst(body.clone(), param, Rc::new(Syntax::Var { name: new_var })), | |
) | |
} else { | |
(param.clone(), body.clone()) | |
}; | |
Rc::new(Syntax::Lambda { | |
param, | |
body: self.subst(body, from, to), | |
}) | |
} | |
// We substitute both sides. | |
Syntax::App { fun, arg } => Rc::new(Syntax::App { | |
fun: self.subst(fun.clone(), from, to.clone()), | |
arg: self.subst(arg.clone(), from, to.clone()), | |
}), | |
Syntax::Pi { param, typ, body } => { | |
let (param, body) = if to.free_vars().contains(param) { | |
let new_var = self.new_name(); | |
( | |
new_var.clone(), | |
self.subst(body.clone(), param, Rc::new(Syntax::Var { name: new_var })), | |
) | |
} else { | |
(param.clone(), body.clone()) | |
}; | |
Rc::new(Syntax::Pi { | |
param: param.clone(), | |
typ: self.subst(typ.clone(), from, to.clone()), | |
body: if param == *from { | |
body.clone() | |
} else { | |
self.subst(body.clone(), from, to.clone()) | |
}, | |
}) | |
} | |
Syntax::Ann { expr: body, typ } => Rc::new(Syntax::Ann { | |
expr: self.subst(body.clone(), from, to.clone()), | |
typ: self.subst(typ.clone(), from, to.clone()), | |
}), | |
_ => expr.clone(), | |
} | |
} | |
/// Gets a lambda expression and evalutes it to it's "Weak head normal form" | |
pub fn eval(&mut self, expr: Rc<Syntax>) -> Rc<Syntax> { | |
match &*expr { | |
Syntax::App { fun, arg } => match &*self.eval(fun.clone()) { | |
Syntax::Lambda { param, body } => { | |
let arg = self.eval(arg.clone()); | |
let res = self.subst(body.clone(), param, arg); | |
self.eval(res) | |
} | |
_ => expr, | |
}, | |
Syntax::Ann { expr, typ: _ } => { | |
expr.clone() | |
} | |
_ => expr, | |
} | |
} | |
/// Strong normalize the ENTIRE expression it's fucked up I think? | |
pub fn reduce(&mut self, expr: Rc<Syntax>) -> Rc<Syntax> { | |
match &*expr { | |
Syntax::App { fun, arg } => match &*self.reduce(fun.clone()) { | |
Syntax::Lambda { param, body } => { | |
let arg = self.reduce(arg.clone()); | |
let res = self.subst(body.clone(), param, arg); | |
self.reduce(res) | |
} | |
_ => { | |
app(self.reduce(fun.clone()), self.reduce(arg.clone())) | |
} | |
}, | |
Syntax::Lambda { param, body } => { | |
lam(param, self.reduce(body.clone())) | |
}, | |
Syntax::Pi { param, typ, body } => { | |
pi(param, self.reduce(typ.clone()), self.reduce(body.clone())) | |
} | |
Syntax::Ann { expr, typ: _ } => { | |
self.reduce(expr.clone()) | |
}, | |
_ => expr.clone() | |
} | |
} | |
} | |
// Some helper functions | |
pub fn var(name: &str) -> Rc<Syntax> { | |
Rc::new(Syntax::Var { | |
name: name.to_string(), | |
}) | |
} | |
pub fn typ() -> Rc<Syntax> { | |
Rc::new(Syntax::Typ) | |
} | |
pub fn app(fun: Rc<Syntax>, arg: Rc<Syntax>) -> Rc<Syntax> { | |
Rc::new(Syntax::App { fun, arg }) | |
} | |
pub fn ann(expr: Rc<Syntax>, typ: Rc<Syntax>) -> Rc<Syntax> { | |
Rc::new(Syntax::Ann { expr, typ }) | |
} | |
pub fn lam(param: &str, body: Rc<Syntax>) -> Rc<Syntax> { | |
Rc::new(Syntax::Lambda { | |
param: param.to_string(), | |
body, | |
}) | |
} | |
pub fn pi(param: &str, typ: Rc<Syntax>, body: Rc<Syntax>) -> Rc<Syntax> { | |
Rc::new(Syntax::Pi { | |
param: param.to_string(), | |
typ, | |
body, | |
}) | |
} | |
// An immutable environment for the type checking phase | |
type Env = im::HashMap<String, Rc<Syntax>>; | |
// Type checking functions | |
impl TyCtx { | |
// The equal function is more commonly defined as "conv" (convergence) | |
// it checks if two expressions are equal | |
pub fn conv(&mut self, left: Rc<Syntax>, right: Rc<Syntax>) -> bool { | |
match (&*self.eval(left), &*self.eval(right)) { | |
(Syntax::Var { name: name_a }, Syntax::Var { name: name_b }) => name_a == name_b, | |
( | |
Syntax::Lambda { | |
param: pa, | |
body: ba, | |
}, | |
Syntax::Lambda { | |
param: pb, | |
body: bb, | |
}, | |
) => { | |
let n = self.new_name(); | |
// Here we rename two expression so in the end they become "alpha equivalebnt". | |
// e.g: (\x.x) = (\y.y) but they have different names so we change the names of the | |
// inside to 'a and we end up with (\'a. 'a) = (\'a. 'a) but we can discard the \'a and | |
// compare the inside part. 'a = 'a | |
let ba_subst = self.subst(ba.clone(), pa, var(&n)); | |
let bb_subst = self.subst(bb.clone(), pb, var(&n)); | |
self.conv(ba_subst, bb_subst) | |
} | |
( | |
Syntax::Pi { | |
param: pa, | |
typ: ta, | |
body: ba, | |
}, | |
Syntax::Pi { | |
param: pb, | |
typ: tb, | |
body: bb, | |
}, | |
) => { | |
let n = self.new_name(); | |
let ba_subst = self.subst(ba.clone(), pa, var(&n)); | |
let bb_subst = self.subst(bb.clone(), pb, var(&n)); | |
self.conv(ta.clone(), tb.clone()) && self.conv(ba_subst, bb_subst) | |
} | |
(Syntax::Ann { expr, typ }, Syntax::Ann { expr: eb, typ: tb }) => { | |
self.conv(expr.clone(), eb.clone()) && self.conv(typ.clone(), tb.clone()) | |
} | |
(Syntax::App { fun, arg }, Syntax::App { fun: fb, arg: ab }) => { | |
self.conv(fun.clone(), fb.clone()) && self.conv(arg.clone(), ab.clone()) | |
} | |
(Syntax::Typ, Syntax::Typ) => true, | |
(_, _) => { | |
false | |
}, | |
} | |
} | |
pub fn check(&mut self, ctx: Env, expr: Rc<Syntax>, typ: Rc<Syntax>) { | |
let expected = self.eval(typ); | |
match (&*expr, &*expected) { | |
// Γ ⊢ λx. e ⇐ (y: A) -> B | |
(Syntax::Lambda { param, body }, Syntax::Pi { param: pb, typ, body: tb }) => { | |
// Γ, x : A | |
let mut new_ctx = ctx.clone(); | |
new_ctx.insert(param.clone(), typ.clone()); | |
// B[y = x] | |
let ret_type = self.subst(tb.clone(), pb, var(¶m)); | |
// e ⇐ B[y = x] | |
self.check(new_ctx, body.clone(), ret_type); | |
}, | |
// Γ ⊢ x ⇐ A | |
(_, _) => { | |
// Γ ⊢ x => B | |
let infered = self.infer(ctx, expr.clone()); | |
// A = B | |
if !self.conv(expected.clone(), infered.clone()) { | |
panic!("Type '{}' does not match with '{}'", expected, infered) | |
} | |
} | |
} | |
} | |
pub fn infer(&mut self, ctx: Env, expr: Rc<Syntax>) -> Rc<Syntax> { | |
match &*expr { | |
Syntax::Lambda { .. } => panic!("Cannot infer lambda"), | |
// Γ ⊢ a b => B[x = b] | |
Syntax::App { fun, arg } => { | |
// Γ ⊢ a => (x: A) -> B | |
let fun_ty = self.infer(ctx.clone(), fun.clone()); | |
if let Syntax::Pi { param, typ, body } = &*fun_ty { | |
// Γ ⊢ b ⇐ A | |
self.check(ctx, arg.clone(), typ.clone()); | |
// B[x = b] | |
self.subst(body.clone(), param, arg.clone()) | |
} else { | |
panic!("Not a function to apply") | |
} | |
}, | |
// Γ ⊢ x => A | |
Syntax::Var { name } => { | |
// x : A ∈ Γ | |
if let Some(ty) = ctx.get(name) { | |
ty.clone() | |
} else { | |
panic!("Cannot find variable '{name}'") | |
} | |
}, | |
// Γ ⊢ (x: A) -> B => Type | |
Syntax::Pi { param, typ: tipo, body } => { | |
// Γ ⊢ A ⇐ Type | |
self.check(ctx.clone(), tipo.clone(), typ()); | |
// Γ, x : A ⊢ B ⇐ Type | |
let mut new_ctx = ctx.clone(); | |
new_ctx.insert(param.clone(), tipo.clone()); | |
self.check(new_ctx, body.clone(), typ()); | |
typ() | |
}, | |
// Γ ⊢ e : A => A | |
Syntax::Ann { expr, typ: tipo } => { | |
// A <= Type | |
self.check(ctx.clone(), tipo.clone(), typ()); | |
// e <= A | |
self.check(ctx, expr.clone(), tipo.clone()); | |
tipo.clone() | |
}, | |
// Γ ⊢ Type => Type | |
Syntax::Typ => { | |
typ() | |
}, | |
} | |
} | |
} | |
fn main() { | |
let mut tyctx = TyCtx { name_counter: 1 }; | |
// Encoding Nat as pi type with church encoding | |
// type nat : Type { | |
// zero : nat | |
// succ : nat -> nat | |
// } | |
let nat = | |
pi("nat", typ(), | |
pi("zero", var("nat"), | |
pi("succ", pi("_", var("nat"), var("nat")), | |
var("nat")))); | |
// Nat is a type | |
tyctx.check(Default::default(), nat.clone(), typ()); | |
// \_ -> \z -> \s -> \z | |
let zero = | |
lam("_", | |
lam("z", | |
lam("s", | |
var("z")))); | |
// Zero is a natural | |
let zero = ann(zero.clone(), nat.clone()); | |
tyctx.infer(Default::default(), zero.clone()); | |
// \m -> \ty -> \z -> \s -> (s (m ty z s)) | |
let succ = | |
lam("m", | |
lam("ty", | |
lam("z", | |
lam("s", | |
app(var("s"), app(app(app(var("m"), var("ty")), var("z")), var("s"))))))); | |
// Succ is a (natural -> natural) | |
let succ = ann(succ.clone(), pi("_", nat.clone(), nat.clone())); | |
tyctx.infer(Default::default(), succ.clone()); | |
// \m -> \n -> \ty -> \z -> \s -> ((m ty) (n ty z s)) s | |
let add = | |
lam("m", | |
lam("n", | |
lam("ty", | |
lam("z", | |
lam("s", | |
app(app(app(var("m"), var("ty")), app(app(app(var("n"), var("ty")), var("z")), var("s"))), var("s"))))))); | |
// We always have to anotate these things \:P | |
// Add is a (natural -> -> nat natural) | |
let add = ann(add.clone(), pi("_", nat.clone(), pi("_", nat.clone(), nat.clone()))); | |
tyctx.infer(Default::default(), add.clone()); | |
// Remember it's in Weak head normal form so it does not reduce until the end | |
let one = app(succ.clone(), zero.clone()); | |
let two = app(succ.clone(), one.clone()); | |
let three = app(succ.clone(), two.clone()); | |
let four = app(succ.clone(), three.clone()); | |
let five = app(succ.clone(), four.clone()); | |
let added_five = app(app(add.clone(), three.clone()), two.clone()); | |
tyctx.check(Default::default(), added_five.clone(), nat.clone()); | |
let redc_added = tyctx.reduce(added_five.clone()); | |
let redc_five = tyctx.reduce(five.clone()); | |
let added_five = app(app(add.clone(), three.clone()), two.clone()); | |
let added_inv_five = app(app(add.clone(), two.clone()), three.clone()); | |
// Testing if the strong is equal to the non evaluated | |
println!("{}", tyctx.conv(added_five.clone(), added_inv_five)); | |
println!("{}", tyctx.conv(redc_added, redc_five)); | |
println!("{}", tyctx.conv(added_five, five)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment