Created
December 20, 2022 21:19
-
-
Save Narsil/d5b0d747e5c8c299eb6d82709e480e3d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from huggingface_hub import hf_hub_download | |
from flax.serialization import msgpack_restore | |
from safetensors.flax import save_file | |
import numpy as np | |
filename = hf_hub_download("gpt2", filename="flax_model.msgpack") | |
with open(filename, "rb") as f: | |
data = f.read() | |
flax_weights = msgpack_restore(data) | |
def flatten(weights, prefix=""): | |
values = {} | |
for k, v in weights.items(): | |
newprefix = f"{prefix}.{k}" if prefix else f"{k}" | |
print(newprefix) | |
if isinstance(v, dict): | |
values.update(flatten(v, prefix=newprefix)) | |
elif isinstance(v, np.ndarray): | |
values[newprefix] = v | |
return values | |
weights = flatten(flax_weights) | |
save_file(weights, "model.safetensors") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment