Skip to content

Instantly share code, notes, and snippets.

@d4l3k
Created December 12, 2024 17:45
Show Gist options
  • Save d4l3k/b68094d649a076384967788c9b0a5f08 to your computer and use it in GitHub Desktop.
Save d4l3k/b68094d649a076384967788c9b0a5f08 to your computer and use it in GitHub Desktop.
torch.save/load benchmark and streaming implementation
from dataclasses import dataclass
import pickle
from io import BufferedIOBase
from typing import Tuple
import tempfile
import time
import struct
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
@dataclass
class TensorMetadata:
nbytes: int
dtype: torch.dtype
storage_offset: int
size: Tuple[int, ...]
stride: Tuple[int, ...]
def write_state_dict(state_dict: object, f: BufferedIOBase) -> None:
"""
Write the state_dict to the file-like object.
This is optimized to minimize the number of memory copies and is
significantly faster than torch.save/load as loading doesn't require the
entire serialized state_dict to be in memory.
This uses pytree to separate the structure of the state_dict from the
values/tensors and writing the structure via pickle and the underlying
tensor storage directly to the file.
Wire Format:
- pickle length: 8 bytes
- pickle data: pickle length bytes
- tensor0 storage: n bytes
- ...
- tensorN storage: m bytes
"""
# Use pytree to flatten the state_dict into the state_dict leaf values and
# the Python tree structure. This allows us to operate on all the tensors in
# the arbitrary Python state_dict using a simple for loop.
values, spec = tree_flatten(state_dict)
storages = []
non_tensor_values = []
for value in values:
if isinstance(value, torch.Tensor):
storage = value.untyped_storage()
storages.append(storage)
non_tensor_values.append(
TensorMetadata(
nbytes=storage.nbytes(),
dtype=value.dtype,
storage_offset=value.storage_offset(),
size=value.size(),
stride=value.stride(),
)
)
else:
non_tensor_values.append(value)
meta_buf = pickle.dumps((non_tensor_values, spec))
f.write(struct.pack("<q", len(meta_buf)))
f.write(meta_buf)
for storage in storages:
# This directly writes the underlying storage buffer to the file-like
# object.
# (f, is_real_file, save_size, element_size)
storage._write_file(f, False, False, 1)
def read_state_dict(f: BufferedIOBase) -> object:
"""
Read the state_dict from the file-like object.
See `write_state_dict` for the format.
"""
meta_len = struct.unpack("<q", f.read(8))[0]
non_tensor_values, spec = pickle.loads(f.read(meta_len))
values = []
for value in non_tensor_values:
if isinstance(value, TensorMetadata):
# Since we know the tensor sizes upfront we can then just read each
# tensor storage directly from the wire.
data = f.read(value.nbytes)
tensor = torch.as_strided(
# This takes ownership of the bytes object which is normally
# immutable but it's fine in this case since only PyTorch is
# using it.
torch.frombuffer(data, dtype=value.dtype),
size=value.size,
stride=value.stride,
storage_offset=value.storage_offset,
)
values.append(tensor)
else:
values.append(value)
return tree_unflatten(values, spec)
def main() -> None:
# get rid of warnings early
torch.frombuffer(b"1234", dtype=torch.float32)
print("creating state dict...")
state_dict = {}
chunk_size = 1024 * 1024 * 1024 # 64MB
total_size = 16 * 1000 * 1000 * 1000 # 16GB
for i in range(0, total_size, chunk_size):
state_dict[f"chunk_{i}"] = torch.zeros(chunk_size//4, dtype=torch.float32)
print("starting benchmark...")
for i in range(0, 10):
print(f"iteration {i}")
with tempfile.TemporaryFile() as fp:
start = time.perf_counter()
write_state_dict(state_dict, fp)
print(f"write_state_dict took {time.perf_counter() - start} seconds")
fp.seek(0)
start = time.perf_counter()
read_state_dict(fp)
print(f"read_state_dict took {time.perf_counter() - start} seconds")
with tempfile.TemporaryFile() as fp:
start = time.perf_counter()
torch.save(state_dict, fp)
print(f"torch.save took {time.perf_counter() - start} seconds")
fp.seek(0)
start = time.perf_counter()
torch.load(fp, weights_only=True)
print(f"torch.load took {time.perf_counter() - start} seconds")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment