Created
March 5, 2025 06:07
-
-
Save blepping/e0cc4efd3cb3db0be7f8927325c8f622 to your computer and use it in GitHub Desktop.
Simple script to pretty print a SafeTensors file
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!python3 | |
import argparse | |
import mmap | |
import json | |
import os | |
import struct | |
import sys | |
from pathlib import Path | |
from typing import NamedTuple | |
class Tensor(NamedTuple): | |
key: str | |
dtype: str | |
shape: tuple | |
def get_args(): | |
parser = argparse.ArgumentParser(description="Pretty print a SafeTensors file") | |
parser.add_argument("FILE", type=Path, help="Input SafeTensors file to read") | |
parser.add_argument("--no-sort", action="store_true", help="Disable sorting tensors by name") | |
parser.add_argument("--metadata", action="store_true", help="Dump metadata section as well") | |
return parser.parse_args() | |
def read_metadata(args, *, max_metadata_size = 1024*1024*8): | |
fp = open(args.FILE, "rb") | |
mm = mmap.mmap(fp.fileno(), fp.seek(0, os.SEEK_END), access=mmap.ACCESS_READ) | |
if len(mm) < 10: | |
print("! File too short to be SafeTensors", file=sys.stderr) | |
sys.exit(1) | |
metadata_size = struct.unpack("<Q", mm[:8])[0] | |
if not 1 < metadata_size < max_metadata_size or mm[8] != 123: # 123 = "{" | |
print("! File does not appear to be SafeTensors", file=sys.stderr) | |
sys.exit(1) | |
metadata = json.loads(mm[8:8 + metadata_size]) | |
st_metadata = metadata.pop("__metadata__", None) | |
metadata = tuple( | |
Tensor(k, v["dtype"], v["shape"]) | |
for k, v in metadata.items() | |
if isinstance(v, dict) and "shape" in v and "dtype" in v | |
) | |
return metadata, st_metadata | |
def make_format(size): | |
return f"{{0:>{size}}}" | |
def dump_tensors(args, tensors, metadata): | |
longest_dim = max( | |
max(len(str(dim)) for dim in t.shape) | |
for t in tensors | |
) | |
longest_shape = max( | |
len(t.shape) * longest_dim + max(0, 2 * len(t.shape) - 1) | |
for t in tensors | |
) | |
longest_dtype = max(len(t.dtype) for t in tensors) | |
dim_format, shape_format, dtype_format, idx_format = ( | |
make_format(size) | |
for size in (longest_dim, longest_shape, longest_dtype, len(str(len(tensors)))) | |
) | |
if not args.no_sort: | |
tensors = sorted(tensors, key=lambda t: t.key) | |
if args.metadata and metadata: | |
print("* Metadata:") | |
for k, v in metadata.items(): | |
print(f" {k}:\n {v}") | |
print() | |
for idx, tensor in enumerate(tensors): | |
shape = shape_format.format(", ".join(dim_format.format(d) for d in tensor.shape)) | |
dtype = dtype_format.format(tensor.dtype) | |
print(" | ".join((idx_format.format(idx), dtype, shape, tensor.key))) | |
def main(): | |
args = get_args() | |
dump_tensors(args, *read_metadata(args)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment