Last active
August 26, 2024 08:14
-
-
Save ddh0/fc3c9e081fad332d5f76609687dd19eb to your computer and use it in GitHub Desktop.
Python code to convert SDXL fp16 safetensors to 8-bit safetensors
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# safetensors_convert_fp16_to_8bit.py | |
# Python 3.11.7 | |
import safetensors.torch | |
import safetensors | |
import torch | |
import os | |
# blacklist takes priority over whitelist | |
# a tensor will only be cast if it matches the whitelist but not the blacklist | |
# list of substrings to search for within tensor names | |
tensor_name_blacklist = [ | |
'emb' # embedding tensors | |
] | |
# list of substrings to search for within tensor names | |
tensor_name_whitelist = [ | |
'diffusion_model' | |
] | |
fn = input('Enter the path of a safetensors FP16 model file: ') | |
if not os.path.exists(fn): | |
raise FileNotFoundError( | |
f"'{fn}' does not exist" | |
) | |
if not os.path.isfile(fn): | |
raise FileNotFoundError( | |
f"'{fn}' is not a file" | |
) | |
if not fn.endswith('.safetensors'): | |
raise ValueError( | |
f"filename '{fn}' does not end in '.safetensors'. only safetensors files are supported" | |
) | |
if 'fp16' in fn: | |
output_fn = fn.replace('fp16', '8-bit') | |
elif 'f16' in fn: | |
output_fn = fn.replace('f16', '8-bit') | |
elif 'FP16' in fn: | |
output_fn = fn.replace('FP16', '8-bit') | |
elif 'F16' in fn: | |
output_fn = fn.replace('F16', '8-bit') | |
else: | |
output_fn = fn.replace('.safetensors', '-8-bit.safetensors') | |
if os.path.exists(output_fn): | |
raise FileExistsError( | |
f"destination file '{output_fn}' already exists" | |
) | |
def maybe_reduce_precision_tensor(tensor: torch.Tensor, tensor_name: str) -> torch.Tensor: | |
""" | |
Convert the given tensor to 8-bit if it is float16, otherwise | |
return the tensor unchanged | |
""" | |
# do not cast tensors that are not fp16 | |
if tensor.dtype not in [torch.float16, torch.half]: | |
print(f"SKIP: tensor {tensor_name}: {tensor.dtype} unchanged") | |
return tensor | |
# fp16 tensor -> 8-bit tensor | |
print(f"CAST: - tensor {tensor_name}: {tensor.dtype} -> torch.int8") | |
return tensor.char() | |
fp16_tensors: dict[str, torch.Tensor] = {} | |
# change `device='mps'` to `device='cpu'` if you are not using Metal | |
with safetensors.safe_open(fn, framework="pt", device='mps') as f: | |
for tensor_name in f.keys(): | |
print(f"LOAD: tensor {tensor_name}") | |
fp16_tensors[tensor_name] = f.get_tensor(tensor_name) | |
_8bit_tensors: dict[str, torch.Tensor] = {} | |
for tensor_name in fp16_tensors.keys(): | |
if any(string in tensor_name for string in tensor_name_blacklist): | |
print(f'COPY: tensor {tensor_name} is blacklisted') | |
_8bit_tensors[tensor_name] = fp16_tensors[tensor_name] | |
else: | |
if any(string in tensor_name for string in tensor_name_whitelist): | |
_8bit_tensors[tensor_name] = maybe_reduce_precision_tensor( | |
tensor=fp16_tensors[tensor_name], | |
tensor_name=tensor_name | |
) | |
else: | |
print(f'COPY: tensor {tensor_name} is not whitelisted') | |
_8bit_tensors[tensor_name] = fp16_tensors[tensor_name] | |
safetensors.torch.save_file( | |
tensors=_8bit_tensors, | |
filename=output_fn | |
) | |
print(f'saved 8-bit file to {output_fn}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment