diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 2aa81b68ed..3650ce3205 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -102,11 +102,9 @@ for idx, fields in enumerate(tqdm(samples)): class_name = fields[class_name_key] audio_file = fields["audio_file"] - dataset_name = fields["dataset_name"] + embedding_key = fields["audio_unique_name"] root_path = fields["root_path"] - relfilepath = os.path.splitext(audio_file.replace(root_path, ""))[0] - embedding_key = f"{dataset_name}#{relfilepath}" if args.old_file is not None and embedding_key in encoder_manager.clip_ids: # get the embedding from the old file embedd = encoder_manager.get_embedding_by_clip(embedding_key) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 6e41889012..eeadf7d331 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -59,6 +59,18 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): return items[:eval_split_size], items[eval_split_size:] +def add_extra_keys(metadata, language, dataset_name): + for item in metadata: + # add language name + item["language"] = language + # add unique audio name + relfilepath = os.path.splitext(item["audio_file"].replace(item["root_path"], ""))[0] + audio_unique_name = f"{dataset_name}#{relfilepath}" + item["audio_unique_name"] = audio_unique_name + + return metadata + + def load_tts_samples( datasets: Union[List[Dict], Dict], eval_split=True, @@ -111,15 +123,15 @@ def load_tts_samples( # load train set meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) assert len(meta_data_train) > 0, f" [!] No training samples found in {root_path}/{meta_file_train}" - meta_data_train = [{**item, **{"language": language, "dataset_name": dataset_name}} for item in meta_data_train] + + meta_data_train = add_extra_keys(meta_data_train, language, dataset_name) + print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") # load evaluation split if set if eval_split: if meta_file_val: meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) - meta_data_eval = [ - {**item, **{"language": language, "dataset_name": dataset_name}} for item in meta_data_eval - ] + meta_data_eval = add_extra_keys(meta_data_eval, language, dataset_name) else: meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size) meta_data_eval_all += meta_data_eval diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index d8f16e4efe..eec493ecf3 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -256,6 +256,7 @@ def load_data(self, idx): "speaker_name": item["speaker_name"], "language_name": item["language"], "wav_file_name": os.path.basename(item["audio_file"]), + "audio_unique_name": item["audio_unique_name"], } return sample @@ -397,8 +398,8 @@ def collate_fn(self, batch): language_ids = None # get pre-computed d-vectors if self.d_vector_mapping is not None: - wav_files_names = list(batch["wav_file_name"]) - d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names] + embedding_keys = list(batch["audio_unique_name"]) + d_vectors = [self.d_vector_mapping[w]["embedding"] for w in embedding_keys] else: d_vectors = None diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index cba3749285..e7eebebaab 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -284,6 +284,7 @@ def __getitem__(self, idx): "wav_file": wav_filename, "speaker_name": item["speaker_name"], "language_name": item["language"], + "audio_unique_name": item["audio_unique_name"], } @property @@ -308,6 +309,7 @@ def collate_fn(self, batch): - language_names: :math:`[B]` - audiofile_paths: :math:`[B]` - raw_texts: :math:`[B]` + - audio_unique_names: :math:`[B]` """ # convert list of dicts to dict of lists B = len(batch) @@ -348,6 +350,7 @@ def collate_fn(self, batch): "language_names": batch["language_name"], "audio_files": batch["wav_file"], "raw_text": batch["raw_text"], + "audio_unique_names": batch["audio_unique_name"], } @@ -1470,7 +1473,7 @@ def format_batch(self, batch: Dict) -> Dict: # get d_vectors from audio file names if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: d_vector_mapping = self.speaker_manager.embeddings - d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]] + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] d_vectors = torch.FloatTensor(d_vectors) # get language ids from language names diff --git a/tests/data/ljspeech/speakers.json b/tests/data/ljspeech/speakers.json index 915cff7360..2790a30ea1 100644 --- a/tests/data/ljspeech/speakers.json +++ b/tests/data/ljspeech/speakers.json @@ -1,5 +1,5 @@ { - "LJ001-0001.wav": { + "#/wavs/LJ001-0001": { "name": "ljspeech-0", "embedding": [ 0.05539746582508087, @@ -260,7 +260,7 @@ -0.09469571709632874 ] }, - "LJ001-0002.wav": { + "#/wavs/LJ001-0002": { "name": "ljspeech-1", "embedding": [ 0.05539746582508087, @@ -521,7 +521,7 @@ -0.09469571709632874 ] }, - "LJ001-0003.wav": { + "#/wavs/LJ001-0003": { "name": "ljspeech-2", "embedding": [ 0.05539746582508087, @@ -782,7 +782,7 @@ -0.09469571709632874 ] }, - "LJ001-0004.wav": { + "#/wavs/LJ001-0004": { "name": "ljspeech-3", "embedding": [ 0.05539746582508087, @@ -1043,7 +1043,7 @@ -0.09469571709632874 ] }, - "LJ001-0005.wav": { + "#/wavs/LJ001-0005": { "name": "ljspeech-4", "embedding": [ 0.05539746582508087, @@ -1304,7 +1304,7 @@ -0.09469571709632874 ] }, - "LJ001-0006.wav": { + "#/wavs/LJ001-0006": { "name": "ljspeech-5", "embedding": [ 0.05539746582508087, @@ -1565,7 +1565,7 @@ -0.09469571709632874 ] }, - "LJ001-0007.wav": { + "#/wavs/LJ001-0007": { "name": "ljspeech-6", "embedding": [ 0.05539746582508087, @@ -1826,7 +1826,7 @@ -0.09469571709632874 ] }, - "LJ001-0008.wav": { + "#/wavs/LJ001-0008": { "name": "ljspeech-7", "embedding": [ 0.05539746582508087, @@ -2087,7 +2087,7 @@ -0.09469571709632874 ] }, - "LJ001-0009.wav": { + "#/wavs/LJ001-0009": { "name": "ljspeech-8", "embedding": [ 0.05539746582508087, @@ -2348,7 +2348,7 @@ -0.09469571709632874 ] }, - "LJ001-0010.wav": { + "#/wavs/LJ001-0010": { "name": "ljspeech-9", "embedding": [ 0.05539746582508087,