Last active
October 5, 2020 16:16
-
-
Save r-barnes/b8d76be3c7430e450ebe2e2dd95c3ddd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cub/cub.cuh> | |
#include <thrust/device_vector.h> | |
#include <algorithm> | |
#include <chrono> | |
#include <random> | |
struct SumCountRet { | |
std::vector<double> sums; | |
std::vector<uint32_t> counts; | |
}; | |
__global__ void group_summer_shmem( | |
const int32_t *const labels, | |
const float *const weights, | |
const int num_elements, | |
const int num_classes, | |
double *const sums, | |
uint32_t *const counts | |
){ | |
constexpr int num_threads = 128; | |
assert(num_threads==blockDim.x); | |
//Get shared memory | |
extern __shared__ int s[]; | |
double *const sums_shmem = (double*)s; | |
uint32_t *const counts_shmem = (uint32_t*)&sums_shmem[num_threads*num_classes]; | |
double *const my_sums = &sums_shmem [num_classes*threadIdx.x]; | |
uint32_t *const my_counts = &counts_shmem[num_classes*threadIdx.x]; | |
for(int i=threadIdx.x;i<num_threads*num_classes;i+=num_threads){ | |
sums_shmem[i] = 0; | |
counts_shmem[i] = 0; | |
} | |
__syncthreads(); | |
for(int i=blockIdx.x * blockDim.x + threadIdx.x;i<num_elements;i+=gridDim.x*blockDim.x){ | |
const auto l = labels[i]; | |
my_sums[l] += weights[i]; | |
my_counts[l]++; | |
} | |
__syncthreads(); | |
__shared__ cub::BlockReduce<double, num_threads>::TempStorage double_temp_storage; | |
__shared__ cub::BlockReduce<uint32_t, num_threads>::TempStorage uint32_t_temp_storage; | |
for(int l=0;l<num_classes;l++){ | |
const auto sums_total = cub::BlockReduce<double,num_threads>(double_temp_storage).Reduce(my_sums[l], cub::Sum()); | |
const auto counts_total = cub::BlockReduce<uint32_t,num_threads>(uint32_t_temp_storage).Reduce(my_counts[l], cub::Sum()); | |
if(threadIdx.x==0){ | |
atomicAdd(&sums[l], sums_total); | |
atomicAdd(&counts[l], counts_total); | |
} | |
} | |
} | |
__global__ void group_summer_shatomic( | |
const int32_t *const labels, | |
const float *const weights, | |
const int num_elements, | |
const int num_classes, | |
double *const sums, | |
uint32_t *const counts | |
){ | |
constexpr int num_threads = 128; | |
assert(num_threads==blockDim.x); | |
//Get shared memory | |
extern __shared__ int s[]; | |
double *const sums_shmem = (double*)s; | |
uint32_t *const counts_shmem = (uint32_t*)&sums_shmem[num_classes]; | |
for(int i=threadIdx.x;i<num_classes;i+=num_threads){ | |
sums_shmem[i] = 0; | |
counts_shmem[i] = 0; | |
} | |
__syncthreads(); | |
for(int i=blockIdx.x * blockDim.x + threadIdx.x;i<num_elements;i+=gridDim.x*blockDim.x){ | |
const auto l = labels[i]; | |
atomicAdd(&sums_shmem[l], (double)weights[i]); | |
atomicAdd(&counts_shmem[l], 1); | |
} | |
__syncthreads(); | |
for(int i=threadIdx.x;i<num_classes;i+=num_threads){ | |
atomicAdd(&sums[i], sums_shmem[i]); | |
atomicAdd(&counts[i], counts_shmem[i]); | |
} | |
} | |
__global__ void group_summer_global( | |
const int32_t *const labels, | |
const float *const weights, | |
const int num_elements, | |
const int num_classes, | |
double *const sums, | |
uint32_t *const counts | |
){ | |
for(int i=blockIdx.x * blockDim.x + threadIdx.x;i<num_elements;i+=gridDim.x*blockDim.x){ | |
const auto l = labels[i]; | |
atomicAdd(&sums[l], (double)weights[i]); | |
atomicAdd(&counts[l], 1); | |
} | |
} | |
SumCountRet group_summer_cpu( | |
const std::vector<int32_t> &labels, | |
const std::vector<float> &weights | |
){ | |
const int num_classes = 1 + *std::max_element(labels.begin(), labels.end()); | |
std::vector<double> sums(num_classes); | |
std::vector<uint32_t> counts(num_classes); | |
for(int i=0;i<labels.size();i++){ | |
const auto l = labels[i]; | |
sums[l] += weights[i]; | |
counts[l]++; | |
} | |
return {sums, counts}; | |
} | |
template<class T> | |
bool vec_nearly_equal(const std::vector<T> &a, const std::vector<T> &b){ | |
if(a.size()!=b.size()) | |
return false; | |
for(size_t i=0;i<a.size();i++){ | |
if(std::abs(a[i]-b[i])>1e-4) | |
return false; | |
} | |
return true; | |
} | |
template<typename Func> | |
SumCountRet cuda_call(const std::vector<int> &labels, const std::vector<float> &weights, Func func){ | |
const int num_classes = 1 + *std::max_element(labels.begin(), labels.end()); | |
thrust::device_vector<int32_t> d_labels(labels.size()); | |
thrust::device_vector<float> d_weights(labels.size()); | |
thrust::device_vector<double> d_sums(num_classes); | |
thrust::device_vector<uint32_t> d_counts(num_classes); | |
thrust::copy(labels.begin(), labels.end(), d_labels.begin()); | |
thrust::copy(weights.begin(), weights.end(), d_weights.begin()); | |
func(d_labels, d_weights, d_sums, d_counts); | |
std::vector<double> h_sums(num_classes); | |
std::vector<uint32_t> h_counts(num_classes); | |
thrust::copy(d_sums.begin(), d_sums.end(), h_sums.begin()); | |
thrust::copy(d_counts.begin(), d_counts.end(), h_counts.begin()); | |
return {h_sums, h_counts}; | |
} | |
void TestGroupSummer(std::mt19937 &gen, const int N, const int label_max){ | |
std::vector<int32_t> labels(N); | |
std::vector<float> weights(N); | |
std::uniform_int_distribution<int> label_dist(0, label_max); | |
std::uniform_real_distribution<float> weight_dist(0, 5000); | |
for(int i=0;i<N;i++){ | |
labels[i] = label_dist(gen); | |
weights[i] = weight_dist(gen); | |
} | |
//Shared memory kernel | |
const auto shmem_ret = cuda_call(labels, weights, []( | |
thrust::device_vector<int32_t> &d_labels, | |
thrust::device_vector<float> &d_weights, | |
thrust::device_vector<double> &d_sums, | |
thrust::device_vector<uint32_t> &d_counts | |
){ | |
constexpr int num_threads = 128; | |
const int num_blocks = (d_labels.size() + num_threads - 1)/num_threads; | |
const int shmem = num_threads * d_sums.size() * (sizeof(double)+sizeof(uint32_t)); | |
const int num_classes = d_sums.size(); | |
group_summer_shmem<<<num_blocks,num_threads,shmem>>>( | |
thrust::raw_pointer_cast(d_labels.data()), | |
thrust::raw_pointer_cast(d_weights.data()), | |
d_labels.size(), | |
num_classes, | |
thrust::raw_pointer_cast(d_sums.data()), | |
thrust::raw_pointer_cast(d_counts.data()) | |
); | |
if(cudaGetLastError()!=cudaSuccess){ | |
std::cout<<"Kernel failed to launch!"<<std::endl; | |
} | |
if(cudaDeviceSynchronize()!=cudaSuccess){ | |
std::cout<<"Error in kernel!"<<std::endl; | |
} | |
}); | |
//Shared atomic memory kernel | |
const auto shatomic_ret = cuda_call(labels, weights, []( | |
thrust::device_vector<int32_t> &d_labels, | |
thrust::device_vector<float> &d_weights, | |
thrust::device_vector<double> &d_sums, | |
thrust::device_vector<uint32_t> &d_counts | |
){ | |
constexpr int num_threads = 128; | |
const int num_blocks = (d_labels.size() + num_threads - 1)/num_threads; | |
const int shmem = d_sums.size() * (sizeof(double)+sizeof(uint32_t)); | |
const int num_classes = d_sums.size(); | |
group_summer_shatomic<<<num_blocks,num_threads,shmem>>>( | |
thrust::raw_pointer_cast(d_labels.data()), | |
thrust::raw_pointer_cast(d_weights.data()), | |
d_labels.size(), | |
num_classes, | |
thrust::raw_pointer_cast(d_sums.data()), | |
thrust::raw_pointer_cast(d_counts.data()) | |
); | |
if(cudaGetLastError()!=cudaSuccess){ | |
std::cout<<"Kernel failed to launch!"<<std::endl; | |
} | |
if(cudaDeviceSynchronize()!=cudaSuccess){ | |
std::cout<<"Error in kernel!"<<std::endl; | |
} | |
}); | |
//Global memory kernel | |
const auto global_ret = cuda_call(labels, weights, []( | |
thrust::device_vector<int32_t> &d_labels, | |
thrust::device_vector<float> &d_weights, | |
thrust::device_vector<double> &d_sums, | |
thrust::device_vector<uint32_t> &d_counts | |
){ | |
constexpr int num_threads = 128; | |
const int num_blocks = (d_labels.size() + num_threads - 1)/num_threads; | |
const int shmem = 0; | |
const int num_classes = d_sums.size(); | |
group_summer_global<<<num_blocks,num_threads,shmem>>>( | |
thrust::raw_pointer_cast(d_labels.data()), | |
thrust::raw_pointer_cast(d_weights.data()), | |
d_labels.size(), | |
num_classes, | |
thrust::raw_pointer_cast(d_sums.data()), | |
thrust::raw_pointer_cast(d_counts.data()) | |
); | |
if(cudaGetLastError()!=cudaSuccess){ | |
std::cout<<"Kernel failed to launch!"<<std::endl; | |
} | |
if(cudaDeviceSynchronize()!=cudaSuccess){ | |
std::cout<<"Error in kernel!"<<std::endl; | |
} | |
}); | |
const auto correct_ret = group_summer_cpu(labels, weights); | |
std::cout<<"shmem sums good? " <<vec_nearly_equal(shmem_ret.sums,correct_ret.sums)<<std::endl; | |
std::cout<<"shmem counts good? "<<(shmem_ret.counts==correct_ret.counts)<<std::endl; | |
std::cout<<"shatomic sums good? " <<vec_nearly_equal(shatomic_ret.sums,correct_ret.sums)<<std::endl; | |
std::cout<<"shatomic counts good? "<<(shatomic_ret.counts==correct_ret.counts)<<std::endl; | |
std::cout<<"global sums good? " <<vec_nearly_equal(global_ret.sums,correct_ret.sums)<<std::endl; | |
std::cout<<"global counts good? "<<(global_ret.counts==correct_ret.counts)<<std::endl; | |
} | |
int main(){ | |
std::mt19937 gen; | |
TestGroupSummer(gen, 10000000, 10); | |
TestGroupSummer(gen, 10000000, 10); | |
TestGroupSummer(gen, 10000000, 10); | |
TestGroupSummer(gen, 10000000, 10); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment