Skip to content

Commit

Permalink
Merge pull request #208 from kan-bayashi/minor_update_5
Browse files Browse the repository at this point in the history
  • Loading branch information
kan-bayashi authored Aug 19, 2020
2 parents b66a877 + e2a101c commit 69d22bd
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 50 deletions.
112 changes: 80 additions & 32 deletions README.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions parallel_wavegan/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from parallel_wavegan.datasets.audio_mel_dataset import * # NOQA
from parallel_wavegan.datasets.scp_dataset import * # NOQA
from .audio_mel_dataset import * # NOQA
from .scp_dataset import * # NOQA
10 changes: 5 additions & 5 deletions parallel_wavegan/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from parallel_wavegan.layers.causal_conv import * # NOQA
from parallel_wavegan.layers.pqmf import * # NOQA
from parallel_wavegan.layers.residual_block import * # NOQA
from parallel_wavegan.layers.residual_stack import * # NOQA
from parallel_wavegan.layers.upsample import * # NOQA
from .causal_conv import * # NOQA
from .pqmf import * # NOQA
from .residual_block import * # NOQA
from .residual_stack import * # NOQA
from .upsample import * # NOQA
2 changes: 1 addition & 1 deletion parallel_wavegan/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from parallel_wavegan.losses.stft_loss import * # NOQA
from .stft_loss import * # NOQA
4 changes: 2 additions & 2 deletions parallel_wavegan/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from parallel_wavegan.models.melgan import * # NOQA
from parallel_wavegan.models.parallel_wavegan import * # NOQA
from .melgan import * # NOQA
from .parallel_wavegan import * # NOQA
3 changes: 2 additions & 1 deletion parallel_wavegan/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from torch.optim import * # NOQA
from parallel_wavegan.optimizers.radam import * # NOQA

from .radam import * # NOQA
2 changes: 1 addition & 1 deletion parallel_wavegan/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from parallel_wavegan.utils.utils import * # NOQA
from .utils import * # NOQA
18 changes: 12 additions & 6 deletions parallel_wavegan/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@
import torch
import yaml

import parallel_wavegan.models

from parallel_wavegan.layers import PQMF


PRETRAINED_MODEL_LIST = {
"ljspeech_parallel_wavegan.v1": "1PdZv37JhAQH6AwNh31QlqruqrvjTBq7U",
"ljspeech_parallel_wavegan.v1.long": "1A9TsrD9fHxFviJVFjCk5W6lkzWXwhftv",
Expand Down Expand Up @@ -291,6 +286,9 @@ def load_model(checkpoint, config=None):
with open(config) as f:
config = yaml.load(f, Loader=yaml.Loader)

# lazy load for circular error
import parallel_wavegan.models

# get model and load parameters
model_class = getattr(
parallel_wavegan.models,
Expand All @@ -303,6 +301,9 @@ def load_model(checkpoint, config=None):

# add pqmf if needed
if config["generator_params"]["out_channels"] > 1:
# lazy load for circular error
from parallel_wavegan.layers import PQMF

pqmf_params = {}
if LooseVersion(config.get("version", "0.1.0")) <= LooseVersion("0.4.2"):
# For compatibility, here we set default values in version <= 0.4.2
Expand Down Expand Up @@ -333,10 +334,15 @@ def download_pretrained_model(tag, download_dir=None):
output_path = f"{download_dir}/{tag}.tar.gz"
os.makedirs(f"{download_dir}", exist_ok=True)
if not os.path.exists(output_path):
# lazy load for compatibility
import gdown

gdown.download(f"https://drive.google.com/uc?id={id_}", output_path, quiet=False)
with tarfile.open(output_path, 'r:*') as tar:
tar.extractall(f"{download_dir}/{tag}")
for member in tar.getmembers():
if member.isreg():
member.name = os.path.basename(member.name)
tar.extract(member, f"{download_dir}/{tag}")
checkpoint_path = find_files(f"{download_dir}/{tag}", "checkpoint*.pkl")

return checkpoint_path[0]

0 comments on commit 69d22bd

Please sign in to comment.