Last active
June 5, 2024 14:36
-
-
Save abetlen/db9f3015e6d5bcc7d00493fa7b368655 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import typing | |
import pathlib | |
import argparse | |
import numpy as np | |
import numpy.typing as npt | |
import gguf | |
from gguf import KEY_ATTENTION_HEAD_COUNT, KEY_ATTENTION_LAYERNORM_EPS, KEY_BLOCK_COUNT, KEY_EMBEDDING_LENGTH, KEY_FEED_FORWARD_LENGTH, GGUFWriter, TokenType, SpecialVocab | |
from safetensors import safe_open | |
class SafetensorsIndexFile(typing.TypedDict): | |
weight_map: typing.Dict[str, str] | |
class SafetensorsIndex: | |
def __init__(self, index_file_path: str): | |
directory = os.path.dirname(index_file_path) | |
self.index = typing.cast(SafetensorsIndexFile, json.load(open(index_file_path))) | |
self.weight_map = self.index["weight_map"] | |
files = set(self.weight_map.values()) | |
self.tensors = {file: safe_open(os.path.join(directory, file), framework="np") for file in files} | |
def get_tensor(self, key: str) -> npt.NDArray[np.float32]: | |
return typing.cast(npt.NDArray[np.float32], self.tensors[self.weight_map[key]].get_tensor(key)) # type: ignore | |
def extract_key(raw_key: str, arch: str) -> str: | |
return raw_key.format(arch=arch) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-d", | |
"--dir-model", | |
required=True, | |
help="path to directory containing the tokenizer", | |
) | |
args = parser.parse_args() | |
dir_model = pathlib.Path(args.dir_model) | |
# set model name to folder name | |
name = dir_model.name | |
tensors = SafetensorsIndex((dir_model / "model.safetensors.index.json").as_posix()) | |
# Load the model config | |
config = json.load(open(dir_model / "config.json")) | |
# text config is based on mistral v0.1 | |
text_config = { | |
"vocab_size": 32000, | |
"hidden_size": 4096, | |
"intermediate_size": 14336, | |
"num_hidden_layers": 32, | |
"num_attention_heads": 32, | |
"num_key_value_heads": 8, | |
"hidden_act": "silu", | |
"max_position_embeddings": 4096 * 32, | |
"rms_norm_eps": 1e-05, | |
"bos_token_id": 1, | |
"eos_token_id": 2, | |
"tie_word_embeddings": False, | |
"rope_theta": 10000.0, | |
"sliding_window": 4096 | |
} | |
text_config.update(config["text_config"]) | |
vision_config = config["vision_config"] | |
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/configuration_idefics2.py#L129 | |
perceiver_config = config.get("perceiver_config", { | |
"hidden_act": "silu", | |
"resampler_n_latents": 64, | |
"resampler_depth": 3, | |
"resampler_n_heads": 16, | |
"resampler_head_dim": 96, | |
"num_key_value_heads": 4, | |
"attention_dropout": 0.0, | |
}) | |
### Vision encoder | |
ftype = 1 # fp16 | |
fname_out = f"{name}-vision-model-f16.gguf" | |
fout = GGUFWriter(fname_out, arch="clip") | |
fout.add_bool("clip.has_text_encoder", False) | |
fout.add_bool("clip.has_vision_encoder", True) | |
fout.add_bool("clip.has_llava_projector", True) | |
fout.add_file_type(ftype) | |
model_name = "idefics2" | |
fout.add_name(model_name) | |
fout.add_description("Vision encoder for " + model_name) | |
fout.add_string("clip.projector_type", "idefics2") | |
n_layers_clip = vision_config["num_hidden_layers"] | |
# vision model hparams | |
VISION = "clip.vision" | |
fout.add_uint32("clip.vision.image_size", vision_config["image_size"]) # Update as necessary | |
fout.add_uint32("clip.vision.patch_size", vision_config["patch_size"]) # Update as necessary | |
fout.add_uint32(extract_key(KEY_EMBEDDING_LENGTH, VISION), vision_config["hidden_size"]) | |
fout.add_uint32(extract_key(KEY_FEED_FORWARD_LENGTH, VISION), vision_config["intermediate_size"]) | |
fout.add_uint32("clip.vision.projection_dim", 4096) # Update as necessary | |
fout.add_uint32(extract_key(KEY_ATTENTION_HEAD_COUNT, VISION), vision_config["num_attention_heads"]) | |
fout.add_float32(extract_key(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) | |
fout.add_uint32(extract_key(KEY_BLOCK_COUNT, VISION), n_layers_clip + 1) | |
fout.add_array("clip.vision.image_mean", [0.5, 0.5, 0.5]) | |
fout.add_array("clip.vision.image_std", [0.5, 0.5, 0.5]) | |
fout.add_bool("clip.use_gelu", True) # using regular GELU instead of quick | |
# connector | |
# model.connector | |
# model.connector.modality_projection.down_proj.weight [4 096, 14 336] | |
# F32 | |
fout.add_tensor( | |
"mm.mp.ffn_down.weight", | |
tensors.get_tensor("model.connector.modality_projection.down_proj.weight").astype(np.float16), | |
) | |
# model.connector.modality_projection.gate_proj.weight [14 336, 1 152] | |
# F32 | |
fout.add_tensor( | |
"mm.mp.ffn_gate.weight", | |
tensors.get_tensor("model.connector.modality_projection.gate_proj.weight").astype(np.float16), | |
) | |
# model.connector.modality_projection.up_proj.weight [14 336, 1 152] | |
# F32 | |
fout.add_tensor( | |
"mm.mp.ffn_up.weight", | |
tensors.get_tensor("model.connector.modality_projection.up_proj.weight").astype(np.float16), | |
) | |
# model.connector.perceiver_resampler.latents [64, 4 096] | |
# F32 | |
fout.add_tensor( | |
"mm.pr.latents.weight", | |
tensors.get_tensor("model.connector.perceiver_resampler.latents").astype(np.float32), | |
) | |
# model.connector.perceiver_resampler.norm.weight [4 096] | |
# F32 | |
fout.add_tensor( | |
"mm.pr.ln0.weight", | |
tensors.get_tensor("model.connector.perceiver_resampler.norm.weight").astype(np.float32), | |
) | |
for i in range(3): | |
# model.connector.perceiver_resampler.layers.0.input_context_norm.weight [4 096] | |
# F32 | |
fout.add_tensor( | |
f"mm.pr.blk.{i}.ln0.weight", | |
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.input_context_norm.weight").astype(np.float32), | |
) | |
# model.connector.perceiver_resampler.layers.0.input_latents_norm.weight [4 096] | |
# F32 | |
fout.add_tensor( | |
f"mm.pr.blk.{i}.ln1.weight", | |
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.input_latents_norm.weight").astype(np.float32), | |
) | |
# model.connector.perceiver_resampler.layers.0.mlp.down_proj.weight [4 096, 16 384] | |
# F32 | |
fout.add_tensor( | |
f"mm.pr.blk.{i}.ffn_down.weight", | |
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.mlp.down_proj.weight").astype(np.float16), | |
) | |
# model.connector.perceiver_resampler.layers.0.mlp.gate_proj.weight [16 384, 4 096] | |
# F32 | |
fout.add_tensor( | |
f"mm.pr.blk.{i}.ffn_gate.weight", | |
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.mlp.gate_proj.weight").astype(np.float16), | |
) | |
# model.connector.perceiver_resampler.layers.0.mlp.up_proj.weight [16 384, 4 096] | |
# F32 | |
fout.add_tensor( | |
f"mm.pr.blk.{i}.ffn_up.weight", | |
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.mlp.up_proj.weight").astype(np.float16), | |
) | |
# model.connector.perceiver_resampler.layers.0.post_attention_layernorm.weight [4 096] | |
# F32 | |
fout.add_tensor( | |
f"mm.pr.blk.{i}.ln2.weight", | |
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.post_attention_layernorm.weight").astype(np.float32), | |
) | |
# model.connector.perceiver_resampler.layers.0.self_attn.k_proj.weight [384, 4 096] | |
# F32 | |
fout.add_tensor( | |
f"mm.pr.blk.{i}.attn_k.weight", | |
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.k_proj.weight").astype(np.float16), | |
) | |
# model.connector.perceiver_resampler.layers.0.self_attn.o_proj.weight [4 096, 1 536] | |
# F32 | |
fout.add_tensor( | |
f"mm.pr.blk.{i}.attn_o.weight", | |
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.o_proj.weight").astype(np.float16), | |
) | |
# model.connector.perceiver_resampler.layers.0.self_attn.q_proj.weight [1 536, 4 096] | |
# F32 | |
fout.add_tensor( | |
f"mm.pr.blk.{i}.attn_q.weight", | |
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.q_proj.weight").astype(np.float16), | |
) | |
# model.connector.perceiver_resampler.layers.0.self_attn.v_proj.weight [384, 4 096] | |
# F32 | |
fout.add_tensor( | |
f"mm.pr.blk.{i}.attn_v.weight", | |
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.v_proj.weight").astype(np.float16), | |
) | |
# vision_model | |
fout.add_tensor( | |
"v.position_embd.weight", | |
tensors.get_tensor("model.vision_model.embeddings.position_embedding.weight").astype(np.float16), | |
) | |
fout.add_tensor( | |
"v.patch_embd.weight", | |
tensors.get_tensor("model.vision_model.embeddings.patch_embedding.weight") | |
.reshape(vision_config["hidden_size"], 3, vision_config["patch_size"], vision_config["patch_size"]) | |
.astype(np.float16), | |
) | |
fout.add_tensor( | |
"v.patch_embd.bias", | |
tensors.get_tensor("model.vision_model.embeddings.patch_embedding.bias").astype(np.float32), | |
) | |
fout.add_tensor( | |
"v.post_ln.weight", | |
tensors.get_tensor("model.vision_model.post_layernorm.weight").astype(np.float32), | |
) | |
fout.add_tensor( | |
"v.post_ln.bias", | |
tensors.get_tensor("model.vision_model.post_layernorm.bias").astype(np.float32), | |
) | |
def add_vision_tensor(blk_id: int, gguf_id: typing.Optional[int]=None): | |
if gguf_id is None: | |
gguf_id = blk_id | |
attn_prefix = f"model.vision_model.encoder.layers.{blk_id}.self_attn." | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.attn_q.weight", | |
tensors.get_tensor(f"{attn_prefix}q_proj.weight").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.attn_q.bias", | |
tensors.get_tensor(f"{attn_prefix}q_proj.bias").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.attn_k.weight", | |
tensors.get_tensor(f"{attn_prefix}k_proj.weight").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.attn_k.bias", | |
tensors.get_tensor(f"{attn_prefix}k_proj.bias").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.attn_v.weight", | |
tensors.get_tensor(f"{attn_prefix}v_proj.weight").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.attn_v.bias", | |
tensors.get_tensor(f"{attn_prefix}v_proj.bias").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.attn_out.weight", | |
tensors.get_tensor(f"{attn_prefix}out_proj.weight").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.attn_out.bias", | |
tensors.get_tensor(f"{attn_prefix}out_proj.bias").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.ln1.weight", | |
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm1.weight").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.ln1.bias", | |
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm1.bias").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.ffn_down.weight", | |
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc1.weight").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.ffn_down.bias", | |
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc1.bias").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.ffn_up.weight", | |
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc2.weight").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.ffn_up.bias", | |
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc2.bias").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.ln2.weight", | |
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm2.weight").astype(np.float32), | |
) | |
fout.add_tensor( | |
f"v.blk.{gguf_id}.ln2.bias", | |
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm2.bias").astype(np.float32), | |
) | |
for i in range(n_layers_clip): | |
add_vision_tensor(i) | |
# Duplicate the last block (llava-cli skips over this) | |
add_vision_tensor(n_layers_clip - 1, n_layers_clip) | |
fout.write_header_to_file() | |
fout.write_kv_data_to_file() | |
fout.write_tensors_to_file() | |
fout.close() | |
### Text Model | |
# general GGUF init | |
fname_out = f"{name}-text-model-f16.gguf" | |
fout = GGUFWriter(fname_out, arch="llama") | |
ftype = 1 | |
block_count = text_config["num_hidden_layers"] | |
fout.add_name(name) | |
fout.add_block_count(block_count) | |
fout.add_context_length(text_config["max_position_embeddings"]) | |
fout.add_embedding_length(text_config["hidden_size"]) | |
fout.add_feed_forward_length(text_config["intermediate_size"]) | |
fout.add_head_count(text_config["num_attention_heads"]) | |
fout.add_head_count_kv(text_config["num_key_value_heads"]) | |
fout.add_rope_freq_base(text_config["rope_theta"]) | |
fout.add_layer_norm_rms_eps(text_config["rms_norm_eps"]) | |
fout.add_file_type(ftype) | |
fout.add_vocab_size(text_config["vocab_size"]) | |
fout.add_rope_dimension_count( | |
text_config["hidden_size"] // text_config["num_attention_heads"] | |
) | |
tokenizer_config_file = dir_model / 'tokenizer_config.json' | |
if tokenizer_config_file.is_file(): | |
with open(tokenizer_config_file, "r", encoding="utf-8") as f: | |
tokenizer_config_json = json.load(f) | |
if "add_prefix_space" in tokenizer_config_json: | |
fout.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) | |
### Tokenizer | |
# Taken from _set_vocab_sentencepiece | |
from enum import IntEnum | |
class SentencePieceTokenTypes(IntEnum): | |
NORMAL = 1 | |
UNKNOWN = 2 | |
CONTROL = 3 | |
USER_DEFINED = 4 | |
UNUSED = 5 | |
BYTE = 6 | |
from sentencepiece import SentencePieceProcessor | |
tokenizer_path = dir_model / 'tokenizer.model' | |
tokens: typing.List[bytes] = [] | |
scores: typing.List[float] = [] | |
toktypes: typing.List[int] = [] | |
if not tokenizer_path.is_file(): | |
raise FileNotFoundError(f"File not found: {tokenizer_path}") | |
tokenizer = SentencePieceProcessor() | |
tokenizer.LoadFromFile(str(tokenizer_path)) | |
vocab_size = text_config["vocab_size"] | |
tokens: typing.List[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] | |
scores: typing.List[float] = [-10000.0] * vocab_size | |
toktypes: typing.List[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size | |
for token_id in range(tokenizer.vocab_size()): | |
piece = tokenizer.IdToPiece(token_id) | |
text = piece.encode("utf-8") | |
score = tokenizer.GetScore(token_id) | |
toktype = SentencePieceTokenTypes.NORMAL | |
if tokenizer.IsUnknown(token_id): | |
toktype = SentencePieceTokenTypes.UNKNOWN | |
elif tokenizer.IsControl(token_id): | |
toktype = SentencePieceTokenTypes.CONTROL | |
elif tokenizer.IsUnused(token_id): | |
toktype = SentencePieceTokenTypes.UNUSED | |
elif tokenizer.IsByte(token_id): | |
toktype = SentencePieceTokenTypes.BYTE | |
tokens[token_id] = text | |
scores[token_id] = score | |
toktypes[token_id] = toktype | |
added_tokens_file = dir_model / 'added_tokens.json' | |
if added_tokens_file.is_file(): | |
with open(added_tokens_file, "r", encoding="utf-8") as f: | |
added_tokens_json = json.load(f) | |
for key in added_tokens_json: | |
token_id = added_tokens_json[key] | |
if (token_id >= vocab_size): | |
print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') | |
continue | |
tokens[token_id] = key.encode("utf-8") | |
scores[token_id] = -1000.0 | |
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED | |
if vocab_size > len(tokens): | |
pad_count = vocab_size - len(tokens) | |
print(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") | |
for i in range(1, pad_count + 1): | |
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) | |
scores.append(-1000.0) | |
toktypes.append(SentencePieceTokenTypes.UNUSED) | |
fout.add_tokenizer_model("llama") | |
fout.add_tokenizer_pre("default") | |
fout.add_token_list(tokens) | |
fout.add_token_scores(scores) | |
fout.add_token_types(toktypes) | |
special_vocab = gguf.SpecialVocab(dir_model, n_vocab=len(tokens)) | |
special_vocab.add_to_gguf(fout) | |
def permute(weights: npt.NDArray[np.float16], n_head: int, n_head_kv: typing.Optional[int]): | |
if n_head_kv is not None and n_head != n_head_kv: | |
n_head = n_head_kv | |
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) | |
.swapaxes(1, 2) | |
.reshape(weights.shape)) | |
n_head = typing.cast(int, text_config["num_attention_heads"]) | |
n_kv_head = typing.cast(int, text_config["num_key_value_heads"]) | |
fout.add_tensor( | |
"token_embd.weight", | |
tensors.get_tensor("model.text_model.embed_tokens.weight").astype(np.float32), | |
) | |
def add_text_tensor(i: int): | |
fout.add_tensor( | |
f"blk.{i}.attn_norm.weight", | |
tensors.get_tensor(f"model.text_model.layers.{i}.input_layernorm.weight").astype( | |
np.float32 | |
), | |
) | |
fout.add_tensor( | |
f"blk.{i}.ffn_down.weight", | |
tensors.get_tensor(f"model.text_model.layers.{i}.mlp.down_proj.weight").astype( | |
np.float16 | |
), | |
) | |
fout.add_tensor( | |
f"blk.{i}.ffn_gate.weight", | |
tensors.get_tensor(f"model.text_model.layers.{i}.mlp.gate_proj.weight").astype( | |
np.float16 | |
), | |
) | |
fout.add_tensor( | |
f"blk.{i}.ffn_up.weight", | |
tensors.get_tensor(f"model.text_model.layers.{i}.mlp.up_proj.weight").astype( | |
np.float16 | |
), | |
) | |
fout.add_tensor( | |
f"blk.{i}.ffn_norm.weight", | |
tensors.get_tensor(f"model.text_model.layers.{i}.post_attention_layernorm.weight").astype( | |
np.float32 | |
), | |
) | |
fout.add_tensor( | |
f"blk.{i}.attn_k.weight", | |
permute( | |
tensors.get_tensor( | |
f"model.text_model.layers.{i}.self_attn.k_proj.weight" | |
).astype(np.float16), | |
n_head, | |
n_kv_head | |
), | |
) | |
fout.add_tensor( | |
f"blk.{i}.attn_output.weight", | |
tensors.get_tensor( | |
f"model.text_model.layers.{i}.self_attn.o_proj.weight" | |
).astype(np.float16), | |
) | |
fout.add_tensor( | |
f"blk.{i}.attn_q.weight", | |
permute( | |
tensors.get_tensor( | |
f"model.text_model.layers.{i}.self_attn.q_proj.weight" | |
).astype(np.float16), | |
n_head, | |
n_head, | |
) | |
) | |
fout.add_tensor( | |
f"blk.{i}.attn_v.weight", | |
tensors.get_tensor( | |
f"model.text_model.layers.{i}.self_attn.v_proj.weight" | |
).astype(np.float16), | |
) | |
for i in range(32): # Update as necessary | |
add_text_tensor(i) | |
fout.add_tensor( | |
"output_norm.weight", | |
tensors.get_tensor("model.text_model.norm.weight").astype(np.float32), | |
) | |
fout.add_tensor( | |
"output.weight", | |
tensors.get_tensor("lm_head.weight").astype(np.float32), | |
) | |
fout.write_header_to_file() | |
fout.write_kv_data_to_file() | |
fout.write_tensors_to_file() | |
fout.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment