diff --git a/README.md b/README.md index 024a4f0e..003db710 100644 --- a/README.md +++ b/README.md @@ -25,15 +25,11 @@ - **TTM**: Text to Music (👨‍💻 developing) - more… -In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. - -Here is the Amphion v0.1 demo, whose voice, audio effects, and singing voice are generated by our models. Just enjoy it! - -[amphion-v0.1-en](https://github.com/open-mmlab/Amphion/assets/24860155/7fcdcea5-3d95-4b31-bd93-4b4da734ef9b -) +In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. Moreover, Amphion is dedicated to advancing audio generation in real-world applications, such as building **large-scale datasets** for speech synthesis. ## 🚀 News -- **2024/6/17**: Amphion has a new release for its VALL-E models, it uses Llama as its underlying architecture and has better model performance, faster training speed, and more readable codes compared to our first version. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/tts/VALLE_V2/README.md) +- **2024/07/01**: Amphion now releases **Emilia**, the first open-source multilingual in-the-wild dataset for speech generation with over 101k hours of speech data, and the **Emilia-Pipe**, the first open-source preprocessing pipeline designed to transform in-the-wild speech data into high-quality training data with annotations for speech generation! [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2407.05361) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia) [![demo](https://img.shields.io/badge/WebPage-Demo-red)](https://emilia-dataset.github.io/Emilia-Demo-Page/) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](preprocessors/Emilia/README.md) +- **2024/06/17**: Amphion has a new release for its **VALL-E** model! It uses Llama as its underlying architecture and has better model performance, faster training speed, and more readable codes compared to our first version. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/tts/VALLE_V2/README.md) - **2024/03/12**: Amphion now support **NaturalSpeech3 FACodec** and release pretrained checkpoints. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2403.03100) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/naturalspeech3_facodec) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/codec/ns3_codec/README.md) - **2024/02/22**: The first Amphion visualization tool, **SingVisio**, release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2402.12660) [![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/SingVisio) [![Video](https://img.shields.io/badge/Video-Demo-orange)](https://github.com/open-mmlab/Amphion/assets/33707885/0a6e39e8-d5f1-4288-b0f8-32da5a2d6e96) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/visualization/SingVisio/README.md) - **2023/12/18**: Amphion v0.1 release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2312.09911) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Amphion-pink)](https://huggingface.co/amphion) [![youtube](https://img.shields.io/badge/YouTube-Demo-red)](https://www.youtube.com/watch?v=1aw0HhcggvQ) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/pull/39) @@ -79,7 +75,8 @@ Amphion provides a comprehensive objective evaluation of the generated audio. Th ### Datasets -Amphion unifies the data preprocess of the open-source datasets including [AudioCaps](https://audiocaps.github.io/), [LibriTTS](https://www.openslr.org/60/), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/), [M4Singer](https://github.com/M4Singer/M4Singer), [Opencpop](https://wenet.org.cn/opencpop/), [OpenSinger](https://github.com/Multi-Singer/Multi-Singer.github.io), [SVCC](http://vc-challenge.org/), [VCTK](https://datashare.ed.ac.uk/handle/10283/3443), and more. The supported dataset list can be seen [here](egs/datasets/README.md) (updating). +- Amphion unifies the data preprocess of the open-source datasets including [AudioCaps](https://audiocaps.github.io/), [LibriTTS](https://www.openslr.org/60/), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/), [M4Singer](https://github.com/M4Singer/M4Singer), [Opencpop](https://wenet.org.cn/opencpop/), [OpenSinger](https://github.com/Multi-Singer/Multi-Singer.github.io), [SVCC](http://vc-challenge.org/), [VCTK](https://datashare.ed.ac.uk/handle/10283/3443), and more. The supported dataset list can be seen [here](egs/datasets/README.md) (updating). +- Amphion (exclusively) supports the [**Emilia**](preprocessors/Emilia/README.md) dataset and its preprocessing pipeline **Emilia-Pipe** for in-the-wild speech data! ### Visualization diff --git a/preprocessors/Emilia/README.md b/preprocessors/Emilia/README.md new file mode 100644 index 00000000..63b7975d --- /dev/null +++ b/preprocessors/Emilia/README.md @@ -0,0 +1,165 @@ +## Emilia: An Extensive, Multilingual, and Diverse Speech Dataset for Large-Scale Speech Generation +[![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2407.05361) +[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia) +[![demo](https://img.shields.io/badge/WebPage-Demo-red)](https://emilia-dataset.github.io/Emilia-Demo-Page/) + +This is the official repository 👑 for the **Emilia** dataset and the source code for **Emilia-Pipe** speech data preprocessing pipeline. + +## News 🔥 +- **2024/07/08**: Our preprint [paper](https://arxiv.org/abs/2407.05361) is now available! 🔥🔥🔥 +- **2024/07/03**: We welcome everyone to check our [homepage](https://emilia-dataset.github.io/Emilia-Demo-Page/) for our brief introduction for Emilia dataset and our demos! +- **2024/07/01**: We release of Emilia and Emilia-Pipe! We welcome everyone to explore it! 🎉🎉🎉 + +## About ⭐️ +🎤 **Emilia** is a comprehensive, multilingual dataset with the following features: +- containing over *101k* hours of speech data; +- covering six different languages: *English (En), Chinese (Zh), German (De), French (Fr), Japanese (Ja), and Korean (Ko)*; +- containing diverse speech data with *various speaking styles*; + +Detailed description for the dataset could be found in our paper. + +🛠️ **Emilia-Pipe** is the first open-source preprocessing pipeline designed to transform raw, in-the-wild speech data into high-quality training data with annotations for speech generation. This pipeline can process one hour of raw audio into model-ready data in just a few minutes, requiring only the URLs of the audio or video sources. + +*To use the Emilia dataset, you can download the raw audio files from the [provided URL list](https://huggingface.co/datasets/amphion/Emilia) and use our open-source [Emilia-Pipe](https://github.com/open-mmlab/Amphion/tree/main/preprocessors/Emilia) preprocessing pipeline to preprocess the raw data and rebuild the dataset. Please note that Emilia doesn't own the copyright of the audios; the copyright remains with the original owners of the video or audio. Additionally, users can easily use Emilia-Pipe to preprocess their own raw speech data for custom needs.* + +By open-sourcing the Emilia-Pipe code, we aim to enable the speech community to collaborate on large-scale speech generation research. + +This following README will introduce the installation and usage guide of the Emilia-Pipe. + +## Pipeline Overview 👀 + +The Emilia-Pipe includes the following major steps: + +0. Standardization:Audio normalization +1. Source Separation: Long audio -> Long audio without BGM +2. Speaker Diarization: Get medium-length single-speaker speech data +3. Fine-grained Segmentation by VAD: Get 3-30s single-speaker speech segments +4. ASR: Get transcriptions of the speech segments +5. Filtering: Obtain the final processed dataset + +## Setup Steps 👨‍💻 + +### 0. Prepare Environment + +1. Install Python and CUDA. +2. Run the following commands to install the required packages: + + ```bash + conda create -y -n AudioPipeline python=3.9 + conda activate AudioPipeline + + bash env.sh + ``` + +3. Download the model files from the third-party repositories. + - Manually download the checkpoints of UVR-MDX-NET-Inst_HQ_3 ([UVR-MDX-NET-Inst_3.onnx](https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/UVR-MDX-NET-Inst_HQ_3.onnx)) and DNSMOS P.835 ([sig_bak_ovr.onnx](https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx)), then save their path for the next step configuration (i.e. #2 and #3 TODO). + - Creat the access token to pyannote/speaker-diarization-3.1 following [the guide](https://huggingface.co/pyannote/speaker-diarization-3.1#requirements), then save it for the next step configuration (i.e. #4 TODO). + - Make sure you have stable connection to GitHub and HuggingFace. The checkpoints of Silero and Whisperx-medium will be downloaded automatically on the pipeline's first run. + + +### 1. Modify Config File + +Change the config.json file according to the following TODOs. + +```json +{ + "language": { + "multilingual": true, + "supported": [ + "zh", + "en", + "fr", + "ja", + "ko", + "de" + ] + }, + "entrypoint": { + // TODO: Fill in the input_folder_path. + "input_folder_path": "examples", // #1: Data input folder for processing + "SAMPLE_RATE": 24000 + }, + "separate": { + "step1": { + // TODO: Fill in the source separation model's path. + "model_path": "/path/to/model/separate_model/UVR-MDX-NET-Inst_HQ_3.onnx", // #2: Model path + "denoise": true, + "margin": 44100, + "chunks": 15, + "n_fft": 6144, + "dim_t": 8, + "dim_f": 3072 + } + }, + "mos_model": { + // TODO: Fill in the DNSMOS prediction model's path. + "primary_model_path": "/path/to/model/mos_model/DNSMOS/sig_bak_ovr.onnx" // #3: Model path + }, + // TODO: Fill in your huggingface access token for pynannote. + "huggingface_token": "" // #4: Huggingface access token for pyannote +} +``` + +### 2. Run Script + +1. Change the `input_folder_path` in `config.json` to the folder path where the downloaded audio files are stored (i.e. #1 TODO). +2. Run the following command to process the audio files: + +```bash +conda activate AudioPipeline +export CUDA_VISIBLE_DEVICES=0 # Setting the GPU to run the pipeline, separate by comma + +python main.py +``` + +3. Processed audio will be saved into `input_folder_path`_processed folder. + + +### 3. Check the Results + +The processed audio (default 24k sample rate) files will be saved into `input_folder_path`_processed folder. The results for a single audio will be saved in a same folder with its original name and include the following information: + +1. **MP3 file**: `_.mp3` where `idx` is corresponding to the index in the JSON-encoded array. +2. **JSON file**: `.json` + +```json +[ + { + "text": "So, don't worry about that. But, like for instance, like yesterday was very hard for me to say, you know what, I should go to bed.", // Transcription + "start": 67.18, // Start timestamp, in second unit + "end": 74.41, // End timestamp, in second unit + "language": "en", // Language + "dnsmos": 3.44 // DNSMOS P.835 score + } +] +``` + +## Acknowledgement 🔔 +We acknowledge the wonderful work by these excellent developers! +- Source Separation: [UVR-MDX-NET-Inst_HQ_3](https://github.com/TRvlvr/model_repo/releases/tag/all_public_uvr_models) +- VAD: [snakers4/silero-vad](https://github.com/snakers4/silero-vad) +- Speaker Diarization: [snakers4/silero-vad](https://github.com/snakers4/silero-vad) +- ASR: [m-bain/whisperX](https://github.com/m-bain/whisperX) +- DNSMOS Prediction: [DNSMOS P.835](https://github.com/microsoft/DNS-Challenge) + + +## Reference 📖 +If you use the Emilia dataset or the Emilia-Pipe pipeline, please cite the following papers: +```bibtex +@article{emilia, + title={Emilia: An Extensive, Multilingual, and Diverse Speech Dataset for Large-Scale Speech Generation}, + author={He, Haorui and Shang, Zengqiang and Wang, Chaoren and Li, Xuyuan and Gu, Yicheng and Hua, Hua and Liu, Liwei and Yang, Chen and Li, Jiaqi and Shi, Peiyang and Wang, Yuancheng and Chen, Kai and Zhang, Pengyuan and Wu, Zhizheng}, + journal={arXiv}, + volume={abs/2407.05361} + year={2024} +} +``` +```bibtex +@article{amphion, + title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit}, + author={Zhang, Xueyao and Xue, Liumeng and Gu, Yicheng and Wang, Yuancheng and He, Haorui and Wang, Chaoren and Chen, Xi and Fang, Zihao and Chen, Haopeng and Zhang, Junan and Tang, Tze Ying and Zou, Lexiao and Wang, Mingxuan and Han, Jun and Chen, Kai and Li, Haizhou and Wu, Zhizheng}, + journal={arXiv}, + volume={abs/2312.09911} + year={2024}, +} +``` diff --git a/preprocessors/Emilia/config.json b/preprocessors/Emilia/config.json new file mode 100755 index 00000000..bf5da332 --- /dev/null +++ b/preprocessors/Emilia/config.json @@ -0,0 +1,35 @@ +{ + "language": { + "multilingual": true, + "supported": [ + "zh", + "en", + "fr", + "ja", + "ko", + "de" + ] + }, + "entrypoint": { + // TODO: Fill in the input_folder_path. + "input_folder_path": "examples", + "SAMPLE_RATE": 24000 + }, + "separate": { + "step1": { + // TODO: Fill in the source separation model's path. + "model_path": "/path/to/model/separate_model/UVR-MDX-NET-Inst_HQ_3.onnx", + "denoise": true, + "margin": 44100, + "chunks": 15, + "n_fft": 6144, + "dim_t": 8, + "dim_f": 3072 + } + }, + "mos_model": { + // TODO: Fill in the DNSMOS prediction model's path. + "primary_model_path": "/path/to/model/mos_model/DNSMOS/sig_bak_ovr.onnx" + }, + "huggingface_token": "" +} \ No newline at end of file diff --git a/preprocessors/Emilia/env.sh b/preprocessors/Emilia/env.sh new file mode 100644 index 00000000..bbc4b1d2 --- /dev/null +++ b/preprocessors/Emilia/env.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +conda install ffmpeg -y +conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y +pip install -r requirements.txt +pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ diff --git a/preprocessors/Emilia/main.py b/preprocessors/Emilia/main.py new file mode 100755 index 00000000..a1663c31 --- /dev/null +++ b/preprocessors/Emilia/main.py @@ -0,0 +1,571 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import json +import librosa +import numpy as np +import sys +import os +import tqdm +import warnings +import torch +from pydub import AudioSegment +from pyannote.audio import Pipeline +import pandas as pd + +from utils.tool import ( + export_to_mp3, + load_cfg, + get_audio_files, + detect_gpu, + check_env, + calculate_audio_stats, +) +from utils.logger import Logger, time_logger +from models import separate_fast, dnsmos, whisper_asr, silero_vad + +warnings.filterwarnings("ignore") +audio_count = 0 + + +@time_logger +def standardization(audio): + """ + Preprocess the audio file, including setting sample rate, bit depth, channels, and volume normalization. + + Args: + audio (str or AudioSegment): Audio file path or AudioSegment object, the audio to be preprocessed. + + Returns: + dict: A dictionary containing the preprocessed audio waveform, audio file name, and sample rate, formatted as: + { + "waveform": np.ndarray, the preprocessed audio waveform, dtype is np.float32, shape is (num_samples,) + "name": str, the audio file name + "sample_rate": int, the audio sample rate + } + + Raises: + ValueError: If the audio parameter is neither a str nor an AudioSegment. + """ + global audio_count + name = "audio" + + if isinstance(audio, str): + name = os.path.basename(audio) + audio = AudioSegment.from_file(audio) + elif isinstance(audio, AudioSegment): + name = f"audio_{audio_count}" + audio_count += 1 + else: + raise ValueError("Invalid audio type") + + logger.debug("Entering the preprocessing of audio") + + # Convert the audio file to WAV format + audio = audio.set_frame_rate(cfg["entrypoint"]["SAMPLE_RATE"]) + audio = audio.set_sample_width(2) # Set bit depth to 16bit + audio = audio.set_channels(1) # Set to mono + + logger.debug("Audio file converted to WAV format") + + # Calculate the gain to be applied + target_dBFS = -20 + gain = target_dBFS - audio.dBFS + logger.info(f"Calculating the gain needed for the audio: {gain} dB") + + # Normalize volume and limit gain range to between -3 and 3 + normalized_audio = audio.apply_gain(min(max(gain, -3), 3)) + + waveform = np.array(normalized_audio.get_array_of_samples(), dtype=np.float32) + max_amplitude = np.max(np.abs(waveform)) + waveform /= max_amplitude # Normalize + + logger.debug(f"waveform shape: {waveform.shape}") + logger.debug("waveform in np ndarray, dtype=" + str(waveform.dtype)) + + return { + "waveform": waveform, + "name": name, + "sample_rate": cfg["entrypoint"]["SAMPLE_RATE"], + } + + +@time_logger +def source_separation(predictor, audio): + """ + Separate the audio into vocals and non-vocals using the given predictor. + + Args: + predictor: The separation model predictor. + audio (str or dict): The audio file path or a dictionary containing audio waveform and sample rate. + + Returns: + dict: A dictionary containing the separated vocals and updated audio waveform. + """ + + mix, rate = None, None + + if isinstance(audio, str): + mix, rate = librosa.load(audio, mono=False, sr=44100) + else: + # resample to 44100 + rate = audio["sample_rate"] + mix = librosa.resample(audio["waveform"], orig_sr=rate, target_sr=44100) + + vocals, no_vocals = predictor.predict(mix) + + # convert vocals back to previous sample rate + logger.debug(f"vocals shape before resample: {vocals.shape}") + vocals = librosa.resample(vocals.T, orig_sr=44100, target_sr=rate).T + logger.debug(f"vocals shape after resample: {vocals.shape}") + audio["waveform"] = vocals[:, 0] # vocals is stereo, only use one channel + + return audio + + +# Step 2: Speaker Diarization +@time_logger +def speaker_diarization(audio): + """ + Perform speaker diarization on the given audio. + + Args: + audio (dict): A dictionary containing the audio waveform and sample rate. + + Returns: + pd.DataFrame: A dataframe containing segments with speaker labels. + """ + logger.debug(f"Start speaker diarization") + logger.debug(f"audio waveform shape: {audio['waveform'].shape}") + + waveform = torch.tensor(audio["waveform"]).to(device) + waveform = torch.unsqueeze(waveform, 0) + + segments = dia_pipeline( + { + "waveform": waveform, + "sample_rate": audio["sample_rate"], + "channel": 0, + } + ) + + diarize_df = pd.DataFrame( + segments.itertracks(yield_label=True), + columns=["segment", "label", "speaker"], + ) + diarize_df["start"] = diarize_df["segment"].apply(lambda x: x.start) + diarize_df["end"] = diarize_df["segment"].apply(lambda x: x.end) + + logger.debug(f"diarize_df: {diarize_df}") + + return diarize_df + + +@time_logger +def cut_by_speaker_label(vad_list): + """ + Merge and trim VAD segments by speaker labels, enforcing constraints on segment length and merge gaps. + + Args: + vad_list (list): List of VAD segments with start, end, and speaker labels. + + Returns: + list: A list of updated VAD segments after merging and trimming. + """ + MERGE_GAP = 2 # merge gap in seconds, if smaller than this, merge + MIN_SEGMENT_LENGTH = 3 # min segment length in seconds + MAX_SEGMENT_LENGTH = 30 # max segment length in seconds + + updated_list = [] + + for idx, vad in enumerate(vad_list): + last_start_time = updated_list[-1]["start"] if updated_list else None + last_end_time = updated_list[-1]["end"] if updated_list else None + last_speaker = updated_list[-1]["speaker"] if updated_list else None + + if vad["end"] - vad["start"] >= MAX_SEGMENT_LENGTH: + current_start = vad["start"] + segment_end = vad["end"] + logger.warning( + f"cut_by_speaker_label > segment longer than 30s, force trimming to 30s smaller segments" + ) + while segment_end - current_start >= MAX_SEGMENT_LENGTH: + vad["end"] = current_start + MAX_SEGMENT_LENGTH # update end time + updated_list.append(vad) + vad = vad.copy() + current_start += MAX_SEGMENT_LENGTH + vad["start"] = current_start # update start time + vad["end"] = segment_end + updated_list.append(vad) + continue + + if ( + last_speaker is None + or last_speaker != vad["speaker"] + or vad["end"] - vad["start"] >= MIN_SEGMENT_LENGTH + ): + updated_list.append(vad) + continue + + if ( + vad["start"] - last_end_time >= MERGE_GAP + or vad["end"] - last_start_time >= MAX_SEGMENT_LENGTH + ): + updated_list.append(vad) + else: + updated_list[-1]["end"] = vad["end"] # merge the time + + logger.debug( + f"cut_by_speaker_label > merged {len(vad_list) - len(updated_list)} segments" + ) + + filter_list = [ + vad for vad in updated_list if vad["end"] - vad["start"] >= MIN_SEGMENT_LENGTH + ] + + logger.debug( + f"cut_by_speaker_label > removed: {len(updated_list) - len(filter_list)} segments by length" + ) + + return filter_list + + +@time_logger +def asr(vad_segments, audio): + """ + Perform Automatic Speech Recognition (ASR) on the VAD segments of the given audio. + + Args: + vad_segments (list): List of VAD segments with start and end times. + audio (dict): A dictionary containing the audio waveform and sample rate. + + Returns: + list: A list of ASR results with transcriptions and language details. + """ + if len(vad_segments) == 0: + return [] + + temp_audio = audio["waveform"] + start_time = vad_segments[0]["start"] + end_time = vad_segments[-1]["end"] + start_frame = int(start_time * audio["sample_rate"]) + end_frame = int(end_time * audio["sample_rate"]) + temp_audio = temp_audio[start_frame:end_frame] # remove silent start and end + + # update vad_segments start and end time (this is a little trick for batched asr:) + for idx, segment in enumerate(vad_segments): + vad_segments[idx]["start"] -= start_time + vad_segments[idx]["end"] -= start_time + + # resample to 16k + temp_audio = librosa.resample( + temp_audio, orig_sr=audio["sample_rate"], target_sr=16000 + ) + + if multilingual_flag: + logger.debug("Multilingual flag is on") + valid_vad_segments, valid_vad_segments_language = [], [] + # get valid segments to be transcripted + for idx, segment in enumerate(vad_segments): + start_frame = int(segment["start"] * 16000) + end_frame = int(segment["end"] * 16000) + segment_audio = temp_audio[start_frame:end_frame] + language, prob = asr_model.detect_language(segment_audio) + # 1. if language is in supported list, 2. if prob > 0.8 + if language in supported_languages and prob > 0.8: + valid_vad_segments.append(vad_segments[idx]) + valid_vad_segments_language.append(language) + + # if no valid segment, return empty + if len(valid_vad_segments) == 0: + return [] + all_transcribe_result = [] + logger.debug(f"valid_vad_segments_language: {valid_vad_segments_language}") + unique_languages = list(set(valid_vad_segments_language)) + logger.debug(f"unique_languages: {unique_languages}") + # process each language one by one + for language_token in unique_languages: + language = language_token + # filter out segments with different language + vad_segments = [ + valid_vad_segments[i] + for i, x in enumerate(valid_vad_segments_language) + if x == language + ] + # bacthed trascription + transcribe_result_temp = asr_model.transcribe( + temp_audio, + vad_segments, + batch_size=batch_size, + language=language, + print_progress=True, + ) + result = transcribe_result_temp["segments"] + # restore the segment annotation + for idx, segment in enumerate(result): + result[idx]["start"] += start_time + result[idx]["end"] += start_time + result[idx]["language"] = transcribe_result_temp["language"] + all_transcribe_result.extend(result) + # sort by start time + all_transcribe_result = sorted(all_transcribe_result, key=lambda x: x["start"]) + return all_transcribe_result + else: + logger.debug("Multilingual flag is off") + language, prob = asr_model.detect_language(temp_audio) + if language in supported_languages and prob > 0.8: + transcribe_result = asr_model.transcribe( + temp_audio, + vad_segments, + batch_size=batch_size, + language=language, + print_progress=True, + ) + result = transcribe_result["segments"] + for idx, segment in enumerate(result): + result[idx]["start"] += start_time + result[idx]["end"] += start_time + result[idx]["language"] = transcribe_result["language"] + return result + else: + return [] + + +@time_logger +def mos_prediction(audio, vad_list): + """ + Predict the Mean Opinion Score (MOS) for the given audio and VAD segments. + + Args: + audio (dict): A dictionary containing the audio waveform and sample rate. + vad_list (list): List of VAD segments with start and end times. + + Returns: + tuple: A tuple containing the average MOS and the updated VAD segments with MOS scores. + """ + audio = audio["waveform"] + sample_rate = 16000 + + audio = librosa.resample( + audio, orig_sr=cfg["entrypoint"]["SAMPLE_RATE"], target_sr=sample_rate + ) + + for index, vad in enumerate(tqdm.tqdm(vad_list, desc="DNSMOS")): + start, end = int(vad["start"] * sample_rate), int(vad["end"] * sample_rate) + segment = audio[start:end] + + dnsmos = dnsmos_compute_score(segment, sample_rate, False)["OVRL"] + + vad_list[index]["dnsmos"] = dnsmos + + predict_dnsmos = np.mean([vad["dnsmos"] for vad in vad_list]) + + logger.debug(f"avg predict_dnsmos for whole audio: {predict_dnsmos}") + + return predict_dnsmos, vad_list + + +def filter(mos_list): + """ + Filter out the segments with MOS scores, wrong char duration, and total duration. + + Args: + mos_list (list): List of VAD segments with MOS scores. + + Returns: + list: A list of VAD segments with MOS scores above the average MOS. + """ + filtered_audio_stats, all_audio_stats = calculate_audio_stats(mos_list) + filtered_segment = len(filtered_audio_stats) + all_segment = len(all_audio_stats) + logger.debug( + f"> {all_segment - filtered_segment}/{all_segment} {(all_segment - filtered_segment) / all_segment:.2%} segments filtered." + ) + filtered_list = [mos_list[idx] for idx, _ in filtered_audio_stats] + return filtered_list + + +def main_process(audio_path, save_path=None, audio_name=None): + """ + Process the audio file, including standardization, source separation, speaker segmentation, VAD, ASR, export to MP3, and MOS prediction. + + Args: + audio_path (str): Audio file path. + save_path (str, optional): Save path, defaults to None, which means saving in the "_processed" folder in the audio file's directory. + audio_name (str, optional): Audio file name, defaults to None, which means using the file name from the audio file path. + + Returns: + tuple: Contains the save path and the MOS list. + """ + if not audio_path.endswith((".mp3", ".wav", ".flac", ".m4a", ".aac")): + logger.warning(f"Unsupported file type: {audio_path}") + + # for a single audio from path Ïaaa/bbb/ccc.wav ---> save to aaa/bbb_processed/ccc/ccc_0.wav + audio_name = audio_name or os.path.splitext(os.path.basename(audio_path))[0] + save_path = save_path or os.path.join( + os.path.dirname(audio_path) + "_processed", audio_name + ) + os.makedirs(save_path, exist_ok=True) + logger.debug( + f"Processing audio: {audio_name}, from {audio_path}, save to: {save_path}" + ) + + logger.info( + "Step 0: Preprocess all audio files --> 24k sample rate + wave format + loudnorm + bit depth 16" + ) + audio = standardization(audio_path) + + logger.info("Step 1: Source Separation") + audio = source_separation(separate_predictor1, audio) + + logger.info("Step 2: Speaker Diarization") + speakerdia = speaker_diarization(audio) + + logger.info("Step 3: Fine-grained Segmentation by VAD") + vad_list = vad.vad(speakerdia, audio) + segment_list = cut_by_speaker_label(vad_list) # post process after vad + + logger.info("Step 4: ASR") + asr_result = asr(segment_list, audio) + + logger.info("Step 5: Filter") + logger.info("Step 5.1: calculate mos_prediction") + avg_mos, mos_list = mos_prediction(audio, asr_result) + + logger.info(f"Step 5.1: done, average MOS: {avg_mos}") + + logger.info("Step 5.2: Filter out files with less than average MOS") + filtered_list = filter(mos_list) + + logger.info("Step 6: write result into MP3 and JSON file") + export_to_mp3(audio, filtered_list, save_path, audio_name) + + final_path = os.path.join(save_path, audio_name + ".json") + with open(final_path, "w") as f: + json.dump(filtered_list, f, ensure_ascii=False) + + logger.info(f"All done, Saved to: {final_path}") + return final_path, filtered_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_folder_path", + type=str, + default="", + help="input folder path, this will override config if set", + ) + parser.add_argument( + "--config_path", type=str, default="config.json", help="config path" + ) + parser.add_argument("--batch_size", type=int, default=16, help="batch size") + parser.add_argument( + "--compute_type", + type=str, + default="float16", + help="The compute type to use for the model", + ) + parser.add_argument( + "--whisper_arch", + type=str, + default="medium", + help="The name of the Whisper model to load.", + ) + parser.add_argument( + "--threads", + type=int, + default=4, + help="The number of CPU threads to use per worker, e.g. will be multiplied by num workers.", + ) + parser.add_argument( + "--exit_pipeline", + type=bool, + default=False, + help="Exit pipeline when task done.", + ) + args = parser.parse_args() + + batch_size = args.batch_size + cfg = load_cfg(args.config_path) + + logger = Logger.get_logger() + + if args.input_folder_path: + logger.info(f"Using input folder path: {args.input_folder_path}") + cfg["entrypoint"]["input_folder_path"] = args.input_folder_path + + logger.debug("Loading models...") + + # Load models + if detect_gpu(): + logger.info("Using GPU") + device_name = "cuda" + device = torch.device(device_name) + else: + logger.info("Using CPU") + device_name = "cpu" + device = torch.device(device_name) + + check_env(logger) + + # Speaker Diarization + logger.debug(" * Loading Speaker Diarization Model") + if not cfg["huggingface_token"].startswith("hf"): + raise ValueError( + "huggingface_token must start with 'hf', check the config file. " + "You can get the token at https://huggingface.co/settings/tokens. " + "Remeber grant access following https://github.com/pyannote/pyannote-audio?tab=readme-ov-file#tldr" + ) + dia_pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + use_auth_token=cfg["huggingface_token"], + ) + dia_pipeline.to(device) + + # ASR + logger.debug(" * Loading ASR Model") + asr_model = whisper_asr.load_asr_model( + args.whisper_arch, + device_name, + compute_type=args.compute_type, + threads=args.threads, + asr_options={ + "initial_prompt": "Um, Uh, Ah. Like, you know. I mean, right. Actually. Basically, and right? okay. Alright. Emm. So. Oh. 生于忧患,死于安乐。岂不快哉?当然,嗯,呃,就,这样,那个,哪个,啊,呀,哎呀,哎哟,唉哇,啧,唷,哟,噫!微斯人,吾谁与归?ええと、あの、ま、そう、ええ。äh, hm, so, tja, halt, eigentlich. euh, quoi, bah, ben, tu vois, tu sais, t'sais, eh bien, du coup. genre, comme, style. 응,어,그,음." + }, + ) + + # VAD + logger.debug(" * Loading VAD Model") + vad = silero_vad.SileroVAD(device=device) + + # Background Noise Separation + logger.debug(" * Loading Background Noise Model") + separate_predictor1 = separate_fast.Predictor( + args=cfg["separate"]["step1"], device=device_name + ) + + # DNSMOS Scoring + logger.debug(" * Loading DNSMOS Model") + primary_model_path = cfg["mos_model"]["primary_model_path"] + dnsmos_compute_score = dnsmos.ComputeScore(primary_model_path, device_name) + logger.debug("All models loaded") + + supported_languages = cfg["language"]["supported"] + multilingual_flag = cfg["language"]["multilingual"] + logger.debug(f"supported languages multilingual {supported_languages}") + logger.debug(f"using multilingual asr {multilingual_flag}") + + input_folder_path = cfg["entrypoint"]["input_folder_path"] + + if not os.path.exists(input_folder_path): + raise FileNotFoundError(f"input_folder_path: {input_folder_path} not found") + + audio_paths = get_audio_files(input_folder_path) # Get all audio files + logger.debug(f"Scanning {len(audio_paths)} audio files in {input_folder_path}") + + for path in audio_paths: + main_process(path) diff --git a/preprocessors/Emilia/models/__init__.py b/preprocessors/Emilia/models/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/preprocessors/Emilia/models/dnsmos.py b/preprocessors/Emilia/models/dnsmos.py new file mode 100755 index 00000000..7b17f196 --- /dev/null +++ b/preprocessors/Emilia/models/dnsmos.py @@ -0,0 +1,174 @@ +# Source: https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS +# +# Copyright (c) 2022 Microsoft +# +# This code is licensed under the Creative Commons Attribution 4.0 International (CC BY 4.0) license. +# The full license text is available at the root of the source repository. +# +# Note: This code has been modified to fit the context of this repository. +# This code is included in an MIT-licensed repository. +# The repository's MIT license does not apply to this code. + +import os +import librosa +import numpy as np +import onnxruntime as ort +import pandas as pd +import tqdm +import warnings + + +warnings.filterwarnings("ignore") + +SAMPLING_RATE = 16000 +INPUT_LENGTH = 9.01 + + +class ComputeScore: + """ + ComputeScore class for evaluating DNSMOS. + """ + + def __init__(self, primary_model_path, device="cpu") -> None: + """ + Initialize the ComputeScore object. + + Args: + primary_model_path (str): Path to the primary model. + device (str): Device to run the models on ('cpu' or 'cuda'). + + Returns: + None + + Raises: + RuntimeError: If the device is not supported. + """ + if device == "cuda": + self.onnx_sess = ort.InferenceSession( + primary_model_path, providers=["CUDAExecutionProvider"] + ) + print("Using CUDA:", self.onnx_sess.get_providers()) + else: + self.onnx_sess = ort.InferenceSession(primary_model_path) + + def audio_melspec( + self, audio, n_mels=120, frame_size=320, hop_length=160, sr=16000, to_db=True + ): + """ + Compute the mel spectrogram of an audio signal. + + Args: + audio (np.ndarray): Input audio signal. + n_mels (int): Number of mel bands. + frame_size (int): Size of the FFT window. + hop_length (int): Number of samples between successive frames. + sr (int): Sampling rate. + to_db (bool): Whether to convert the power spectrogram to decibel units. + + Returns: + np.ndarray: Mel spectrogram. + """ + mel_spec = librosa.feature.melspectrogram( + y=audio, sr=sr, n_fft=frame_size + 1, hop_length=hop_length, n_mels=n_mels + ) + if to_db: + mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) + 40) / 40 + return mel_spec.T + + def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS): + """ + Apply polynomial fitting to MOS scores. + + Args: + sig (float): Signal MOS score. + bak (float): Background MOS score. + ovr (float): Overall MOS score. + is_personalized_MOS (bool): Flag for personalized MOS. + + Returns: + tuple: Tuple containing the adjusted signal, background, and overall MOS scores. + """ + if is_personalized_MOS: + p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046]) + p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726]) + p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132]) + else: + p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535]) + p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439]) + p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546]) + + sig_poly = p_sig(sig) + bak_poly = p_bak(bak) + ovr_poly = p_ovr(ovr) + + return sig_poly, bak_poly, ovr_poly + + def __call__(self, audio, sampling_rate, is_personalized_MOS): + """ + Compute DNSMOS scores for an audio signal. + + Args: + audio (np.ndarray or str): Input audio signal or path to audio file. + sampling_rate (int): Sampling rate of the input audio. + is_personalized_MOS (bool): Flag for personalized MOS. + + Returns: + dict: Dictionary containing MOS scores. + + Raises: + ValueError: If the input audio is not valid. + """ + fs = SAMPLING_RATE + if isinstance(audio, str): + audio, _ = librosa.load(audio, sr=fs) + elif sampling_rate != fs: + # resample audio + audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=fs) + + actual_audio_len = len(audio) + + len_samples = int(INPUT_LENGTH * fs) + while len(audio) < len_samples: + audio = np.append(audio, audio) + + num_hops = int(np.floor(len(audio) / fs) - INPUT_LENGTH) + 1 + hop_len_samples = fs + predicted_mos_sig_seg_raw = [] + predicted_mos_bak_seg_raw = [] + predicted_mos_ovr_seg_raw = [] + predicted_mos_sig_seg = [] + predicted_mos_bak_seg = [] + predicted_mos_ovr_seg = [] + + for idx in range(num_hops): + audio_seg = audio[ + int(idx * hop_len_samples) : int((idx + INPUT_LENGTH) * hop_len_samples) + ] + if len(audio_seg) < len_samples: + continue + input_features = np.array(audio_seg).astype("float32")[np.newaxis, :] + oi = {"input_1": input_features} + mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0] + mos_sig, mos_bak, mos_ovr = self.get_polyfit_val( + mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_MOS + ) + predicted_mos_sig_seg_raw.append(mos_sig_raw) + predicted_mos_bak_seg_raw.append(mos_bak_raw) + predicted_mos_ovr_seg_raw.append(mos_ovr_raw) + predicted_mos_sig_seg.append(mos_sig) + predicted_mos_bak_seg.append(mos_bak) + predicted_mos_ovr_seg.append(mos_ovr) + + clip_dict = { + "filename": "audio_clip", + "len_in_sec": actual_audio_len / fs, + "sr": fs, + "num_hops": num_hops, + "OVRL_raw": np.mean(predicted_mos_ovr_seg_raw), + "SIG_raw": np.mean(predicted_mos_sig_seg_raw), + "BAK_raw": np.mean(predicted_mos_bak_seg_raw), + "OVRL": np.mean(predicted_mos_ovr_seg), + "SIG": np.mean(predicted_mos_sig_seg), + "BAK": np.mean(predicted_mos_bak_seg), + } + return clip_dict diff --git a/preprocessors/Emilia/models/separate_fast.py b/preprocessors/Emilia/models/separate_fast.py new file mode 100755 index 00000000..d761dd8d --- /dev/null +++ b/preprocessors/Emilia/models/separate_fast.py @@ -0,0 +1,293 @@ +# Copyright (c) 2023 seanghay +# +# This code is from an unliscensed repository. +# +# Note: This code has been modified to fit the context of this repository. +# This code is included in an MIT-licensed repository. +# The repository's MIT license does not apply to this code. + +# This code is modified from https://github.com/seanghay/uvr-mdx-infer/blob/main/separate.py + +import torch +import numpy as np +import onnxruntime as ort +from tqdm import tqdm + + +class ConvTDFNet: + """ + ConvTDFNet - Convolutional Temporal Frequency Domain Network. + """ + + def __init__(self, target_name, L, dim_f, dim_t, n_fft, hop=1024): + """ + Initialize ConvTDFNet. + + Args: + target_name (str): The target name for separation. + L (int): Number of layers. + dim_f (int): Dimension in the frequency domain. + dim_t (int): Dimension in the time domain (log2). + n_fft (int): FFT size. + hop (int, optional): Hop size. Defaults to 1024. + + Returns: + None + """ + super(ConvTDFNet, self).__init__() + self.dim_c = 4 + self.dim_f = dim_f + self.dim_t = 2**dim_t + self.n_fft = n_fft + self.hop = hop + self.n_bins = self.n_fft // 2 + 1 + self.chunk_size = hop * (self.dim_t - 1) + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.target_name = target_name + + out_c = self.dim_c * 4 if target_name == "*" else self.dim_c + + self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]) + self.n = L // 2 + + def stft(self, x): + """ + Perform Short-Time Fourier Transform (STFT). + + Args: + x (torch.Tensor): Input waveform. + + Returns: + torch.Tensor: STFT of the input waveform. + """ + x = x.reshape([-1, self.chunk_size]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop, + window=self.window, + center=True, + return_complex=True, + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape( + [-1, self.dim_c, self.n_bins, self.dim_t] + ) + return x[:, :, : self.dim_f] + + def istft(self, x, freq_pad=None): + """ + Perform Inverse Short-Time Fourier Transform (ISTFT). + + Args: + x (torch.Tensor): Input STFT. + freq_pad (torch.Tensor, optional): Frequency padding. Defaults to None. + + Returns: + torch.Tensor: Inverse STFT of the input. + """ + freq_pad = ( + self.freq_pad.repeat([x.shape[0], 1, 1, 1]) + if freq_pad is None + else freq_pad + ) + x = torch.cat([x, freq_pad], -2) + c = 4 * 2 if self.target_name == "*" else 2 + x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape( + [-1, 2, self.n_bins, self.dim_t] + ) + x = x.permute([0, 2, 3, 1]) + x = x.contiguous() + x = torch.view_as_complex(x) + x = torch.istft( + x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True + ) + return x.reshape([-1, c, self.chunk_size]) + + +class Predictor: + """ + Predictor class for source separation using ConvTDFNet and ONNX Runtime. + """ + + def __init__(self, args, device): + """ + Initialize the Predictor. + + Args: + args (dict): Configuration arguments. + device (str): Device to run the model ('cuda' or 'cpu'). + + Returns: + None + + Raises: + ValueError: If the provided device is not 'cuda' or 'cpu'. + """ + self.args = args + self.model_ = ConvTDFNet( + target_name="vocals", + L=11, + dim_f=args["dim_f"], + dim_t=args["dim_t"], + n_fft=args["n_fft"], + ) + + if device == "cuda": + self.model = ort.InferenceSession( + args["model_path"], providers=["CUDAExecutionProvider"] + ) + elif device == "cpu": + self.model = ort.InferenceSession( + args["model_path"], providers=["CPUExecutionProvider"] + ) + else: + raise ValueError("Device must be either 'cuda' or 'cpu'") + + def demix(self, mix): + """ + Separate the sources from the input mix. + + Args: + mix (np.ndarray): Input mixture signal. + + Returns: + np.ndarray: Separated sources. + + Raises: + AssertionError: If margin is zero. + """ + samples = mix.shape[-1] + margin = self.args["margin"] + chunk_size = self.args["chunks"] * 44100 + + assert margin != 0, "Margin cannot be zero!" + + if margin > chunk_size: + margin = chunk_size + + segmented_mix = {} + + if self.args["chunks"] == 0 or samples < chunk_size: + chunk_size = samples + + counter = -1 + for skip in range(0, samples, chunk_size): + counter += 1 + s_margin = 0 if counter == 0 else margin + end = min(skip + chunk_size + margin, samples) + start = skip - s_margin + segmented_mix[skip] = mix[:, start:end].copy() + if end == samples: + break + + sources = self.demix_base(segmented_mix, margin_size=margin) + return sources + + def demix_base(self, mixes, margin_size): + """ + Base function for source separation. + + Args: + mixes (dict): Dictionary of segmented mixtures. + margin_size (int): Size of the margin. + + Returns: + np.ndarray: Separated sources. + """ + chunked_sources = [] + progress_bar = tqdm(total=len(mixes)) + progress_bar.set_description("Source separation") + + for mix in mixes: + cmix = mixes[mix] + sources = [] + n_sample = cmix.shape[1] + model = self.model_ + trim = model.n_fft // 2 + gen_size = model.chunk_size - 2 * trim + pad = gen_size - n_sample % gen_size + mix_p = np.concatenate( + (np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1 + ) + mix_waves = [] + i = 0 + while i < n_sample + pad: + waves = np.array(mix_p[:, i : i + model.chunk_size]) + mix_waves.append(waves) + i += gen_size + + mix_waves = torch.tensor(np.array(mix_waves), dtype=torch.float32) + + with torch.no_grad(): + _ort = self.model + spek = model.stft(mix_waves) + if self.args["denoise"]: + spec_pred = ( + -_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5 + + _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5 + ) + tar_waves = model.istft(torch.tensor(spec_pred)) + else: + tar_waves = model.istft( + torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0]) + ) + tar_signal = ( + tar_waves[:, :, trim:-trim] + .transpose(0, 1) + .reshape(2, -1) + .numpy()[:, :-pad] + ) + + start = 0 if mix == 0 else margin_size + end = None if mix == list(mixes.keys())[::-1][0] else -margin_size + + if margin_size == 0: + end = None + + sources.append(tar_signal[:, start:end]) + + progress_bar.update(1) + + chunked_sources.append(sources) + _sources = np.concatenate(chunked_sources, axis=-1) + + progress_bar.close() + return _sources + + def predict(self, mix): + """ + Predict the separated sources from the input mix. + + Args: + mix (np.ndarray): Input mixture signal. + + Returns: + tuple: Tuple containing the mixture minus the separated sources and the separated sources. + """ + if mix.ndim == 1: + mix = np.asfortranarray([mix, mix]) + + tail = mix.shape[1] % (self.args["chunks"] * 44100) + if mix.shape[1] % (self.args["chunks"] * 44100) != 0: + mix = np.pad( + mix, + ( + (0, 0), + ( + 0, + self.args["chunks"] * 44100 + - mix.shape[1] % (self.args["chunks"] * 44100), + ), + ), + ) + + mix = mix.T + sources = self.demix(mix.T) + opt = sources[0].T + + if tail != 0: + return ((mix - opt)[: -(self.args["chunks"] * 44100 - tail), :], opt) + else: + return ((mix - opt), opt) diff --git a/preprocessors/Emilia/models/silero_vad.py b/preprocessors/Emilia/models/silero_vad.py new file mode 100755 index 00000000..ca9390c4 --- /dev/null +++ b/preprocessors/Emilia/models/silero_vad.py @@ -0,0 +1,181 @@ +# Source: https://github.com/snakers4/silero-vad +# +# Copyright (c) 2024 snakers4 +# +# This code is from a MIT-licensed repository. The full license text is available at the root of the source repository. +# +# Note: This code has been modified to fit the context of this repository. + +import librosa +import torch +import numpy as np + +VAD_THRESHOLD = 20 +SAMPLING_RATE = 16000 + + +class SileroVAD: + """ + Voice Activity Detection (VAD) using Silero-VAD. + """ + + def __init__(self, local=False, model="silero_vad", device=torch.device("cpu")): + """ + Initialize the VAD object. + + Args: + local (bool, optional): Whether to load the model locally. Defaults to False. + model (str, optional): The VAD model name to load. Defaults to "silero_vad". + device (torch.device, optional): The device to run the model on. Defaults to 'cpu'. + + Returns: + None + + Raises: + RuntimeError: If loading the model fails. + """ + try: + vad_model, utils = torch.hub.load( + repo_or_dir="snakers4/silero-vad" if not local else "vad/silero-vad", + model=model, + force_reload=False, + onnx=True, + source="github" if not local else "local", + ) + self.vad_model = vad_model + (get_speech_timestamps, _, _, _, _) = utils + self.get_speech_timestamps = get_speech_timestamps + except Exception as e: + raise RuntimeError(f"Failed to load VAD model: {e}") + + def segment_speech(self, audio_segment, start_time, end_time, sampling_rate): + """ + Segment speech from an audio segment and return a list of timestamps. + + Args: + audio_segment (np.ndarray): The audio segment to be segmented. + start_time (int): The start time of the audio segment in frames. + end_time (int): The end time of the audio segment in frames. + sampling_rate (int): The sampling rate of the audio segment. + + Returns: + list: A list of timestamps, each containing the start and end times of speech segments in frames. + + Raises: + ValueError: If the audio segment is invalid. + """ + if audio_segment is None or not isinstance(audio_segment, (np.ndarray, list)): + raise ValueError("Invalid audio segment") + + speech_timestamps = self.get_speech_timestamps( + audio_segment, self.vad_model, sampling_rate=sampling_rate + ) + + adjusted_timestamps = [ + (ts["start"] + start_time, ts["end"] + start_time) + for ts in speech_timestamps + ] + if not adjusted_timestamps: + return [] + + intervals = [ + end[0] - start[1] + for start, end in zip(adjusted_timestamps[:-1], adjusted_timestamps[1:]) + ] + + segments = [] + + def split_timestamps(start_index, end_index): + if ( + start_index == end_index + or adjusted_timestamps[end_index][1] + - adjusted_timestamps[start_index][0] + < 20 * sampling_rate + ): + segments.append([start_index, end_index]) + else: + if not intervals[start_index:end_index]: + return + max_interval_index = intervals[start_index:end_index].index( + max(intervals[start_index:end_index]) + ) + split_index = start_index + max_interval_index + split_timestamps(start_index, split_index) + split_timestamps(split_index + 1, end_index) + + split_timestamps(0, len(adjusted_timestamps) - 1) + + merged_timestamps = [ + [adjusted_timestamps[start][0], adjusted_timestamps[end][1]] + for start, end in segments + ] + return merged_timestamps + + def vad(self, speakerdia, audio): + """ + Process the audio based on the given speaker diarization dataframe. + + Args: + speakerdia (pd.DataFrame): The diarization dataframe containing start, end, and speaker info. + audio (dict): A dictionary containing the audio waveform and sample rate. + + Returns: + list: A list of dictionaries containing processed audio segments with start, end, and speaker. + """ + sampling_rate = audio["sample_rate"] + audio_data = audio["waveform"] + + out = [] + last_end = 0 + speakers_seen = set() + count_id = 0 + + for index, row in speakerdia.iterrows(): + start = float(row["start"]) + end = float(row["end"]) + + if end <= last_end: + continue + last_end = end + + start_frame = int(start * sampling_rate) + end_frame = int(end * sampling_rate) + if row["speaker"] not in speakers_seen: + speakers_seen.add(row["speaker"]) + + if end - start <= VAD_THRESHOLD: + out.append( + { + "index": str(count_id).zfill(5), + "start": start, # in seconds + "end": end, + "speaker": row["speaker"], # same for all + } + ) + count_id += 1 + continue + + temp_audio = audio_data[start_frame:end_frame] + + # resample from 24k to 16k + temp_audio_resampled = librosa.resample( + temp_audio, orig_sr=sampling_rate, target_sr=SAMPLING_RATE + ) + + for start_frame_sub, end_frame_sub in self.segment_speech( + temp_audio_resampled, + int(start * SAMPLING_RATE), + int(end * SAMPLING_RATE), + SAMPLING_RATE, + ): + out.append( + { + "index": str(count_id).zfill(5), + "start": start_frame_sub / SAMPLING_RATE, # in seconds + "end": end_frame_sub / SAMPLING_RATE, + "speaker": row["speaker"], # same for all + } + ) + count_id += 1 + + return out diff --git a/preprocessors/Emilia/models/whisper_asr.py b/preprocessors/Emilia/models/whisper_asr.py new file mode 100755 index 00000000..dd062b80 --- /dev/null +++ b/preprocessors/Emilia/models/whisper_asr.py @@ -0,0 +1,299 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import faster_whisper +from typing import List, Union, Optional, NamedTuple +import torch +import numpy as np +import tqdm +from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram +from whisperx.types import TranscriptionResult, SingleSegment +from whisperx.asr import WhisperModel, FasterWhisperPipeline, find_numeral_symbol_tokens + + +class VadFreeFasterWhisperPipeline(FasterWhisperPipeline): + """ + FasterWhisperModel without VAD + """ + + def __init__( + self, + model, + options: NamedTuple, + tokenizer=None, + device: Union[int, str, "torch.device"] = -1, + framework="pt", + language: Optional[str] = None, + suppress_numerals: bool = False, + **kwargs, + ): + """ + Initialize the VadFreeFasterWhisperPipeline. + + Args: + model: The Whisper model instance. + options: Transcription options. + tokenizer: The tokenizer instance. + device: Device to run the model on. + framework: The framework to use ('pt' for PyTorch). + language: The language for transcription. + suppress_numerals: Whether to suppress numeral tokens. + **kwargs: Additional keyword arguments. + + Returns: + None + """ + super().__init__( + model=model, + vad=None, + vad_params={}, + options=options, + tokenizer=tokenizer, + device=device, + framework=framework, + language=language, + suppress_numerals=suppress_numerals, + **kwargs, + ) + + def detect_language(self, audio: np.ndarray): + """ + Detect the language of the audio. + + Args: + audio (np.ndarray): The input audio signal. + + Returns: + tuple: Detected language and its probability. + """ + model_n_mels = self.model.feat_kwargs.get("feature_size") + if audio.shape[0] > N_SAMPLES: + # Randomly sample N_SAMPLES from the audio array + start_index = np.random.randint(0, audio.shape[0] - N_SAMPLES) + audio_sample = audio[start_index : start_index + N_SAMPLES] + else: + audio_sample = audio[:N_SAMPLES] + padding = 0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0] + segment = log_mel_spectrogram( + audio_sample, + n_mels=model_n_mels if model_n_mels is not None else 80, + padding=padding, + ) + encoder_output = self.model.encode(segment) + results = self.model.model.detect_language(encoder_output) + language_token, language_probability = results[0][0] + language = language_token[2:-2] + return language, language_probability + + def transcribe( + self, + audio: Union[str, np.ndarray], + vad_segments: List[dict], + batch_size=None, + num_workers=0, + language=None, + task=None, + chunk_size=30, + print_progress=False, + combined_progress=False, + ) -> TranscriptionResult: + """ + Transcribe the audio into text. + + Args: + audio (Union[str, np.ndarray]): The input audio signal or path to audio file. + vad_segments (List[dict]): List of VAD segments. + batch_size (int, optional): Batch size for transcription. Defaults to None. + num_workers (int, optional): Number of workers for loading data. Defaults to 0. + language (str, optional): Language for transcription. Defaults to None. + task (str, optional): Task type ('transcribe' or 'translate'). Defaults to None. + chunk_size (int, optional): Size of chunks for processing. Defaults to 30. + print_progress (bool, optional): Whether to print progress. Defaults to False. + combined_progress (bool, optional): Whether to combine progress. Defaults to False. + + Returns: + TranscriptionResult: The transcription result containing segments and language. + """ + if isinstance(audio, str): + audio = load_audio(audio) + + def data(audio, segments): + for seg in segments: + f1 = int(seg["start"] * SAMPLE_RATE) + f2 = int(seg["end"] * SAMPLE_RATE) + yield {"inputs": audio[f1:f2]} + + if self.tokenizer is None: + language = language or self.detect_language(audio) + task = task or "transcribe" + self.tokenizer = faster_whisper.tokenizer.Tokenizer( + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task=task, + language=language, + ) + else: + language = language or self.tokenizer.language_code + task = task or self.tokenizer.task + if task != self.tokenizer.task or language != self.tokenizer.language_code: + self.tokenizer = faster_whisper.tokenizer.Tokenizer( + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task=task, + language=language, + ) + + if self.suppress_numerals: + previous_suppress_tokens = self.options.suppress_tokens + numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) + new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens + new_suppressed_tokens = list(set(new_suppressed_tokens)) + self.options = self.options._replace(suppress_tokens=new_suppressed_tokens) + + segments: List[SingleSegment] = [] + batch_size = batch_size or self._batch_size + total_segments = len(vad_segments) + progress = tqdm.tqdm(total=total_segments, desc="Transcribing") + for idx, out in enumerate( + self.__call__( + data(audio, vad_segments), + batch_size=batch_size, + num_workers=num_workers, + ) + ): + if print_progress: + progress.update(1) + text = out["text"] + if batch_size in [0, 1, None]: + text = text[0] + segments.append( + { + "text": text, + "start": round(vad_segments[idx]["start"], 3), + "end": round(vad_segments[idx]["end"], 3), + "speaker": vad_segments[idx].get("speaker", None), + } + ) + + # revert the tokenizer if multilingual inference is enabled + if self.preset_language is None: + self.tokenizer = None + + # revert suppressed tokens if suppress_numerals is enabled + if self.suppress_numerals: + self.options = self.options._replace( + suppress_tokens=previous_suppress_tokens + ) + + return {"segments": segments, "language": language} + + +def load_asr_model( + whisper_arch: str, + device: str, + device_index: int = 0, + compute_type: str = "float16", + asr_options: Optional[dict] = None, + language: Optional[str] = None, + vad_model=None, + vad_options=None, + model: Optional[WhisperModel] = None, + task: str = "transcribe", + download_root: Optional[str] = None, + threads: int = 4, +) -> VadFreeFasterWhisperPipeline: + """ + Load a Whisper model for inference. + + Args: + whisper_arch (str): The name of the Whisper model to load. + device (str): The device to load the model on. + device_index (int, optional): The device index. Defaults to 0. + compute_type (str, optional): The compute type to use for the model. Defaults to "float16". + asr_options (Optional[dict], optional): Options for ASR. Defaults to None. + language (Optional[str], optional): The language of the model. Defaults to None. + vad_model: The VAD model instance. Defaults to None. + vad_options: Options for VAD. Defaults to None. + model (Optional[WhisperModel], optional): The WhisperModel instance to use. Defaults to None. + task (str, optional): The task type ('transcribe' or 'translate'). Defaults to "transcribe". + download_root (Optional[str], optional): The root directory to download the model to. Defaults to None. + threads (int, optional): The number of CPU threads to use per worker. Defaults to 4. + + Returns: + VadFreeFasterWhisperPipeline: The loaded Whisper pipeline. + + Raises: + ValueError: If the whisper architecture is not recognized. + """ + + if whisper_arch.endswith(".en"): + language = "en" + + model = model or WhisperModel( + whisper_arch, + device=device, + device_index=device_index, + compute_type=compute_type, + download_root=download_root, + cpu_threads=threads, + ) + if language is not None: + tokenizer = faster_whisper.tokenizer.Tokenizer( + model.hf_tokenizer, + model.model.is_multilingual, + task=task, + language=language, + ) + else: + print( + "No language specified, language will be detected for each audio file (increases inference time)." + ) + tokenizer = None + + default_asr_options = { + "beam_size": 5, + "best_of": 5, + "patience": 1, + "length_penalty": 1, + "repetition_penalty": 1, + "no_repeat_ngram_size": 0, + "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": False, + "prompt_reset_on_temperature": 0.5, + "initial_prompt": None, + "prefix": None, + "suppress_blank": True, + "suppress_tokens": [-1], + "without_timestamps": True, + "max_initial_timestamp": 0.0, + "word_timestamps": False, + "prepend_punctuations": "\"'“¿([{-", + "append_punctuations": "\"'.。,,!!??::”)]}、", + "suppress_numerals": False, + "max_new_tokens": None, + "clip_timestamps": None, + "hallucination_silence_threshold": None, + } + + if asr_options is not None: + default_asr_options.update(asr_options) + + suppress_numerals = default_asr_options["suppress_numerals"] + del default_asr_options["suppress_numerals"] + + default_asr_options = faster_whisper.transcribe.TranscriptionOptions( + **default_asr_options + ) + + return VadFreeFasterWhisperPipeline( + model=model, + options=default_asr_options, + tokenizer=tokenizer, + language=language, + suppress_numerals=suppress_numerals, + ) diff --git a/preprocessors/Emilia/requirements.txt b/preprocessors/Emilia/requirements.txt new file mode 100755 index 00000000..8657331f --- /dev/null +++ b/preprocessors/Emilia/requirements.txt @@ -0,0 +1,7 @@ +librosa +numpy +tqdm +pydub +pyannote.audio +pandas +git+https://github.com/m-bain/whisperx.git # needs torch >= 2.0.0 diff --git a/preprocessors/Emilia/utils/__init__.py b/preprocessors/Emilia/utils/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/preprocessors/Emilia/utils/logger.py b/preprocessors/Emilia/utils/logger.py new file mode 100755 index 00000000..dd6864f2 --- /dev/null +++ b/preprocessors/Emilia/utils/logger.py @@ -0,0 +1,124 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import time +import os + + +class Logger: + """ + Logger class for managing logging operations. + """ + + _logger = None + + @classmethod + def get_logger(cls, name=None): + """ + Get the logger instance with the specified name. If it doesn't exist, create and cache it. + + Args: + cls (type): The class type. + name (str, optional): The name of the logger. Defaults to None, which uses the class name. + + Returns: + logging.Logger: The logger instance. + """ + if cls._logger is None: + cls._logger = cls.init_logger(name) + return cls._logger + + @classmethod + def init_logger(cls, name=None): + """ + Initialize the logger, including file and console logging. + + Args: + cls (type): The class type. + name (str, optional): The name of the logger. Defaults to None. + + Returns: + logging.Logger: The initialized logger instance. + """ + if name is None: + name = "main" + if "SELF_ID" in os.environ: + name = name + "_ID" + os.environ["SELF_ID"] + if "CUDA_VISIBLE_DEVICES" in os.environ: + name = name + "_GPU" + os.environ["CUDA_VISIBLE_DEVICES"] + print(f"Initialize logger for {name}") + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + # Add file handler to save logs to a file + log_date = time.strftime("%Y-%m-%d", time.localtime()) + log_time = time.strftime("%H-%M-%S", time.localtime()) + os.makedirs(f"logs/{log_date}", exist_ok=True) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + fh = logging.FileHandler(f"logs/{log_date}/{name}-{log_time}.log") + fh.setFormatter(formatter) + logger.addHandler(fh) + + # Create a custom log formatter to set specific log levels to color + class ColorFormatter(logging.Formatter): + """ + Custom log formatter to add color to specific log levels. + """ + + def format(self, record): + """ + Format the log record with color based on log level. + + Args: + record (logging.LogRecord): The log record to format. + + Returns: + str: The formatted log message. + """ + if record.levelno >= logging.ERROR: + record.msg = "\033[1;31m" + str(record.msg) + "\033[0m" + elif record.levelno >= logging.WARNING: + record.msg = "\033[1;33m" + str(record.msg) + "\033[0m" + elif record.levelno >= logging.INFO: + record.msg = "\033[1;34m" + str(record.msg) + "\033[0m" + elif record.levelno >= logging.DEBUG: + record.msg = "\033[1;32m" + str(record.msg) + "\033[0m" + return super().format(record) + + color_formatter = ColorFormatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ch = logging.StreamHandler() + ch.setFormatter(color_formatter) + logger.addHandler(ch) + + return logger + + +def time_logger(func): + """ + Decorator to log the execution time of a function. + + Args: + func (callable): The function whose execution time is to be logged. + + Returns: + callable: The wrapper function that logs the execution time of the original function. + """ + + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + Logger.get_logger().debug( + f"Function {func.__name__} took {end_time - start_time} seconds to execute" + ) + return result + + return wrapper diff --git a/preprocessors/Emilia/utils/tool.py b/preprocessors/Emilia/utils/tool.py new file mode 100755 index 00000000..2d3a278a --- /dev/null +++ b/preprocessors/Emilia/utils/tool.py @@ -0,0 +1,323 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from concurrent.futures import ThreadPoolExecutor +import json +import os +import librosa +import numpy as np +import time +import torch +from pydub import AudioSegment +import soundfile as sf +import onnxruntime as ort +import tqdm +import subprocess +import re + +from utils.logger import Logger, time_logger + + +def load_cfg(cfg_path): + """ + Load configuration from a JSON file. + + Args: + cfg_path (str): Path to the configuration file. + + Returns: + dict: Configuration dictionary. + """ + if not os.path.exists(cfg_path): + raise FileNotFoundError( + f"{cfg_path} not found. Please copy, configure, and rename `config.json.example` to `{cfg_path}`." + ) + with open(cfg_path, "r") as f: + try: + cfg = json.load(f) + except json.decoder.JSONDecodeError as e: + raise TypeError( + "Please finish the `// TODO:` in the `config.json` file before running the script. Check README.md for details." + ) + return cfg + + +def write_wav(path, sr, x): + """Write numpy array to WAV file.""" + sf.write(path, x, sr) + + +def write_mp3(path, sr, x): + """Convert numpy array to MP3.""" + try: + # Ensure x is in the correct format and normalize if necessary + if x.dtype != np.int16: + # Normalize the array to fit in int16 range if it's not already int16 + x = np.int16(x / np.max(np.abs(x)) * 32767) + + # Create audio segment from numpy array + audio = AudioSegment( + x.tobytes(), frame_rate=sr, sample_width=x.dtype.itemsize, channels=1 + ) + # Export as MP3 file + audio.export(path, format="mp3") + except Exception as e: + print(e) + print("Error: Failed to write MP3 file.") + + +def get_audio_files(folder_path): + """Get all audio files in a folder.""" + audio_files = [] + for root, _, files in os.walk(folder_path): + if "_processed" in root: + continue + for file in files: + if ".temp" in file: + continue + if file.endswith((".mp3", ".wav", ".flac", ".m4a", ".aac")): + audio_files.append(os.path.join(root, file)) + return audio_files + + +def get_specific_files(folder_path, ext): + """Get specific files with a given extension in a folder.""" + audio_files = [] + for root, _, files in os.walk(folder_path): + if "_processed" in root: + continue + for file in files: + if ".temp" in file: + continue + if file.endswith(ext): + audio_files.append(os.path.join(root, file)) + return audio_files + + +def export_to_srt(asr_result, file_path): + """Export ASR result to SRT file.""" + with open(file_path, "w") as f: + + def format_time(seconds): + return ( + time.strftime("%H:%M:%S", time.gmtime(seconds)) + + f",{int(seconds * 1000 % 1000):03d}" + ) + + for idx, segment in enumerate(asr_result): + f.write(f"{idx + 1}\n") + f.write( + f"{format_time(segment['start'])} --> {format_time(segment['end'])}\n" + ) + f.write(f"{segment['speaker']}: {segment['text']}\n\n") + + +def detect_gpu(): + """Detect if GPU is available and print related information.""" + logger = Logger.get_logger() + + if "CUDA_VISIBLE_DEVICES" not in os.environ: + logger.info("ENV: CUDA_VISIBLE_DEVICES not set, use default setting") + else: + gpu_id = os.environ["CUDA_VISIBLE_DEVICES"] + logger.info(f"ENV: CUDA_VISIBLE_DEVICES = {gpu_id}") + + if not torch.cuda.is_available(): + logger.error("Torch CUDA: No GPU detected. torch.cuda.is_available() = False.") + return False + + num_gpus = torch.cuda.device_count() + logger.debug(f"Torch CUDA: Detected {num_gpus} GPUs.") + for i in range(num_gpus): + gpu_name = torch.cuda.get_device_name(i) + logger.debug(f" * GPU {i}: {gpu_name}") + + logger.debug("Torch: CUDNN version = " + str(torch.backends.cudnn.version())) + if not torch.backends.cudnn.is_available(): + logger.error("Torch: CUDNN is not available.") + return False + logger.debug("Torch: CUDNN is available.") + + ort_providers = ort.get_available_providers() + logger.debug(f"ORT: Available providers: {ort_providers}") + if "CUDAExecutionProvider" not in ort_providers: + logger.warning( + "ORT: CUDAExecutionProvider is not available. " + "Please install a compatible version of ONNX Runtime. " + "See https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html" + ) + + return True + + +def get_gpu_nums(): + """Get GPU nums by nvidia-smi.""" + logger = Logger.get_logger() + try: + result = subprocess.check_output("nvidia-smi -L | wc -l", shell=True) + gpus_count = int(result.decode().strip()) + except Exception as e: + logger.error("Error occurred while getting GPU count: " + str(e)) + gpus_count = 8 # Default to 8 if GPU count retrieval fails + return gpus_count + + +def check_env(logger): + """Check environment variables.""" + if "http_proxy" in os.environ: + logger.info(f"ENV: http_proxy = {os.environ['http_proxy']}") + else: + logger.info("ENV: http_proxy not set") + + if "https_proxy" in os.environ: + logger.info(f"ENV: https_proxy = {os.environ['https_proxy']}") + else: + logger.info("ENV: https_proxy not set") + + if "HF_ENDPOINT" in os.environ: + logger.info( + f"ENV: HF_ENDPOINT = {os.environ['HF_ENDPOINT']}, if downloading slow, try `unset HF_ENDPOINT`" + ) + else: + logger.info("ENV: HF_ENDPOINT not set") + + hostname = os.popen("hostname").read().strip() + logger.debug(f"HOSTNAME: {hostname}") + + environ_path = os.environ["PATH"] + environ_ld_library = os.environ.get("LD_LIBRARY_PATH", "") + logger.debug(f"ENV: PATH = {environ_path}, LD_LIBRARY_PATH = {environ_ld_library}") + + +@time_logger +def export_to_mp3(audio, asr_result, folder_path, file_name): + """Export segmented audio to MP3 files.""" + sr = audio["sample_rate"] + audio = audio["waveform"] + + os.makedirs(folder_path, exist_ok=True) + + # Function to process each segment in a separate thread + def process_segment(idx, segment): + start, end = int(segment["start"] * sr), int(segment["end"] * sr) + split_audio = audio[start:end] + split_audio = librosa.to_mono(split_audio) + out_file = f"{file_name}_{idx}.mp3" + out_path = os.path.join(folder_path, out_file) + write_mp3(out_path, sr, split_audio) + + # Use ThreadPoolExecutor for concurrent execution + with ThreadPoolExecutor(max_workers=72) as executor: + # Submit each segment processing as a separate thread + futures = [ + executor.submit(process_segment, idx, segment) + for idx, segment in enumerate(asr_result) + ] + + # Wait for all threads to complete + for future in tqdm.tqdm( + futures, total=len(asr_result), desc="Exporting to MP3" + ): + future.result() + + +@time_logger +def export_to_wav(audio, asr_result, folder_path, file_name): + """Export segmented audio to WAV files.""" + sr = audio["sample_rate"] + audio = audio["waveform"] + + os.makedirs(folder_path, exist_ok=True) + + for idx, segment in enumerate(tqdm.tqdm(asr_result, desc="Exporting to WAV")): + start, end = int(segment["start"] * sr), int(segment["end"] * sr) + split_audio = audio[start:end] + split_audio = librosa.to_mono(split_audio) + out_file = f"{file_name}_{idx}.wav" + out_path = os.path.join(folder_path, out_file) + write_wav(out_path, sr, split_audio) + + +def get_char_count(text): + """ + Get the number of characters in the text. + + Args: + text (str): Input text. + + Returns: + int: Number of characters in the text. + """ + # Using regular expression to remove punctuation and spaces + cleaned_text = re.sub(r"[,.!?\"',。!?“”‘’ ]", "", text) + char_count = len(cleaned_text) + return char_count + + +def calculate_audio_stats( + data, min_duration=3, max_duration=30, min_dnsmos=3, min_char_count=2 +): + """ + Reading the proviced json, calculate and return the audio ID and their duration that meet the given filtering criteria. + + Args: + data: JSON. + min_duration: Minimum duration of the audio in seconds. + max_duration: Maximum duration of the audio in seconds. + min_dnsmos: Minimum DNSMOS value. + min_char_count: Minimum number of characters. + + Returns: + valid_audio_stats: A list containing tuples of audio ID and their duration. + """ + all_audio_stats = [] + valid_audio_stats = [] + avg_durations = [] + + # iterate over each entry in the JSON to collect the average duration of the phonemes + for entry in data: + # remove punctuation and spaces + char_count = get_char_count(entry["text"]) + duration = entry["end"] - entry["start"] + if char_count > 0: + avg_durations.append(duration / char_count) + + # calculate the bounds for the average character duration + if len(avg_durations) > 0: + q1 = np.percentile(avg_durations, 25) + q3 = np.percentile(avg_durations, 75) + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + else: + # if no valid character data, use default values + lower_bound, upper_bound = 0, np.inf + + # iterate over each entry in the JSON to apply all filtering criteria + for idx, entry in enumerate(data): + duration = entry["end"] - entry["start"] + dnsmos = entry["dnsmos"] + # remove punctuation and spaces + char_count = get_char_count(entry["text"]) + if char_count > 0: + avg_char_duration = duration / char_count + else: + avg_char_duration = 0 + + # collect the duration of all audios + all_audio_stats.append((idx, duration)) + + # apply filtering criteria + if ( + (min_duration <= duration <= max_duration) # withing duration range + and (dnsmos >= min_dnsmos) + and (char_count >= min_char_count) + and ( + lower_bound <= avg_char_duration <= upper_bound + ) # average character duration within bounds + ): + valid_audio_stats.append((idx, duration)) + + return valid_audio_stats, all_audio_stats