Skip to content

Instantly share code, notes, and snippets.

@AnirudhDagar
Created August 21, 2021 02:28
Show Gist options
  • Save AnirudhDagar/ec18f7cdd5bcdd2bef2d2d2ebf16a8ba to your computer and use it in GitHub Desktop.
Save AnirudhDagar/ec18f7cdd5bcdd2bef2d2d2ebf16a8ba to your computer and use it in GitHub Desktop.
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/Resize.h>
#include <ATen/native/Cross.h>
namespace at { namespace native {
DEFINE_DISPATCH(cross_stub);
Tensor cross(const Tensor & input, const Tensor & other, const c10::optional<int64_t> dimension) {
Tensor out = at::empty({0}, input.options());
at::cross_out(out, input, other, dimension);
return out;
}
Tensor & cross_out(const Tensor & input, const Tensor & other, const c10::optional<int64_t> dimension, Tensor & out) {
int64_t dim = -1;
if(!dimension.has_value()) {
for(int64_t i = 0; i < input.dim(); i++) {
if(input.size(i) == 3) {
dim = i;
break;
}
}
TORCH_CHECK(dim >= 0, "no dimension of size 3 in input");
} else {
dim = dimension.value();
}
return at::linalg_cross_out(out, input, other, dim);
}
Tensor linalg_cross(const Tensor & input, const Tensor & other, const int64_t dimension) {
Tensor out = at::empty({0}, input.options());
native::linalg_cross_out(input, other, dimension, out);
return out;
}
Tensor & linalg_cross_out(const Tensor & input, const Tensor & other, const int64_t dimension, Tensor & out) {
auto device1 = input.device().type();
TORCH_CHECK(input.dim() == other.dim(), "inconsistent tensors dimensions input: ", input.dim(), " other: ", other.dim());
TORCH_CHECK(input.sizes() == other.sizes(), "inconsistent tensors sizes input: ", input.sizes(), " other: ", other.sizes());
// default dimension=-1
int64_t dim = maybe_wrap_dim(dimension, input.dim());
TORCH_CHECK(input.size(dim) == 3, "dimension ", dimension, " does not have size 3");
// check if resizing output is required
// raise a warning while resizing if output has one or more elements
at::native::resize_output(out, input.sizes());
cross_stub(device1, out, input, other, dim);
return out;
}
}} // namespace at::native
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment