Created
April 9, 2024 21:22
-
-
Save VictorTaelin/3095032b157cbe79ec368347690fd893 to your computer and use it in GitHub Desktop.
Fast CUDA block-local prefix sum (scamsun) using warp sync primitives (__shfl_up_sync)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Fast block-local prefix-sum on CUDA, using warp-syncs. | |
// The input is an array of u32. It is mutated in place. Example: | |
// arr = [1,1,1,1,...] | |
// Becomes: | |
// arr = [1,2,3,4,...] | |
// The number of elements must be equal to threads per block (TPB). | |
#include <stdio.h> | |
#include <cuda_runtime.h> | |
typedef unsigned int u32; | |
// Threads Per Block | |
#define TPB_L2 8 | |
#define TPB (1 << TPB_L2) | |
// Amount of times to repeat, for benchmark | |
#define TIMES (32 * 256 * 256) | |
// OLD SCANSUM ("work-efficient" algorith) - exclusive | |
__device__ u32 scansum_0(u32* arr) { | |
u32 tid = threadIdx.x; | |
// upsweep | |
for (u32 d = 0; d < TPB_L2; ++d) { | |
u32 a = 1 << (d + 0); | |
u32 b = 1 << (d + 1); | |
if (tid % b == 0) { | |
arr[tid+b-1] += arr[tid+a-1]; | |
} | |
__syncthreads(); | |
} | |
// gets sum | |
u32 sum = arr[TPB - 1]; | |
__syncthreads(); | |
// clears last | |
if (tid == 0) { | |
arr[TPB - 1] = 0; | |
} | |
__syncthreads(); | |
// downsweep | |
for (u32 d = TPB_L2 - 1; d <= TPB_L2 - 1; --d) { | |
u32 a = 1 << (d + 0); | |
u32 b = 1 << (d + 1); | |
if (tid % b == 0) { | |
u32 tmp = arr[tid+a-1]; | |
arr[tid+a-1] = arr[tid+b-1]; | |
arr[tid+b-1] += tmp; | |
} | |
__syncthreads(); | |
} | |
return sum; | |
} | |
// NEW SCANSUM (using warp syncs) - inclusive | |
__device__ u32 scansum_1(u32* arr) { | |
__shared__ u32 wsum[TPB]; | |
u32 tid = threadIdx.x; // thread id | |
u32 wid = tid / 32; // warp id | |
u32 lid = tid % 32; // local id | |
u32 ini = wid * 32; // array index | |
// Performs warp scansum | |
u32 sum, num; | |
sum = arr[ini+lid]; | |
for (u32 k = 1; k < 32; k *= 2) { | |
num = __shfl_up_sync(__activemask(), sum, k); | |
sum = lid >= k ? sum + num : sum; | |
} | |
arr[ini+lid] = sum; | |
// Saves total warp sum | |
if (lid == 31) { | |
//printf("[%04x] %d <- %d\n", tid, TPB+wid, sum); | |
wsum[wid] = sum; | |
} | |
__syncthreads(); | |
// First warp perform a "scansum of warp sums" | |
u32 ssum, snum; | |
if (wid == 0 && lid < TPB / 32) { | |
ssum = wsum[lid]; | |
for (u32 k = 1; k < TPB / 32; k *= 2) { | |
snum = __shfl_up_sync(__activemask(), ssum, k); | |
ssum = lid >= k ? ssum + snum : ssum; | |
} | |
wsum[lid] = ssum; | |
} | |
__syncthreads(); | |
// Adds sum of warps before this one | |
if (wid > 0) { | |
arr[ini+lid] += wsum[wid-1]; | |
} | |
return sum; | |
} | |
__global__ void scansum_kernel(u32* arr) { | |
__shared__ u32 smem[2*TPB]; | |
u32 tid = threadIdx.x; | |
for (u32 i = 0; i < TIMES; ++i) { | |
smem[tid] = tid; | |
__syncthreads(); | |
scansum_1(smem); | |
__syncthreads(); | |
} | |
arr[tid] = smem[tid]; | |
arr[tid+TPB] = smem[tid+TPB]; | |
} | |
int main() { | |
u32 h_arr[TPB]; | |
memset(h_arr, 0, TPB * sizeof(u32)); | |
u32 *d_arr; | |
cudaMalloc(&d_arr, TPB * sizeof(u32)); | |
cudaMemcpy(d_arr, h_arr, TPB * sizeof(u32), cudaMemcpyHostToDevice); | |
cudaEvent_t start, stop; | |
cudaEventCreate(&start); | |
cudaEventCreate(&stop); | |
cudaEventRecord(start); | |
scansum_kernel<<<1, TPB>>>(d_arr); | |
cudaEventRecord(stop); | |
cudaMemcpy(h_arr, d_arr, TPB*sizeof(u32), cudaMemcpyDeviceToHost); | |
cudaEventSynchronize(stop); | |
float milliseconds = 0; | |
cudaEventElapsedTime(&milliseconds, start, stop); | |
printf("Scansum time: %f us\n", milliseconds * 1000.0 / (float)TIMES); | |
for (int i = 0; i < TPB; ++i) { | |
printf("%u ", h_arr[i]); | |
} | |
printf("\n"); | |
cudaFree(d_arr); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment