diff --git a/TTS/api.py b/TTS/api.py index ed82825007..83189482cb 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -1,12 +1,14 @@ +"""Coqui TTS Python API.""" + import logging import tempfile import warnings from pathlib import Path +from typing import Optional from torch import nn from TTS.config import load_config -from TTS.utils.audio.numpy_transforms import save_wav from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer @@ -19,13 +21,19 @@ class TTS(nn.Module): def __init__( self, model_name: str = "", - model_path: str = None, - config_path: str = None, - vocoder_path: str = None, - vocoder_config_path: str = None, + *, + model_path: Optional[str] = None, + config_path: Optional[str] = None, + vocoder_name: Optional[str] = None, + vocoder_path: Optional[str] = None, + vocoder_config_path: Optional[str] = None, + encoder_path: Optional[str] = None, + encoder_config_path: Optional[str] = None, + speakers_file_path: Optional[str] = None, + language_ids_file_path: Optional[str] = None, progress_bar: bool = True, - gpu=False, - ): + gpu: bool = False, + ) -> None: """🐸TTS python interface that allows to load and use the released models. Example with a multi-speaker model: @@ -35,31 +43,36 @@ def __init__( >>> tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav") Example with a single-speaker model: - >>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False) + >>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False) >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") Example loading a model from a path: - >>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False) + >>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False) >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") Example voice cloning with YourTTS in English, French and Portuguese: - >>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True) + >>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False).to("cuda") >>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav") >>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav") >>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav") Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html): - >>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True) + >>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False).to("cuda") >>> tts.tts_to_file("This is a test.", file_path="output.wav") Args: model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. model_path (str, optional): Path to the model checkpoint. Defaults to None. config_path (str, optional): Path to the model config. Defaults to None. + vocoder_name (str, optional): Pre-trained vocoder to use. Defaults to None, i.e. using the default vocoder. vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None. - progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True. - gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + encoder_path: Path to speaker encoder checkpoint. Default to None. + encoder_config_path: Path to speaker encoder config file. Defaults to None. + speakers_file_path: JSON file for multi-speaker model. Defaults to None. + language_ids_file_path: JSON file for multilingual model. Defaults to None + progress_bar (bool, optional): Whether to print a progress bar while downloading a model. Defaults to True. + gpu (bool, optional): Enable/disable GPU. Defaults to False. DEPRECATED, use TTS(...).to("cuda") """ super().__init__() self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar) @@ -67,34 +80,45 @@ def __init__( self.synthesizer = None self.voice_converter = None self.model_name = "" + + self.vocoder_path = vocoder_path + self.vocoder_config_path = vocoder_config_path + self.encoder_path = encoder_path + self.encoder_config_path = encoder_config_path + self.speakers_file_path = speakers_file_path + self.language_ids_file_path = language_ids_file_path + if gpu: warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") if model_name is not None and len(model_name) > 0: if "tts_models" in model_name: - self.load_tts_model_by_name(model_name, gpu) + self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu) elif "voice_conversion_models" in model_name: - self.load_vc_model_by_name(model_name, gpu) + self.load_vc_model_by_name(model_name, gpu=gpu) + # To allow just TTS("xtts") else: - self.load_model_by_name(model_name, gpu) + self.load_model_by_name(model_name, vocoder_name, gpu=gpu) if model_path: - self.load_tts_model_by_path( - model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu - ) + self.load_tts_model_by_path(model_path, config_path, gpu=gpu) @property - def models(self): + def models(self) -> list[str]: return self.manager.list_tts_models() @property - def is_multi_speaker(self): - if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager: + def is_multi_speaker(self) -> bool: + if ( + self.synthesizer is not None + and hasattr(self.synthesizer.tts_model, "speaker_manager") + and self.synthesizer.tts_model.speaker_manager + ): return self.synthesizer.tts_model.speaker_manager.num_speakers > 1 return False @property - def is_multi_lingual(self): + def is_multi_lingual(self) -> bool: # Not sure what sets this to None, but applied a fix to prevent crashing. if ( isinstance(self.model_name, str) @@ -103,51 +127,63 @@ def is_multi_lingual(self): and ("xtts" in self.config.model or "languages" in self.config and len(self.config.languages) > 1) ): return True - if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager: + if ( + self.synthesizer is not None + and hasattr(self.synthesizer.tts_model, "language_manager") + and self.synthesizer.tts_model.language_manager + ): return self.synthesizer.tts_model.language_manager.num_languages > 1 return False @property - def speakers(self): + def speakers(self) -> list[str]: if not self.is_multi_speaker: return None return self.synthesizer.tts_model.speaker_manager.speaker_names @property - def languages(self): + def languages(self) -> list[str]: if not self.is_multi_lingual: return None return self.synthesizer.tts_model.language_manager.language_names @staticmethod - def get_models_file_path(): + def get_models_file_path() -> Path: return Path(__file__).parent / ".models.json" @staticmethod - def list_models(): + def list_models() -> list[str]: return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models() - def download_model_by_name(self, model_name: str): + def download_model_by_name( + self, model_name: str, vocoder_name: Optional[str] = None + ) -> tuple[Optional[str], Optional[str], Optional[str]]: model_path, config_path, model_item = self.manager.download_model(model_name) if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)): # return model directory if there are multiple files # we assume that the model knows how to load itself - return None, None, None, None, model_path + return None, None, model_path if model_item.get("default_vocoder") is None: - return model_path, config_path, None, None, None - vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"]) - return model_path, config_path, vocoder_path, vocoder_config_path, None - - def load_model_by_name(self, model_name: str, gpu: bool = False): + return model_path, config_path, None + if vocoder_name is None: + vocoder_name = model_item["default_vocoder"] + vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name) + # A local vocoder model will take precedence if specified via vocoder_path + if self.vocoder_path is None or self.vocoder_config_path is None: + self.vocoder_path = vocoder_path + self.vocoder_config_path = vocoder_config_path + return model_path, config_path, None + + def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None: """Load one of the 🐸TTS models by name. Args: model_name (str): Model name to load. You can list models by ```tts.models```. gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ - self.load_tts_model_by_name(model_name, gpu) + self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu) - def load_vc_model_by_name(self, model_name: str, gpu: bool = False): + def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False) -> None: """Load one of the voice conversion models by name. Args: @@ -155,12 +191,12 @@ def load_vc_model_by_name(self, model_name: str, gpu: bool = False): gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ self.model_name = model_name - model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name) + model_path, config_path, model_dir = self.download_model_by_name(model_name) self.voice_converter = Synthesizer( vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu ) - def load_tts_model_by_name(self, model_name: str, gpu: bool = False): + def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None: """Load one of 🐸TTS models by name. Args: @@ -172,7 +208,7 @@ def load_tts_model_by_name(self, model_name: str, gpu: bool = False): self.synthesizer = None self.model_name = model_name - model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name) + model_path, config_path, model_dir = self.download_model_by_name(model_name, vocoder_name) # init synthesizer # None values are fetch from the model @@ -181,17 +217,15 @@ def load_tts_model_by_name(self, model_name: str, gpu: bool = False): tts_config_path=config_path, tts_speakers_file=None, tts_languages_file=None, - vocoder_checkpoint=vocoder_path, - vocoder_config=vocoder_config_path, - encoder_checkpoint=None, - encoder_config=None, + vocoder_checkpoint=self.vocoder_path, + vocoder_config=self.vocoder_config_path, + encoder_checkpoint=self.encoder_path, + encoder_config=self.encoder_config_path, model_dir=model_dir, use_cuda=gpu, ) - def load_tts_model_by_path( - self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False - ): + def load_tts_model_by_path(self, model_path: str, config_path: str, *, gpu: bool = False) -> None: """Load a model from a path. Args: @@ -205,22 +239,22 @@ def load_tts_model_by_path( self.synthesizer = Synthesizer( tts_checkpoint=model_path, tts_config_path=config_path, - tts_speakers_file=None, - tts_languages_file=None, - vocoder_checkpoint=vocoder_path, - vocoder_config=vocoder_config, - encoder_checkpoint=None, - encoder_config=None, + tts_speakers_file=self.speakers_file_path, + tts_languages_file=self.language_ids_file_path, + vocoder_checkpoint=self.vocoder_path, + vocoder_config=self.vocoder_config_path, + encoder_checkpoint=self.encoder_path, + encoder_config=self.encoder_config_path, use_cuda=gpu, ) def _check_arguments( self, - speaker: str = None, - language: str = None, - speaker_wav: str = None, - emotion: str = None, - speed: float = None, + speaker: Optional[str] = None, + language: Optional[str] = None, + speaker_wav: Optional[str] = None, + emotion: Optional[str] = None, + speed: Optional[float] = None, **kwargs, ) -> None: """Check if the arguments are valid for the model.""" @@ -280,10 +314,6 @@ def tts( speaker_name=speaker, language_name=language, speaker_wav=speaker_wav, - reference_wav=None, - style_wav=None, - style_text=None, - reference_speaker_name=None, split_sentences=split_sentences, **kwargs, ) @@ -301,7 +331,7 @@ def tts_to_file( file_path: str = "output.wav", split_sentences: bool = True, **kwargs, - ): + ) -> str: """Convert text to speech. Args: @@ -367,6 +397,7 @@ def voice_conversion_to_file( source_wav: str, target_wav: str, file_path: str = "output.wav", + pipe_out=None, ) -> str: """Voice conversion with FreeVC. Convert source wav to target speaker. @@ -377,9 +408,11 @@ def voice_conversion_to_file( Path to the target wav file. file_path (str, optional): Output file path. Defaults to "output.wav". + pipe_out (BytesIO, optional): + Flag to stdout the generated TTS wav file for shell pipe. """ wav = self.voice_conversion(source_wav=source_wav, target_wav=target_wav) - save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate) + self.voice_converter.save_wav(wav=wav, path=file_path, pipe_out=pipe_out) return file_path def tts_with_vc( @@ -432,7 +465,8 @@ def tts_with_vc_to_file( file_path: str = "output.wav", speaker: str = None, split_sentences: bool = True, - ): + pipe_out=None, + ) -> str: """Convert text to speech with voice conversion and save to file. Check `tts_with_vc` for more details. @@ -455,8 +489,11 @@ def tts_with_vc_to_file( Split text into sentences, synthesize them separately and concatenate the file audio. Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only applicable to the 🐸TTS models. Defaults to True. + pipe_out (BytesIO, optional): + Flag to stdout the generated TTS wav file for shell pipe. """ wav = self.tts_with_vc( text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences ) - save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate) + self.voice_converter.save_wav(wav=wav, path=file_path, pipe_out=pipe_out) + return file_path diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 454f528ab4..885f6d6f0c 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -9,8 +9,6 @@ from argparse import RawTextHelpFormatter # pylint: disable=redefined-outer-name, unused-argument -from pathlib import Path - from TTS.utils.generic_utils import ConsoleFormatter, setup_logger logger = logging.getLogger(__name__) @@ -253,11 +251,6 @@ def parse_args() -> argparse.Namespace: action="store_true", ) # aux args - parser.add_argument( - "--save_spectogram", - action="store_true", - help="Save raw spectogram for further (vocoder) processing in out_path.", - ) parser.add_argument( "--reference_wav", type=str, @@ -317,7 +310,8 @@ def parse_args() -> argparse.Namespace: return args -def main(): +def main() -> None: + """Entry point for `tts` command line interface.""" setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) args = parse_args() @@ -325,12 +319,11 @@ def main(): with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout): # Late-import to make things load faster + from TTS.api import TTS from TTS.utils.manage import ModelManager - from TTS.utils.synthesizer import Synthesizer # load model manager - path = Path(__file__).parent / "../.models.json" - manager = ModelManager(path, progress_bar=args.progress_bar) + manager = ModelManager(models_file=TTS.get_models_file_path(), progress_bar=args.progress_bar) tts_path = None tts_config_path = None @@ -344,12 +337,12 @@ def main(): vc_config_path = None model_dir = None - # CASE1 #list : list pre-trained TTS models + # 1) List pre-trained TTS models if args.list_models: manager.list_models() sys.exit() - # CASE2 #info : model info for pre-trained TTS models + # 2) Info about pre-trained TTS models (without loading a model) if args.model_info_by_idx: model_query = args.model_info_by_idx manager.model_info_by_idx(model_query) @@ -360,91 +353,50 @@ def main(): manager.model_info_by_full_name(model_query_full_name) sys.exit() - # CASE3: load pre-trained model paths - if args.model_name is not None and not args.model_path: - model_path, config_path, model_item = manager.download_model(args.model_name) - # tts model - if model_item["model_type"] == "tts_models": - tts_path = model_path - tts_config_path = config_path - if args.vocoder_name is None and "default_vocoder" in model_item: - args.vocoder_name = model_item["default_vocoder"] - - # voice conversion model - if model_item["model_type"] == "voice_conversion_models": - vc_path = model_path - vc_config_path = config_path - - # tts model with multiple files to be loaded from the directory path - if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list): - model_dir = model_path - tts_path = None - tts_config_path = None - args.vocoder_name = None - - # load vocoder - if args.vocoder_name is not None and not args.vocoder_path: - vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) - - # CASE4: set custom model paths - if args.model_path is not None: - tts_path = args.model_path - tts_config_path = args.config_path - speakers_file_path = args.speakers_file_path - language_ids_file_path = args.language_ids_file_path - - if args.vocoder_path is not None: - vocoder_path = args.vocoder_path - vocoder_config_path = args.vocoder_config_path - - if args.encoder_path is not None: - encoder_path = args.encoder_path - encoder_config_path = args.encoder_config_path - + # 3) Load a model for further info or TTS/VC device = args.device if args.use_cuda: device = "cuda" - - # load models - synthesizer = Synthesizer( - tts_checkpoint=tts_path, - tts_config_path=tts_config_path, - tts_speakers_file=speakers_file_path, - tts_languages_file=language_ids_file_path, - vocoder_checkpoint=vocoder_path, - vocoder_config=vocoder_config_path, - encoder_checkpoint=encoder_path, - encoder_config=encoder_config_path, - vc_checkpoint=vc_path, - vc_config=vc_config_path, - model_dir=model_dir, - voice_dir=args.voice_dir, + # A local model will take precedence if specified via modeL_path + model_name = args.model_name if args.model_path is None else None + api = TTS( + model_name=model_name, + model_path=args.model_path, + config_path=args.config_path, + vocoder_name=args.vocoder_name, + vocoder_path=args.vocoder_path, + vocoder_config_path=args.vocoder_config_path, + encoder_path=args.encoder_path, + encoder_config_path=args.encoder_config_path, + speakers_file_path=args.speakers_file_path, + language_ids_file_path=args.language_ids_file_path, + progress_bar=args.progress_bar, ).to(device) # query speaker ids of a multi-speaker model. if args.list_speaker_idxs: - if synthesizer.tts_model.speaker_manager is None: + if not api.is_multi_speaker: logger.info("Model only has a single speaker.") return logger.info( "Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." ) - logger.info(list(synthesizer.tts_model.speaker_manager.name_to_id.keys())) + logger.info(api.speakers) return # query langauge ids of a multi-lingual model. if args.list_language_idxs: - if synthesizer.tts_model.language_manager is None: + if not api.is_multi_lingual: logger.info("Monolingual model.") return logger.info( "Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." ) - logger.info(synthesizer.tts_model.language_manager.name_to_id) + logger.info(api.languages) return # check the arguments against a multi-speaker model. - if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): + if api.is_multi_speaker and (not args.speaker_idx and not args.speaker_wav): logger.error( "Looks like you use a multi-speaker model. Define `--speaker_idx` to " "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`." @@ -455,31 +407,29 @@ def main(): if args.text: logger.info("Text: %s", args.text) - # kick it - if tts_path is not None: - wav = synthesizer.tts( - args.text, - speaker_name=args.speaker_idx, - language_name=args.language_idx, + if args.text is not None: + api.tts_to_file( + text=args.text, + speaker=args.speaker_idx, + language=args.language_idx, speaker_wav=args.speaker_wav, + pipe_out=pipe_out, + file_path=args.out_path, reference_wav=args.reference_wav, style_wav=args.capacitron_style_wav, style_text=args.capacitron_style_text, reference_speaker_name=args.reference_speaker_idx, + voice_dir=args.voice_dir, ) - elif vc_path is not None: - wav = synthesizer.voice_conversion( + logger.info("Saved TTS output to %s", args.out_path) + elif args.source_wav is not None and args.target_wav is not None: + api.voice_conversion_to_file( source_wav=args.source_wav, target_wav=args.target_wav, + file_path=args.out_path, + pipe_out=pipe_out, ) - elif model_dir is not None: - wav = synthesizer.tts( - args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav - ) - - # save the results - synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out) - logger.info("Saved output to %s", args.out_path) + logger.info("Saved VC output to %s", args.out_path) if __name__ == "__main__": diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py index f838297af3..411a9b0dbe 100644 --- a/TTS/demos/xtts_ft_demo/utils/gpt_train.py +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -5,7 +5,8 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig from TTS.utils.manage import ModelManager diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py index 4c3733e691..1bbf676393 100644 --- a/TTS/tts/layers/tortoise/arch_utils.py +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -70,11 +70,10 @@ def forward(self, qkv, mask=None, rel_pos=None): weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape( bs * self.n_heads, weight.shape[-2], weight.shape[-1] ) - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) if mask is not None: - # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. - mask = mask.repeat(self.n_heads, 1).unsqueeze(1) - weight = weight * mask + mask = mask.repeat(self.n_heads, 1, 1) + weight[mask.logical_not()] = -torch.inf + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) @@ -93,7 +92,9 @@ def __init__( channels, num_heads=1, num_head_channels=-1, + *, relative_pos_embeddings=False, + tortoise_norm=False, ): super().__init__() self.channels = channels @@ -108,6 +109,7 @@ def __init__( self.qkv = nn.Conv1d(channels, channels * 3, 1) # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) + self.tortoise_norm = tortoise_norm self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) if relative_pos_embeddings: @@ -124,10 +126,13 @@ def __init__( def forward(self, x, mask=None): b, c, *spatial = x.shape x = x.reshape(b, c, -1) - qkv = self.qkv(self.norm(x)) + x_norm = self.norm(x) + qkv = self.qkv(x_norm) h = self.attention(qkv, mask, self.relative_pos_embeddings) h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) + if self.tortoise_norm: + return (x + h).reshape(b, c, *spatial) + return (x_norm + h).reshape(b, c, *spatial) class Upsample(nn.Module): diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py index 07cf3d542b..00c884e973 100644 --- a/TTS/tts/layers/tortoise/autoregressive.py +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -176,12 +176,14 @@ def __init__( embedding_dim, attn_blocks=6, num_attn_heads=4, + *, + tortoise_norm=False, ): super().__init__() attn = [] self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=tortoise_norm)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim diff --git a/TTS/tts/layers/tortoise/classifier.py b/TTS/tts/layers/tortoise/classifier.py index c72834e9a8..337323db67 100644 --- a/TTS/tts/layers/tortoise/classifier.py +++ b/TTS/tts/layers/tortoise/classifier.py @@ -97,7 +97,7 @@ def __init__( self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) attn = [] for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=True)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim diff --git a/TTS/tts/layers/tortoise/diffusion_decoder.py b/TTS/tts/layers/tortoise/diffusion_decoder.py index 15bbfb7121..cfdeaff8bb 100644 --- a/TTS/tts/layers/tortoise/diffusion_decoder.py +++ b/TTS/tts/layers/tortoise/diffusion_decoder.py @@ -130,7 +130,7 @@ def __init__(self, model_channels, dropout, num_heads): dims=1, use_scale_shift_norm=True, ) - self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) + self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True) def forward(self, x, time_emb): y = self.resblk(x, time_emb) @@ -177,17 +177,17 @@ def __init__( # transformer network. self.code_embedding = nn.Embedding(in_tokens, model_channels) self.code_converter = nn.Sequential( - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), ) self.code_norm = normalization(model_channels) self.latent_conditioner = nn.Sequential( nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), ) self.contextual_embedder = nn.Sequential( nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2), @@ -196,26 +196,31 @@ def __init__( model_channels * 2, num_heads, relative_pos_embeddings=True, + tortoise_norm=True, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, + tortoise_norm=True, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, + tortoise_norm=True, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, + tortoise_norm=True, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, + tortoise_norm=True, ), ) self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1)) diff --git a/TTS/tts/layers/xtts/latent_encoder.py b/TTS/tts/layers/xtts/latent_encoder.py deleted file mode 100644 index 6becffb8b7..0000000000 --- a/TTS/tts/layers/xtts/latent_encoder.py +++ /dev/null @@ -1,95 +0,0 @@ -# ported from: Originally ported from: https://github.com/neonbjb/tortoise-tts - -import math - -import torch -from torch import nn -from torch.nn import functional as F - -from TTS.tts.layers.tortoise.arch_utils import normalization, zero_module - - -def conv_nd(dims, *args, **kwargs): - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class QKVAttention(nn.Module): - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv, mask=None, qk_bias=0): - """ - Apply QKV attention. - - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards - weight = weight + qk_bias - if mask is not None: - mask = mask.repeat(self.n_heads, 1, 1) - weight[mask.logical_not()] = -torch.inf - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - a = torch.einsum("bts,bcs->bct", weight, v) - - return a.reshape(bs, -1, length) - - -class AttentionBlock(nn.Module): - """An attention block that allows spatial positions to attend to each other.""" - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - out_channels=None, - do_activation=False, - ): - super().__init__() - self.channels = channels - out_channels = channels if out_channels is None else out_channels - self.do_activation = do_activation - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, out_channels * 3, 1) - self.attention = QKVAttention(self.num_heads) - - self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1) - self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1)) - - def forward(self, x, mask=None, qk_bias=0): - b, c, *spatial = x.shape - if mask is not None: - if len(mask.shape) == 2: - mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1) - if mask.shape[1] != x.shape[-1]: - mask = mask[:, : x.shape[-1], : x.shape[-1]] - - x = x.reshape(b, c, -1) - x = self.norm(x) - if self.do_activation: - x = F.silu(x, inplace=True) - qkv = self.qkv(x) - h = self.attention(qkv, mask=mask, qk_bias=qk_bias) - h = self.proj_out(h) - xp = self.x_proj(x) - return (xp + h).reshape(b, xp.shape[1], *spatial) diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 0253d65ddd..107054189c 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -18,7 +18,7 @@ from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig +from TTS.tts.models.xtts import Xtts, XttsArgs from TTS.utils.generic_utils import is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -34,11 +34,6 @@ class GPTTrainerConfig(XttsConfig): test_sentences: List[dict] = field(default_factory=lambda: []) -@dataclass -class XttsAudioConfig(XttsAudioConfig): - dvae_sample_rate: int = 22050 - - @dataclass class GPTArgs(XttsArgs): min_conditioning_length: int = 66150 diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 35de91e359..f05863ae1d 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass from pathlib import Path +from typing import Optional import librosa import torch @@ -101,10 +102,12 @@ class XttsAudioConfig(Coqpit): Args: sample_rate (int): The sample rate in which the GPT operates. output_sample_rate (int): The sample rate of the output audio waveform. + dvae_sample_rate (int): The sample rate of the DVAE """ sample_rate: int = 22050 output_sample_rate: int = 24000 + dvae_sample_rate: int = 22050 @dataclass @@ -719,14 +722,14 @@ def get_compatible_checkpoint_state_dict(self, model_path): def load_checkpoint( self, - config, - checkpoint_dir=None, - checkpoint_path=None, - vocab_path=None, - eval=True, - strict=True, - use_deepspeed=False, - speaker_file_path=None, + config: "XttsConfig", + checkpoint_dir: Optional[str] = None, + checkpoint_path: Optional[str] = None, + vocab_path: Optional[str] = None, + eval: bool = True, + strict: bool = True, + use_deepspeed: bool = False, + speaker_file_path: Optional[str] = None, ): """ Loads a checkpoint from disk and initializes the model's state and tokenizer. @@ -742,7 +745,9 @@ def load_checkpoint( Returns: None """ - + if checkpoint_dir is not None and Path(checkpoint_dir).is_file(): + msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead." + raise ValueError(msg) model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") if vocab_path is None: if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file(): diff --git a/docs/source/models/bark.md b/docs/source/models/bark.md index a180afbb91..77f99c0d3a 100644 --- a/docs/source/models/bark.md +++ b/docs/source/models/bark.md @@ -37,7 +37,7 @@ from TTS.api import TTS # Load the model to GPU # Bark is really slow on CPU, so we recommend using GPU. -tts = TTS("tts_models/multilingual/multi-dataset/bark", gpu=True) +tts = TTS("tts_models/multilingual/multi-dataset/bark").to("cuda") # Cloning a new speaker @@ -57,7 +57,7 @@ tts.tts_to_file(text="Hello, my name is Manmay , how are you?", # random speaker -tts = TTS("tts_models/multilingual/multi-dataset/bark", gpu=True) +tts = TTS("tts_models/multilingual/multi-dataset/bark").to("cuda") tts.tts_to_file("hello world", file_path="out.wav") ``` diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md index c07d879f7c..7c0f1c4a60 100644 --- a/docs/source/models/xtts.md +++ b/docs/source/models/xtts.md @@ -118,7 +118,7 @@ You can optionally disable sentence splitting for better coherence but more VRAM ```python from TTS.api import TTS -tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True) +tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to("cuda") # generate speech by cloning a voice using default settings tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", @@ -137,15 +137,15 @@ You can pass multiple audio files to the `speaker_wav` argument for better voice from TTS.api import TTS # using the default version set in 🐸TTS -tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True) +tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to("cuda") # using a specific version # 👀 see the branch names for versions on https://huggingface.co/coqui/XTTS-v2/tree/main # ❗some versions might be incompatible with the API -tts = TTS("xtts_v2.0.2", gpu=True) +tts = TTS("xtts_v2.0.2").to("cuda") # getting the latest XTTS_v2 -tts = TTS("xtts", gpu=True) +tts = TTS("xtts").to("cuda") # generate speech by cloning a voice using default settings tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", @@ -160,7 +160,7 @@ You can do inference using one of the available speakers using the following cod ```python from TTS.api import TTS -tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True) +tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to("cuda") # generate speech by cloning a voice using default settings tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", diff --git a/pyproject.toml b/pyproject.toml index 5386d274ac..bf0a1d88c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ build-backend = "hatchling.build" [project] name = "coqui-tts" -version = "0.25.0" +version = "0.25.1" description = "Deep learning for Text to Speech." readme = "README.md" requires-python = ">=3.9, <3.13" diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index d31ec8f1ed..a077a18064 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -4,7 +4,8 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig from TTS.utils.manage import ModelManager # Logging parameters diff --git a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py index ccaa97f1e4..362f45008e 100644 --- a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py @@ -4,7 +4,8 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig from TTS.utils.manage import ModelManager # Logging parameters diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index b8b9a4e388..bb592f1f2d 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -8,7 +8,8 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.layers.xtts.dvae import DiscreteVAE -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig config_dataset = BaseDatasetConfig( formatter="ljspeech", diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index 6663433c12..454e867385 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -8,7 +8,8 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.layers.xtts.dvae import DiscreteVAE -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig config_dataset = BaseDatasetConfig( formatter="ljspeech", diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index b944423988..f38880b51f 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -34,30 +34,27 @@ def run_models(offset=0, step=1): # download and run the model speaker_files = glob.glob(local_download_dir + "/speaker*") language_files = glob.glob(local_download_dir + "/language*") - language_id = "" + speaker_arg = "" + language_arg = "" if len(speaker_files) > 0: # multi-speaker model if "speaker_ids" in speaker_files[0]: speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0]) elif "speakers" in speaker_files[0]: speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0]) - - # multi-lingual model - Assuming multi-lingual models are also multi-speaker - if len(language_files) > 0 and "language_ids" in language_files[0]: - language_manager = LanguageManager(language_ids_file_path=language_files[0]) - language_id = language_manager.language_names[0] - - speaker_id = list(speaker_manager.name_to_id.keys())[0] - run_cli( - f"tts --model_name {model_name} " - f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" --no-progress_bar' - ) - else: - # single-speaker model - run_cli( - f"tts --model_name {model_name} " - f'--text "This is an example." --out_path "{output_path}" --no-progress_bar' - ) + speakers = list(speaker_manager.name_to_id.keys()) + if len(speakers) > 1: + speaker_arg = f'--speaker_idx "{speakers[0]}"' + if len(language_files) > 0 and "language_ids" in language_files[0]: + # multi-lingual model + language_manager = LanguageManager(language_ids_file_path=language_files[0]) + languages = language_manager.language_names + if len(languages) > 1: + language_arg = f'--language_idx "{languages[0]}"' + run_cli( + f'tts --model_name {model_name} --text "This is an example." ' + f'--out_path "{output_path}" {speaker_arg} {language_arg} --no-progress_bar' + ) # remove downloaded models shutil.rmtree(local_download_dir) shutil.rmtree(get_user_data_dir("tts"))