Skip to content

Commit

Permalink
fix: torch.stft will soon require return_complex=True
Browse files Browse the repository at this point in the history
Refactor that removes the deprecation warning:
torch.view_as_real(torch.stft(*, return_complex=True)) is equal to
torch.stft(*, return_complex=False)

https://pytorch.org/docs/stable/generated/torch.stft.html
  • Loading branch information
eginhard committed Mar 13, 2024
1 parent 0c6c20f commit e95f895
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 77 deletions.
48 changes: 26 additions & 22 deletions TTS/tts/models/delightful_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,19 @@ def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
)
y = y.squeeze(1)

spec = torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)

return spec
Expand Down Expand Up @@ -274,17 +276,19 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
)
y = y.squeeze(1)

spec = torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
Expand Down
48 changes: 26 additions & 22 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,19 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
)
y = y.squeeze(1)

spec = torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
Expand Down Expand Up @@ -189,17 +191,19 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
)
y = y.squeeze(1)

spec = torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
Expand Down
24 changes: 13 additions & 11 deletions TTS/utils/audio/torch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,19 @@ def __call__(self, x):
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
# B x D x T x 2
o = torch.stft(
x.squeeze(1),
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=self.normalized,
onesided=True,
return_complex=False,
o = torch.view_as_real(
torch.stft(
x.squeeze(1),
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=self.normalized,
onesided=True,
return_complex=True,
)
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
Expand Down
48 changes: 26 additions & 22 deletions TTS/vc/modules/freevc/mel_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,19 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
)
y = y.squeeze(1)

spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
Expand Down Expand Up @@ -104,17 +106,19 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size,
)
y = y.squeeze(1)

spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
Expand Down

0 comments on commit e95f895

Please sign in to comment.