Created
May 30, 2020 21:54
-
-
Save dlibenzi/64419788b8f0e029298811c8e1ebbc38 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch_xla | |
import torch_xla.core.xla_builder as xb | |
import torch_xla.core.xla_op_registry as xor | |
import torch_xla.core.xla_model as xm | |
import torch_xla.distributed.xla_multiprocessing as xmp | |
def _split_indices(index): | |
ishape = index.shape() | |
assert ishape.rank == 1 | |
indices = [] | |
for dim in range(0, ishape.sizes[0]): | |
indices.append(index.slice_in_dim(dim, dim + 1, 0).reshape([])) | |
return indices | |
def _dynamic_slice_forward(input, start_indices, slice_sizes=None): | |
return input.dynamic_slice(_split_indices(start_indices), slice_sizes) | |
def _dynamic_slice_backward(grad_output, input, start_indices, slice_sizes=None): | |
return input.zeros_like().dynamic_update_slice(grad_output, _split_indices(start_indices)) | |
DYNAMIC_SLICE_FORWARD = xor.register('DynamicSliceForward', _dynamic_slice_forward) | |
DYNAMIC_SLICE_BACKWARD = xor.register('DynamicSliceBackward', _dynamic_slice_backward) | |
class DynamicSlice(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, input, start_indices, slice_sizes): | |
ctx.slice_sizes = slice_sizes | |
output = DYNAMIC_SLICE_FORWARD(input, start_indices, slice_sizes=slice_sizes) | |
ctx.save_for_backward(input, start_indices) | |
return output | |
@staticmethod | |
def backward(ctx, grad_output): | |
input, start_indices = ctx.saved_tensors | |
grad = DYNAMIC_SLICE_BACKWARD(grad_output, input, start_indices, | |
slice_sizes=ctx.slice_sizes) | |
return grad, None, None | |
def dynamic_slice(input, start_indices, slice_sizes): | |
return DynamicSlice.apply(input, start_indices, slice_sizes) | |
def _mp_fn(index): | |
device = xm.xla_device() | |
x = torch.randn(6, 8, device=device, requires_grad=True) | |
index = torch.tensor([2, 4], dtype=torch.int32, device=device) | |
out = dynamic_slice(x, index, (2, 3)) | |
loss = out.pow(2).sum() | |
loss.backward() | |
print(x.grad.cpu()) | |
if __name__ == '__main__': | |
xmp.spawn(_mp_fn, nprocs=None) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment