forked from tuanh123789/Train_Hifigan_XTTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_lapa.py
78 lines (61 loc) · 2.97 KB
/
train_lapa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import random
import json
import torch
from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor
from datasets.preprocess import load_wav_feat_spk_data
from configs.gpt_hifigan_config import GPTHifiganConfig
from models.gpt_gan import GPTGAN
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Or '0' for a single GPU
class GPTHifiganTrainer:
def __init__(self, config):
self.config = config
self.ap = AudioProcessor(**config.audio.to_dict())
print(f"Files in data path: {os.listdir(config.data_path)}")
print(f"Files in mel path: {os.listdir(config.mel_path)}")
print(f"Files in speaker path: {os.listdir(config.spk_path)}")
self.eval_samples, self.train_samples = load_wav_feat_spk_data(
config.data_path, config.mel_path, config.spk_path, eval_split_size=config.eval_split_size
)
if len(self.train_samples) == 0:
print("[!] Training set is empty. Using all samples for training.")
self.train_samples = self.eval_samples
self.eval_samples = []
# Log dataset split
print(f"Training samples: {len(self.train_samples)}")
print(f"Evaluation samples: {len(self.eval_samples)}")
# Initialize model
self.model = GPTGAN(config, self.ap)
# Load pretrained weights if provided
if config.pretrain_path is not None:
state_dict = torch.load(config.pretrain_path)
hifigan_state_dict = {
k.replace("xtts.hifigan_decoder.waveform_decoder.", "").replace("hifigan_decoder.waveform_decoder.",
""): v
for k, v in state_dict["model"].items()
if "hifigan_decoder" in k and "speaker_encoder" not in k
}
self.model.model_g.load_state_dict(hifigan_state_dict, strict=False)
if config.train_spk_encoder:
speaker_encoder_state_dict = {
k.replace("xtts.hifigan_decoder.speaker_encoder.", "").replace("hifigan_decoder.waveform_decoder.",
""): v
for k, v in state_dict["model"].items()
if "hifigan_decoder" in k and "speaker_encoder" in k
}
self.model.speaker_encoder.load_state_dict(speaker_encoder_state_dict, strict=True)
def train(self):
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(), config, config.output_path, model=self.model, train_samples=self.train_samples,
eval_samples=self.eval_samples
)
trainer.fit()
if __name__ == "__main__":
with open("config_v00.json", "r") as f:
config = json.load(f)
# Dynamically pass the JSON keys to GPTHifiganConfig
config = GPTHifiganConfig(**config)
hifigan_trainer = GPTHifiganTrainer(config=config)
hifigan_trainer.train()