Skip to content

Commit

Permalink
Added advanced settings for XTTS and support for fine-tuned XTTS models
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszliniewicz authored Sep 6, 2024
1 parent c6909c2 commit f3c702d
Showing 1 changed file with 174 additions and 14 deletions.
188 changes: 174 additions & 14 deletions pandrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from nltk.tokenize import sent_tokenize
nltk.download('punkt')
import hasami
import wave

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

Expand Down Expand Up @@ -117,7 +116,14 @@ def __init__(self, master):
self.session_name = ctk.StringVar()
self.tts_service = ctk.StringVar(value="XTTS")
self.mark_paragraphs_multiple_newlines = ctk.BooleanVar(value=False)

self.xtts_temperature = ctk.DoubleVar(value=0.75)
self.xtts_length_penalty = ctk.DoubleVar(value=1.0)
self.xtts_repetition_penalty = ctk.DoubleVar(value=5.0)
self.xtts_top_k = ctk.IntVar(value=50)
self.xtts_top_p = ctk.DoubleVar(value=0.85)
self.xtts_speed = ctk.DoubleVar(value=1.0)
self.xtts_enable_text_splitting = ctk.BooleanVar(value=True)
self.xtts_stream_chunk_size = ctk.IntVar(value=100)

# Layout
ctk.set_appearance_mode("dark")
Expand Down Expand Up @@ -197,6 +203,12 @@ def __init__(self, master):
self.voicecraft_model_dropdown.grid(row=3, column=1, padx=10, pady=5, sticky=tk.EW)
self.voicecraft_model_label.grid_remove() # Hide the VoiceCraft model label initially
self.voicecraft_model_dropdown.grid_remove() # Hide the VoiceCraft model dropdown initially

self.xtts_model = ctk.StringVar(value="")
self.xtts_model_label = ctk.CTkLabel(session_settings_frame, text="XTTS Model:")
self.xtts_model_label.grid(row=3, column=0, padx=10, pady=5, sticky=tk.W)
self.xtts_model_dropdown = ctk.CTkOptionMenu(session_settings_frame, variable=self.xtts_model, values=[], command=self.on_xtts_model_change)
self.xtts_model_dropdown.grid(row=3, column=1, padx=10, pady=5, sticky=tk.EW)

self.connect_to_server_button = ctk.CTkButton(session_settings_frame, text="Connect to Server", command=self.connect_to_server)
self.connect_to_server_button.grid(row=2, column=2, columnspan=2, padx=10, pady=5, sticky=tk.EW)
Expand Down Expand Up @@ -249,7 +261,6 @@ def __init__(self, master):
self.show_advanced_tts_settings = ctk.BooleanVar(value=False)
self.advanced_settings_switch = ctk.CTkSwitch(session_settings_frame, text="Advanced TTS Settings", variable=self.show_advanced_tts_settings, command=self.toggle_advanced_tts_settings)
self.advanced_settings_switch.grid(row=9, column=0, padx=5, pady=5, sticky=tk.W)
self.advanced_settings_switch.grid_remove() # Hide the switch initially

# Advanced TTS Settings Frame
self.advanced_tts_settings_frame = ctk.CTkFrame(self.session_tab, fg_color="gray20", corner_radius=10)
Expand Down Expand Up @@ -550,6 +561,7 @@ def __init__(self, master):
ctk.CTkEntry(output_frame, textvariable=self.bitrate).grid(row=0, column=3, padx=5, pady=5, sticky=tk.EW)

self.sentence_audio_data = {} # Dictionary to store sentence audio data
self.create_xtts_advanced_settings_frame()

#self.populate_speaker_dropdown()

Expand Down Expand Up @@ -986,6 +998,7 @@ def connect_to_server(self):
if response.status_code == 200:
self.external_server_connected = True
self.populate_speaker_dropdown()
self.populate_xtts_models()
messagebox.showinfo("Connected", "Successfully connected to the external XTTS server.")
else:
messagebox.showerror("Error", f"Failed to connect to the external XTTS server. Status code: {response.status_code}")
Expand All @@ -1001,6 +1014,7 @@ def connect_to_server(self):
response = requests.get("http://localhost:8020/docs")
if response.status_code == 200:
self.populate_speaker_dropdown()
self.populate_xtts_models()
messagebox.showinfo("Connected", "Successfully connected to the local XTTS server.")
else:
messagebox.showerror("Error", f"Failed to connect to the local XTTS server. Status code: {response.status_code}")
Expand Down Expand Up @@ -1309,8 +1323,11 @@ def load_models(self):
CTkMessagebox(title="Error", message="Failed to load models from the API.", icon="cancel")
except requests.exceptions.ConnectionError:
CTkMessagebox(title="Error", message="Failed to connect to the LLM API.", icon="cancel")

def update_tts_service(self, event=None):
if not hasattr(self, 'xtts_advanced_settings_frame'):
self.create_xtts_advanced_settings_frame()

if self.tts_service.get() == "XTTS":
self.connect_to_server_button.grid()
self.use_external_server_switch.grid()
Expand All @@ -1322,7 +1339,11 @@ def update_tts_service(self, event=None):
self.external_server_url_entry_voicecraft.grid_remove()
self.voicecraft_model_dropdown.grid_remove()
self.voicecraft_model_label.grid_remove()
self.advanced_settings_switch.grid_remove() # Hide advanced settings for XTTS
self.advanced_settings_switch.grid() # Show advanced settings for XTTS
self.xtts_advanced_settings_frame.grid_remove() # Hide XTTS advanced settings initially
self.xtts_model_label.grid()
self.xtts_model_dropdown.grid()

elif self.tts_service.get() == "VoiceCraft":
self.connect_to_server_button.grid()
self.use_external_server_switch.grid_remove()
Expand All @@ -1335,6 +1356,9 @@ def update_tts_service(self, event=None):
self.voicecraft_model_dropdown.grid()
self.voicecraft_model_label.grid()
self.advanced_settings_switch.grid() # Show advanced settings for VoiceCraft
self.xtts_model_label.grid_remove()
self.xtts_model_dropdown.grid_remove()

else: # Silero
self.connect_to_server_button.grid_remove()
self.use_external_server_switch.grid_remove()
Expand All @@ -1346,9 +1370,63 @@ def update_tts_service(self, event=None):
self.voicecraft_model_dropdown.grid_remove()
self.voicecraft_model_label.grid_remove()
self.advanced_settings_switch.grid_remove() # Hide advanced settings for Silero
self.xtts_model_label.grid_remove()
self.xtts_model_dropdown.grid_remove()

self.update_language_dropdown()

def populate_xtts_models(self):
try:
if self.use_external_server.get() and self.external_server_connected:
url = f"{self.external_server_url.get()}/get_models_list"
else:
url = "http://localhost:8020/get_models_list"

response = requests.get(url)
if response.status_code == 200:
models = response.json()
self.xtts_model_dropdown.configure(values=models)
if models:
self.xtts_model.set(models[0])
self.switch_xtts_model(models[0])
else:
messagebox.showerror("Error", f"Failed to fetch XTTS models. Status code: {response.status_code}")
except requests.exceptions.RequestException as e:
messagebox.showerror("Error", f"Failed to fetch XTTS models: {str(e)}")

def switch_xtts_model(self, model_name):
try:
if self.use_external_server.get() and self.external_server_connected:
url = f"{self.external_server_url.get()}/switch_model"
else:
url = "http://localhost:8020/switch_model"

data = {"model_name": model_name}
response = requests.post(url, json=data)
if response.status_code == 200:
print(f"Switched to XTTS model: {model_name}")
elif response.status_code == 400:
response_json = response.json()
if "detail" in response_json and "already loaded in memory" in response_json["detail"]:
print(f"XTTS model {model_name} is already loaded.")
else:
print(f"Failed to switch XTTS model. Status code: {response.status_code}")
print(f"Response: {response.text}")
else:
print(f"Failed to switch XTTS model. Status code: {response.status_code}")
print(f"Response: {response.text}")
except requests.exceptions.RequestException as e:
print(f"Failed to switch XTTS model: {str(e)}")

def update_speaker_dropdown_state(self):
if self.is_custom_model:
self.speaker_dropdown.configure(state="disabled")
else:
self.speaker_dropdown.configure(state="normal")

def on_xtts_model_change(self, model_name):
self.switch_xtts_model(model_name)

def toggle_playback(self):
if self.playing:
if self.paused:
Expand Down Expand Up @@ -1639,11 +1717,90 @@ def toggle_advanced_tts_settings(self):
self.advanced_tts_settings_frame.grid()
else:
self.advanced_tts_settings_frame.grid_remove()
elif self.tts_service.get() == "XTTS":
if self.show_advanced_tts_settings.get():
self.xtts_advanced_settings_frame.grid()
else:
self.xtts_advanced_settings_frame.grid_remove()
else:
self.advanced_tts_settings_frame.grid_remove()
self.xtts_advanced_settings_frame.grid_remove()

def create_xtts_advanced_settings_frame(self):
self.xtts_advanced_settings_frame = ctk.CTkFrame(self.session_tab, fg_color="gray20", corner_radius=10)
self.xtts_advanced_settings_frame.grid(row=6, column=0, columnspan=4, padx=10, pady=(0, 20), sticky=tk.EW)
self.xtts_advanced_settings_frame.grid_columnconfigure(0, weight=1)
self.xtts_advanced_settings_frame.grid_columnconfigure(1, weight=1)

# Add Stream Chunk Size
ctk.CTkLabel(self.xtts_advanced_settings_frame, text="Stream Chunk Size:").grid(row=0, column=0, padx=5, pady=5, sticky=tk.W)
ctk.CTkEntry(self.xtts_advanced_settings_frame, textvariable=self.xtts_stream_chunk_size).grid(row=0, column=1, padx=5, pady=5, sticky=tk.EW)

ctk.CTkLabel(self.xtts_advanced_settings_frame, text="Temperature:").grid(row=1, column=0, padx=5, pady=5, sticky=tk.W)
ctk.CTkEntry(self.xtts_advanced_settings_frame, textvariable=self.xtts_temperature).grid(row=1, column=1, padx=5, pady=5, sticky=tk.EW)

ctk.CTkLabel(self.xtts_advanced_settings_frame, text="Length Penalty:").grid(row=2, column=0, padx=5, pady=5, sticky=tk.W)
ctk.CTkEntry(self.xtts_advanced_settings_frame, textvariable=self.xtts_length_penalty).grid(row=2, column=1, padx=5, pady=5, sticky=tk.EW)

ctk.CTkLabel(self.xtts_advanced_settings_frame, text="Repetition Penalty:").grid(row=3, column=0, padx=5, pady=5, sticky=tk.W)
ctk.CTkEntry(self.xtts_advanced_settings_frame, textvariable=self.xtts_repetition_penalty).grid(row=3, column=1, padx=5, pady=5, sticky=tk.EW)

ctk.CTkLabel(self.xtts_advanced_settings_frame, text="Top K:").grid(row=4, column=0, padx=5, pady=5, sticky=tk.W)
ctk.CTkEntry(self.xtts_advanced_settings_frame, textvariable=self.xtts_top_k).grid(row=4, column=1, padx=5, pady=5, sticky=tk.EW)

ctk.CTkLabel(self.xtts_advanced_settings_frame, text="Top P:").grid(row=5, column=0, padx=5, pady=5, sticky=tk.W)
ctk.CTkEntry(self.xtts_advanced_settings_frame, textvariable=self.xtts_top_p).grid(row=5, column=1, padx=5, pady=5, sticky=tk.EW)

# Speed Slider
ctk.CTkLabel(self.xtts_advanced_settings_frame, text="Speed:").grid(row=6, column=0, padx=5, pady=5, sticky=tk.W)
speed_slider = ctk.CTkSlider(self.xtts_advanced_settings_frame, from_=0.2, to=2.0, number_of_steps=180, variable=self.xtts_speed)
speed_slider.grid(row=6, column=1, padx=5, pady=5, sticky=tk.EW)

# Add a label to display the current speed value
self.speed_value_label = ctk.CTkLabel(self.xtts_advanced_settings_frame, text=f"Speed: {self.xtts_speed.get():.2f}")
self.speed_value_label.grid(row=6, column=2, padx=5, pady=5, sticky=tk.W)

# Update the speed value label when the slider changes
speed_slider.configure(command=self.update_speed_label)

ctk.CTkSwitch(self.xtts_advanced_settings_frame, text="Enable Text Splitting", variable=self.xtts_enable_text_splitting).grid(row=7, column=0, columnspan=2, padx=5, pady=5, sticky=tk.W)

# Add the Apply button
apply_button = ctk.CTkButton(self.xtts_advanced_settings_frame, text="Apply", command=self.apply_xtts_settings)
apply_button.grid(row=8, column=0, columnspan=2, padx=5, pady=10, sticky=tk.EW)

self.xtts_advanced_settings_frame.grid_remove() # Hide the frame initially

def apply_xtts_settings(self):
settings = {
"stream_chunk_size": int(self.xtts_stream_chunk_size.get()),
"temperature": float(self.xtts_temperature.get()),
"speed": float(self.xtts_speed.get()),
"length_penalty": float(self.xtts_length_penalty.get()),
"repetition_penalty": float(self.xtts_repetition_penalty.get()),
"top_p": float(self.xtts_top_p.get()),
"top_k": int(self.xtts_top_k.get()),
"enable_text_splitting": self.xtts_enable_text_splitting.get()
}

try:
if self.use_external_server.get() and self.external_server_connected:
url = f"{self.external_server_url.get()}/set_tts_settings"
else:
url = "http://localhost:8020/set_tts_settings"

response = requests.post(url, json=settings)
if response.status_code == 200:
messagebox.showinfo("Success", "XTTS settings updated successfully.")
else:
messagebox.showerror("Error", f"Failed to update XTTS settings. Status code: {response.status_code}")
except requests.exceptions.RequestException as e:
messagebox.showerror("Error", f"Failed to connect to the XTTS server: {str(e)}")

def update_speed_label(self, value):
self.speed_value_label.configure(text=f"Speed: {float(value):.2f}")

def convert_digits_to_words(self, sentence):
import re

def replace_numbers(match):
number = match.group(0)
Expand Down Expand Up @@ -2427,11 +2584,12 @@ def update_language_dropdown(self, event=None):
messagebox.showerror("Error", "Failed to set Silero language.")
except requests.exceptions.ConnectionError:
messagebox.showerror("Error", "Failed to connect to the Silero API.")

def tts_to_audio(self, text):
best_audio = None
best_mos = -1
if self.tts_service.get() == "XTTS":
language = self.language_dropdown.get() # Get the value directly from the language dropdown/combobox
language = self.language_dropdown.get()
speaker = self.selected_speaker.get()

speaker_path = os.path.join(self.tts_voices_folder, speaker)
Expand All @@ -2447,25 +2605,27 @@ def tts_to_audio(self, text):
"speaker_wav": speaker_arg,
"language": language
}
print(f"Request data: {data}") # Print the request data

print(f"Request data: {data}")
if self.external_server_connected:
external_server_url = self.external_server_url.get()
response = requests.post(f"{external_server_url}/tts_to_audio/", json=data)
else:
response = requests.post("http://localhost:8020/tts_to_audio/", json=data)
print(f"Response status code: {response.status_code}") # Print the response status code
print(f"Response status code: {response.status_code}")
if response.status_code == 200:
audio_data = io.BytesIO(response.content)
audio = AudioSegment.from_file(audio_data, format="wav")

if self.enable_tts_evaluation.get():
mos_score = self.evaluate_tts(text, audio)
if mos_score > best_mos:
best_audio = audio
best_mos = mos_score
if mos_score is not None:
if mos_score > best_mos:
best_audio = audio
best_mos = mos_score

if mos_score >= float(self.target_mos_value.get()):
return best_audio
if mos_score >= float(self.target_mos_value.get()):
return best_audio
else:
return audio
else:
Expand Down

0 comments on commit f3c702d

Please sign in to comment.