Skip to content

Commit

Permalink
Merge pull request #207 from kan-bayashi/feature/pretrained_model_dow…
Browse files Browse the repository at this point in the history
…nloader
  • Loading branch information
kan-bayashi authored Aug 18, 2020
2 parents 6041c9f + 1ca8781 commit b66a877
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
2 changes: 1 addition & 1 deletion parallel_wavegan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# -*- coding: utf-8 -*-

__version__ = "0.4.4"
__version__ = "0.4.5"
50 changes: 50 additions & 0 deletions parallel_wavegan/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
import sys
import tarfile

from distutils.version import LooseVersion

Expand All @@ -22,6 +23,28 @@
from parallel_wavegan.layers import PQMF


PRETRAINED_MODEL_LIST = {
"ljspeech_parallel_wavegan.v1": "1PdZv37JhAQH6AwNh31QlqruqrvjTBq7U",
"ljspeech_parallel_wavegan.v1.long": "1A9TsrD9fHxFviJVFjCk5W6lkzWXwhftv",
"ljspeech_parallel_wavegan.v1.no_limit": "1CdWKSiKoFNPZyF1lo7Dsj6cPKmfLJe72",
"ljspeech_parallel_wavegan.v3": "1-oZpwpWZMMolDYsCqeL12dFkXSBD9VBq",
"ljspeech_full_band_melgan.v2": "1Kb7q5zBeQ30Wsnma0X23G08zvgDG5oen",
"ljspeech_multi_band_melgan.v2": "1b70pJefKI8DhGYz4SxbEHpxm92tj1_qC",
"jsut_parallel_wavegan.v1": "1qok91A6wuubuz4be-P9R2zKhNmQXG0VQ",
"jsut_multi_band_melgan.v2": "1chTt-76q2p69WPpZ1t1tt8szcM96IKad",
"csmsc_parallel_wavegan.v1": "1QTOAokhD5dtRnqlMPTXTW91-CG7jf74e",
"csmsc_multi_band_melgan.v2": "1G6trTmt0Szq-jWv2QDhqglMdWqQxiXQT",
"arctic_slt_parallel_wavegan.v1": "1_MXePg40-7DTjD0CDVzyduwQuW_O9aA1",
"jnas_parallel_wavegan.v1": "1D2TgvO206ixdLI90IqG787V6ySoXLsV_",
"vctk_parallel_wavegan.v1": "1bqEFLgAroDcgUy5ZFP4g2O2MwcwWLEca",
"vctk_parallel_wavegan.v1.long": "1tO4-mFrZ3aVYotgg7M519oobYkD4O_0-",
"vctk_multi_band_melgan.v2": "10PRQpHMFPE7RjF-MHYqvupK9S0xwBlJ_",
"libritts_parallel_wavegan.v1": "1zHQl8kUYEuZ_i1qEFU6g2MEu99k3sHmR",
"libritts_parallel_wavegan.v1.long": "1b9zyBYGCCaJu0TIus5GXoMF8M3YEbqOw",
"libritts_multi_band_melgan.v2": "1kIDSBjrQvAsRewHPiFwBZ3FDelTWMp64",
}


def find_files(root_dir, query="*.wav", include_root_dir=True):
"""Find files recursively.
Expand Down Expand Up @@ -290,3 +313,30 @@ def load_model(checkpoint, config=None):
)

return model


def download_pretrained_model(tag, download_dir=None):
"""Download pretrained model form google drive.
Args:
tag (str): Pretrained model tag.
download_dir (str): Directory to save downloaded files.
Returns:
str: Path of downloaded model checkpoint.
"""
assert tag in PRETRAINED_MODEL_LIST, f"{tag} does not exists."
id_ = PRETRAINED_MODEL_LIST[tag]
if download_dir is None:
download_dir = os.path.expanduser("~/.cache/parallel_wavegan")
output_path = f"{download_dir}/{tag}.tar.gz"
os.makedirs(f"{download_dir}", exist_ok=True)
if not os.path.exists(output_path):
import gdown
gdown.download(f"https://drive.google.com/uc?id={id_}", output_path, quiet=False)
with tarfile.open(output_path, 'r:*') as tar:
tar.extractall(f"{download_dir}/{tag}")
checkpoint_path = find_files(f"{download_dir}/{tag}", "checkpoint*.pkl")

return checkpoint_path[0]
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"yq>=2.10.0",
# Fix No module named "numba.decorators"
"numba<=0.48",
"gdown",
],
"setup": [
"numpy",
Expand Down Expand Up @@ -65,7 +66,7 @@

dirname = os.path.dirname(__file__)
setup(name="parallel_wavegan",
version="0.4.4",
version="0.4.5",
url="http://github.com/kan-bayashi/ParallelWaveGAN",
author="Tomoki Hayashi",
author_email="[email protected]",
Expand Down

0 comments on commit b66a877

Please sign in to comment.