diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 39ae70c41f..1826cbe0f0 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -14,6 +14,8 @@ from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy from mutagen.mp3 import MP3 +from mutagen.flac import FLAC +from mutagen.wave import WAVE # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 @@ -48,12 +50,14 @@ def get_audio_size(audiopath): extension = audiopath.rpartition(".")[-1].lower() if extension == "mp3": audio_info = MP3(audiopath).info - return int(audio_info.length * audio_info.sample_rate) - if extension in ("wav", "flac"): - compress_factor = 8 - bitrate = 16 # assuming 16bit audio - return int(os.path.getsize(audiopath) / bitrate * compress_factor) - raise RuntimeError(f"The audio format {extension} is not supported, please convert the audio files for mp3, flac or wav format!") + elif extension == "wav": + audio_info = WAVE(audiopath).info + elif extension == "flac": + audio_info = FLAC(audiopath).info + else: + raise RuntimeError(f"The audio format {extension} is not supported, please convert the audio files for mp3, flac or wav format!") + + return int(audio_info.length * audio_info.sample_rate) class TTSDataset(Dataset):