diff --git a/modules/nsf_hifigan/models.py b/modules/nsf_hifigan/models.py index a77eb0a3..59300ebb 100644 --- a/modules/nsf_hifigan/models.py +++ b/modules/nsf_hifigan/models.py @@ -7,13 +7,19 @@ import torch.nn.functional as F from lightning.pytorch.utilities.rank_zero import rank_zero_info from torch.nn import Conv1d, ConvTranspose1d -from torch.nn.utils import weight_norm, remove_weight_norm +# from torch.nn.utils import weight_norm, remove_weight_norm from .env import AttrDict from .utils import init_weights, get_padding LRELU_SLOPE = 0.1 - +_OLD_WEIGHT_NORM = False +try: + from torch.nn.utils.parametrizations import weight_norm +except ImportError: + from torch.nn.utils import weight_norm + from torch.nn.utils import remove_weight_norm + _OLD_WEIGHT_NORM = True def load_model(model_path: pathlib.Path): config_file = model_path.with_name('config.json') @@ -67,10 +73,17 @@ def forward(self, x): return x def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) + global _OLD_WEIGHT_NORM + if _OLD_WEIGHT_NORM: + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + else: + for l in self.convs1: + torch.nn.utils.parametrize.remove_parametrizations(l) + for l in self.convs2: + torch.nn.utils.parametrize.remove_parametrizations(l) class ResBlock2(torch.nn.Module): @@ -93,8 +106,15 @@ def forward(self, x): return x def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) + + global _OLD_WEIGHT_NORM + if _OLD_WEIGHT_NORM: + for l in self.convs: + remove_weight_norm(l) + + else: + for l in self.convs: + torch.nn.utils.parametrize.remove_parametrizations(l) class SineGen(torch.nn.Module): @@ -285,10 +305,22 @@ def forward(self, x, f0): return x def remove_weight_norm(self): - rank_zero_info('Removing weight norm...') - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) + # rank_zero_info('Removing weight norm...') + print('Removing weight norm...') + global _OLD_WEIGHT_NORM + if _OLD_WEIGHT_NORM: + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + else: + for l in self.ups: + torch.nn.utils.parametrize.remove_parametrizations(l) + for l in self.resblocks: + l.remove_weight_norm() + + torch.nn.utils.parametrize.remove_parametrizations(self.conv_pre) + torch.nn.utils.parametrize.remove_parametrizations(self.conv_post)