Skip to content

Commit

Permalink
fix: use logging instead of print statements
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Apr 2, 2024
1 parent 018daa0 commit 674326f
Show file tree
Hide file tree
Showing 66 changed files with 531 additions and 313 deletions.
3 changes: 3 additions & 0 deletions TTS/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import tempfile
import warnings
from pathlib import Path
Expand All @@ -9,6 +10,8 @@
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer

logger = logging.getLogger(__name__)


class TTS(nn.Module):
"""TODO: Add voice conversion and Capacitron support."""
Expand Down
16 changes: 10 additions & 6 deletions TTS/encoder/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import random

import torch
Expand All @@ -6,6 +7,9 @@
from TTS.encoder.utils.generic_utils import AugmentWAV


logger = logging.getLogger(__name__)


class EncoderDataset(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -51,12 +55,12 @@ def __init__(
self.gaussian_augmentation_config = augmentation_config["gaussian"]

if self.verbose:
print("\n > DataLoader initialization")
print(f" | > Classes per Batch: {num_classes_in_batch}")
print(f" | > Number of instances : {len(self.items)}")
print(f" | > Sequence length: {self.seq_len}")
print(f" | > Num Classes: {len(self.classes)}")
print(f" | > Classes: {self.classes}")
logger.info("DataLoader initialization")
logger.info(" | Classes per batch: %d", num_classes_in_batch)
logger.info(" | Number of instances: %d", len(self.items))
logger.info(" | Sequence length: %d", self.seq_len)
logger.info(" | Number of classes: %d", len(self.classes))
logger.info(" | Classes: %d", self.classes)

def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
Expand Down
13 changes: 9 additions & 4 deletions TTS/encoder/losses.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging

import torch
import torch.nn.functional as F
from torch import nn


logger = logging.getLogger(__name__)


# adapted from https://github.com/cvqluu/GE2E-Loss
class GE2ELoss(nn.Module):
def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"):
Expand All @@ -23,7 +28,7 @@ def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"):
self.b = nn.Parameter(torch.tensor(init_b))
self.loss_method = loss_method

print(" > Initialized Generalized End-to-End loss")
logger.info("Initialized Generalized End-to-End loss")

assert self.loss_method in ["softmax", "contrast"]

Expand Down Expand Up @@ -139,7 +144,7 @@ def __init__(self, init_w=10.0, init_b=-5.0):
self.b = nn.Parameter(torch.tensor(init_b))
self.criterion = torch.nn.CrossEntropyLoss()

print(" > Initialized Angular Prototypical loss")
logger.info("Initialized Angular Prototypical loss")

def forward(self, x, _label=None):
"""
Expand Down Expand Up @@ -177,7 +182,7 @@ def __init__(self, embedding_dim, n_speakers):
self.criterion = torch.nn.CrossEntropyLoss()
self.fc = nn.Linear(embedding_dim, n_speakers)

print("Initialised Softmax Loss")
logger.info("Initialised Softmax Loss")

def forward(self, x, label=None):
# reshape for compatibility
Expand Down Expand Up @@ -212,7 +217,7 @@ def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
self.angleproto = AngleProtoLoss(init_w, init_b)

print("Initialised SoftmaxAnglePrototypical Loss")
logger.info("Initialised SoftmaxAnglePrototypical Loss")

def forward(self, x, label=None):
"""
Expand Down
10 changes: 7 additions & 3 deletions TTS/encoder/models/base_encoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import numpy as np
import torch
import torchaudio
Expand All @@ -8,6 +10,8 @@
from TTS.utils.generic_utils import set_init_dict
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)


class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
Expand Down Expand Up @@ -118,13 +122,13 @@ def load_checkpoint(
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
try:
self.load_state_dict(state["model"])
print(" > Model fully restored. ")
logger.info("Model fully restored. ")
except (KeyError, RuntimeError) as error:
# If eval raise the error
if eval:
raise error

print(" > Partial model initialization.")
logger.info("Partial model initialization.")
model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"], c)
self.load_state_dict(model_dict)
Expand All @@ -135,7 +139,7 @@ def load_checkpoint(
try:
criterion.load_state_dict(state["criterion"])
except (KeyError, RuntimeError) as error:
print(" > Criterion load ignored because of:", error)
logger.exception("Criterion load ignored because of: %s", error)

# instance and load the criterion for the encoder classifier in inference time
if (
Expand Down
12 changes: 9 additions & 3 deletions TTS/encoder/utils/generic_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import glob
import logging
import os
import random

Expand All @@ -9,6 +10,9 @@
from TTS.encoder.models.resnet import ResNetSpeakerEncoder


logger = logging.getLogger(__name__)


class AugmentWAV(object):
def __init__(self, ap, augmentation_config):
self.ap = ap
Expand Down Expand Up @@ -38,8 +42,10 @@ def __init__(self, ap, augmentation_config):
self.noise_list[noise_dir] = []
self.noise_list[noise_dir].append(wav_file)

print(
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
logger.info(
"Using Additive Noise Augmentation: with %d audios instances from %s",
len(additive_files),
self.additive_noise_types,
)

self.use_rir = False
Expand All @@ -50,7 +56,7 @@ def __init__(self, ap, augmentation_config):
self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
self.use_rir = True

print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files))

self.create_augmentation_global_list()

Expand Down
30 changes: 16 additions & 14 deletions TTS/encoder/utils/prepare_voxceleb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@

import csv
import hashlib
import logging
import os
import subprocess
import sys
import zipfile

import soundfile as sf
from absl import logging

logger = logging.getLogger(__name__)

SUBSETS = {
"vox1_dev_wav": [
Expand Down Expand Up @@ -77,14 +79,14 @@ def download_and_extract(directory, subset, urls):
zip_filepath = os.path.join(directory, url.split("/")[-1])
if os.path.exists(zip_filepath):
continue
logging.info("Downloading %s to %s" % (url, zip_filepath))
logger.info("Downloading %s to %s" % (url, zip_filepath))
subprocess.call(
"wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath),
shell=True,
)

statinfo = os.stat(zip_filepath)
logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))

# concatenate all parts into zip files
if ".zip" not in zip_filepath:
Expand Down Expand Up @@ -118,9 +120,9 @@ def exec_cmd(cmd):
try:
retcode = subprocess.call(cmd, shell=True)
if retcode < 0:
logging.info(f"Child was terminated by signal {retcode}")
logger.info(f"Child was terminated by signal {retcode}")
except OSError as e:
logging.info(f"Execution failed: {e}")
logger.info(f"Execution failed: {e}")
retcode = -999
return retcode

Expand All @@ -134,11 +136,11 @@ def decode_aac_with_ffmpeg(aac_file, wav_file):
bool, True if success.
"""
cmd = f"ffmpeg -i {aac_file} {wav_file}"
logging.info(f"Decoding aac file using command line: {cmd}")
logger.info(f"Decoding aac file using command line: {cmd}")
ret = exec_cmd(cmd)
if ret != 0:
logging.error(f"Failed to decode aac file with retcode {ret}")
logging.error("Please check your ffmpeg installation.")
logger.error(f"Failed to decode aac file with retcode {ret}")
logger.error("Please check your ffmpeg installation.")
return False
return True

Expand All @@ -152,7 +154,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
"""

logging.info("Preprocessing audio and label for subset %s" % subset)
logger.info("Preprocessing audio and label for subset %s" % subset)
source_dir = os.path.join(input_dir, subset)

files = []
Expand Down Expand Up @@ -190,7 +192,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
for wav_file in files:
writer.writerow(wav_file)
logging.info("Successfully generated csv file {}".format(csv_file_path))
logger.info("Successfully generated csv file {}".format(csv_file_path))


def processor(directory, subset, force_process):
Expand All @@ -203,16 +205,16 @@ def processor(directory, subset, force_process):
if not force_process and os.path.exists(subset_csv):
return subset_csv

logging.info("Downloading and process the voxceleb in %s", directory)
logging.info("Preparing subset %s", subset)
logger.info("Downloading and process the voxceleb in %s", directory)
logger.info("Preparing subset %s", subset)
download_and_extract(directory, subset, urls[subset])
convert_audio_and_make_label(directory, subset, directory, subset + ".csv")
logging.info("Finished downloading and processing")
logger.info("Finished downloading and processing")
return subset_csv


if __name__ == "__main__":
logging.set_verbosity(logging.INFO)
logging.getLogger("TTS").setLevel(logging.INFO)
if len(sys.argv) != 4:
print("Usage: python prepare_data.py save_directory user password")
sys.exit()
Expand Down
11 changes: 7 additions & 4 deletions TTS/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import io
import json
import logging
import os
import sys
from pathlib import Path
Expand All @@ -18,6 +19,8 @@
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer

logger = logging.getLogger(__name__)


def create_argparser():
def convert_boolean(x):
Expand Down Expand Up @@ -200,9 +203,9 @@ def tts():
style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav)

print(f" > Model input: {text}")
print(f" > Speaker Idx: {speaker_idx}")
print(f" > Language Idx: {language_idx}")
logger.info("Model input: %s", text)
logger.info("Speaker idx: %s", speaker_idx)
logger.info("Language idx: %s", language_idx)
wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
Expand Down Expand Up @@ -246,7 +249,7 @@ def mary_tts_api_process():
text = data.get("INPUT_TEXT", [""])[0]
else:
text = request.args.get("INPUT_TEXT", "")
print(f" > Model input: {text}")
logger.info("Model input: %s", text)
wavs = synthesizer.tts(text)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
Expand Down
17 changes: 10 additions & 7 deletions TTS/tts/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import sys
from collections import Counter
Expand All @@ -10,6 +11,9 @@
from TTS.tts.datasets.formatters import *


logger = logging.getLogger(__name__)


def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
Expand Down Expand Up @@ -122,7 +126,7 @@ def load_tts_samples(

meta_data_train = add_extra_keys(meta_data_train, language, dataset_name)

print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
logger.info("Found %d files in %s", len(meta_data_train), Path(root_path).resolve())
# load evaluation split if set
if eval_split:
if meta_file_val:
Expand Down Expand Up @@ -166,16 +170,15 @@ def _get_formatter_by_name(name):
return getattr(thismodule, name.lower())


def find_unique_chars(data_samples, verbose=True):
def find_unique_chars(data_samples):
texts = "".join(item["text"] for item in data_samples)
chars = set(texts)
lower_chars = filter(lambda c: c.islower(), chars)
chars_force_lower = [c.lower() for c in chars]
chars_force_lower = set(chars_force_lower)

if verbose:
print(f" > Number of unique characters: {len(chars)}")
print(f" > Unique characters: {''.join(sorted(chars))}")
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}")
logger.info("Number of unique characters: %d", len(chars))
logger.info("Unique characters: %s", "".join(sorted(chars)))
logger.info("Unique lower characters: %s", "".join(sorted(lower_chars)))
logger.info("Unique all forced to lower characters: %s", "".join(sorted(chars_force_lower)))
return chars_force_lower
Loading

0 comments on commit 674326f

Please sign in to comment.