Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active July 26, 2025 09:31
Show Gist options
  • Save vadimkantorov/fa1306bf6307e1757b2193d31c6c0765 to your computer and use it in GitHub Desktop.
Save vadimkantorov/fa1306bf6307e1757b2193d31c6c0765 to your computer and use it in GitHub Desktop.
Inplace downcasting in PyTorch
# https://github.com/pytorch/pytorch/issues/158710
# https://github.com/pytorch/pytorch/issues/158698
# https://github.com/pytorch/pytorch/issues/69431
import torch
def to_(tensor1d, dtype, *, chunks = 0, split_size = 0):
# TODO: instead of clone() maybe could copy_ into a buffer, clone() does not allow using a buffer
# TODO: unclear if these codes can support autograd, and if so, will it remember too much in saved_for_backward
assert tensor1d.ndim == 1
assert tensor1d.dtype.itemsize % dtype.itemsize == 0
res1d = tensor1d.view(dtype)
if chunks:
k = 0
for chunk in tensor1d.chunk(chunks):
res1d[k : k + len(chunk)] = chunk.detach().clone() # RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.
k += len(chunk)
return res1d[:k]
elif split_size:
k = 0
for chunk in tensor1d.split(split_size):
res1d[k : k + len(chunk)] = chunk.detach().clone() # RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.
k += len(chunk)
return res1d[:k]
else:
step = tensor1d.dtype.itemsize // dtype.itemsize
res1d[::step].copy_(tensor1d)
return res1d[::step]
if __name__ == '__main__':
K = 10 * 1024
a = torch.ones(K, dtype = torch.float32)
print(a.sum(), a.dtype, a.dtype.itemsize, a.nbytes, a.is_contiguous())
b = to_(a, dtype = torch.bfloat16)
print(b.sum(), b.dtype, b.dtype.itemsize, b.nbytes, b.is_contiguous())
a = torch.ones(K, dtype = torch.float32)
print(a.sum(), a.dtype, a.dtype.itemsize, a.nbytes, a.is_contiguous())
b = to_(a, dtype = torch.bfloat16, chunks = 10)
print(b.sum(), b.dtype, b.dtype.itemsize, b.nbytes, b.is_contiguous())
a = torch.ones(K, dtype = torch.float32)
print(a.sum(), a.dtype, a.dtype.itemsize, a.nbytes, a.is_contiguous())
b = to_(a, dtype = torch.bfloat16, split_size = 1024)
print(b.sum(), b.dtype, b.dtype.itemsize, b.nbytes, b.is_contiguous())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment