-
Generic Parameters Tensor<B, D, K> :
B: Backend - execution backend (composable)D: usize - dimensionality (compile-time constant)K: TensorKind - element type (Float,Int,Bool)
-
Backend Composition
Tensor<Fusion<CubeBackend<WgpuRuntime, f32, i32, u32>>, 2, Float>This shows that - backends are composable, with Fusion wrapping CubeBackend for kernel fusion capabilities.
-
Client-Server Architecture Burn’s dual client-server abstractions: Fusion Level:
MutexFusionClientwrapsFusionServerwith mutexFusionServerownsMultiStreamwith operation queues CubeCL Level:ComputeClientcommunicates withComputeServer- Similar pattern for actual GPU operations
-
High-Level Tensor Allocation Flow allocation chain is as follows:
reserve() → pool.alloc() → create_page() → WgpuStorage.alloc() → wgpu::Device::create_buffer()- Operation Registration: tensor creation itself is a
NoOpfor the fusion queue. The tensor allocation happens immediately in the CubeCL layer, but fusion operations are queued.
🎯 Key Points
- Layered Abstractions: Fusion and CubeCL each have their own client-server patterns
- Lazy Pooling: ring buffer reuse before new allocations
- Separation of Concerns: Fusion handles operation batching/fusion, CubeCL handles actual GPU execution
📝 Summary
Burn's architecture relies on a
- The generic type system
- Backend composability
- Dual client-server abstractions
- Custom memory management
-
The standard
Tensortype in burn is generic over- a
Backend-B:- Some examples of backends are -
CubeBackend,Fusionetc. CubeBackendis an abstraction built on top of thecubeclcrate - which targets multiple runtimes or GPU programming APIs (Cuda, WebGPU, RoCm etc.)- backends are composable - for instance I can have a tensor that uses the
CubeBackendlike this
// Tensor<B, D, K> - B, D and K are generic type parameters Tensor<Fusion<CubeBackend<WgpuRuntime, f32, i32, u32>>, 2, Float>
- Some examples of backends are -
- i.e. this is a
CubeBackendthat uses thewgpuruntime or API, wrapped in aFusionbackend - meaning it is also capable of kernel fusion. - dimensionality -
D - and the element type -
KKcan be anything that implements theTensorKindtrait i.e.Float,Int,Bool
- a
-
The simplest api used to create a
Tensoris withfrom_data// Type alias for the backend to use. pub type Wgpu<F = f32, I = i32, B = u32> = burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>>; type Backend = Wgpu; let device = Default::default(); // Creation of a tensors with explicit values let tensor_1 = Tensor::<Backend, 2, Float>::from_data([[2., 3.], [4., 5.]], &device);
- But this single line of code has a lot packed into it.
let tensor_1 = Tensor::<Backend, 2, Float>::from_data([[2., 3.], [4., 5.]], &device);
- Essentially, this means we want to allocate a
TensorofFloatsusingWgpufrom_datainvokes one of theFloatTensorOpsmethods calledfloat_from_dataon our targetFusion Backend(Fusion<CubeBackend<Wgpu … >) in our case) to perform the allocation.
-
The
Fusionbackend in burn uses a client-server abstraction - i.e. aFusionClientand aFusionServerFusionClient: An concrete example of a fusion client is aMutexFusionClient.- This simply wraps a
FusionServerin anArc<Mutex<FusionServer<R>>>i.e. it uses a mutex to communicate with the fusion server. - The job of the
FusionClient(orMutexFusionClientin this case) is to register or add new tensor operations (actuallyOperationIr- a burn specific representation of an operation) to a stream’sOperationQueue.
pub struct MutexFusionClient<R: FusionRuntime> { server: Arc<Mutex<FusionServer<R>>>, device: FusionDevice<R>, }
FusionServer: this is a data structure that owns what’s called aMultiStream.- And a multistream is the thing that holds a hashmap of
streams. Each stream in this hashmap is a list of queued operations in anOperationQueueand aProcessorto process those operations.
Note: all of these types are generic over a Backend’s runtime (i.e. the target GPU API) which is
Wgpuin our case Additionally, allocating tensors in GPU memory is considered aNoOp
// In this example, R is FusionCubeRuntime which implements FusionRuntime - a runtime asbtraction for the Fusion
// backend.
pub struct FusionServer<R: FusionRuntime> {
streams: MultiStream<R>,
pub(crate) handles: HandleContainer<R::FusionHandle>,
}
↓
/// Keep track of multiple concurrent streams of operations.
pub struct MultiStream<R: FusionRuntime> {
streams: HashMap<StreamId, Stream<R>>,
optimizations: ExecutionPlanStore<R::Optimization>,
device: R::FusionDevice,
}
↓
struct Stream<R: FusionRuntime> {
queue: OperationQueue<R>,
processor: Processor<R::Optimization>,
}- How does memory/buffer allocation happen for our Tensor:
- There are two different abstractions here
- Fusion Backend abstraction: the
Fusionbackend uses a client-server abstraction to aid with kernel fusion - CubeBackend abstraction: the actual allocation for the tensor on the GPU is done via the wrapped backend
<Fusion<CubeBackend … >
- Fusion Backend abstraction: the
- In our case, thats
CubeBackendwhich inturn relies on a CubeCL runtime<Fusion<CubeBackend<WgpuRuntime … >to do the real work -WgpuRuntime - Note: CubeCL itself has a similar client-server abstraction akin to the Fusion Backend - In CubeCL, they’re called
ComputeClientandComputeServer
- There are two different abstractions here
# Fusion Backend invokes the wrapped CubeBackend which inturn calls on WgpuRuntime to the actual job i.e. Tensor allocation.
# Here's the approx flow
Tensor::from_data()
↓
Fusion<CubeBackend ... >::float_from_data()
↓
CubeBackend::float_from_data() (this simply directs the call to CubeCL runtime)
↓
burn_cubecl::ops::base::from_data() (a proxy for instantiating CubeCL types)
↓
WgpuRuntime::client() (instantiates a CubeCL compute client, in our case we instantiate a Wgpu compute client)
↓
ComputeClient<WgpuServer, MutexComputeChannel>::create() (which given a resource like a buffer, stores it and returns the resource handle.)
↓
MutexComputeChannel::create() (if the compute client is using mutexes to communicate with a Compute Server)
↓
WgpuServer::create() (assuming we're using Wgpu as our backend)
↓
MemoryManager::reserve()
↓
MemoryManagement<WgpuStorage>::reserve()
↓
SlicedPool::alloc() (if using sliced pools)
↓
SlicedPool::create_page() (if new page needed)
↓
WgpuStorage::alloc()
↓
wgpu::Device::create_buffer() (actual GPU allocation)
↓
StorageHandle created with Storage ID
↓
ID pushed to SlicedPool ring buffer
↓
ID returned- The
reservemethod is the entry point that triggers this entire allocation chain, but it's designed to be lazy and pooled - it tries to reuse existing allocations from the ring buffer before creating new GPU buffers. - All these abstraction layers allow
CubeCLto efficiently manage GPU memory while hiding the complexity from higher-level tensor operations.
- My understanding is that memory allocation and data transfer are handled via
DMA— either over PCIe (for discrete GPUs) or through a proprietary on-die interconnect (in the case of integrated GPUs or Apple Silicon). This process does not go through the GPU’s compute pipeline or trigger a kernel launch. Instead, the data transfer is typically handled by a dedicated copy engine via aMemCopycommand in the GPU’s command queues. - Yep, but it is still on the same stream on cuda, and unsure if it's linked to a command encoder in wgpu
- For
wgpu- I see a call being made toqueue.write_buffer(&buffer, 0, bytemuck::cast_slice(&data));which stages the data into CPU-accessible memory, internally records a copy command and submits it immediately to the GPU without needing to manually create aCommandEncoderor submit aCommandBuffer. - I believe the command encoder method is used for large data transfers.
Trace through how kernel fusion actually works for our three operations.
RUST_LOG=burn_fusion=trace,cubecl_wgpu::runtime=trace,wgpu_hal::metal=trace cargo run --example burn-test --release --features wgpu
let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device); // NoOp
let y = tensor_1.clone() * 2.0 + 1.0; // ScalarMul + ScalarAdd operations
let z = y.tanh(); // Tanh operationLog output shows 4 different kernels being generated during the fusion demonstration because of how the Burn framework's fusion system works with the Wgpu backend:
-
Initial Setup Kernels (2): The first two kernels are part of the WGPU backend initialization:
First kernel: Handles buffer setup and validation (dispatch workgroup size validation)Second kernel: Handles metadata and indirect buffer operations (“indirect dispatches" where dispatch parameters come from GPU buffers)
-
Fused Operation Kernel:
[2025-06-30T07:48:56Z DEBUG wgpu_hal::metal::device] Naga generated shader for entry point 'elemwise_fuse' and stage Compute
- This is the key kernel that demonstrates fusion. Notice how it combines all three operations (multiply by 2.0, add 1.0, and tanh) into a single kernel. In the code, you can see:
metal::float2 l_10_ = buffer_0_global[id]; float _e66 = scalars_f32_.inner[0]; metal::float2 l_13_ = l_10_ * _e66; // multiply by 2.0 float _e70 = scalars_f32_.inner[1]; metal::float2 l_16_ = l_13_ + metal::float2(_e70); // add 1.0 metal::float2 _e73 = safe_tanh(l_16_); // tanh
- The log confirms this fusion:
[2025-06-30T07:48:56Z TRACE burn_fusion::stream::store::base] New execution plan 1 - Operations: 3 - Triggers 1
-
Slice Kernel:
[2025-06-30T07:48:56Z DEBUG wgpu_hal::metal::device] Naga generated shader for entry point 'slice_kernel' and stage Compute
- This kernel handles the final data extraction and formatting for display when you print the result.
-
The fusion demonstration works because Burn's fusion system queues operations without executing them immediately. When you call:
let temp = tensor_1.clone() * 2.0; let y = temp + 1.0; let z = y.tanh();
These operations are recorded but not executed until you force execution by printing the result. At that point, the system optimizes by fusing all three operations into a single kernel execution, which is more efficient than running three separate kernels.
let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device); // NoOp
let y = tensor_1.clone() * 2.0 + 1.0; // ScalarMul + ScalarAdd operations
let z = y.tanh(); // Tanh operation==🔴But before we jump in let’s look at pre-requisites i.e. what we would need to generate fast, performant GPU kernels==
In the context of fusion optimization, "reads" and "writes" refer to memory access patterns that track how operations interact with tensors.
pub struct FuseBlockBuilder {
pub settings: FuseBlockSettings,
pub ops: Vec<FuseOp>,
pub reads: BTreeMap<TensorId, (FusePrecision, LayoutInfo)>,
pub writes: BTreeMap<TensorId, FusePrecision>,
pub tensor_writes: BTreeMap<TensorId, FusePrecision>,
}The reads field is a map that tracks which tensors are read by operations in the block:
- Key:
TensorId- Identifies a specific tensor - Value:
(FusePrecision, LayoutInfo)- Specifies:- The floating point precision of tensor elements (e.g.,
F32,F16) - Layout information about how the tensor is accessed
- The floating point precision of tensor elements (e.g.,
When an operation needs to read a tensor, it's registered in this map. This helps the optimizer understand which tensors need to be loaded from memory.
The writes field tracks which tensors are written to by operations in the block:
- Key:
TensorId- Identifies a specific tensor - Value:
FusePrecision- The precision of the written data
When an operation produces a result that needs to be stored in a tensor, it's registered in this map. This helps the optimizer understand which tensors need to be stored back to memory.
The tensor_writes field specifically tracks tensors that need to be written to global memory (as opposed to temporary results that can stay in registers or shared memory).
-
Memory Access Optimization: By knowing which tensors are read and written, the optimizer can:
- Minimize global memory accesses
- Keep intermediate results in faster memory (registers or shared memory)
- Coalesce memory operations for better performance
-
Data Dependency Analysis: Tracking reads and writes helps identify:
- Which operations depend on each other
- Which operations can be executed in parallel
- Which operations can be fused together
-
Memory Allocation: It helps determine:
- How much memory is needed for intermediate results
- When memory can be reused
- When memory needs to be allocated
Let's look at how reads and writes are tracked for our example:
let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device);
let y = tensor_1.clone() * 2.0 + 1.0;
let z = y.tanh();-
Reads:
tensor_1is read from global memory- The scalar
2.0is read from a constant
-
Writes:
- An intermediate result (let's call it
temp1) is written
- An intermediate result (let's call it
-
Reads:
temp1is read (but this is an intermediate result, not from global memory)- The scalar
1.0is read from a constant
-
Writes:
- Another intermediate result (let's call it
y) is written
- Another intermediate result (let's call it
-
Reads:
yis read (again, an intermediate result)
-
Writes:
- The final result
zis written to global memory
- The final result
The optimizer recognizes that temp1 and y are intermediate results that don't need to be written to global memory. Instead, they can be kept in registers or shared memory. Only z needs to be written to global memory.
- Tracks all tensor reads - both from global memory AND intermediate results
- Maps
TensorId → Vec<FuseOp>showing which operations read each tensor
- Contains only the final global memory writes
- Maps
TensorId → FuseOpfor tensors that need to be written to global memory
- Analyzes the dataflow to determine which intermediate results actually need to be written to global memory
- Filters out intermediate results that are only used within the kernel
let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device);
let y = tensor_1.clone() * 2.0 + 1.0; // temp1 = tensor_1 * 2.0, y = temp1 + 1.0
let z = y.tanh(); // z = tanh(y)-
tensor_1 * 2.0:reads[tensor_1.id]getsFuseOp::Assign(Input → Local(0))- Creates intermediate
temp1(Local(1))
-
temp1 + 1.0:reads[temp1.id]gets operation that readstemp1- Creates intermediate
y(Local(2))
-
y.tanh():reads[y.id]gets operation that readsy- Creates final result
z(Local(3))
The tensor_writes() method analyzes:
temp1: Only read by the+ 1.0operation → not written to global memoryy: Only read by thetanh()operation → not written to global memoryz: Final result, needs to persist → written to global memory
So writes only contains:
writes[z.id] = FuseOp::Assign(Local(3) → Output(0))The tensor_writes() method performs dataflow analysis:
// All output tensors that are never read by a following operation should be written to
// since they are essentially the "logical" output of the shader.
for entry in local_tensor_ids_output {
let is_read = local_tensor_ids_input.contains(&entry);
if !is_read
&& !self.local_outputs.contains(&entry.0)
&& !resources.dropped.contains(&entry.0)
{
// This tensor is produced but never consumed → write to global memory
result.insert(precision, tensor.clone());
}
}reads: Tracks ALL tensor reads (global + intermediate)ops: Tracks computation operationstensor_writes(): Analyzes which results need global memory writeswrites: Contains only the global memory write operations
You're absolutely correct - the distinction is between:
- Intermediate writes (kept in registers/local memory, not explicitly tracked as "writes")
- Global memory writes (tracked in
writes, determined bytensor_writes())
By carefully tracking reads and writes, the fusion optimizer can minimize memory accesses and maximize computational efficiency, resulting in faster execution of neural network operations.
-
Operation Queue:
- All 3 operations are added to the queue
-
Stream Optimizer:
- Creates blocks based on tensor dependencies
Block 1:Tensor creation (not fusable with others)Block 2:ScalarMul,ScalarAdd, andTanh(all fusable)
-
Blocks Optimizer:
- Tries to merge blocks (not possible here)
- Optimizes each block separately
- No holes found in this simple example
-
Execution Strategy:
Strategy 1: Execute tensor creation individuallyStrategy 2: Execute fused kernel for the element-wise operations
The final execution will:
- Create
tensor_1 - Execute a single fused kernel that computes
z = tanh((tensor_1 * 2.0) + 1.0)
This eliminates the need for intermediate storage of the y tensor, reducing memory traffic and improving performance.
All operations are queued for execution but aren’t immediately executed i.e. execution is lazy. The OperationQueue in a Stream is effectively the compute graph that contains the ordered sequence of operations waiting to be executed.
/// A growing list of [tensor operation descriptions](OperationIr).
pub struct OperationQueue<R: FusionRuntime> {
/// List of operation descriptions. These contain the exact tensor IDs
/// and shapes so that kernels can be run correctly.
pub(crate) global: Vec<OperationIr>,
/// List of operation descriptions. The tensor IDs and shapes are relative
/// because we don't need to know the exact values, but they are sufficient to
/// determine which operations can be fused.
pub(crate) relative: Vec<OperationIr>,
pub(crate) converter: OperationConverter,
pub(crate) operations: Vec<Box<dyn Operation<R>>>,
pub(crate) variables: HashMap<TensorId, (StreamId, TensorStatus)>,
}- It stores both the high-level operation descriptions (
OperationIr) and the actual executable operations. - It maintains two representations:
- global: The exact operations with precise tensor IDs and shapes
- relative: A representation that helps identify fusion opportunities
- It tracks tensor
variablesand their states through the variables HashMap. - Operations are added in sequence via the add method, building up the computation graph.
- When execution is triggered, the system analyzes this queue to identify fusion opportunities before executing the operations.
The MultiStream type manages multiple such queues (one per stream), effectively maintaining multiple compute graphs that can interact when operations span across streams.
let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device); // NoOp
let y = tensor_1.clone() * 2.0 + 1.0; // ScalarMul + ScalarAdd operations
let z = y.tanh();
println!("Final result: {}", z);Printing z here invokes the Display implementation for Tensor but doesn't directly trigger kernel fusion and execution. Instead, it attempts to read tensor data for display, which indirectly causes execution of pending operations.
The exact flow:
-
When
fmtis called on a tensor, it callsdisplay_recursivewhich eventually callsfmt_inner_tensor -
In
fmt_inner_tensor, the key line that triggers execution is:
let data = burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());-
This flow:
- First creates a slice of the tensor
- Then calls
into_data_async()which returns a future - Then
try_read_syncattempts to synchronously read that future
-
The
into_data_async()call is what triggers the fusion and execution pipeline:- It causes any pending operations to be materialized
- The fusion system analyzes the operation graph
- Fused kernels are created and executed to compute the actual tensor values
-
The execution flow goes through:
burn-fusion/src/stream/multi.rs- Thedrainmethod is called to process pending operationsburn-fusion/src/stream/execution/processor.rs- The processor analyzes operations for fusionburn-cubecl-fusion/src/shared/trace/base.rs- TheFuseTrace::runmethod executes the fused operationsburn-cubecl-fusion/src/shared/trace/executor.rs- TheLaunchPlanExecutor::executemethod handles the actual kernel launch
-
The key steps in the fusion execution are:
- Input planning (via
InputPlanner) - Output planning (via
OutputPlanner) - Vectorization planning (via
VectorizationPlanner) - Kernel execution (via
LaunchPlanExecutor)
- Input planning (via
This happens because displaying a tensor requires accessing its actual values, which forces the computation of any pending operations in the computation graph.
The drain method in MultiStream is responsible for executing all pending operations in a specific stream. Here's what it does:
/// Drain a stream
pub fn drain(&mut self, handles: &mut HandleContainer<R::FusionHandle>, id: StreamId) {
if let Some(stream) = self.streams.get_mut(&id) {
let num_executed = stream.queue.global.len();
stream.processor.process(
Segment::new(&mut stream.queue, handles),
&mut self.optimizations,
ExecutionMode::Sync,
);
stream.cursor += num_executed as u64;
let cleared = self.shared_tensors.on_executed_ops(id, stream);
self.clear_shared_tensors(&cleared, id);
let to_drop = self.shared_tensors.clear_tensors(cleared);
self.drop_shared_tensors(to_drop, handles, id);
}
}- It finds the stream with the given ID
- It processes all operations in the stream's queue in
Syncmode (immediate execution) - It updates the stream's cursor to track execution progress
- It handles shared tensors that were executed:
- Clears them from other streams
- Drops tensors that are no longer needed
As you can see in the code above, draining a stream triggers the fusion process. Remember, every stream is composed of a queue and a processor.
pub(crate) struct Stream<R: FusionRuntime> {
pub(crate) queue: OperationQueue<R>,
processor: Processor<R::Optimization>,
pub(crate) cursor: u64,
}
/// Process a [stream segment](StreamSegment) following a [policy](Policy).
pub(crate) struct Processor<O> {
policy: Policy<O>,
explorer: Explorer<O>,
}The workflow is:
- Operations are queued in the
OperationQueue - The
Processoranalyzes these operations to find fusion opportunities - The
Processoruses theStreamSegmentabstraction to access operations in the queue
⠀The key insight is that the processor doesn't directly decide what's fusable. Instead:
- The processor coordinates the fusion process
- It delegates the actual fusion decisions to optimization builders provided by the runtime
- It uses the
Policyto decide when to explore, execute, or defer operations - It uses the
Explorerto find optimization opportunities
Note: The processor doesn't directly decide what's fusable - it works with the StreamSegment abstraction:
#[derive(new)]
struct Segment<'a, R: FusionRuntime> {
queue: &'a mut OperationQueue<R>,
handles: &'a mut HandleContainer<R::FusionHandle>,
}
pub fn process<Segment>(
&mut self,
mut segment: Segment,
store: &mut ExecutionPlanStore<O>,
mode: ExecutionMode,
) where
Segment: StreamSegment<O>,
{
...
...
...
let action = self.policy.action(store, segment.operations(), mode);
match action {
Action::Explore => {
self.explore(&mut segment, store, mode);
if self.explorer.is_up_to_date() {
break;
}
}
Action::Defer => {
match mode {
ExecutionMode::Lazy => break,
ExecutionMode::Sync => panic!("Can't defer while sync"),
};
}
Action::Execute(id) => {
if let ExecutionMode::Sync = mode {
store.add_trigger(id, ExecutionTrigger::OnSync);
}
segment.execute(id, store);
self.reset(store, segment.operations());
}
};
}The Segment type takes exclusive access of operation queue and provides access to operations through the operations() method. The processor then:
- Uses a
Policyto decide what action to take (explore, execute, defer) - Uses an
Explorerto find optimization opportunities - When an optimization is found, it's stored in the
ExecutionPlanStore
/// The policy keeps track of all possible execution plans (ids) for the current operation stream.
pub(crate) struct Policy<O> {
/// List of potential execution plans that are compatible with current stream segment
candidates: Vec<OperationsValidator<ExecutionPlanId>>,
/// List of candidate execution plans that have been found; we can still keep searching
/// to potentially find a better one.
availables: Vec<AvailableItem>,
/// The found execution plan that should be executed, along with the number of operations
/// in the plan.
found: Option<(ExecutionPlanId, usize)>,
/// The number of operations that have been analyzed
num_operations: usize,
_item_type: PhantomData<O>,
}
/// Explore and create new optimization.
pub struct Explorer<O> {
optimizer: StreamOptimizer<O>,
num_deferred: usize,
num_explored: usize,
is_still_optimizing: bool,
}Note: Fusability is ultimately determined by the optimization builders provided by the runtime, which analyze operations to see if they can be combined.
If querying the policy returns Action::Explore while processing a segment (i.e. a stream or sequence of ordered operations), we enter the exploration branch of the match above. This invokes the Explorer which contains the StreamOptimizer
- The
StreaOptimizer’s first job is to register (or add) all operations in a stream/segment to ablock
This method attempts to register an operation with existing blocks in the StreamOptimizer. Here's what it does:
impl<O: NumOperations> Explorer<O> {
...
...
// Explore the provided operations.
pub(crate) fn explore(
&mut self,
operations: &[OperationIr],
mode: ExecutionMode,
) -> ExplorationAction<O> {
self.update(operations); // This essentially performs Block Op registration via `register_inner`, below
// Can only continue exploration when not sync.
if let ExecutionMode::Lazy = mode {
if self.is_still_optimizing {
return ExplorationAction::Continue;
}
}
let optimization = self.optimizer.optimize(operations); // post registartion, we optimize
ExplorationAction::Completed(optimization)
}
...
...
}
impl<O: NumOperations> StreamOptimizer<O> {
...
...
fn register_inner(&mut self, operation: &OperationIr, force: bool) -> usize {
let mut added_count = 0;
for block in self.blocks.iter_mut() {
match block.register(operation, self.length, force) {
RegistrationResult::Accepted => {
added_count += 1;
}
RegistrationResult::NotPartOfTheGraph => {}
}
}
added_count
}
...
...
}- It iterates through all existing blocks in the
StreamOptimizer - For each block, it tries to register the operation by calling
block.register() - It passes:
- The operation to register
- The current length (position in the stream)
- A force flag that can override normal registration rules
- It counts how many blocks accepted the operation
- It returns this count
A Block is the StreamOptimizer's abstraction for an ordered sequence of operations that can potentially be fused together. Each block:
- Contains operations that are related (they use the same tensors)
- Tracks the ordering of operations
- Maintains a set of optimization builders that analyze the operations
- Can determine if operations can be fused
The key part is how a block decides whether to accept an operation:
pub fn register(
&mut self,
operation: &OperationIr,
order: usize,
force: bool,
) -> RegistrationResult {
if self.ids.is_empty() {
self.register_op(operation, order);
return RegistrationResult::Accepted;
}
let mut contains = false;
for node in operation.nodes() {
contains = self.ids.contains(&node.id);
if contains {
break;
}
}
if !contains && !force {
return RegistrationResult::NotPartOfTheGraph;
}
self.register_op(operation, order);
RegistrationResult::Accepted
}A block accepts an operation if:
- The block is empty (first operation always accepted)
- The operation uses tensors that are already in the block
- The
forceflag is true (override normal rules)
Note: calling
register_opon aBlock<O>ultimately registers the operation withFuseBlockBuilderby adding it to theops vectorinFuseBlockBuilder.The operation flows through several layers of abstraction, but ultimately ends up in the ops vector of
FuseBlockBuilder,which is part of theFuseTraceBuilder.During the operation registration flow, all
FuseBlockBuilderfields get populated simultaneously:
- self.ops gets the actual operations
- self.reads gets populated with input read operations
- resources.inputs/outputs/scalars get populated with tensor/scalar metadata
- tensor_writes (computed later) is determined by analyzing the populated resources and operations
⠀This is why by the time block.build() is called, all the information needed to generate the final fused kernel is already available!
Block.register_op
↓
OptimizationBuilder.register (trait method)
↓
ElementWiseBuilder.register (or other implementation)
↓
FuseOptimizationBuilder.register
↓
FuseOptimizationBuilder.register_numeric/register_binary_ops/etc.
↓
FuseOptimizationBuilder.register_scalar_ops/register_unary_ops/etc.
↓
TryFuseBuilder.register
↓
FuseTraceBuilder.register_operation
↓
FuseBlockBuilder.ops.pushThis method is part of a larger strategy in StreamOptimizer:
- When a new operation arrives, it first tries to merge blocks if needed
- Then it tries to register the operation with existing blocks using
register_inner - If no block accepts it, it creates a new block for the operation
- It tracks how many blocks it has and may stop optimizing if it exceeds
max_blocks
The goal is to group operations into blocks that can be optimized together, while maintaining the correct execution order and dependencies between operations.
This approach allows the system to:
- Find fusion opportunities within each block
- Handle complex streams with multiple independent fusion groups
- Maintain correct execution semantics
To Reiterate:
- Each
Block<O>contains operations that could potentially be fused - The
optimize()method inBlock<O>finds the best optimization strategy: - Each block also contains a set of
OptimizationBuilderinstances that analyze operations- Example: for element-wise operations, the
ElemwiseOptimizationBuilderrecognizes patterns it can fuse
- Example: for element-wise operations, the
Here’s the flow from a StreamOptimizer all the way to producing a FuseTrace.
Explorer.explore
↓
StreamOptimizer.optimize
↓
BlocksOptimizer.optimize
↓
Block.optimize
↓
find_best_optimization_index
↓
OptimizationBuilder.build (trait method)
↓
FuseOptimizationBuilder.build
↓
TryFuseBuilder.build
↓
FuseTraceBuilder.build
↓
FuseTrace is createdpub fn optimize(&self, operations: &[OperationIr]) -> BlockOptimization<O> {
let result = BlocksOptimizer::new(self.blocks.clone()).optimize();
match result {
BlocksOptimizerResult::Full(optimization) => optimization,
BlocksOptimizerResult::WithHoles { strategies, ordering, holes } => {
// Handle holes case...
}
}
}The StreamOptimizer creates a BlocksOptimizer with its blocks and calls optimize().
pub fn optimize(mut self) -> BlocksOptimizerResult<O> {
self = self.merging_pass();
let mut strategies = Vec::with_capacity(self.blocks.len());
let mut ordering = Vec::new();
let mut blocks = Vec::new();
core::mem::swap(&mut blocks, &mut self.blocks);
for block in blocks {
match self.optimize_block(block, &mut ordering) {
BlockOptimizationStep::Contiguous { strategy } => {
strategies.push(Box::new(strategy));
}
// Handle other cases...
}
}
// Create and return BlockOptimization
}
/// Optimize a single block.
fn optimize_block(
&mut self,
block: Block<O>,
ordering: &mut Vec<usize>,
) -> BlockOptimizationStep<O> {
let last_index = block.end_pos;
let mut block_optimization = block.optimize();
let opt_size = block_optimization.ordering.len();
...
...
}The BlocksOptimizer:
- Tries to merge blocks that can be combined
- Processes each block to create optimization strategies
- Combines these strategies into a final
BlockOptimization
pub fn optimize(mut self) -> BlockOptimization<O> {
match find_best_optimization_index(&mut self.builders) {
Some(index) => {
let opt = self.builders[index].build();
let opt_len = opt.len();
if opt_len < self.operations.len() {
self.ordering.drain(opt_len..);
}
let strategy = ExecutionStrategy::Optimization {
ordering: Arc::new(self.ordering.clone()),
opt,
};
BlockOptimization::new(strategy, self.ordering)
}
None => {
let strategy = ExecutionStrategy::Operations {
ordering: Arc::new(self.ordering.clone()),
};
BlockOptimization::new(strategy, self.ordering)
}
}
}The Block.optimize method:
- Finds the best optimization builder using
find_best_optimization_index - Calls
build()on that builder to create the optimization - Creates an
ExecutionStrategywith the optimization - Returns a
BlockOptimizationwith the strategy and ordering
fn find_best_optimization_index<O>(
optimizations: &mut [Box<dyn OptimizationBuilder<O>>],
) -> Option<usize> {
let mut best_index = None;
let mut best_score = 0;
for (i, optimization) in optimizations.iter().enumerate() {
let properties = optimization.properties();
if properties.ready && properties.score >= best_score {
best_index = Some(i);
best_score = properties.score;
}
}
best_index
}This function:
-
Examines all optimization builders
-
Finds the one with the highest score that is ready
-
Returns its index
The scores are populated during the registration process as operations are added to each builder. Here's the complete flow:
When a Block receives operations, it calls:
Block.register()→Block.register_op()→builder.register(operation)for each builder
Each optimization builder calculates its score based on how many operations it successfully accepts:
fn properties(&self) -> OptimizationProperties {
let ready = self.num_ops > 0;
OptimizationProperties {
ready,
score: self.num_ops as u64, // Score = number of operations accepted
}
}When FuseOptimizationBuilder.register() is called:
fn register(&mut self, operation: &OperationIr) {
// ... operation type checking ...
if !self.register_numeric::<i32>(ops) {
self.status = OptimizationStatus::Closed; // Builder rejects future ops
return;
}
self.status = OptimizationStatus::Open;
self.num_ops += 1; // This increments the score!
}Different builders have different scoring strategies:
FuseOptimizationBuilder: Score = num_ops (number of operations it can fuse)
ReduceBuilder: Score = base_score + 1 if it has a reduce operation
fn properties(&self) -> burn_fusion::OptimizationProperties {
let mut properties = self.builder.properties();
if self.reduce.is_some() {
properties.ready = true;
properties.score += 1; // Bonus for having reduce
} else {
properties.ready = false;
};
properties
}MatmulBuilder: Score = base_score + 1 (bonus for matmul operations)
fn properties(&self) -> burn_fusion::OptimizationProperties {
let mut properties = self.builder.properties();
properties.score += 1; // Bonus for matmul
properties
}Finally, find_best_optimization_index picks the builder with the highest score:
fn find_best_optimization_index<O>(
optimizations: &mut [Box<dyn OptimizationBuilder<O>>],
) -> Option<usize> {
let mut best_index = None;
let mut best_score = 0;
for (i, optimization) in optimizations.iter().enumerate() {
let properties = optimization.properties();
if properties.ready && properties.score >= best_score {
best_index = Some(i);
best_score = properties.score;
}
}
best_index
}The scores are populated incrementally during operation registration:
- Each operation gets registered with all builders in the block
- Each builder decides if it can handle the operation
- If accepted:
num_ops++which increases the score - If rejected: Builder status becomes
Closedand stops accepting operations - Specialized builders (Matmul, Reduce) get bonus points for their specific operations
- Best builder is selected based on highest score when
Block.optimize()is called
This allows the fusion system to automatically choose the most effective optimization strategy based on which builder can fuse the most operations together.
This is a trait method implemented by various builders. For example, ElementWiseBuilder.build():
fn build(&self) -> CubeOptimization<R> {
let client = R::client(&self.device);
let trace = self.builder.build();
let elementwise =
ElemwiseOptimization::<R>::new(trace, client, self.device.clone(), self.len());
CubeOptimization::ElementWise(elementwise)
}The builder calls build() on its internal builder (typically a FuseOptimizationBuilder).
fn build(&self) -> FuseTrace {
self.builder.build(self.current_output_shape.clone())
}This simply forwards to TryFuseBuilder.build().
fn build(&self, shape: Vec<usize>) -> FuseTrace {
self.builder.build(shape)
}This forwards to FuseTraceBuilder.build().
pub fn build(&self, shape_ref: Vec<usize>) -> FuseTrace {
let mut resources = self.resources.clone();
let mut outputs = RegisteredTensors::default();
let mut blocks = Vec::new();
let mut register_block =
|block: &FuseBlockBuilder, shape_ref: &Vec<usize>, offset: usize| {
let (block, block_tensor_writes) =
block.build(&self.resources, shape_ref.clone(), offset);
blocks.push(block);
let num_outputs = block_tensor_writes.len();
for (ir, precision) in block_tensor_writes.into_iter() {
outputs.insert(precision, ir);
}
num_outputs
};
let mut offset = 0;
for (block, shape_ref) in self.blocks_previous.iter() {
offset += register_block(block, shape_ref, offset);
}
register_block(&self.block_current, &shape_ref, offset);
resources.outputs = outputs;
FuseTrace { blocks, resources }
}This is where the FuseTrace is actually created:
- It clones the current resources
- It processes each block using
FuseBlockBuilder.build() - It collects all blocks and their outputs
- It creates a
FuseTracewith the blocks and resources
pub fn build(
&self,
resources: &FuseResources,
shape_ref: Vec<usize>,
offset: usize,
) -> (FuseBlock, RegisteredTensors) {
let ops = self.ops.clone();
let reads = self.reads.clone();
let tensor_writes = self.tensor_writes(resources);
let mut writes = BTreeMap::new();
// Process writes...
(
FuseBlock {
settings: self.settings,
ops,
shape_ref,
reads,
writes,
},
tensor_writes,
)
}This creates a FuseBlock with:
- The operations that have been registered
- The reads and writes for each tensor
- The shape reference and settings
To summarize how a FuseTrace is produced:
- Block Collection: The
StreamOptimizercollects blocks of operations - Block Optimization: Each block is optimized using the best available builder
- Builder Selection: The best builder is selected based on its score
- Trace Building: The selected builder builds a trace by:
- Processing each block to create a
FuseBlock - Collecting all blocks and their resources
- Creating a
FuseTracewith the blocks and resources
- Processing each block to create a
The key is that by the time build() is called, all operations have already been registered with the builders. The build() method doesn't register new operations - it processes the already-registered operations to create an optimized execution plan.
The resulting FuseTrace contains:
- A list of
FuseBlocks, each with its operations, reads, and writes - Resources including inputs, outputs, and scalars
- Everything needed to execute the fused operations efficiently
This trace can then be executed by a runtime to perform the fused operations efficiently.
- WIP
$ cargo expand --manifest-path crates/burn-cubecl-fusion/Cargo.toml --lib elemwise::optimizationFrom what I can tell, fused kernels aren’t generated on the fly or at runtime—they’re statically defined and precompiled. They are like generic templates that can handle arbitrary sequences of operations.
For example, the elemwise_fuse kernel below appears to handle any sequence of elementwise ops:
#[cube(launch_unchecked)]
fn elemwise_fuse(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
#[comptime] config: &FuseBlockConfig,
) {
let values = Registry::<Arg, Line<f32>>::new();
let args = comptime![Sequence::<Arg>::new()];
let pos = ABSOLUTE_POS;
let mut locals = init_locals(inputs, outputs, config);
let length = ref_len(inputs, outputs, &locals, config);
if pos < length {
fuse_on_write::<f32>(inputs, outputs, &mut locals, pos, values, args, config)
}
}As I understand it, the flow looks something like this:
1. A FuseTrace is recorded with a sequence of operations like:
ops: [
Assign(input -> local_0),
Mul(local_0, scalar_0 -> local_1),
Add(local_1, scalar_1 -> local_2),
Tanh(local_2 -> local_3),
]2. At kernel launch time, this sequence is passed in as part of the `FuseBlockConfig`.
3. The actual kernel uses a generic fuse function that looks like:
#[cube]
fn fuse(..., #[comptime] config: &FuseBlockConfig) {
#[unroll]
for index in 0..config.ops.len() {
let op = comptime! { config.ops.index(index).clone() };
match op {
FuseOp::Mul(op) => mul(...),
FuseOp::Add(op) => add(...),
FuseOp::Tanh(op) => tanh(...),
// etc.
}
}
}So there’s no runtime code generation involved—just compile-time specialization or rather JIT compilation-specialisation via #[comptime] and #[unroll].
When operations can't be fused:
- The processor still processes them through the same flow
- The
Explorerwill fail to find optimizations - The operations will be executed individually using the
ExecutionStrategy::Operationsstrategy
pub(crate) enum ExecutionStrategy<O> {
/// An optimization was found, and therefore should be executed.
Optimization { opt: O, ordering: Arc<Vec<usize>> },
/// No optimization was found, each operation should be executed individually.
Operations { ordering: Arc<Vec<usize>> },
/// A composition of multiple execution strategies.
Composed(Vec<Box<Self>>),
}Operations across multiple streams are not directly fused. The system is designed to:
- Keep streams independent for concurrent execution
- Handle shared tensors between streams
- Synchronize streams when necessary
The MultiStream type manages these interactions through methods like resolve_streams and merge_streams_timelines, which ensure proper ordering when streams share tensors, but it doesn't attempt to fuse operations across different streams.
When a tensor is shared between streams, the system ensures proper execution order by draining dependent streams when needed, but fusion only happens within individual streams.
- What is the correct tile size?
- Is it good to unroll this loop or not?