Last active
November 27, 2024 21:20
-
-
Save pzelasko/f63820a5823ece287a2212ebaf56cf1c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from io import BytesIO | |
from lhotse import CutSet, Recording, SupervisionSegment, AudioSource | |
from lhotse.shar import AudioTarWriter | |
import soundfile as sf | |
from lhotse.shar.utils import to_shar_placeholder | |
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config | |
def main_text_to_audio(input_shards: list[str], output_prefix: str, model_path: str): | |
""" | |
PATHS OR URLS: | |
input_shards = ["path1", "path2", "path3"] | |
input_shards = ["s3://path1", "s3://path2", "s3://path3"] | |
URL: | |
output_prefix = "s3://..." | |
LOCAL DISK: | |
output_prefix="/path/to/" | |
""" | |
model = ... | |
for shard in input_shards: | |
""" | |
shard = "/path/to/manifest_0.json" | |
shard = "s3://path/to/manifest_0.json" | |
""" | |
dl = build_t5_dataloader_for_text_to_speech(shard, batch_size=4, **opts) | |
# "s3://path/to/manifest_0.json" -> "manifest_0" | |
shard_name = determine_shard_name(shard) # TODO: implement | |
cuts_output_path = f"{output_prefix}/{shard_name}/cuts.00000.jsonl.gz" | |
audio_output_path = f"{output_prefix}/{shard_name}/recording.00000.tar" | |
with CutSet.open_writer(cuts_output_path) as cuts_writer, AudioTarWriter(audio_output_path, shard_size=None, format="flac") as audio_writer: | |
# TODO: ensure that the order of examples in mini-batch and across mini-batches is the same as in the input shard | |
for batch in dl: | |
# TODO: return both audios and text from predict_step, avoid writing directly in predict_step | |
outputs = model.predict_step(batch) | |
text: str | |
audio: torch.Tensor | |
for example_id, text, audio in zip(outputs["ids"], outputs["texts"], outputs["audios"]): | |
recording = Recording( | |
id=example_id, | |
sources=[AudioSource(type="shar", channels=[0], source="")], | |
sampling_rate=model.sample_rate, | |
num_samples=audio.shape[1], # (1, N_samples) | |
duration=audio.shape[1] / model.sample_rate, | |
) | |
audio_writer.write(key=example_id, value=audio, sampling_rate=model.sample_rate, manifest=recording) | |
cut = recording.to_cut() | |
cut.supervisions = [SupervisionSegment(id=example_id, recording_id=example_id, start=0, duration=cut.duration, text=text, language=..., speaker=...)] | |
cuts_writer.write(cut) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment