Skip to content

Instantly share code, notes, and snippets.

@nihalpasham
Created August 24, 2025 06:15
Show Gist options
  • Save nihalpasham/f596eac2d8ebd753c280797f8989e592 to your computer and use it in GitHub Desktop.
Save nihalpasham/f596eac2d8ebd753c280797f8989e592 to your computer and use it in GitHub Desktop.
Luminal - a deep learning framework in Rust.
image

High-level overview focusing on graph structure, search-based compilation, and visualization:

Luminal Architecture Overview

1. Core Graph Structure

Luminal is built around a directed acyclic graph (DAG) representation where:

/// A Luminal compute graph.
///
/// All computation is represented as a directed acyclic graph.
/// All data is stored inside this object as well.
#[derive(Debug, Default)]
pub struct Graph {
    /// The store of tensors in the graph. Indexed by node index and output index.
    pub tensors: FxHashMap<(NodeIndex, u8), Tensor>,
    /// A map of dynamic dimensions to concrete dimension sizes
    pub dyn_map: FxHashMap<char, usize>,
    /// Edge weights: (Input index, Output index, Input shape)
    pub graph: StorageGraph,
    /// Tensors marked in this set will not get deleted when the graph is ran
    pub no_delete: FxHashSet<NodeIndex>,

Key Components:

  • Nodes: Represent operations (ops) like matrix multiplication, unary functions, etc.
  • Edges: Carry tensor data between operations with shape information
  • GraphTensor: High-level interface for building computation graphs.
/// A tensor on the graph.
///
/// Graphs can be built by performing operations on these tensors.
#[derive(Clone, Copy)]
pub struct GraphTensor {
    pub id: NodeIndex,
    pub graph_ref: *mut Graph,
    pub shape: ShapeTracker,
}

2. Search-Based Compilation Architecture

Luminal's key innovation is search-based compilation that automatically discovers optimal implementations:

Phase 1: Graph Translation

The high-level graph is translated into a lower-level representation with loop structures:

pub fn translate_graph(
    graph: &Graph,
) -> (StableGraph<GraphTerm, (), Directed>, Vec<(String, f32)>) {
    let mut new_graph = StableGraph::new();
    let mut node_mapping = FxHashMap::default();
    let mut accumulators = vec![];
    for node in toposort(&graph.graph, None).unwrap() {
        let node_weight = graph.node_weight(node).unwrap();
        let op_name_full = format!("{node_weight:?}");
        let op = op_name_full
            .split('|')
            .next()
            .unwrap_or(&op_name_full)
            .trim();
        let mut sources = graph.get_sources(node);
        match op {
            "Sqrt" | "Exp2" | "Sin" | "Contiguous" => {

Phase 2: E-graph Construction

The system builds an e-graph (equality graph) that represents all possible equivalent transformations:

pub fn build_search_space(
    graph: &StableGraph<GraphTerm, (), Directed>,
    iters: usize,
    remove_tiling: bool,
) -> egraph_serialize::EGraph {
    let (rendered, root) = render_egglog(graph);
    if option_env!("PRINT_KERNELS").is_some() {
        println!("{rendered}");
    }
    let code = include_str!("code.lisp");

    let mut final_code = code
        .replace("{code}", &rendered)
        .replace("{iters}", &iters.to_string());

Phase 3: Search and Evaluation

The system searches through the e-graph to find optimal implementations:

pub fn search(
    egraph: &EGraph,
    inputs: &[(NodeIndex, Vec<f32>)],
    arch: GPUArch,
) -> Option<StableGraph<Kernel, (u8, u8)>> {
    fn recurse<'a>(
        egraph: &'a EGraph,
        current_class: &'a ClassId,
        seen: &mut FxHashMap<&'a NodeId, usize>,
    ) -> (Vec<Vec<&'a NodeId>>, usize) {
        let mut trajectories = vec![];
        let mut total_completed = 0;
        'enode_loop: for enode in &egraph.classes()[current_class].nodes {
            if total_completed >= MAX_SEARCHED_GRAPHS {
                break;
            }

Search Process:

  1. DFS Traversal: Explores different paths through the e-graph
  2. Code Generation: Each path is compiled to GPU kernels
  3. Performance Evaluation: Kernels are executed and timed
  4. Best Selection: The fastest implementation is chosen

3. Optimization Transformations

The system applies various optimizations defined in egglog rules:

; Loop Fusion
(rewrite (LoopIn (LoopOut ?x ?loop ?st) ?loop ?st) ?x
	;:when ((!= ?st (MAccum ?y))) ; don't fuse if we're accumulating that loop
) ; this is causing infinite loops in the e-graph!

; Loop Fission


; Loop tiling
(rewrite
	(LoopOut ?body (Loop ?loop (MNum ?range)) ?stride)
	(LoopOut
		(LoopOut
			(TileLoop ?body ?loop)
			(Loop (+ ?loop "_tile") (MNum 8))
			?stride
		)
		(Loop ?loop (MNum (/ ?range 8)))
		(MReplace ?stride (MVar "z") (MMul (MVar "z") (MNum 8)))
	)
	:when ((> ?range 8) (= (% ?range 8) 0))
)

Key Optimizations:

  • Loop Fusion: Combines adjacent loops to reduce memory traffic
  • Loop Tiling: Breaks large loops into smaller tiles for better cache usage
  • Loop Swapping: Reorders nested loops for better memory access patterns

4. Visualization Capabilities

Luminal provides several visualization tools for debugging and understanding:

Graph Visualization

/// View a debug graph in the browser
pub fn display_graph<G: TermToString, E: EdgeToString>(
    graph: &StableGraph<G, E, Directed, u32>,
    mark_nodes: &[(NodeIndex, String)],
) {
    let mut new_graph = StableGraph::new();
    let mut map = HashMap::new();
    for node in graph.node_indices() {
        map.insert(
            node,
            new_graph.add_node(graph.node_weight(node).unwrap().term_to_string()),
        );
    }

Visualization Features:

  • Browser-based: Opens graphs in GraphViz Online
  • Node Highlighting: Can mark specific nodes with colors
  • Multiple Graph Types: Supports computation graphs, kernel graphs, and e-graphs
  • Debug Execution: Shows timing information for each operation

Debug Execution

/// Execute the graph with debug prints
pub fn execute_debug(&mut self) {
    fn format_duration(duration: &Duration) -> String {
        if duration.as_secs() > 0 {
            format!("{:.2}s", duration.as_secs_f32())
        } else if duration.as_millis() > 0 {
            format!("{}ms", duration.as_millis())
        } else {
            format!("{}µs", duration.as_micros())
        }
    }

5. Multi-Backend Support

Luminal supports multiple execution backends:

  • CPU: Basic CPU execution
  • Metal: Apple's Metal for GPU acceleration
  • CUDA: NVIDIA GPU support (in development)

6. Getting Started with Visualization

To visualize graphs in your code:

  1. Basic Graph Visualization:
use luminal::prelude::*;

let mut cx = Graph::new();
let a = cx.tensor((3, 1)).set([[1.0], [2.0], [3.0]]);
let b = cx.tensor((1, 4)).set([[1.0, 2.0, 3.0, 4.0]]);
let c = a.matmul(b).retrieve();

// Uncomment to visualize the computation graph
// luminal::compiler_utils::display_graph(&cx.graph, &[], &[]);
  1. Search-Based Compilation Visualization:
// Translate to lower-level representation
let (new_graph, _) = luminal::search::translate_graph(&cx);
// Visualize the translated graph
luminal::search::display_graph_2(&new_graph, &[]);
  1. Debug Execution:
cx.execute_debug(); // Shows timing for each operation

The search-based compilation is Luminal's core innovation - it automatically discovers optimal GPU kernel implementations by exploring a space of equivalent program transformations and empirically measuring their performance.

Great question! Let me explain exactly what's happening in this linear layer example step by step.

Perfect! Now I can explain exactly what's happening in your linear layer example. Let me break it down step by step:

Linear Layer Breakdown

1. Weight Matrix Organization

In this example:

  • Input size: 4 (4 input features)
  • Output size: 5 (5 output features)
  • Total weights: 20 (4 × 5 = 20)
pub fn new(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
    Self {
        weight: cx.named_tensor("Weight", (inp, out)),
        bias: if bias {
            Some(cx.named_tensor("Bias", out))
        } else {
            None
        },
        permute: false,
    }
}

The weight matrix has shape (4, 5) - 4 rows, 5 columns:

Weight Matrix (4×5):
[w₀₀  w₀₁  w₀₂  w₀₃  w₀₄]
[w₁₀  w₁₁  w₁₂  w₁₃  w₁₄]  
[w₂₀  w₂₁  w₂₂  w₂₃  w₂₄]
[w₃₀  w₃₁  w₃₂  w₃₃  w₃₄]

20 weights are stored row-major as: [w₀₀, w₀₁, w₀₂, w₀₃, w₀₄, w₁₀, w₁₁, ...]

2. Forward Pass Math

fn forward(&self, input: GraphTensor) -> Self::Output {
    let mut output = input.matmul(if self.permute {
        self.weight.permute((1, 0))
    } else {
        self.weight
    });
    if let Some(bias) = self.bias {
        output += bias.expand(output.shape);
    }
    output
}

The forward pass performs: output = input @ weights (matrix multiplication)

With our data:

  • Input: [1.0, 2.0, 3.0, 4.0] (shape: (4,))
  • Weights: 20 values in (4, 5) matrix
  • Output: 5 values (shape: (5,))

3. Matrix Multiplication Details

pub fn matmul(mut self, mut rhs: GraphTensor) -> Self {
    if (self.shape.len() == 1 || self.shape.len() == 2) && rhs.shape.len() == 2 {
        let vec = self.shape.len() == 1;
        if vec {
            self = self.expand_dim(0, 1);
        }
        let (m, _) = self.dims2();
        let (_, n) = rhs.dims2();
        // Broadcasted Multiply
        let mul = self.expand_dim(1, n) * rhs.permute((1, 0)).expand_dim(0, m);

        // Sum Reduce
        let mut ret = mul.sum(2);
        if vec {
            ret = ret.reshape(ret.dims().last().unwrap());
        }
        ret

Step-by-step computation:

  1. Input vector: [1.0, 2.0, 3.0, 4.0] is treated as (1, 4) matrix
  2. Weight matrix: (4, 5)
  3. Matrix multiplication: (1, 4) @ (4, 5) = (1, 5)
  4. Result: Reshaped back to (5,) vector

Mathematical formula for each output element:

output[0] = 1.0*w₀₀ + 2.0*w₁₀ + 3.0*w₂₀ + 4.0*w₃₀
output[1] = 1.0*w₀₁ + 2.0*w₁₁ + 3.0*w₂₁ + 4.0*w₃₁  
output[2] = 1.0*w₀₂ + 2.0*w₁₂ + 3.0*w₂₂ + 4.0*w₃₂
output[3] = 1.0*w₀₃ + 2.0*w₁₃ + 3.0*w₂₃ + 4.0*w₃₃
output[4] = 1.0*w₀₄ + 2.0*w₁₄ + 3.0*w₂₄ + 4.0*w₃₄

4. Concrete Example

Let's say your 20 random weights are:

[0.1, 0.2, 0.3, 0.4, 0.5,    // Row 0: weights for input[0]
 0.6, 0.7, 0.8, 0.9, 1.0,    // Row 1: weights for input[1]  
 1.1, 1.2, 1.3, 1.4, 1.5,    // Row 2: weights for input[2]
 1.6, 1.7, 1.8, 1.9, 2.0]    // Row 3: weights for input[3]

Then:

output[0] = 1.0*0.1 + 2.0*0.6 + 3.0*1.1 + 4.0*1.6 = 0.1 + 1.2 + 3.3 + 6.4 = 11.0
output[1] = 1.0*0.2 + 2.0*0.7 + 3.0*1.2 + 4.0*1.7 = 0.2 + 1.4 + 3.6 + 6.8 = 12.0
output[2] = 1.0*0.3 + 2.0*0.8 + 3.0*1.3 + 4.0*1.8 = 0.3 + 1.6 + 3.9 + 7.2 = 13.0
output[3] = 1.0*0.4 + 2.0*0.9 + 3.0*1.4 + 4.0*1.9 = 0.4 + 1.8 + 4.2 + 7.6 = 14.0
output[4] = 1.0*0.5 + 2.0*1.0 + 3.0*1.5 + 4.0*2.0 = 0.5 + 2.0 + 4.5 + 8.0 = 15.0

5. What the Graph Shows

When you call cx.display(), you'll see nodes representing:

  1. Input tensor loading (your [1,2,3,4])
  2. Weight tensor loading (your 20 weights)
  3. Matrix multiplication operation
  4. Output tensor

The linear layer is essentially a learned transformation that maps 4-dimensional input space to 5-dimensional output space using the weight matrix as the transformation parameters.

Internal implementation:

The internal implementation of matrix multiplication in Luminal. We see both Mul and SumReduce operations:

Matrix Multiplication Implementation

Luminal implements matrix multiplication as a two-step process:

  1. Broadcasted Element-wise Multiplication (Mul)
  2. Sum Reduction (SumReduce)
pub fn matmul(mut self, mut rhs: GraphTensor) -> Self {
    if (self.shape.len() == 1 || self.shape.len() == 2) && rhs.shape.len() == 2 {
        let vec = self.shape.len() == 1;
        if vec {
            self = self.expand_dim(0, 1);
        }
        let (m, _) = self.dims2();
        let (_, n) = rhs.dims2();
        // Broadcasted Multiply
        let mul = self.expand_dim(1, n) * rhs.permute((1, 0)).expand_dim(0, m);

        // Sum Reduce
        let mut ret = mul.sum(2);

Step-by-Step Breakdown

Our Case: Vector × Matrix

  • Input: [1, 2, 3, 4] shape (4,) → expanded to (1, 4)
  • Weights: (4, 5) matrix → transposed to (5, 4)

Step 1: Broadcasting and Expansion

Input (1, 4):     [[1, 2, 3, 4]]
                  ↓ expand_dim(1, 5)
Input (1, 5, 4):  [[[1, 2, 3, 4],
                    [1, 2, 3, 4],
                    [1, 2, 3, 4],
                    [1, 2, 3, 4],
                    [1, 2, 3, 4]]]

Weights (5, 4):   [[w₀₀, w₀₁, w₀₂, w₀₃],
                   [w₁₀, w₁₁, w₁₂, w₁₃],
                   [w₂₀, w₂₁, w₂₂, w₂₃],
                   [w₃₀, w₃₁, w₃₂, w₃₃],
                   [w₄₀, w₄₁, w₄₂, w₄₃]]
                  ↓ expand_dim(0, 1)
Weights (1, 5, 4): Same as above but with batch dimension

Step 2: Element-wise Multiplication (Mul)

This creates a (1, 5, 4) tensor where each element is:

result[0, i, j] = input[0, i, j] * weights[0, i, j]

Output shape: [1, 5, 4] - This is what you see in the debug output!

Step 3: Sum Reduction (SumReduce(2))

The SumReduce(2) means "sum along dimension 2" (the last dimension):

For each output[i]:
  output[i] = sum(result[0, i, :])
            = result[0, i, 0] + result[0, i, 1] + result[0, i, 2] + result[0, i, 3]

This gives you the final (1, 5) result, which gets reshaped to (5,).

Why This Approach?

This broadcast + reduce pattern is used because:

  1. GPU Efficiency: Modern GPUs are optimized for element-wise operations and reductions
  2. Parallelization: Each element-wise multiplication can happen in parallel
  3. Memory Access: Better memory coalescing patterns
  4. Compiler Optimization: Easier for the search-based compiler to optimize

Visual Representation

Input:    [1, 2, 3, 4]
Weights:  [[0.1, 0.2, 0.3, 0.4, 0.5],
           [0.6, 0.7, 0.8, 0.9, 1.0],
           [1.1, 1.2, 1.3, 1.4, 1.5],
           [1.6, 1.7, 1.8, 1.9, 2.0]]

After broadcasting and Mul:
[[[1*0.1, 2*0.6, 3*1.1, 4*1.6],    # For output[0]
  [1*0.2, 2*0.7, 3*1.2, 4*1.7],    # For output[1]
  [1*0.3, 2*0.8, 3*1.3, 4*1.8],    # For output[2]
  [1*0.4, 2*0.9, 3*1.4, 4*1.9],    # For output[3]
  [1*0.5, 2*1.0, 3*1.5, 4*2.0]]]   # For output[4]

After SumReduce(2):
[0.1+1.2+3.3+6.4,    # = 11.0
 0.2+1.4+3.6+6.8,    # = 12.0
 0.3+1.6+3.9+7.2,    # = 13.0
 0.4+1.8+4.2+7.6,    # = 14.0
 0.5+2.0+4.5+8.0]    # = 15.0

This is exactly the mathematical definition of matrix multiplication, just implemented as separate broadcast and reduction operations for better GPU performance.

The MetalCompiler is actually a type alias that combines multiple individual compilers, and the Compiler trait is implemented through tuple implementations.

MetalCompiler Definition

/// Compile graphs to run on Metal-supported macOS devices in supported data formats
pub type MetalCompiler<T> = (MetalCompilerPreBuffer<T>, BufferCompilers);

/// All metal compilers coming before buffer compilers
pub type MetalCompilerPreBuffer<T> = (
    prim::PrimitiveCompiler<T>,
    SpecialOpsCompiler<T>,
    other::CopyCompiler<T>,
    elementwise_fusion::ElementwiseFusionCompiler<T>,
);

/// Compilers to share command and storage buffers
pub type BufferCompilers = (
    command_buffer::CommandBufferCompiler,
    storage_buffer::StorageBufferCompiler,
);

/// Compiler to replace metal ops with specialized variants
pub type SpecialOpsCompiler<T> = (
    binary::MetalSubtractionCompiler<T>,
    binary::MetalEqualCompiler<T>,
    other::ARangeCompiler<T>,
    binary::MetalGatherCompiler<T>,
    unary::MetalExpCompiler<T>,
    unary::MetalCosCompiler<T>,
    unary::MeanReduceCompiler<T>,
    unary::StdNormCompiler<T>,
    matmul::MetalMatMulCompiler<T>,
);

Tuple Compilation

The Compiler trait is implemented for tuples in the core library.

Compiler Trait Definition

pub trait Compiler {
    type Output;
    /// Run a compilation pass
    fn compile<T: ToIdsMut>(&self, graph: &mut Graph, ids: T) -> Self::Output;
}

Tuple Implementation of Compiler

The magic happens with macro-generated tuple implementations:

macro_rules! tuple_impls {
    ([$($name:ident),+] , [$($idx:tt),+]) => {
        impl<
        $($name:
            Compiler, )+
        > Compiler for ($($name,)+) {
            type Output = ( $($name::Output, )+ );
            fn compile<T: ToIdsMut>(&self, graph: &mut Graph, mut remap: T) -> Self::Output {
                ( $(self.$idx.compile(graph, &mut remap), )+ )
            }
        }
    };
}

tuple_impls!([M1], [0]);
tuple_impls!([M1, M2], [0, 1]);
tuple_impls!([M1, M2, M3], [0, 1, 2]);
// ... up to 20 elements

How MetalCompiler Works

So when we call:

cx.compile(MetalCompiler::<f32>::default(), &mut b);

Here's what happens:

  1. MetalCompiler expands to:
(
    (
        prim::PrimitiveCompiler<f32>,
        SpecialOpsCompiler<f32>,
        other::CopyCompiler<f32>,
        elementwise_fusion::ElementwiseFusionCompiler<f32>,
    ),
    (
        command_buffer::CommandBufferCompiler,
        storage_buffer::StorageBufferCompiler,
    )
)
  1. The tuple implementation calls each compiler in sequence:
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, mut remap: T) -> Self::Output {
    (
        self.0.compile(graph, &mut remap),  // MetalCompilerPreBuffer
        self.1.compile(graph, &mut remap),  // BufferCompilers
    )
}
  1. Each nested tuple also gets compiled in sequence, so the full order is:
    • PrimitiveCompiler (converts basic ops to Metal ops)
    • SpecialOpsCompiler (handles matmul, subtraction, etc.)
    • CopyCompiler (manages device memory transfers)
    • ElementwiseFusionCompiler (fuses element-wise operations)
    • CommandBufferCompiler (optimizes command buffer usage)
    • StorageBufferCompiler (optimizes storage buffer usage)

Individual Compiler Examples

Each individual compiler implements the trait. For example:

impl<T: MetalFloat> Compiler for MetalMatMulCompiler<T> {
    type Output = ();
    fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) {
        let dev = Device::system_default().unwrap();
        let queue = dev.new_command_queue();

        // Look for the matmul pattern
        // Mul ([A, C(fake), B] | [A(fake), C, B]) -> SumReduce(2) -> [A, C]

Design allows composable compilation passes -i.e. we can mix and match different compilers, and they all run in sequence to progressively optimize the graph.

SelectGraph and Pattern Matching in Luminal’s Composable Compilers

What is SelectGraph?

SelectGraph is a pattern matching tool that helps compilers find specific operation sequences in the computation graph. Think of it as a "template" that describes what operations to look for.

Example: How it works in MetalMatMulCompiler:

  1. Define the Pattern: The compiler creates SelectGraph objects that describe the matrix multiplication pattern:

    let mut mul2d = op::<MetalMul<T>>();  // Look for Metal multiplication
    let mut sr2d = op::<MetalSumReduce<T>>(); // Look for sum reduction
  2. Set Constraints: It adds constraints like shapes and which dimensions are "fake" (broadcasted):

    mul2d.shapes([['M', 'N', 'K'], ['M', 'N', 'K']]);  // Expected shapes
    mul2d.fakes([[None, Some(true), Some(false)], ...]);  // Which dims are broadcasted
  3. Connect the Pattern: It connects the operations to form the complete pattern:

    let mut s2d = mul2d.clone().connect(sr2d.clone()).search(graph);

    This creates a pattern: Mul -> SumReduce (which is how matrix multiplication is implemented)

  4. Search for Matches: The search() method returns a GraphSearch object that can find all instances of this pattern in the graph:

    while s2d.next_match() {  // Find next occurrence of the pattern
        let (mul, sum_reduce) = (s2d.get(&mul2d), s2d.get(&sr2d));  // Get the actual nodes
        // Replace with optimized matmul kernel
    }

Why This is Useful:

  • Pattern Recognition: Automatically finds Mul + SumReduce sequences that represent matrix multiplication
  • Optimization: Replaces these inefficient sequences with a single optimized GEMM kernel
  • Multiple Dimensions: Handles 2D, 3D, 4D, and 5D cases with different patterns

In Simple Terms:

SelectGraph is like a "find and replace" tool for computation graphs. It says: "Find all places where you see Mul -> SumReduce with these specific shape patterns, and replace them with a single optimized matrix multiplication operation."

This is how Luminal automatically converts your high-level a.matmul(b) into optimized GPU kernels!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment