From 1291ed22d57d1a5c289368b57684256b33a59a5b Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 25 Jan 2025 06:36:26 +0330 Subject: [PATCH] Changed all text_hash from string to size_t for more performance. Also solved bug by accidentally pasted extra code in previous commit --- .../csrc/offline-tts-cache-mechanism.cc | 82 ++++++++++--------- .../csrc/offline-tts-cache-mechanism.h | 14 ++-- sherpa-onnx/csrc/offline-tts.cc | 6 +- sherpa-onnx/jni/offline-tts.cc | 18 ---- 4 files changed, 52 insertions(+), 68 deletions(-) diff --git a/sherpa-onnx/csrc/offline-tts-cache-mechanism.cc b/sherpa-onnx/csrc/offline-tts-cache-mechanism.cc index 049e0d90b..ecd19f4c6 100644 --- a/sherpa-onnx/csrc/offline-tts-cache-mechanism.cc +++ b/sherpa-onnx/csrc/offline-tts-cache-mechanism.cc @@ -10,6 +10,7 @@ #include #include #include +#include // for std::size_t #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" @@ -60,14 +61,14 @@ OfflineTtsCacheMechanism::~OfflineTtsCacheMechanism() { } void OfflineTtsCacheMechanism::AddWavFile( - const std::string &text_hash, + const std::size_t &text_hash, const std::vector &samples, const int32_t sample_rate) { std::lock_guard lock(mutex_); if (cache_mechanism_inited_ == false) return; - std::string file_path = cache_dir_ + "/" + text_hash + ".wav"; + std::string file_path = cache_dir_ + "/" + std::to_string(text_hash) + ".wav"; // Check if the file physically exists in the cache directory bool file_exists = std::filesystem::exists(file_path); @@ -92,7 +93,7 @@ void OfflineTtsCacheMechanism::AddWavFile( } std::vector OfflineTtsCacheMechanism::GetWavFile( - const std::string &text_hash, + const std::size_t &text_hash, int32_t *sample_rate) { std::lock_guard lock(mutex_); @@ -100,7 +101,7 @@ std::vector OfflineTtsCacheMechanism::GetWavFile( if (cache_mechanism_inited_ == false) return samples; - std::string file_path = cache_dir_ + "/" + text_hash + ".wav"; + std::string file_path = cache_dir_ + "/" + std::to_string(text_hash) + ".wav"; if (std::filesystem::exists(file_path)) { bool is_ok = false; @@ -119,12 +120,12 @@ std::vector OfflineTtsCacheMechanism::GetWavFile( } // Save the repeat counts every 10 minutes - auto now = std::chrono::steady_clock::now(); - if (std::chrono::duration_cast( - now - last_save_time_).count() >= 10 * 60) { + //auto now = std::chrono::steady_clock::now(); + //if (std::chrono::duration_cast( + //now - last_save_time_).count() >= 10 * 60) { SaveRepeatCounts(); - last_save_time_ = now; - } + //last_save_time_ = now; + //} return samples; } @@ -168,7 +169,7 @@ void OfflineTtsCacheMechanism::ClearCache() { repeat_counts_.clear(); cache_vector_.clear(); - // Remove repeat counts also in the repeat_counts.txt + // Remove repeat counts also in the repeat_counts file SaveRepeatCounts(); } @@ -183,58 +184,60 @@ int32_t OfflineTtsCacheMechanism::GetTotalUsedCacheSize() const { // Private functions /////////////////////////////////////////////////// void OfflineTtsCacheMechanism::LoadRepeatCounts() { - std::string repeat_count_file = cache_dir_ + "/repeat_counts.txt"; + std::string repeat_count_file = cache_dir_ + "/repeat_counts.bin"; // Check if the file exists if (!std::filesystem::exists(repeat_count_file)) { return; // Skip loading if the file doesn't exist } - // Open the file for reading - std::ifstream ifs(repeat_count_file); + // Open the file for reading in binary mode + std::ifstream ifs(repeat_count_file, std::ios::binary); if (!ifs.is_open()) { SHERPA_ONNX_LOGE("Failed to open repeat count file: %s", repeat_count_file.c_str()); return; // Skip loading if the file cannot be opened } - // Read the file line by line - std::string line; - while (std::getline(ifs, line)) { - size_t pos = line.find(' '); - if (pos != std::string::npos) { - std::string text_hash = line.substr(0, pos); - int32_t count = std::stoi(line.substr(pos + 1)); - repeat_counts_[text_hash] = count; - } + // Read the number of entries + size_t num_entries; + ifs.read(reinterpret_cast(&num_entries), sizeof(num_entries)); + + // Read each entry + for (size_t i = 0; i < num_entries; ++i) { + std::size_t text_hash; + int32_t count; + ifs.read(reinterpret_cast(&text_hash), sizeof(text_hash)); + ifs.read(reinterpret_cast(&count), sizeof(count)); + repeat_counts_[text_hash] = count; } } void OfflineTtsCacheMechanism::SaveRepeatCounts() { - std::string repeat_count_file = cache_dir_ + "/repeat_counts.txt"; + std::string repeat_count_file = cache_dir_ + "/repeat_counts.bin"; - // Open the file for writing - std::ofstream ofs(repeat_count_file); + // Open the file for writing in binary mode + std::ofstream ofs(repeat_count_file, std::ios::binary); if (!ofs.is_open()) { SHERPA_ONNX_LOGE("Failed to open repeat count file for writing: %s", repeat_count_file.c_str()); return; // Skip saving if the file cannot be opened } - // Write the repeat counts to the file + // Write the number of entries + size_t num_entries = repeat_counts_.size(); + ofs.write(reinterpret_cast(&num_entries), sizeof(num_entries)); + + // Write each entry for (const auto &entry : repeat_counts_) { - ofs << entry.first << " " << entry.second; - if (!ofs) { - SHERPA_ONNX_LOGE("Failed to write repeat count for text hash: %s", - entry.first.c_str()); - return; // Stop writing if an error occurs - } - ofs << std::endl; + ofs.write(reinterpret_cast(&entry.first), sizeof(entry.first)); + ofs.write(reinterpret_cast(&entry.second), sizeof(entry.second)); } } -void OfflineTtsCacheMechanism::RemoveWavFile(const std::string &text_hash) { - std::string file_path = cache_dir_ + "/" + text_hash + ".wav"; +void OfflineTtsCacheMechanism::RemoveWavFile(const std::size_t &text_hash) { + std::string file_path = cache_dir_ + "/" + + std::to_string(text_hash) + ".wav"; if (std::filesystem::exists(file_path)) { // Subtract the size of the removed WAV file from the total cache size std::ifstream file(file_path, std::ios::binary | std::ios::ate); @@ -259,7 +262,8 @@ void OfflineTtsCacheMechanism::UpdateCacheVector() { for (const auto &entry : std::filesystem::directory_iterator(cache_dir_)) { if (entry.path().extension() == ".wav") { - std::string text_hash = entry.path().stem().string(); + std::string text_hash_str = entry.path().stem().string(); + std::size_t text_hash = std::stoull(text_hash_str); if (repeat_counts_.find(text_hash) == repeat_counts_.end()) { // Remove the file if it's not in the repeat count file (orphaned file) std::filesystem::remove(entry.path()); @@ -282,14 +286,14 @@ void OfflineTtsCacheMechanism::EnsureCacheLimit() { while (used_cache_size_bytes_> 0 && used_cache_size_bytes_ > target_cache_size) { // Cache is full, remove the least repeated file - std::string least_repeated_file = GetLeastRepeatedFile(); + std::size_t least_repeated_file = GetLeastRepeatedFile(); RemoveWavFile(least_repeated_file); } } } -std::string OfflineTtsCacheMechanism::GetLeastRepeatedFile() { - std::string least_repeated_file; +std::size_t OfflineTtsCacheMechanism::GetLeastRepeatedFile() { + std::size_t least_repeated_file = 0; int32_t min_count = std::numeric_limits::max(); for (const auto &entry : repeat_counts_) { diff --git a/sherpa-onnx/csrc/offline-tts-cache-mechanism.h b/sherpa-onnx/csrc/offline-tts-cache-mechanism.h index 0c672e1c2..b1e5dd0da 100644 --- a/sherpa-onnx/csrc/offline-tts-cache-mechanism.h +++ b/sherpa-onnx/csrc/offline-tts-cache-mechanism.h @@ -9,6 +9,7 @@ #include #include #include // NOLINT +#include // for std::size_t #include "sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h" @@ -16,19 +17,18 @@ namespace sherpa_onnx { class OfflineTtsCacheMechanism { public: - explicit OfflineTtsCacheMechanism(const OfflineTtsCacheMechanismConfig &config); ~OfflineTtsCacheMechanism(); // Add a new wav file to the cache void AddWavFile( - const std::string &text_hash, + const std::size_t &text_hash, const std::vector &samples, const int32_t sample_rate); // Get the cached wav file if it exists std::vector GetWavFile( - const std::string &text_hash, + const std::size_t &text_hash, int32_t *sample_rate); // Get the current cache size in bytes @@ -51,7 +51,7 @@ class OfflineTtsCacheMechanism { void SaveRepeatCounts(); // Remove a wav file from the cache - void RemoveWavFile(const std::string &text_hash); + void RemoveWavFile(const std::size_t &text_hash); // Update the cache vector with the actual files in the cache folder void UpdateCacheVector(); @@ -60,7 +60,7 @@ class OfflineTtsCacheMechanism { void EnsureCacheLimit(); // Get the least repeated file in the cache - std::string GetLeastRepeatedFile(); + std::size_t GetLeastRepeatedFile(); // Data directory where the cache folder is located std::string cache_dir_; @@ -72,10 +72,10 @@ class OfflineTtsCacheMechanism { int32_t used_cache_size_bytes_; // Map of text hash to repeat count - std::unordered_map repeat_counts_; + std::unordered_map repeat_counts_; // Vector of cached file names - std::vector cache_vector_; + std::vector cache_vector_; // Mutex for thread safety (recursive to avoid deadlocks) mutable std::recursive_mutex mutex_; diff --git a/sherpa-onnx/csrc/offline-tts.cc b/sherpa-onnx/csrc/offline-tts.cc index decb5a526..c7a52a88b 100644 --- a/sherpa-onnx/csrc/offline-tts.cc +++ b/sherpa-onnx/csrc/offline-tts.cc @@ -112,8 +112,7 @@ GeneratedAudio OfflineTts::Generate( GeneratedAudioCallback callback /*= nullptr*/) const { // Generate a hash for the text std::hash hasher; - std::string text_hash = std::to_string(hasher(text)); - // SHERPA_ONNX_LOGE("Generated text hash: %s", text_hash.c_str()); + std::size_t text_hash = hasher(text); // Check if the cache mechanism is active and if the audio is already cached if (cache_mechanism_) { @@ -122,7 +121,7 @@ GeneratedAudio OfflineTts::Generate( = cache_mechanism_->GetWavFile(text_hash, &sample_rate); if (!samples.empty()) { - SHERPA_ONNX_LOGE("Returning cached audio for hash:%s", text_hash.c_str()); + SHERPA_ONNX_LOGE("Returning cached audio for hash: %zu", text_hash); // If a callback is provided, call it with the cached audio if (callback) { @@ -146,7 +145,6 @@ GeneratedAudio OfflineTts::Generate( // Cache the generated audio if the cache mechanism is active if (cache_mechanism_) { cache_mechanism_->AddWavFile(text_hash, audio.samples, audio.sample_rate); - // SHERPA_ONNX_LOGE("Cached audio for text hash: %s", text_hash.c_str()); } return audio; diff --git a/sherpa-onnx/jni/offline-tts.cc b/sherpa-onnx/jni/offline-tts.cc index 1191d47e8..8629e109e 100644 --- a/sherpa-onnx/jni/offline-tts.cc +++ b/sherpa-onnx/jni/offline-tts.cc @@ -196,24 +196,6 @@ static OfflineTtsCacheMechanismConfig GetOfflineTtsCacheConfig(JNIEnv *env, jobj return ans; } - // Get data directory from config - jfieldID model_fid = env->GetFieldID(cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;"); - jobject model_config = env->GetObjectField(config, model_fid); - jclass model_cls = env->GetObjectClass(model_config); - - jfieldID vits_fid = env->GetFieldID(model_cls, "vits", "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;"); - jobject vits_config = env->GetObjectField(model_config, vits_fid); - - fid = env->GetFieldID(vits_cls, "dataDir", "Ljava/lang/String;"); - jstring data_dir = (jstring)env->GetObjectField(vits_config, fid); - const char *p_data_dir = env->GetStringUTFChars(data_dir, nullptr); - - // Convert data directory to cache directory - std::string cache_dir = std::string(p_data_dir) + "/../cache"; - ans.cache_dir = cache_dir; - - env->ReleaseStringUTFChars(data_dir, p_data_dir); - } // namespace sherpa_onnx SHERPA_ONNX_EXTERN_C