Skip to content

Instantly share code, notes, and snippets.

@ncthbrt
Created October 10, 2024 11:56
Show Gist options
  • Save ncthbrt/b4c7206cff273acbeddb3afeea54e5ad to your computer and use it in GitHub Desktop.
Save ncthbrt/b4c7206cff273acbeddb3afeea54e5ad to your computer and use it in GitHub Desktop.
alias main = ReduceBuffer with {
alias Op = SumBinaryOp<F32>;
const block_area: u32 = 4u;
const work_size: u32 = 18u;
const threads: u32 = 10u;
}::main;
struct Sum<N> {
sum: N::T
}
mod SumBinaryOp<N> {
alias OpElem = Sum<N>;
alias LoadElem = Sum<N>;
fn identityOp() -> OpElem {
return OpElem();
}
fn loadOp(a: LoadElem) -> OpElem {
return OpElem(a.sum);
}
fn binaryOp(a: OpElem, b: OpElem) -> OpElem {
return OpElem(N::add(a.sum, b.sum));
}
}
mod Intrinsic<N> {
alias T = N;
fn add(a: T, b: T) -> T {
return a + b;
}
fn identity() -> T {
return N();
}
}
alias F32 = Intrinsic<f32>;
mod ReduceWorkgroup<Op, work_size, threads> {
var <workgroup> work: array<Op::OpElem, work_size>;
fn reduceWorkgroup(localId: u32) {
let workDex = localId << 1u;
for (var step = 1u; step < threads; step <<= 1u) {
workgroupBarrier();
if localId % step == 0u {
work[workDex] = Op::binaryOp(work[workDex], work[workDex + step]);
}
}
}
}
// Same here
mod ReduceBuffer<Op, block_area, work_size, threads> {
// extend brings the module members into the namespace
extend ReduceWorkgroup<Op, work_size, threads>;
alias Input = Op::LoadElem;
alias Output = Op::OpElem;
struct Uniforms {
sourceOffset: u32, // offset in Input elements to start reading in the source
resultOffset: u32, // offset in Output elements to start writing in the results
}
@group(0) @binding(0) var<uniform> u: Uniforms;
@group(0) @binding(1) var<storage, read> src: array<Input>;
@group(0) @binding(2) var<storage, read_write> out: array<Output>;
@group(0) @binding(11) var<storage, read_write> debug: array<f32>; // buffer to hold debug values
const workgroup_threads = 4u;
var <workgroup> work: array<Output, workgroup_threads>;
// reduce a buffer of values to a single value, returned as the last element of the out array
//
// each dispatch does two reductions:
// . each invocation reduces from a src buffer to the workgroup buffer
// . one invocation per workgroup reduces from the workgroup buffer to the out buffer
// the driver issues multiple dispatches until the output is 1 element long
// (subsequent passes uses the output of the previous pass as the src)
// the same output buffer can be used as input and output in subsequent passes
// . start and end indices in the uniforms indicate input and output positions in the buffer
//
@compute
@workgroup_size(workgroup_threads, 1, 1)
fn main(
@builtin(global_invocation_id) grid: vec3<u32>, // coords in the global compute grid
@builtin(local_invocation_index) localIndex: u32, // index inside the this workgroup
@builtin(num_workgroups) numWorkgroups: vec3<u32>, // number of workgroups in this dispatch
@builtin(workgroup_id) workgroupId: vec3<u32> // workgroup id in the dispatch
) {
reduceBufferToWork(grid.xy, localIndex);
let outDex = workgroupId.x + u.resultOffset;
reduceWorkgroup(localIndex);
if localIndex == 0u {
out[outDex] = work[0];
}
}
fn reduceBufferToWork(grid: vec2<u32>, localId: u32) {
var values = fetchSrcBuffer(grid.x);
var v = reduceSrcBlock(values);
work[localId] = v;
}
fn fetchSrcBuffer(gridX: u32) -> array<Output, block_area> {
let start = u.sourceOffset + (gridX * block_area);
let end = arrayLength(&src);
var a = array<Output, block_area>();
for (var i = 0u; i < block_area; i = i + 1u) {
var idx = i + start;
if idx < end {
a[i] = Op::loadOp(src[idx]);
} else {
a[i] = Op::identityOp();
}
}
return a;
}
fn reduceSrcBlock(a: array<Output, block_area>) -> Output {
var v = a[0];
for (var i = 1u; i < block_area; i = i + 1u) {
v = Op::binaryOp(v, a[i]);
}
return v;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment