From 08dfbd134c3cf797722e00e84f4800309f7841f1 Mon Sep 17 00:00:00 2001 From: autumn-2-net <109412646+autumn-2-net@users.noreply.github.com> Date: Sat, 3 Feb 2024 20:20:16 +0800 Subject: [PATCH 1/2] Compatible with torch2.2 --- modules/nsf_hifigan/models.py | 62 ++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 15 deletions(-) diff --git a/modules/nsf_hifigan/models.py b/modules/nsf_hifigan/models.py index a77eb0a3..d5c32c7c 100644 --- a/modules/nsf_hifigan/models.py +++ b/modules/nsf_hifigan/models.py @@ -7,13 +7,19 @@ import torch.nn.functional as F from lightning.pytorch.utilities.rank_zero import rank_zero_info from torch.nn import Conv1d, ConvTranspose1d -from torch.nn.utils import weight_norm, remove_weight_norm +# from torch.nn.utils import weight_norm, remove_weight_norm from .env import AttrDict from .utils import init_weights, get_padding LRELU_SLOPE = 0.1 - +_OLD_WEIGHT_NORM = False +try: + from torch.nn.utils.parametrizations import weight_norm +except ImportError: + from torch.nn.utils import weight_norm + from torch.nn.utils import remove_weight_norm + _OLD_WEIGHT_NORM = True def load_model(model_path: pathlib.Path): config_file = model_path.with_name('config.json') @@ -67,10 +73,17 @@ def forward(self, x): return x def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) + global _OLD_WEIGHT_NORM + if _OLD_WEIGHT_NORM: + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + else: + for l in self.convs1: + torch.nn.utils.parametrize.remove_parametrizations(l) + for l in self.convs2: + torch.nn.utils.parametrize.remove_parametrizations(l) class ResBlock2(torch.nn.Module): @@ -93,8 +106,15 @@ def forward(self, x): return x def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) + + global _OLD_WEIGHT_NORM + if _OLD_WEIGHT_NORM: + for l in self.convs: + remove_weight_norm(l) + + else: + for l in self.convs: + torch.nn.utils.parametrize.remove_parametrizations(l) class SineGen(torch.nn.Module): @@ -285,10 +305,22 @@ def forward(self, x, f0): return x def remove_weight_norm(self): - rank_zero_info('Removing weight norm...') - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) + # rank_zero_info('Removing weight norm...') + print('Removing weight norm...') + global _OLD_WEIGHT_NORM + if _OLD_WEIGHT_NORM: + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + else: + for l in self.ups: + torch.nn.utils.parametrize.remove_parametrizations(l) + for l in self.resblocks: + l.remove_weight_norm() + + torch.nn.utils.parametrize.remove_parametrizations(self.conv_pre) + torch.nn.utils.parametrize.remove_parametrizations(self.conv_post) From cd3129f852a601777219ba1ab74de5e1bb5a73f8 Mon Sep 17 00:00:00 2001 From: autumn-2-net <109412646+autumn-2-net@users.noreply.github.com> Date: Sat, 3 Feb 2024 23:19:56 +0800 Subject: [PATCH 2/2] fix --- modules/nsf_hifigan/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/nsf_hifigan/models.py b/modules/nsf_hifigan/models.py index d5c32c7c..59300ebb 100644 --- a/modules/nsf_hifigan/models.py +++ b/modules/nsf_hifigan/models.py @@ -113,7 +113,7 @@ def remove_weight_norm(self): remove_weight_norm(l) else: - for l in self.convs: + for l in self.convs: torch.nn.utils.parametrize.remove_parametrizations(l)