Created
February 26, 2025 09:53
-
-
Save vra/e39be847ad51dcd5dc1d5b8e8b63a6e4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/modeling_whisper_official0.py b/modeling_whisper0.py | |
index 1285034..4ab3bb0 100644 | |
--- a/modeling_whisper_official0.py | |
+++ b/modeling_whisper0.py | |
@@ -1,4 +1,4 @@ | |
-class WhisperEncoder(WhisperPreTrainedModel): | |
+class WhisperVQEncoder(WhisperPreTrainedModel): | |
""" | |
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a | |
[`WhisperEncoderLayer`]. | |
@@ -7,8 +7,9 @@ class WhisperEncoder(WhisperPreTrainedModel): | |
config: WhisperConfig | |
""" | |
- def __init__(self, config: WhisperConfig): | |
+ def __init__(self, config: WhisperVQConfig): | |
super().__init__(config) | |
+ self.config = config | |
self.dropout = config.dropout | |
self.layerdrop = config.encoder_layerdrop | |
@@ -17,22 +18,107 @@ class WhisperEncoder(WhisperPreTrainedModel): | |
self.padding_idx = config.pad_token_id | |
self.max_source_positions = config.max_source_positions | |
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 | |
- | |
- self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) | |
- self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) | |
+ if config.encoder_causal_convolution: | |
+ conv_class = CausalConv1d | |
+ else: | |
+ conv_class = nn.Conv1d | |
+ self.conv1 = conv_class(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) | |
+ self.conv2 = conv_class( | |
+ embed_dim, embed_dim, kernel_size=3, stride=2, padding=1 | |
+ ) | |
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) | |
self.embed_positions.requires_grad_(False) | |
- | |
- self.layers = nn.ModuleList( | |
- [WhisperEncoderLayer(config) for _ in range(config.encoder_layers)] | |
- ) | |
- self.layer_norm = nn.LayerNorm(config.d_model) | |
+ if config.quantize_encoder_only: | |
+ self.layers = nn.ModuleList( | |
+ [ | |
+ WhisperVQEncoderLayer( | |
+ config, | |
+ is_causal=config.encoder_causal_attention | |
+ or config.quantize_causal_encoder, | |
+ ) | |
+ for _ in range(config.quantize_position) | |
+ ] | |
+ ) | |
+ else: | |
+ self.layers = nn.ModuleList( | |
+ [ | |
+ WhisperVQEncoderLayer( | |
+ config, | |
+ is_causal=config.encoder_causal_attention | |
+ or ( | |
+ config.quantize_causal_encoder | |
+ and layer_id < config.quantize_position | |
+ ), | |
+ ) | |
+ for layer_id in range(config.encoder_layers) | |
+ ] | |
+ ) | |
+ self.layer_norm = nn.LayerNorm(config.d_model) | |
self.gradient_checkpointing = False | |
+ # Parameters related to pooling layer | |
+ self.pooling_layer = None | |
+ # Parameters related to quantization layer | |
+ self.codebook = None | |
+ self.embed_positions2 = None | |
+ self.quantize_loss = None | |
+ self.num_active_codes = None | |
+ self.quantize_ema_count = 0 | |
+ # Save hiddens | |
+ self.save_hidden_dir = None | |
+ self.save_hidden_position = None | |
# Initialize weights and apply final processing | |
+ self.init_pooling_layer(config) | |
+ self.init_quantize_layer(config) | |
self.post_init() | |
+ def init_pooling_layer(self, config: WhisperVQConfig): | |
+ if config.pooling_kernel_size is not None: | |
+ if config.pooling_type == "max": | |
+ self.pooling_layer = nn.MaxPool1d( | |
+ kernel_size=config.pooling_kernel_size | |
+ ) | |
+ elif config.pooling_type == "avg": | |
+ self.pooling_layer = nn.AvgPool1d( | |
+ kernel_size=config.pooling_kernel_size | |
+ ) | |
+ else: | |
+ raise NotImplementedError( | |
+ f"Pooling type {config.pooling_type} not implemented" | |
+ ) | |
+ | |
+ def init_quantize_layer(self, config: WhisperVQConfig, quantize_load_codebook=None): | |
+ if config.quantize_vocab_size is not None: | |
+ if config.pooling_position is not None: | |
+ assert config.quantize_position >= config.pooling_position | |
+ self.codebook = nn.Embedding( | |
+ config.quantize_vocab_size, self.config.d_model | |
+ ) | |
+ if quantize_load_codebook is not None: | |
+ init_codes = np.load(quantize_load_codebook) | |
+ self.codebook.weight.data.copy_(torch.from_numpy(init_codes)) | |
+ max_source_positions = self.max_source_positions | |
+ if config.pooling_kernel_size is not None: | |
+ max_source_positions = math.ceil( | |
+ max_source_positions / self.config.pooling_kernel_size | |
+ ) | |
+ self.embed_positions2 = nn.Embedding( | |
+ max_source_positions, self.config.d_model | |
+ ) | |
+ self.embed_positions2.weight.data.copy_( | |
+ self.embed_positions.weight.data[:max_source_positions] | |
+ ) | |
+ if config.quantize_ema_decay is not None: | |
+ self.codebook.weight.requires_grad = False | |
+ self.register_buffer( | |
+ "ema_count", | |
+ torch.ones(config.quantize_vocab_size, dtype=torch.float), | |
+ ) | |
+ self.register_buffer( | |
+ "ema_weight", self.codebook.weight.data.clone().float() | |
+ ) | |
+ | |
def _freeze_parameters(self): | |
for param in self.parameters(): | |
param.requires_grad = False | |
@@ -44,6 +130,31 @@ class WhisperEncoder(WhisperPreTrainedModel): | |
def set_input_embeddings(self, value: nn.Module): | |
self.conv1 = value | |
+ def get_block_causal_attention_mask(self, attention_mask, block_size=50): | |
+ dtype = self.dtype | |
+ batch_size, seq_length = attention_mask.shape | |
+ causal_mask = torch.torch.tril( | |
+ torch.ones( | |
+ 1, | |
+ seq_length, | |
+ seq_length, | |
+ dtype=torch.bool, | |
+ device=attention_mask.device, | |
+ ) | |
+ ) | |
+ block_square_mask = [] | |
+ for start in range(0, seq_length, block_size): | |
+ end = min(start + block_size, seq_length) | |
+ length = end - start | |
+ block_square_mask.append(causal_mask.new_ones((length, length))) | |
+ block_square_mask = torch.block_diag(*block_square_mask) | |
+ block_causal_mask = causal_mask | block_square_mask | |
+ block_causal_mask = block_causal_mask & attention_mask[:, None, :] | |
+ block_causal_mask = block_causal_mask.to(dtype=dtype) # fp16 compatibility | |
+ block_causal_mask = (1.0 - block_causal_mask) * torch.finfo(dtype).min | |
+ block_causal_mask = block_causal_mask.unsqueeze(1) | |
+ return block_causal_mask | |
+ | |
def forward( | |
self, | |
input_features, | |
@@ -52,6 +163,7 @@ class WhisperEncoder(WhisperPreTrainedModel): | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
+ quantized_token_ids=None, | |
): | |
r""" | |
Args: | |
@@ -79,16 +191,26 @@ class WhisperEncoder(WhisperPreTrainedModel): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
""" | |
- expected_seq_length = ( | |
- self.config.max_source_positions | |
- * self.conv1.stride[0] | |
- * self.conv2.stride[0] | |
- ) | |
- if input_features.shape[-1] != expected_seq_length: | |
- raise ValueError( | |
- f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." | |
- ) | |
+ # expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] | |
+ # if input_features.shape[-1] != expected_seq_length: | |
+ # raise ValueError( | |
+ # f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." | |
+ # ) | |
+ | |
+ batch_size, feature_size, seq_length = input_features.shape | |
+ seq_length = seq_length // (self.conv1.stride[0] * self.conv2.stride[0]) | |
+ attention_mask = attention_mask[ | |
+ :, :: self.conv1.stride[0] * self.conv2.stride[0] | |
+ ] | |
+ if self.config.quantize_causal_block_size is not None: | |
+ extended_attention_mask = self.get_block_causal_attention_mask( | |
+ attention_mask, block_size=self.config.quantize_causal_block_size | |
+ ) | |
+ else: | |
+ extended_attention_mask = self.get_extended_attention_mask( | |
+ attention_mask, (batch_size, seq_length) | |
+ ) | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
@@ -108,7 +230,7 @@ class WhisperEncoder(WhisperPreTrainedModel): | |
inputs_embeds = inputs_embeds.permute(0, 2, 1) | |
embed_pos = self.embed_positions.weight | |
- hidden_states = inputs_embeds + embed_pos | |
+ hidden_states = inputs_embeds + embed_pos[:seq_length] | |
hidden_states = nn.functional.dropout( | |
hidden_states, p=self.dropout, training=self.training | |
) | |
@@ -116,12 +238,12 @@ class WhisperEncoder(WhisperPreTrainedModel): | |
encoder_states = () if output_hidden_states else None | |
all_attentions = () if output_attentions else None | |
+ assert attention_mask.shape[-1] == hidden_states.shape[1] | |
# check if head_mask has a correct number of layers specified if desired | |
if head_mask is not None: | |
assert head_mask.size()[0] == ( | |
len(self.layers) | |
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." | |
- | |
for idx, encoder_layer in enumerate(self.layers): | |
if output_hidden_states: | |
encoder_states = encoder_states + (hidden_states,) | |
@@ -139,14 +261,14 @@ class WhisperEncoder(WhisperPreTrainedModel): | |
layer_outputs = self._gradient_checkpointing_func( | |
encoder_layer.__call__, | |
hidden_states, | |
- None, | |
+ extended_attention_mask, | |
(head_mask[idx] if head_mask is not None else None), | |
output_attentions, | |
) | |
else: | |
layer_outputs = encoder_layer( | |
hidden_states, | |
- None, | |
+ extended_attention_mask, | |
layer_head_mask=( | |
head_mask[idx] if head_mask is not None else None | |
), | |
@@ -157,8 +279,217 @@ class WhisperEncoder(WhisperPreTrainedModel): | |
if output_attentions: | |
all_attentions = all_attentions + (layer_outputs[1],) | |
+ if ( | |
+ idx + 1 == self.config.pooling_position | |
+ and self.config.pooling_kernel_size is not None | |
+ ): | |
+ hidden_states = hidden_states.permute(0, 2, 1) | |
+ if hidden_states.shape[-1] % self.config.pooling_kernel_size != 0: | |
+ hidden_states = torch.nn.functional.pad( | |
+ hidden_states, | |
+ ( | |
+ 0, | |
+ self.config.pooling_kernel_size | |
+ - hidden_states.shape[-1] % self.config.pooling_kernel_size, | |
+ ), | |
+ ) | |
+ hidden_states = self.pooling_layer(hidden_states).permute(0, 2, 1) | |
+ attention_mask = attention_mask[:, :: self.config.pooling_kernel_size] | |
+ if self.config.quantize_causal_block_size is not None: | |
+ extended_attention_mask = self.get_block_causal_attention_mask( | |
+ attention_mask, | |
+ block_size=self.config.quantize_causal_block_size | |
+ // self.config.pooling_kernel_size, | |
+ ) | |
+ else: | |
+ extended_attention_mask = self.get_extended_attention_mask( | |
+ attention_mask, | |
+ (batch_size, seq_length // self.config.pooling_kernel_size), | |
+ ) | |
+ | |
+ if ( | |
+ idx + 1 == self.config.quantize_position | |
+ and self.config.quantize_vocab_size is not None | |
+ ): | |
+ if quantized_token_ids is not None: | |
+ hidden_states = self.codebook(quantized_token_ids) | |
+ else: | |
+ hidden_quantized, indices_flat, distances = vector_quantize( | |
+ hidden_states, self.codebook.weight | |
+ ) | |
+ quantized_token_ids = indices_flat.reshape( | |
+ batch_size, hidden_quantized.shape[1] | |
+ ) | |
+ if self.training: | |
+ encodings = torch.nn.functional.one_hot( | |
+ indices_flat, self.config.quantize_vocab_size | |
+ ).float() | |
+ encodings = encodings * attention_mask.reshape(-1, 1) | |
+ n = torch.sum(encodings, dim=0) | |
+ torch.distributed.all_reduce( | |
+ n, op=torch.distributed.ReduceOp.SUM | |
+ ) | |
+ self.num_active_codes = n.nonzero().shape[0] | |
+ if self.config.quantize_ema_decay: | |
+ hidden_flat = ( | |
+ hidden_states.detach() | |
+ .float() | |
+ .reshape(-1, hidden_states.shape[-1]) | |
+ ) | |
+ with torch.autocast( | |
+ device_type="cuda", dtype=torch.float32 | |
+ ): | |
+ dw = torch.matmul(encodings.t(), hidden_flat) | |
+ torch.distributed.all_reduce( | |
+ dw, op=torch.distributed.ReduceOp.SUM | |
+ ) | |
+ self.ema_count = ( | |
+ self.ema_count * self.config.quantize_ema_decay | |
+ + (1 - self.config.quantize_ema_decay) * n | |
+ ) | |
+ total_count = torch.sum(self.ema_count) | |
+ self.ema_count = ( | |
+ (self.ema_count + 1e-5) | |
+ / (total_count + self.config.quantize_vocab_size * 1e-5) | |
+ * total_count | |
+ ) | |
+ self.ema_weight = ( | |
+ self.ema_weight * self.config.quantize_ema_decay | |
+ + (1 - self.config.quantize_ema_decay) * dw | |
+ ) | |
+ self.codebook.weight.data = ( | |
+ self.ema_weight / self.ema_count.unsqueeze(1) | |
+ ) | |
+ self.quantize_loss = ( | |
+ self.config.quantize_loss_scale | |
+ * self.config.quantize_commit_coefficient | |
+ * mse_loss_with_mask( | |
+ hidden_states, | |
+ hidden_quantized.detach(), | |
+ attention_mask, | |
+ ) | |
+ ) | |
+ self.quantize_ema_count += 1 | |
+ if ( | |
+ self.config.quantize_restart_interval is not None | |
+ and self.quantize_ema_count | |
+ % self.config.quantize_restart_interval | |
+ == 0 | |
+ ): | |
+ rank, world_size = ( | |
+ torch.distributed.get_rank(), | |
+ torch.distributed.get_world_size(), | |
+ ) | |
+ segment_vocab_size = ( | |
+ self.config.quantize_vocab_size // world_size | |
+ ) | |
+ start_idx = segment_vocab_size * rank | |
+ ema_count_segment = self.ema_count[ | |
+ start_idx : start_idx + segment_vocab_size | |
+ ] | |
+ threshold = 1 * ( | |
+ self.config.quantize_ema_decay | |
+ ** self.config.quantize_restart_interval | |
+ ) | |
+ update_indices = ( | |
+ ema_count_segment < threshold | |
+ ).nonzero()[:, 0] + start_idx | |
+ num_update = update_indices.shape[0] | |
+ mask_flat = attention_mask.reshape(-1) > 0 | |
+ hidden_selected = hidden_flat[mask_flat] | |
+ hidden_update = hidden_selected[ | |
+ random.sample( | |
+ range(len(hidden_selected)), num_update | |
+ ) | |
+ ] | |
+ num_update = torch.as_tensor( | |
+ [num_update], | |
+ dtype=torch.long, | |
+ device=hidden_states.device, | |
+ ) | |
+ num_update_list = [ | |
+ torch.as_tensor( | |
+ [0], | |
+ dtype=torch.long, | |
+ device=hidden_states.device, | |
+ ) | |
+ for _ in range(world_size) | |
+ ] | |
+ torch.distributed.all_gather( | |
+ num_update_list, num_update | |
+ ) | |
+ update_indices_list = [ | |
+ torch.zeros( | |
+ num.item(), | |
+ dtype=torch.long, | |
+ device=hidden_states.device, | |
+ ) | |
+ for num in num_update_list | |
+ ] | |
+ torch.distributed.all_gather( | |
+ update_indices_list, update_indices | |
+ ) | |
+ update_indices = torch.cat(update_indices_list) | |
+ hidden_update_list = [ | |
+ torch.zeros( | |
+ num.item(), | |
+ hidden_flat.shape[-1], | |
+ dtype=hidden_update.dtype, | |
+ device=hidden_states.device, | |
+ ) | |
+ for num in num_update_list | |
+ ] | |
+ torch.distributed.all_gather( | |
+ hidden_update_list, hidden_update | |
+ ) | |
+ hidden_update = torch.cat(hidden_update_list) | |
+ self.codebook.weight.data[update_indices] = ( | |
+ hidden_update | |
+ ) | |
+ self.ema_count[update_indices] = 1 | |
+ self.ema_weight[update_indices] = hidden_update | |
+ if torch.distributed.get_rank() == 0: | |
+ print(f"restart {len(update_indices)} tokens") | |
+ else: | |
+ loss = self.config.quantize_loss_scale * ( | |
+ self.config.quantize_commit_coefficient | |
+ * mse_loss_with_mask( | |
+ hidden_states, | |
+ hidden_quantized.detach(), | |
+ attention_mask, | |
+ ) | |
+ + mse_loss_with_mask( | |
+ hidden_quantized, | |
+ hidden_states.detach(), | |
+ attention_mask, | |
+ ) | |
+ ) | |
+ self.quantize_loss = loss | |
+ hidden_states = ( | |
+ hidden_states + (hidden_quantized - hidden_states).detach() | |
+ ) | |
+ else: | |
+ hidden_states = hidden_quantized | |
+ hidden_states = ( | |
+ hidden_states | |
+ + self.embed_positions2.weight[: hidden_states.shape[1]] | |
+ ) | |
+ | |
+ if idx + 1 == self.save_hidden_position: | |
+ import numpy as np | |
+ import uuid | |
- hidden_states = self.layer_norm(hidden_states) | |
+ to_save = [] | |
+ for batch_idx, hidden_state in enumerate(hidden_states): | |
+ for seq_idx, hidden in enumerate(hidden_state): | |
+ if attention_mask[batch_idx, seq_idx]: | |
+ to_save.append(hidden.detach().cpu().numpy()) | |
+ np.save( | |
+ os.path.join(self.save_hidden_dir, f"{str(uuid.uuid4())}.npy"), | |
+ to_save, | |
+ ) | |
+ if not self.config.quantize_encoder_only: | |
+ hidden_states = self.layer_norm(hidden_states) | |
if output_hidden_states: | |
encoder_states = encoder_states + (hidden_states,) | |
@@ -168,8 +499,9 @@ class WhisperEncoder(WhisperPreTrainedModel): | |
for v in [hidden_states, encoder_states, all_attentions] | |
if v is not None | |
) | |
- return BaseModelOutput( | |
+ return QuantizedBaseModelOutput( | |
last_hidden_state=hidden_states, | |
hidden_states=encoder_states, | |
attentions=all_attentions, | |
+ quantized_token_ids=quantized_token_ids, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment