Skip to content

Instantly share code, notes, and snippets.

@nihalpasham
Last active July 20, 2025 09:52
Show Gist options
  • Save nihalpasham/fc128f074e20d880bfd97198c2ac784b to your computer and use it in GitHub Desktop.
Save nihalpasham/fc128f074e20d880bfd97198c2ac784b to your computer and use it in GitHub Desktop.
How does automatic kernel fusion work in burn?

How does automatic kernel fusion work in burn?

Screenshot 2025-07-07 at 10 50 37 AM

Burn’s Tensor Abstraction:

  1. Generic Parameters Tensor<B, D, K> :

    • B: Backend - execution backend (composable)
    • D: usize - dimensionality (compile-time constant)
    • K: TensorKind - element type (FloatIntBool)
  2. Backend Composition

Tensor<Fusion<CubeBackend<WgpuRuntime, f32, i32, u32>>, 2, Float>

This shows that - backends are composable, with Fusion wrapping CubeBackend for kernel fusion capabilities.

  1. Client-Server Architecture Burn’s dual client-server abstractions: Fusion Level:

    • MutexFusionClient wraps FusionServer with mutex
    • FusionServer owns MultiStream with operation queues CubeCL Level:
    • ComputeClient communicates with ComputeServer
    • Similar pattern for actual GPU operations
  2. High-Level Tensor Allocation Flow allocation chain is as follows:

reserve() → pool.alloc() → create_page() → WgpuStorage.alloc() → wgpu::Device::create_buffer()

Notes:

  • Operation Registration: tensor creation itself is a NoOp for the fusion queue. The tensor allocation happens immediately in the CubeCL layer, but fusion operations are queued.

🎯 Key Points

  1. Layered Abstractions: Fusion and CubeCL each have their own client-server patterns
  2. Lazy Pooling: ring buffer reuse before new allocations
  3. Separation of Concerns: Fusion handles operation batching/fusion, CubeCL handles actual GPU execution

📝 Summary

Burn's architecture relies on a

  • The generic type system
  • Backend composability
  • Dual client-server abstractions
  • Custom memory management

Closer look at Creating a Tensor:

  1. The standard Tensor type in burn is generic over

    • a Backend - B:
      • Some examples of backends are - CubeBackend, Fusion etc.
      • CubeBackend is an abstraction built on top of the cubecl crate - which targets multiple runtimes or GPU programming APIs (Cuda, WebGPU, RoCm etc.)
      • backends are composable - for instance I can have a tensor that uses the CubeBackend like this
      // Tensor<B, D, K> - B, D and K are generic type parameters
      
      Tensor<Fusion<CubeBackend<WgpuRuntime, f32, i32, u32>>, 2, Float>
    • i.e. this is a CubeBackendthat uses the wgpu runtime or API, wrapped in a Fusion backend - meaning it is also capable of kernel fusion.
    • dimensionality - D
    • and the element type - K
      • Kcan be anything that implements the TensorKind trait i.e. Float, Int, Bool
  2. The simplest api used to create a Tensor is with from_data

    // Type alias for the backend to use.
    pub type Wgpu<F = f32, I = i32, B = u32> = burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>>;
    type Backend = Wgpu;
    
    let device = Default::default();
    // Creation of a tensors with explicit values
    let tensor_1 = Tensor::<Backend, 2, Float>::from_data([[2., 3.], [4., 5.]], &device);
    • But this single line of code has a lot packed into it.
    let tensor_1 = Tensor::<Backend, 2, Float>::from_data([[2., 3.], [4., 5.]], &device);
    • Essentially, this means we want to allocate a Tensor of Floats using Wgpu
      • from_data invokes one of the FloatTensorOps methods called float_from_data on our target Fusion Backend (Fusion<CubeBackend<Wgpu … >) in our case) to perform the allocation.
  3. The Fusion backend in burn uses a client-server abstraction - i.e. a FusionClient and a FusionServer

    • FusionClient: An concrete example of a fusion client is a MutexFusionClient.
    • This simply wraps a FusionServer in an Arc<Mutex<FusionServer<R>>> i.e. it uses a mutex to communicate with the fusion server.
    • The job of the FusionClient (or MutexFusionClient in this case) is to register or add new tensor operations (actually OperationIr - a burn specific representation of an operation) to a stream’s OperationQueue.
    pub struct MutexFusionClient<R: FusionRuntime> {
    	server: Arc<Mutex<FusionServer<R>>>,
    	device: FusionDevice<R>,
    }
    • FusionServer: this is a data structure that owns what’s called a MultiStream.
    • And a multistream is the thing that holds a hashmap of streams. Each stream in this hashmap is a list of queued operations in an OperationQueue and a Processor to process those operations.

Note: all of these types are generic over a Backend’s runtime (i.e. the target GPU API) which is Wgpu in our case Additionally, allocating tensors in GPU memory is considered a NoOp

// In this example, R is FusionCubeRuntime which implements FusionRuntime - a runtime asbtraction for the Fusion 
// backend. 

pub struct FusionServer<R: FusionRuntime> {
	streams: MultiStream<R>,
	pub(crate) handles: HandleContainer<R::FusionHandle>,
}/// Keep track of multiple concurrent streams of operations.
pub struct MultiStream<R: FusionRuntime> {
	streams: HashMap<StreamId, Stream<R>>,
	optimizations: ExecutionPlanStore<R::Optimization>,
    device: R::FusionDevice,
}struct Stream<R: FusionRuntime> {
	queue: OperationQueue<R>,
	processor: Processor<R::Optimization>,
}
  1. How does memory/buffer allocation happen for our Tensor:
    • There are two different abstractions here
      • Fusion Backend abstraction: the Fusion backend uses a client-server abstraction to aid with kernel fusion
      • CubeBackend abstraction: the actual allocation for the tensor on the GPU is done via the wrapped backend <Fusion<CubeBackend … >
    • In our case, thats CubeBackend which inturn relies on a CubeCL runtime <Fusion<CubeBackend<WgpuRuntime … > to do the real work - WgpuRuntime
    • Note: CubeCL itself has a similar client-server abstraction akin to the Fusion Backend - In CubeCL, they’re called ComputeClient and ComputeServer

The complete allocation flow:

# Fusion Backend invokes the wrapped CubeBackend which inturn calls on WgpuRuntime to the actual job i.e. Tensor allocation. 
# Here's the approx flow

Tensor::from_data()
    ↓
Fusion<CubeBackend ... >::float_from_data()
    ↓
CubeBackend::float_from_data() (this simply directs the call to CubeCL runtime)
    ↓
burn_cubecl::ops::base::from_data() (a proxy for instantiating CubeCL types)
    ↓
WgpuRuntime::client() (instantiates a CubeCL compute client, in our case we instantiate a Wgpu compute client)
    ↓
ComputeClient<WgpuServer, MutexComputeChannel>::create() (which given a resource like a buffer, stores it and returns the resource handle.)
    ↓
MutexComputeChannel::create() (if the compute client is using mutexes to communicate with a Compute Server)
    ↓
WgpuServer::create() (assuming we're using Wgpu as our backend)
    ↓
MemoryManager::reserve() 
    ↓
MemoryManagement<WgpuStorage>::reserve()
    ↓
SlicedPool::alloc() (if using sliced pools)
    ↓
SlicedPool::create_page() (if new page needed)
    ↓
WgpuStorage::alloc()
    ↓
wgpu::Device::create_buffer() (actual GPU allocation)
    ↓
StorageHandle created with Storage ID
    ↓
ID pushed to SlicedPool ring buffer
    ↓
ID returned

Key Point:

  • The reserve method is the entry point that triggers this entire allocation chain, but it's designed to be lazy and pooled - it tries to reuse existing allocations from the ring buffer before creating new GPU buffers.
  • All these abstraction layers allow CubeCL to efficiently manage GPU memory while hiding the complexity from higher-level tensor operations.

Additional notes:

  • My understanding is that memory allocation and data transfer are handled via DMA — either over PCIe (for discrete GPUs) or through a proprietary on-die interconnect (in the case of integrated GPUs or Apple Silicon). This process does not go through the GPU’s compute pipeline or trigger a kernel launch. Instead, the data transfer is typically handled by a dedicated copy engine via a MemCopy command in the GPU’s command queues.
  • Yep, but it is still on the same stream on cuda, and unsure if it's linked to a command encoder in wgpu
  • For wgpu - I see a call being made to queue.write_buffer(&buffer, 0, bytemuck::cast_slice(&data)); which stages the data into CPU-accessible memory, internally records a copy command and submits it immediately to the GPU without needing to manually create a CommandEncoder or submit a CommandBuffer.
  • I believe the command encoder method is used for large data transfers.

Lets look at Three Operations and Generated Kernels

Trace through how kernel fusion actually works for our three operations.

RUST_LOG=burn_fusion=trace,cubecl_wgpu::runtime=trace,wgpu_hal::metal=trace cargo run --example burn-test --release --features wgpu

let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device); // NoOp
let y = tensor_1.clone() * 2.0 + 1.0; // ScalarMul + ScalarAdd operations
let z = y.tanh();                     // Tanh operation

Log output shows 4 different kernels being generated during the fusion demonstration because of how the Burn framework's fusion system works with the Wgpu backend:

  1. Initial Setup Kernels (2): The first two kernels are part of the WGPU backend initialization:

    • First kernel: Handles buffer setup and validation (dispatch workgroup size validation)
    • Second kernel: Handles metadata and indirect buffer operations (“indirect dispatches" where dispatch parameters come from GPU buffers)
  2. Fused Operation Kernel:

    [2025-06-30T07:48:56Z DEBUG wgpu_hal::metal::device] Naga generated shader for entry point 'elemwise_fuse' and stage Compute
    • This is the key kernel that demonstrates fusion. Notice how it combines all three operations (multiply by 2.0, add 1.0, and tanh) into a single kernel. In the code, you can see:
    metal::float2 l_10_ = buffer_0_global[id];
    float _e66 = scalars_f32_.inner[0];
    metal::float2 l_13_ = l_10_ * _e66;           // multiply by 2.0
    float _e70 = scalars_f32_.inner[1];
    metal::float2 l_16_ = l_13_ + metal::float2(_e70);  // add 1.0
    metal::float2 _e73 = safe_tanh(l_16_);        // tanh
    • The log confirms this fusion: 
    [2025-06-30T07:48:56Z TRACE burn_fusion::stream::store::base] New execution plan 1 - Operations: 3 - Triggers 1
  3. Slice Kernel:

    [2025-06-30T07:48:56Z DEBUG wgpu_hal::metal::device] Naga generated shader for entry point 'slice_kernel' and stage Compute
    • This kernel handles the final data extraction and formatting for display when you print the result.
  4. The fusion demonstration works because Burn's fusion system queues operations without executing them immediately. When you call:

    let temp = tensor_1.clone() * 2.0;
    let y = temp + 1.0;
    let z = y.tanh();

These operations are recorded but not executed until you force execution by printing the result. At that point, the system optimizes by fusing all three operations into a single kernel execution, which is more efficient than running three separate kernels.

How Fusion Actually Works

let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device); // NoOp
let y = tensor_1.clone() * 2.0 + 1.0; // ScalarMul + ScalarAdd operations
let z = y.tanh();                     // Tanh operation

==🔴But before we jump in let’s look at pre-requisites i.e. what we would need to generate fast, performant GPU kernels==

Understanding Reads and Writes in Fusion Optimization

In the context of fusion optimization, "reads" and "writes" refer to memory access patterns that track how operations interact with tensors.

Reads and Writes in FuseBlockBuilder

pub struct FuseBlockBuilder {
    pub settings: FuseBlockSettings,
    pub ops: Vec<FuseOp>,
    pub reads: BTreeMap<TensorId, (FusePrecision, LayoutInfo)>,
    pub writes: BTreeMap<TensorId, FusePrecision>,
    pub tensor_writes: BTreeMap<TensorId, FusePrecision>,
}

Reads

The reads field is a map that tracks which tensors are read by operations in the block:

  • Key: TensorId - Identifies a specific tensor
  • Value: (FusePrecision, LayoutInfo) - Specifies:
    • The floating point precision of tensor elements (e.g., F32, F16)
    • Layout information about how the tensor is accessed

When an operation needs to read a tensor, it's registered in this map. This helps the optimizer understand which tensors need to be loaded from memory.

Writes

The writes field tracks which tensors are written to by operations in the block:

  • Key: TensorId - Identifies a specific tensor
  • Value: FusePrecision - The precision of the written data

When an operation produces a result that needs to be stored in a tensor, it's registered in this map. This helps the optimizer understand which tensors need to be stored back to memory.

Tensor Writes

The tensor_writes field specifically tracks tensors that need to be written to global memory (as opposed to temporary results that can stay in registers or shared memory).

Why Tracking Reads and Writes is Important

  1. Memory Access Optimization: By knowing which tensors are read and written, the optimizer can:

    • Minimize global memory accesses
    • Keep intermediate results in faster memory (registers or shared memory)
    • Coalesce memory operations for better performance
  2. Data Dependency Analysis: Tracking reads and writes helps identify:

    • Which operations depend on each other
    • Which operations can be executed in parallel
    • Which operations can be fused together
  3. Memory Allocation: It helps determine:

    • How much memory is needed for intermediate results
    • When memory can be reused
    • When memory needs to be allocated

Example: How Reads and Writes are Tracked

Let's look at how reads and writes are tracked for our example:

let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device);
let y = tensor_1.clone() * 2.0 + 1.0;
let z = y.tanh();

1. Scalar Multiplication: tensor_1.clone() * 2.0

  • Reads:

    • tensor_1 is read from global memory
    • The scalar 2.0 is read from a constant
  • Writes:

    • An intermediate result (let's call it temp1) is written

2. Scalar Addition: temp1 + 1.0

  • Reads:

    • temp1 is read (but this is an intermediate result, not from global memory)
    • The scalar 1.0 is read from a constant
  • Writes:

    • Another intermediate result (let's call it y) is written

3. Tanh: y.tanh()

  • Reads:

    • y is read (again, an intermediate result)
  • Writes:

    • The final result z is written to global memory

Optimization

The optimizer recognizes that temp1 and y are intermediate results that don't need to be written to global memory. Instead, they can be kept in registers or shared memory. Only z needs to be written to global memory.

reads (in FuseBlockBuilder):

  • Tracks all tensor reads - both from global memory AND intermediate results
  • Maps TensorId → Vec<FuseOp> showing which operations read each tensor

writes (in FuseBlock after build()):

  • Contains only the final global memory writes
  • Maps TensorId → FuseOp for tensors that need to be written to global memory

tensor_writes() method:

  • Analyzes the dataflow to determine which intermediate results actually need to be written to global memory
  • Filters out intermediate results that are only used within the kernel

Example Breakdown:

let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device);
let y = tensor_1.clone() * 2.0 + 1.0;  // temp1 = tensor_1 * 2.0, y = temp1 + 1.0
let z = y.tanh();                       // z = tanh(y)

During Registration (populates reads):

  1. tensor_1 * 2.0:

    • reads[tensor_1.id] gets FuseOp::Assign(Input → Local(0))
    • Creates intermediate temp1 (Local(1))
  2. temp1 + 1.0:

    • reads[temp1.id] gets operation that reads temp1
    • Creates intermediate y (Local(2))
  3. y.tanh():

    • reads[y.id] gets operation that reads y
    • Creates final result z (Local(3))

During build() (creates writes):

The tensor_writes() method analyzes:

  • temp1: Only read by the + 1.0 operation → not written to global memory
  • y: Only read by the tanh() operation → not written to global memory
  • z: Final result, needs to persist → written to global memory

So writes only contains:

writes[z.id] = FuseOp::Assign(Local(3)Output(0))

The Key Insight:

The tensor_writes() method performs dataflow analysis:

// All output tensors that are never read by a following operation should be written to
// since they are essentially the "logical" output of the shader.
for entry in local_tensor_ids_output {
    let is_read = local_tensor_ids_input.contains(&entry);

    if !is_read
        && !self.local_outputs.contains(&entry.0)
        && !resources.dropped.contains(&entry.0)
    {
        // This tensor is produced but never consumed → write to global memory
        result.insert(precision, tensor.clone());
    }
}

Summary

  1. reads: Tracks ALL tensor reads (global + intermediate)
  2. ops: Tracks computation operations
  3. tensor_writes(): Analyzes which results need global memory writes
  4. writes: Contains only the global memory write operations

You're absolutely correct - the distinction is between:

  • Intermediate writes (kept in registers/local memory, not explicitly tracked as "writes")
  • Global memory writes (tracked in writes, determined by tensor_writes())

By carefully tracking reads and writes, the fusion optimizer can minimize memory accesses and maximize computational efficiency, resulting in faster execution of neural network operations.

High-Level Fusion Architecture in Burn (approximation only):

image

Flow Explanation:

  1. Operation Queue:

    • All 3 operations are added to the queue
  2. Stream Optimizer:

    • Creates blocks based on tensor dependencies
    • Block 1: Tensor creation (not fusable with others)
    • Block 2: ScalarMul, ScalarAdd, and Tanh (all fusable)
  3. Blocks Optimizer:

    • Tries to merge blocks (not possible here)
    • Optimizes each block separately
    • No holes found in this simple example
  4. Execution Strategy:

    • Strategy 1: Execute tensor creation individually
    • Strategy 2: Execute fused kernel for the element-wise operations

The final execution will:

  1. Create tensor_1
  2. Execute a single fused kernel that computes z = tanh((tensor_1 * 2.0) + 1.0)

This eliminates the need for intermediate storage of the y tensor, reducing memory traffic and improving performance.

Deeper-Dive

Operation Queue:

All operations are queued for execution but aren’t immediately executed i.e. execution is lazy. The OperationQueue in a Stream is effectively the compute graph that contains the ordered sequence of operations waiting to be executed.

/// A growing list of [tensor operation descriptions](OperationIr).
pub struct OperationQueue<R: FusionRuntime> {
    /// List of operation descriptions. These contain the exact tensor IDs
    /// and shapes so that kernels can be run correctly.
    pub(crate) global: Vec<OperationIr>,
    /// List of operation descriptions. The tensor IDs and shapes are relative
    /// because we don't need to know the exact values, but they are sufficient to
    /// determine which operations can be fused.
    pub(crate) relative: Vec<OperationIr>,
    pub(crate) converter: OperationConverter,
    pub(crate) operations: Vec<Box<dyn Operation<R>>>,
    pub(crate) variables: HashMap<TensorId, (StreamId, TensorStatus)>,
}

Key aspects of the compute graph:

  1. It stores both the high-level operation descriptionsOperationIr) and the actual executable operations.
  2. It maintains two representations:
    • global: The exact operations with precise tensor IDs and shapes
    • relative: A representation that helps identify fusion opportunities
  3. It tracks tensor variables and their states through the variables HashMap.
  4. Operations are added in sequence via the add method, building up the computation graph.
  5. When execution is triggered, the system analyzes this queue to identify fusion opportunities before executing the operations.

The MultiStream type manages multiple such queues (one per stream), effectively maintaining multiple compute graphs that can interact when operations span across streams.

How does Burn’s lazy execution system work?

let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device); // NoOp
let y = tensor_1.clone() * 2.0 + 1.0; // ScalarMul + ScalarAdd operations
let z = y.tanh();

println!("Final result: {}", z);

Printing z here invokes the Display implementation for Tensor but doesn't directly trigger kernel fusion and execution. Instead, it attempts to read tensor data for display, which indirectly causes execution of pending operations.

The exact flow:

  1. When fmt is called on a tensor, it calls display_recursive which eventually calls fmt_inner_tensor

  2. In fmt_inner_tensor, the key line that triggers execution is:

let data = burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());
  1. This flow:

    • First creates a slice of the tensor
    • Then calls into_data_async() which returns a future
    • Then try_read_sync attempts to synchronously read that future
  2. The into_data_async() call is what triggers the fusion and execution pipeline:

    • It causes any pending operations to be materialized
    • The fusion system analyzes the operation graph
    • Fused kernels are created and executed to compute the actual tensor values
  3. The execution flow goes through:

    • burn-fusion/src/stream/multi.rs - The drain method is called to process pending operations
    • burn-fusion/src/stream/execution/processor.rs - The processor analyzes operations for fusion
    • burn-cubecl-fusion/src/shared/trace/base.rs - The FuseTrace::run method executes the fused operations
    • burn-cubecl-fusion/src/shared/trace/executor.rs - The LaunchPlanExecutor::execute method handles the actual kernel launch
  4. The key steps in the fusion execution are:

    • Input planning (via InputPlanner)
    • Output planning (via OutputPlanner)
    • Vectorization planning (via VectorizationPlanner)
    • Kernel execution (via LaunchPlanExecutor)

This happens because displaying a tensor requires accessing its actual values, which forces the computation of any pending operations in the computation graph.

What the drain method does?

The drain method in MultiStream is responsible for executing all pending operations in a specific stream. Here's what it does:

/// Drain a stream
    pub fn drain(&mut self, handles: &mut HandleContainer<R::FusionHandle>, id: StreamId) {
        if let Some(stream) = self.streams.get_mut(&id) {
            let num_executed = stream.queue.global.len();
            stream.processor.process(
                Segment::new(&mut stream.queue, handles),
                &mut self.optimizations,
                ExecutionMode::Sync,
            );
            stream.cursor += num_executed as u64;

            let cleared = self.shared_tensors.on_executed_ops(id, stream);
            self.clear_shared_tensors(&cleared, id);
            let to_drop = self.shared_tensors.clear_tensors(cleared);

            self.drop_shared_tensors(to_drop, handles, id);
        }
    }
  1. It finds the stream with the given ID
  2. It processes all operations in the stream's queue in Sync mode (immediate execution)
  3. It updates the stream's cursor to track execution progress
  4. It handles shared tensors that were executed:
    • Clears them from other streams
    • Drops tensors that are no longer needed

How the processor identifies fusable segments?

As you can see in the code above, draining a stream triggers the fusion process. Remember, every stream is composed of a queue and a processor.

pub(crate) struct Stream<R: FusionRuntime> {
    pub(crate) queue: OperationQueue<R>,
    processor: Processor<R::Optimization>,
    pub(crate) cursor: u64,
}

/// Process a [stream segment](StreamSegment) following a [policy](Policy).
pub(crate) struct Processor<O> {
    policy: Policy<O>,
    explorer: Explorer<O>,
}

The workflow is:

  1. Operations are queued in the  OperationQueue
  2. The  Processor analyzes these operations to find fusion opportunities
  3. The  Processor uses the  StreamSegment abstraction to access operations in the queue

⠀The key insight is that the processor doesn't directly decide what's fusable. Instead:

  1. The processor coordinates the fusion process
  2. It delegates the actual fusion decisions to optimization builders provided by the runtime
  3. It uses the Policy to decide when to explore, execute, or defer operations
  4. It uses the Explorer to find optimization opportunities

Note: The processor doesn't directly decide what's fusable - it works with the StreamSegment abstraction:

#[derive(new)]
struct Segment<'a, R: FusionRuntime> {
    queue: &'a mut OperationQueue<R>,
    handles: &'a mut HandleContainer<R::FusionHandle>,
}

pub fn process<Segment>(
    &mut self,
    mut segment: Segment,
    store: &mut ExecutionPlanStore<O>,
    mode: ExecutionMode,
) where
    Segment: StreamSegment<O>,
{
    ...
	...
	...
	let action = self.policy.action(store, segment.operations(), mode);

            match action {
                Action::Explore => {
                    self.explore(&mut segment, store, mode);

                    if self.explorer.is_up_to_date() {
                        break;
                    }
                }
                Action::Defer => {
                    match mode {
                        ExecutionMode::Lazy => break,
                        ExecutionMode::Sync => panic!("Can't defer while sync"),
                    };
                }
                Action::Execute(id) => {
                    if let ExecutionMode::Sync = mode {
                        store.add_trigger(id, ExecutionTrigger::OnSync);
                    }

                    segment.execute(id, store);
                    self.reset(store, segment.operations());
                }
            };

}

The Segment type takes exclusive access of operation queue and provides access to operations through the operations() method. The processor then:

  1. Uses a Policy to decide what action to take (explore, execute, defer)
  2. Uses an Explorer to find optimization opportunities
  3. When an optimization is found, it's stored in the ExecutionPlanStore
/// The policy keeps track of all possible execution plans (ids) for the current operation stream. 
pub(crate) struct Policy<O> {
    /// List of potential execution plans that are compatible with current stream segment
    candidates: Vec<OperationsValidator<ExecutionPlanId>>,
    /// List of candidate execution plans that have been found; we can still keep searching
    /// to potentially find a better one.
    availables: Vec<AvailableItem>,
    /// The found execution plan that should be executed, along with the number of operations
    /// in the plan.
    found: Option<(ExecutionPlanId, usize)>,
    /// The number of operations that have been analyzed
    num_operations: usize,
    _item_type: PhantomData<O>,
}

/// Explore and create new optimization.
pub struct Explorer<O> {
    optimizer: StreamOptimizer<O>,
    num_deferred: usize,
    num_explored: usize,
    is_still_optimizing: bool,
}

Note: Fusability is ultimately determined by the optimization builders provided by the runtime, which analyze operations to see if they can be combined.

==🔴But how exactly does the Explorer find optimization opportunities?==

If querying the policy returns Action::Explore while processing a segment (i.e. a stream or sequence of ordered operations), we enter the exploration branch of the match above. This invokes the Explorer which contains the StreamOptimizer

  • The StreaOptimizer’s first job is to register (or add) all operations in a stream/segment to a block

How StreamOptimizer registers Ops:

This method attempts to register an operation with existing blocks in the StreamOptimizer. Here's what it does:

impl<O: NumOperations> Explorer<O> {

... 
...
// Explore the provided operations.
    pub(crate) fn explore(
        &mut self,
        operations: &[OperationIr],
        mode: ExecutionMode,
    ) -> ExplorationAction<O> {
        self.update(operations); // This essentially performs Block Op registration via `register_inner`, below

        // Can only continue exploration when not sync.
        if let ExecutionMode::Lazy = mode {
            if self.is_still_optimizing {
                return ExplorationAction::Continue;
            }
        }

        let optimization = self.optimizer.optimize(operations); // post registartion, we optimize 

        ExplorationAction::Completed(optimization)
    }
...
...
}

impl<O: NumOperations> StreamOptimizer<O> {

...
...

fn register_inner(&mut self, operation: &OperationIr, force: bool) -> usize {
    let mut added_count = 0;
    for block in self.blocks.iter_mut() {
        match block.register(operation, self.length, force) {
            RegistrationResult::Accepted => {
                added_count += 1;
            }
            RegistrationResult::NotPartOfTheGraph => {}
        }
    }
    added_count
}

...
...
}

The process

  1. It iterates through all existing blocks in the StreamOptimizer
  2. For each block, it tries to register the operation by calling block.register()
  3. It passes:
    • The operation to register
    • The current length (position in the stream)
    • A force flag that can override normal registration rules
  4. It counts how many blocks accepted the operation
  5. It returns this count

What "block" means here

A Block is the StreamOptimizer's abstraction for an ordered sequence of operations that can potentially be fused together. Each block:

  1. Contains operations that are related (they use the same tensors)
  2. Tracks the ordering of operations
  3. Maintains a set of optimization builders that analyze the operations
  4. Can determine if operations can be fused

Registering Ops in a Block

The key part is how a block decides whether to accept an operation:

pub fn register(
    &mut self,
    operation: &OperationIr,
    order: usize,
    force: bool,
) -> RegistrationResult {
    if self.ids.is_empty() {
        self.register_op(operation, order);
        return RegistrationResult::Accepted;
    }
    let mut contains = false;
    for node in operation.nodes() {
        contains = self.ids.contains(&node.id);

        if contains {
            break;
        }
    }

    if !contains && !force {
        return RegistrationResult::NotPartOfTheGraph;
    }

    self.register_op(operation, order);
    RegistrationResult::Accepted
}

A block accepts an operation if:

  1. The block is empty (first operation always accepted)
  2. The operation uses tensors that are already in the block
  3. The force flag is true (override normal rules)

Note: calling register_op on a Block<O> ultimately registers the operation with FuseBlockBuilder by adding it to the ops vector in FuseBlockBuilder.

The operation flows through several layers of abstraction, but ultimately ends up in the ops vector of FuseBlockBuilder, which is part of the FuseTraceBuilder.

During the operation registration flow, all FuseBlockBuilder fields get populated simultaneously:

  • self.ops gets the actual operations
  • self.reads gets populated with input read operations
  • resources.inputs/outputs/scalars get populated with tensor/scalar metadata
  • tensor_writes (computed later) is determined by analyzing the populated resources and operations

⠀This is why by the time block.build() is called, all the information needed to generate the final fused kernel is already available!

Block Operation Registration Flow

Block.register_op
    ↓
OptimizationBuilder.register (trait method)ElementWiseBuilder.register (or other implementation)FuseOptimizationBuilder.register
    ↓
FuseOptimizationBuilder.register_numeric/register_binary_ops/etc.FuseOptimizationBuilder.register_scalar_ops/register_unary_ops/etc.TryFuseBuilder.register
    ↓
FuseTraceBuilder.register_operationFuseBlockBuilder.ops.push

The bigger picture

This method is part of a larger strategy in StreamOptimizer:

  1. When a new operation arrives, it first tries to merge blocks if needed
  2. Then it tries to register the operation with existing blocks using register_inner
  3. If no block accepts it, it creates a new block for the operation
  4. It tracks how many blocks it has and may stop optimizing if it exceeds max_blocks

The goal is to group operations into blocks that can be optimized together, while maintaining the correct execution order and dependencies between operations.

This approach allows the system to:

  1. Find fusion opportunities within each block
  2. Handle complex streams with multiple independent fusion groups
  3. Maintain correct execution semantics

Block Optimizer and Block Optimization Process

To Reiterate:

  • Each Block<O> contains operations that could potentially be fused
  • The optimize() method in Block<O> finds the best optimization strategy:
  • Each block also contains a set of  OptimizationBuilder instances that analyze operations
    • Example: for element-wise operations, the ElemwiseOptimizationBuilder recognizes patterns it can fuse

Here’s the flow from a StreamOptimizer all the way to producing a FuseTrace.

The Full Optimization Flow

Explorer.explore
    ↓
StreamOptimizer.optimize
    ↓
BlocksOptimizer.optimizeBlock.optimize
    ↓
find_best_optimization_index
    ↓
OptimizationBuilder.build (trait method)FuseOptimizationBuilder.build
    ↓
TryFuseBuilder.build
    ↓
FuseTraceBuilder.build
    ↓
FuseTrace is created

Detailed Explanation

1. StreamOptimizer.optimize

pub fn optimize(&self, operations: &[OperationIr]) -> BlockOptimization<O> {
    let result = BlocksOptimizer::new(self.blocks.clone()).optimize();
    
    match result {
        BlocksOptimizerResult::Full(optimization) => optimization,
        BlocksOptimizerResult::WithHoles { strategies, ordering, holes } => {
            // Handle holes case...
        }
    }
}

The StreamOptimizer creates a BlocksOptimizer with its blocks and calls optimize().

2. BlocksOptimizer.optimize

pub fn optimize(mut self) -> BlocksOptimizerResult<O> {
    self = self.merging_pass();

    let mut strategies = Vec::with_capacity(self.blocks.len());
    let mut ordering = Vec::new();
    let mut blocks = Vec::new();
    core::mem::swap(&mut blocks, &mut self.blocks);

    for block in blocks {
        match self.optimize_block(block, &mut ordering) {
            BlockOptimizationStep::Contiguous { strategy } => {
                strategies.push(Box::new(strategy));
            }
            // Handle other cases...
        }
    }

    // Create and return BlockOptimization
}

/// Optimize a single block.
    fn optimize_block(
        &mut self,
        block: Block<O>,
        ordering: &mut Vec<usize>,
    ) -> BlockOptimizationStep<O> {
        let last_index = block.end_pos;
        let mut block_optimization = block.optimize();
        let opt_size = block_optimization.ordering.len();
...
...
}

The BlocksOptimizer:

  1. Tries to merge blocks that can be combined
  2. Processes each block to create optimization strategies
  3. Combines these strategies into a final BlockOptimization

3. Block.optimize

pub fn optimize(mut self) -> BlockOptimization<O> {
    match find_best_optimization_index(&mut self.builders) {
        Some(index) => {
            let opt = self.builders[index].build();
            let opt_len = opt.len();
            if opt_len < self.operations.len() {
                self.ordering.drain(opt_len..);
            }

            let strategy = ExecutionStrategy::Optimization {
                ordering: Arc::new(self.ordering.clone()),
                opt,
            };
            BlockOptimization::new(strategy, self.ordering)
        }
        None => {
            let strategy = ExecutionStrategy::Operations {
                ordering: Arc::new(self.ordering.clone()),
            };
            BlockOptimization::new(strategy, self.ordering)
        }
    }
}

The Block.optimize method:

  1. Finds the best optimization builder using find_best_optimization_index
  2. Calls build() on that builder to create the optimization
  3. Creates an ExecutionStrategy with the optimization
  4. Returns a BlockOptimization with the strategy and ordering

4. find_best_optimization_index

fn find_best_optimization_index<O>(
    optimizations: &mut [Box<dyn OptimizationBuilder<O>>],
) -> Option<usize> {
    let mut best_index = None;
    let mut best_score = 0;

    for (i, optimization) in optimizations.iter().enumerate() {
        let properties = optimization.properties();

        if properties.ready && properties.score >= best_score {
            best_index = Some(i);
            best_score = properties.score;
        }
    }

    best_index
}

This function:

  1. Examines all optimization builders

  2. Finds the one with the highest score that is ready

  3. Returns its index

    How Optimization Builder Scores Are Populated

The scores are populated during the registration process as operations are added to each builder. Here's the complete flow:

1. Registration Flow

When a Block receives operations, it calls:

  • Block.register()Block.register_op()builder.register(operation) for each builder
2. Score Calculation During Registration

Each optimization builder calculates its score based on how many operations it successfully accepts:

fn properties(&self) -> OptimizationProperties {
    let ready = self.num_ops > 0;

    OptimizationProperties {
        ready,
        score: self.num_ops as u64,  // Score = number of operations accepted
    }
}
3. Score Updates During Registration

When FuseOptimizationBuilder.register() is called:

fn register(&mut self, operation: &OperationIr) {
    // ... operation type checking ...
    
    if !self.register_numeric::<i32>(ops) {
        self.status = OptimizationStatus::Closed;  // Builder rejects future ops
        return;
    }
    
    self.status = OptimizationStatus::Open;
    self.num_ops += 1;  // This increments the score!
}
4. Different Builder Scoring Strategies

Different builders have different scoring strategies:

FuseOptimizationBuilder: Score = num_ops (number of operations it can fuse)

ReduceBuilder: Score = base_score + 1 if it has a reduce operation

fn properties(&self) -> burn_fusion::OptimizationProperties {
    let mut properties = self.builder.properties();

    if self.reduce.is_some() {
        properties.ready = true;
        properties.score += 1;  // Bonus for having reduce
    } else {
        properties.ready = false;
    };

    properties
}

MatmulBuilder: Score = base_score + 1 (bonus for matmul operations)

fn properties(&self) -> burn_fusion::OptimizationProperties {
    let mut properties = self.builder.properties();
    properties.score += 1;  // Bonus for matmul
    properties
}
5. Best Builder Selection

Finally, find_best_optimization_index picks the builder with the highest score:

fn find_best_optimization_index<O>(
    optimizations: &mut [Box<dyn OptimizationBuilder<O>>],
) -> Option<usize> {
    let mut best_index = None;
    let mut best_score = 0;

    for (i, optimization) in optimizations.iter().enumerate() {
        let properties = optimization.properties();

        if properties.ready && properties.score >= best_score {
            best_index = Some(i);
            best_score = properties.score;
        }
    }

    best_index
}
Summary

The scores are populated incrementally during operation registration:

  1. Each operation gets registered with all builders in the block
  2. Each builder decides if it can handle the operation
  3. If accepted: num_ops++ which increases the score
  4. If rejected: Builder status becomes Closed and stops accepting operations
  5. Specialized builders (Matmul, Reduce) get bonus points for their specific operations
  6. Best builder is selected based on highest score when Block.optimize() is called

This allows the fusion system to automatically choose the most effective optimization strategy based on which builder can fuse the most operations together.

5. OptimizationBuilder.build

This is a trait method implemented by various builders. For example, ElementWiseBuilder.build():

fn build(&self) -> CubeOptimization<R> {
    let client = R::client(&self.device);
    let trace = self.builder.build();
    let elementwise =
        ElemwiseOptimization::<R>::new(trace, client, self.device.clone(), self.len());

    CubeOptimization::ElementWise(elementwise)
}

The builder calls build() on its internal builder (typically a FuseOptimizationBuilder).

6. FuseOptimizationBuilder.build

fn build(&self) -> FuseTrace {
    self.builder.build(self.current_output_shape.clone())
}

This simply forwards to TryFuseBuilder.build().

7. TryFuseBuilder.build

fn build(&self, shape: Vec<usize>) -> FuseTrace {
    self.builder.build(shape)
}

This forwards to FuseTraceBuilder.build().

8. FuseTraceBuilder.build

pub fn build(&self, shape_ref: Vec<usize>) -> FuseTrace {
    let mut resources = self.resources.clone();
    let mut outputs = RegisteredTensors::default();
    let mut blocks = Vec::new();

    let mut register_block =
        |block: &FuseBlockBuilder, shape_ref: &Vec<usize>, offset: usize| {
            let (block, block_tensor_writes) =
                block.build(&self.resources, shape_ref.clone(), offset);
            blocks.push(block);

            let num_outputs = block_tensor_writes.len();
            for (ir, precision) in block_tensor_writes.into_iter() {
                outputs.insert(precision, ir);
            }

            num_outputs
        };

    let mut offset = 0;

    for (block, shape_ref) in self.blocks_previous.iter() {
        offset += register_block(block, shape_ref, offset);
    }
    register_block(&self.block_current, &shape_ref, offset);

    resources.outputs = outputs;

    FuseTrace { blocks, resources }
}

This is where the FuseTrace is actually created:

  1. It clones the current resources
  2. It processes each block using FuseBlockBuilder.build()
  3. It collects all blocks and their outputs
  4. It creates a FuseTrace with the blocks and resources

9. FuseBlockBuilder.build

pub fn build(
    &self,
    resources: &FuseResources,
    shape_ref: Vec<usize>,
    offset: usize,
) -> (FuseBlock, RegisteredTensors) {
    let ops = self.ops.clone();
    let reads = self.reads.clone();
    let tensor_writes = self.tensor_writes(resources);

    let mut writes = BTreeMap::new();

    // Process writes...

    (
        FuseBlock {
            settings: self.settings,
            ops,
            shape_ref,
            reads,
            writes,
        },
        tensor_writes,
    )
}

This creates a FuseBlock with:

  1. The operations that have been registered
  2. The reads and writes for each tensor
  3. The shape reference and settings

How FuseTrace is Produced

To summarize how a FuseTrace is produced:

  1. Block Collection: The StreamOptimizer collects blocks of operations
  2. Block Optimization: Each block is optimized using the best available builder
  3. Builder Selection: The best builder is selected based on its score
  4. Trace Building: The selected builder builds a trace by:
    • Processing each block to create a FuseBlock
    • Collecting all blocks and their resources
    • Creating a FuseTrace with the blocks and resources

The key is that by the time build() is called, all operations have already been registered with the builders. The build() method doesn't register new operations - it processes the already-registered operations to create an optimized execution plan.

The resulting FuseTrace contains:

  1. A list of FuseBlocks, each with its operations, reads, and writes
  2. Resources including inputs, outputs, and scalars
  3. Everything needed to execute the fused operations efficiently

This trace can then be executed by a runtime to perform the fused operations efficiently.

How CubeCL Kernels Are Generated

  • WIP
$ cargo expand --manifest-path crates/burn-cubecl-fusion/Cargo.toml --lib elemwise::optimization

From what I can tell, fused kernels aren’t generated on the fly or at runtime—they’re statically defined and precompiled. They are like generic templates that can handle arbitrary sequences of operations.

For example, the elemwise_fuse kernel below appears to handle any sequence of elementwise ops:

#[cube(launch_unchecked)]
fn elemwise_fuse(
    inputs: &GlobalArgs,
    outputs: &mut GlobalArgs,
    #[comptime] config: &FuseBlockConfig,
) {
    let values = Registry::<Arg, Line<f32>>::new();
    let args = comptime![Sequence::<Arg>::new()];
    let pos = ABSOLUTE_POS;

    let mut locals = init_locals(inputs, outputs, config);
    let length = ref_len(inputs, outputs, &locals, config);

    if pos < length {
        fuse_on_write::<f32>(inputs, outputs, &mut locals, pos, values, args, config)
    }
}

As I understand it, the flow looks something like this: 1. A FuseTrace is recorded with a sequence of operations like:

ops: [
    Assign(input -> local_0),
    Mul(local_0, scalar_0 -> local_1),
    Add(local_1, scalar_1 -> local_2),
    Tanh(local_2 -> local_3),
]
2.    At kernel launch time, this sequence is passed in as part of the `FuseBlockConfig`.

3.    The actual kernel uses a generic fuse function that looks like:
#[cube]
fn fuse(..., #[comptime] config: &FuseBlockConfig) {
    #[unroll]
    for index in 0..config.ops.len() {
        let op = comptime! { config.ops.index(index).clone() };
        match op {
            FuseOp::Mul(op) => mul(...),
            FuseOp::Add(op) => add(...),
            FuseOp::Tanh(op) => tanh(...),
            // etc.
        }
    }
}

So there’s no runtime code generation involved—just compile-time specialization or rather JIT compilation-specialisation via #[comptime] and #[unroll].

Miscellaneous:

Processing non-fusable operations

When operations can't be fused:

  1. The processor still processes them through the same flow
  2. The Explorer will fail to find optimizations
  3. The operations will be executed individually using the ExecutionStrategy::Operations strategy
pub(crate) enum ExecutionStrategy<O> {
    /// An optimization was found, and therefore should be executed.
    Optimization { opt: O, ordering: Arc<Vec<usize>> },
    /// No optimization was found, each operation should be executed individually.
    Operations { ordering: Arc<Vec<usize>> },
    /// A composition of multiple execution strategies.
    Composed(Vec<Box<Self>>),
}

Cross-stream fusion

Operations across multiple streams are not directly fused. The system is designed to:

  1. Keep streams independent for concurrent execution
  2. Handle shared tensors between streams
  3. Synchronize streams when necessary

The MultiStream type manages these interactions through methods like resolve_streams and merge_streams_timelines, which ensure proper ordering when streams share tensors, but it doesn't attempt to fuse operations across different streams.

When a tensor is shared between streams, the system ensures proper execution order by draining dependent streams when needed, but fusion only happens within individual streams.

Search algorithms for

  • What is the correct tile size?
  • Is it good to unroll this loop or not?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment