Skip to content

Commit

Permalink
Merge pull request #206 from kan-bayashi/inference
Browse files Browse the repository at this point in the history
  • Loading branch information
kan-bayashi authored Aug 18, 2020
2 parents 9a399c0 + 2236aa8 commit 6041c9f
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 84 deletions.
2 changes: 1 addition & 1 deletion parallel_wavegan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# -*- coding: utf-8 -*-

__version__ = "0.4.1"
__version__ = "0.4.4"
44 changes: 6 additions & 38 deletions parallel_wavegan/bin/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,16 @@
import os
import time

from distutils.version import LooseVersion

import numpy as np
import soundfile as sf
import torch
import yaml

from tqdm import tqdm

import parallel_wavegan.models

from parallel_wavegan.datasets import MelDataset
from parallel_wavegan.datasets import MelSCPDataset
from parallel_wavegan.layers import PQMF
from parallel_wavegan.utils import load_model
from parallel_wavegan.utils import read_hdf5


Expand Down Expand Up @@ -102,59 +98,31 @@ def main():
)
logging.info(f"The number of features to be decoded = {len(dataset)}.")

# setup
# setup model
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model_class = getattr(
parallel_wavegan.models,
config.get("generator_type", "ParallelWaveGANGenerator"))
model = model_class(**config["generator_params"])
model.load_state_dict(
torch.load(args.checkpoint, map_location="cpu")["model"]["generator"])
model = load_model(args.checkpoint, config)
logging.info(f"Loaded model parameters from {args.checkpoint}.")
model.remove_weight_norm()
model = model.eval().to(device)
use_noise_input = not isinstance(
model, parallel_wavegan.models.MelGANGenerator)
pad_fn = torch.nn.ReplicationPad1d(
config["generator_params"].get("aux_context_window", 0))
if config["generator_params"]["out_channels"] > 1:
pqmf_params = {}
if LooseVersion(config.get("version", "0.1.0")) <= LooseVersion("0.4.2"):
# For compatibility, here we set default values in version <= 0.4.2
pqmf_params.update(taps=62, cutoff_ratio=0.15, beta=9.0, use_legacy=True)
pqmf = PQMF(
subbands=config["generator_params"]["out_channels"],
**config.get("pqmf_params", pqmf_params),
).to(device)

# start generation
total_rtf = 0.0
with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
for idx, (utt_id, c) in enumerate(pbar, 1):
# setup input
x = ()
if use_noise_input:
z = torch.randn(1, 1, len(c) * config["hop_size"]).to(device)
x += (z,)
c = pad_fn(torch.tensor(c, dtype=torch.float).unsqueeze(0).transpose(2, 1)).to(device)
x += (c,)

# generate
c = torch.tensor(c, dtype=torch.float).to(device)
start = time.time()
if config["generator_params"]["out_channels"] == 1:
y = model(*x).view(-1).cpu().numpy()
else:
y = pqmf.synthesis(model(*x)).view(-1).cpu().numpy()
y = model.inference(c).view(-1)
rtf = (time.time() - start) / (len(y) / config["sampling_rate"])
pbar.set_postfix({"RTF": rtf})
total_rtf += rtf

# save as PCM 16 bit wav file
sf.write(os.path.join(config["outdir"], f"{utt_id}_gen.wav"),
y, config["sampling_rate"], "PCM_16")
y.cpu().numpy(), config["sampling_rate"], "PCM_16")

# report average RTF
logging.info(f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f}).")
Expand Down
55 changes: 13 additions & 42 deletions parallel_wavegan/layers/pqmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class PQMF(torch.nn.Module):
"""

def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0, use_legacy=False):
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
"""Initilize PQMF module.
The cutoff_ratio and beta parameters are optimized for #subbands = 4.
Expand All @@ -69,17 +69,23 @@ def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0, use_legacy
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
use_legacy (bool): Whether to use legacy PQMF coefficients (for <= 0.4.2).
"""
super(PQMF, self).__init__()

# define filter coefficient
# build analysis & synthesis filter coefficients
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
if use_legacy:
h_analysis, h_synthesis = self._build_filter_legacy(h_proto, subbands, taps)
else:
h_analysis, h_synthesis = self._build_filter(h_proto, subbands, taps)
h_analysis = np.zeros((subbands, len(h_proto)))
h_synthesis = np.zeros((subbands, len(h_proto)))
for k in range(subbands):
h_analysis[k] = 2 * h_proto * np.cos(
(2 * k + 1) * (np.pi / (2 * subbands)) *
(np.arange(taps + 1) - (taps / 2)) +
(-1) ** k * np.pi / 4)
h_synthesis[k] = 2 * h_proto * np.cos(
(2 * k + 1) * (np.pi / (2 * subbands)) *
(np.arange(taps + 1) - (taps / 2)) -
(-1) ** k * np.pi / 4)

# convert to tensor
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
Expand All @@ -99,41 +105,6 @@ def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0, use_legacy
# keep padding info
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)

def _build_filter(self, h_proto, subbands, taps):
# build analysis & synthesis filter coefficients
h_analysis = np.zeros((subbands, len(h_proto)))
h_synthesis = np.zeros((subbands, len(h_proto)))
for k in range(subbands):
h_analysis[k] = 2 * h_proto * np.cos(
(2 * k + 1) * (np.pi / (2 * subbands)) *
(np.arange(taps + 1) - (taps / 2)) +
(-1) ** k * np.pi / 4)
h_synthesis[k] = 2 * h_proto * np.cos(
(2 * k + 1) * (np.pi / (2 * subbands)) *
(np.arange(taps + 1) - (taps / 2)) -
(-1) ** k * np.pi / 4)

return h_analysis, h_synthesis

def _build_filter_legacy(self, h_proto, subbands, taps):
# NOTE(kan-bayashi): legacy version is for the <= 0.4.2 compatibility
# build analysis & synthesis filter coefficients
h_analysis = np.zeros((subbands, len(h_proto)))
h_synthesis = np.zeros((subbands, len(h_proto)))
for k in range(subbands):
h_analysis[k] = 2 * h_proto * np.cos(
(2 * k + 1) * (np.pi / (2 * subbands)) *
# NOTE(kan-bayashi): (taps - 1) is used for <= v.0.4.2
(np.arange(taps + 1) - ((taps - 1) / 2)) +
(-1) ** k * np.pi / 4)
h_synthesis[k] = 2 * h_proto * np.cos(
(2 * k + 1) * (np.pi / (2 * subbands)) *
# NOTE(kan-bayashi): (taps - 1) is used for <= v.0.4.2
(np.arange(taps + 1) - ((taps - 1) / 2)) -
(-1) ** k * np.pi / 4)

return h_analysis, h_synthesis

def analysis(self, x):
"""Analysis with PQMF.
Expand Down
20 changes: 20 additions & 0 deletions parallel_wavegan/models/melgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def __init__(self,
# reset parameters
self.reset_parameters()

# initialize pqmf for inference
self.pqmf = None

def forward(self, c):
"""Calculate forward propagation.
Expand Down Expand Up @@ -190,6 +193,23 @@ def _reset_parameters(m):

self.apply(_reset_parameters)

def inference(self, c):
"""Perform inference.
Args:
c (Union[Tensor, ndarray]): Input tensor (T, in_channels).
Returns:
Tensor: Output tensor (T ** prod(upsample_scales), out_channels).
"""
if not isinstance(c, torch.Tensor):
c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device)
c = self.melgan(c.transpose(1, 0).unsqueeze(0))
if self.pqmf is not None:
c = self.pqmf.synthesis(c)
return c.squeeze(0).transpose(1, 0)


class MelGANDiscriminator(torch.nn.Module):
"""MelGAN discriminator module."""
Expand Down
29 changes: 29 additions & 0 deletions parallel_wavegan/models/parallel_wavegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import math

import numpy as np
import torch

from parallel_wavegan.layers import Conv1d
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(self,
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.aux_context_window = aux_context_window
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
Expand Down Expand Up @@ -96,8 +98,10 @@ def __init__(self,
"aux_context_window": aux_context_window,
})
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
self.upsample_factor = np.prod(upsample_params["upsample_scales"])
else:
self.upsample_net = None
self.upsample_factor = 1

# define residual blocks
self.conv_layers = torch.nn.ModuleList()
Expand Down Expand Up @@ -192,6 +196,31 @@ def receptive_field_size(self):
"""Return receptive field size."""
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)

def inference(self, c=None, x=None):
"""Perform inference.
Args:
c (Union[Tensor, ndarray]): Local conditioning auxiliary features (T' ,C).
x (Union[Tensor, ndarray]): Input noise signal (T, 1).
Returns:
Tensor: Output tensor (T, out_channels)
"""
if x is not None:
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, dtype=torch.float).to(next(self.parameters()).device)
x = x.transpose(1, 0).unsqueeze(0)
else:
assert c is not None
x = torch.randn(1, 1, len(c) * self.upsample_factor).to(next(self.parameters()).device)
if c is not None:
if not isinstance(c, torch.Tensor):
c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device)
c = c.transpose(1, 0).unsqueeze(0)
c = torch.nn.ReplicationPad1d(self.aux_context_window)(c)
return self.forward(x, c).squeeze(0).transpose(1, 0)


class ParallelWaveGANDiscriminator(torch.nn.Module):
"""Parallel WaveGAN Discriminator module."""
Expand Down
50 changes: 50 additions & 0 deletions parallel_wavegan/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,16 @@
import os
import sys

from distutils.version import LooseVersion

import h5py
import numpy as np
import torch
import yaml

import parallel_wavegan.models

from parallel_wavegan.layers import PQMF


def find_files(root_dir, query="*.wav", include_root_dir=True):
Expand Down Expand Up @@ -240,3 +248,45 @@ def values(self):
"""Return the values of the scp file."""
for key in self.keys():
yield self[key]


def load_model(checkpoint, config=None):
"""Load trained model.
Args:
checkpoint (str): Checkpoint path.
config (dict): Configuration dict.
Return:
torch.nn.Module: Model instance.
"""
# load config if not provided
if config is None:
dirname = os.path.dirname(checkpoint)
config = os.path.join(dirname, "config.yml")
with open(config) as f:
config = yaml.load(f, Loader=yaml.Loader)

# get model and load parameters
model_class = getattr(
parallel_wavegan.models,
config.get("generator_type", "ParallelWaveGANGenerator")
)
model = model_class(**config["generator_params"])
model.load_state_dict(
torch.load(checkpoint, map_location="cpu")["model"]["generator"]
)

# add pqmf if needed
if config["generator_params"]["out_channels"] > 1:
pqmf_params = {}
if LooseVersion(config.get("version", "0.1.0")) <= LooseVersion("0.4.2"):
# For compatibility, here we set default values in version <= 0.4.2
pqmf_params.update(taps=62, cutoff_ratio=0.15, beta=9.0)
model.pqmf = PQMF(
subbands=config["generator_params"]["out_channels"],
**config.get("pqmf_params", pqmf_params),
)

return model
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

requirements = {
"install": [
"torch>=1.0.1,<=1.5.1",
"torch>=1.0.1",
"setuptools>=38.5.1",
"librosa>=0.7.0",
"soundfile>=0.10.2",
Expand Down Expand Up @@ -65,7 +65,7 @@

dirname = os.path.dirname(__file__)
setup(name="parallel_wavegan",
version="0.4.3",
version="0.4.4",
url="http://github.com/kan-bayashi/ParallelWaveGAN",
author="Tomoki Hayashi",
author_email="[email protected]",
Expand Down
2 changes: 1 addition & 1 deletion tools/Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
PYTHON:= python3.6
CUDA_VERSION:= 10.0
PYTORCH_VERSION:= 1.5
PYTORCH_VERSION:= 1.4
DOT:= .
.PHONY: all clean

Expand Down

0 comments on commit 6041c9f

Please sign in to comment.