Created
October 10, 2024 11:56
-
-
Save ncthbrt/b4c7206cff273acbeddb3afeea54e5ad to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
alias main = ReduceBuffer with { | |
alias Op = SumBinaryOp<F32>; | |
const block_area: u32 = 4u; | |
const work_size: u32 = 18u; | |
const threads: u32 = 10u; | |
}::main; | |
struct Sum<N> { | |
sum: N::T | |
} | |
mod SumBinaryOp<N> { | |
alias OpElem = Sum<N>; | |
alias LoadElem = Sum<N>; | |
fn identityOp() -> OpElem { | |
return OpElem(); | |
} | |
fn loadOp(a: LoadElem) -> OpElem { | |
return OpElem(a.sum); | |
} | |
fn binaryOp(a: OpElem, b: OpElem) -> OpElem { | |
return OpElem(N::add(a.sum, b.sum)); | |
} | |
} | |
mod Intrinsic<N> { | |
alias T = N; | |
fn add(a: T, b: T) -> T { | |
return a + b; | |
} | |
fn identity() -> T { | |
return N(); | |
} | |
} | |
alias F32 = Intrinsic<f32>; | |
mod ReduceWorkgroup<Op, work_size, threads> { | |
var <workgroup> work: array<Op::OpElem, work_size>; | |
fn reduceWorkgroup(localId: u32) { | |
let workDex = localId << 1u; | |
for (var step = 1u; step < threads; step <<= 1u) { | |
workgroupBarrier(); | |
if localId % step == 0u { | |
work[workDex] = Op::binaryOp(work[workDex], work[workDex + step]); | |
} | |
} | |
} | |
} | |
// Same here | |
mod ReduceBuffer<Op, block_area, work_size, threads> { | |
// extend brings the module members into the namespace | |
extend ReduceWorkgroup<Op, work_size, threads>; | |
alias Input = Op::LoadElem; | |
alias Output = Op::OpElem; | |
struct Uniforms { | |
sourceOffset: u32, // offset in Input elements to start reading in the source | |
resultOffset: u32, // offset in Output elements to start writing in the results | |
} | |
@group(0) @binding(0) var<uniform> u: Uniforms; | |
@group(0) @binding(1) var<storage, read> src: array<Input>; | |
@group(0) @binding(2) var<storage, read_write> out: array<Output>; | |
@group(0) @binding(11) var<storage, read_write> debug: array<f32>; // buffer to hold debug values | |
const workgroup_threads = 4u; | |
var <workgroup> work: array<Output, workgroup_threads>; | |
// reduce a buffer of values to a single value, returned as the last element of the out array | |
// | |
// each dispatch does two reductions: | |
// . each invocation reduces from a src buffer to the workgroup buffer | |
// . one invocation per workgroup reduces from the workgroup buffer to the out buffer | |
// the driver issues multiple dispatches until the output is 1 element long | |
// (subsequent passes uses the output of the previous pass as the src) | |
// the same output buffer can be used as input and output in subsequent passes | |
// . start and end indices in the uniforms indicate input and output positions in the buffer | |
// | |
@compute | |
@workgroup_size(workgroup_threads, 1, 1) | |
fn main( | |
@builtin(global_invocation_id) grid: vec3<u32>, // coords in the global compute grid | |
@builtin(local_invocation_index) localIndex: u32, // index inside the this workgroup | |
@builtin(num_workgroups) numWorkgroups: vec3<u32>, // number of workgroups in this dispatch | |
@builtin(workgroup_id) workgroupId: vec3<u32> // workgroup id in the dispatch | |
) { | |
reduceBufferToWork(grid.xy, localIndex); | |
let outDex = workgroupId.x + u.resultOffset; | |
reduceWorkgroup(localIndex); | |
if localIndex == 0u { | |
out[outDex] = work[0]; | |
} | |
} | |
fn reduceBufferToWork(grid: vec2<u32>, localId: u32) { | |
var values = fetchSrcBuffer(grid.x); | |
var v = reduceSrcBlock(values); | |
work[localId] = v; | |
} | |
fn fetchSrcBuffer(gridX: u32) -> array<Output, block_area> { | |
let start = u.sourceOffset + (gridX * block_area); | |
let end = arrayLength(&src); | |
var a = array<Output, block_area>(); | |
for (var i = 0u; i < block_area; i = i + 1u) { | |
var idx = i + start; | |
if idx < end { | |
a[i] = Op::loadOp(src[idx]); | |
} else { | |
a[i] = Op::identityOp(); | |
} | |
} | |
return a; | |
} | |
fn reduceSrcBlock(a: array<Output, block_area>) -> Output { | |
var v = a[0]; | |
for (var i = 1u; i < block_area; i = i + 1u) { | |
v = Op::binaryOp(v, a[i]); | |
} | |
return v; | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment