Last active
May 2, 2023 03:25
-
-
Save sourabh2k15/80adbf1c5861e727f7698fd66e51be39 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#@title FFT Layer | |
"""Flax layer to perform preprocessing on librispeech audio inputs. | |
This layer computes windowed short time fourier transform over audio signals | |
then converts it to mel scale and finally takes a logarithm of resulting | |
mel spectrograms and normalizes it to be used in speech recognition models. | |
This code is based on lingvo's librispeech preprocessing code here: | |
https://github.com/tensorflow/lingvo/blob/master/lingvo/tasks/asr/frontend.py | |
""" | |
from typing import Any, Optional, Union | |
from flax import linen as nn | |
from flax import struct | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
# mel spectrum constants. | |
_MEL_BREAK_FREQUENCY_HERTZ = 700.0 | |
_MEL_HIGH_FREQUENCY_Q = 1127.0 | |
LIBRISPEECH_MEAN_VECTOR = [ | |
-7.6047816276550293, | |
-7.1206226348876953, | |
-6.8864245414733887, | |
-6.8705768585205078, | |
-6.9667720794677734, | |
-7.1084094047546387, | |
-6.9528026580810547, | |
-6.783994197845459, | |
-6.6195521354675293, | |
-6.4876265525817871, | |
-6.4120659828186035, | |
-6.394047737121582, | |
-6.4244871139526367, | |
-6.3993711471557617, | |
-6.5158271789550781, | |
-6.7137999534606934, | |
-6.8476877212524414, | |
-6.9885001182556152, | |
-6.9221386909484863, | |
-7.146148681640625, | |
-7.2040400505065918, | |
-7.0537552833557129, | |
-7.3140382766723633, | |
-7.1223249435424805, | |
-7.30251407623291, | |
-7.1212143898010254, | |
-7.2425732612609863, | |
-7.1730537414550781, | |
-7.0979413986206055, | |
-7.088747501373291, | |
-6.9849910736083984, | |
-6.8787732124328613, | |
-6.7602753639221191, | |
-6.6300945281982422, | |
-6.5145769119262695, | |
-6.4245057106018066, | |
-6.356513500213623, | |
-6.31787633895874, | |
-6.2660770416259766, | |
-6.2468328475952148, | |
-6.2821526527404785, | |
-6.1908388137817383, | |
-6.2484354972839355, | |
-6.1472640037536621, | |
-6.0924725532531738, | |
-6.0171003341674805, | |
-5.9250402450561523, | |
-5.8535833358764648, | |
-5.8209109306335449, | |
-5.8118929862976074, | |
-5.80783748626709, | |
-5.7714629173278809, | |
-5.7453732490539551, | |
-5.7705655097961426, | |
-5.7765641212463379, | |
-5.7831673622131348, | |
-5.7954087257385254, | |
-5.7994823455810547, | |
-5.8023476600646973, | |
-5.8047118186950684, | |
-5.8168182373046875, | |
-5.8844799995422363, | |
-5.9727106094360352, | |
-6.0444660186767578, | |
-6.1284866333007812, | |
-6.2257585525512695, | |
-6.3157496452331543, | |
-6.39061164855957, | |
-6.4928598403930664, | |
-6.5498456954956055, | |
-6.6054320335388184, | |
-6.6508378982543945, | |
-6.66917610168457, | |
-6.6726889610290527, | |
-6.684234619140625, | |
-6.6974577903747559, | |
-6.75471830368042, | |
-6.7949142456054688, | |
-6.8634209632873535, | |
-6.94186544418335 | |
] | |
LIBRISPEECH_STD_VECTOR = [ | |
3.4353282451629639, | |
3.5962932109832764, | |
3.7012472152709961, | |
3.7369205951690674, | |
3.7535104751586914, | |
3.693629264831543, | |
3.6922497749328613, | |
3.7641522884368896, | |
3.8419716358184814, | |
3.8999848365783691, | |
3.9294240474700928, | |
3.9317409992218018, | |
3.9139585494995117, | |
3.9031598567962646, | |
3.8691999912261963, | |
3.8155081272125244, | |
3.7644970417022705, | |
3.7099106311798096, | |
3.6965086460113525, | |
3.6003766059875488, | |
3.5493226051330566, | |
3.5465121269226074, | |
3.45003604888916, | |
3.4712812900543213, | |
3.4084610939025879, | |
3.4408135414123535, | |
3.4104881286621094, | |
3.4217638969421387, | |
3.4312851428985596, | |
3.4199209213256836, | |
3.4305806159973145, | |
3.4382665157318115, | |
3.4580366611480713, | |
3.4817991256713867, | |
3.4958710670471191, | |
3.5036792755126953, | |
3.5047574043273926, | |
3.4988734722137451, | |
3.493056058883667, | |
3.4822943210601807, | |
3.459430456161499, | |
3.4612770080566406, | |
3.4559063911437988, | |
3.4755423069000244, | |
3.4971549510955811, | |
3.5326557159423828, | |
3.5705199241638184, | |
3.5920312404632568, | |
3.596907377243042, | |
3.5913500785827637, | |
3.5865931510925293, | |
3.5826809406280518, | |
3.5837743282318115, | |
3.5895791053771973, | |
3.5819313526153564, | |
3.5837869644165039, | |
3.5861184597015381, | |
3.5889589786529541, | |
3.592214822769165, | |
3.5939455032348633, | |
3.5856630802154541, | |
3.5884113311767578, | |
3.5921022891998291, | |
3.5870490074157715, | |
3.5806570053100586, | |
3.5731067657470703, | |
3.5617532730102539, | |
3.54980731010437, | |
3.5527374744415283, | |
3.5475366115570068, | |
3.5387849807739258, | |
3.5256178379058838, | |
3.5031836032867432, | |
3.4922726154327393, | |
3.4879646301269531, | |
3.4725594520568848, | |
3.4558389186859131, | |
3.4351828098297119, | |
3.4284293651580811, | |
3.4299170970916748 | |
] | |
@struct.dataclass | |
class LibrispeechPreprocessingConfig: | |
"""Config to hold all preprocessing options for librispeech dataset.""" | |
sample_rate: float = 16000.0 | |
frame_size_ms: float = 25.0 | |
frame_step_ms: float = 10.0 | |
compute_energy: bool = True | |
window_fn: str = 'HANNING' | |
output_log_floor: float = 1.0 | |
pad_end: bool = False | |
preemph: float = 0.97 | |
preemph_htk_flavor: bool = True | |
noise_scale: float = 0.0 | |
num_bins: int = 80 | |
lower_edge_hertz: float = 125.0 | |
upper_edge_hertz: float = 7600.0 | |
fft_overdrive: bool = False | |
output_floor: float = 0.000010 | |
def _hertz_to_mel(frequencies_hertz): | |
"""Convert hertz to mel.""" | |
return _MEL_HIGH_FREQUENCY_Q * jnp.log(1.0 + (frequencies_hertz / | |
_MEL_BREAK_FREQUENCY_HERTZ)) | |
def _pad_end_length(num_timesteps, frame_step, frame_size): | |
"""Returns how many sample needed to be padded for pad_end feature.""" | |
# The number of frames that can be extracted from the signal. | |
num_frames = int(np.ceil(num_timesteps / frame_step)) | |
# Signal length required for computing `num_frames` frames. | |
padded_length = frame_step * (num_frames - 1) + frame_size | |
return padded_length - num_timesteps | |
def frame(x, | |
frame_length: int, | |
frame_step: int, | |
pad_end: bool = False, | |
pad_value: Union[int, float] = 0.0): | |
"""Slides a window and extract values. | |
This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with | |
stride of `frame_step`, and returns an array `y` with the shape | |
`(batch_size, num_frames, frame_length, num_channels)`. Unlike the | |
counterpart in Tensorflow (`tf.signal.frame`), this function currently does | |
not take `axis` argument, and the input tensor `x` is expected to have a | |
shape of `(batch_size, timesteps, channels)`. | |
Args: | |
x: An input array with `(batch_size, timesteps, channels)`-shape. | |
frame_length: The frame length. | |
frame_step: The frame hop size. | |
pad_end: If True, the end of signal is padded so the window can continue | |
sliding while the starting point of the window is in the valid range. | |
pad_value: A scalar used as a padding value when `pad_end` is True. | |
Returns: | |
A tensor with shape `(batch_size, num_frames, frame_length, num_chennels)`. | |
""" | |
_, num_timesteps, num_channels = x.shape | |
if pad_end: | |
num_extends = _pad_end_length(num_timesteps, frame_step, frame_length) | |
x = jnp.pad( | |
x, ((0, 0), (0, num_extends), (0, 0)), | |
'constant', | |
constant_values=pad_value) | |
flat_y = jax.lax.conv_general_dilated_patches( | |
x, (frame_length,), (frame_step,), | |
'VALID', | |
dimension_numbers=('NTC', 'OIT', 'NTC')) | |
ret = flat_y.reshape(flat_y.shape[:-1] + (num_channels, frame_length)) | |
return ret.transpose((0, 1, 3, 2)) | |
def linear_to_mel_weight_matrix(num_mel_bins: int = 20, | |
num_spectrogram_bins: int = 129, | |
sample_rate: Union[int, float] = 8000, | |
lower_edge_hertz: Union[int, float] = 125.0, | |
upper_edge_hertz: Union[int, float] = 3800.0, | |
dtype: Any = jnp.float32): | |
r"""Jax-port of `tf.signal.linear_to_mel_weight_matrix`. | |
Args: | |
num_mel_bins: Python int. How many bands in the resulting mel spectrum. | |
num_spectrogram_bins: An integer `Tensor`. How many bins there are in the | |
source spectrogram data, which is understood to be `fft_size // 2 + 1`, | |
i.e. the spectrogram only contains the nonredundant FFT bins. | |
sample_rate: An integer or float `Tensor`. Samples per second of the input | |
signal used to create the spectrogram. Used to figure out the frequencies | |
corresponding to each spectrogram bin, which dictates how they are mapped | |
into the mel scale. | |
lower_edge_hertz: Python float. Lower bound on the frequencies to be | |
included in the mel spectrum. This corresponds to the lower edge of the | |
lowest triangular band. | |
upper_edge_hertz: Python float. The desired top edge of the highest | |
frequency band. | |
dtype: The `DType` of the result matrix. Must be a floating point type. | |
Returns: | |
An array of shape `[num_spectrogram_bins, num_mel_bins]`. | |
Raises: | |
ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not | |
positive, `lower_edge_hertz` is negative, frequency edges are incorrectly | |
ordered, `upper_edge_hertz` is larger than the Nyquist frequency. | |
[mel]: https://en.wikipedia.org/wiki/Mel_scale | |
""" | |
# Input validator from tensorflow/python/ops/signal/mel_ops.py#L71 | |
if num_mel_bins <= 0: | |
raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins) | |
if lower_edge_hertz < 0.0: | |
raise ValueError('lower_edge_hertz must be non-negative. Got: %s' % | |
lower_edge_hertz) | |
if lower_edge_hertz >= upper_edge_hertz: | |
raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' % | |
(lower_edge_hertz, upper_edge_hertz)) | |
if sample_rate <= 0.0: | |
raise ValueError('sample_rate must be positive. Got: %s' % sample_rate) | |
if upper_edge_hertz > sample_rate / 2: | |
raise ValueError('upper_edge_hertz must not be larger than the Nyquist ' | |
'frequency (sample_rate / 2). Got %s for sample_rate: %s' % | |
(upper_edge_hertz, sample_rate)) | |
# HTK excludes the spectrogram DC bin. | |
bands_to_zero = 1 | |
nyquist_hertz = sample_rate / 2.0 | |
linear_frequencies = jnp.linspace( | |
0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype)[bands_to_zero:] | |
spectrogram_bins_mel = _hertz_to_mel(linear_frequencies)[:, jnp.newaxis] | |
# Compute num_mel_bins triples of (lower_edge, center, upper_edge). The | |
# center of each band is the lower and upper edge of the adjacent bands. | |
# Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into | |
# num_mel_bins + 2 pieces. | |
edges = jnp.linspace( | |
_hertz_to_mel(lower_edge_hertz), | |
_hertz_to_mel(upper_edge_hertz), | |
num_mel_bins + 2, | |
dtype=dtype) | |
# Split the triples up and reshape them into [1, num_mel_bins] tensors. | |
lower_edge_mel = edges[:-2][jnp.newaxis, :] | |
center_mel = edges[1:-1][jnp.newaxis, :] | |
upper_edge_mel = edges[2:][jnp.newaxis, :] | |
# Calculate lower and upper slopes for every spectrogram bin. | |
# Line segments are linear in the mel domain, not Hertz. | |
lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / ( | |
center_mel - lower_edge_mel) | |
upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / ( | |
upper_edge_mel - center_mel) | |
# Intersect the line segments with each other and zero. | |
mel_weights_matrix = jnp.maximum(0.0, jnp.minimum(lower_slopes, upper_slopes)) | |
# Re-add the zeroed lower bins we sliced out above. | |
return jnp.pad(mel_weights_matrix, [[bands_to_zero, 0], [0, 0]]) | |
def _hanning_greco(win_support, frame_size, dtype): | |
"""Add a greco-style hanning window to the graph. | |
Note that the Hanning window in Wikipedia is not the same as the Hanning | |
window in Greco. The Greco3 Hanning window at 0 is NOT 0, as the wikipedia | |
page would indicate. Talkin's explanation was that it was like wasting two | |
samples to have the values at the edge of the window to be 0.0 exactly. | |
Args: | |
win_support: Number of samples for non-zero support in the window | |
frame_size: Total size of the window (frame_size >= win_support) | |
dtype: TF data type | |
Returns: | |
Tensor of size frame_size with the window to apply. | |
""" | |
if frame_size < win_support: | |
raise ValueError( | |
'Provided frame_size = {} is lower than win_support = {}'.format( | |
frame_size, win_support)) | |
arg = jnp.pi * 2.0 / (win_support) | |
hann = 0.5 - (0.5 * jnp.cos(arg * | |
(jnp.arange(win_support, dtype=dtype) + 0.5))) | |
zero_size = frame_size - win_support | |
return jnp.pad(hann, [(0, zero_size)]) | |
def _next_pow_of_two(x: Union[int, float]) -> int: | |
return int(2**np.ceil(np.log2(x))) | |
class SpectrogramFrontend(nn.Module): | |
"""Layer to convert input audio signals from time domain to frequency domain. | |
""" | |
config: LibrispeechPreprocessingConfig = None | |
input_scale_factor: float = 1.0 | |
output_log: bool = False | |
def setup(self) -> None: | |
p = self.config | |
self._frame_step = int(round(p.sample_rate * p.frame_step_ms / 1000.0)) | |
self._frame_size = int(round( | |
p.sample_rate * p.frame_size_ms / 1000.0)) + 1 # +1 for the preemph | |
# TF-version has maximum of 512, but it's not always necessary | |
self.fft_size = _next_pow_of_two(self._frame_size) | |
if p.window_fn is None: | |
self._window_fn = None | |
elif p.window_fn.upper() == 'HANNING': | |
def _hanning_window(frame_size, dtype): | |
# Preparing 1-point longer window to follow TF's definition | |
if frame_size % 2 == 0: | |
# simulate periodic=True in tf.signal.hann_window | |
return jnp.hanning(frame_size + 1).astype(dtype)[:-1] | |
else: | |
return jnp.hanning(frame_size).astype(dtype) | |
self._window_fn = _hanning_window | |
elif p.window_fn.upper() == 'HANNING_GRECO': | |
# Greco-compatible hanning window | |
def f(frame_size, dtype): | |
return _hanning_greco(self._frame_size - 1, frame_size, dtype) | |
self._window_fn = f | |
else: | |
raise ValueError('Illegal value %r for window_fn param' % p.window_fn) | |
def _apply_preemphasis(self, framed_signal): | |
p = self.config | |
if p.preemph_htk_flavor: | |
return jnp.concatenate([ | |
framed_signal[:, :, :1, :] * (1. - p.preemph), | |
(framed_signal[:, :, 1:-1, :] - | |
p.preemph * framed_signal[:, :, :-2, :]) | |
], | |
axis=2) | |
else: | |
return (framed_signal[:, :, 1:, :] - | |
p.preemph * framed_signal[:, :, :-1, :]) | |
def fprop_paddings(self, input_paddings): | |
p = self.config | |
if p.pad_end: | |
num_extends = _pad_end_length(input_paddings.shape[1], | |
self._frame_step, | |
self._frame_size) | |
input_paddings = jnp.pad( | |
input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0) | |
return jax.lax.reduce_window( | |
input_paddings, | |
init_value=1.0, | |
computation=jax.lax.min, | |
window_dimensions=[1, self._frame_size], | |
window_strides=[1, self._frame_step], | |
padding='valid') | |
def next_prng_key(self, name='dropout'): | |
return self.make_rng(name) | |
@nn.compact | |
def __call__(self, inputs, input_paddings): | |
inputs = inputs.astype(jnp.float32) | |
p = self.config | |
# Expand to have a channel axis | |
if inputs.ndim == 2: | |
inputs = jnp.expand_dims(inputs, -1) | |
output_paddings = None | |
if input_paddings is not None: | |
inputs = inputs * jnp.expand_dims(1.0 - input_paddings, -1) | |
output_paddings = self.fprop_paddings(input_paddings) | |
else: | |
output_paddings = None | |
pcm_audio_chunk = inputs.astype(jnp.float32) * self.input_scale_factor | |
framed_signal = frame( | |
pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end) | |
if p.preemph != 0.0: | |
preemphasized = self._apply_preemphasis(framed_signal) | |
else: | |
preemphasized = framed_signal[..., :-1, :] | |
if p.noise_scale > 0.0: | |
noise_signal = jax.random.normal(self.next_prng_key(), | |
preemphasized.shape) * p.noise_scale | |
else: | |
noise_signal = jnp.zeros(preemphasized.shape) | |
windowed_signal = preemphasized + noise_signal | |
# Window here | |
if self._window_fn is not None: | |
window = self._window_fn(self._frame_size - 1, framed_signal.dtype) | |
window = window.reshape((1, 1, self._frame_size - 1, 1)) | |
windowed_signal *= window | |
spectrum = jnp.fft.rfft(windowed_signal, self.fft_size, axis=2) | |
spectrum = jnp.abs(spectrum) | |
if p.compute_energy: | |
spectrum = spectrum**2.0 | |
outputs = spectrum | |
if self.output_log: | |
outputs = jnp.log(jnp.maximum(outputs, p.output_log_floor)) | |
return outputs, output_paddings | |
class MelFilterbankFrontend(nn.Module): | |
"""Layer to compute log mel spectograms from input audio signals. | |
""" | |
config: LibrispeechPreprocessingConfig = None | |
use_divide_stream: bool = True | |
per_bin_mean: Optional[float] = None | |
per_bin_stddev: Optional[float] = None | |
def setup(self): | |
p = self.config | |
input_scale_factor = 2**-15 if self.use_divide_stream else 1.0 | |
self.stft = SpectrogramFrontend( | |
p, input_scale_factor=input_scale_factor, output_log=False) | |
if self.per_bin_mean is None: | |
per_bin_mean = [0.0] * p.num_bins | |
else: | |
per_bin_mean = self.per_bin_mean | |
if self.per_bin_stddev is None: | |
per_bin_stddev = [1.0] * p.num_bins | |
else: | |
per_bin_stddev = self.per_bin_stddev | |
self._normalizer_mean = jnp.array(per_bin_mean)[ | |
jnp.newaxis, jnp.newaxis, :, jnp.newaxis] | |
self._normalizer_stddev = jnp.array(per_bin_stddev)[ | |
jnp.newaxis, jnp.newaxis, :, jnp.newaxis] | |
@nn.compact | |
def __call__(self, inputs, input_paddings): | |
p = self.config | |
spect, spect_paddings = self.stft(inputs, input_paddings) | |
mel_weights = linear_to_mel_weight_matrix( | |
num_mel_bins=p.num_bins, | |
num_spectrogram_bins=spect.shape[2], | |
sample_rate=p.sample_rate, | |
lower_edge_hertz=p.lower_edge_hertz, | |
upper_edge_hertz=p.upper_edge_hertz) | |
mel_spectrogram = jnp.einsum('fn,btfc->btnc', mel_weights, spect) | |
logmel_spectrogram = jnp.log(jnp.maximum(mel_spectrogram, p.output_floor)) | |
normalized_logmel_spectrogram = ( | |
(logmel_spectrogram - self._normalizer_mean) / self._normalizer_stddev) | |
normalized_logmel_spectrogram = jnp.squeeze(normalized_logmel_spectrogram, | |
-1) | |
return normalized_logmel_spectrogram, spect_paddings | |
#@title SpecAug Layer | |
"""A flax layer to do data augmentation for audio signals as | |
described in https://arxiv.org/abs/1904.08779. | |
Code based on: | |
github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/spectrum_augmenter.py | |
""" | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
class SpecAug(nn.Module): | |
"""Layer performs masking prodecure along time and frequency axis. | |
The procedure is detailed in https://arxiv.org/abs/1904.08779. | |
This is an essential component in speech recognition models that helps achieve | |
better word error rates. | |
""" | |
freq_mask_count: int = 2 | |
freq_mask_max_bins: int = 27 | |
time_mask_count: int = 10 | |
time_mask_max_frames: int = 40 | |
time_mask_max_ratio: float = 0.05 | |
time_masks_per_frame: float = 0.0 | |
use_dynamic_time_mask_max_frames: bool = True | |
def next_prng_key(self, name='dropout'): | |
return self.make_rng(name) | |
def _get_mask(self, | |
batch_size, | |
choose_range, | |
mask_size, | |
max_length=None, | |
masks_per_frame=0.0, | |
multiplicity=1, | |
max_ratio=1.0): | |
# Sample lengths for multiple masks. | |
if max_length and max_length > 0: | |
max_length = jnp.tile(max_length, (batch_size,)) | |
else: | |
max_length = choose_range * max_ratio | |
masked_portion = jax.random.uniform( | |
key=self.next_prng_key(), | |
shape=(batch_size, multiplicity), | |
minval=0.0, | |
maxval=1.0) | |
masked_frame_size = jnp.einsum('b,bm->bm', max_length, | |
masked_portion).astype(jnp.int32) | |
# Make sure the sampled length was smaller than max_ratio * length_bound. | |
# Note that sampling in this way was biased | |
# (shorter sequence may over-masked.) | |
choose_range = jnp.tile(choose_range[:, None], [1, multiplicity]) | |
length_bound = (max_ratio * choose_range).astype(jnp.int32) | |
length = jnp.minimum(masked_frame_size, jnp.maximum(length_bound, 1)) | |
# Choose starting point. | |
random_start = jax.random.uniform( | |
key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0) | |
start_with_in_valid_range = random_start * (choose_range - length + 1) | |
start = start_with_in_valid_range.astype(jnp.int32) | |
end = start + length - 1 | |
# Shift starting and end point by small value. | |
delta = 0.1 | |
start = jnp.expand_dims(start - delta, -1) | |
start = jnp.tile(start, [1, 1, mask_size]) | |
end = jnp.expand_dims(end + delta, -1) | |
end = jnp.tile(end, [1, 1, mask_size]) | |
# Construct pre-mask of shape (batch_size, multiplicity, mask_size). | |
diagonal = jnp.expand_dims(jnp.expand_dims(jnp.arange(mask_size), 0), 0) | |
diagonal = jnp.tile(diagonal, [batch_size, multiplicity, 1]) | |
pre_mask = jnp.minimum(diagonal < end, diagonal > start) | |
# Sum masks with appropriate multiplicity. | |
if masks_per_frame > 0: | |
multiplicity_weights = jnp.tile( | |
jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), | |
[batch_size, 1]) | |
multiplicity_tensor = masks_per_frame * choose_range | |
multiplicity_weights = (multiplicity_weights < | |
multiplicity_tensor).astype(jnp.int32) | |
pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) | |
else: | |
pre_mask = jnp.einsum('bmt->bt', pre_mask) | |
mask = 1.0 - (pre_mask > 0).astype(jnp.int32) | |
return mask | |
def _time_mask(self, inputs, length): | |
# Get time masking parameters. | |
time_mask_max_frames = self.time_mask_max_frames | |
use_dynamic_time_mask_max_frames = self.use_dynamic_time_mask_max_frames | |
multiplicity = self.time_mask_count | |
max_ratio = self.time_mask_max_ratio | |
# If maximum mask length is zero, do nothing. | |
if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or | |
max_ratio <= 0.0): | |
return inputs | |
if multiplicity == 0: | |
return inputs | |
batch_size, time_length, _ = inputs.shape | |
# When using dynamic time mask size, discard upper-bound on | |
# maximum allowed frames for time mask. | |
if use_dynamic_time_mask_max_frames: | |
time_mask_max_frames = None | |
# Create masks in time direction and apply. | |
block_arrays = self._get_mask( | |
batch_size, | |
choose_range=length, | |
mask_size=time_length, | |
max_length=time_mask_max_frames, | |
masks_per_frame=self.time_masks_per_frame, | |
multiplicity=multiplicity, | |
max_ratio=max_ratio) | |
outputs = jnp.einsum('bxy,bx->bxy', inputs, block_arrays) | |
return outputs | |
def _frequency_mask(self, inputs): | |
# Mask parameters. | |
freq_mask_max_bins = self.freq_mask_max_bins | |
multiplicity = self.freq_mask_count | |
# If masking length or count is zero, do nothing. | |
if freq_mask_max_bins == 0 or multiplicity == 0: | |
return inputs | |
# Arguments to pass to mask generator. | |
batch_size, _, num_freq = inputs.shape | |
choose_range = jnp.tile(num_freq, (batch_size,)) | |
# Create masks in frequency direction and apply. | |
block_arrays = self._get_mask( | |
batch_size, | |
choose_range=choose_range, | |
mask_size=num_freq, | |
max_length=freq_mask_max_bins, | |
masks_per_frame=0.0, | |
multiplicity=multiplicity, | |
max_ratio=1.0) | |
outputs = jnp.einsum('bxy,by->bxy', inputs, block_arrays) | |
return outputs | |
@nn.compact | |
def __call__(self, inputs, paddings): | |
lengths = jnp.einsum('bh->b', 1 - paddings).astype(jnp.int32) | |
inputs = self._time_mask(inputs, lengths) | |
inputs = self._frequency_mask(inputs) | |
return inputs, paddings | |
#@title CudnnLSTM Layer | |
from typing import Any, Optional, Sequence, Tuple, Union | |
from flax import linen as nn | |
import jax | |
from jax.experimental import rnn | |
import jax.numpy as jnp | |
import numpy as np | |
Array = jnp.ndarray | |
StateType = Union[Array, Tuple[Array, ...]] | |
PRNGKey = Any | |
Shape = Tuple[int] | |
Dtype = Any | |
class CudnnLSTM(nn.Module): | |
input_size: int | |
hidden_size: int | |
num_layers: int | |
dropout_rate: float = 0.0 | |
bidirectional: bool = False | |
def setup(self): | |
self.w = self.param( | |
'lstm_weights', | |
rnn.init_lstm_weight, | |
self.input_size, | |
self.hidden_size, | |
self.num_layers, | |
self.bidirectional, | |
) | |
def __call__( | |
self, | |
inputs: Array, | |
input_paddings: Array, | |
initial_states: Optional[Sequence[StateType]] = None, | |
deterministic: bool = False, | |
) -> Tuple[Array, Sequence[StateType]]: | |
# TODO(zhangqiaorjc): initial_states | |
assert initial_states is None | |
num_directions = 2 if self.bidirectional else 1 | |
batch_size = inputs.shape[0] | |
dropout = 0.0 if deterministic else self.dropout_rate | |
h_0 = jnp.zeros( | |
(num_directions * self.num_layers, batch_size, self.hidden_size), | |
jnp.float32, | |
) | |
c_0 = jnp.zeros( | |
(num_directions * self.num_layers, batch_size, self.hidden_size), | |
jnp.float32, | |
) | |
seq_lengths = jnp.sum(1.0 - input_paddings, axis=-1, dtype=jnp.int32) | |
# def lstm(input, h_0, c_0, weights, input_size: int, hidden_size: int, | |
# num_layers: int, dropout: float, bidirectional: bool): | |
y, _, _ = rnn.lstm( | |
inputs, | |
h_0, | |
c_0, | |
self.w, | |
seq_lengths, | |
self.input_size, | |
self.hidden_size, | |
self.num_layers, | |
dropout, | |
self.bidirectional, | |
) | |
return y | |
#@title Deepspeech Model | |
r"""Deepspeech. | |
This model uses a deepspeech2 network to convert speech to text. | |
paper : https://arxiv.org/abs/1512.02595 | |
# BiLSTM code contributed by bastings@ | |
# github : https://github.com/bastings | |
# webpage : https://bastings.github.io/ | |
""" | |
import functools | |
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Type, Union | |
import flax | |
from flax import linen as nn | |
from flax import struct | |
import jax | |
import jax.numpy as jnp | |
Array = jnp.ndarray | |
StateType = Union[Array, Tuple[Array, ...]] | |
PRNGKey = Any | |
Shape = Tuple[int] | |
Dtype = Any | |
@struct.dataclass | |
class DeepspeechConfig: | |
"""Global hyperparameters used to minimize obnoxious kwarg plumbing.""" | |
vocab_size: int = 1024 | |
dtype: Any = jnp.float32 | |
encoder_dim: int = 512 | |
num_lstm_layers: int = 6 | |
num_ffn_layers: int = 3 | |
conv_subsampling_factor: int = 2 | |
conv_subsampling_layers: int = 2 | |
use_specaug: bool = True | |
freq_mask_count: int = 2 | |
freq_mask_max_bins: int = 27 | |
time_mask_count: int = 10 | |
time_mask_max_frames: int = 40 | |
time_mask_max_ratio: float = 0.05 | |
time_masks_per_frame: float = 0.0 | |
use_dynamic_time_mask_max_frames: bool = True | |
batch_norm_momentum: float = 0.999 | |
batch_norm_epsilon: float = 0.001 | |
# If None, defaults to 0.1. | |
input_dropout_rate: Optional[float] = 0.1 | |
# If None, defaults to 0.1. | |
feed_forward_dropout_rate: Optional[float] = 0.1 | |
enable_residual_connections: bool = True | |
enable_decoder_layer_norm: bool = True | |
bidirectional: bool = True | |
use_cudnn_lstm: bool = False | |
class Subsample(nn.Module): | |
"""Module to perform strided convolution in order to subsample inputs. | |
Attributes: | |
encoder_dim: model dimension of conformer. | |
input_dropout_rate: dropout rate for inputs. | |
""" | |
config: DeepspeechConfig | |
@nn.compact | |
def __call__(self, inputs, output_paddings, train): | |
config = self.config | |
outputs = jnp.expand_dims(inputs, axis=-1) | |
outputs, output_paddings = Conv2dSubsampling( | |
encoder_dim=config.encoder_dim, | |
dtype=config.dtype, | |
batch_norm_momentum=config.batch_norm_momentum, | |
batch_norm_epsilon=config.batch_norm_epsilon, | |
input_channels=1, | |
output_channels=config.encoder_dim)(outputs, output_paddings, train) | |
outputs, output_paddings = Conv2dSubsampling( | |
encoder_dim=config.encoder_dim, | |
dtype=config.dtype, | |
batch_norm_momentum=config.batch_norm_momentum, | |
batch_norm_epsilon=config.batch_norm_epsilon, | |
input_channels=config.encoder_dim, | |
output_channels=config.encoder_dim)(outputs, output_paddings, train) | |
batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape | |
outputs = jnp.reshape( | |
outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)) | |
outputs = nn.Dense( | |
config.encoder_dim, | |
use_bias=True, | |
kernel_init=nn.initializers.xavier_uniform())( | |
outputs) | |
if config.input_dropout_rate is None: | |
input_dropout_rate = 0.1 | |
else: | |
input_dropout_rate = config.input_dropout_rate | |
outputs = nn.Dropout( | |
rate=input_dropout_rate, deterministic=not train)( | |
outputs) | |
return outputs, output_paddings | |
class Conv2dSubsampling(nn.Module): | |
"""Helper module used in Subsample layer. | |
1) Performs strided convolution over inputs and then applies non-linearity. | |
2) Also performs strided convolution over input_paddings to return the correct | |
paddings for downstream layers. | |
""" | |
input_channels: int = 0 | |
output_channels: int = 0 | |
filter_stride: List[int] = (2, 2) | |
padding: str = 'SAME' | |
encoder_dim: int = 0 | |
dtype: Any = jnp.float32 | |
batch_norm_momentum: float = 0.999 | |
batch_norm_epsilon: float = 0.001 | |
def setup(self): | |
self.filter_shape = (3, 3, self.input_channels, self.output_channels) | |
self.kernel = self.param('kernel', | |
nn.initializers.xavier_uniform(), | |
self.filter_shape) | |
self.bias = self.param( | |
'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) | |
@nn.compact | |
def __call__(self, inputs, paddings, train): | |
# Computing strided convolution to subsample inputs. | |
feature_group_count = inputs.shape[3] // self.filter_shape[2] | |
outputs = jax.lax.conv_general_dilated( | |
lhs=inputs, | |
rhs=self.kernel, | |
window_strides=self.filter_stride, | |
padding=self.padding, | |
rhs_dilation=(1, 1), | |
dimension_numbers=('NHWC', 'HWIO', 'NHWC'), | |
feature_group_count=feature_group_count) | |
outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) | |
outputs = nn.relu(outputs) | |
# Computing correct paddings post input convolution. | |
input_length = paddings.shape[1] | |
stride = self.filter_stride[0] | |
pad_len = (input_length + stride - 1) // stride * stride - input_length | |
out_padding = jax.lax.conv_general_dilated( | |
lhs=paddings[:, :, None], | |
rhs=jnp.ones([1, 1, 1]), | |
window_strides=self.filter_stride[:1], | |
padding=[(0, pad_len)], | |
dimension_numbers=('NHC', 'HIO', 'NHC')) | |
out_padding = jnp.squeeze(out_padding, axis=-1) | |
# Mask outputs by correct paddings to ensure padded elements in inputs map | |
# to padded value in outputs. | |
outputs = outputs * (1.0 - | |
jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)) | |
return outputs, out_padding | |
class FeedForwardModule(nn.Module): | |
"""Feedforward block of conformer layer.""" | |
config: DeepspeechConfig | |
@nn.compact | |
def __call__(self, inputs, input_paddings=None, train=False): | |
padding_mask = jnp.expand_dims(1 - input_paddings, -1) | |
config = self.config | |
inputs = BatchNorm(config.encoder_dim, | |
config.dtype, | |
config.batch_norm_momentum, | |
config.batch_norm_epsilon)(inputs, input_paddings, train) | |
inputs = nn.Dense( | |
config.encoder_dim, | |
use_bias=True, | |
kernel_init=nn.initializers.xavier_uniform())( | |
inputs) | |
inputs = nn.relu(inputs) | |
inputs *= padding_mask | |
if config.feed_forward_dropout_rate is None: | |
feed_forward_dropout_rate = 0.1 | |
else: | |
feed_forward_dropout_rate = config.feed_forward_dropout_rate | |
inputs = nn.Dropout(rate=feed_forward_dropout_rate)( | |
inputs, deterministic=not train) | |
return inputs | |
class LayerNorm(nn.Module): | |
"""Module implementing layer normalization. | |
This implementation is same as in this paper: | |
https://arxiv.org/pdf/1607.06450.pdf. | |
note: we multiply normalized inputs by (1 + scale) and initialize scale to | |
zeros, this differs from default flax implementation of multiplying by scale | |
and initializing to ones. | |
""" | |
dim: int = 0 | |
epsilon: float = 1e-6 | |
def setup(self): | |
self.scale = self.param('scale', nn.initializers.zeros, [self.dim]) | |
self.bias = self.param('bias', nn.initializers.zeros, [self.dim]) | |
@nn.compact | |
def __call__(self, inputs): | |
mean = jnp.mean(inputs, axis=-1, keepdims=True) | |
var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True) | |
normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) | |
normed_inputs *= (1 + self.scale) | |
normed_inputs += self.bias | |
return normed_inputs | |
class BatchNorm(nn.Module): | |
"""Implements batch norm respecting input paddings. | |
This implementation takes into account input padding by masking inputs before | |
computing mean and variance. | |
This is inspired by lingvo jax implementation of BatchNorm: | |
https://github.com/tensorflow/lingvo/blob/84b85514d7ad3652bc9720cb45acfab08604519b/lingvo/jax/layers/normalizations.py#L92 | |
and the corresponding defaults for momentum and epsilon have been copied over | |
from lingvo. | |
""" | |
encoder_dim: int = 0 | |
dtype: Any = jnp.float32 | |
batch_norm_momentum: float = 0.999 | |
batch_norm_epsilon: float = 0.001 | |
def setup(self): | |
dim = self.encoder_dim | |
dtype = self.dtype | |
self.ra_mean = self.variable('batch_stats', | |
'mean', | |
lambda s: jnp.zeros(s, dtype), | |
dim) | |
self.ra_var = self.variable('batch_stats', | |
'var', | |
lambda s: jnp.ones(s, dtype), | |
dim) | |
self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) | |
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) | |
def _get_default_paddings(self, inputs): | |
"""Gets the default paddings for an input.""" | |
in_shape = list(inputs.shape) | |
in_shape[-1] = 1 | |
return jnp.zeros(in_shape, dtype=inputs.dtype) | |
@nn.compact | |
def __call__(self, inputs, input_paddings=None, train=False): | |
rank = inputs.ndim | |
reduce_over_dims = list(range(0, rank - 1)) | |
if input_paddings is None: | |
padding = self._get_default_paddings(inputs) | |
else: | |
padding = jnp.expand_dims(input_paddings, -1) | |
momentum = self.batch_norm_momentum | |
epsilon = self.batch_norm_epsilon | |
if train: | |
mask = 1.0 - padding | |
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) | |
count_v = jnp.sum( | |
jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) | |
sum_v = jax.lax.psum(sum_v, axis_name='batch') | |
count_v = jax.lax.psum(count_v, axis_name='batch') | |
count_v = jnp.maximum(count_v, 1.0) | |
mean = sum_v / count_v | |
variance = (inputs - mean) * (inputs - mean) * mask | |
sum_vv = jnp.sum(variance, axis=reduce_over_dims, keepdims=True) | |
sum_vv = jax.lax.psum(sum_vv, axis_name='batch') | |
var = sum_vv / count_v | |
self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean | |
self.ra_var.value = momentum * self.ra_var.value + (1 - momentum) * var | |
else: | |
mean = self.ra_mean.value | |
var = self.ra_var.value | |
inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) | |
bn_output = (inputs - mean) * inv + self.beta | |
bn_output *= 1.0 - padding | |
return bn_output | |
# return inputs | |
@jax.vmap | |
def flip_sequences(inputs: Array, lengths: Array) -> Array: | |
"""Flips a sequence of inputs along the time dimension. | |
This function can be used to prepare inputs for the reverse direction of a | |
bidirectional LSTM. It solves the issue that, when naively flipping multiple | |
padded sequences stored in a matrix, the first elements would be padding | |
values for those sequences that were padded. This function keeps the padding | |
at the end, while flipping the rest of the elements. | |
Example: | |
```python | |
inputs = [[1, 0, 0], | |
[2, 3, 0] | |
[4, 5, 6]] | |
lengths = [1, 2, 3] | |
flip_sequences(inputs, lengths) = [[1, 0, 0], | |
[3, 2, 0], | |
[6, 5, 4]] | |
``` | |
Args: | |
inputs: An array of input IDs <int>[batch_size, seq_length]. | |
lengths: The length of each sequence <int>[batch_size]. | |
Returns: | |
An ndarray with the flipped inputs. | |
""" | |
# Compute the indices to put the inputs in flipped order as per above example. | |
max_length = inputs.shape[0] | |
idxs = (jnp.arange(max_length - 1, -1, -1) + lengths) % max_length | |
return inputs[idxs] | |
class GenericRNNSequenceEncoder(nn.Module): | |
"""Encodes a single sequence using any RNN cell, for example `nn.LSTMCell`. | |
The sequence can be encoded left-to-right (default) or right-to-left (by | |
calling the module with reverse=True). Regardless of encoding direction, | |
outputs[i, j, ...] is the representation of inputs[i, j, ...]. | |
Attributes: | |
hidden_size: The hidden size of the RNN cell. | |
cell_type: The RNN cell module to use, for example, `nn.LSTMCell`. | |
cell_kwargs: Optional keyword arguments for the recurrent cell. | |
recurrent_dropout_rate: The dropout to apply across time steps. If this is | |
greater than zero, you must use an RNN cell that implements | |
`RecurrentDropoutCell` such as RecurrentDropoutOptimizedLSTMCell. | |
""" | |
hidden_size: int | |
cell_type: Type[nn.recurrent.RNNCellBase] | |
cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() | |
recurrent_dropout_rate: float = 0.0 | |
def setup(self): | |
self.cell = self.cell_type(**self.cell_kwargs) | |
@functools.partial( # Repeatedly calls the below method to encode the inputs. | |
nn.transforms.scan, | |
variable_broadcast='params', | |
in_axes=(1, flax.core.axes_scan.broadcast, flax.core.axes_scan.broadcast), | |
out_axes=1, | |
split_rngs={'params': False}) | |
def unroll_cell(self, | |
cell_state: StateType, | |
inputs: Array, | |
recurrent_dropout_mask: Optional[Array], | |
deterministic: bool): | |
"""Unrolls a recurrent cell over an input sequence. | |
Args: | |
cell_state: The initial cell state, shape: <float32>[batch_size, | |
hidden_size] (or an n-tuple thereof). | |
inputs: The input sequence. <float32>[batch_size, seq_len, input_dim]. | |
recurrent_dropout_mask: An optional recurrent dropout mask to apply in | |
between time steps. <float32>[batch_size, hidden_size]. | |
deterministic: Disables recurrent dropout when set to True. | |
Returns: | |
The cell state after processing the complete sequence (including padding), | |
and a tuple with all intermediate cell states and cell outputs. | |
""" | |
# We do not directly scan the cell itself, since it only returns the output. | |
# This returns both the state and the output, so we can slice out the | |
# correct final states later. | |
new_cell_state, output = self.cell(cell_state, inputs) | |
return new_cell_state, (new_cell_state, output) | |
def __call__(self, | |
inputs: Array, | |
input_paddings: Array, | |
initial_state: StateType, | |
reverse: bool = False, | |
deterministic: bool = False): | |
"""Unrolls the RNN cell over the inputs. | |
Arguments: | |
inputs: A batch of sequences. Shape: <float32>[batch_size, seq_len, | |
input_dim]. | |
lengths: The lengths of the input sequences. | |
initial_state: The initial state for the RNN cell. Shape: [batch_size, | |
hidden_size]. | |
reverse: Process the inputs in reverse order, and reverse the outputs. | |
This means that the outputs still correspond to the order of the inputs, | |
but their contexts come from the right, and not from the left. | |
deterministic: Disables recurrent dropout if set to True. | |
Returns: | |
The encoded sequence of inputs, shaped <float32>[batch_size, seq_len, | |
hidden_size], as well as the final hidden states of the RNN cell. For an | |
LSTM cell the final states are a tuple (c, h), each shaped <float32>[ | |
batch_size, hidden_size]. | |
""" | |
lengths = jnp.sum(1 - input_paddings, axis=-1, dtype=jnp.int32) | |
if reverse: | |
inputs = flip_sequences(inputs, lengths) | |
recurrent_dropout_mask = None | |
_, (_, outputs) = self.unroll_cell(initial_state, | |
inputs, | |
recurrent_dropout_mask, | |
deterministic) | |
if reverse: | |
outputs = flip_sequences(outputs, lengths) | |
return outputs | |
class GenericRNN(nn.Module): | |
"""Generic RNN class. | |
This provides generic RNN functionality to encode sequences with any RNN cell. | |
The class provides unidirectional and bidirectional layers, and these are | |
stacked when asking for multiple layers. | |
This class be used to create a specific RNN class such as LSTM or GRU. | |
Attributes: | |
cell_type: An RNN cell class to use, e.g., `flax.linen.LSTMCell`. | |
hidden_size: The size of each recurrent cell. | |
num_layers: The number of stacked recurrent layers. The output of the first | |
layer, with optional dropout applied, feeds into the next layer. | |
dropout_rate: Dropout rate to be applied between LSTM layers. Only applies | |
when num_layers > 1. | |
recurrent_dropout_rate: Dropout rate to be applied on the hidden state at | |
each time step repeating the same dropout mask. | |
bidirectional: Process the sequence left-to-right and right-to-left and | |
concatenate the outputs from the two directions. | |
cell_kwargs: Optional keyword arguments to instantiate the cell with. | |
""" | |
cell_type: Type[nn.recurrent.RNNCellBase] | |
hidden_size: int | |
num_layers: int = 1 | |
dropout_rate: float = 0. | |
recurrent_dropout_rate: float = 0. | |
bidirectional: bool = False | |
cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() | |
@nn.compact | |
def __call__( | |
self, | |
inputs: Array, | |
input_paddings: Array, | |
initial_states: Optional[Sequence[StateType]] = None, | |
deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: | |
"""Processes the input sequence using the recurrent cell. | |
Args: | |
inputs: The input sequence <float32>[batch_size, sequence_length, ...] | |
lengths: The lengths of each sequence in the batch. <int64>[batch_size] | |
initial_states: The initial states for the cells. You must provide | |
`num_layers` initial states (when using bidirectional, `num_layers * | |
2`). | |
These must be ordered in the following way: (layer_0_forward, | |
layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, | |
all initial states will be initialized with zeros. | |
deterministic: Disables dropout between layers when set to True. | |
Returns: | |
The sequence of all outputs for the final layer, and a list of final | |
states for each cell and direction. Directions are alternated (first | |
forward, then backward, if bidirectional). For example for a bidirectional | |
cell this would be: layer 1 forward, layer 1 backward, layer 2 forward, | |
layer 2 backward, etc.. | |
For some cells like LSTMCell a state consists of an (c, h) tuple, while | |
for others cells it only contains a single vector (h,). | |
""" | |
batch_size = inputs.shape[0] | |
num_directions = 2 if self.bidirectional else 1 | |
num_cells = self.num_layers * num_directions | |
# Construct initial states. | |
if initial_states is None: # Initialize with zeros. | |
rng = jax.random.PRNGKey(0) | |
initial_states = [ | |
self.cell_type.initialize_carry(rng, (batch_size,), self.hidden_size) | |
for _ in range(num_cells) | |
] | |
if len(initial_states) != num_cells: | |
raise ValueError( | |
f'Please provide {self.num_cells} (`num_layers`, *2 if bidirectional)' | |
f'initial states.') | |
# For each layer, apply the forward and optionally the backward RNN cell. | |
cell_idx = 0 | |
for _ in range(self.num_layers): | |
# Unroll an RNN cell (forward direction) for this layer. | |
outputs = GenericRNNSequenceEncoder( | |
cell_type=self.cell_type, | |
cell_kwargs=self.cell_kwargs, | |
hidden_size=self.hidden_size, | |
recurrent_dropout_rate=self.recurrent_dropout_rate, | |
name=f'{self.name}SequenceEncoder_{cell_idx}')( | |
inputs, | |
input_paddings, | |
initial_state=initial_states[cell_idx], | |
deterministic=deterministic) | |
cell_idx += 1 | |
# Unroll an RNN cell (backward direction) for this layer. | |
if self.bidirectional: | |
backward_outputs = GenericRNNSequenceEncoder( | |
cell_type=self.cell_type, | |
cell_kwargs=self.cell_kwargs, | |
hidden_size=self.hidden_size, | |
recurrent_dropout_rate=self.recurrent_dropout_rate, | |
name=f'{self.name}SequenceEncoder_{cell_idx}')( | |
inputs, | |
input_paddings, | |
initial_state=initial_states[cell_idx], | |
reverse=True, | |
deterministic=deterministic) | |
outputs = jnp.concatenate([outputs, backward_outputs], axis=-1) | |
cell_idx += 1 | |
inputs = outputs | |
return outputs | |
class LSTM(nn.Module): | |
"""LSTM. | |
Attributes: | |
hidden_size: The size of each recurrent cell. | |
num_layers: The number of stacked recurrent layers. The output of the first | |
layer, with optional dropout applied, feeds into the next layer. | |
dropout_rate: Dropout rate to be applied between LSTM layers. Only applies | |
when num_layers > 1. | |
recurrent_dropout_rate: Dropout rate to be applied on the hidden state at | |
each time step repeating the same dropout mask. | |
bidirectional: Process the sequence left-to-right and right-to-left and | |
concatenate the outputs from the two directions. | |
cell_type: The LSTM cell class to use. Default: | |
`flax.linen.OptimizedLSTMCell`. If you use hidden_size of >2048, consider | |
using `flax.linen.LSTMCell` instead, since the optimized LSTM cell works | |
best for hidden sizes up to 2048. | |
cell_kwargs: Optional keyword arguments to instantiate the cell with. | |
""" | |
hidden_size: int | |
num_layers: int = 1 | |
dropout_rate: float = 0. | |
recurrent_dropout_rate: float = 0. | |
bidirectional: bool = False | |
cell_type: Any = nn.OptimizedLSTMCell | |
cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() | |
@nn.compact | |
def __call__( | |
self, | |
inputs: Array, | |
input_paddings: Array, | |
initial_states: Optional[Sequence[StateType]] = None, | |
deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: | |
"""Processes an input sequence with an LSTM cell. | |
Example usage: | |
``` | |
inputs = np.random.normal(size=(2, 3, 4)) | |
lengths = np.array([1, 3]) | |
outputs, final_states = LSTM(hidden_size=10).apply(rngs, inputs, lengths) | |
``` | |
Args: | |
inputs: The input sequence <float32>[batch_size, sequence_length, ...] | |
lengths: The lengths of each sequence in the batch. <int64>[batch_size] | |
initial_states: The initial states for the cells. You must provide | |
`num_layers` initial states (when using bidirectional, `num_layers * | |
2`). These must be ordered in the following way: (layer_0_forward, | |
layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, | |
all initial states will be initialized with zeros. | |
deterministic: Disables dropout between layers when set to True. | |
Returns: | |
The sequence of all outputs for the final layer, and a list of final | |
states (h, c) for each cell and direction, ordered first by layer number | |
and then by direction (first forward, then backward, if bidirectional). | |
""" | |
return GenericRNN( | |
cell_type=self.cell_type, | |
hidden_size=self.hidden_size, | |
num_layers=self.num_layers, | |
dropout_rate=self.dropout_rate, | |
recurrent_dropout_rate=self.recurrent_dropout_rate, | |
bidirectional=self.bidirectional, | |
cell_kwargs=self.cell_kwargs, | |
name='LSTM')( | |
inputs, | |
input_paddings, | |
initial_states=initial_states, | |
deterministic=deterministic) | |
class BatchRNN(nn.Module): | |
"""Implements a single deepspeech encoder layer. | |
""" | |
config: DeepspeechConfig | |
@nn.compact | |
def __call__(self, inputs, input_paddings, train): | |
config = self.config | |
inputs = BatchNorm(config.encoder_dim, | |
config.dtype, | |
config.batch_norm_momentum, | |
config.batch_norm_epsilon)(inputs, input_paddings, train) | |
if config.use_cudnn_lstm: | |
output = CudnnLSTM( | |
input_size = config.encoder_dim, | |
hidden_size=config.encoder_dim // 2 if config.bidirectional else config.encoder_dim, | |
bidirectional=config.bidirectional, | |
num_layers=1)(inputs, input_paddings) | |
else: | |
output = LSTM( | |
hidden_size=config.encoder_dim // 2 if config.bidirectional else config.encoder_dim, | |
bidirectional=config.bidirectional, | |
num_layers=1)(inputs, input_paddings) | |
return output | |
class Deepspeech(nn.Module): | |
"""Conformer (encoder + decoder) block. | |
Takes audio input signals and outputs probability distribution over vocab size | |
for each time step. The output is then fed into a CTC loss which eliminates | |
the need for alignment with targets. | |
""" | |
config: DeepspeechConfig | |
def setup(self): | |
config = self.config | |
self.specaug = SpecAug( | |
freq_mask_count=config.freq_mask_count, | |
freq_mask_max_bins=config.freq_mask_max_bins, | |
time_mask_count=config.time_mask_count, | |
time_mask_max_frames=config.time_mask_max_frames, | |
time_mask_max_ratio=config.time_mask_max_ratio, | |
time_masks_per_frame=config.time_masks_per_frame, | |
use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames | |
) | |
@nn.compact | |
def __call__(self, inputs, input_paddings, train): | |
config = self.config | |
outputs = inputs | |
output_paddings = input_paddings | |
# Compute normalized log mel spectrograms from input audio signal. | |
preprocessing_config = LibrispeechPreprocessingConfig() | |
outputs, output_paddings = MelFilterbankFrontend( | |
preprocessing_config, | |
per_bin_mean=LIBRISPEECH_MEAN_VECTOR, | |
per_bin_stddev=LIBRISPEECH_STD_VECTOR)(outputs, output_paddings) | |
# Ablate random parts of input along temporal and frequency dimension | |
# following the specaug procedure in https://arxiv.org/abs/1904.08779. | |
if config.use_specaug and train: | |
outputs, output_paddings = self.specaug(outputs, output_paddings) | |
# Subsample input by a factor of 4 by performing strided convolutions. | |
outputs, output_paddings = Subsample( | |
config=config)(outputs, output_paddings, train) | |
# Run the lstm layers. | |
for _ in range(config.num_lstm_layers): | |
if config.enable_residual_connections: | |
outputs = outputs + BatchRNN(config)(outputs, output_paddings, train) | |
else: | |
outputs = BatchRNN(config)(outputs, output_paddings, train) | |
for _ in range(config.num_ffn_layers): | |
if config.enable_residual_connections: | |
outputs = outputs + FeedForwardModule(config=self.config)( | |
outputs, output_paddings, train) | |
else: | |
outputs = FeedForwardModule(config=self.config)(outputs, | |
output_paddings, | |
train) | |
# Run the decoder which in this case is a trivial projection layer. | |
if config.enable_decoder_layer_norm: | |
outputs = LayerNorm(config.encoder_dim)(outputs) | |
outputs = nn.Dense( | |
config.vocab_size, | |
use_bias=True, | |
kernel_init=nn.initializers.xavier_uniform())( | |
outputs) | |
return outputs, output_paddings | |
BATCH_SIZE = 128 | |
USE_CUDNN_LSTM=False | |
#@title Pmapped Train Loop 1 step | |
import jax | |
import numpy as np | |
import functools | |
import jax.numpy as jnp | |
import flax | |
import flax.linen as nn | |
from flax import jax_utils | |
import optax | |
from absl import logging | |
import jax.lax as lax | |
import time | |
from absl import app | |
from absl import flags | |
from absl import logging | |
import os | |
_GRAD_CLIP_EPS = 1e-6 | |
def shard(batch, n_devices=None): | |
"""Prepares the batch for pmap by adding a leading n_devices dimension. | |
If all the entries are lists, assume they are already divided into n_devices | |
smaller arrays and stack them for pmapping. If all the entries are arrays, | |
assume they have leading dimension divisible by n_devices and reshape. | |
Args: | |
batch: A dict of arrays or lists of arrays | |
n_devices: If None, this will be set to jax.local_device_count(). | |
Returns: | |
Sharded data. | |
""" | |
if n_devices is None: | |
n_devices = jax.local_device_count() | |
# TODO(mbadura): Specify a sharding function per dataset instead | |
# If entries in the batch dict are lists, then the data is already divided | |
# into n_devices chunks, so we need to stack them. | |
if all((isinstance(v, list) for v in batch.values())): | |
assert all(len(v) == n_devices for v in batch.values()) | |
# transpose a dict of lists to a list of dicts | |
shards = [{k: v[i] for (k, v) in batch.items()} for i in range(n_devices)] | |
return jax.tree_map(lambda *vals: np.stack(vals, axis=0), shards[0], | |
*shards[1:]) | |
# Otherwise, the entries are arrays, so just reshape them. | |
def _shard_array(array): | |
return array.reshape((n_devices, -1) + array.shape[1:]) | |
return jax.tree_map(_shard_array, batch) | |
def load_dummy_batch(): | |
batch_size = BATCH_SIZE | |
inputs = np.zeros((batch_size, 320000)) | |
input_paddings = np.zeros((batch_size, 320000)) | |
targets = np.zeros((batch_size, 256)) | |
target_paddings = np.zeros((batch_size, 256)) | |
padded_batch = { | |
'inputs': (jnp.array(inputs), jnp.array(input_paddings)), | |
'targets': (jnp.array(targets), jnp.array(target_paddings)) | |
} | |
sharded_padded_batch = shard(padded_batch) | |
inputs, input_paddings = sharded_padded_batch['inputs'] | |
print(inputs.shape, input_paddings.shape) | |
return sharded_padded_batch | |
# Initing optimizer and LR schedule | |
def jax_cosine_warmup(): | |
# Create learning rate schedule. | |
warmup_fn = optax.linear_schedule( | |
init_value=0., | |
end_value=0.02, | |
transition_steps=5000) | |
cosine_steps = max(60000 - 5000, 1) | |
cosine_fn = optax.cosine_decay_schedule( | |
init_value=0.02, decay_steps=cosine_steps) | |
schedule_fn = optax.join_schedules( | |
schedules=[warmup_fn, cosine_fn], | |
boundaries=[500]) | |
return schedule_fn | |
def init_optimizer_state(params): | |
"""Creates an AdamW optimizer and a learning rate schedule.""" | |
lr_schedule_fn = jax_cosine_warmup() | |
# Create optimizer. | |
epsilon = (1e-8) | |
opt_init_fn, opt_update_fn = optax.adamw( | |
learning_rate=lr_schedule_fn, | |
b1=0.98, | |
b2=0.99, | |
eps=epsilon, | |
weight_decay=0.0) | |
optimizer_state = opt_init_fn(params) | |
return jax_utils.replicate(optimizer_state), opt_update_fn | |
def train_step(model_class, | |
opt_update_fn, | |
params, | |
batch_stats, | |
optimizer_state, | |
batch, | |
rng, | |
grad_clip): | |
def _loss_fn(params): | |
"""Loss function used for training.""" | |
inputs, input_paddings = batch['inputs'] | |
targets, target_paddings = batch['targets'] | |
(logits, logit_paddings), updated_vars = model_class.apply( | |
{'params': params, 'batch_stats': batch_stats}, | |
inputs, | |
input_paddings, | |
train=True, | |
rngs={'dropout' : rng}, | |
mutable=['batch_stats']) | |
new_batch_stats = updated_vars['batch_stats'] | |
logprobs = nn.log_softmax(logits) | |
per_seq_loss = optax.ctc_loss(logprobs, | |
logit_paddings, | |
targets, | |
target_paddings) | |
normalizer = jnp.sum(1 - target_paddings) | |
normalized_loss = jnp.sum(per_seq_loss) / jnp.maximum(normalizer, 1) | |
return normalized_loss, new_batch_stats | |
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) | |
(loss, new_batch_stats), grad = grad_fn(params) | |
(loss, grad) = lax.pmean((loss, grad), axis_name='batch') | |
grad_norm = jnp.sqrt( | |
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) | |
if grad_clip is not None: | |
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) | |
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) | |
grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) | |
updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, | |
params) | |
updated_params = optax.apply_updates(params, updates) | |
return updated_params, new_batch_stats, new_optimizer_state, jnp.mean(loss), jnp.mean(grad_norm) | |
def main(): | |
sharded_padded_batch = load_dummy_batch() | |
# Initing model | |
config = DeepspeechConfig(use_cudnn_lstm=USE_CUDNN_LSTM) | |
model_class = Deepspeech(config) | |
rng = jax.random.PRNGKey(0) | |
params_rng, dropout_rng = jax.random.split(rng, 2) | |
model_init_fn = jax.jit(functools.partial(model_class.init, train=False)) | |
input_shape = [(320000,), (320000,)] | |
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] | |
print('Initializing model.') | |
vars = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) | |
batch_stats, params = vars.pop('params') | |
print('Initializing optimizer') | |
replicated_optimizer_state, opt_update_fn = init_optimizer_state(params) | |
replicated_params = jax_utils.replicate(params) | |
replicated_batch_stats = jax_utils.replicate(batch_stats) | |
# Starting Training to measure time: | |
num_training_steps = 10 | |
grad_clip=5.0 | |
# Defining pmapped update step | |
bound_train_step = functools.partial(train_step, model_class, opt_update_fn) | |
pmapped_train_step = jax.pmap(bound_train_step, | |
axis_name='batch', | |
in_axes=(0, 0, 0, 0, None, None)) | |
print('Starting training') | |
print('JAX local device count = ', jax.local_device_count()) | |
for step in range(num_training_steps): | |
if step == 1: | |
start_time = time.time() | |
jax.profiler.start_trace("/experiment_runs/traces/old_layer_bs128_jax044_10steps_new/", create_perfetto_trace=True) | |
( | |
replicated_params, | |
replicated_batch_stats, | |
replicated_optimizer_state, | |
loss, | |
grad_norm) = pmapped_train_step( | |
replicated_params, | |
replicated_batch_stats, | |
replicated_optimizer_state, | |
sharded_padded_batch, | |
rng, | |
grad_clip) | |
print('{}) loss = {} grad_norm = {}'.format(step, loss[0], grad_norm[0])) | |
jax.profiler.stop_trace() | |
end_time = time.time() | |
print('JAX program execution took %s seconds' % (end_time - start_time)) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment