Last active
October 14, 2018 14:48
-
-
Save t-vi/82a46dc87eceae303a4f805147f82310 to your computer and use it in GitHub Desktop.
Highly accurate batchnorm backward reductions.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
csrc = """ | |
#include <torch/extension.h> | |
#include <THC/THCDeviceUtils.cuh> | |
#include <THC/THCGeneral.h> | |
#include "ATen/ATen.h" | |
#include "ATen/AccumulateType.h" | |
#include "ATen/cuda/CUDAContext.h" | |
using namespace at; | |
#if defined(__HIP_PLATFORM_HCC__) | |
constexpr int WARP_SIZE = 64; | |
#else | |
constexpr int WARP_SIZE = 32; | |
#endif | |
// The maximum number of threads in a block | |
#if defined(__HIP_PLATFORM_HCC__) | |
constexpr int MAX_BLOCK_SIZE = 256; | |
#else | |
constexpr int MAX_BLOCK_SIZE = 512; | |
#endif | |
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE | |
static int getNumThreads(int nElem) { | |
#if defined(__HIP_PLATFORM_HCC__) | |
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; | |
#else | |
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; | |
#endif | |
for (int i = 0; i != 5; ++i) { | |
if (nElem <= threadSizes[i]) { | |
return threadSizes[i]; | |
} | |
} | |
return MAX_BLOCK_SIZE; | |
} | |
template <typename scalar_t> | |
__device__ scalar_t add_with_lower(scalar_t& rlower, const scalar_t& a, const scalar_t& b) { | |
scalar_t rupper = a + b; | |
if (fabs(a) >= fabs(b)) { | |
rlower += (a - rupper) + b; | |
} else { | |
rlower += (b - rupper) + a; | |
} | |
return rupper; | |
} | |
template<typename T> | |
__device__ __forceinline__ void reduce_block(T *x, T val, T lower) | |
{ | |
int tid = threadIdx.x; | |
int blockSize = blockDim.x; // blockSize is intended to be a multiple of 32. | |
if(blockSize >= 64) { | |
x[tid] = val; | |
x[tid + blockSize] = lower; | |
__syncthreads(); | |
} | |
#pragma unroll | |
for(int i = (blockSize >> 1); i >= 64; i >>= 1) { | |
if(tid < i) { | |
x[tid] = add_with_lower(x[tid + blockSize], x[tid], x[tid+i]); | |
__syncthreads(); | |
x[tid + blockSize] += x[tid + blockSize + i]; | |
} | |
__syncthreads(); | |
} | |
if(tid < 32) { | |
T final; | |
T final_lower; | |
if(blockSize >= 64) { | |
final_lower = x[tid + blockSize] + x[tid + blockSize + 32]; | |
__syncthreads(); | |
final = add_with_lower(final_lower, x[tid], x[tid+32]); | |
} else { | |
final = val; | |
final_lower = lower; | |
} | |
// __SYNCWARP(); | |
#pragma unroll | |
for(int i = 16; i >= 1; i >>= 1) { | |
final_lower += WARP_SHFL_DOWN(final_lower, i, 32); | |
final = add_with_lower(final_lower, final, WARP_SHFL_DOWN(final, i, 32)); | |
} | |
if(tid == 0) { | |
x[0] = final; // EpilogueOp | |
x[1] = final_lower; | |
} | |
} | |
// Make sure the smem result is visible to all warps. | |
__syncthreads(); | |
} | |
template <typename scalar_t, typename accscalar_t> | |
__global__ void sum_kernel_kahan_parallel( | |
PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits> sum, | |
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input) { | |
extern __shared__ char buf[]; // aliasing to not have type complaints from nvcc | |
accscalar_t* s = (accscalar_t*)buf; | |
const int plane = threadIdx.y + blockDim.y * blockIdx.y; | |
const int stride = blockDim.x; | |
const int offset = threadIdx.x; | |
accscalar_t sum_ = 0; | |
accscalar_t sum_lower = 0; | |
if (plane < input.size(1)) { | |
for (int64_t b = 0; b < input.size(0); b++) { | |
for (int64_t f = offset; f < input.size(2); f += stride) { | |
accscalar_t inp = input[b][plane][f]; | |
sum_ = add_with_lower(sum_lower, inp, sum_); | |
} | |
} | |
reduce_block(s+threadIdx.y * blockDim.x * 2, sum_, sum_lower); | |
if (offset == 0) { | |
sum[plane] = s[threadIdx.y * blockDim.x * 2] + s[threadIdx.y * blockDim.x * 2+ 1]; | |
} | |
} | |
} | |
template <typename scalar_t, typename accscalar_t, bool train> | |
__global__ void grad_sum_and_dot_kernel_parallel( | |
PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits> sum, | |
PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits> scalar_prod, | |
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> grad_out, | |
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input, | |
const PackedTensorAccessor<typename std::conditional<train, accscalar_t, scalar_t>::type, 1, at::RestrictPtrTraits> mean_inp) { | |
extern __shared__ char buf[]; // aliasing to not have type complaints from nvcc | |
accscalar_t* s = (accscalar_t*)buf; | |
const int plane = threadIdx.y + blockDim.y * blockIdx.y; | |
const int stride = blockDim.x; | |
const int offset = threadIdx.x; | |
accscalar_t sum_ = 0; | |
accscalar_t sum_lower = 0; | |
accscalar_t scalar_prod_ = 0; | |
accscalar_t scalar_prod_lower = 0; | |
if (plane < input.size(1)) { | |
accscalar_t mi = mean_inp[plane]; | |
for (int64_t b = 0; b < input.size(0); b++) { | |
for (int64_t f = offset; f < input.size(2); f += stride) { | |
accscalar_t go = grad_out[b][plane][f]; | |
sum_ = add_with_lower(sum_lower, go, sum_); | |
accscalar_t demeaned_inp_lower = 0; | |
accscalar_t demeaned_inp = add_with_lower(demeaned_inp_lower, static_cast<accscalar_t>(input[b][plane][f]), -mi); | |
accscalar_t g_dmil = go * demeaned_inp_lower; | |
// we skip computing the lower bits of l * rm_lower | |
accscalar_t prod = go * demeaned_inp; | |
accscalar_t prodl = fma(go, demeaned_inp, -prod); | |
scalar_prod_ = add_with_lower(scalar_prod_lower, prod, scalar_prod_); | |
scalar_prod_ = add_with_lower(scalar_prod_lower, g_dmil, scalar_prod_); | |
scalar_prod_ = add_with_lower(scalar_prod_lower, prodl, scalar_prod_); | |
} | |
} | |
reduce_block(s+threadIdx.y * blockDim.x * 2, sum_, sum_lower); | |
if (offset == 0) { | |
sum[plane] = s[threadIdx.y * blockDim.x * 2] + s[threadIdx.y * blockDim.x * 2+ 1]; | |
} | |
reduce_block(s+threadIdx.y * blockDim.x * 2, scalar_prod_, scalar_prod_lower); | |
if (offset == 0) { | |
scalar_prod[plane] = s[threadIdx.y * blockDim.x * 2] + s[threadIdx.y * blockDim.x * 2+ 1]; | |
} | |
} | |
} | |
template <typename scalar_t> | |
__global__ void sum_kernel_kahan( | |
PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> sum, | |
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input) { | |
int64_t plane = threadIdx.x + blockDim.x * blockIdx.x; | |
scalar_t sum_ = 0; | |
scalar_t lower_order = 0; | |
if (plane < input.size(1)) { | |
for (int64_t b = 0; b < input.size(0); b++) { | |
for (int64_t f = 0; f < input.size(2); f++) { | |
sum_ = add_with_lower(lower_order, input[b][plane][f], sum_); | |
} | |
} | |
sum[plane] = sum_ + lower_order; | |
} | |
} | |
template <typename scalar_t> | |
__global__ void sum_kernel_kahan2( | |
PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> sum, | |
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input) { | |
int64_t plane = threadIdx.x + blockDim.x * blockIdx.x; | |
scalar_t sum_ = 0; | |
scalar_t lower_order = 0; | |
int64_t BLOCKS = 512; | |
if (plane < input.size(1)) { | |
for (int64_t b = 0; b < input.size(0); b++) { | |
for (int64_t block = 0; block < BLOCKS; block++) { | |
scalar_t sum_local = 0; | |
scalar_t lower_order_local = 0; | |
for (int64_t f = block; f < input.size(2); f+=BLOCKS) { | |
sum_local = add_with_lower(lower_order_local, input[b][plane][f], sum_local); | |
} | |
sum_ = add_with_lower(lower_order, sum_local, sum_); | |
lower_order += lower_order_local; | |
} | |
} | |
sum[plane] = sum_ + lower_order; | |
} | |
} | |
template <typename scalar_t> | |
__global__ void scalar_product_kernel( | |
PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> sum, | |
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> grad_out, | |
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> inp, | |
const PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> mean_inp | |
) { | |
int64_t plane = threadIdx.x + blockDim.x * blockIdx.x; | |
scalar_t res = 0; | |
scalar_t res_lower = 0; | |
if (plane < grad_out.size(1)) { | |
scalar_t m = mean_inp[plane]; | |
for (int64_t b = 0; b < grad_out.size(0); b++) { | |
for (int64_t f = 0; f < grad_out.size(2); f++) { | |
scalar_t rm_lower = 0; | |
scalar_t rm = add_with_lower(rm_lower, inp[b][plane][f], -m); | |
scalar_t l = grad_out[b][plane][f]; | |
scalar_t lrml = l * rm_lower; | |
// we skip computing the lower bits of l * rm_lower | |
scalar_t prod = l * rm; | |
scalar_t prodl = fma(l, rm, -prod); | |
res = add_with_lower(res_lower, prod, res); | |
res = add_with_lower(res_lower, lrml, res); | |
res = add_with_lower(res_lower, prodl, res); | |
} | |
} | |
sum[plane] = res + res_lower; | |
} | |
} | |
template<typename scalar_t> | |
Tensor sum_template(const Tensor& input) { | |
Tensor sum = empty({input.size(1)}, input.options()); | |
using accscalar_t = acc_type<scalar_t, true>; | |
constexpr int MAX_THREADS = 512; | |
int feature_threads = ((std::min<int>(input.size(2), MAX_THREADS)+31)/32)*32; // round to multiples of 32 | |
int plane_threads = std::max<int>(1, MAX_THREADS/feature_threads); | |
int smem_size = sizeof(accscalar_t)*plane_threads*feature_threads*2; | |
dim3 threads(feature_threads, plane_threads); | |
dim3 blocks(1, (input.size(1)+plane_threads-1)/plane_threads); | |
sum_kernel_kahan_parallel<scalar_t, accscalar_t><<<blocks, threads, smem_size>>>(sum.packed_accessor<accscalar_t, 1, at::RestrictPtrTraits>(), | |
input.packed_accessor<scalar_t, 3, at::RestrictPtrTraits>()); | |
return sum; | |
} | |
template<typename scalar_t> | |
std::tuple<Tensor,Tensor> sum_and_scalar_prod_template(const Tensor& grad_out, const Tensor& input, const Tensor& mean_inp) { | |
Tensor sum = at::zeros({input.size(1)}, input.options()); | |
Tensor scalar_prod = at::empty({input.size(1)}, input.options()); | |
using accscalar_t = acc_type<scalar_t, true>; | |
constexpr int MAX_THREADS = 512; | |
int feature_threads = ((std::min<int>(input.size(2), MAX_THREADS)+31)/32)*32; // round to multiples of 32 | |
int plane_threads = std::max<int>(1, MAX_THREADS/feature_threads); | |
int smem_size = sizeof(accscalar_t)*plane_threads*feature_threads*2; | |
dim3 threads(feature_threads, plane_threads); | |
dim3 blocks(1, (input.size(1)+plane_threads-1)/plane_threads); | |
grad_sum_and_dot_kernel_parallel<scalar_t, accscalar_t,true><<<blocks, threads, smem_size>>>( | |
sum.packed_accessor<accscalar_t, 1, at::RestrictPtrTraits>(), | |
scalar_prod.packed_accessor<accscalar_t, 1, at::RestrictPtrTraits>(), | |
grad_out.packed_accessor<scalar_t, 3, at::RestrictPtrTraits>(), | |
input.packed_accessor<scalar_t, 3, at::RestrictPtrTraits>(), | |
mean_inp.packed_accessor<accscalar_t, 1, at::RestrictPtrTraits>() | |
); | |
return std::make_tuple(sum, scalar_prod); | |
} | |
template<typename scalar_t> | |
Tensor scalar_product_template(const Tensor& l, const Tensor& r, const Tensor& mean) { | |
Tensor res = empty({l.size(1)}, l.options()); | |
dim3 threads(std::min<int>(l.size(1), 512)); | |
dim3 blocks(std::max<int>(1, (l.size(1)+511)/ 512)); | |
scalar_product_kernel<scalar_t><<<blocks, threads>>>(res.packed_accessor<scalar_t, 1, at::RestrictPtrTraits>(), | |
l.packed_accessor<scalar_t, 3, at::RestrictPtrTraits>(), | |
r.packed_accessor<scalar_t, 3, at::RestrictPtrTraits>(), | |
mean.packed_accessor<scalar_t, 1, at::RestrictPtrTraits>() | |
); | |
return res; | |
} | |
// TensorAccessor in which the last dimensions are collapsed or expanded as needed | |
template <typename scalar_t, int64_t dim> | |
static PackedTensorAccessor<scalar_t, dim, at::RestrictPtrTraits> reshaped_packed_accessor(const Tensor& t) { | |
// undefined... | |
if (! t.defined()) { | |
const std::vector<int64_t> zeros(dim); | |
return PackedTensorAccessor<scalar_t, dim, at::RestrictPtrTraits>(nullptr, zeros.data(), zeros.data()); | |
} | |
int64_t in_dim = t.dim(); | |
if (in_dim == dim) { | |
return t.packed_accessor<scalar_t, dim, at::RestrictPtrTraits>(); | |
} | |
AT_CHECK(in_dim < dim || t.is_contiguous(), "need contiguous or <= 3d tensor"); | |
std::vector<int64_t> sizes(dim); | |
std::vector<int64_t> strides(dim); | |
for (int i = 0; i < in_dim || i < dim; ++i) { | |
if (i < dim && i < in_dim) { | |
sizes[i] = t.size(i); | |
strides[i] = t.stride(i); | |
} else if (i < dim) { | |
sizes[i] = 1; | |
strides[i] = 0; | |
} else { | |
sizes[dim - 1] *= t.size(i); | |
strides[dim -1] = 1; | |
} | |
} | |
// evil trick to get adjusted 2d tensors to have large dimension last | |
if (dim == 3 && sizes[0] > sizes[2]) { | |
std::swap(sizes[0], sizes[2]); | |
std::swap(strides[0], strides[2]); | |
} | |
return PackedTensorAccessor<scalar_t, dim, at::RestrictPtrTraits>(t.data<scalar_t>(), sizes.data(), strides.data()); | |
} | |
template <typename scalar_t, typename accscalar_t, bool train> | |
__global__ void batch_norm_backward_gradient_kernel( | |
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input, | |
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> grad_output, | |
PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> grad_input, | |
PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> grad_weight, | |
PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> grad_bias, | |
const PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits> grad_out_sum, | |
const PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits> grad_out_dot_demeaned_input, | |
const PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> weight, | |
const PackedTensorAccessor<typename std::conditional<train, accscalar_t, scalar_t>::type, 1, at::RestrictPtrTraits> mean_, | |
const PackedTensorAccessor<typename std::conditional<train, accscalar_t, scalar_t>::type, 1, at::RestrictPtrTraits> var_or_invstd, | |
accscalar_t epsilon) { | |
int plane = blockIdx.y * blockDim.y + threadIdx.y; | |
if (plane >= input.size(1)) { | |
return; | |
} | |
int N = grad_output.size(0) * grad_output.size(2); | |
accscalar_t gamma = weight.size(0) > 0 ? static_cast<accscalar_t>(weight[plane]) : static_cast<accscalar_t>(1); | |
//accscalar_t beta = bias.size(0) > 0 ? static_cast<accscalar_t>(bias[plane]) : static_cast<accscalar_t>(0); | |
accscalar_t mean = static_cast<accscalar_t>(mean_[plane]); | |
accscalar_t invstd; | |
if (train) { | |
invstd = var_or_invstd[plane]; | |
} else { | |
invstd = static_cast<accscalar_t>(1) / std::sqrt(static_cast<accscalar_t>(var_or_invstd[plane]) + epsilon); | |
} | |
accscalar_t weight_val = weight.size(0) > 0 ? static_cast<accscalar_t>(weight[plane]) : accscalar_t(1); | |
accscalar_t norm = accscalar_t(1) / N; | |
accscalar_t grad_output_sum = grad_out_sum[plane]; | |
accscalar_t dot_p = grad_out_dot_demeaned_input[plane]; | |
accscalar_t grad_mean = grad_output_sum * norm; | |
accscalar_t proj_scale = dot_p * norm * invstd * invstd; | |
accscalar_t grad_scale = invstd * weight_val; | |
if (grad_input.data() != NULL) { | |
for (int64_t batch = blockIdx.x; batch < input.size(0); batch += gridDim.x) { | |
for (int64_t feature = blockIdx.z; feature < input.size(2); feature += gridDim.z) { | |
scalar_t go = grad_output[batch][plane][feature]; | |
if (train) { | |
scalar_t inp = input[batch][plane][feature]; | |
accscalar_t proj = (inp - mean) * proj_scale; | |
grad_input[batch][plane][feature] = static_cast<scalar_t>((go - proj - grad_mean) * grad_scale); | |
} else { | |
grad_input[batch][plane][feature] = static_cast<scalar_t>(go * grad_scale); | |
} | |
} | |
} | |
} | |
if (grad_weight.size(0) > 0) { | |
if (threadIdx.x == 0) { | |
grad_weight[plane] = static_cast<scalar_t>(dot_p * invstd); | |
} | |
} | |
if (grad_bias.size(0) > 0) { | |
if (threadIdx.x == 0) { | |
grad_bias[plane] = static_cast<scalar_t>(grad_output_sum); | |
} | |
} | |
} | |
template<typename scalar_t> | |
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template( | |
const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_, | |
const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_, | |
bool train, double epsilon, std::array<bool,3> grad_input_mask) { | |
using accscalar_t = at::acc_type<scalar_t, true>; | |
Tensor grad_input_; | |
Tensor grad_weight_; | |
Tensor grad_bias_; | |
auto input_options = input_.options(); | |
if (grad_input_mask[0]) { | |
grad_input_ = at::empty_like(input_); | |
} | |
if (grad_input_mask[1]) { | |
grad_weight_ = at::empty(input_.size(1), input_options); | |
} | |
if (grad_input_mask[2]) { | |
grad_bias_ = at::empty(input_.size(1), input_options); | |
} | |
if (input_options.dtype() == ScalarType::Half) { | |
input_options.dtype(ScalarType::Float); | |
} | |
Tensor grad_out_sum_ = at::empty(input_.size(1), input_options); | |
Tensor grad_out_dot_demeaned_input_ = at::empty(input_.size(1), input_options); | |
auto grad_output = reshaped_packed_accessor<scalar_t, 3>(grad_out_); | |
auto input = reshaped_packed_accessor<scalar_t, 3>(input_); | |
auto grad_input = reshaped_packed_accessor<scalar_t, 3>(grad_input_); | |
auto weight = reshaped_packed_accessor<scalar_t, 1>(weight_); | |
auto grad_weight = reshaped_packed_accessor<scalar_t, 1>(grad_weight_); | |
auto grad_bias = reshaped_packed_accessor<scalar_t, 1>(grad_bias_); | |
auto running_mean = reshaped_packed_accessor<scalar_t, 1>(running_mean_); | |
auto running_var = reshaped_packed_accessor<scalar_t, 1>( running_var_); | |
auto save_mean = reshaped_packed_accessor<accscalar_t, 1>(save_mean_); | |
auto save_invstd = reshaped_packed_accessor<accscalar_t, 1>(save_invstd_); | |
auto grad_out_sum = reshaped_packed_accessor<accscalar_t, 1>(grad_out_sum_); | |
auto grad_out_dot_demeaned_input = reshaped_packed_accessor<accscalar_t, 1>(grad_out_dot_demeaned_input_); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
dim3 blocks(input.size(1)); | |
dim3 threads(getNumThreads(input.size(2))); | |
constexpr int MAX_THREADS = 512; | |
int feature_threads = ((std::min<int>(input.size(2), MAX_THREADS)+31)/32)*32; // round to multiples of 32 | |
int plane_threads = std::max<int>(1, MAX_THREADS/feature_threads); | |
int smem_size = sizeof(accscalar_t)*plane_threads*feature_threads*2; | |
dim3 threads_red(feature_threads, plane_threads); | |
dim3 blocks_red(1, (input.size(1)+plane_threads-1)/plane_threads); | |
if (train) { | |
grad_sum_and_dot_kernel_parallel<scalar_t, accscalar_t, true><<<blocks_red, threads_red, smem_size>>>( | |
grad_out_sum, grad_out_dot_demeaned_input, | |
grad_output, input, save_mean); | |
} else { | |
grad_sum_and_dot_kernel_parallel<scalar_t, accscalar_t, false><<<blocks_red, threads_red, smem_size>>>( | |
grad_out_sum, grad_out_dot_demeaned_input, | |
grad_output, input, running_mean); | |
} | |
{ | |
constexpr int max_blocks_per_input = 60000; | |
int feature_blocks = std::min<int>(input.size(2), max_blocks_per_input); | |
int batch_blocks = std::min<int>(input.size(0), max_blocks_per_input / feature_blocks); | |
dim3 blocks(batch_blocks, (input.size(1)+127)/128, feature_blocks); | |
dim3 threads(1, 128); | |
if (train) { | |
batch_norm_backward_gradient_kernel<scalar_t, accscalar_t, true> <<<blocks, threads, 0, stream>>> | |
(input, grad_output, grad_input, grad_weight, grad_bias, grad_out_sum, grad_out_dot_demeaned_input, | |
weight, save_mean, save_invstd, epsilon); | |
} else { | |
batch_norm_backward_gradient_kernel<scalar_t, accscalar_t, false> <<<blocks, threads, 0, stream>>> | |
(input, grad_output, grad_input, grad_weight, grad_bias, grad_out_sum, grad_out_dot_demeaned_input, | |
weight, running_mean, running_var, epsilon); | |
} | |
} | |
THCudaCheck(cudaGetLastError()); | |
return std::make_tuple(grad_input_, grad_weight_, grad_bias_); | |
} | |
Tensor sum_cuda(const Tensor& input) { | |
return AT_DISPATCH_FLOATING_TYPES(input.type(), "sum_cuda", [&] { | |
return sum_template<scalar_t>(input); | |
}); | |
} | |
std::tuple<Tensor,Tensor> sum_and_scalar_prod_cuda(const Tensor& grad_out, const Tensor& input, const Tensor& mean_inp) { | |
return AT_DISPATCH_FLOATING_TYPES(input.type(), "sum_and_scalar_prod", [&] { | |
return sum_and_scalar_prod_template<scalar_t>(grad_out, input, mean_inp); | |
}); | |
} | |
Tensor scalar_product_cuda(const Tensor& l, const Tensor& r, const Tensor& mean) { | |
return AT_DISPATCH_FLOATING_TYPES(l.type(), "scalar_product_cuda", [&] { | |
return scalar_product_template<scalar_t>(l, r, mean); | |
}); | |
} | |
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& self, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var, | |
const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array<bool,3> grad_input_mask) { | |
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] { | |
return batch_norm_backward_cuda_template<scalar_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); | |
}); | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("sum_cuda", &sum_cuda, "blubb"); | |
m.def("scalar_product", &scalar_product_cuda, "blubb"); | |
m.def("sum_and_scalar_prod", &sum_and_scalar_prod_cuda, "blubb"); | |
m.def("batch_norm_backward", &batch_norm_backward_cuda, "blubb"); | |
} | |
""" | |
import hashlib | |
import torch | |
import torch.utils.cpp_extension | |
name = "test" | |
sum_ext = torch.utils.cpp_extension.load_inline(name, [], cuda_sources=[csrc], verbose=True) | |
grads3 = sum_ext.batch_norm_backward(grad_o, inp, weight, running_mean, running_var, sm3, sis3, True, 1e-5, [True, True, True]) | |
grads4 = sum_ext.batch_norm_backward(grad_o.double(), inp.double(), weight.double(), running_mean.double(), running_var.double(), sm3.double(), sis3.double(), True, 1e-5, [True, True, True]) | |
for g1, g2 in zip(grads3, grads2): | |
print (seed, (g1-g2.float()).abs().max().item()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment