|
import os |
|
import torch |
|
import argparse |
|
from safetensors.torch import save_file |
|
|
|
# check main |
|
if __name__ != "__main__": |
|
raise Exception("This script is not meant to be imported") |
|
|
|
parser = argparse.ArgumentParser(description="Convert a model from pickle to safetensors format") |
|
parser.add_argument("--input", type=str, help="Path to input model in torch format (.ckpt)", required=True) |
|
parser.add_argument("--output", type=str, help="Path to output model (without extension)", default="model", required=False) |
|
parser.add_argument("--fp16", action=argparse.BooleanOptionalAction, help="Whether to use half precision", default=False, required=False) |
|
parser.add_argument("--device", type=str, help="Device to use (defaults to 'cpu')", default="cpu", required=False) |
|
args = parser.parse_args() |
|
|
|
print(f"• Loading model from {args.input}...") |
|
|
|
weights = torch.load(args.input, map_location=args.device)["state_dict"] |
|
|
|
if args.fp16: |
|
print("• Converting to half precision...") |
|
weights = {k: v.half() for k, v in weights.items()} |
|
|
|
output_extension = f"{'.fp16' if args.fp16 else ''}.safetensors" |
|
output_file = args.output + output_extension |
|
|
|
while os.path.isfile(output_file): |
|
overwrite = input( |
|
f"! Output file '{output_file}' already exists. Overwrite? [y/N]: ") |
|
if overwrite.lower() == "y": |
|
break |
|
else: |
|
filename = input( |
|
"? Please enter a new output file name (without extension): ") |
|
if filename: |
|
output_file = filename + output_extension |
|
|
|
print(f"• Saving to {output_file}...") |
|
|
|
save_file(weights, output_file) |
|
|
|
print("✓ Done!") |