Skip to content

Instantly share code, notes, and snippets.

@utdemir
Created June 1, 2025 08:26
Show Gist options
  • Save utdemir/63cafcb74f60aebed21d95d5871d6c68 to your computer and use it in GitHub Desktop.
Save utdemir/63cafcb74f60aebed21d95d5871d6c68 to your computer and use it in GitHub Desktop.
Two ways to implement Trees That Grow in Rust
// Trees that Grow
trait HasExt<P> {
type Ext;
}
macro_rules! decl_ext {
($phase:ty, $name:ident, $ext:ty) => {
impl HasExt<$name> for $phase {
type Ext = $ext;
}
};
($phase:ty: $($name:ident => $ext:ty),* $(,)?) => {
$(decl_ext!($phase, $name, $ext);)*
};
}
type Ext<P, X> = <X as HasExt<P>>::Ext;
type Never = std::convert::Infallible;
// Example
enum MyAST<P>
where
P: HasExt<XNumber>
+ HasExt<XString>
+ HasExt<XAdd>
+ HasExt<XSubtract>
+ HasExt<XNegate>
+ HasExt<XToString>,
{
Number(Ext<XNumber, P>, i32),
String(Ext<XString, P>, String),
Add(Ext<XAdd, P>, Box<MyAST<P>>, Box<MyAST<P>>),
Subtract(Ext<XSubtract, P>, Box<MyAST<P>>, Box<MyAST<P>>),
Negate(Ext<XNegate, P>, Box<MyAST<P>>),
ToString(Ext<XToString, P>, Box<MyAST<P>>),
}
// Branches
struct XNumber;
struct XString;
struct XAdd;
struct XSubtract;
struct XNegate;
struct XToString;
// Phases
struct PInit;
struct PTc;
struct POpt;
// Types
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Type {
Number,
String,
}
#[derive(Clone, Copy, Debug)]
enum TypeError {
MismatchedTypes,
}
// Mappings
decl_ext!(PInit:
XNumber => (),
XString => (),
XAdd => (),
XSubtract => (),
XNegate => (),
XToString => (),
);
decl_ext!(PTc:
XNumber => Type,
XString => Type,
XAdd => Type,
XSubtract => Type,
XNegate => Type,
XToString => Type,
);
decl_ext!(POpt:
XNumber => Type,
XString => Type,
XAdd => Type,
XSubtract => Never,
XNegate => Type,
XToString => Type,
);
// Implementations
fn typecheck(ast: MyAST<PInit>) -> Result<(Type, MyAST<PTc>), TypeError> {
match ast {
MyAST::Number((), value) => Ok((Type::Number, MyAST::Number(Type::Number, value))),
MyAST::String((), value) => Ok((Type::String, MyAST::String(Type::String, value))),
MyAST::Add((), left, right) => {
let (left_ty, lhs) = typecheck(*left)?;
let (right_ty, rhs) = typecheck(*right)?;
if left_ty != right_ty {
return Err(TypeError::MismatchedTypes);
} else {
Ok((left_ty, MyAST::Add(left_ty, Box::new(lhs), Box::new(rhs))))
}
}
MyAST::Subtract((), left, right) => {
if let ((Type::Number, lhs), (Type::Number, rhs)) =
(typecheck(*left)?, typecheck(*right)?)
{
Ok((
Type::Number,
MyAST::Subtract(Type::Number, Box::new(lhs), Box::new(rhs)),
))
} else {
Err(TypeError::MismatchedTypes)
}
}
MyAST::Negate((), expr) => {
if let (Type::Number, expr_tc) = typecheck(*expr)? {
Ok((Type::Number, MyAST::Negate(Type::Number, Box::new(expr_tc))))
} else {
Err(TypeError::MismatchedTypes)
}
}
MyAST::ToString((), expr) => {
if let (Type::Number, expr_tc) = typecheck(*expr)? {
Ok((
Type::String,
MyAST::ToString(Type::String, Box::new(expr_tc)),
))
} else {
Err(TypeError::MismatchedTypes)
}
}
}
}
fn optimize(ast: MyAST<PTc>) -> MyAST<POpt> {
match ast {
MyAST::Number(ext, value) => MyAST::Number(ext, value),
MyAST::String(ext, value) => MyAST::String(ext, value),
MyAST::Add(ext, left, right) => {
MyAST::Add(ext, Box::new(optimize(*left)), Box::new(optimize(*right)))
}
MyAST::Subtract(ext, left, right) => MyAST::Add(
ext,
Box::new(optimize(*left)),
Box::new(MyAST::Negate(ext, Box::new(optimize(*right)))),
),
MyAST::Negate(ext, expr) => MyAST::Negate(ext, Box::new(optimize(*expr))),
MyAST::ToString(ext, expr) => MyAST::ToString(ext, Box::new(optimize(*expr))),
}
}
fn _complexity<T>(ast: &MyAST<T>) -> usize
where
T: HasExt<XNumber>
+ HasExt<XString>
+ HasExt<XAdd>
+ HasExt<XSubtract>
+ HasExt<XNegate>
+ HasExt<XToString>,
{
match ast {
MyAST::Number(_, _) => 1,
MyAST::String(_, _) => 1,
MyAST::Add(_, left, right) => 1 + _complexity(left) + _complexity(right),
MyAST::Subtract(_, left, right) => 1 + _complexity(left) + _complexity(right),
MyAST::Negate(_, expr) => 1 + _complexity(expr),
MyAST::ToString(_, expr) => 1 + _complexity(expr),
}
}
#[derive(Default)]
struct CompileState {
statements: Vec<String>,
last_var: usize,
}
impl CompileState {
fn next_var(&mut self) -> String {
self.last_var += 1;
format!("var{}", self.last_var)
}
fn set_stmt(&mut self, value: String) -> String {
let var = self.next_var();
self.add_stmt(format!("let {} = {}", var, value));
var
}
fn add_stmt(&mut self, value: String) {
self.statements.push(value);
}
}
fn compile(ast: MyAST<POpt>) -> String {
fn infix(
state: &mut CompileState,
left: Box<MyAST<POpt>>,
right: Box<MyAST<POpt>>,
fmt: impl FnOnce(String, String) -> String,
) -> String {
let l = go(state, *left);
let r = go(state, *right);
state.set_stmt(fmt(l, r))
}
fn go(state: &mut CompileState, node: MyAST<POpt>) -> String {
match node {
MyAST::Add(ty, l, r) => match ty {
Type::Number => infix(state, l, r, |a, b| format!("{} + {}", a, b)),
Type::String => infix(state, l, r, |a, b| {
format!("format!(\"{{}}{{}}\", {}, {})", a, b)
}),
},
MyAST::Subtract(never, _, _) => match never {},
MyAST::Number(_, value) => state.set_stmt(value.to_string()),
MyAST::String(_, value) => state.set_stmt(format!("\"{}\"", value)),
MyAST::Negate(_, expr) => {
let expr_code = go(state, *expr);
state.set_stmt(format!("-{}", expr_code))
}
MyAST::ToString(_, expr) => {
let var = go(state, *expr);
state.set_stmt(format!("{}.to_string()", var))
}
}
}
let mut st = CompileState::default();
let var = go(&mut st, ast);
st.statements
.iter()
.map(|s| format!("{};", s))
.chain(std::iter::once(var))
.collect::<Vec<_>>()
.join("\n")
}
fn main() {
let ast = {
use MyAST::*;
Add(
(),
Box::new(String((), "Result: ".to_string())),
Box::new(ToString(
(),
Box::new(Subtract(
(),
Box::new(Number((), 10)),
Box::new(Number((), 5)),
)),
)),
)
};
let tc = typecheck(ast).expect("Type checking failed");
let opt = optimize(tc.1);
let compiled_code = compile(opt);
println!("{}", compiled_code);
}
// Example
trait XMyAST {
type Number;
type String;
type Add;
type Subtract;
type Negate;
type ToString;
}
enum MyAST<X: XMyAST> {
Number(X::Number, i32),
String(X::String, String),
Add(X::Add, Box<MyAST<X>>, Box<MyAST<X>>),
Subtract(X::Subtract, Box<MyAST<X>>, Box<MyAST<X>>),
Negate(X::Negate, Box<MyAST<X>>),
ToString(X::ToString, Box<MyAST<X>>),
}
// Types
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Type {
Number,
String,
}
#[derive(Clone, Copy, Debug)]
enum TypeError {
MismatchedTypes,
}
type Never = std::convert::Infallible;
// Mappings
struct PInit;
struct PTc;
struct POpt;
impl XMyAST for PInit {
type Number = ();
type String = ();
type Add = ();
type Subtract = ();
type Negate = ();
type ToString = ();
}
impl XMyAST for PTc {
type Number = Type;
type String = Type;
type Add = Type;
type Subtract = Type;
type Negate = Type;
type ToString = Type;
}
impl XMyAST for POpt {
type Number = Type;
type String = Type;
type Add = Type;
type Subtract = Never;
type Negate = Type;
type ToString = Type;
}
// Implementations
fn typecheck(ast: MyAST<PInit>) -> Result<(Type, MyAST<PTc>), TypeError> {
match ast {
MyAST::Number((), value) => Ok((Type::Number, MyAST::Number(Type::Number, value))),
MyAST::String((), value) => Ok((Type::String, MyAST::String(Type::String, value))),
MyAST::Add((), left, right) => {
let (left_ty, lhs) = typecheck(*left)?;
let (right_ty, rhs) = typecheck(*right)?;
if left_ty != right_ty {
return Err(TypeError::MismatchedTypes);
} else {
Ok((left_ty, MyAST::Add(left_ty, Box::new(lhs), Box::new(rhs))))
}
}
MyAST::Subtract((), left, right) => {
if let ((Type::Number, lhs), (Type::Number, rhs)) =
(typecheck(*left)?, typecheck(*right)?)
{
Ok((
Type::Number,
MyAST::Subtract(Type::Number, Box::new(lhs), Box::new(rhs)),
))
} else {
Err(TypeError::MismatchedTypes)
}
}
MyAST::Negate((), expr) => {
if let (Type::Number, expr_tc) = typecheck(*expr)? {
Ok((Type::Number, MyAST::Negate(Type::Number, Box::new(expr_tc))))
} else {
Err(TypeError::MismatchedTypes)
}
}
MyAST::ToString((), expr) => {
if let (Type::Number, expr_tc) = typecheck(*expr)? {
Ok((
Type::String,
MyAST::ToString(Type::String, Box::new(expr_tc)),
))
} else {
Err(TypeError::MismatchedTypes)
}
}
}
}
fn optimize(ast: MyAST<PTc>) -> MyAST<POpt> {
match ast {
MyAST::Number(ext, value) => MyAST::Number(ext, value),
MyAST::String(ext, value) => MyAST::String(ext, value),
MyAST::Add(ext, left, right) => {
MyAST::Add(ext, Box::new(optimize(*left)), Box::new(optimize(*right)))
}
MyAST::Subtract(ext, left, right) => MyAST::Add(
ext,
Box::new(optimize(*left)),
Box::new(MyAST::Negate(ext, Box::new(optimize(*right)))),
),
MyAST::Negate(ext, expr) => MyAST::Negate(ext, Box::new(optimize(*expr))),
MyAST::ToString(ext, expr) => MyAST::ToString(ext, Box::new(optimize(*expr))),
}
}
fn _complexity<T: XMyAST>(ast: &MyAST<T>) -> usize {
match ast {
MyAST::Number(_, _) => 1,
MyAST::String(_, _) => 1,
MyAST::Add(_, left, right) => 1 + _complexity(left) + _complexity(right),
MyAST::Subtract(_, left, right) => 1 + _complexity(left) + _complexity(right),
MyAST::Negate(_, expr) => 1 + _complexity(expr),
MyAST::ToString(_, expr) => 1 + _complexity(expr),
}
}
#[derive(Default)]
struct CompileState {
statements: Vec<String>,
last_var: usize,
}
impl CompileState {
fn next_var(&mut self) -> String {
self.last_var += 1;
format!("var{}", self.last_var)
}
fn set_stmt(&mut self, value: String) -> String {
let var = self.next_var();
self.add_stmt(format!("let {} = {}", var, value));
var
}
fn add_stmt(&mut self, value: String) {
self.statements.push(value);
}
}
fn compile(ast: MyAST<POpt>) -> String {
fn infix(
state: &mut CompileState,
left: Box<MyAST<POpt>>,
right: Box<MyAST<POpt>>,
fmt: impl FnOnce(String, String) -> String,
) -> String {
let l = go(state, *left);
let r = go(state, *right);
state.set_stmt(fmt(l, r))
}
fn go(state: &mut CompileState, node: MyAST<POpt>) -> String {
match node {
MyAST::Add(ty, l, r) => match ty {
Type::Number => infix(state, l, r, |a, b| format!("{} + {}", a, b)),
Type::String => infix(state, l, r, |a, b| {
format!("format!(\"{{}}{{}}\", {}, {})", a, b)
}),
},
MyAST::Subtract(never, _, _) => match never {},
MyAST::Number(_, value) => state.set_stmt(value.to_string()),
MyAST::String(_, value) => state.set_stmt(format!("\"{}\"", value)),
MyAST::Negate(_, expr) => {
let expr_code = go(state, *expr);
state.set_stmt(format!("-{}", expr_code))
}
MyAST::ToString(_, expr) => {
let var = go(state, *expr);
state.set_stmt(format!("{}.to_string()", var))
}
}
}
let mut st = CompileState::default();
let var = go(&mut st, ast);
st.statements
.iter()
.map(|s| format!("{};", s))
.chain(std::iter::once(var))
.collect::<Vec<_>>()
.join("\n")
}
fn main() {
let ast = {
use MyAST::*;
Add(
(),
Box::new(String((), "Result: ".to_string())),
Box::new(ToString(
(),
Box::new(Subtract(
(),
Box::new(Number((), 10)),
Box::new(Number((), 5)),
)),
)),
)
};
let tc = typecheck(ast).expect("Type checking failed");
let opt = optimize(tc.1);
let compiled_code = compile(opt);
println!("{}", compiled_code);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment