Skip to content

Instantly share code, notes, and snippets.

@vra
Created February 26, 2025 09:53
Show Gist options
  • Save vra/e39be847ad51dcd5dc1d5b8e8b63a6e4 to your computer and use it in GitHub Desktop.
Save vra/e39be847ad51dcd5dc1d5b8e8b63a6e4 to your computer and use it in GitHub Desktop.
diff --git a/modeling_whisper_official0.py b/modeling_whisper0.py
index 1285034..4ab3bb0 100644
--- a/modeling_whisper_official0.py
+++ b/modeling_whisper0.py
@@ -1,4 +1,4 @@
-class WhisperEncoder(WhisperPreTrainedModel):
+class WhisperVQEncoder(WhisperPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`WhisperEncoderLayer`].
@@ -7,8 +7,9 @@ class WhisperEncoder(WhisperPreTrainedModel):
config: WhisperConfig
"""
- def __init__(self, config: WhisperConfig):
+ def __init__(self, config: WhisperVQConfig):
super().__init__(config)
+ self.config = config
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
@@ -17,22 +18,107 @@ class WhisperEncoder(WhisperPreTrainedModel):
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
-
- self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
- self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
+ if config.encoder_causal_convolution:
+ conv_class = CausalConv1d
+ else:
+ conv_class = nn.Conv1d
+ self.conv1 = conv_class(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
+ self.conv2 = conv_class(
+ embed_dim, embed_dim, kernel_size=3, stride=2, padding=1
+ )
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.requires_grad_(False)
-
- self.layers = nn.ModuleList(
- [WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]
- )
- self.layer_norm = nn.LayerNorm(config.d_model)
+ if config.quantize_encoder_only:
+ self.layers = nn.ModuleList(
+ [
+ WhisperVQEncoderLayer(
+ config,
+ is_causal=config.encoder_causal_attention
+ or config.quantize_causal_encoder,
+ )
+ for _ in range(config.quantize_position)
+ ]
+ )
+ else:
+ self.layers = nn.ModuleList(
+ [
+ WhisperVQEncoderLayer(
+ config,
+ is_causal=config.encoder_causal_attention
+ or (
+ config.quantize_causal_encoder
+ and layer_id < config.quantize_position
+ ),
+ )
+ for layer_id in range(config.encoder_layers)
+ ]
+ )
+ self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
+ # Parameters related to pooling layer
+ self.pooling_layer = None
+ # Parameters related to quantization layer
+ self.codebook = None
+ self.embed_positions2 = None
+ self.quantize_loss = None
+ self.num_active_codes = None
+ self.quantize_ema_count = 0
+ # Save hiddens
+ self.save_hidden_dir = None
+ self.save_hidden_position = None
# Initialize weights and apply final processing
+ self.init_pooling_layer(config)
+ self.init_quantize_layer(config)
self.post_init()
+ def init_pooling_layer(self, config: WhisperVQConfig):
+ if config.pooling_kernel_size is not None:
+ if config.pooling_type == "max":
+ self.pooling_layer = nn.MaxPool1d(
+ kernel_size=config.pooling_kernel_size
+ )
+ elif config.pooling_type == "avg":
+ self.pooling_layer = nn.AvgPool1d(
+ kernel_size=config.pooling_kernel_size
+ )
+ else:
+ raise NotImplementedError(
+ f"Pooling type {config.pooling_type} not implemented"
+ )
+
+ def init_quantize_layer(self, config: WhisperVQConfig, quantize_load_codebook=None):
+ if config.quantize_vocab_size is not None:
+ if config.pooling_position is not None:
+ assert config.quantize_position >= config.pooling_position
+ self.codebook = nn.Embedding(
+ config.quantize_vocab_size, self.config.d_model
+ )
+ if quantize_load_codebook is not None:
+ init_codes = np.load(quantize_load_codebook)
+ self.codebook.weight.data.copy_(torch.from_numpy(init_codes))
+ max_source_positions = self.max_source_positions
+ if config.pooling_kernel_size is not None:
+ max_source_positions = math.ceil(
+ max_source_positions / self.config.pooling_kernel_size
+ )
+ self.embed_positions2 = nn.Embedding(
+ max_source_positions, self.config.d_model
+ )
+ self.embed_positions2.weight.data.copy_(
+ self.embed_positions.weight.data[:max_source_positions]
+ )
+ if config.quantize_ema_decay is not None:
+ self.codebook.weight.requires_grad = False
+ self.register_buffer(
+ "ema_count",
+ torch.ones(config.quantize_vocab_size, dtype=torch.float),
+ )
+ self.register_buffer(
+ "ema_weight", self.codebook.weight.data.clone().float()
+ )
+
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
@@ -44,6 +130,31 @@ class WhisperEncoder(WhisperPreTrainedModel):
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
+ def get_block_causal_attention_mask(self, attention_mask, block_size=50):
+ dtype = self.dtype
+ batch_size, seq_length = attention_mask.shape
+ causal_mask = torch.torch.tril(
+ torch.ones(
+ 1,
+ seq_length,
+ seq_length,
+ dtype=torch.bool,
+ device=attention_mask.device,
+ )
+ )
+ block_square_mask = []
+ for start in range(0, seq_length, block_size):
+ end = min(start + block_size, seq_length)
+ length = end - start
+ block_square_mask.append(causal_mask.new_ones((length, length)))
+ block_square_mask = torch.block_diag(*block_square_mask)
+ block_causal_mask = causal_mask | block_square_mask
+ block_causal_mask = block_causal_mask & attention_mask[:, None, :]
+ block_causal_mask = block_causal_mask.to(dtype=dtype) # fp16 compatibility
+ block_causal_mask = (1.0 - block_causal_mask) * torch.finfo(dtype).min
+ block_causal_mask = block_causal_mask.unsqueeze(1)
+ return block_causal_mask
+
def forward(
self,
input_features,
@@ -52,6 +163,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
+ quantized_token_ids=None,
):
r"""
Args:
@@ -79,16 +191,26 @@ class WhisperEncoder(WhisperPreTrainedModel):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
- expected_seq_length = (
- self.config.max_source_positions
- * self.conv1.stride[0]
- * self.conv2.stride[0]
- )
- if input_features.shape[-1] != expected_seq_length:
- raise ValueError(
- f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
- )
+ # expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
+ # if input_features.shape[-1] != expected_seq_length:
+ # raise ValueError(
+ # f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
+ # )
+
+ batch_size, feature_size, seq_length = input_features.shape
+ seq_length = seq_length // (self.conv1.stride[0] * self.conv2.stride[0])
+ attention_mask = attention_mask[
+ :, :: self.conv1.stride[0] * self.conv2.stride[0]
+ ]
+ if self.config.quantize_causal_block_size is not None:
+ extended_attention_mask = self.get_block_causal_attention_mask(
+ attention_mask, block_size=self.config.quantize_causal_block_size
+ )
+ else:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, (batch_size, seq_length)
+ )
output_attentions = (
output_attentions
if output_attentions is not None
@@ -108,7 +230,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight
- hidden_states = inputs_embeds + embed_pos
+ hidden_states = inputs_embeds + embed_pos[:seq_length]
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
@@ -116,12 +238,12 @@ class WhisperEncoder(WhisperPreTrainedModel):
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
+ assert attention_mask.shape[-1] == hidden_states.shape[1]
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
-
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
@@ -139,14 +261,14 @@ class WhisperEncoder(WhisperPreTrainedModel):
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
- None,
+ extended_attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
- None,
+ extended_attention_mask,
layer_head_mask=(
head_mask[idx] if head_mask is not None else None
),
@@ -157,8 +279,217 @@ class WhisperEncoder(WhisperPreTrainedModel):
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
+ if (
+ idx + 1 == self.config.pooling_position
+ and self.config.pooling_kernel_size is not None
+ ):
+ hidden_states = hidden_states.permute(0, 2, 1)
+ if hidden_states.shape[-1] % self.config.pooling_kernel_size != 0:
+ hidden_states = torch.nn.functional.pad(
+ hidden_states,
+ (
+ 0,
+ self.config.pooling_kernel_size
+ - hidden_states.shape[-1] % self.config.pooling_kernel_size,
+ ),
+ )
+ hidden_states = self.pooling_layer(hidden_states).permute(0, 2, 1)
+ attention_mask = attention_mask[:, :: self.config.pooling_kernel_size]
+ if self.config.quantize_causal_block_size is not None:
+ extended_attention_mask = self.get_block_causal_attention_mask(
+ attention_mask,
+ block_size=self.config.quantize_causal_block_size
+ // self.config.pooling_kernel_size,
+ )
+ else:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask,
+ (batch_size, seq_length // self.config.pooling_kernel_size),
+ )
+
+ if (
+ idx + 1 == self.config.quantize_position
+ and self.config.quantize_vocab_size is not None
+ ):
+ if quantized_token_ids is not None:
+ hidden_states = self.codebook(quantized_token_ids)
+ else:
+ hidden_quantized, indices_flat, distances = vector_quantize(
+ hidden_states, self.codebook.weight
+ )
+ quantized_token_ids = indices_flat.reshape(
+ batch_size, hidden_quantized.shape[1]
+ )
+ if self.training:
+ encodings = torch.nn.functional.one_hot(
+ indices_flat, self.config.quantize_vocab_size
+ ).float()
+ encodings = encodings * attention_mask.reshape(-1, 1)
+ n = torch.sum(encodings, dim=0)
+ torch.distributed.all_reduce(
+ n, op=torch.distributed.ReduceOp.SUM
+ )
+ self.num_active_codes = n.nonzero().shape[0]
+ if self.config.quantize_ema_decay:
+ hidden_flat = (
+ hidden_states.detach()
+ .float()
+ .reshape(-1, hidden_states.shape[-1])
+ )
+ with torch.autocast(
+ device_type="cuda", dtype=torch.float32
+ ):
+ dw = torch.matmul(encodings.t(), hidden_flat)
+ torch.distributed.all_reduce(
+ dw, op=torch.distributed.ReduceOp.SUM
+ )
+ self.ema_count = (
+ self.ema_count * self.config.quantize_ema_decay
+ + (1 - self.config.quantize_ema_decay) * n
+ )
+ total_count = torch.sum(self.ema_count)
+ self.ema_count = (
+ (self.ema_count + 1e-5)
+ / (total_count + self.config.quantize_vocab_size * 1e-5)
+ * total_count
+ )
+ self.ema_weight = (
+ self.ema_weight * self.config.quantize_ema_decay
+ + (1 - self.config.quantize_ema_decay) * dw
+ )
+ self.codebook.weight.data = (
+ self.ema_weight / self.ema_count.unsqueeze(1)
+ )
+ self.quantize_loss = (
+ self.config.quantize_loss_scale
+ * self.config.quantize_commit_coefficient
+ * mse_loss_with_mask(
+ hidden_states,
+ hidden_quantized.detach(),
+ attention_mask,
+ )
+ )
+ self.quantize_ema_count += 1
+ if (
+ self.config.quantize_restart_interval is not None
+ and self.quantize_ema_count
+ % self.config.quantize_restart_interval
+ == 0
+ ):
+ rank, world_size = (
+ torch.distributed.get_rank(),
+ torch.distributed.get_world_size(),
+ )
+ segment_vocab_size = (
+ self.config.quantize_vocab_size // world_size
+ )
+ start_idx = segment_vocab_size * rank
+ ema_count_segment = self.ema_count[
+ start_idx : start_idx + segment_vocab_size
+ ]
+ threshold = 1 * (
+ self.config.quantize_ema_decay
+ ** self.config.quantize_restart_interval
+ )
+ update_indices = (
+ ema_count_segment < threshold
+ ).nonzero()[:, 0] + start_idx
+ num_update = update_indices.shape[0]
+ mask_flat = attention_mask.reshape(-1) > 0
+ hidden_selected = hidden_flat[mask_flat]
+ hidden_update = hidden_selected[
+ random.sample(
+ range(len(hidden_selected)), num_update
+ )
+ ]
+ num_update = torch.as_tensor(
+ [num_update],
+ dtype=torch.long,
+ device=hidden_states.device,
+ )
+ num_update_list = [
+ torch.as_tensor(
+ [0],
+ dtype=torch.long,
+ device=hidden_states.device,
+ )
+ for _ in range(world_size)
+ ]
+ torch.distributed.all_gather(
+ num_update_list, num_update
+ )
+ update_indices_list = [
+ torch.zeros(
+ num.item(),
+ dtype=torch.long,
+ device=hidden_states.device,
+ )
+ for num in num_update_list
+ ]
+ torch.distributed.all_gather(
+ update_indices_list, update_indices
+ )
+ update_indices = torch.cat(update_indices_list)
+ hidden_update_list = [
+ torch.zeros(
+ num.item(),
+ hidden_flat.shape[-1],
+ dtype=hidden_update.dtype,
+ device=hidden_states.device,
+ )
+ for num in num_update_list
+ ]
+ torch.distributed.all_gather(
+ hidden_update_list, hidden_update
+ )
+ hidden_update = torch.cat(hidden_update_list)
+ self.codebook.weight.data[update_indices] = (
+ hidden_update
+ )
+ self.ema_count[update_indices] = 1
+ self.ema_weight[update_indices] = hidden_update
+ if torch.distributed.get_rank() == 0:
+ print(f"restart {len(update_indices)} tokens")
+ else:
+ loss = self.config.quantize_loss_scale * (
+ self.config.quantize_commit_coefficient
+ * mse_loss_with_mask(
+ hidden_states,
+ hidden_quantized.detach(),
+ attention_mask,
+ )
+ + mse_loss_with_mask(
+ hidden_quantized,
+ hidden_states.detach(),
+ attention_mask,
+ )
+ )
+ self.quantize_loss = loss
+ hidden_states = (
+ hidden_states + (hidden_quantized - hidden_states).detach()
+ )
+ else:
+ hidden_states = hidden_quantized
+ hidden_states = (
+ hidden_states
+ + self.embed_positions2.weight[: hidden_states.shape[1]]
+ )
+
+ if idx + 1 == self.save_hidden_position:
+ import numpy as np
+ import uuid
- hidden_states = self.layer_norm(hidden_states)
+ to_save = []
+ for batch_idx, hidden_state in enumerate(hidden_states):
+ for seq_idx, hidden in enumerate(hidden_state):
+ if attention_mask[batch_idx, seq_idx]:
+ to_save.append(hidden.detach().cpu().numpy())
+ np.save(
+ os.path.join(self.save_hidden_dir, f"{str(uuid.uuid4())}.npy"),
+ to_save,
+ )
+ if not self.config.quantize_encoder_only:
+ hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
@@ -168,8 +499,9 @@ class WhisperEncoder(WhisperPreTrainedModel):
for v in [hidden_states, encoder_states, all_attentions]
if v is not None
)
- return BaseModelOutput(
+ return QuantizedBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
+ quantized_token_ids=quantized_token_ids,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment