High-level overview focusing on graph structure, search-based compilation, and visualization:
Luminal is built around a directed acyclic graph (DAG) representation where:
/// A Luminal compute graph.
///
/// All computation is represented as a directed acyclic graph.
/// All data is stored inside this object as well.
#[derive(Debug, Default)]
pub struct Graph {
/// The store of tensors in the graph. Indexed by node index and output index.
pub tensors: FxHashMap<(NodeIndex, u8), Tensor>,
/// A map of dynamic dimensions to concrete dimension sizes
pub dyn_map: FxHashMap<char, usize>,
/// Edge weights: (Input index, Output index, Input shape)
pub graph: StorageGraph,
/// Tensors marked in this set will not get deleted when the graph is ran
pub no_delete: FxHashSet<NodeIndex>,Key Components:
- Nodes: Represent operations (ops) like matrix multiplication, unary functions, etc.
- Edges: Carry tensor data between operations with shape information
- GraphTensor: High-level interface for building computation graphs.
/// A tensor on the graph.
///
/// Graphs can be built by performing operations on these tensors.
#[derive(Clone, Copy)]
pub struct GraphTensor {
pub id: NodeIndex,
pub graph_ref: *mut Graph,
pub shape: ShapeTracker,
}Luminal's key innovation is search-based compilation that automatically discovers optimal implementations:
The high-level graph is translated into a lower-level representation with loop structures:
pub fn translate_graph(
graph: &Graph,
) -> (StableGraph<GraphTerm, (), Directed>, Vec<(String, f32)>) {
let mut new_graph = StableGraph::new();
let mut node_mapping = FxHashMap::default();
let mut accumulators = vec![];
for node in toposort(&graph.graph, None).unwrap() {
let node_weight = graph.node_weight(node).unwrap();
let op_name_full = format!("{node_weight:?}");
let op = op_name_full
.split('|')
.next()
.unwrap_or(&op_name_full)
.trim();
let mut sources = graph.get_sources(node);
match op {
"Sqrt" | "Exp2" | "Sin" | "Contiguous" => {The system builds an e-graph (equality graph) that represents all possible equivalent transformations:
pub fn build_search_space(
graph: &StableGraph<GraphTerm, (), Directed>,
iters: usize,
remove_tiling: bool,
) -> egraph_serialize::EGraph {
let (rendered, root) = render_egglog(graph);
if option_env!("PRINT_KERNELS").is_some() {
println!("{rendered}");
}
let code = include_str!("code.lisp");
let mut final_code = code
.replace("{code}", &rendered)
.replace("{iters}", &iters.to_string());The system searches through the e-graph to find optimal implementations:
pub fn search(
egraph: &EGraph,
inputs: &[(NodeIndex, Vec<f32>)],
arch: GPUArch,
) -> Option<StableGraph<Kernel, (u8, u8)>> {
fn recurse<'a>(
egraph: &'a EGraph,
current_class: &'a ClassId,
seen: &mut FxHashMap<&'a NodeId, usize>,
) -> (Vec<Vec<&'a NodeId>>, usize) {
let mut trajectories = vec![];
let mut total_completed = 0;
'enode_loop: for enode in &egraph.classes()[current_class].nodes {
if total_completed >= MAX_SEARCHED_GRAPHS {
break;
}Search Process:
- DFS Traversal: Explores different paths through the e-graph
- Code Generation: Each path is compiled to GPU kernels
- Performance Evaluation: Kernels are executed and timed
- Best Selection: The fastest implementation is chosen
The system applies various optimizations defined in egglog rules:
; Loop Fusion
(rewrite (LoopIn (LoopOut ?x ?loop ?st) ?loop ?st) ?x
;:when ((!= ?st (MAccum ?y))) ; don't fuse if we're accumulating that loop
) ; this is causing infinite loops in the e-graph!
; Loop Fission
; Loop tiling
(rewrite
(LoopOut ?body (Loop ?loop (MNum ?range)) ?stride)
(LoopOut
(LoopOut
(TileLoop ?body ?loop)
(Loop (+ ?loop "_tile") (MNum 8))
?stride
)
(Loop ?loop (MNum (/ ?range 8)))
(MReplace ?stride (MVar "z") (MMul (MVar "z") (MNum 8)))
)
:when ((> ?range 8) (= (% ?range 8) 0))
)Key Optimizations:
- Loop Fusion: Combines adjacent loops to reduce memory traffic
- Loop Tiling: Breaks large loops into smaller tiles for better cache usage
- Loop Swapping: Reorders nested loops for better memory access patterns
Luminal provides several visualization tools for debugging and understanding:
/// View a debug graph in the browser
pub fn display_graph<G: TermToString, E: EdgeToString>(
graph: &StableGraph<G, E, Directed, u32>,
mark_nodes: &[(NodeIndex, String)],
) {
let mut new_graph = StableGraph::new();
let mut map = HashMap::new();
for node in graph.node_indices() {
map.insert(
node,
new_graph.add_node(graph.node_weight(node).unwrap().term_to_string()),
);
}Visualization Features:
- Browser-based: Opens graphs in GraphViz Online
- Node Highlighting: Can mark specific nodes with colors
- Multiple Graph Types: Supports computation graphs, kernel graphs, and e-graphs
- Debug Execution: Shows timing information for each operation
/// Execute the graph with debug prints
pub fn execute_debug(&mut self) {
fn format_duration(duration: &Duration) -> String {
if duration.as_secs() > 0 {
format!("{:.2}s", duration.as_secs_f32())
} else if duration.as_millis() > 0 {
format!("{}ms", duration.as_millis())
} else {
format!("{}µs", duration.as_micros())
}
}Luminal supports multiple execution backends:
- CPU: Basic CPU execution
- Metal: Apple's Metal for GPU acceleration
- CUDA: NVIDIA GPU support (in development)
To visualize graphs in your code:
- Basic Graph Visualization:
use luminal::prelude::*;
let mut cx = Graph::new();
let a = cx.tensor((3, 1)).set([[1.0], [2.0], [3.0]]);
let b = cx.tensor((1, 4)).set([[1.0, 2.0, 3.0, 4.0]]);
let c = a.matmul(b).retrieve();
// Uncomment to visualize the computation graph
// luminal::compiler_utils::display_graph(&cx.graph, &[], &[]);- Search-Based Compilation Visualization:
// Translate to lower-level representation
let (new_graph, _) = luminal::search::translate_graph(&cx);
// Visualize the translated graph
luminal::search::display_graph_2(&new_graph, &[]);- Debug Execution:
cx.execute_debug(); // Shows timing for each operationThe search-based compilation is Luminal's core innovation - it automatically discovers optimal GPU kernel implementations by exploring a space of equivalent program transformations and empirically measuring their performance.
Great question! Let me explain exactly what's happening in this linear layer example step by step.
Perfect! Now I can explain exactly what's happening in your linear layer example. Let me break it down step by step:
In this example:
- Input size: 4 (4 input features)
- Output size: 5 (5 output features)
- Total weights: 20 (4 × 5 = 20)
pub fn new(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
Self {
weight: cx.named_tensor("Weight", (inp, out)),
bias: if bias {
Some(cx.named_tensor("Bias", out))
} else {
None
},
permute: false,
}
}The weight matrix has shape (4, 5) - 4 rows, 5 columns:
Weight Matrix (4×5):
[w₀₀ w₀₁ w₀₂ w₀₃ w₀₄]
[w₁₀ w₁₁ w₁₂ w₁₃ w₁₄]
[w₂₀ w₂₁ w₂₂ w₂₃ w₂₄]
[w₃₀ w₃₁ w₃₂ w₃₃ w₃₄]
20 weights are stored row-major as: [w₀₀, w₀₁, w₀₂, w₀₃, w₀₄, w₁₀, w₁₁, ...]
fn forward(&self, input: GraphTensor) -> Self::Output {
let mut output = input.matmul(if self.permute {
self.weight.permute((1, 0))
} else {
self.weight
});
if let Some(bias) = self.bias {
output += bias.expand(output.shape);
}
output
}The forward pass performs: output = input @ weights (matrix multiplication)
With our data:
- Input:
[1.0, 2.0, 3.0, 4.0](shape:(4,)) - Weights: 20 values in
(4, 5)matrix - Output: 5 values (shape:
(5,))
pub fn matmul(mut self, mut rhs: GraphTensor) -> Self {
if (self.shape.len() == 1 || self.shape.len() == 2) && rhs.shape.len() == 2 {
let vec = self.shape.len() == 1;
if vec {
self = self.expand_dim(0, 1);
}
let (m, _) = self.dims2();
let (_, n) = rhs.dims2();
// Broadcasted Multiply
let mul = self.expand_dim(1, n) * rhs.permute((1, 0)).expand_dim(0, m);
// Sum Reduce
let mut ret = mul.sum(2);
if vec {
ret = ret.reshape(ret.dims().last().unwrap());
}
retStep-by-step computation:
- Input vector:
[1.0, 2.0, 3.0, 4.0]is treated as(1, 4)matrix - Weight matrix:
(4, 5) - Matrix multiplication:
(1, 4) @ (4, 5) = (1, 5) - Result: Reshaped back to
(5,)vector
Mathematical formula for each output element:
output[0] = 1.0*w₀₀ + 2.0*w₁₀ + 3.0*w₂₀ + 4.0*w₃₀
output[1] = 1.0*w₀₁ + 2.0*w₁₁ + 3.0*w₂₁ + 4.0*w₃₁
output[2] = 1.0*w₀₂ + 2.0*w₁₂ + 3.0*w₂₂ + 4.0*w₃₂
output[3] = 1.0*w₀₃ + 2.0*w₁₃ + 3.0*w₂₃ + 4.0*w₃₃
output[4] = 1.0*w₀₄ + 2.0*w₁₄ + 3.0*w₂₄ + 4.0*w₃₄
Let's say your 20 random weights are:
[0.1, 0.2, 0.3, 0.4, 0.5, // Row 0: weights for input[0]
0.6, 0.7, 0.8, 0.9, 1.0, // Row 1: weights for input[1]
1.1, 1.2, 1.3, 1.4, 1.5, // Row 2: weights for input[2]
1.6, 1.7, 1.8, 1.9, 2.0] // Row 3: weights for input[3]
Then:
output[0] = 1.0*0.1 + 2.0*0.6 + 3.0*1.1 + 4.0*1.6 = 0.1 + 1.2 + 3.3 + 6.4 = 11.0
output[1] = 1.0*0.2 + 2.0*0.7 + 3.0*1.2 + 4.0*1.7 = 0.2 + 1.4 + 3.6 + 6.8 = 12.0
output[2] = 1.0*0.3 + 2.0*0.8 + 3.0*1.3 + 4.0*1.8 = 0.3 + 1.6 + 3.9 + 7.2 = 13.0
output[3] = 1.0*0.4 + 2.0*0.9 + 3.0*1.4 + 4.0*1.9 = 0.4 + 1.8 + 4.2 + 7.6 = 14.0
output[4] = 1.0*0.5 + 2.0*1.0 + 3.0*1.5 + 4.0*2.0 = 0.5 + 2.0 + 4.5 + 8.0 = 15.0
When you call cx.display(), you'll see nodes representing:
- Input tensor loading (your
[1,2,3,4]) - Weight tensor loading (your 20 weights)
- Matrix multiplication operation
- Output tensor
The linear layer is essentially a learned transformation that maps 4-dimensional input space to 5-dimensional output space using the weight matrix as the transformation parameters.
The internal implementation of matrix multiplication in Luminal. We see both Mul and SumReduce operations:
Luminal implements matrix multiplication as a two-step process:
- Broadcasted Element-wise Multiplication (
Mul) - Sum Reduction (
SumReduce)
pub fn matmul(mut self, mut rhs: GraphTensor) -> Self {
if (self.shape.len() == 1 || self.shape.len() == 2) && rhs.shape.len() == 2 {
let vec = self.shape.len() == 1;
if vec {
self = self.expand_dim(0, 1);
}
let (m, _) = self.dims2();
let (_, n) = rhs.dims2();
// Broadcasted Multiply
let mul = self.expand_dim(1, n) * rhs.permute((1, 0)).expand_dim(0, m);
// Sum Reduce
let mut ret = mul.sum(2);- Input:
[1, 2, 3, 4]shape(4,)→ expanded to(1, 4) - Weights:
(4, 5)matrix → transposed to(5, 4)
Input (1, 4): [[1, 2, 3, 4]]
↓ expand_dim(1, 5)
Input (1, 5, 4): [[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]]
Weights (5, 4): [[w₀₀, w₀₁, w₀₂, w₀₃],
[w₁₀, w₁₁, w₁₂, w₁₃],
[w₂₀, w₂₁, w₂₂, w₂₃],
[w₃₀, w₃₁, w₃₂, w₃₃],
[w₄₀, w₄₁, w₄₂, w₄₃]]
↓ expand_dim(0, 1)
Weights (1, 5, 4): Same as above but with batch dimension
This creates a (1, 5, 4) tensor where each element is:
result[0, i, j] = input[0, i, j] * weights[0, i, j]
Output shape: [1, 5, 4] - This is what you see in the debug output!
The SumReduce(2) means "sum along dimension 2" (the last dimension):
For each output[i]:
output[i] = sum(result[0, i, :])
= result[0, i, 0] + result[0, i, 1] + result[0, i, 2] + result[0, i, 3]
This gives you the final (1, 5) result, which gets reshaped to (5,).
This broadcast + reduce pattern is used because:
- GPU Efficiency: Modern GPUs are optimized for element-wise operations and reductions
- Parallelization: Each element-wise multiplication can happen in parallel
- Memory Access: Better memory coalescing patterns
- Compiler Optimization: Easier for the search-based compiler to optimize
Input: [1, 2, 3, 4]
Weights: [[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1.0],
[1.1, 1.2, 1.3, 1.4, 1.5],
[1.6, 1.7, 1.8, 1.9, 2.0]]
After broadcasting and Mul:
[[[1*0.1, 2*0.6, 3*1.1, 4*1.6], # For output[0]
[1*0.2, 2*0.7, 3*1.2, 4*1.7], # For output[1]
[1*0.3, 2*0.8, 3*1.3, 4*1.8], # For output[2]
[1*0.4, 2*0.9, 3*1.4, 4*1.9], # For output[3]
[1*0.5, 2*1.0, 3*1.5, 4*2.0]]] # For output[4]
After SumReduce(2):
[0.1+1.2+3.3+6.4, # = 11.0
0.2+1.4+3.6+6.8, # = 12.0
0.3+1.6+3.9+7.2, # = 13.0
0.4+1.8+4.2+7.6, # = 14.0
0.5+2.0+4.5+8.0] # = 15.0
This is exactly the mathematical definition of matrix multiplication, just implemented as separate broadcast and reduction operations for better GPU performance.
The MetalCompiler is actually a type alias that combines multiple individual compilers, and the Compiler trait is implemented through tuple implementations.
/// Compile graphs to run on Metal-supported macOS devices in supported data formats
pub type MetalCompiler<T> = (MetalCompilerPreBuffer<T>, BufferCompilers);
/// All metal compilers coming before buffer compilers
pub type MetalCompilerPreBuffer<T> = (
prim::PrimitiveCompiler<T>,
SpecialOpsCompiler<T>,
other::CopyCompiler<T>,
elementwise_fusion::ElementwiseFusionCompiler<T>,
);
/// Compilers to share command and storage buffers
pub type BufferCompilers = (
command_buffer::CommandBufferCompiler,
storage_buffer::StorageBufferCompiler,
);
/// Compiler to replace metal ops with specialized variants
pub type SpecialOpsCompiler<T> = (
binary::MetalSubtractionCompiler<T>,
binary::MetalEqualCompiler<T>,
other::ARangeCompiler<T>,
binary::MetalGatherCompiler<T>,
unary::MetalExpCompiler<T>,
unary::MetalCosCompiler<T>,
unary::MeanReduceCompiler<T>,
unary::StdNormCompiler<T>,
matmul::MetalMatMulCompiler<T>,
);The Compiler trait is implemented for tuples in the core library.
pub trait Compiler {
type Output;
/// Run a compilation pass
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, ids: T) -> Self::Output;
}The magic happens with macro-generated tuple implementations:
macro_rules! tuple_impls {
([$($name:ident),+] , [$($idx:tt),+]) => {
impl<
$($name:
Compiler, )+
> Compiler for ($($name,)+) {
type Output = ( $($name::Output, )+ );
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, mut remap: T) -> Self::Output {
( $(self.$idx.compile(graph, &mut remap), )+ )
}
}
};
}
tuple_impls!([M1], [0]);
tuple_impls!([M1, M2], [0, 1]);
tuple_impls!([M1, M2, M3], [0, 1, 2]);
// ... up to 20 elementsSo when we call:
cx.compile(MetalCompiler::<f32>::default(), &mut b);Here's what happens:
- MetalCompiler expands to:
(
(
prim::PrimitiveCompiler<f32>,
SpecialOpsCompiler<f32>,
other::CopyCompiler<f32>,
elementwise_fusion::ElementwiseFusionCompiler<f32>,
),
(
command_buffer::CommandBufferCompiler,
storage_buffer::StorageBufferCompiler,
)
)- The tuple implementation calls each compiler in sequence:
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, mut remap: T) -> Self::Output {
(
self.0.compile(graph, &mut remap), // MetalCompilerPreBuffer
self.1.compile(graph, &mut remap), // BufferCompilers
)
}- Each nested tuple also gets compiled in sequence, so the full order is:
PrimitiveCompiler(converts basic ops to Metal ops)SpecialOpsCompiler(handles matmul, subtraction, etc.)CopyCompiler(manages device memory transfers)ElementwiseFusionCompiler(fuses element-wise operations)CommandBufferCompiler(optimizes command buffer usage)StorageBufferCompiler(optimizes storage buffer usage)
Each individual compiler implements the trait. For example:
impl<T: MetalFloat> Compiler for MetalMatMulCompiler<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) {
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
// Look for the matmul pattern
// Mul ([A, C(fake), B] | [A(fake), C, B]) -> SumReduce(2) -> [A, C]Design allows composable compilation passes -i.e. we can mix and match different compilers, and they all run in sequence to progressively optimize the graph.
SelectGraph is a pattern matching tool that helps compilers find specific operation sequences in the computation graph. Think of it as a "template" that describes what operations to look for.
-
Define the Pattern: The compiler creates
SelectGraphobjects that describe the matrix multiplication pattern:let mut mul2d = op::<MetalMul<T>>(); // Look for Metal multiplication let mut sr2d = op::<MetalSumReduce<T>>(); // Look for sum reduction
-
Set Constraints: It adds constraints like shapes and which dimensions are "fake" (broadcasted):
mul2d.shapes([['M', 'N', 'K'], ['M', 'N', 'K']]); // Expected shapes mul2d.fakes([[None, Some(true), Some(false)], ...]); // Which dims are broadcasted
-
Connect the Pattern: It connects the operations to form the complete pattern:
let mut s2d = mul2d.clone().connect(sr2d.clone()).search(graph);
This creates a pattern:
Mul -> SumReduce(which is how matrix multiplication is implemented) -
Search for Matches: The
search()method returns aGraphSearchobject that can find all instances of this pattern in the graph:while s2d.next_match() { // Find next occurrence of the pattern let (mul, sum_reduce) = (s2d.get(&mul2d), s2d.get(&sr2d)); // Get the actual nodes // Replace with optimized matmul kernel }
- Pattern Recognition: Automatically finds
Mul + SumReducesequences that represent matrix multiplication - Optimization: Replaces these inefficient sequences with a single optimized GEMM kernel
- Multiple Dimensions: Handles 2D, 3D, 4D, and 5D cases with different patterns
SelectGraph is like a "find and replace" tool for computation graphs. It says: "Find all places where you see Mul -> SumReduce with these specific shape patterns, and replace them with a single optimized matrix multiplication operation."
This is how Luminal automatically converts your high-level a.matmul(b) into optimized GPU kernels!