Skip to content

Instantly share code, notes, and snippets.

@ddh0
Last active August 26, 2024 08:14
Show Gist options
  • Save ddh0/fc3c9e081fad332d5f76609687dd19eb to your computer and use it in GitHub Desktop.
Save ddh0/fc3c9e081fad332d5f76609687dd19eb to your computer and use it in GitHub Desktop.
Python code to convert SDXL fp16 safetensors to 8-bit safetensors
# safetensors_convert_fp16_to_8bit.py
# Python 3.11.7
import safetensors.torch
import safetensors
import torch
import os
# blacklist takes priority over whitelist
# a tensor will only be cast if it matches the whitelist but not the blacklist
# list of substrings to search for within tensor names
tensor_name_blacklist = [
'emb' # embedding tensors
]
# list of substrings to search for within tensor names
tensor_name_whitelist = [
'diffusion_model'
]
fn = input('Enter the path of a safetensors FP16 model file: ')
if not os.path.exists(fn):
raise FileNotFoundError(
f"'{fn}' does not exist"
)
if not os.path.isfile(fn):
raise FileNotFoundError(
f"'{fn}' is not a file"
)
if not fn.endswith('.safetensors'):
raise ValueError(
f"filename '{fn}' does not end in '.safetensors'. only safetensors files are supported"
)
if 'fp16' in fn:
output_fn = fn.replace('fp16', '8-bit')
elif 'f16' in fn:
output_fn = fn.replace('f16', '8-bit')
elif 'FP16' in fn:
output_fn = fn.replace('FP16', '8-bit')
elif 'F16' in fn:
output_fn = fn.replace('F16', '8-bit')
else:
output_fn = fn.replace('.safetensors', '-8-bit.safetensors')
if os.path.exists(output_fn):
raise FileExistsError(
f"destination file '{output_fn}' already exists"
)
def maybe_reduce_precision_tensor(tensor: torch.Tensor, tensor_name: str) -> torch.Tensor:
"""
Convert the given tensor to 8-bit if it is float16, otherwise
return the tensor unchanged
"""
# do not cast tensors that are not fp16
if tensor.dtype not in [torch.float16, torch.half]:
print(f"SKIP: tensor {tensor_name}: {tensor.dtype} unchanged")
return tensor
# fp16 tensor -> 8-bit tensor
print(f"CAST: - tensor {tensor_name}: {tensor.dtype} -> torch.int8")
return tensor.char()
fp16_tensors: dict[str, torch.Tensor] = {}
# change `device='mps'` to `device='cpu'` if you are not using Metal
with safetensors.safe_open(fn, framework="pt", device='mps') as f:
for tensor_name in f.keys():
print(f"LOAD: tensor {tensor_name}")
fp16_tensors[tensor_name] = f.get_tensor(tensor_name)
_8bit_tensors: dict[str, torch.Tensor] = {}
for tensor_name in fp16_tensors.keys():
if any(string in tensor_name for string in tensor_name_blacklist):
print(f'COPY: tensor {tensor_name} is blacklisted')
_8bit_tensors[tensor_name] = fp16_tensors[tensor_name]
else:
if any(string in tensor_name for string in tensor_name_whitelist):
_8bit_tensors[tensor_name] = maybe_reduce_precision_tensor(
tensor=fp16_tensors[tensor_name],
tensor_name=tensor_name
)
else:
print(f'COPY: tensor {tensor_name} is not whitelisted')
_8bit_tensors[tensor_name] = fp16_tensors[tensor_name]
safetensors.torch.save_file(
tensors=_8bit_tensors,
filename=output_fn
)
print(f'saved 8-bit file to {output_fn}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment