Last active
July 26, 2025 09:31
-
-
Save vadimkantorov/fa1306bf6307e1757b2193d31c6c0765 to your computer and use it in GitHub Desktop.
Inplace downcasting in PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/pytorch/pytorch/issues/158710 | |
# https://github.com/pytorch/pytorch/issues/158698 | |
# https://github.com/pytorch/pytorch/issues/69431 | |
import torch | |
def to_(tensor1d, dtype, *, chunks = 0, split_size = 0): | |
# TODO: instead of clone() maybe could copy_ into a buffer, clone() does not allow using a buffer | |
# TODO: unclear if these codes can support autograd, and if so, will it remember too much in saved_for_backward | |
assert tensor1d.ndim == 1 | |
assert tensor1d.dtype.itemsize % dtype.itemsize == 0 | |
res1d = tensor1d.view(dtype) | |
if chunks: | |
k = 0 | |
for chunk in tensor1d.chunk(chunks): | |
res1d[k : k + len(chunk)] = chunk.detach().clone() # RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation. | |
k += len(chunk) | |
return res1d[:k] | |
elif split_size: | |
k = 0 | |
for chunk in tensor1d.split(split_size): | |
res1d[k : k + len(chunk)] = chunk.detach().clone() # RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation. | |
k += len(chunk) | |
return res1d[:k] | |
else: | |
step = tensor1d.dtype.itemsize // dtype.itemsize | |
res1d[::step].copy_(tensor1d) | |
return res1d[::step] | |
if __name__ == '__main__': | |
K = 10 * 1024 | |
a = torch.ones(K, dtype = torch.float32) | |
print(a.sum(), a.dtype, a.dtype.itemsize, a.nbytes, a.is_contiguous()) | |
b = to_(a, dtype = torch.bfloat16) | |
print(b.sum(), b.dtype, b.dtype.itemsize, b.nbytes, b.is_contiguous()) | |
a = torch.ones(K, dtype = torch.float32) | |
print(a.sum(), a.dtype, a.dtype.itemsize, a.nbytes, a.is_contiguous()) | |
b = to_(a, dtype = torch.bfloat16, chunks = 10) | |
print(b.sum(), b.dtype, b.dtype.itemsize, b.nbytes, b.is_contiguous()) | |
a = torch.ones(K, dtype = torch.float32) | |
print(a.sum(), a.dtype, a.dtype.itemsize, a.nbytes, a.is_contiguous()) | |
b = to_(a, dtype = torch.bfloat16, split_size = 1024) | |
print(b.sum(), b.dtype, b.dtype.itemsize, b.nbytes, b.is_contiguous()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment