Skip to content

Commit

Permalink
corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
mah92 committed Feb 16, 2025
1 parent 680c679 commit 5d69dab
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ class MainActivity : ComponentActivity() {
val RTF = String.format(
"Number of threads: %d\nElapsed: %.3f s\nAudio duration: %.3f s\nRTF: %.3f/%.3f = %.3f",
TtsEngine.tts!!.config.model.numThreads,
audioDuration,
elapsed,
audioDuration,
elapsed,
audioDuration,
elapsed / audioDuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ object TtsEngine {
//
// This model supports many languages, e.g., English, Chinese, etc.
// We set lang to eng here.


}

fun createTts(context: Context) {
Expand Down Expand Up @@ -221,7 +223,10 @@ object TtsEngine {
speed = PreferenceHelper(context).getSpeed()
speakerId = PreferenceHelper(context).getSid()

tts = OfflineTts(assetManager = assets, config = config, cacheConfig = cacheConfig)
OfflineTtsCacheMechanismConfig config
auto cache = new OfflineTtsCacheMechanism(config)

tts = OfflineTts(assetManager = assets, config = config, cache = cache)
}


Expand Down
23 changes: 16 additions & 7 deletions sherpa-onnx/csrc/offline-tts-cache-mechanism.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <filesystem>
#include <iostream>
#include <limits>
#include <thread>
#include <cstddef> // for std::size_t

#include "sherpa-onnx/csrc/file-utils.h"
Expand Down Expand Up @@ -74,8 +75,6 @@ void OfflineTtsCacheMechanism::AddWavFile(
bool file_exists = std::filesystem::exists(file_path);

if (!file_exists) { // If the file does not exist, add it to the cache
// Ensure the cache does not exceed its size limit
EnsureCacheLimit();

// Write the audio samples to a WAV file
bool success = WriteWave(file_path,
Expand All @@ -86,6 +85,10 @@ void OfflineTtsCacheMechanism::AddWavFile(
if (file.is_open()) {
used_cache_size_bytes_ += file.tellg();
}

// Ensure the cache does not exceed its size limit, non-blocking
EnsureCacheLimit();

} else {
SHERPA_ONNX_LOGE("Failed to write wav file: %s", file_path.c_str());
}
Expand Down Expand Up @@ -290,15 +293,21 @@ void OfflineTtsCacheMechanism::UpdateCacheVector() {
}

void OfflineTtsCacheMechanism::EnsureCacheLimit() {
std::lock_guard<std::recursive_mutex> lock(mutex_); // Lock the mutex for the entire function

if (used_cache_size_bytes_ > cache_size_bytes_) {
auto target_cache_size
= std::max(static_cast<int> (cache_size_bytes_*0.95), 0);
while (used_cache_size_bytes_> 0
&& used_cache_size_bytes_ > target_cache_size) {
// Launch a new thread to handle cache cleanup in a non-blocking way
std::thread([this]() {
std::lock_guard<std::recursive_mutex> lock(mutex_); // Lock the mutex for the cleanup process

auto target_cache_size = std::max(static_cast<int>(cache_size_bytes_ * 0.95), 0);
while (used_cache_size_bytes_ > 0
&& used_cache_size_bytes_ > target_cache_size) {
// Cache is full, remove the least repeated file
std::size_t least_repeated_file = GetLeastRepeatedFile();
RemoveWavFile(least_repeated_file);
}
}
}).detach(); // Detach the thread to run independently
}
}

Expand Down
58 changes: 53 additions & 5 deletions sherpa-onnx/csrc/offline-tts-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ std::vector<int64_t> OfflineTtsImpl::AddBlank(const std::vector<int64_t> &x,
}

std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
const OfflineTtsConfig &config) {
const OfflineTtsConfig &config, OfflineTtsCacheMechanism* cache) {
cache_ = cache;
if (!config.model.vits.model.empty()) {
return std::make_unique<OfflineTtsVitsImpl>(config);
} else if (!config.model.matcha.acoustic_model.empty()) {
Expand All @@ -47,7 +48,8 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(

template <typename Manager>
std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
Manager *mgr, const OfflineTtsConfig &config) {
Manager *mgr, const OfflineTtsConfig &config, OfflineTtsCacheMechanism* cache) {
cache_ = cache;
if (!config.model.vits.model.empty()) {
return std::make_unique<OfflineTtsVitsImpl>(mgr, config);
} else if (!config.model.matcha.acoustic_model.empty()) {
Expand All @@ -59,12 +61,58 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(

#if __ANDROID_API__ >= 9
template std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
AAssetManager *mgr, const OfflineTtsConfig &config);
AAssetManager *mgr, const OfflineTtsConfig &config, OfflineTtsCacheMechanism* cache);
#endif

#if __OHOS__
template std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
NativeResourceManager *mgr, const OfflineTtsConfig &config);
NativeResourceManager *mgr, const OfflineTtsConfig &config, OfflineTtsCacheMechanism* cache);
#endif

GeneratedAudio OfflineTtsImpl::GenerateWitchCache(
const std::string &text, int64_t sid, float speed,
GeneratedAudioCallback callback) const
{
// Generate a hash for the text
std::hash<std::string> hasher;
std::size_t text_hash = hasher(text);

//In phones, long texts come from messages, websites and book which are usually not repeated. Repeated text comes from menus and settings which are usually short
bool text_is_long = text.length() > 50? true: false;

// Check if the cache mechanism is active and if the audio is already cached
if (cache_ && !text_is_long) {
int32_t sample_rate;
std::vector<float> samples
= cache_->GetWavFile(text_hash, &sample_rate);

if (!samples.empty()) {
SHERPA_ONNX_LOGE("Returning cached audio for hash: %zu", text_hash);

// If a callback is provided, call it with the cached audio
if (callback) {
int32_t result
= callback(samples.data(), samples.size(), 1.0f /* progress */);
if (result == 0) {
// If the callback returns 0, stop further processing
SHERPA_ONNX_LOGE("Callback requested to stop processing.");
return {samples, sample_rate};
}
}

// Return the cached audio
return {samples, sample_rate};
}
}

auto audio = Generate(text, sid, speed, callback);

// Cache the generated audio if the cache mechanism is active
if (cache_ && !text_is_long) {
cache_->AddWavFile(text_hash, audio.samples, audio.sample_rate);
}

return audio;
}

} // namespace sherpa_onnx
13 changes: 11 additions & 2 deletions sherpa-onnx/csrc/offline-tts-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,23 @@ class OfflineTtsImpl {
public:
virtual ~OfflineTtsImpl() = default;

static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);
static std::unique_ptr<OfflineTtsImpl> Create(
const OfflineTtsConfig &config,
OfflineTtsCacheMechanism* cache = nullptr);

template <typename Manager>
static std::unique_ptr<OfflineTtsImpl> Create(Manager *mgr,
const OfflineTtsConfig &config);
const OfflineTtsConfig &config,
OfflineTtsCacheMechanism* cache = nullptr);

virtual GeneratedAudio Generate(
const std::string &text, int64_t sid = 0, float speed = 1.0,
GeneratedAudioCallback callback = nullptr) const = 0;

GeneratedAudio GenerateWitchCache(
const std::string &text, int64_t sid = 0, float speed = 1.0,
GeneratedAudioCallback callback = nullptr) const;

// Return the sample rate of the generated audio
virtual int32_t SampleRate() const = 0;

Expand All @@ -36,6 +43,8 @@ class OfflineTtsImpl {

std::vector<int64_t> AddBlank(const std::vector<int64_t> &x,
int32_t blank_id = 0) const;
private:
static OfflineTtsCacheMechanism *cache_; // not owned here
};

} // namespace sherpa_onnx
Expand Down
72 changes: 10 additions & 62 deletions sherpa-onnx/csrc/offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,68 +162,27 @@ std::string OfflineTtsConfig::ToString() const {
return os.str();
}

OfflineTts::OfflineTts(const OfflineTtsConfig &config)
: impl_(OfflineTtsImpl::Create(config)) {}

OfflineTts::OfflineTts(const OfflineTtsConfig &config,
const OfflineTtsCacheMechanismConfig &cache_config)
: impl_(OfflineTtsImpl::Create(config)) {
cache_mechanism_ = std::make_unique<OfflineTtsCacheMechanism>(cache_config);
}

template <typename Manager>
OfflineTts::OfflineTts(Manager *mgr, const OfflineTtsConfig &config)
: impl_(OfflineTtsImpl::Create(mgr, config)) {}
OfflineTtsCacheMechanism *cache)
: impl_(OfflineTtsImpl::Create(config, cache)) {}

template <typename Manager>
OfflineTts::OfflineTts(Manager *mgr, const OfflineTtsConfig &config,
const OfflineTtsCacheMechanismConfig &cache_config)
: impl_(OfflineTtsImpl::Create(mgr, config)) {
cache_mechanism_ = std::make_unique<OfflineTtsCacheMechanism>(cache_config);
}
OfflineTtsCacheMechanism *cache)
: impl_(OfflineTtsImpl::Create(mgr, config, cache)) {}

OfflineTts::~OfflineTts() = default;

GeneratedAudio OfflineTts::Generate(
const std::string &text, int64_t sid /*=0*/, float speed /*= 1.0*/,
GeneratedAudioCallback callback /*= nullptr*/) const {

// Generate a hash for the text
std::hash<std::string> hasher;
std::size_t text_hash = hasher(text);
GeneratedAudio audio;

// Check if the cache mechanism is active and if the audio is already cached
if (cache_mechanism_) {
int32_t sample_rate;
std::vector<float> samples
= cache_mechanism_->GetWavFile(text_hash, &sample_rate);

if (!samples.empty()) {
SHERPA_ONNX_LOGE("Returning cached audio for hash: %zu", text_hash);

// If a callback is provided, call it with the cached audio
if (callback) {
int32_t result
= callback(samples.data(), samples.size(), 1.0f /* progress */);
if (result == 0) {
// If the callback returns 0, stop further processing
SHERPA_ONNX_LOGE("Callback requested to stop processing.");
return {samples, sample_rate};
}
}

// Return the cached audio
return {samples, sample_rate};
}
}

// Generate the audio if not cached
#if !defined(_WIN32)
audio = impl_->Generate(text, sid, speed, std::move(callback));
return impl_->GenerateWitchCache(text, sid, speed, std::move(callback));
#else
if (IsUtf8(text)) {
audio = impl_->Generate(text, sid, speed, std::move(callback));
return impl_->GenerateWitchCache(text, sid, speed, std::move(callback));
} else if (IsGB2312(text)) {
auto utf8_text = Gb2312ToUtf8(text);
static bool printed = false;
Expand All @@ -232,41 +191,30 @@ GeneratedAudio OfflineTts::Generate(
"Detected GB2312 encoded string! Converting it to UTF8.");
printed = true;
}
audio = impl_->Generate(utf8_text, sid, speed, std::move(callback));
return impl_->GenerateWitchCache(utf8_text, sid, speed, std::move(callback));
} else {
SHERPA_ONNX_LOGE(
"Non UTF8 encoded string is received. You would not get expected "
"results!");
audio = impl_->Generate(text, sid, speed, std::move(callback));
return impl_->GenerateWitchCache(text, sid, speed, std::move(callback));
}
#endif

// Cache the generated audio if the cache mechanism is active
if (cache_mechanism_) {
cache_mechanism_->AddWavFile(text_hash, audio.samples, audio.sample_rate);
}

return audio;
}

int32_t OfflineTts::SampleRate() const { return impl_->SampleRate(); }

int32_t OfflineTts::NumSpeakers() const { return impl_->NumSpeakers(); }

#if __ANDROID_API__ >= 9
template OfflineTts::OfflineTts(AAssetManager *mgr,
const OfflineTtsConfig &config);
template OfflineTts::OfflineTts(AAssetManager *mgr,
const OfflineTtsConfig &config,
const OfflineTtsCacheMechanismConfig &cache_config);
OfflineTtsCacheMechanism *cache = nullptr);
#endif

#if __OHOS__
template OfflineTts::OfflineTts(NativeResourceManager *mgr,
const OfflineTtsConfig &config);
template OfflineTts::OfflineTts(NativeResourceManager *mgr,
const OfflineTtsConfig &config,
const OfflineTtsCacheMechanismConfig &cache_config);
OfflineTtsCacheMechanism *cache = nullptr);
#endif

} // namespace sherpa_onnx
10 changes: 2 additions & 8 deletions sherpa-onnx/csrc/offline-tts.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,12 @@ using GeneratedAudioCallback = std::function<int32_t(
class OfflineTts {
public:
~OfflineTts();
explicit OfflineTts(const OfflineTtsConfig &config);
explicit OfflineTts(const OfflineTtsConfig &config,
const OfflineTtsCacheMechanismConfig &cache_config);

template <typename Manager>
OfflineTts(Manager *mgr, const OfflineTtsConfig &config);
OfflineTtsCacheMechanism *cache = nullptr);

template <typename Manager>
OfflineTts(Manager *mgr, const OfflineTtsConfig &config,
const OfflineTtsCacheMechanismConfig &cache_config);
OfflineTtsCacheMechanism *cache = nullptr);

// @param text A string containing words separated by spaces
// @param sid Speaker ID. Used only for multi-speaker models, e.g., models
Expand All @@ -110,8 +106,6 @@ class OfflineTts {
// If it supports only a single speaker, then it return 0 or 1.
int32_t NumSpeakers() const;

std::unique_ptr<OfflineTtsCacheMechanism> cache_mechanism_; // not owned here

private:
std::unique_ptr<OfflineTtsImpl> impl_;
};
Expand Down
Loading

0 comments on commit 5d69dab

Please sign in to comment.