Created
December 7, 2024 01:39
-
-
Save zjlww/64f254cb8ee553fcfa1408dffd843484 to your computer and use it in GitHub Desktop.
Stripped AudioCodecModel from NeMo @ bde672e
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import torch | |
from torch import nn, Tensor | |
import torch.nn.functional as F | |
from einops import rearrange | |
from .modules import HiFiGANEncoder, HiFiGANDecoder, GroupFiniteScalarQuantizer | |
class AudioCodecModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.sample_rate = 22050 | |
self.samples_per_frame = 1024 | |
self.audio_encoder = HiFiGANEncoder( | |
down_sample_rates=(2, 2, 4, 8, 8), | |
encoded_dim=32, | |
base_channels=48, | |
resblock_dilation_sizes=(1,), | |
) | |
self.audio_decoder = HiFiGANDecoder( | |
up_sample_rates=(8, 8, 4, 2, 2), | |
input_dim=32, | |
base_channels=1024, | |
) | |
self.vector_quantizer = GroupFiniteScalarQuantizer(8, [8, 7, 6, 6]) | |
def encode_audio(self, audio: Tensor, audio_len: Tensor) -> Tuple[Tensor, Tensor]: | |
"""Apply encoder on the input audio signal. Input will be padded with zeros so | |
the last frame has full `self.samples_per_frame` samples. | |
Args: | |
audio (Tensor): [B, T_audio]. | |
audio_len (LongTensor): [B]. | |
Returns: | |
encoded (Tensor): [B, D, T_encoded]. | |
encoded_len (LongTensor): [B]. | |
""" | |
audio, audio_len = self.pad_audio(audio, audio_len) | |
encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) | |
return encoded, encoded_len | |
def decode_audio(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor]: | |
"""Apply decoder on the input. Note that the input is a non-quantized encoder output or a dequantized representation. | |
Args: | |
inputs (Tensor): [B, D, T_encoded]. | |
input_len (LongTensor): [B]. Valid length for each example in the batch | |
Returns: | |
audio (Tensor): [B, T_audio]. | |
Decoded output `audio` in the time domain and its length in number of samples `audio_len`. | |
audio_len (LongTensor): [B]. | |
Note that `audio_len` will be a multiple of `self.samples_per_frame`. | |
""" | |
audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len) | |
return audio, audio_len | |
def quantize(self, encoded: Tensor, encoded_len: Tensor) -> Tensor: | |
"""Quantize the continuous encoded representation into a discrete | |
representation for each frame. | |
Args: | |
encoded (Tensor): [B, D, T_encoded]. Encoded signal representation. | |
encoded_len (Tensor): [B]. Valid length of the encoded representation in frames. | |
Returns: | |
tokens (Tensor): [B, C, T_encoded]. A tensor of tokens for each codebook for each frame. | |
""" | |
if not self.vector_quantizer: | |
raise ValueError("Cannot quantize without quantizer") | |
# vector quantizer is returning [C, B, T], where C is the number of codebooks | |
tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len) | |
# use batch first for the output | |
tokens = rearrange(tokens, "C B T -> B C T") | |
return tokens | |
def dequantize(self, tokens: Tensor, tokens_len: Tensor) -> Tensor: | |
"""Convert the discrete tokens into a continuous encoded representation. | |
Args: | |
tokens (Tensor): [B, C, T_encoded]. Discrete tokens for each codebook for each time frame. | |
tokens_len (Tensor): [B]. Valid length of each example in the batch. | |
Returns: | |
dequantized (Tensor): [B, D, T_encoded]. Continuous encoded representation of the discrete input representation. | |
""" | |
if not self.vector_quantizer: | |
raise ValueError("Cannot dequantize without quantizer") | |
# vector quantizer is using [C, B, T], where C is the number of codebooks | |
tokens = rearrange(tokens, "B C T -> C B T") | |
dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len) | |
return dequantized | |
def encode(self, audio: Tensor, audio_len: Tensor) -> Tuple[Tensor, Tensor]: | |
"""Convert input time-domain audio signal into a discrete representation (tokens). | |
Args: | |
audio (Tensor): input time-domain signal, shape `(B, T_audio)` | |
audio_len (Tensor): valid length for each example in the batch, shape `(B,)` | |
Returns: | |
tokens (Tensor): Tokens for each codebook for each frame, shape `(B, C, T_encoded)` | |
encoded_len (Tensor): Corresponding valid lengths, shape `(B,)` | |
""" | |
# Apply encoder to obtain a continuous vector for each frame | |
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len) | |
# Apply quantizer to obtain discrete representation per frame | |
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len) | |
return tokens, encoded_len | |
def decode(self, tokens: Tensor, tokens_len: Tensor) -> Tuple[Tensor, Tensor]: | |
"""Convert discrete tokens into a continuous time-domain signal. | |
Args: | |
tokens (Tensor): [B, C, T_encoded]. Discrete tokens for each codebook for each time frame. | |
tokens_len (Tensor): [B]. Valid lengths for each example in the batch. | |
Returns: | |
audio (Tensor): [B, T_audio]. Decoded output `audio` in the time domain. | |
audio_len (Tensor): [B]. Length of the decoded audio in number of samples. | |
Note that `audio_len` will be a multiple of `self.samples_per_frame`. | |
""" | |
# Convert a discrete representation to a dequantized vector for each frame | |
dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len) | |
# Apply decoder to obtain time-domain audio for each frame | |
audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len) | |
return audio, audio_len | |
def forward(self, audio: Tensor, audio_len: Tensor) -> Tuple[Tensor, Tensor]: | |
"""Apply encoder, quantizer, decoder on the input time-domain signal. | |
Args: | |
audio (Tensor): input time-domain signal, shape `(B, T_audio)` | |
audio_len (Tensor): valid length for each example in the batch, shape `(B,)` | |
Returns: | |
output_audio (Tensor): Reconstructed time-domain signal, shape `(B, T_audio)` | |
output_audio_len (Tensor): Length of the reconstructed audio in number of samples, shape `(B,)` | |
""" | |
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len) | |
if self.vector_quantizer: | |
# quantize to discrete tokens | |
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len) | |
# decode tokens to audio | |
output_audio, output_audio_len = self.decode( | |
tokens=tokens, tokens_len=encoded_len | |
) | |
else: | |
# no quantization, directly decode to audio | |
output_audio, output_audio_len = self.decode_audio( | |
inputs=encoded, input_len=encoded_len | |
) | |
return output_audio, output_audio_len | |
def pad_audio(self, audio: Tensor, audio_len: Tensor) -> Tuple[Tensor, Tensor]: | |
"""Zero pad the end of the audio so that we do not have a partial end frame. | |
The output will be zero-padded to have an integer number of frames of | |
length `self.samples_per_frame`. | |
Args: | |
audio (Tensor): input time-domain signal, shape `(B, T_audio)` | |
audio_len (Tensor): valid length for each example in the batch, shape `(B,)` | |
Returns: | |
padded_audio (Tensor): Padded time-domain signal, shape `(B, T_padded)` | |
padded_len (Tensor): Length of the padded audio, shape `(B,)` | |
""" | |
padded_len = ( | |
self.samples_per_frame | |
* torch.ceil(audio_len / self.samples_per_frame).int() | |
) | |
max_len = padded_len.max().item() | |
num_padding = max_len - audio.shape[1] | |
padded_audio = F.pad(audio, (0, num_padding)) | |
return padded_audio, padded_len |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
from typing import Iterable, List, Optional, Tuple, Union | |
import numpy as np | |
import math | |
import torch | |
from torch import Tensor | |
import random | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
import logging | |
from .utils import ClampActivation, HalfSnake, Snake, mask_sequence_tensor | |
CONSTANT = 1e-5 | |
def get_padding(kernel_size: int, dilation: int = 1) -> int: | |
return (kernel_size * dilation - dilation) // 2 | |
def get_padding_2d( | |
kernel_size: Tuple[int, int], dilation: Tuple[int, int] | |
) -> Tuple[int, int]: | |
paddings = ( | |
get_padding(kernel_size[0], dilation[0]), | |
get_padding(kernel_size[1], dilation[1]), | |
) | |
return paddings | |
def get_down_sample_padding(kernel_size: int, stride: int) -> int: | |
return (kernel_size - stride + 1) // 2 | |
def get_up_sample_padding(kernel_size: int, stride: int) -> Tuple[int, int]: | |
output_padding = (kernel_size - stride) % 2 | |
padding = (kernel_size - stride + 1) // 2 | |
return padding, output_padding | |
class CodecActivation(nn.Module): | |
""" | |
Choose between activation based on the input parameter. | |
Args: | |
activation: Name of activation to use. Valid options are "elu" (default), "lrelu", and "snake". | |
channels: Input dimension. | |
""" | |
def __init__(self, activation: str = "elu", channels: int = 1): | |
super().__init__() | |
activation = activation.lower() | |
if activation == "elu": | |
self.activation = nn.ELU() | |
elif activation == "lrelu": | |
self.activation = torch.nn.LeakyReLU() | |
elif activation == "snake": | |
self.activation = Snake(channels) | |
elif activation == "half_snake": | |
self.activation = HalfSnake(channels) | |
else: | |
raise ValueError(f"Unknown activation {activation}") | |
def forward(self, x): | |
return self.activation(x) | |
class Conv1dNorm(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
stride: int = 1, | |
dilation: int = 1, | |
padding: Optional[int] = None, | |
): | |
super().__init__() | |
if not padding: | |
padding = get_padding(kernel_size=kernel_size, dilation=dilation) | |
conv = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
padding_mode="reflect", | |
) | |
self.conv = nn.utils.weight_norm(conv) | |
def remove_weight_norm(self): | |
nn.utils.remove_weight_norm(self.conv) | |
def forward(self, inputs, input_len): | |
out = self.conv(inputs) | |
out = mask_sequence_tensor(out, input_len) | |
return out | |
class ConvTranspose1dNorm(nn.Module): | |
def __init__( | |
self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1 | |
): | |
super().__init__() | |
padding, output_padding = get_up_sample_padding(kernel_size, stride) | |
conv = nn.ConvTranspose1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding, | |
padding_mode="zeros", | |
) | |
self.conv = nn.utils.weight_norm(conv) | |
def remove_weight_norm(self): | |
nn.utils.remove_weight_norm(self.conv) | |
def forward(self, inputs, input_len): | |
out = self.conv(inputs) | |
out = mask_sequence_tensor(out, input_len) | |
return out | |
class Conv2dNorm(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Tuple[int, int], | |
stride: Tuple[int, int] = (1, 1), | |
dilation: Tuple[int, int] = (1, 1), | |
): | |
super().__init__() | |
assert len(kernel_size) == len(dilation) | |
padding = get_padding_2d(kernel_size, dilation) | |
conv = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding, | |
padding_mode="reflect", | |
) | |
self.conv = nn.utils.weight_norm(conv) | |
def remove_weight_norm(self): | |
nn.utils.remove_weight_norm(self.conv) | |
def forward(self, inputs): | |
return self.conv(inputs) | |
class PeriodDiscriminator(nn.Module): | |
""" | |
Period discriminator introduced in HiFi-GAN https://arxiv.org/abs/2010.05646 which attempts to | |
discriminate phase information by looking at equally spaced audio samples. | |
Args: | |
period: Spacing between audio sample inputs. | |
lrelu_slope: Slope to use for activation. Leaky relu with slope of 0.1 or 0.2 is recommended for the | |
stability of the feature matching loss. | |
""" | |
def __init__(self, period, lrelu_slope=0.1): | |
super().__init__() | |
self.period = period | |
self.activation = nn.LeakyReLU(lrelu_slope) | |
self.conv_layers = nn.ModuleList( | |
[ | |
Conv2dNorm(1, 32, kernel_size=(5, 1), stride=(3, 1)), | |
Conv2dNorm(32, 128, kernel_size=(5, 1), stride=(3, 1)), | |
Conv2dNorm(128, 512, kernel_size=(5, 1), stride=(3, 1)), | |
Conv2dNorm(512, 1024, kernel_size=(5, 1), stride=(3, 1)), | |
Conv2dNorm(1024, 1024, kernel_size=(5, 1), stride=(1, 1)), | |
] | |
) | |
self.conv_post = Conv2dNorm(1024, 1, kernel_size=(3, 1)) | |
def forward(self, audio): | |
batch_size, time = audio.shape | |
out = rearrange(audio, "B T -> B 1 T") | |
# Pad audio so that it is divisible by the period | |
if time % self.period != 0: | |
n_pad = self.period - (time % self.period) | |
out = F.pad(out, (0, n_pad), "reflect") | |
time = time + n_pad | |
# [batch, 1, (time / period), period] | |
out = out.view(batch_size, 1, time // self.period, self.period) | |
fmap = [] | |
for conv in self.conv_layers: | |
# [batch, filters, (time / period / stride), period] | |
out = conv(inputs=out) | |
out = self.activation(out) | |
fmap.append(out) | |
# [batch, 1, (time / period / strides), period] | |
score = self.conv_post(inputs=out) | |
fmap.append(score) | |
score = rearrange(score, "B 1 T C -> B C T") | |
return score, fmap | |
class MultiPeriodDiscriminator(nn.Module): | |
""" | |
Wrapper class to aggregate results of multiple period discriminators. | |
The periods are expected to be increasing prime numbers in order to maximize coverage and minimize overlap | |
""" | |
def __init__(self, periods: Iterable[int] = (2, 3, 5, 7, 11), lrelu_slope=0.1): | |
super().__init__() | |
self.discriminators = nn.ModuleList( | |
[ | |
PeriodDiscriminator(period=period, lrelu_slope=lrelu_slope) | |
for period in periods | |
] | |
) | |
def forward(self, audio_real, audio_gen): | |
scores_real = [] | |
scores_gen = [] | |
fmaps_real = [] | |
fmaps_gen = [] | |
for discriminator in self.discriminators: | |
score_real, fmap_real = discriminator(audio=audio_real) | |
score_gen, fmap_gen = discriminator(audio=audio_gen) | |
scores_real.append(score_real) | |
fmaps_real.append(fmap_real) | |
scores_gen.append(score_gen) | |
fmaps_gen.append(fmap_gen) | |
return scores_real, scores_gen, fmaps_real, fmaps_gen | |
class DiscriminatorSTFT(nn.Module): | |
""" | |
Discriminator network from EnCodec for Complex STFT input, but without dilations. | |
Args: | |
filters: number of filters to use in Conv2d layers | |
lrelu_slope: Slope to use for activations. Leaky relu with slope of 0.1 or 0.2 is recommended for the | |
stability of the feature matching loss | |
""" | |
def __init__(self, filters: int = 32, lrelu_slope: float = 0.1): | |
super().__init__() | |
self.activation = nn.LeakyReLU(lrelu_slope) | |
self.conv_layers = nn.ModuleList( | |
[ | |
Conv2dNorm(2, filters, kernel_size=(3, 9)), | |
Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), | |
Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), | |
Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), | |
Conv2dNorm(filters, filters, kernel_size=(3, 3)), | |
] | |
) | |
self.conv_post = Conv2dNorm(filters, 1, kernel_size=(3, 3)) | |
def forward(self, spec): | |
fmap = [] | |
# [batch, 2, T_spec, fft] | |
out = spec | |
for conv in self.conv_layers: | |
# [batch, filters, T_spec, fft // strides] | |
out = conv(inputs=out) | |
out = self.activation(out) | |
fmap.append(out) | |
# [batch, 1, T_spec, fft // 8] | |
scores = self.conv_post(inputs=out) | |
fmap.append(scores) | |
scores = rearrange(scores, "B 1 T C -> B C T") | |
return scores, fmap | |
class MultiBandDiscriminatorSTFT(nn.Module): | |
""" | |
Multi-band STFT discriminator proposed in DAC (https://arxiv.org/abs/2306.06546). | |
Computes the complex STFT for a given resolution and splits it into sub-bands, | |
which are given to separate discriminator networks. | |
Args: | |
resolution: STFT resolution, provided as a tuple of 3 integers ordered (num_fft, hop_length, window_length) | |
stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end). | |
The floats are in the range [0, 1] representing the fraction of all stft bands. | |
For example for n_fft=1024, the stft output has 513 dimensions. | |
For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512]. | |
""" | |
def __init__( | |
self, resolution: Tuple[int, ...], stft_bands: Iterable[Tuple[int, int]] | |
): | |
super().__init__() | |
self.n_fft, self.hop_length, self.win_length = resolution | |
self.register_buffer( | |
"window", torch.hann_window(self.win_length, periodic=False) | |
) | |
self.discriminators = nn.ModuleList([DiscriminatorSTFT() for _ in stft_bands]) | |
n_stft = self.n_fft // 2 + 1 | |
self.stft_bands = [ | |
(int(band[0] * n_stft), int(band[1] * n_stft)) for band in stft_bands | |
] | |
def compute_stft(self, audio): | |
# [B, fft, T_spec] | |
fft = torch.stft( | |
audio, | |
n_fft=self.n_fft, | |
hop_length=self.hop_length, | |
win_length=self.win_length, | |
window=self.window, | |
normalized=True, | |
center=True, | |
return_complex=True, | |
) | |
fft = rearrange(fft, "B fft T -> B T fft") | |
# [batch, 2, T_spec, fft] | |
out = torch.stack([fft.real, fft.imag], dim=1) | |
return out | |
def forward(self, audio): | |
scores_list = [] | |
fmap_list = [] | |
spec = self.compute_stft(audio) | |
for band, disc in zip(self.stft_bands, self.discriminators): | |
spec_band = spec[:, :, :, band[0] : band[1]] | |
score, fmap = disc(spec=spec_band) | |
scores_list.append(score) | |
fmap_list.append(fmap) | |
return scores_list, fmap_list | |
class MultiResolutionDiscriminatorSTFT(nn.Module): | |
""" | |
Multi-resolution discriminator which creates a multi-band discriminator for each input resolution. | |
Args: | |
resolutions: List of STFT resolutions, each resolution provided as a tuple of 3 integers ordered | |
(num_fft, hop_length, window_length) | |
stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end). | |
The floats are in the range [0, 1] representing the fraction of all stft bands. | |
For example for n_fft=1024, the stft output has 513 dimensions. | |
For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512]. | |
""" | |
def __init__( | |
self, resolutions: Iterable[Tuple[int]], stft_bands: Iterable[Tuple[int, int]] | |
): | |
super().__init__() | |
self.discriminators = nn.ModuleList( | |
[ | |
MultiBandDiscriminatorSTFT(resolution=resolution, stft_bands=stft_bands) | |
for resolution in resolutions | |
] | |
) | |
def forward(self, audio_real, audio_gen): | |
scores_real = [] | |
scores_gen = [] | |
fmaps_real = [] | |
fmaps_gen = [] | |
for disc in self.discriminators: | |
score_real_i, fmap_real_i = disc(audio=audio_real) | |
scores_real = scores_real + score_real_i | |
fmaps_real = fmaps_real + fmap_real_i | |
score_gen_i, fmap_gen_i = disc(audio=audio_gen) | |
scores_gen = scores_gen + score_gen_i | |
fmaps_gen = fmaps_gen + fmap_gen_i | |
return scores_real, scores_gen, fmaps_real, fmaps_gen | |
class Discriminator(nn.Module): | |
""" | |
Wrapper class which takes a list of discriminators and aggregates the results across them. | |
""" | |
def __init__(self, discriminators: Iterable[nn.Module]): | |
super().__init__() | |
self.discriminators = nn.ModuleList(discriminators) | |
def forward(self, audio_real, audio_gen): | |
scores_real = [] | |
scores_gen = [] | |
fmaps_real = [] | |
fmaps_gen = [] | |
for discriminator in self.discriminators: | |
score_real, score_gen, fmap_real, fmap_gen = discriminator( | |
audio_real=audio_real, audio_gen=audio_gen | |
) | |
scores_real += score_real | |
fmaps_real += fmap_real | |
scores_gen += score_gen | |
fmaps_gen += fmap_gen | |
return scores_real, scores_gen, fmaps_real, fmaps_gen | |
class VectorQuantizerBase(nn.Module, ABC): | |
@abstractmethod | |
def forward(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor]: | |
pass | |
@abstractmethod | |
def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: | |
pass | |
@abstractmethod | |
def decode(self, indices: Tensor, input_len: Tensor) -> Tensor: | |
pass | |
class FiniteScalarQuantizer(VectorQuantizerBase): | |
"""This quantizer is based on the Finite Scalar Quantization (FSQ) method. | |
It quantizes each element of the input vector independently into a number of levels. | |
Args: | |
num_levels: number of levels for each dimension/element of the input vector | |
eps: small regularization constant for scaling | |
References: | |
Mentzer et al., Finite Scalar Quantization: VQ-VAE Made Simple (https://arxiv.org/abs/2309.15505v1) | |
""" | |
def __init__(self, num_levels: List[int], eps: float = 1e-3): | |
super().__init__() | |
# index base per dimension of the input vector | |
# this is used to convert between per-dimension indices and a codebook token index | |
dim_base_index = torch.cumprod( | |
torch.tensor([1] + num_levels[:-1]), dim=0, dtype=torch.int32 | |
) | |
dim_base_index = rearrange(dim_base_index, "D -> 1 D 1") | |
self.register_buffer("dim_base_index", dim_base_index) | |
# Register the number of levels for each dimension | |
num_levels = torch.tensor(num_levels, dtype=torch.int32) | |
num_levels = rearrange(num_levels, "D -> 1 D 1") | |
self.register_buffer("num_levels", num_levels) | |
# Regularization | |
self.eps = eps | |
logging.debug("Initializing %s with", self.__class__.__name__) | |
logging.debug("\tdim: %s", self.dim) | |
logging.debug("\tnum_levels: %s", self.num_levels) | |
logging.debug("\tcodebook_size: %s", self.codebook_size) | |
logging.debug("\teps: %s", self.eps) | |
@property | |
def codebook_size(self): | |
"""Returns the size of the corresponding codebook.""" | |
return self.num_levels.prod().item() | |
@property | |
def dim(self): | |
"""Returns the dimension of the input vector.""" | |
return self.num_levels.numel() | |
@property | |
def codebook_dim(self): | |
"""Returns the dimension of the input vector. | |
Keeping for compatiblitiy with the original RVQ implementation. | |
""" | |
return self.dim | |
@property | |
def codes(self): | |
"""Returns the codebooks entries. | |
Note that the codebook entries are implicitly defined by the number of levels. | |
""" | |
indices = torch.arange(self.codebook_size) | |
# [D, B, T] | |
indices = rearrange(indices, "B -> 1 B 1") | |
# [B, D, T] | |
codes = self.decode(indices=indices, input_len=None) | |
# Remove the time dimension | |
codes = codes.squeeze(-1) | |
return codes | |
@property | |
def codebook(self): | |
"""Returns the codebooks entries. | |
See self.codes for more details. | |
""" | |
return self.codes | |
@staticmethod | |
def round(inputs: Tensor, input_len: Tensor) -> Tensor: | |
"""Round the input tensor to nearest integer | |
and use a straight-through estimator for the gradient. | |
""" | |
inputs_rounded = torch.round(inputs) | |
return inputs + (inputs_rounded - inputs).detach() | |
def compress(self, inputs: Tensor, input_len: Tensor) -> Tensor: | |
"""Apply compression to the input, to limit to values.""" | |
output_scale = (self.num_levels - 1) / 2 | |
# scale down a bit to avoid rounding issues | |
output_scale = output_scale * (1 - self.eps) | |
# offset for even number of levels | |
output_offset = torch.where(self.num_levels % 2 == 0, 0.5, 0) | |
# shift for even number of levels | |
input_shift = (output_offset / output_scale).tan() | |
# compressed output | |
output = output_scale * (inputs + input_shift).tanh() - output_offset | |
return output | |
def inputs_to_codes(self, inputs: Tensor, input_len: Tensor) -> Tensor: | |
# apply compression | |
compressed = self.compress(inputs=inputs, input_len=input_len) | |
# apply rounding to nearest integer | |
codes = self.round(inputs=compressed, input_len=input_len) | |
# normalize to [-1, 1] | |
scale = self.num_levels // 2 | |
codes = codes / scale | |
return codes | |
def codes_to_nonnegative(self, codes: Tensor) -> Tensor: | |
"""Convert values centered arouund zero to nonnegative values.""" | |
scale = offset = self.num_levels // 2 | |
return scale * codes + offset | |
def nonnegative_to_codes(self, codes_nonnegative: Tensor) -> Tensor: | |
"""Convert nonnegative values to values centered arouund zero.""" | |
scale = offset = self.num_levels // 2 | |
return (codes_nonnegative - offset) / scale | |
def codes_to_indices(self, codes: Tensor) -> Tensor: | |
"""Converts a code vector to a single index.""" | |
if codes.size(1) != self.dim: | |
raise RuntimeError( | |
f"Input code dimension {codes.size(1)} not matching the expected dimension {self.dim}, input codes shape {codes.shape}" | |
) | |
# convert code vectors to nonnegative values | |
indices = self.codes_to_nonnegative(codes) | |
# convert one nonnegative index per dimension to a single index per code vector | |
indices = torch.sum(indices * self.dim_base_index, dim=1) | |
return indices.to(torch.int32) | |
# Implementation of VectorQuantiserBase API | |
def forward( | |
self, inputs: Tensor, input_len: Optional[Tensor] = None | |
) -> Tuple[Tensor, Tensor]: | |
if inputs.size(1) != self.dim: | |
raise RuntimeError( | |
f"Input dimension {inputs.size(1)} not matching the expected dimension {self.dim}, inputs shape {inputs.shape}" | |
) | |
dequantized = self.inputs_to_codes(inputs=inputs, input_len=input_len) | |
indices = self.codes_to_indices(codes=dequantized) | |
if input_len is not None: | |
# apply masking | |
dequantized = mask_sequence_tensor(dequantized, input_len) | |
indices = mask_sequence_tensor(indices, input_len) | |
# only 1 codebook, but return in [D, B, T] format to match RVQ API | |
indices = indices.unsqueeze(0) | |
return dequantized, indices | |
def encode(self, inputs: Tensor, input_len: Optional[Tensor] = None) -> Tensor: | |
"""Convert a continuous code vector to a single index.""" | |
_, indices = self(inputs=inputs, input_len=input_len) | |
return indices | |
def decode(self, indices: Tensor, input_len: Optional[Tensor] = None) -> Tensor: | |
"""Convert a single index to a continuous code vector.""" | |
if indices.size(0) > 1: | |
# codebook dimension used for compatibility with RVQ | |
raise ValueError( | |
f"Expected a single codebook, got {indices.size(0)} codebooks for indices with shape {indices.shape}." | |
) | |
indices = rearrange(indices, "D B T -> B D T") | |
# convert a single index to nonnegative index per-dimension | |
codes_nonnegative = (indices // self.dim_base_index) % self.num_levels | |
# convert nonnegative codes to codes (centered around zero) | |
dequantized = self.nonnegative_to_codes(codes_nonnegative) | |
if input_len is not None: | |
# apply masking | |
dequantized = mask_sequence_tensor(dequantized, input_len) | |
return dequantized | |
class GroupFiniteScalarQuantizer(VectorQuantizerBase): | |
"""Split the input vector into groups and apply FSQ on each group separately. | |
This class is for convenience. Since FSQ is applied on each group separately, | |
groups can be defined arbitrarily by splitting the input vector. However, this | |
class makes it easy to construct several groups with the same quantization num_levels. | |
Args: | |
num_groups: number of groups to split the input into, each group will be quantized separately using num_codebooks//num_groups codebooks | |
codebook_dim: embedding dimension, will be split into num_groups | |
**kwargs: parameters of FiniteScalarQuantizer | |
References: | |
Yang et al, HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec, 2023 (http://arxiv.org/abs/2305.02765). | |
""" | |
def __init__(self, num_groups: int, num_levels_per_group: List[int], **kwargs): | |
super().__init__() | |
self.num_groups = num_groups | |
self.codebook_dim_per_group = len(num_levels_per_group) | |
# Initialize FSQ for each group | |
self.fsqs = torch.nn.ModuleList( | |
[ | |
FiniteScalarQuantizer(num_levels=num_levels_per_group, **kwargs) | |
for _ in range(self.num_groups) | |
] | |
) | |
logging.debug("Initialized %s with", self.__class__.__name__) | |
logging.debug("\tnum_groups: %d", self.num_groups) | |
logging.debug("\tcodebook_dim: %d", self.codebook_dim) | |
logging.debug("\tnum_levels_per_group: %s", num_levels_per_group) | |
logging.debug("\tcodebook_dim_per_group: %d", self.codebook_dim_per_group) | |
@property | |
def codebook_dim(self): | |
"""Input vector dimension.""" | |
return self.codebook_dim_per_group * self.num_groups | |
@property | |
def codebook_size_per_group(self): | |
"""Returns the size of the implicit codebook for each group.""" | |
return self.fsqs[0].codebook_size | |
@property | |
def codebook_size(self): | |
"""Returns the size of the implicit codebook.""" | |
return self.codebook_size_per_group**self.num_groups | |
def forward(self, inputs, input_len): | |
"""Quantize each group separately, then concatenate the results.""" | |
inputs_grouped = inputs.chunk(self.num_groups, dim=1) | |
dequantized, indices = [], [] | |
for in_group, fsq_group in zip(inputs_grouped, self.fsqs): | |
dequantized_group, indices_group = fsq_group( | |
inputs=in_group, input_len=input_len | |
) | |
dequantized.append(dequantized_group) | |
indices.append(indices_group) | |
# concatenate along the feature dimension | |
dequantized = torch.cat(dequantized, dim=1) | |
# concatente along the codebook dimension | |
indices = torch.cat(indices, dim=0) | |
return dequantized, indices | |
def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: | |
"""Input is split into groups, each group is encoded separately, then the results are concatenated.""" | |
inputs_grouped = inputs.chunk(self.num_groups, dim=1) | |
indices = [] | |
for in_group, fsq_group in zip(inputs_grouped, self.fsqs): | |
indices_group = fsq_group.encode(inputs=in_group, input_len=input_len) | |
indices.append(indices_group) | |
# concatenate along the codebook dimension | |
indices = torch.cat(indices, dim=0) | |
return indices | |
def decode(self, indices: Tensor, input_len: Tensor) -> Tensor: | |
"""Input indices are split into groups, each group is decoded separately, then the results are concatenated.""" | |
indices_grouped = indices.chunk(self.num_groups, dim=0) | |
dequantized = [] | |
for indices_group, fsq_group in zip(indices_grouped, self.fsqs): | |
dequantized_group = fsq_group.decode( | |
indices=indices_group, input_len=input_len | |
) | |
dequantized.append(dequantized_group) | |
# concatenate along the feature dimension | |
dequantized = torch.cat(dequantized, dim=1) | |
return dequantized | |
class ResidualBlock(nn.Module): | |
""" | |
The residual block structure defined by the HiFi-GAN V1 and V2 configurations. | |
Args: | |
channels: Input dimension. | |
filters: Number of channels in the residual convolutions. | |
kernel_size: Kernel size of the residual convolutions. | |
dilation: Dilation of the residual convolutions. | |
dropout_rate: Dropout to apply to residuals. | |
activation: Activation to apply in between residual convolutions. | |
""" | |
def __init__( | |
self, | |
channels: int, | |
filters: int, | |
kernel_size: int = 3, | |
dilation: int = 1, | |
dropout_rate: float = 0.0, | |
activation: str = "lrelu", | |
): | |
super(ResidualBlock, self).__init__() | |
self.input_activation = CodecActivation( | |
activation=activation, channels=channels | |
) | |
self.skip_activation = CodecActivation(activation=activation, channels=filters) | |
self.dropout = torch.nn.Dropout(dropout_rate) | |
self.input_conv = Conv1dNorm( | |
in_channels=channels, | |
out_channels=filters, | |
kernel_size=kernel_size, | |
dilation=dilation, | |
) | |
self.skip_conv = Conv1dNorm( | |
in_channels=filters, out_channels=channels, kernel_size=kernel_size | |
) | |
def remove_weight_norm(self): | |
self.input_conv.remove_weight_norm() | |
self.skip_conv.remove_weight_norm() | |
def forward(self, inputs, input_len): | |
conv_input = self.input_activation(inputs) | |
skip_input = self.input_conv(inputs=conv_input, input_len=input_len) | |
skip_input = self.skip_activation(skip_input) | |
res = self.skip_conv(inputs=skip_input, input_len=input_len) | |
res = self.dropout(res) | |
out = inputs + res | |
return out | |
class HiFiGANResBlock(nn.Module): | |
""" | |
Residual block wrapper for HiFi-GAN which creates a block for multiple dilations. | |
Args: | |
channels: Input dimension. | |
kernel_size: Kernel size of the residual blocks. | |
dilations: List of dilations. One residual block will be created for each dilation in the list. | |
activation: Activation for the residual blocks. | |
""" | |
def __init__( | |
self, channels: int, kernel_size: int, dilations: Iterable[int], activation: str | |
): | |
super().__init__() | |
self.res_blocks = nn.ModuleList( | |
[ | |
ResidualBlock( | |
channels=channels, | |
filters=channels, | |
kernel_size=kernel_size, | |
dilation=dilation, | |
activation=activation, | |
) | |
for dilation in dilations | |
] | |
) | |
def remove_weight_norm(self): | |
for res_block in self.res_blocks: | |
res_block.remove_weight_norm() | |
def forward(self, inputs, input_len): | |
out = inputs | |
for res_block in self.res_blocks: | |
out = res_block(inputs=out, input_len=input_len) | |
return out | |
class HiFiGANResLayer(nn.Module): | |
""" | |
Residual block wrapper for HiFi-GAN which creates a block for multiple kernel sizes and dilations. | |
One residual block is created for each combination of kernel size and dilation. | |
Args: | |
channels: Input dimension. | |
kernel_sizes: List of kernel sizes. | |
dilations: List of dilations. | |
activation: Activation for the residual layers. | |
""" | |
def __init__( | |
self, | |
channels: int, | |
kernel_sizes: Iterable[int], | |
dilations: Iterable[int], | |
activation: str, | |
): | |
super().__init__() | |
self.res_blocks = nn.ModuleList( | |
[ | |
HiFiGANResBlock( | |
channels=channels, | |
kernel_size=kernel_size, | |
dilations=dilations, | |
activation=activation, | |
) | |
for kernel_size in kernel_sizes | |
] | |
) | |
def remove_weight_norm(self): | |
for res_block in self.res_blocks: | |
res_block.remove_weight_norm() | |
def forward(self, inputs, input_len): | |
residuals = [ | |
res_block(inputs=inputs, input_len=input_len) | |
for res_block in self.res_blocks | |
] | |
out = sum(residuals) / len(residuals) | |
return out | |
class HiFiGANEncoder(nn.Module): | |
""" | |
Audio encoder created by inverting the HiFi-GAN decoder. | |
Args: | |
encoded_dim: Dimension of encoder output. | |
down_sample_rates: Rate to upsample for each decoder block. The product of the downsample rates will | |
determine the output token rate. For example 2 * 2 * 8 * 8 = 256 samples per token. | |
base_channels: Number of filters in the first convolution. The number of channels will be doubled after each | |
downsample layer. | |
in_kernel_size: Kernel size of the input convolution. | |
out_kernel_size: Kernel size of the output convolution. | |
resblock_kernel_sizes: List of kernel sizes to use in each residual block. | |
resblock_dilation_sizes: List of dilations to use in each residual block. | |
activation: Activation to use in residual and downsample layers, defaults to leaky relu. | |
""" | |
def __init__( | |
self, | |
encoded_dim: int, | |
down_sample_rates: Iterable[int] = (2, 2, 8, 8), | |
base_channels: int = 32, | |
in_kernel_size: int = 7, | |
out_kernel_size: int = 7, | |
resblock_kernel_sizes: Iterable[int] = (3, 7, 11), | |
resblock_dilation_sizes: Iterable[int] = (1, 3, 5), | |
activation: str = "lrelu", | |
): | |
assert in_kernel_size > 0 | |
assert out_kernel_size > 0 | |
super().__init__() | |
self.down_sample_rates = down_sample_rates | |
self.pre_conv = Conv1dNorm( | |
in_channels=1, out_channels=base_channels, kernel_size=in_kernel_size | |
) | |
in_channels = base_channels | |
self.activations = nn.ModuleList([]) | |
self.down_sample_conv_layers = nn.ModuleList([]) | |
self.res_layers = nn.ModuleList([]) | |
for i, down_sample_rate in enumerate(self.down_sample_rates): | |
res_layer = HiFiGANResLayer( | |
channels=in_channels, | |
kernel_sizes=resblock_kernel_sizes, | |
dilations=resblock_dilation_sizes, | |
activation=activation, | |
) | |
self.res_layers.append(res_layer) | |
act = CodecActivation(activation, channels=in_channels) | |
self.activations.append(act) | |
out_channels = 2 * in_channels | |
kernel_size = 2 * down_sample_rate | |
padding = get_down_sample_padding( | |
kernel_size=kernel_size, stride=down_sample_rate | |
) | |
down_sample_conv = Conv1dNorm( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=down_sample_rate, | |
padding=padding, | |
) | |
in_channels = out_channels | |
self.down_sample_conv_layers.append(down_sample_conv) | |
self.post_activation = CodecActivation(activation, channels=in_channels) | |
self.post_conv = Conv1dNorm( | |
in_channels=in_channels, | |
out_channels=encoded_dim, | |
kernel_size=out_kernel_size, | |
) | |
def remove_weight_norm(self): | |
self.pre_conv.remove_weight_norm() | |
self.post_conv.remove_weight_norm() | |
for res_layer in self.res_layers: | |
res_layer.remove_weight_norm() | |
for down_sample_conv in self.down_sample_conv_layers: | |
down_sample_conv.remove_weight_norm() | |
def forward(self, audio, audio_len): | |
encoded_len = audio_len | |
audio = rearrange(audio, "B T -> B 1 T") | |
# [B, C, T_audio] | |
out = self.pre_conv(inputs=audio, input_len=encoded_len) | |
for act, res_layer, down_sample_conv, down_sample_rate in zip( | |
self.activations, | |
self.res_layers, | |
self.down_sample_conv_layers, | |
self.down_sample_rates, | |
): | |
# [B, C, T] | |
out = res_layer(inputs=out, input_len=encoded_len) | |
out = act(out) | |
encoded_len = encoded_len // down_sample_rate | |
# [B, 2 * C, T / down_sample_rate] | |
out = down_sample_conv(inputs=out, input_len=encoded_len) | |
out = self.post_activation(out) | |
# [B, encoded_dim, T_encoded] | |
encoded = self.post_conv(inputs=out, input_len=encoded_len) | |
return encoded, encoded_len | |
class HiFiGANDecoder(nn.Module): | |
""" | |
Codec decoder using the HiFi-GAN generator architecture. | |
Default parameters match the HiFi-GAN V1 configuration for 22.05khz. | |
Args: | |
input_dim: Input dimension. | |
up_sample_rates: Rate to upsample for each decoder block. The product of the upsample rates should be the same | |
as the overall downsample rate for your encoder. For example, a symmetric encoder/decoder can be created | |
with encoder downsample rates [2, 2, 8, 8] and decoder upsample rates [8, 8, 2, 2]. | |
base_channels: Number of filters in the first convolution. The number of channels will be cut in | |
half after each upsample layer. | |
in_kernel_size: Kernel size of the input convolution. | |
out_kernel_size: Kernel size of the output convolution. | |
resblock_kernel_sizes: List of kernel sizes to use in each residual block. | |
resblock_dilation_sizes: List of dilations to use in each residual block. | |
activation: Activation to use in residual and upsample layers, defaults to leaky relu. | |
output_activation: Activation to apply to output. To produce a valid audio signal, it should output values in | |
the range [-1.0, 1.0]. Supports "tanh" and "clamp". | |
""" | |
def __init__( | |
self, | |
input_dim: int, | |
up_sample_rates: Iterable[int] = (8, 8, 2, 2), | |
base_channels: int = 512, | |
in_kernel_size: int = 7, | |
out_kernel_size: int = 3, | |
resblock_kernel_sizes: Iterable[int] = (3, 7, 11), | |
resblock_dilation_sizes: Iterable[int] = (1, 3, 5), | |
activation: str = "lrelu", | |
output_activation: str = "tanh", | |
): | |
assert in_kernel_size > 0 | |
assert out_kernel_size > 0 | |
super().__init__() | |
self.up_sample_rates = up_sample_rates | |
self.pre_conv = Conv1dNorm( | |
in_channels=input_dim, | |
out_channels=base_channels, | |
kernel_size=in_kernel_size, | |
) | |
in_channels = base_channels | |
self.activations = nn.ModuleList([]) | |
self.up_sample_conv_layers = nn.ModuleList([]) | |
self.res_layers = nn.ModuleList([]) | |
for i, up_sample_rate in enumerate(self.up_sample_rates): | |
out_channels = in_channels // 2 | |
kernel_size = 2 * up_sample_rate | |
act = CodecActivation(activation, channels=in_channels) | |
self.activations.append(act) | |
up_sample_conv = ConvTranspose1dNorm( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=up_sample_rate, | |
) | |
in_channels = out_channels | |
self.up_sample_conv_layers.append(up_sample_conv) | |
res_layer = HiFiGANResLayer( | |
channels=in_channels, | |
kernel_sizes=resblock_kernel_sizes, | |
dilations=resblock_dilation_sizes, | |
activation=activation, | |
) | |
self.res_layers.append(res_layer) | |
self.post_activation = CodecActivation(activation, channels=in_channels) | |
self.post_conv = Conv1dNorm( | |
in_channels=in_channels, out_channels=1, kernel_size=out_kernel_size | |
) | |
if output_activation == "tanh": | |
self.out_activation = nn.Tanh() | |
elif output_activation == "clamp": | |
self.out_activation = ClampActivation() | |
else: | |
raise ValueError(f"Invalid audio output activation {output_activation}") | |
def remove_weight_norm(self): | |
self.pre_conv.remove_weight_norm() | |
for up_sample_conv in self.up_sample_conv_layers: | |
up_sample_conv.remove_weight_norm() | |
for res_layer in self.res_layers: | |
res_layer.remove_weight_norm() | |
def forward(self, inputs, input_len): | |
audio_len = input_len | |
# [B, C, T_encoded] | |
out = self.pre_conv(inputs=inputs, input_len=audio_len) | |
for act, res_layer, up_sample_conv, up_sample_rate in zip( | |
self.activations, | |
self.res_layers, | |
self.up_sample_conv_layers, | |
self.up_sample_rates, | |
): | |
audio_len = audio_len * up_sample_rate | |
out = act(out) | |
# [B, C / 2, T * up_sample_rate] | |
out = up_sample_conv(inputs=out, input_len=audio_len) | |
out = res_layer(inputs=out, input_len=audio_len) | |
out = self.post_activation(out) | |
# [B, 1, T_audio] | |
out = self.post_conv(inputs=out, input_len=audio_len) | |
audio = self.out_activation(out) | |
audio = rearrange(audio, "B 1 T -> B T") | |
return audio, audio_len | |
@torch.jit.script_if_tracing | |
def make_seq_mask_like( | |
lengths: Tensor, | |
like: Tensor, | |
time_dim: int = -1, | |
valid_ones: bool = True, | |
) -> Tensor: | |
""" | |
Args: | |
lengths: Tensor with shape [B] containing the sequence length of each batch element | |
like: The mask will contain the same number of dimensions as this Tensor, and will have the same max | |
length in the time dimension of this Tensor. | |
time_dim: Time dimension of the `shape_tensor` and the resulting mask. Zero-based. | |
valid_ones: If True, valid tokens will contain value `1` and padding will be `0`. Else, invert. | |
Returns: | |
A :class:`Tensor` containing 1's and 0's for valid and invalid tokens, respectively, if `valid_ones`, else | |
vice-versa. Mask will have the same number of dimensions as `like`. Batch and time dimensions will match | |
the `like`. All other dimensions will be singletons. E.g., if `like.shape == [3, 4, 5]` and | |
`time_dim == -1', mask will have shape `[3, 1, 5]`. | |
""" | |
# Mask with shape [B, T] | |
mask = ( | |
torch.arange(like.shape[time_dim], device=like.device) | |
.repeat(lengths.shape[0], 1) | |
.lt(lengths.view(-1, 1)) | |
) | |
# [B, T] -> [B, *, T] where * is any number of singleton dimensions to expand to like tensor | |
for _ in range(like.dim() - mask.dim()): | |
mask = mask.unsqueeze(1) | |
# If needed, transpose time dim | |
if time_dim != -1 and time_dim != mask.dim() - 1: | |
mask = mask.transpose(-1, time_dim) | |
# Maybe invert the padded vs. valid token values | |
if not valid_ones: | |
mask = ~mask | |
return mask | |
def normalize_batch(x, seq_len, normalize_type): | |
x_mean = None | |
x_std = None | |
if normalize_type == "per_feature": | |
batch_size = x.shape[0] | |
max_time = x.shape[2] | |
# When doing stream capture to a graph, item() is not allowed | |
# becuase it calls cudaStreamSynchronize(). Therefore, we are | |
# sacrificing some error checking when running with cuda graphs. | |
if ( | |
torch.cuda.is_available() | |
and not torch.cuda.is_current_stream_capturing() | |
and torch.any(seq_len == 1).item() | |
): | |
raise ValueError( | |
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result " | |
"in torch.std() returning nan. Make sure your audio length has enough samples for a single " | |
"feature (ex. at least `hop_length` for Mel Spectrograms)." | |
) | |
time_steps = ( | |
torch.arange(max_time, device=x.device) | |
.unsqueeze(0) | |
.expand(batch_size, max_time) | |
) | |
valid_mask = time_steps < seq_len.unsqueeze(1) | |
x_mean_numerator = torch.where(valid_mask.unsqueeze(1), x, 0.0).sum(axis=2) | |
x_mean_denominator = valid_mask.sum(axis=1) | |
x_mean = x_mean_numerator / x_mean_denominator.unsqueeze(1) | |
# Subtract 1 in the denominator to correct for the bias. | |
x_std = torch.sqrt( | |
torch.sum( | |
torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0) ** 2, | |
axis=2, | |
) | |
/ (x_mean_denominator.unsqueeze(1) - 1.0) | |
) | |
# make sure x_std is not zero | |
x_std += CONSTANT | |
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std | |
elif normalize_type == "all_features": | |
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) | |
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) | |
for i in range(x.shape[0]): | |
x_mean[i] = x[i, :, : seq_len[i].item()].mean() | |
x_std[i] = x[i, :, : seq_len[i].item()].std() | |
# make sure x_std is not zero | |
x_std += CONSTANT | |
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std | |
elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type: | |
x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device) | |
x_std = torch.tensor(normalize_type["fixed_std"], device=x.device) | |
return ( | |
(x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) | |
/ x_std.view(x.shape[0], x.shape[1]).unsqueeze(2), | |
x_mean, | |
x_std, | |
) | |
else: | |
return x, x_mean, x_std | |
def splice_frames(x, frame_splicing): | |
"""Stacks frames together across feature dim | |
input is batch_size, feature_dim, num_frames | |
output is batch_size, feature_dim*frame_splicing, num_frames | |
""" | |
seq = [x] | |
for n in range(1, frame_splicing): | |
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2)) | |
return torch.cat(seq, dim=1) | |
class FilterbankFeatures(nn.Module): | |
"""Featurizer that converts wavs to Mel Spectrograms. | |
See AudioToMelSpectrogramPreprocessor for args. | |
""" | |
def __init__( | |
self, | |
sample_rate=16000, | |
n_window_size=320, | |
n_window_stride=160, | |
window="hann", | |
normalize="per_feature", | |
n_fft=None, | |
preemph=0.97, | |
nfilt=64, | |
lowfreq=0, | |
highfreq=None, | |
log=True, | |
log_zero_guard_type="add", | |
log_zero_guard_value=2**-24, | |
dither=CONSTANT, | |
pad_to=16, | |
max_duration=16.7, | |
frame_splicing=1, | |
exact_pad=False, | |
pad_value=0, | |
mag_power=2.0, | |
use_grads=False, | |
rng=None, | |
nb_augmentation_prob=0.0, | |
nb_max_freq=4000, | |
mel_norm="slaney", | |
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility | |
stft_conv=False, # Deprecated arguments; kept for config compatibility | |
): | |
super().__init__() | |
if stft_conv or stft_exact_pad: | |
logging.warning( | |
"Using torch_stft is deprecated and has been removed. The values have been forcibly set to False " | |
"for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True " | |
"as needed." | |
) | |
if exact_pad and n_window_stride % 2 == 1: | |
raise NotImplementedError( | |
f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the " | |
"returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size." | |
) | |
self.log_zero_guard_value = log_zero_guard_value | |
if ( | |
n_window_size is None | |
or n_window_stride is None | |
or not isinstance(n_window_size, int) | |
or not isinstance(n_window_stride, int) | |
or n_window_size <= 0 | |
or n_window_stride <= 0 | |
): | |
raise ValueError( | |
f"{self} got an invalid value for either n_window_size or " | |
f"n_window_stride. Both must be positive ints." | |
) | |
logging.info(f"PADDING: {pad_to}") | |
self.win_length = n_window_size | |
self.hop_length = n_window_stride | |
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) | |
self.stft_pad_amount = ( | |
(self.n_fft - self.hop_length) // 2 if exact_pad else None | |
) | |
self.exact_pad = exact_pad | |
if exact_pad: | |
logging.info("STFT using exact pad") | |
torch_windows = { | |
"hann": torch.hann_window, | |
"hamming": torch.hamming_window, | |
"blackman": torch.blackman_window, | |
"bartlett": torch.bartlett_window, | |
"none": None, | |
} | |
window_fn = torch_windows.get(window, None) | |
window_tensor = ( | |
window_fn(self.win_length, periodic=False) if window_fn else None | |
) | |
self.register_buffer("window", window_tensor) | |
self.normalize = normalize | |
self.log = log | |
self.dither = dither | |
self.frame_splicing = frame_splicing | |
self.nfilt = nfilt | |
self.preemph = preemph | |
self.pad_to = pad_to | |
highfreq = highfreq or sample_rate / 2 | |
import librosa | |
filterbanks = torch.tensor( | |
librosa.filters.mel( | |
sr=sample_rate, | |
n_fft=self.n_fft, | |
n_mels=nfilt, | |
fmin=lowfreq, | |
fmax=highfreq, | |
norm=mel_norm, | |
), | |
dtype=torch.float, | |
).unsqueeze(0) | |
self.register_buffer("fb", filterbanks) | |
# Calculate maximum sequence length | |
max_length = self.get_seq_len( | |
torch.tensor(max_duration * sample_rate, dtype=torch.float) | |
) | |
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0 | |
self.max_length = max_length + max_pad | |
self.pad_value = pad_value | |
self.mag_power = mag_power | |
# We want to avoid taking the log of zero | |
# There are two options: either adding or clamping to a small value | |
if log_zero_guard_type not in ["add", "clamp"]: | |
raise ValueError( | |
f"{self} received {log_zero_guard_type} for the " | |
f"log_zero_guard_type parameter. It must be either 'add' or " | |
f"'clamp'." | |
) | |
self.use_grads = use_grads | |
if not use_grads: | |
self.forward = torch.no_grad()(self.forward) | |
self._rng = random.Random() if rng is None else rng | |
self.nb_augmentation_prob = nb_augmentation_prob | |
if self.nb_augmentation_prob > 0.0: | |
if nb_max_freq >= sample_rate / 2: | |
self.nb_augmentation_prob = 0.0 | |
else: | |
self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft) | |
# log_zero_guard_value is the the small we want to use, we support | |
# an actual number, or "tiny", or "eps" | |
self.log_zero_guard_type = log_zero_guard_type | |
logging.debug(f"sr: {sample_rate}") | |
logging.debug(f"n_fft: {self.n_fft}") | |
logging.debug(f"win_length: {self.win_length}") | |
logging.debug(f"hop_length: {self.hop_length}") | |
logging.debug(f"n_mels: {nfilt}") | |
logging.debug(f"fmin: {lowfreq}") | |
logging.debug(f"fmax: {highfreq}") | |
logging.debug(f"using grads: {use_grads}") | |
logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}") | |
def stft(self, x): | |
return torch.stft( | |
x, | |
n_fft=self.n_fft, | |
hop_length=self.hop_length, | |
win_length=self.win_length, | |
center=False if self.exact_pad else True, | |
window=self.window.to(dtype=torch.float), | |
return_complex=True, | |
) | |
def log_zero_guard_value_fn(self, x): | |
if isinstance(self.log_zero_guard_value, str): | |
if self.log_zero_guard_value == "tiny": | |
return torch.finfo(x.dtype).tiny | |
elif self.log_zero_guard_value == "eps": | |
return torch.finfo(x.dtype).eps | |
else: | |
raise ValueError( | |
f"{self} received {self.log_zero_guard_value} for the " | |
f"log_zero_guard_type parameter. It must be either a " | |
f"number, 'tiny', or 'eps'" | |
) | |
else: | |
return self.log_zero_guard_value | |
def get_seq_len(self, seq_len): | |
# Assuming that center is True is stft_pad_amount = 0 | |
pad_amount = ( | |
self.stft_pad_amount * 2 | |
if self.stft_pad_amount is not None | |
else self.n_fft // 2 * 2 | |
) | |
seq_len = ( | |
torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1 | |
) | |
return seq_len.to(dtype=torch.long) | |
@property | |
def filter_banks(self): | |
return self.fb | |
def forward(self, x, seq_len, linear_spec=False): | |
seq_len = self.get_seq_len(seq_len) | |
if self.stft_pad_amount is not None: | |
x = torch.nn.functional.pad( | |
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect" | |
).squeeze(1) | |
# dither (only in training mode for eval determinism) | |
if self.training and self.dither > 0: | |
x += self.dither * torch.randn_like(x) | |
# do preemphasis | |
if self.preemph is not None: | |
x = torch.cat( | |
(x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1 | |
) | |
# disable autocast to get full range of stft values | |
with torch.amp.autocast(x.device.type, enabled=False): | |
x = self.stft(x) | |
# torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude | |
# guard is needed for sqrt if grads are passed through | |
guard = 0 if not self.use_grads else CONSTANT | |
x = torch.view_as_real(x) | |
x = torch.sqrt(x.pow(2).sum(-1) + guard) | |
if self.training and self.nb_augmentation_prob > 0.0: | |
for idx in range(x.shape[0]): | |
if self._rng.random() < self.nb_augmentation_prob: | |
x[idx, self._nb_max_fft_bin :, :] = 0.0 | |
# get power spectrum | |
if self.mag_power != 1.0: | |
x = x.pow(self.mag_power) | |
# return plain spectrogram if required | |
if linear_spec: | |
return x, seq_len | |
# dot with filterbank energies | |
x = torch.matmul(self.fb.to(x.dtype), x) | |
# log features if required | |
if self.log: | |
if self.log_zero_guard_type == "add": | |
x = torch.log(x + self.log_zero_guard_value_fn(x)) | |
elif self.log_zero_guard_type == "clamp": | |
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x))) | |
else: | |
raise ValueError("log_zero_guard_type was not understood") | |
# frame splicing if required | |
if self.frame_splicing > 1: | |
x = splice_frames(x, self.frame_splicing) | |
# normalize if required | |
if self.normalize: | |
x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize) | |
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency) | |
max_len = x.size(-1) | |
mask = torch.arange(max_len, device=x.device) | |
mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1) | |
x = x.masked_fill( | |
mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value | |
) | |
del mask | |
pad_to = self.pad_to | |
if pad_to == "max": | |
x = nn.functional.pad( | |
x, (0, self.max_length - x.size(-1)), value=self.pad_value | |
) | |
elif pad_to > 0: | |
pad_amt = x.size(-1) % pad_to | |
if pad_amt != 0: | |
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) | |
return x, seq_len | |
class AudioToMelSpectrogramPreprocessor(nn.Module): | |
"""Featurizer module that converts wavs to mel spectrograms. | |
Args: | |
sample_rate (int): Sample rate of the input audio data. | |
Defaults to 16000 | |
window_size (float): Size of window for fft in seconds | |
Defaults to 0.02 | |
window_stride (float): Stride of window for fft in seconds | |
Defaults to 0.01 | |
n_window_size (int): Size of window for fft in samples | |
Defaults to None. Use one of window_size or n_window_size. | |
n_window_stride (int): Stride of window for fft in samples | |
Defaults to None. Use one of window_stride or n_window_stride. | |
window (str): Windowing function for fft. can be one of ['hann', | |
'hamming', 'blackman', 'bartlett'] | |
Defaults to "hann" | |
normalize (str): Can be one of ['per_feature', 'all_features']; all | |
other options disable feature normalization. 'all_features' | |
normalizes the entire spectrogram to be mean 0 with std 1. | |
'pre_features' normalizes per channel / freq instead. | |
Defaults to "per_feature" | |
n_fft (int): Length of FT window. If None, it uses the smallest power | |
of 2 that is larger than n_window_size. | |
Defaults to None | |
preemph (float): Amount of pre emphasis to add to audio. Can be | |
disabled by passing None. | |
Defaults to 0.97 | |
features (int): Number of mel spectrogram freq bins to output. | |
Defaults to 64 | |
lowfreq (int): Lower bound on mel basis in Hz. | |
Defaults to 0 | |
highfreq (int): Lower bound on mel basis in Hz. | |
Defaults to None | |
log (bool): Log features. | |
Defaults to True | |
log_zero_guard_type(str): Need to avoid taking the log of zero. There | |
are two options: "add" or "clamp". | |
Defaults to "add". | |
log_zero_guard_value(float, or str): Add or clamp requires the number | |
to add with or clamp to. log_zero_guard_value can either be a float | |
or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is | |
passed. | |
Defaults to 2**-24. | |
dither (float): Amount of white-noise dithering. | |
Defaults to 1e-5 | |
pad_to (int): Ensures that the output size of the time dimension is | |
a multiple of pad_to. | |
Defaults to 16 | |
frame_splicing (int): Defaults to 1 | |
exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length | |
// hop_length. Defaults to False. | |
pad_value (float): The value that shorter mels are padded with. | |
Defaults to 0 | |
mag_power (float): The power that the linear spectrogram is raised to | |
prior to multiplication with mel basis. | |
Defaults to 2 for a power spec | |
rng : Random number generator | |
nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to | |
samples in the batch. | |
Defaults to 0.0 | |
nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation. | |
Defaults to 4000 | |
use_torchaudio: Whether to use the `torchaudio` implementation. | |
mel_norm: Normalization used for mel filterbank weights. | |
Defaults to 'slaney' (area normalization) | |
stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints. | |
stft_conv: Deprecated argument, kept for compatibility with older checkpoints. | |
""" | |
def __init__( | |
self, | |
sample_rate=16000, | |
window_size=0.02, | |
window_stride=0.01, | |
n_window_size=None, | |
n_window_stride=None, | |
window="hann", | |
normalize="per_feature", | |
n_fft=None, | |
preemph=0.97, | |
features=64, | |
lowfreq=0, | |
highfreq=None, | |
log=True, | |
log_zero_guard_type="add", | |
log_zero_guard_value=2**-24, | |
dither=1e-5, | |
pad_to=16, | |
frame_splicing=1, | |
exact_pad=False, | |
pad_value=0, | |
mag_power=2.0, | |
rng=None, | |
nb_augmentation_prob=0.0, | |
nb_max_freq=4000, | |
use_torchaudio: bool = False, | |
mel_norm="slaney", | |
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility | |
stft_conv=False, # Deprecated arguments; kept for config compatibility | |
): | |
super().__init__(n_window_size, n_window_stride) | |
self._sample_rate = sample_rate | |
if window_size and n_window_size: | |
raise ValueError( | |
f"{self} received both window_size and " | |
f"n_window_size. Only one should be specified." | |
) | |
if window_stride and n_window_stride: | |
raise ValueError( | |
f"{self} received both window_stride and " | |
f"n_window_stride. Only one should be specified." | |
) | |
if window_size: | |
n_window_size = int(window_size * self._sample_rate) | |
if window_stride: | |
n_window_stride = int(window_stride * self._sample_rate) | |
# Given the long and similar argument list, point to the class and instantiate it by reference | |
featurizer_class = FilterbankFeatures | |
self.featurizer = featurizer_class( | |
sample_rate=self._sample_rate, | |
n_window_size=n_window_size, | |
n_window_stride=n_window_stride, | |
window=window, | |
normalize=normalize, | |
n_fft=n_fft, | |
preemph=preemph, | |
nfilt=features, | |
lowfreq=lowfreq, | |
highfreq=highfreq, | |
log=log, | |
log_zero_guard_type=log_zero_guard_type, | |
log_zero_guard_value=log_zero_guard_value, | |
dither=dither, | |
pad_to=pad_to, | |
frame_splicing=frame_splicing, | |
exact_pad=exact_pad, | |
pad_value=pad_value, | |
mag_power=mag_power, | |
rng=rng, | |
nb_augmentation_prob=nb_augmentation_prob, | |
nb_max_freq=nb_max_freq, | |
mel_norm=mel_norm, | |
stft_exact_pad=stft_exact_pad, # Deprecated arguments; kept for config compatibility | |
stft_conv=stft_conv, # Deprecated arguments; kept for config compatibility | |
) | |
def input_example( | |
self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200 | |
): | |
batch_size = torch.randint(low=1, high=max_batch, size=[1]).item() | |
max_length = torch.randint(low=min_length, high=max_dim, size=[1]).item() | |
signals = torch.rand(size=[batch_size, max_length]) * 2 - 1 | |
lengths = torch.randint(low=min_length, high=max_dim, size=[batch_size]) | |
lengths[0] = max_length | |
return signals, lengths | |
def get_features(self, input_signal, length): | |
return self.featurizer(input_signal, length) | |
@property | |
def filter_banks(self): | |
return self.featurizer.filter_banks | |
class MelSpectrogramProcessor(nn.Module): | |
""" | |
Wrapper interface for computing mel spectrogram for codec training. | |
""" | |
def __init__( | |
self, | |
sample_rate: int, | |
win_length: int, | |
hop_length: int, | |
mel_dim: int = 80, | |
log_guard: float = 1.0, | |
): | |
super(MelSpectrogramProcessor, self).__init__() | |
self.mel_dim = mel_dim | |
self.hop_length = hop_length | |
self.preprocessor = AudioToMelSpectrogramPreprocessor( | |
sample_rate=sample_rate, | |
highfreq=None, | |
features=mel_dim, | |
pad_to=1, | |
exact_pad=True, | |
n_window_size=win_length, | |
n_window_stride=hop_length, | |
window_size=False, | |
window_stride=False, | |
n_fft=win_length, | |
mag_power=1.0, | |
log=True, | |
log_zero_guard_type="add", | |
log_zero_guard_value=log_guard, | |
mel_norm=None, | |
normalize=None, | |
preemph=None, | |
dither=0.0, | |
) | |
def forward(self, audio, audio_len): | |
spec, spec_len = self.preprocessor(input_signal=audio, length=audio_len) | |
return spec, spec_len | |
class ResNetEncoder(nn.Module): | |
""" | |
Residual network which uses HiFi-GAN residual blocks to encode spectrogram features without changing | |
the time dimension. | |
Args: | |
in_channels: input dimension | |
out_channels: output dimension | |
num_layers: number of residual blocks to use | |
hidden_channels: encoder hidden dimension | |
filters: number of filters in residual block layers | |
kernel_size: kernel size in residual block convolutions | |
dropout_rate: Optional dropout rate to apply to residuals. | |
activation: Activation to use, defaults to leaky relu. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
num_layers: int = 6, | |
hidden_channels: int = 256, | |
filters: int = 768, | |
kernel_size: int = 3, | |
dropout_rate: float = 0.1, | |
activation: str = "lrelu", | |
): | |
super(ResNetEncoder, self).__init__() | |
self.pre_conv = Conv1dNorm( | |
in_channels=in_channels, | |
out_channels=hidden_channels, | |
kernel_size=kernel_size, | |
) | |
self.res_layers = nn.ModuleList( | |
[ | |
ResidualBlock( | |
channels=hidden_channels, | |
filters=filters, | |
kernel_size=kernel_size, | |
dropout_rate=dropout_rate, | |
activation=activation, | |
) | |
for _ in range(num_layers) | |
] | |
) | |
self.post_activation = CodecActivation(activation, channels=hidden_channels) | |
self.post_conv = Conv1dNorm( | |
in_channels=hidden_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
) | |
def remove_weight_norm(self): | |
self.pre_conv.remove_weight_norm() | |
self.post_conv.remove_weight_norm() | |
for res_layer in self.res_layers: | |
res_layer.remove_weight_norm() | |
def forward(self, inputs, input_len): | |
encoded = self.pre_conv(inputs=inputs, input_len=input_len) | |
for res_layer in self.res_layers: | |
encoded = res_layer(inputs=encoded, input_len=input_len) | |
encoded = self.post_activation(encoded) | |
encoded = self.post_conv(inputs=encoded, input_len=input_len) | |
return encoded | |
class FullBandMelEncoder(nn.Module): | |
""" | |
Encoder which encodes the entire mel spectrogram with a single encoder network. | |
Args: | |
mel_processor: MelSpectrogramProcessor or equivalent class instance for computing the mel spectrogram from | |
input audio. | |
encoder: ResNetEncoder or equivalent class for encoding the mel spectrogram. | |
""" | |
def __init__(self, mel_processor: nn.Module, encoder: nn.Module): | |
super(FullBandMelEncoder, self).__init__() | |
self.mel_processor = mel_processor | |
self.encoder = encoder | |
def remove_weight_norm(self): | |
self.encoder.remove_weight_norm() | |
def forward(self, audio, audio_len): | |
out, spec_len = self.mel_processor(audio=audio, audio_len=audio_len) | |
encoded = self.encoder(inputs=out, input_len=spec_len) | |
return encoded, spec_len | |
class MultiBandMelEncoder(nn.Module): | |
""" | |
Encoder which splits mel spectrogram into bands and encodes each using separate residual networks. | |
Args: | |
mel_bands: List of mel spectrogram bands to encode. | |
Each list element is tuple of 2 elements with the start and end index of the mel features to use. | |
mel_processor: MelSpectrogramProcessor or equivalent class instance for computing the mel spectrogram from | |
input audio. | |
encoder_kwargs: Arguments for constructing encoder for each mel band. | |
""" | |
def __init__( | |
self, | |
mel_bands: Iterable[Tuple[int, int]], | |
mel_processor: nn.Module, | |
**encoder_kwargs, | |
): | |
super(MultiBandMelEncoder, self).__init__() | |
self.validate_mel_bands(mel_dim=mel_processor.mel_dim, mel_bands=mel_bands) | |
self.mel_bands = mel_bands | |
self.mel_processor = mel_processor | |
band_dims = [band[1] - band[0] for band in self.mel_bands] | |
self.encoders = nn.ModuleList( | |
[ | |
ResNetEncoder(in_channels=band_dim, **encoder_kwargs) | |
for band_dim in band_dims | |
] | |
) | |
@staticmethod | |
def validate_mel_bands(mel_dim: int, mel_bands: Iterable[Tuple[int, int]]): | |
mel_dims_used = np.zeros([mel_dim], dtype=bool) | |
for band in mel_bands: | |
mel_dims_used[band[0] : band[1]] = True | |
if not all(mel_dims_used): | |
missing_dims = np.where(~mel_dims_used) | |
raise ValueError( | |
f"Mel bands must cover all {mel_dim} dimensions. Missing {missing_dims}." | |
) | |
return | |
def remove_weight_norm(self): | |
for encoder in self.encoders: | |
encoder.remove_weight_norm() | |
def forward(self, audio, audio_len): | |
spec, spec_len = self.mel_processor(audio=audio, audio_len=audio_len) | |
outputs = [] | |
for (band_start, band_end), encoder in zip(self.mel_bands, self.encoders): | |
# [B, D_band, T] | |
spec_band = spec[:, band_start:band_end, :] | |
band_out = encoder(inputs=spec_band, input_len=spec_len) | |
outputs.append(band_out) | |
# [B, C, T] | |
encoded = torch.cat(outputs, dim=1) | |
return encoded, spec_len |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import einops | |
import torch | |
import torch.nn as nn | |
activation_registry = { | |
"identity": nn.Identity, | |
"hardtanh": nn.Hardtanh, | |
"relu": nn.ReLU, | |
"selu": nn.SELU, | |
"swish": nn.SiLU, | |
"silu": nn.SiLU, | |
"gelu": nn.GELU, | |
} | |
def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor): | |
""" | |
For tensors containing sequences, zero out out-of-bound elements given lengths of every element in the batch. | |
tensor: tensor of shape (B, L), (B, D, L) or (B, D1, D2, L), | |
lengths: LongTensor of shape (B,) | |
""" | |
batch_size, *_, max_lengths = tensor.shape | |
if len(tensor.shape) == 2: | |
mask = torch.ones(batch_size, max_lengths).cumsum(dim=-1).type_as(lengths) | |
mask = mask <= einops.rearrange(lengths, "B -> B 1") | |
elif len(tensor.shape) == 3: | |
mask = torch.ones(batch_size, 1, max_lengths).cumsum(dim=-1).type_as(lengths) | |
mask = mask <= einops.rearrange(lengths, "B -> B 1 1") | |
elif len(tensor.shape) == 4: | |
mask = torch.ones(batch_size, 1, 1, max_lengths).cumsum(dim=-1).type_as(lengths) | |
mask = mask <= einops.rearrange(lengths, "B -> B 1 1 1") | |
else: | |
raise ValueError( | |
"Can only mask tensors of shape B x L, B x D x L and B x D1 x D2 x L" | |
) | |
return tensor * mask | |
class ClampActivation(nn.Module): | |
def __init__(self, min_value: float = -1.0, max_value: float = 1.0): | |
super().__init__() | |
self.min_value = min_value | |
self.max_value = max_value | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return torch.clamp(input, min=self.min_value, max=self.max_value) | |
@torch.jit.script | |
def snake(x: torch.Tensor, alpha: torch.Tensor, eps: float = 1e-9) -> torch.Tensor: | |
""" | |
equation for snake activation function: x + (alpha + eps)^-1 * sin(alpha * x)^2 | |
""" | |
shape = x.shape | |
x = x.reshape(shape[0], shape[1], -1) | |
x = x + (alpha + eps).reciprocal() * torch.sin(alpha * x).pow(2) | |
x = x.reshape(shape) | |
return x | |
class Snake(nn.Module): | |
""" | |
Snake activation function introduced in 'https://arxiv.org/abs/2006.08195' | |
""" | |
def __init__(self, channels: int): | |
super().__init__() | |
self.alpha = nn.Parameter(torch.ones(1, channels, 1)) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return snake(x, self.alpha) | |
class HalfSnake(nn.Module): | |
""" | |
Activation which applies snake to the first half of input elements and leaky relu to the second half. | |
""" | |
def __init__(self, channels: int): | |
super().__init__() | |
self.snake_channels = channels // 2 | |
self.snake_act = Snake(self.snake_channels) | |
self.lrelu = torch.nn.LeakyReLU() | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
snake_out = self.snake_act(x[:, : self.snake_channels, :]) | |
lrelu_out = self.lrelu(x[:, self.snake_channels :, :]) | |
out = torch.cat([snake_out, lrelu_out], dim=1) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment