Created
August 21, 2021 02:28
-
-
Save AnirudhDagar/ec18f7cdd5bcdd2bef2d2d2ebf16a8ba to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <ATen/ATen.h> | |
#include <ATen/Dispatch.h> | |
#include <ATen/NativeFunctions.h> | |
#include <ATen/native/Resize.h> | |
#include <ATen/native/Cross.h> | |
namespace at { namespace native { | |
DEFINE_DISPATCH(cross_stub); | |
Tensor cross(const Tensor & input, const Tensor & other, const c10::optional<int64_t> dimension) { | |
Tensor out = at::empty({0}, input.options()); | |
at::cross_out(out, input, other, dimension); | |
return out; | |
} | |
Tensor & cross_out(const Tensor & input, const Tensor & other, const c10::optional<int64_t> dimension, Tensor & out) { | |
int64_t dim = -1; | |
if(!dimension.has_value()) { | |
for(int64_t i = 0; i < input.dim(); i++) { | |
if(input.size(i) == 3) { | |
dim = i; | |
break; | |
} | |
} | |
TORCH_CHECK(dim >= 0, "no dimension of size 3 in input"); | |
} else { | |
dim = dimension.value(); | |
} | |
return at::linalg_cross_out(out, input, other, dim); | |
} | |
Tensor linalg_cross(const Tensor & input, const Tensor & other, const int64_t dimension) { | |
Tensor out = at::empty({0}, input.options()); | |
native::linalg_cross_out(input, other, dimension, out); | |
return out; | |
} | |
Tensor & linalg_cross_out(const Tensor & input, const Tensor & other, const int64_t dimension, Tensor & out) { | |
auto device1 = input.device().type(); | |
TORCH_CHECK(input.dim() == other.dim(), "inconsistent tensors dimensions input: ", input.dim(), " other: ", other.dim()); | |
TORCH_CHECK(input.sizes() == other.sizes(), "inconsistent tensors sizes input: ", input.sizes(), " other: ", other.sizes()); | |
// default dimension=-1 | |
int64_t dim = maybe_wrap_dim(dimension, input.dim()); | |
TORCH_CHECK(input.size(dim) == 3, "dimension ", dimension, " does not have size 3"); | |
// check if resizing output is required | |
// raise a warning while resizing if output has one or more elements | |
at::native::resize_output(out, input.sizes()); | |
cross_stub(device1, out, input, other, dim); | |
return out; | |
} | |
}} // namespace at::native |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment