Last active
September 13, 2025 06:23
-
-
Save kohya-ss/fa4b7ae7119c10850ae7d70c90a59277 to your computer and use it in GitHub Desktop.
メインメモリを消費しないsafetensorsファイル読み込み・保存
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: Apache 2.0 | |
from typing import Dict, Optional | |
import struct | |
import json | |
import numpy as np | |
import torch | |
class MemoryEfficientSafeOpen: | |
"""Memory-efficient reader for safetensors files. | |
This class provides a memory-efficient way to read tensors from safetensors files | |
by using memory mapping for large tensors and avoiding unnecessary copies. | |
""" | |
def __init__(self, filename): | |
"""Initialize the SafeTensor reader. | |
Args: | |
filename (str): Path to the safetensors file to read. | |
""" | |
self.filename = filename | |
self.file = open(filename, "rb") | |
self.header, self.header_size = self._read_header() | |
def __enter__(self): | |
"""Enter context manager.""" | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
"""Exit context manager and close file.""" | |
self.file.close() | |
def keys(self): | |
"""Get all tensor keys in the file. | |
Returns: | |
list: List of tensor names (excludes metadata). | |
""" | |
return [k for k in self.header.keys() if k != "__metadata__"] | |
def metadata(self) -> Dict[str, str]: | |
"""Get metadata from the file. | |
Returns: | |
Dict[str, str]: Metadata dictionary. | |
""" | |
return self.header.get("__metadata__", {}) | |
def _read_header(self): | |
"""Read and parse the header from the safetensors file. | |
Returns: | |
tuple: (header_dict, header_size) containing parsed header and its size. | |
""" | |
# Read header size (8 bytes, little-endian unsigned long long) | |
header_size = struct.unpack("<Q", self.file.read(8))[0] | |
# Read and decode header JSON | |
header_json = self.file.read(header_size).decode("utf-8") | |
return json.loads(header_json), header_size | |
def get_tensor(self, key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): | |
"""Load a tensor from the file with memory-efficient strategies. | |
**Note:** | |
If device is 'cuda' , the transfer to GPU is done efficiently using pinned memory and non-blocking transfer. | |
So you must ensure that the transfer is completed before using the tensor (e.g., by `torch.cuda.synchronize()`). | |
If the tensor is large (>10MB) and the target device is CUDA, memory mapping with numpy.memmap is used to avoid intermediate copies. | |
Args: | |
key (str): Name of the tensor to load. | |
device (Optional[torch.device]): Target device for the tensor. | |
dtype (Optional[torch.dtype]): Target dtype for the tensor. | |
Returns: | |
torch.Tensor: The loaded tensor. | |
Raises: | |
KeyError: If the tensor key is not found in the file. | |
""" | |
if key not in self.header: | |
raise KeyError(f"Tensor '{key}' not found in the file") | |
metadata = self.header[key] | |
offset_start, offset_end = metadata["data_offsets"] | |
num_bytes = offset_end - offset_start | |
original_dtype = self._get_torch_dtype(metadata["dtype"]) | |
target_dtype = dtype if dtype is not None else original_dtype | |
# Handle empty tensors | |
if num_bytes == 0: | |
return torch.empty(metadata["shape"], dtype=target_dtype, device=device) | |
# Determine if we should use pinned memory for GPU transfer | |
pin_to_gpu = device is not None and device.type == "cuda" # *** Set to False here to avoid using shared GPU memory *** | |
non_blocking = device is not None and device.type == "cuda" | |
# Calculate absolute file offset | |
tensor_offset = self.header_size + 8 + offset_start # adjust offset by header size | |
# Memory mapping strategy for large tensors to GPU | |
# Use memmap for large tensors to avoid intermediate copies. | |
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired. | |
# So we only use memmap if device is not cpu. | |
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu": | |
# Create memory map for zero-copy reading | |
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,)) | |
byte_tensor = torch.from_numpy(mm) # zero copy | |
del mm | |
# Deserialize tensor (view and reshape) | |
cpu_tensor = self._deserialize_tensor(byte_tensor, metadata) # view and reshape | |
del byte_tensor | |
# Pin memory for faster GPU transfer | |
if pin_to_gpu: | |
cpu_tensor = cpu_tensor.pin_memory() | |
# Transfer to target device and dtype | |
gpu_tensor = cpu_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking) | |
del cpu_tensor | |
return gpu_tensor | |
# Standard file reading strategy for smaller tensors or CPU target | |
# seek to the specified position | |
self.file.seek(tensor_offset) | |
# read directly into a numpy array by numpy.fromfile without intermediate copy | |
numpy_array = np.fromfile(self.file, dtype=np.uint8, count=num_bytes) | |
byte_tensor = torch.from_numpy(numpy_array) | |
del numpy_array | |
# deserialize (view and reshape) | |
deserialized_tensor = self._deserialize_tensor(byte_tensor, metadata) | |
del byte_tensor | |
# Pin memory for GPU transfer if needed | |
if pin_to_gpu: | |
deserialized_tensor = deserialized_tensor.pin_memory() | |
# cast to target dtype and move to device | |
return deserialized_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking) | |
def _deserialize_tensor(self, byte_tensor: torch.Tensor, metadata: Dict): | |
"""Deserialize byte tensor to the correct shape and dtype. | |
Args: | |
byte_tensor (torch.Tensor): Raw byte tensor from file. | |
metadata (Dict): Tensor metadata containing dtype and shape info. | |
Returns: | |
torch.Tensor: Deserialized tensor with correct shape and dtype. | |
""" | |
dtype = self._get_torch_dtype(metadata["dtype"]) | |
shape = metadata["shape"] | |
# Handle special float8 types | |
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: | |
return self._convert_float8(byte_tensor, metadata["dtype"], shape) | |
# Standard conversion: view as target dtype and reshape | |
return byte_tensor.view(dtype).reshape(shape) | |
@staticmethod | |
def _get_torch_dtype(dtype_str): | |
"""Convert string dtype to PyTorch dtype. | |
Args: | |
dtype_str (str): String representation of the dtype. | |
Returns: | |
torch.dtype: Corresponding PyTorch dtype. | |
""" | |
# Standard dtype mappings | |
dtype_map = { | |
"F64": torch.float64, | |
"F32": torch.float32, | |
"F16": torch.float16, | |
"BF16": torch.bfloat16, | |
"I64": torch.int64, | |
"I32": torch.int32, | |
"I16": torch.int16, | |
"I8": torch.int8, | |
"U8": torch.uint8, | |
"BOOL": torch.bool, | |
} | |
# Add float8 types if available in PyTorch version | |
if hasattr(torch, "float8_e5m2"): | |
dtype_map["F8_E5M2"] = torch.float8_e5m2 | |
if hasattr(torch, "float8_e4m3fn"): | |
dtype_map["F8_E4M3"] = torch.float8_e4m3fn | |
return dtype_map.get(dtype_str) | |
@staticmethod | |
def _convert_float8(byte_tensor, dtype_str, shape): | |
"""Convert byte tensor to float8 format if supported. | |
Args: | |
byte_tensor (torch.Tensor): Raw byte tensor. | |
dtype_str (str): Float8 dtype string ("F8_E5M2" or "F8_E4M3"). | |
shape (tuple): Target tensor shape. | |
Returns: | |
torch.Tensor: Tensor with float8 dtype. | |
Raises: | |
ValueError: If float8 type is not supported in current PyTorch version. | |
""" | |
# Convert to specific float8 types if available | |
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): | |
return byte_tensor.view(torch.float8_e5m2).reshape(shape) | |
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): | |
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) | |
else: | |
# Float8 not supported in this PyTorch version | |
# # convert to float16 if float8 is not supported | |
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.") | |
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) | |
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import torch | |
import tempfile | |
import os | |
from safetensors import safe_open | |
from safetensors.torch import save_file | |
from mem_eff_safeopen import MemoryEfficientSafeOpen | |
class TestMemoryEfficientSafeOpen(unittest.TestCase): | |
def setUp(self): | |
self.test_tensors = { | |
"float32": torch.randn(10, 20).float(), | |
"float16": torch.randn(5, 15).half(), | |
"int64": torch.randint(-100, 100, (8, 12)).long(), | |
"bool": torch.randint(0, 2, (6, 6)).bool(), | |
"empty": torch.empty(0, 10), | |
"scalar": torch.tensor(3.14), | |
} | |
if hasattr(torch, "bfloat16"): | |
self.test_tensors["bfloat16"] = torch.randn(7, 9).to(torch.bfloat16) | |
if hasattr(torch, "float8_e5m2"): | |
self.test_tensors["float8_e5m2"] = torch.randn(4, 8).to(torch.float8_e5m2) | |
if hasattr(torch, "float8_e4m3fn"): | |
self.test_tensors["float8_e4m3fn"] = torch.randn(3, 7).to(torch.float8_e4m3fn) | |
def test_tensor_loading(self): | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
tmp_filename = tmp.name | |
try: | |
# 1. テスト用の.safetensorsファイルを作成 | |
save_file(self.test_tensors, tmp_filename) | |
# 2. 公式safetensorsとMemoryEfficientSafeOpenで読み込み、比較 | |
with safe_open(tmp_filename, framework="pt", device="cpu") as f: | |
official_tensors = {key: f.get_tensor(key) for key in f.keys()} | |
with MemoryEfficientSafeOpen(tmp_filename) as f: | |
efficient_tensors = {key: f.get_tensor(key) for key in f.keys()} | |
# 3. 各テンソルについて比較 | |
for key in self.test_tensors.keys(): | |
dtype = self.test_tensors[key].dtype | |
if "float8" in str(dtype): | |
# float8型の場合はtorch.allcloseが使えないので、要素ごとに比較 | |
for a, b in zip(official_tensors[key].view(-1), efficient_tensors[key].view(-1)): | |
self.assertAlmostEqual(a.item(), b.item(), delta=1e-2) | |
else: | |
self.assertTrue(torch.allclose(official_tensors[key], efficient_tensors[key], atol=1e-5, rtol=1e-3)) | |
self.assertEqual(official_tensors[key].shape, efficient_tensors[key].shape) | |
self.assertEqual(official_tensors[key].dtype, efficient_tensors[key].dtype) | |
finally: | |
os.unlink(tmp_filename) | |
def test_tensor_loading_dtype(self): | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
tmp_filename = tmp.name | |
dtype = torch.float16 | |
try: | |
# 1. テスト用の.safetensorsファイルを作成 | |
save_file(self.test_tensors, tmp_filename) | |
# 2. 公式safetensorsとMemoryEfficientSafeOpenで読み込み、比較 | |
with safe_open(tmp_filename, framework="pt", device="cpu") as f: | |
official_tensors = {key: f.get_tensor(key).to(dtype) for key in f.keys()} | |
with MemoryEfficientSafeOpen(tmp_filename) as f: | |
efficient_tensors = {key: f.get_tensor(key, dtype=dtype) for key in f.keys()} | |
# 3. 各テンソルについて比較 | |
for key in self.test_tensors.keys(): | |
dtype = self.test_tensors[key].dtype | |
self.assertEqual(efficient_tensors[key].dtype, torch.float16) | |
self.assertTrue(torch.allclose(official_tensors[key], efficient_tensors[key], atol=1e-5, rtol=1e-3)) | |
self.assertEqual(official_tensors[key].shape, efficient_tensors[key].shape) | |
self.assertEqual(official_tensors[key].dtype, efficient_tensors[key].dtype) | |
finally: | |
os.unlink(tmp_filename) | |
def test_memory_efficiency(self): | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
tmp_filename = tmp.name | |
try: | |
# 大きなテンソルを作成 | |
num_tensors = 10 | |
large_tensors = {f"large_{i}": torch.randn(10000, 1000) for i in range(num_tensors)} | |
save_file(large_tensors, tmp_filename) | |
# メモリ使用量を測定(簡易的な方法) | |
import psutil | |
import gc | |
process = psutil.Process() | |
def get_memory_usage(): | |
return process.memory_info().rss / 1024 / 1024 # MB単位 | |
# 公式safetensorsでの読み込み | |
gc.collect() | |
mem_before = get_memory_usage() | |
with safe_open(tmp_filename, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
t = f.get_tensor(key) | |
t = t.mul(2) # 何か操作を行い実際にメモリに読み込む | |
del t | |
gc.collect() | |
mem_after_official = get_memory_usage() | |
# MemoryEfficientSafeOpenでの読み込み | |
gc.collect() | |
mem_before = get_memory_usage() | |
with MemoryEfficientSafeOpen(tmp_filename) as f: | |
for key in f.keys(): | |
t = f.get_tensor(key) | |
t = t.mul(2) # すでに読み込まれている | |
del t | |
gc.collect() | |
mem_after_efficient = get_memory_usage() | |
# メモリ使用量の比較 | |
self.assertLess(mem_after_efficient - mem_before, mem_after_official - mem_before) | |
finally: | |
os.unlink(tmp_filename) | |
def test_cuda_device(self): | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
tmp_filename = tmp.name | |
if not torch.cuda.is_available(): | |
self.skipTest("CUDAが利用できないためスキップ") | |
device = torch.device("cuda") | |
try: | |
# 1. 大きなテンソルを作成:MemoryEfficientSafeOpenのCUDAサポートは大きなテンソルでのみ実行される | |
test_tensors = {} | |
for i, (key, tensor) in enumerate(self.test_tensors.items()): | |
test_tensors[f"large_{i}"] = ( | |
torch.randn(10000, 1000, dtype=tensor.dtype) | |
if tensor.dtype.is_floating_point and tensor.dtype.itemsize >= 2 | |
else torch.randint(-100, 100, (10000, 1000)).to(tensor.dtype) # supports int, fp8 and bool | |
) | |
# いくつかの小さいテンソルも追加 | |
test_tensors.update({f"small_{i}": torch.randn(10, 10) for i in range(5)}) | |
save_file(test_tensors, tmp_filename) | |
# 2. 公式safetensorsとMemoryEfficientSafeOpenで読み込み、比較 | |
with safe_open(tmp_filename, framework="pt", device="cpu") as f: | |
official_tensors = {key: f.get_tensor(key).to(device) for key in f.keys()} | |
with MemoryEfficientSafeOpen(tmp_filename) as f: | |
efficient_tensors = {key: f.get_tensor(key, device=device) for key in f.keys()} | |
# 3. 各テンソルについて比較 | |
for key in test_tensors.keys(): | |
dtype = test_tensors[key].dtype | |
if "float8" in str(dtype): | |
# # float8型の場合はtorch.allcloseが使えないので、要素ごとに比較 | |
# for a, b in zip(official_tensors[key].view(-1), efficient_tensors[key].view(-1)): | |
# self.assertAlmostEqual(a.item(), b.item(), delta=1e-2) | |
# 大きいテンソルだと要素ごとの比較は時間がかかるので、float16に変換して比較 | |
official_fp16 = official_tensors[key].to(torch.float16) | |
efficient_fp16 = efficient_tensors[key].to(torch.float16) | |
self.assertTrue(torch.allclose(official_fp16, efficient_fp16, atol=1e-2, rtol=1e-2)) | |
else: | |
self.assertTrue(torch.allclose(official_tensors[key], efficient_tensors[key], atol=1e-5, rtol=1e-3)) | |
self.assertEqual(official_tensors[key].shape, efficient_tensors[key].shape) | |
self.assertEqual(official_tensors[key].dtype, efficient_tensors[key].dtype) | |
self.assertEqual(official_tensors[key].device, efficient_tensors[key].device) | |
self.assertEqual(official_tensors[key].device.type, "cuda") | |
finally: | |
os.unlink(tmp_filename) | |
if __name__ == "__main__": | |
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: Apache 2.0 | |
import torch | |
import json | |
import struct | |
from typing import Dict, Any | |
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): | |
_TYPES = { | |
torch.float64: "F64", | |
torch.float32: "F32", | |
torch.float16: "F16", | |
torch.bfloat16: "BF16", | |
torch.int64: "I64", | |
torch.int32: "I32", | |
torch.int16: "I16", | |
torch.int8: "I8", | |
torch.uint8: "U8", | |
torch.bool: "BOOL", | |
getattr(torch, "float8_e5m2", None): "F8_E5M2", | |
getattr(torch, "float8_e4m3fn", None): "F8_E4M3", | |
} | |
_ALIGN = 256 | |
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: | |
validated = {} | |
for key, value in metadata.items(): | |
if not isinstance(key, str): | |
raise ValueError(f"Metadata key must be a string, got {type(key)}") | |
if not isinstance(value, str): | |
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") | |
validated[key] = str(value) | |
else: | |
validated[key] = value | |
return validated | |
header = {} | |
offset = 0 | |
if metadata: | |
header["__metadata__"] = validate_metadata(metadata) | |
for k, v in tensors.items(): | |
if v.numel() == 0: # empty tensor | |
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} | |
else: | |
size = v.numel() * v.element_size() | |
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} | |
offset += size | |
hjson = json.dumps(header).encode("utf-8") | |
hjson += b" " * (-(len(hjson) + 8) % _ALIGN) | |
with open(filename, "wb") as f: | |
f.write(struct.pack("<Q", len(hjson))) | |
f.write(hjson) | |
for k, v in tensors.items(): | |
if v.numel() == 0: | |
continue | |
if v.is_cuda: | |
# Direct GPU to disk save | |
with torch.cuda.device(v.device): | |
if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
v = v.unsqueeze(0) | |
tensor_bytes = v.contiguous().view(torch.uint8) | |
tensor_bytes.cpu().numpy().tofile(f) | |
else: | |
# CPU tensor save | |
if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
v = v.unsqueeze(0) | |
v.contiguous().view(torch.uint8).numpy().tofile(f) | |
# Usage example | |
if __name__ == "__main__": | |
# Create some example tensors on GPU | |
tensors = {"weight": torch.randn(1000, 1000, device="cuda"), "bias": torch.randn(1000, device="cuda")} | |
metadata = {"model_type": "example", "version": "1.0"} | |
mem_eff_save_file(tensors, "model.safetensors", metadata) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: Apache 2.0 | |
import unittest | |
import torch | |
import os | |
import tempfile | |
from safetensors.torch import load_file as official_load_file | |
from safetensors import safe_open | |
from mem_eff_save_file import mem_eff_save_file # あなたの実装 | |
class TestCompatibilityWithOfficialSafetensors(unittest.TestCase): | |
def setUp(self): | |
self.temp_dir = tempfile.mkdtemp() | |
def tearDown(self): | |
for file in os.listdir(self.temp_dir): | |
os.remove(os.path.join(self.temp_dir, file)) | |
os.rmdir(self.temp_dir) | |
def assert_tensors_equal(self, tensor1, tensor2): | |
self.assertTrue(torch.allclose(tensor1, tensor2, rtol=1e-5, atol=1e-8), f"Tensors are not equal: {tensor1} vs {tensor2}") | |
def test_compatibility_cpu_tensor(self): | |
tensor = torch.randn(100, 100) | |
tensors = {"test": tensor} | |
file_path = os.path.join(self.temp_dir, "custom_cpu.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
def test_compatibility_not_contiguous_cpu_tensor(self): | |
tensor = torch.randn(100, 100) | |
tensor = tensor[:, ::2] | |
tensors = {"test": tensor} | |
assert not tensor.is_contiguous(), "Tensor must not be contiguous" | |
file_path = os.path.join(self.temp_dir, "custom_not_contiguous_cpu.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | |
def test_compatibility_gpu_tensor(self): | |
tensor = torch.randn(100, 100, device="cuda") | |
tensors = {"test": tensor} | |
file_path = os.path.join(self.temp_dir, "custom_gpu.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key].cpu(), loaded_tensors[key]) | |
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | |
def test_compatibility_not_contiguous_gpu_tensor(self): | |
tensor = torch.randn(100, 100, device="cuda") | |
tensor = tensor[:, ::2] | |
tensors = {"test": tensor} | |
assert not tensor.is_contiguous(), "Tensor must not be contiguous" | |
file_path = os.path.join(self.temp_dir, "custom_not_contiguous_gpu.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key].cpu(), loaded_tensors[key]) | |
def test_compatibility_multiple_tensors(self): | |
tensors = {"weight": torch.randn(100, 100), "bias": torch.randn(100)} | |
file_path = os.path.join(self.temp_dir, "custom_multiple.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
def test_compatibility_with_empty_tensors(self): | |
tensors = {"empty": torch.tensor([]), "zero_dim": torch.tensor(1)} | |
file_path = os.path.join(self.temp_dir, "custom_empty.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
def test_compatibility_different_dtypes(self): | |
tensors = { | |
"float32": torch.randn(10, 10, dtype=torch.float32), | |
"float16": torch.randn(10, 10, dtype=torch.float16), | |
"int32": torch.randint(0, 10, (10, 10), dtype=torch.int32), | |
} | |
file_path = os.path.join(self.temp_dir, "custom_dtypes.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
self.assertEqual(tensors[key].dtype, loaded_tensors[key].dtype) | |
def test_compatibility_with_metadata(self): | |
tensor = torch.randn(10, 10) | |
tensors = {"test": tensor} | |
metadata = {"model_type": "test", "version": "1.0"} | |
file_path = os.path.join(self.temp_dir, "custom_metadata.safetensors") | |
mem_eff_save_file(tensors, file_path, metadata) | |
from safetensors import safe_open | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
# load metadata from .safetensors in official implementation | |
with safe_open(file_path, framework="pt") as f: | |
official_metadata = f.metadata() | |
self.assertEqual(metadata, official_metadata) | |
def test_compatibility_with_metadata_not_str_to_str(self): | |
tensor = torch.randn(10, 10) | |
tensors = {"test": tensor} | |
metadata = {"model_type": "test", "version": 1.0} | |
file_path = os.path.join(self.temp_dir, "custom_metadata_not_str_to_str.safetensors") | |
mem_eff_save_file(tensors, file_path, metadata) | |
from safetensors import safe_open | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
# load metadata from .safetensors in official implementation | |
with safe_open(file_path, framework="pt") as f: | |
official_metadata = f.metadata() | |
self.assertEqual({"model_type": "test", "version": "1.0"}, official_metadata) | |
def test_large_model_compatibility(self): | |
# 大規模なモデルをシミュレート | |
large_tensors = {f"layer_{i}": torch.randn(1000, 1000) for i in range(10)} | |
file_path = os.path.join(self.temp_dir, "large_model.safetensors") | |
mem_eff_save_file(large_tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(large_tensors.keys()), set(loaded_tensors.keys())) | |
for key in large_tensors: | |
self.assert_tensors_equal(large_tensors[key], loaded_tensors[key]) | |
if __name__ == "__main__": | |
unittest.main() |
pin_memoryをすると共有メモリが消費されるので、それを避けたい場合はpin_to_gpu = device is not None and device.type == "cuda" # *** Set to False here to avoid using shared GPU memory ***
の行をpin_to_gpu = False
としてください。
(速度向上が限定的になります。)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
MemoryEfficientSafeOpenにメタデータ読み込みの機能を追加しました。
またNumpyを使用することで読み込みを高速化しました。
テンソルをすぐにGPUへ移す場合、
get_tensor
にdeviceを指定してください。このとき転送は非同期で行われるため、読み込みが終わったらtorch.cuda.synchronize()
を呼び出してください。読み込み時にテンソルをキャストする場合はdtypeを指定してください(こちらの指定のみならsynchronizeは不要)。
device転送を行う場合で5倍ほど、行わない場合で1.5倍程度、高速化されると思います。