Created
March 5, 2018 00:42
-
-
Save hiromu/2ea697b492ab36a7266a7052d450d969 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Prepare acoustic/linguistic/duration features. | |
usage: | |
prepare_features.py [options] <DATA_ROOT> <DST_ROOT> | |
options: | |
--overwrite Overwrite files | |
--max_num_files=<N> Max num files to be collected. [default: -1] | |
-h, --help show this help message and exit | |
""" | |
from __future__ import division, print_function, absolute_import | |
from docopt import docopt | |
import numpy as np | |
from nnmnkwii.datasets import FileSourceDataset | |
from nnmnkwii.datasets.jsut import WavFileDataSource | |
from nnmnkwii.frontend import merlin as fe | |
from nnmnkwii.preprocessing import delta_features, inv_scale, scale | |
from nnmnkwii.preprocessing.f0 import interp1d | |
from nnmnkwii.io import hts | |
import itertools | |
import os | |
import pysptk | |
import pyworld | |
import soundfile | |
import re | |
import warnings | |
from sklearn.utils.extmath import _incremental_mean_and_var | |
from tqdm import tqdm | |
max_num_files = None | |
silence = re.compile('sil*') | |
order = 59 | |
phonemes = [':', 'a', 'a:', 'b', 'by', 'ch', 'd', 'dy', 'e', 'e:', 'f', 'g', 'gy', 'h', 'hy', 'i', 'i:', 'j', 'k', 'ky', 'm', 'my', 'n', 'N', 'ny', 'o', 'o:', 'p', 'py', 'q', 'r', 'ry', 's', 'sh', 't', 'ts', 'ty', 'u', 'u:', 'w', 'y', 'z', 'zy', 'sp'] | |
class LinguisticSource(WavFileDataSource): | |
def __init__(self, data_root, *args, **kwargs): | |
super(LinguisticSource, self).__init__(data_root, *args, **kwargs) | |
self.phoneme_dict = {p: i for i, p in enumerate(phonemes)} | |
def collect_files(self): | |
files = [f.replace('wav', 'lab') for f in super(LinguisticSource, self).collect_files()] | |
if max_num_files is not None and max_num_files > 0: | |
return files[:max_num_files] | |
else: | |
return files | |
def collect_features(self, path): | |
labels = hts.load(path) | |
features = np.delete(labels.contexts, labels.silence_phone_indices(silence), axis=0) | |
return np.vectorize(self.phoneme_dict.get)(features).astype(np.float32) | |
class DurationSource(WavFileDataSource): | |
def collect_files(self): | |
files = [f.replace('wav', 'lab') for f in super(DurationSource, self).collect_files()] | |
if max_num_files is not None and max_num_files > 0: | |
return files[:max_num_files] | |
else: | |
return files | |
def collect_features(self, path): | |
labels = hts.load(path) | |
return np.delete(fe.duration_features(labels), labels.silence_phone_indices(silence), axis=0).astype(np.float32) | |
class AcousticSource(WavFileDataSource): | |
def collect_files(self): | |
wav_files = super(AcousticSource, self).collect_files() | |
lab_files = [f.replace('wav', 'lab') for f in wav_files] | |
if max_num_files is not None and max_num_files > 0: | |
return wav_files[:max_num_files], lab_files[:max_num_files] | |
else: | |
return wav_files, lab_files | |
def collect_features(self, wav_path, lab_path): | |
x, fs = soundfile.read(wav_path) | |
f0, sp, ap = pyworld.wav2world(x, fs) | |
bap = pyworld.code_aperiodicity(ap, fs) | |
mgc = pysptk.sp2mc(sp, order=order, alpha=pysptk.util.mcepalpha(fs)) | |
f0 = f0[:, None] | |
lf0 = f0.copy() | |
nonzero_indices = np.nonzero(f0) | |
lf0[nonzero_indices] = np.log(f0[nonzero_indices]) | |
vuv = (lf0 != 0).astype(np.float32) | |
lf0 = interp1d(lf0, kind="slinear") | |
features = np.hstack((mgc, lf0, vuv, bap)) | |
labels = hts.load(lab_path) | |
return np.delete(features[:labels.num_frames()], labels.silence_frame_indices(silence), axis=0).astype(np.float32) | |
if __name__ == "__main__": | |
args = docopt(__doc__) | |
data_root = args['<DATA_ROOT>'] | |
dst_root = args['<DST_ROOT>'] | |
max_num_files = int(args['--max_num_files']) | |
overwrite = args['--overwrite'] | |
linguistic_source = FileSourceDataset(LinguisticSource(data_root, subsets='all')) | |
duration_source = FileSourceDataset(DurationSource(data_root, subsets='all')) | |
acoustic_source = FileSourceDataset(AcousticSource(data_root, subsets='all')) | |
get_name = lambda idx: os.path.join(dst_root, os.path.splitext(os.path.basename(linguistic_source.collected_files[idx][0]))[0] + '.npz') | |
process_indices, rescale_indices = [], [] | |
for idx, lin in tqdm(enumerate(linguistic_source)): | |
if not overwrite and os.path.exists(get_name(idx)): | |
rescale_indices.append(idx) | |
else: | |
process_indices.append(idx) | |
if len(process_indices) != len(linguistic_source): | |
warnings.warn('{}/{} wav files are processed.'.format(len(process_indices), len(linguistic_source))) | |
mean, var, count = 0, 0, 0 | |
norm_path = os.path.join(dst_root, 'norm.npz') | |
if not overwrite and os.path.exists(norm_path): | |
norm = np.load(norm_path) | |
mean, var, count = norm['mean'], norm['var'], norm['count'] | |
if len(rescale_indices): | |
init_mean, init_std = mean.copy(), np.sqrt(var) | |
for idx in tqdm(process_indices): | |
acoustic = acoustic_source[idx] | |
np.savez_compressed(get_name(idx), audio_features=acoustic) | |
mean, var, count = _incremental_mean_and_var(acoustic, mean, var, count) | |
std = np.sqrt(var) | |
np.savez_compressed(norm_path, mean=mean, var=var, count=count) | |
for idx in tqdm(rescale_indices): | |
data = dict(np.load(get_name(idx))) | |
data['audio_features'] = scale(inv_scale(data['audio_features'], init_mean, init_std), mean, std) | |
np.savez_compressed(get_name(idx), **data) | |
for idx in tqdm(process_indices): | |
name = get_name(idx) | |
acoustic = np.load(name)['audio_features'] | |
np.savez_compressed(name, file_id=os.path.splitext(os.path.basename(name))[0], phonemes=linguistic_source[idx], durations=duration_source[idx], audio_features=scale(acoustic, mean, std)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment