From a07e582dd49cfc7617803c8ddd2c8e00a7a1f482 Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 12:41:17 +0900 Subject: [PATCH 01/14] support inference method --- parallel_wavegan/models/melgan.py | 16 +++++++++++++ parallel_wavegan/models/parallel_wavegan.py | 25 +++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/parallel_wavegan/models/melgan.py b/parallel_wavegan/models/melgan.py index 429419ca..9e7e8e3d 100644 --- a/parallel_wavegan/models/melgan.py +++ b/parallel_wavegan/models/melgan.py @@ -144,6 +144,9 @@ def __init__(self, # reset parameters self.reset_parameters() + # initialize pqmf for inference + self.pqmf = None + def forward(self, c): """Calculate forward propagation. @@ -190,6 +193,19 @@ def _reset_parameters(m): self.apply(_reset_parameters) + def inference(self, c): + """Perform inference. + + Args: + c (Tensor): Input tensor (T, in_channels). + + Returns: + Tensor: Output tensor (T ** prod(upsample_scales), out_channels). + + """ + c = self.melgan(c.transpose(1, 0).unsqueeze(0)) + return c.squeeze(0).transpose(1, 0) + class MelGANDiscriminator(torch.nn.Module): """MelGAN discriminator module.""" diff --git a/parallel_wavegan/models/parallel_wavegan.py b/parallel_wavegan/models/parallel_wavegan.py index 986838b6..2d98b10e 100644 --- a/parallel_wavegan/models/parallel_wavegan.py +++ b/parallel_wavegan/models/parallel_wavegan.py @@ -8,6 +8,7 @@ import logging import math +import numpy as np import torch from parallel_wavegan.layers import Conv1d @@ -66,6 +67,7 @@ def __init__(self, self.in_channels = in_channels self.out_channels = out_channels self.aux_channels = aux_channels + self.aux_context_window = aux_context_window self.layers = layers self.stacks = stacks self.kernel_size = kernel_size @@ -96,8 +98,10 @@ def __init__(self, "aux_context_window": aux_context_window, }) self.upsample_net = getattr(upsample, upsample_net)(**upsample_params) + self.upsample_factor = np.prod(upsample_params["upsample_scales"]) else: self.upsample_net = None + self.upsample_factor = 0 # define residual blocks self.conv_layers = torch.nn.ModuleList() @@ -192,6 +196,27 @@ def receptive_field_size(self): """Return receptive field size.""" return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + def inference(self, c=None, x=None): + """Perform inference. + + Args: + c (Tensor): Local conditioning auxiliary features (T' ,C). + x (Tensor): Input noise signal (T, 1). + + Returns: + Tensor: Output tensor (T, out_channels) + + """ + if x is not None: + x = x.transpose(1, 0).unsqueeze(0) + else: + assert c is not None + x = c.new_tensor(torch.randn(1, 1, len(c) * self.upsample_factor)) + if c is not None: + c = c.transpose(1, 0).unsqueeze(0) + c = torch.nn.ReplicationPad1d(self.aux_context_window)(c) + return self.forward(x, c).squeeze(0).transpose(1, 0) + class ParallelWaveGANDiscriminator(torch.nn.Module): """Parallel WaveGAN Discriminator module.""" From 6ae2812a195023ba13b3a8fb561fca6d8c24b1fc Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 12:41:30 +0900 Subject: [PATCH 02/14] fixed default pytorch version --- tools/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/Makefile b/tools/Makefile index 07caf2a9..7a41fd45 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -1,6 +1,6 @@ PYTHON:= python3.6 CUDA_VERSION:= 10.0 -PYTORCH_VERSION:= 1.5 +PYTORCH_VERSION:= 1.4 DOT:= . .PHONY: all clean From 162332cf998c31842b2022e3a09c6a19b6a6088c Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 12:43:19 +0900 Subject: [PATCH 03/14] use inference method for simplicity --- parallel_wavegan/bin/decode.py | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/parallel_wavegan/bin/decode.py b/parallel_wavegan/bin/decode.py index db906875..6c2278a5 100755 --- a/parallel_wavegan/bin/decode.py +++ b/parallel_wavegan/bin/decode.py @@ -109,45 +109,33 @@ def main(): device = torch.device("cpu") model_class = getattr( parallel_wavegan.models, - config.get("generator_type", "ParallelWaveGANGenerator")) + config.get("generator_type", "ParallelWaveGANGenerator") + ) model = model_class(**config["generator_params"]) model.load_state_dict( - torch.load(args.checkpoint, map_location="cpu")["model"]["generator"]) + torch.load(args.checkpoint, map_location="cpu")["model"]["generator"] + ) logging.info(f"Loaded model parameters from {args.checkpoint}.") - model.remove_weight_norm() - model = model.eval().to(device) - use_noise_input = not isinstance( - model, parallel_wavegan.models.MelGANGenerator) - pad_fn = torch.nn.ReplicationPad1d( - config["generator_params"].get("aux_context_window", 0)) if config["generator_params"]["out_channels"] > 1: pqmf_params = {} if LooseVersion(config.get("version", "0.1.0")) <= LooseVersion("0.4.2"): # For compatibility, here we set default values in version <= 0.4.2 pqmf_params.update(taps=62, cutoff_ratio=0.15, beta=9.0, use_legacy=True) - pqmf = PQMF( + model.pqmf = PQMF( subbands=config["generator_params"]["out_channels"], **config.get("pqmf_params", pqmf_params), - ).to(device) + ) + model.remove_weight_norm() + model = model.eval().to(device) # start generation total_rtf = 0.0 with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar: for idx, (utt_id, c) in enumerate(pbar, 1): - # setup input - x = () - if use_noise_input: - z = torch.randn(1, 1, len(c) * config["hop_size"]).to(device) - x += (z,) - c = pad_fn(torch.tensor(c, dtype=torch.float).unsqueeze(0).transpose(2, 1)).to(device) - x += (c,) - # generate + c = torch.tensor(c, dtype=torch.float).to(device) start = time.time() - if config["generator_params"]["out_channels"] == 1: - y = model(*x).view(-1).cpu().numpy() - else: - y = pqmf.synthesis(model(*x)).view(-1).cpu().numpy() + y = model.inference(c).view(-1).cpu().numpy() rtf = (time.time() - start) / (len(y) / config["sampling_rate"]) pbar.set_postfix({"RTF": rtf}) total_rtf += rtf From 6b893162dfd1bd24fb976645345281bb32dfcb57 Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 12:44:07 +0900 Subject: [PATCH 04/14] relax dependency --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7d71f754..154cdd48 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ requirements = { "install": [ - "torch>=1.0.1,<=1.5.1", + "torch>=1.0.1", "setuptools>=38.5.1", "librosa>=0.7.0", "soundfile>=0.10.2", From f76af273c81ce012575326bfd00bb16daced85f5 Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 12:45:02 +0900 Subject: [PATCH 05/14] fixed --- parallel_wavegan/models/parallel_wavegan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parallel_wavegan/models/parallel_wavegan.py b/parallel_wavegan/models/parallel_wavegan.py index 2d98b10e..2cb087f2 100644 --- a/parallel_wavegan/models/parallel_wavegan.py +++ b/parallel_wavegan/models/parallel_wavegan.py @@ -101,7 +101,7 @@ def __init__(self, self.upsample_factor = np.prod(upsample_params["upsample_scales"]) else: self.upsample_net = None - self.upsample_factor = 0 + self.upsample_factor = 1 # define residual blocks self.conv_layers = torch.nn.ModuleList() From ef9f040bbfd5f2ffec236a89dfcb9d4e878f8fe3 Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 12:47:10 +0900 Subject: [PATCH 06/14] removed use_legacy option --- parallel_wavegan/bin/decode.py | 2 +- parallel_wavegan/layers/pqmf.py | 55 ++++++++------------------------- 2 files changed, 14 insertions(+), 43 deletions(-) diff --git a/parallel_wavegan/bin/decode.py b/parallel_wavegan/bin/decode.py index 6c2278a5..4e16d371 100755 --- a/parallel_wavegan/bin/decode.py +++ b/parallel_wavegan/bin/decode.py @@ -120,7 +120,7 @@ def main(): pqmf_params = {} if LooseVersion(config.get("version", "0.1.0")) <= LooseVersion("0.4.2"): # For compatibility, here we set default values in version <= 0.4.2 - pqmf_params.update(taps=62, cutoff_ratio=0.15, beta=9.0, use_legacy=True) + pqmf_params.update(taps=62, cutoff_ratio=0.15, beta=9.0) model.pqmf = PQMF( subbands=config["generator_params"]["out_channels"], **config.get("pqmf_params", pqmf_params), diff --git a/parallel_wavegan/layers/pqmf.py b/parallel_wavegan/layers/pqmf.py index c39ae756..f5179a99 100644 --- a/parallel_wavegan/layers/pqmf.py +++ b/parallel_wavegan/layers/pqmf.py @@ -58,7 +58,7 @@ class PQMF(torch.nn.Module): """ - def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0, use_legacy=False): + def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0): """Initilize PQMF module. The cutoff_ratio and beta parameters are optimized for #subbands = 4. @@ -69,17 +69,23 @@ def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0, use_legacy taps (int): The number of filter taps. cutoff_ratio (float): Cut-off frequency ratio. beta (float): Beta coefficient for kaiser window. - use_legacy (bool): Whether to use legacy PQMF coefficients (for <= 0.4.2). """ super(PQMF, self).__init__() - # define filter coefficient + # build analysis & synthesis filter coefficients h_proto = design_prototype_filter(taps, cutoff_ratio, beta) - if use_legacy: - h_analysis, h_synthesis = self._build_filter_legacy(h_proto, subbands, taps) - else: - h_analysis, h_synthesis = self._build_filter(h_proto, subbands, taps) + h_analysis = np.zeros((subbands, len(h_proto))) + h_synthesis = np.zeros((subbands, len(h_proto))) + for k in range(subbands): + h_analysis[k] = 2 * h_proto * np.cos( + (2 * k + 1) * (np.pi / (2 * subbands)) * + (np.arange(taps + 1) - (taps / 2)) + + (-1) ** k * np.pi / 4) + h_synthesis[k] = 2 * h_proto * np.cos( + (2 * k + 1) * (np.pi / (2 * subbands)) * + (np.arange(taps + 1) - (taps / 2)) - + (-1) ** k * np.pi / 4) # convert to tensor analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1) @@ -99,41 +105,6 @@ def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0, use_legacy # keep padding info self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) - def _build_filter(self, h_proto, subbands, taps): - # build analysis & synthesis filter coefficients - h_analysis = np.zeros((subbands, len(h_proto))) - h_synthesis = np.zeros((subbands, len(h_proto))) - for k in range(subbands): - h_analysis[k] = 2 * h_proto * np.cos( - (2 * k + 1) * (np.pi / (2 * subbands)) * - (np.arange(taps + 1) - (taps / 2)) + - (-1) ** k * np.pi / 4) - h_synthesis[k] = 2 * h_proto * np.cos( - (2 * k + 1) * (np.pi / (2 * subbands)) * - (np.arange(taps + 1) - (taps / 2)) - - (-1) ** k * np.pi / 4) - - return h_analysis, h_synthesis - - def _build_filter_legacy(self, h_proto, subbands, taps): - # NOTE(kan-bayashi): legacy version is for the <= 0.4.2 compatibility - # build analysis & synthesis filter coefficients - h_analysis = np.zeros((subbands, len(h_proto))) - h_synthesis = np.zeros((subbands, len(h_proto))) - for k in range(subbands): - h_analysis[k] = 2 * h_proto * np.cos( - (2 * k + 1) * (np.pi / (2 * subbands)) * - # NOTE(kan-bayashi): (taps - 1) is used for <= v.0.4.2 - (np.arange(taps + 1) - ((taps - 1) / 2)) + - (-1) ** k * np.pi / 4) - h_synthesis[k] = 2 * h_proto * np.cos( - (2 * k + 1) * (np.pi / (2 * subbands)) * - # NOTE(kan-bayashi): (taps - 1) is used for <= v.0.4.2 - (np.arange(taps + 1) - ((taps - 1) / 2)) - - (-1) ** k * np.pi / 4) - - return h_analysis, h_synthesis - def analysis(self, x): """Analysis with PQMF. From 1b330f6051743fce33fb3c28f25f0fd7c98d2f8b Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 12:52:22 +0900 Subject: [PATCH 07/14] updated --- parallel_wavegan/models/melgan.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/parallel_wavegan/models/melgan.py b/parallel_wavegan/models/melgan.py index 9e7e8e3d..8a161d30 100644 --- a/parallel_wavegan/models/melgan.py +++ b/parallel_wavegan/models/melgan.py @@ -204,6 +204,8 @@ def inference(self, c): """ c = self.melgan(c.transpose(1, 0).unsqueeze(0)) + if self.pqmf is not None: + c = self.pqmf.synthesis(c) return c.squeeze(0).transpose(1, 0) From caf7e5da457a9a89bbc4843f1b8532cd507da862 Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 12:58:47 +0900 Subject: [PATCH 08/14] fixed userwarning --- parallel_wavegan/models/parallel_wavegan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parallel_wavegan/models/parallel_wavegan.py b/parallel_wavegan/models/parallel_wavegan.py index 2cb087f2..ee38fd86 100644 --- a/parallel_wavegan/models/parallel_wavegan.py +++ b/parallel_wavegan/models/parallel_wavegan.py @@ -211,7 +211,7 @@ def inference(self, c=None, x=None): x = x.transpose(1, 0).unsqueeze(0) else: assert c is not None - x = c.new_tensor(torch.randn(1, 1, len(c) * self.upsample_factor)) + x = torch.randn(1, 1, len(c) * self.upsample_factor).to(c.device) if c is not None: c = c.transpose(1, 0).unsqueeze(0) c = torch.nn.ReplicationPad1d(self.aux_context_window)(c) From 4452874265235e11e21ce4b9c1ae1c415dbfd77e Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 13:05:34 +0900 Subject: [PATCH 09/14] fixed --- parallel_wavegan/bin/decode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parallel_wavegan/bin/decode.py b/parallel_wavegan/bin/decode.py index 4e16d371..fa1eeb3c 100755 --- a/parallel_wavegan/bin/decode.py +++ b/parallel_wavegan/bin/decode.py @@ -135,14 +135,14 @@ def main(): # generate c = torch.tensor(c, dtype=torch.float).to(device) start = time.time() - y = model.inference(c).view(-1).cpu().numpy() + y = model.inference(c).view(-1) rtf = (time.time() - start) / (len(y) / config["sampling_rate"]) pbar.set_postfix({"RTF": rtf}) total_rtf += rtf # save as PCM 16 bit wav file sf.write(os.path.join(config["outdir"], f"{utt_id}_gen.wav"), - y, config["sampling_rate"], "PCM_16") + y.cpu().numpy(), config["sampling_rate"], "PCM_16") # report average RTF logging.info(f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f}).") From 5143953758f16f2c8fa1ec6f2389e22495645f65 Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 13:26:54 +0900 Subject: [PATCH 10/14] added load model function --- parallel_wavegan/utils/utils.py | 50 +++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/parallel_wavegan/utils/utils.py b/parallel_wavegan/utils/utils.py index c970d2e2..9892b9de 100644 --- a/parallel_wavegan/utils/utils.py +++ b/parallel_wavegan/utils/utils.py @@ -10,8 +10,16 @@ import os import sys +from distutils.version import LooseVersion + import h5py import numpy as np +import torch +import yaml + +import parallel_wavegan.models + +from parallel_wavegan.layers import PQMF def find_files(root_dir, query="*.wav", include_root_dir=True): @@ -240,3 +248,45 @@ def values(self): """Return the values of the scp file.""" for key in self.keys(): yield self[key] + + +def load_model(checkpoint, config=None): + """Load trained model. + + Args: + checkpoint (str): Checkpoint path. + config (dict): Configuration dict. + + Return: + torch.nn.Module: Model instance. + + """ + # load config if not provided + if config is None: + dirname = os.path.dirname(checkpoint) + config = os.path.join(dirname, "config.yml") + with open(config) as f: + config = yaml.load(f, Loader=yaml.Loader) + + # get model and load parameters + model_class = getattr( + parallel_wavegan.models, + config.get("generator_type", "ParallelWaveGANGenerator") + ) + model = model_class(**config["generator_params"]) + model.load_state_dict( + torch.load(checkpoint, map_location="cpu")["model"]["generator"] + ) + + # add pqmf if needed + if config["generator_params"]["out_channels"] > 1: + pqmf_params = {} + if LooseVersion(config.get("version", "0.1.0")) <= LooseVersion("0.4.2"): + # For compatibility, here we set default values in version <= 0.4.2 + pqmf_params.update(taps=62, cutoff_ratio=0.15, beta=9.0) + model.pqmf = PQMF( + subbands=config["generator_params"]["out_channels"], + **config.get("pqmf_params", pqmf_params), + ) + + return model From 44a31ce96153baf7701a6e9b582fa6e1df958a35 Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 13:27:14 +0900 Subject: [PATCH 11/14] use load_model function --- parallel_wavegan/bin/decode.py | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/parallel_wavegan/bin/decode.py b/parallel_wavegan/bin/decode.py index fa1eeb3c..62898b0d 100755 --- a/parallel_wavegan/bin/decode.py +++ b/parallel_wavegan/bin/decode.py @@ -11,8 +11,6 @@ import os import time -from distutils.version import LooseVersion - import numpy as np import soundfile as sf import torch @@ -20,11 +18,9 @@ from tqdm import tqdm -import parallel_wavegan.models - from parallel_wavegan.datasets import MelDataset from parallel_wavegan.datasets import MelSCPDataset -from parallel_wavegan.layers import PQMF +from parallel_wavegan.utils import load_model from parallel_wavegan.utils import read_hdf5 @@ -102,29 +98,13 @@ def main(): ) logging.info(f"The number of features to be decoded = {len(dataset)}.") - # setup + # setup model if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") - model_class = getattr( - parallel_wavegan.models, - config.get("generator_type", "ParallelWaveGANGenerator") - ) - model = model_class(**config["generator_params"]) - model.load_state_dict( - torch.load(args.checkpoint, map_location="cpu")["model"]["generator"] - ) + model = load_model(args.checkpoint, config) logging.info(f"Loaded model parameters from {args.checkpoint}.") - if config["generator_params"]["out_channels"] > 1: - pqmf_params = {} - if LooseVersion(config.get("version", "0.1.0")) <= LooseVersion("0.4.2"): - # For compatibility, here we set default values in version <= 0.4.2 - pqmf_params.update(taps=62, cutoff_ratio=0.15, beta=9.0) - model.pqmf = PQMF( - subbands=config["generator_params"]["out_channels"], - **config.get("pqmf_params", pqmf_params), - ) model.remove_weight_norm() model = model.eval().to(device) From c8d59e15bf61061c4eb5182dad48a63ad0e230c8 Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 14:40:10 +0900 Subject: [PATCH 12/14] updated version --- parallel_wavegan/models/parallel_wavegan.py | 10 +++++++--- setup.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/parallel_wavegan/models/parallel_wavegan.py b/parallel_wavegan/models/parallel_wavegan.py index ee38fd86..e82c4b0d 100644 --- a/parallel_wavegan/models/parallel_wavegan.py +++ b/parallel_wavegan/models/parallel_wavegan.py @@ -200,19 +200,23 @@ def inference(self, c=None, x=None): """Perform inference. Args: - c (Tensor): Local conditioning auxiliary features (T' ,C). - x (Tensor): Input noise signal (T, 1). + c (Union[Tensor, ndarray]): Local conditioning auxiliary features (T' ,C). + x (Union[Tensor, ndarray]): Input noise signal (T, 1). Returns: Tensor: Output tensor (T, out_channels) """ if x is not None: + if not isinstance(x, torch.Tensor): + x = torch.tensor(x, dtype=torch.float).to(next(self.parameters()).device) x = x.transpose(1, 0).unsqueeze(0) else: assert c is not None - x = torch.randn(1, 1, len(c) * self.upsample_factor).to(c.device) + x = torch.randn(1, 1, len(c) * self.upsample_factor).to(next(self.parameters()).device) if c is not None: + if not isinstance(c, torch.Tensor): + c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device) c = c.transpose(1, 0).unsqueeze(0) c = torch.nn.ReplicationPad1d(self.aux_context_window)(c) return self.forward(x, c).squeeze(0).transpose(1, 0) diff --git a/setup.py b/setup.py index 154cdd48..2c32b6b8 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ dirname = os.path.dirname(__file__) setup(name="parallel_wavegan", - version="0.4.3", + version="0.4.4", url="http://github.com/kan-bayashi/ParallelWaveGAN", author="Tomoki Hayashi", author_email="hayashi.tomoki@g.sp.m.is.nagoya-u.ac.jp", From 84c052e8c3f0e8114408125a89ea49e3dca699d4 Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 14:40:27 +0900 Subject: [PATCH 13/14] updated version --- parallel_wavegan/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parallel_wavegan/__init__.py b/parallel_wavegan/__init__.py index ad668cca..cfb0319b 100644 --- a/parallel_wavegan/__init__.py +++ b/parallel_wavegan/__init__.py @@ -1,3 +1,3 @@ # -*- coding: utf-8 -*- -__version__ = "0.4.1" +__version__ = "0.4.4" From 2236aa83465a21de1871e135c411b48620aaa29a Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 14:41:23 +0900 Subject: [PATCH 14/14] support ndarray input for inference --- parallel_wavegan/models/melgan.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/parallel_wavegan/models/melgan.py b/parallel_wavegan/models/melgan.py index 8a161d30..caf7e050 100644 --- a/parallel_wavegan/models/melgan.py +++ b/parallel_wavegan/models/melgan.py @@ -197,12 +197,14 @@ def inference(self, c): """Perform inference. Args: - c (Tensor): Input tensor (T, in_channels). + c (Union[Tensor, ndarray]): Input tensor (T, in_channels). Returns: Tensor: Output tensor (T ** prod(upsample_scales), out_channels). """ + if not isinstance(c, torch.Tensor): + c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device) c = self.melgan(c.transpose(1, 0).unsqueeze(0)) if self.pqmf is not None: c = self.pqmf.synthesis(c)