Skip to content

Instantly share code, notes, and snippets.

@blepping
Created March 5, 2025 06:07
Show Gist options
  • Save blepping/e0cc4efd3cb3db0be7f8927325c8f622 to your computer and use it in GitHub Desktop.
Save blepping/e0cc4efd3cb3db0be7f8927325c8f622 to your computer and use it in GitHub Desktop.
Simple script to pretty print a SafeTensors file
#!python3
import argparse
import mmap
import json
import os
import struct
import sys
from pathlib import Path
from typing import NamedTuple
class Tensor(NamedTuple):
key: str
dtype: str
shape: tuple
def get_args():
parser = argparse.ArgumentParser(description="Pretty print a SafeTensors file")
parser.add_argument("FILE", type=Path, help="Input SafeTensors file to read")
parser.add_argument("--no-sort", action="store_true", help="Disable sorting tensors by name")
parser.add_argument("--metadata", action="store_true", help="Dump metadata section as well")
return parser.parse_args()
def read_metadata(args, *, max_metadata_size = 1024*1024*8):
fp = open(args.FILE, "rb")
mm = mmap.mmap(fp.fileno(), fp.seek(0, os.SEEK_END), access=mmap.ACCESS_READ)
if len(mm) < 10:
print("! File too short to be SafeTensors", file=sys.stderr)
sys.exit(1)
metadata_size = struct.unpack("<Q", mm[:8])[0]
if not 1 < metadata_size < max_metadata_size or mm[8] != 123: # 123 = "{"
print("! File does not appear to be SafeTensors", file=sys.stderr)
sys.exit(1)
metadata = json.loads(mm[8:8 + metadata_size])
st_metadata = metadata.pop("__metadata__", None)
metadata = tuple(
Tensor(k, v["dtype"], v["shape"])
for k, v in metadata.items()
if isinstance(v, dict) and "shape" in v and "dtype" in v
)
return metadata, st_metadata
def make_format(size):
return f"{{0:>{size}}}"
def dump_tensors(args, tensors, metadata):
longest_dim = max(
max(len(str(dim)) for dim in t.shape)
for t in tensors
)
longest_shape = max(
len(t.shape) * longest_dim + max(0, 2 * len(t.shape) - 1)
for t in tensors
)
longest_dtype = max(len(t.dtype) for t in tensors)
dim_format, shape_format, dtype_format, idx_format = (
make_format(size)
for size in (longest_dim, longest_shape, longest_dtype, len(str(len(tensors))))
)
if not args.no_sort:
tensors = sorted(tensors, key=lambda t: t.key)
if args.metadata and metadata:
print("* Metadata:")
for k, v in metadata.items():
print(f" {k}:\n {v}")
print()
for idx, tensor in enumerate(tensors):
shape = shape_format.format(", ".join(dim_format.format(d) for d in tensor.shape))
dtype = dtype_format.format(tensor.dtype)
print(" | ".join((idx_format.format(idx), dtype, shape, tensor.key)))
def main():
args = get_args()
dump_tensors(args, *read_metadata(args))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment