From bd2f992e7ea719a84f629b96ab5d7e5a87744669 Mon Sep 17 00:00:00 2001 From: Daniel Walmsley Date: Sat, 15 Jun 2024 09:18:13 -0700 Subject: [PATCH] Make it work on mps --- TTS/tts/layers/xtts/stream_generator.py | 44 +++++++++++++++++++++---- requirements.txt | 7 ++-- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index dd07e9dc07..fa8b9c730f 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -182,14 +182,44 @@ def generate( accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: - pad_token_tensor = torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device) if generation_config.pad_token_id is not None else None - eos_token_tensor = torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device) if generation_config.eos_token_id is not None else None - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, - pad_token_tensor, - eos_token_tensor, + if ( + model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask + ): + pad_token_tensor = ( + torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device) + if generation_config.pad_token_id is not None + else None ) + eos_token_tensor = ( + torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device) + if generation_config.eos_token_id is not None + else None + ) + + # hack to produce attention mask for mps devices since transformers bails but pytorch supports torch.isin on mps now + # for this to work, you must run with PYTORCH_ENABLE_MPS_FALLBACK=1 and call model.to(mps_device) on the XttsModel + if inputs_tensor.device.type == "mps": + default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device) + + is_pad_token_in_inputs = (pad_token_tensor is not None) and ( + torch.isin(elements=inputs_tensor, test_elements=pad_token_tensor).any() + ) + is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~( + torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any() + ) + can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id + attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long() + + model_kwargs["attention_mask"] = ( + attention_mask_from_padding * can_infer_attention_mask + + default_attention_mask * ~can_infer_attention_mask + ) + else: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, + pad_token_tensor, + eos_token_tensor, + ) # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: diff --git a/requirements.txt b/requirements.txt index 2944e6face..2a48e71c95 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,9 @@ numpy==1.22.0;python_version<="3.10" numpy>=1.24.3;python_version>"3.10" cython>=0.29.30 scipy>=1.11.2 -torch>=2.1 -torchaudio +torch==2.3.1 +torchaudio==2.3.1 +torchvision==0.18.1 soundfile>=0.12.0 librosa>=0.10.0 scikit-learn>=1.3.0 @@ -48,7 +49,7 @@ bnnumerizer bnunicodenormalizer #deps for tortoise einops>=0.6.0 -transformers>=4.33.0 +transformers>=4.41.2 #deps for bark encodec>=0.1.1 # deps for XTTS