Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add new TTS examples for Kokoro, Vits, and Matcha models #75

Merged
merged 4 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ sherpa-onnx-kws-*
jniLibs/
build/
kokoro-en-*/
matcha-*
/
14 changes: 12 additions & 2 deletions crates/sherpa-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,19 @@ cuda = ["sherpa-rs-sys/cuda"]
directml = ["sherpa-rs-sys/directml"]

[[example]]
name = "tts"
name = "tts_kokoro"
required-features = ["tts"]
path = "../../examples/tts.rs"
path = "../../examples/tts_kokoro.rs"

[[example]]
name = "tts_vits"
required-features = ["tts"]
path = "../../examples/tts_vits.rs"

[[example]]
name = "tts_matcha"
required-features = ["tts"]
path = "../../examples/tts_matcha.rs"

[[example]]
name = "audio_tag"
Expand Down
62 changes: 51 additions & 11 deletions crates/sherpa-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ pub use sherpa_rs_sys;
use eyre::{bail, Result};

pub fn get_default_provider() -> String {
if cfg!(feature = "cuda") {
"cuda"
} else if cfg!(target_os = "macos") {
"coreml"
} else if cfg!(feature = "directml") {
"directml"
} else {
"cpu"
}
.into()
"cpu".into()
// Other providers has many issues with different models!!
// if cfg!(feature = "cuda") {
// "cuda"
// } else if cfg!(target_os = "macos") {
// "coreml"
// } else if cfg!(feature = "directml") {
// "directml"
// } else {
// "cpu"
// }
// .into()
}

pub fn read_audio_file(path: &str) -> Result<(Vec<f32>, u32)> {
Expand All @@ -45,8 +47,46 @@ pub fn read_audio_file(path: &str) -> Result<(Vec<f32>, u32)> {
// Collect samples into a Vec<f32>
let samples: Vec<f32> = reader
.samples::<i16>()
.map(|s| s.unwrap() as f32 / i16::MAX as f32)
.map(|s| (s.unwrap() as f32) / (i16::MAX as f32))
.collect();

Ok((samples, sample_rate))
}

pub fn write_audio_file(path: &str, samples: &[f32], sample_rate: u32) -> Result<()> {
// Create a WAV file writer
let spec = hound::WavSpec {
channels: 1,
sample_rate,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};

let mut writer = hound::WavWriter::create(path, spec)?;

// Convert samples from f32 to i16 and write them to the WAV file
for &sample in samples {
let scaled_sample =
(sample * (i16::MAX as f32)).clamp(i16::MIN as f32, i16::MAX as f32) as i16;
writer.write_sample(scaled_sample)?;
}

writer.finalize()?;
Ok(())
}

pub struct OnnxConfig {
pub provider: String,
pub debug: bool,
pub num_threads: i32,
}

impl Default for OnnxConfig {
fn default() -> Self {
Self {
provider: get_default_provider(),
debug: false,
num_threads: 1,
}
}
}
189 changes: 0 additions & 189 deletions crates/sherpa-rs/src/tts.rs

This file was deleted.

65 changes: 65 additions & 0 deletions crates/sherpa-rs/src/tts/kokoro.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use std::{mem, ptr::null};

use crate::{utils::RawCStr, OnnxConfig};
use eyre::Result;
use sherpa_rs_sys;

use super::{CommonTtsConfig, TtsAudio};

pub struct KokoroTts {
tts: *const sherpa_rs_sys::SherpaOnnxOfflineTts,
}

#[derive(Default)]
pub struct KokoroTtsConfig {
pub model: String,
pub voices: String,
pub tokens: String,
pub data_dir: String,
pub length_scale: f32,
pub onnx_config: OnnxConfig,
pub common_config: CommonTtsConfig,
}

impl KokoroTts {
pub fn new(config: KokoroTtsConfig) -> Self {
let tts = unsafe {
let model = RawCStr::new(&config.model);
let voices = RawCStr::new(&config.voices);
let tokens = RawCStr::new(&config.tokens);
let data_dir = RawCStr::new(&config.data_dir);

let provider = RawCStr::new(&config.onnx_config.provider);

let tts_config = config.common_config.to_raw();

let model_config = sherpa_rs_sys::SherpaOnnxOfflineTtsModelConfig {
vits: mem::zeroed::<_>(),
num_threads: config.onnx_config.num_threads,
debug: config.onnx_config.debug.into(),
provider: provider.as_ptr(),
matcha: mem::zeroed::<_>(),
kokoro: sherpa_rs_sys::SherpaOnnxOfflineTtsKokoroModelConfig {
model: model.as_ptr(),
voices: voices.as_ptr(),
tokens: tokens.as_ptr(),
data_dir: data_dir.as_ptr(),
length_scale: config.length_scale,
},
};
let config = sherpa_rs_sys::SherpaOnnxOfflineTtsConfig {
max_num_sentences: 0,
model: model_config,
rule_fars: tts_config.rule_fars.map(|v| v.as_ptr()).unwrap_or(null()),
rule_fsts: tts_config.rule_fsts.map(|v| v.as_ptr()).unwrap_or(null()),
};
sherpa_rs_sys::SherpaOnnxCreateOfflineTts(&config)
};

Self { tts }
}

pub fn create(&mut self, text: &str, sid: i32, speed: f32) -> Result<TtsAudio> {
unsafe { super::create(self.tts, text, sid, speed) }
}
}
Loading