diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index b1cf886bea..a4aa563f48 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -179,17 +179,19 @@ def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) return spec @@ -274,17 +276,19 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index e91d26b9ed..0d2187d206 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -121,17 +121,19 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False): ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) @@ -189,17 +191,19 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py index fd40ebb048..632969c51a 100644 --- a/TTS/utils/audio/torch_transforms.py +++ b/TTS/utils/audio/torch_transforms.py @@ -119,17 +119,19 @@ def __call__(self, x): padding = int((self.n_fft - self.hop_length) / 2) x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") # B x D x T x 2 - o = torch.stft( - x.squeeze(1), - self.n_fft, - self.hop_length, - self.win_length, - self.window, - center=True, - pad_mode="reflect", # compatible with audio.py - normalized=self.normalized, - onesided=True, - return_complex=False, + o = torch.view_as_real( + torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=self.normalized, + onesided=True, + return_complex=True, + ) ) M = o[:, :, :, 0] P = o[:, :, :, 1] diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/modules/freevc/mel_processing.py index 2dcbf21493..1955e758ac 100644 --- a/TTS/vc/modules/freevc/mel_processing.py +++ b/TTS/vc/modules/freevc/mel_processing.py @@ -54,17 +54,19 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) @@ -104,17 +106,19 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)