Created
June 1, 2025 08:26
-
-
Save utdemir/63cafcb74f60aebed21d95d5871d6c68 to your computer and use it in GitHub Desktop.
Two ways to implement Trees That Grow in Rust
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Trees that Grow | |
trait HasExt<P> { | |
type Ext; | |
} | |
macro_rules! decl_ext { | |
($phase:ty, $name:ident, $ext:ty) => { | |
impl HasExt<$name> for $phase { | |
type Ext = $ext; | |
} | |
}; | |
($phase:ty: $($name:ident => $ext:ty),* $(,)?) => { | |
$(decl_ext!($phase, $name, $ext);)* | |
}; | |
} | |
type Ext<P, X> = <X as HasExt<P>>::Ext; | |
type Never = std::convert::Infallible; | |
// Example | |
enum MyAST<P> | |
where | |
P: HasExt<XNumber> | |
+ HasExt<XString> | |
+ HasExt<XAdd> | |
+ HasExt<XSubtract> | |
+ HasExt<XNegate> | |
+ HasExt<XToString>, | |
{ | |
Number(Ext<XNumber, P>, i32), | |
String(Ext<XString, P>, String), | |
Add(Ext<XAdd, P>, Box<MyAST<P>>, Box<MyAST<P>>), | |
Subtract(Ext<XSubtract, P>, Box<MyAST<P>>, Box<MyAST<P>>), | |
Negate(Ext<XNegate, P>, Box<MyAST<P>>), | |
ToString(Ext<XToString, P>, Box<MyAST<P>>), | |
} | |
// Branches | |
struct XNumber; | |
struct XString; | |
struct XAdd; | |
struct XSubtract; | |
struct XNegate; | |
struct XToString; | |
// Phases | |
struct PInit; | |
struct PTc; | |
struct POpt; | |
// Types | |
#[derive(Clone, Copy, Debug, PartialEq, Eq)] | |
enum Type { | |
Number, | |
String, | |
} | |
#[derive(Clone, Copy, Debug)] | |
enum TypeError { | |
MismatchedTypes, | |
} | |
// Mappings | |
decl_ext!(PInit: | |
XNumber => (), | |
XString => (), | |
XAdd => (), | |
XSubtract => (), | |
XNegate => (), | |
XToString => (), | |
); | |
decl_ext!(PTc: | |
XNumber => Type, | |
XString => Type, | |
XAdd => Type, | |
XSubtract => Type, | |
XNegate => Type, | |
XToString => Type, | |
); | |
decl_ext!(POpt: | |
XNumber => Type, | |
XString => Type, | |
XAdd => Type, | |
XSubtract => Never, | |
XNegate => Type, | |
XToString => Type, | |
); | |
// Implementations | |
fn typecheck(ast: MyAST<PInit>) -> Result<(Type, MyAST<PTc>), TypeError> { | |
match ast { | |
MyAST::Number((), value) => Ok((Type::Number, MyAST::Number(Type::Number, value))), | |
MyAST::String((), value) => Ok((Type::String, MyAST::String(Type::String, value))), | |
MyAST::Add((), left, right) => { | |
let (left_ty, lhs) = typecheck(*left)?; | |
let (right_ty, rhs) = typecheck(*right)?; | |
if left_ty != right_ty { | |
return Err(TypeError::MismatchedTypes); | |
} else { | |
Ok((left_ty, MyAST::Add(left_ty, Box::new(lhs), Box::new(rhs)))) | |
} | |
} | |
MyAST::Subtract((), left, right) => { | |
if let ((Type::Number, lhs), (Type::Number, rhs)) = | |
(typecheck(*left)?, typecheck(*right)?) | |
{ | |
Ok(( | |
Type::Number, | |
MyAST::Subtract(Type::Number, Box::new(lhs), Box::new(rhs)), | |
)) | |
} else { | |
Err(TypeError::MismatchedTypes) | |
} | |
} | |
MyAST::Negate((), expr) => { | |
if let (Type::Number, expr_tc) = typecheck(*expr)? { | |
Ok((Type::Number, MyAST::Negate(Type::Number, Box::new(expr_tc)))) | |
} else { | |
Err(TypeError::MismatchedTypes) | |
} | |
} | |
MyAST::ToString((), expr) => { | |
if let (Type::Number, expr_tc) = typecheck(*expr)? { | |
Ok(( | |
Type::String, | |
MyAST::ToString(Type::String, Box::new(expr_tc)), | |
)) | |
} else { | |
Err(TypeError::MismatchedTypes) | |
} | |
} | |
} | |
} | |
fn optimize(ast: MyAST<PTc>) -> MyAST<POpt> { | |
match ast { | |
MyAST::Number(ext, value) => MyAST::Number(ext, value), | |
MyAST::String(ext, value) => MyAST::String(ext, value), | |
MyAST::Add(ext, left, right) => { | |
MyAST::Add(ext, Box::new(optimize(*left)), Box::new(optimize(*right))) | |
} | |
MyAST::Subtract(ext, left, right) => MyAST::Add( | |
ext, | |
Box::new(optimize(*left)), | |
Box::new(MyAST::Negate(ext, Box::new(optimize(*right)))), | |
), | |
MyAST::Negate(ext, expr) => MyAST::Negate(ext, Box::new(optimize(*expr))), | |
MyAST::ToString(ext, expr) => MyAST::ToString(ext, Box::new(optimize(*expr))), | |
} | |
} | |
fn _complexity<T>(ast: &MyAST<T>) -> usize | |
where | |
T: HasExt<XNumber> | |
+ HasExt<XString> | |
+ HasExt<XAdd> | |
+ HasExt<XSubtract> | |
+ HasExt<XNegate> | |
+ HasExt<XToString>, | |
{ | |
match ast { | |
MyAST::Number(_, _) => 1, | |
MyAST::String(_, _) => 1, | |
MyAST::Add(_, left, right) => 1 + _complexity(left) + _complexity(right), | |
MyAST::Subtract(_, left, right) => 1 + _complexity(left) + _complexity(right), | |
MyAST::Negate(_, expr) => 1 + _complexity(expr), | |
MyAST::ToString(_, expr) => 1 + _complexity(expr), | |
} | |
} | |
#[derive(Default)] | |
struct CompileState { | |
statements: Vec<String>, | |
last_var: usize, | |
} | |
impl CompileState { | |
fn next_var(&mut self) -> String { | |
self.last_var += 1; | |
format!("var{}", self.last_var) | |
} | |
fn set_stmt(&mut self, value: String) -> String { | |
let var = self.next_var(); | |
self.add_stmt(format!("let {} = {}", var, value)); | |
var | |
} | |
fn add_stmt(&mut self, value: String) { | |
self.statements.push(value); | |
} | |
} | |
fn compile(ast: MyAST<POpt>) -> String { | |
fn infix( | |
state: &mut CompileState, | |
left: Box<MyAST<POpt>>, | |
right: Box<MyAST<POpt>>, | |
fmt: impl FnOnce(String, String) -> String, | |
) -> String { | |
let l = go(state, *left); | |
let r = go(state, *right); | |
state.set_stmt(fmt(l, r)) | |
} | |
fn go(state: &mut CompileState, node: MyAST<POpt>) -> String { | |
match node { | |
MyAST::Add(ty, l, r) => match ty { | |
Type::Number => infix(state, l, r, |a, b| format!("{} + {}", a, b)), | |
Type::String => infix(state, l, r, |a, b| { | |
format!("format!(\"{{}}{{}}\", {}, {})", a, b) | |
}), | |
}, | |
MyAST::Subtract(never, _, _) => match never {}, | |
MyAST::Number(_, value) => state.set_stmt(value.to_string()), | |
MyAST::String(_, value) => state.set_stmt(format!("\"{}\"", value)), | |
MyAST::Negate(_, expr) => { | |
let expr_code = go(state, *expr); | |
state.set_stmt(format!("-{}", expr_code)) | |
} | |
MyAST::ToString(_, expr) => { | |
let var = go(state, *expr); | |
state.set_stmt(format!("{}.to_string()", var)) | |
} | |
} | |
} | |
let mut st = CompileState::default(); | |
let var = go(&mut st, ast); | |
st.statements | |
.iter() | |
.map(|s| format!("{};", s)) | |
.chain(std::iter::once(var)) | |
.collect::<Vec<_>>() | |
.join("\n") | |
} | |
fn main() { | |
let ast = { | |
use MyAST::*; | |
Add( | |
(), | |
Box::new(String((), "Result: ".to_string())), | |
Box::new(ToString( | |
(), | |
Box::new(Subtract( | |
(), | |
Box::new(Number((), 10)), | |
Box::new(Number((), 5)), | |
)), | |
)), | |
) | |
}; | |
let tc = typecheck(ast).expect("Type checking failed"); | |
let opt = optimize(tc.1); | |
let compiled_code = compile(opt); | |
println!("{}", compiled_code); | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Example | |
trait XMyAST { | |
type Number; | |
type String; | |
type Add; | |
type Subtract; | |
type Negate; | |
type ToString; | |
} | |
enum MyAST<X: XMyAST> { | |
Number(X::Number, i32), | |
String(X::String, String), | |
Add(X::Add, Box<MyAST<X>>, Box<MyAST<X>>), | |
Subtract(X::Subtract, Box<MyAST<X>>, Box<MyAST<X>>), | |
Negate(X::Negate, Box<MyAST<X>>), | |
ToString(X::ToString, Box<MyAST<X>>), | |
} | |
// Types | |
#[derive(Clone, Copy, Debug, PartialEq, Eq)] | |
enum Type { | |
Number, | |
String, | |
} | |
#[derive(Clone, Copy, Debug)] | |
enum TypeError { | |
MismatchedTypes, | |
} | |
type Never = std::convert::Infallible; | |
// Mappings | |
struct PInit; | |
struct PTc; | |
struct POpt; | |
impl XMyAST for PInit { | |
type Number = (); | |
type String = (); | |
type Add = (); | |
type Subtract = (); | |
type Negate = (); | |
type ToString = (); | |
} | |
impl XMyAST for PTc { | |
type Number = Type; | |
type String = Type; | |
type Add = Type; | |
type Subtract = Type; | |
type Negate = Type; | |
type ToString = Type; | |
} | |
impl XMyAST for POpt { | |
type Number = Type; | |
type String = Type; | |
type Add = Type; | |
type Subtract = Never; | |
type Negate = Type; | |
type ToString = Type; | |
} | |
// Implementations | |
fn typecheck(ast: MyAST<PInit>) -> Result<(Type, MyAST<PTc>), TypeError> { | |
match ast { | |
MyAST::Number((), value) => Ok((Type::Number, MyAST::Number(Type::Number, value))), | |
MyAST::String((), value) => Ok((Type::String, MyAST::String(Type::String, value))), | |
MyAST::Add((), left, right) => { | |
let (left_ty, lhs) = typecheck(*left)?; | |
let (right_ty, rhs) = typecheck(*right)?; | |
if left_ty != right_ty { | |
return Err(TypeError::MismatchedTypes); | |
} else { | |
Ok((left_ty, MyAST::Add(left_ty, Box::new(lhs), Box::new(rhs)))) | |
} | |
} | |
MyAST::Subtract((), left, right) => { | |
if let ((Type::Number, lhs), (Type::Number, rhs)) = | |
(typecheck(*left)?, typecheck(*right)?) | |
{ | |
Ok(( | |
Type::Number, | |
MyAST::Subtract(Type::Number, Box::new(lhs), Box::new(rhs)), | |
)) | |
} else { | |
Err(TypeError::MismatchedTypes) | |
} | |
} | |
MyAST::Negate((), expr) => { | |
if let (Type::Number, expr_tc) = typecheck(*expr)? { | |
Ok((Type::Number, MyAST::Negate(Type::Number, Box::new(expr_tc)))) | |
} else { | |
Err(TypeError::MismatchedTypes) | |
} | |
} | |
MyAST::ToString((), expr) => { | |
if let (Type::Number, expr_tc) = typecheck(*expr)? { | |
Ok(( | |
Type::String, | |
MyAST::ToString(Type::String, Box::new(expr_tc)), | |
)) | |
} else { | |
Err(TypeError::MismatchedTypes) | |
} | |
} | |
} | |
} | |
fn optimize(ast: MyAST<PTc>) -> MyAST<POpt> { | |
match ast { | |
MyAST::Number(ext, value) => MyAST::Number(ext, value), | |
MyAST::String(ext, value) => MyAST::String(ext, value), | |
MyAST::Add(ext, left, right) => { | |
MyAST::Add(ext, Box::new(optimize(*left)), Box::new(optimize(*right))) | |
} | |
MyAST::Subtract(ext, left, right) => MyAST::Add( | |
ext, | |
Box::new(optimize(*left)), | |
Box::new(MyAST::Negate(ext, Box::new(optimize(*right)))), | |
), | |
MyAST::Negate(ext, expr) => MyAST::Negate(ext, Box::new(optimize(*expr))), | |
MyAST::ToString(ext, expr) => MyAST::ToString(ext, Box::new(optimize(*expr))), | |
} | |
} | |
fn _complexity<T: XMyAST>(ast: &MyAST<T>) -> usize { | |
match ast { | |
MyAST::Number(_, _) => 1, | |
MyAST::String(_, _) => 1, | |
MyAST::Add(_, left, right) => 1 + _complexity(left) + _complexity(right), | |
MyAST::Subtract(_, left, right) => 1 + _complexity(left) + _complexity(right), | |
MyAST::Negate(_, expr) => 1 + _complexity(expr), | |
MyAST::ToString(_, expr) => 1 + _complexity(expr), | |
} | |
} | |
#[derive(Default)] | |
struct CompileState { | |
statements: Vec<String>, | |
last_var: usize, | |
} | |
impl CompileState { | |
fn next_var(&mut self) -> String { | |
self.last_var += 1; | |
format!("var{}", self.last_var) | |
} | |
fn set_stmt(&mut self, value: String) -> String { | |
let var = self.next_var(); | |
self.add_stmt(format!("let {} = {}", var, value)); | |
var | |
} | |
fn add_stmt(&mut self, value: String) { | |
self.statements.push(value); | |
} | |
} | |
fn compile(ast: MyAST<POpt>) -> String { | |
fn infix( | |
state: &mut CompileState, | |
left: Box<MyAST<POpt>>, | |
right: Box<MyAST<POpt>>, | |
fmt: impl FnOnce(String, String) -> String, | |
) -> String { | |
let l = go(state, *left); | |
let r = go(state, *right); | |
state.set_stmt(fmt(l, r)) | |
} | |
fn go(state: &mut CompileState, node: MyAST<POpt>) -> String { | |
match node { | |
MyAST::Add(ty, l, r) => match ty { | |
Type::Number => infix(state, l, r, |a, b| format!("{} + {}", a, b)), | |
Type::String => infix(state, l, r, |a, b| { | |
format!("format!(\"{{}}{{}}\", {}, {})", a, b) | |
}), | |
}, | |
MyAST::Subtract(never, _, _) => match never {}, | |
MyAST::Number(_, value) => state.set_stmt(value.to_string()), | |
MyAST::String(_, value) => state.set_stmt(format!("\"{}\"", value)), | |
MyAST::Negate(_, expr) => { | |
let expr_code = go(state, *expr); | |
state.set_stmt(format!("-{}", expr_code)) | |
} | |
MyAST::ToString(_, expr) => { | |
let var = go(state, *expr); | |
state.set_stmt(format!("{}.to_string()", var)) | |
} | |
} | |
} | |
let mut st = CompileState::default(); | |
let var = go(&mut st, ast); | |
st.statements | |
.iter() | |
.map(|s| format!("{};", s)) | |
.chain(std::iter::once(var)) | |
.collect::<Vec<_>>() | |
.join("\n") | |
} | |
fn main() { | |
let ast = { | |
use MyAST::*; | |
Add( | |
(), | |
Box::new(String((), "Result: ".to_string())), | |
Box::new(ToString( | |
(), | |
Box::new(Subtract( | |
(), | |
Box::new(Number((), 10)), | |
Box::new(Number((), 5)), | |
)), | |
)), | |
) | |
}; | |
let tc = typecheck(ast).expect("Type checking failed"); | |
let opt = optimize(tc.1); | |
let compiled_code = compile(opt); | |
println!("{}", compiled_code); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment