From 316424b3825db2ac474042917a9c0d9bf4d8c84f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 16 Feb 2025 22:45:24 +0800 Subject: [PATCH] Add C++ and Python API for FireRedASR AED models (#1867) --- .gitignore | 1 + .../offline-fire-red-asr-decode-files.py | 75 +++++ sherpa-onnx/csrc/CMakeLists.txt | 3 + .../csrc/offline-fire-red-asr-decoder.h | 39 +++ ...line-fire-red-asr-greedy-search-decoder.cc | 87 ++++++ ...fline-fire-red-asr-greedy-search-decoder.h | 29 ++ .../csrc/offline-fire-red-asr-model-config.cc | 56 ++++ .../csrc/offline-fire-red-asr-model-config.h | 31 +++ .../offline-fire-red-asr-model-meta-data.h | 28 ++ .../csrc/offline-fire-red-asr-model.cc | 256 ++++++++++++++++++ sherpa-onnx/csrc/offline-fire-red-asr-model.h | 92 +++++++ sherpa-onnx/csrc/offline-model-config.cc | 8 +- sherpa-onnx/csrc/offline-model-config.h | 4 + .../offline-recognizer-fire-red-asr-impl.h | 158 +++++++++++ sherpa-onnx/csrc/offline-recognizer-impl.cc | 9 + sherpa-onnx/python/csrc/CMakeLists.txt | 1 + .../csrc/offline-fire-red-asr-model-config.cc | 24 ++ .../csrc/offline-fire-red-asr-model-config.h | 16 ++ .../python/csrc/offline-model-config.cc | 55 ++-- .../python/sherpa_onnx/offline_recognizer.py | 73 +++++ 20 files changed, 1019 insertions(+), 26 deletions(-) create mode 100644 python-api-examples/offline-fire-red-asr-decode-files.py create mode 100644 sherpa-onnx/csrc/offline-fire-red-asr-decoder.h create mode 100644 sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc create mode 100644 sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h create mode 100644 sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-fire-red-asr-model-config.h create mode 100644 sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h create mode 100644 sherpa-onnx/csrc/offline-fire-red-asr-model.cc create mode 100644 sherpa-onnx/csrc/offline-fire-red-asr-model.h create mode 100644 sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h create mode 100644 sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h diff --git a/.gitignore b/.gitignore index ea356b0652..53bddb47e2 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,4 @@ lexicon.txt us_gold.json us_silver.json kokoro-multi-lang-v1_0 +sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16 diff --git a/python-api-examples/offline-fire-red-asr-decode-files.py b/python-api-examples/offline-fire-red-asr-decode-files.py new file mode 100644 index 0000000000..3b4c69644f --- /dev/null +++ b/python-api-examples/offline-fire-red-asr-decode-files.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming FireRedAsr AED model from +https://github.com/FireRedTeam/FireRedASR +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + +For instance, + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +""" + +from pathlib import Path + +import sherpa_onnx +import soundfile as sf + + +def create_recognizer(): + encoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx" + decoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx" + tokens = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt" + test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/1.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/2.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/8k.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3-sichuan.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/4-tianjin.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/5-henan.wav" + + if ( + not Path(encoder).is_file() + or not Path(decoder).is_file() + or not Path(test_wav).is_file() + ): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_onnx.OfflineRecognizer.from_fire_red_asr( + encoder=encoder, + decoder=decoder, + tokens=tokens, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 5ee7a50507..40f0ee6cd3 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -27,6 +27,9 @@ set(sources offline-ctc-fst-decoder.cc offline-ctc-greedy-search-decoder.cc offline-ctc-model.cc + offline-fire-red-asr-greedy-search-decoder.cc + offline-fire-red-asr-model-config.cc + offline-fire-red-asr-model.cc offline-lm-config.cc offline-lm.cc offline-model-config.cc diff --git a/sherpa-onnx/csrc/offline-fire-red-asr-decoder.h b/sherpa-onnx/csrc/offline-fire-red-asr-decoder.h new file mode 100644 index 0000000000..9d60cb4962 --- /dev/null +++ b/sherpa-onnx/csrc/offline-fire-red-asr-decoder.h @@ -0,0 +1,39 @@ +// sherpa-onnx/csrc/offline-fire-red-asr-decoder.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_ + +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct OfflineFireRedAsrDecoderResult { + /// The decoded token IDs + std::vector tokens; +}; + +class OfflineFireRedAsrDecoder { + public: + virtual ~OfflineFireRedAsrDecoder() = default; + + /** Run beam search given the output from the FireRedAsr encoder model. + * + * @param n_layer_cross_k A 4-D tensor of shape + * (num_decoder_layers, N, T, d_model). + * @param n_layer_cross_v A 4-D tensor of shape + * (num_decoder_layers, N, T, d_model). + * + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc new file mode 100644 index 0000000000..4e758fa427 --- /dev/null +++ b/sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc @@ -0,0 +1,87 @@ +// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +// Note: this functions works only for batch size == 1 at present +std::vector +OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k, + Ort::Value cross_v) { + const auto &meta_data = model_->GetModelMetadata(); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + // For multilingual models, initial_tokens contains [sot, language, task] + // - language is English by default + // - task is transcribe by default + // + // For non-multilingual models, initial_tokens contains [sot] + std::array token_shape = {1, 1}; + int64_t token = meta_data.sos_id; + + int32_t batch_size = 1; + + Ort::Value tokens = Ort::Value::CreateTensor( + memory_info, &token, 1, token_shape.data(), token_shape.size()); + + std::array offset_shape{1}; + Ort::Value offset = Ort::Value::CreateTensor( + model_->Allocator(), offset_shape.data(), offset_shape.size()); + *(offset.GetTensorMutableData()) = 0; + + std::vector ans(1); + + auto self_kv_cache = model_->GetInitialSelfKVCache(); + + std::tuple + decoder_out = {Ort::Value{nullptr}, + std::move(self_kv_cache.first), + std::move(self_kv_cache.second), + std::move(cross_k), + std::move(cross_v), + std::move(offset)}; + + for (int32_t i = 0; i < meta_data.max_len; ++i) { + decoder_out = model_->ForwardDecoder(View(&tokens), + std::move(std::get<1>(decoder_out)), + std::move(std::get<2>(decoder_out)), + std::move(std::get<3>(decoder_out)), + std::move(std::get<4>(decoder_out)), + std::move(std::get<5>(decoder_out))); + + const auto &logits = std::get<0>(decoder_out); + const float *p_logits = logits.GetTensorData(); + + auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape(); + int32_t vocab_size = logits_shape[2]; + + int32_t max_token_id = static_cast(std::distance( + p_logits, std::max_element(p_logits, p_logits + vocab_size))); + if (max_token_id == meta_data.eos_id) { + break; + } + + ans[0].tokens.push_back(max_token_id); + + token = max_token_id; + + // increment offset + *(std::get<5>(decoder_out).GetTensorMutableData()) += 1; + } + + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h new file mode 100644 index 0000000000..30302df534 --- /dev/null +++ b/sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h" +#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h" + +namespace sherpa_onnx { + +class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder { + public: + explicit OfflineFireRedAsrGreedySearchDecoder(OfflineFireRedAsrModel *model) + : model_(model) {} + + std::vector Decode( + Ort::Value cross_k, Ort::Value cross_v) override; + + private: + OfflineFireRedAsrModel *model_; // not owned +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc b/sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc new file mode 100644 index 0000000000..53eb933763 --- /dev/null +++ b/sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc @@ -0,0 +1,56 @@ +// sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineFireRedAsrModelConfig::Register(ParseOptions *po) { + po->Register("fire-red-asr-encoder", &encoder, + "Path to onnx encoder of FireRedAsr"); + + po->Register("fire-red-asr-decoder", &decoder, + "Path to onnx decoder of FireRedAsr"); +} + +bool OfflineFireRedAsrModelConfig::Validate() const { + if (encoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --fire-red-asr-encoder"); + return false; + } + + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("FireRedAsr encoder file '%s' does not exist", + encoder.c_str()); + return false; + } + + if (decoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --fire-red-asr-decoder"); + return false; + } + + if (!FileExists(decoder)) { + SHERPA_ONNX_LOGE("FireRedAsr decoder file '%s' does not exist", + decoder.c_str()); + return false; + } + + return true; +} + +std::string OfflineFireRedAsrModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineFireRedAsrModelConfig("; + os << "encoder=\"" << encoder << "\", "; + os << "decoder=\"" << decoder << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-fire-red-asr-model-config.h b/sherpa-onnx/csrc/offline-fire-red-asr-model-config.h new file mode 100644 index 0000000000..48d3b9d1d6 --- /dev/null +++ b/sherpa-onnx/csrc/offline-fire-red-asr-model-config.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/offline-fire-red-asr-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +// see https://github.com/FireRedTeam/FireRedASR +struct OfflineFireRedAsrModelConfig { + std::string encoder; + std::string decoder; + + OfflineFireRedAsrModelConfig() = default; + OfflineFireRedAsrModelConfig(const std::string &encoder, + const std::string &decoder) + : encoder(encoder), decoder(decoder) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h b/sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h new file mode 100644 index 0000000000..0d2a57d1ea --- /dev/null +++ b/sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_ + +#include +#include +#include + +namespace sherpa_onnx { + +struct OfflineFireRedAsrModelMetaData { + int32_t sos_id; + int32_t eos_id; + int32_t max_len; + + int32_t num_decoder_layers; + int32_t num_head; + int32_t head_dim; + + std::vector mean; + std::vector inv_stddev; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/offline-fire-red-asr-model.cc b/sherpa-onnx/csrc/offline-fire-red-asr-model.cc new file mode 100644 index 0000000000..bf45399456 --- /dev/null +++ b/sherpa-onnx/csrc/offline-fire-red-asr-model.cc @@ -0,0 +1,256 @@ +// sherpa-onnx/csrc/offline-fire-red-asr-model.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h" + +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineFireRedAsrModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.fire_red_asr.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.fire_red_asr.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.fire_red_asr.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.fire_red_asr.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + std::pair ForwardEncoder(Ort::Value features, + Ort::Value features_length) { + std::array inputs{std::move(features), + std::move(features_length)}; + + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(), + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); + + return {std::move(encoder_out[0]), std::move(encoder_out[1])}; + } + + std::tuple + ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, + Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, + Ort::Value n_layer_cross_v, Ort::Value offset) { + std::array decoder_input = {std::move(tokens), + std::move(n_layer_self_k_cache), + std::move(n_layer_self_v_cache), + std::move(n_layer_cross_k), + std::move(n_layer_cross_v), + std::move(offset)}; + + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), decoder_input.data(), + decoder_input.size(), decoder_output_names_ptr_.data(), + decoder_output_names_ptr_.size()); + + return std::tuple{ + std::move(decoder_out[0]), std::move(decoder_out[1]), + std::move(decoder_out[2]), std::move(decoder_input[3]), + std::move(decoder_input[4]), std::move(decoder_input[5])}; + } + + std::pair GetInitialSelfKVCache() { + int32_t batch_size = 1; + std::array shape{meta_data_.num_decoder_layers, batch_size, + meta_data_.max_len, meta_data_.num_head, + meta_data_.head_dim}; + + Ort::Value n_layer_self_k_cache = Ort::Value::CreateTensor( + Allocator(), shape.data(), shape.size()); + + Ort::Value n_layer_self_v_cache = Ort::Value::CreateTensor( + Allocator(), shape.data(), shape.size()); + + auto n = shape[0] * shape[1] * shape[2] * shape[3] * shape[4]; + + float *p_k = n_layer_self_k_cache.GetTensorMutableData(); + float *p_v = n_layer_self_v_cache.GetTensorMutableData(); + + memset(p_k, 0, sizeof(float) * n); + memset(p_v, 0, sizeof(float) * n); + + return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)}; + } + + OrtAllocator *Allocator() { return allocator_; } + + const OfflineFireRedAsrModelMetaData &GetModelMetadata() const { + return meta_data_; + } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.num_decoder_layers, + "num_decoder_layers"); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_head, "num_head"); + SHERPA_ONNX_READ_META_DATA(meta_data_.head_dim, "head_dim"); + SHERPA_ONNX_READ_META_DATA(meta_data_.sos_id, "sos"); + SHERPA_ONNX_READ_META_DATA(meta_data_.eos_id, "eos"); + SHERPA_ONNX_READ_META_DATA(meta_data_.max_len, "max_len"); + + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.mean, "cmvn_mean"); + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev, + "cmvn_inv_stddev"); + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + private: + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + OfflineFireRedAsrModelMetaData meta_data_; +}; + +OfflineFireRedAsrModel::OfflineFireRedAsrModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineFireRedAsrModel::OfflineFireRedAsrModel(Manager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineFireRedAsrModel::~OfflineFireRedAsrModel() = default; + +std::pair OfflineFireRedAsrModel::ForwardEncoder( + Ort::Value features, Ort::Value features_length) const { + return impl_->ForwardEncoder(std::move(features), std::move(features_length)); +} + +std::tuple +OfflineFireRedAsrModel::ForwardDecoder(Ort::Value tokens, + Ort::Value n_layer_self_k_cache, + Ort::Value n_layer_self_v_cache, + Ort::Value n_layer_cross_k, + Ort::Value n_layer_cross_v, + Ort::Value offset) const { + return impl_->ForwardDecoder( + std::move(tokens), std::move(n_layer_self_k_cache), + std::move(n_layer_self_v_cache), std::move(n_layer_cross_k), + std::move(n_layer_cross_v), std::move(offset)); +} + +std::pair +OfflineFireRedAsrModel::GetInitialSelfKVCache() const { + return impl_->GetInitialSelfKVCache(); +} + +OrtAllocator *OfflineFireRedAsrModel::Allocator() const { + return impl_->Allocator(); +} + +const OfflineFireRedAsrModelMetaData &OfflineFireRedAsrModel::GetModelMetadata() + const { + return impl_->GetModelMetadata(); +} + +#if __ANDROID_API__ >= 9 +template OfflineFireRedAsrModel::OfflineFireRedAsrModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineFireRedAsrModel::OfflineFireRedAsrModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-fire-red-asr-model.h b/sherpa-onnx/csrc/offline-fire-red-asr-model.h new file mode 100644 index 0000000000..13aff32dc2 --- /dev/null +++ b/sherpa-onnx/csrc/offline-fire-red-asr-model.h @@ -0,0 +1,92 @@ +// sherpa-onnx/csrc/offline-fire-red-asr-model.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h" +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +class OfflineFireRedAsrModel { + public: + explicit OfflineFireRedAsrModel(const OfflineModelConfig &config); + + template + OfflineFireRedAsrModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineFireRedAsrModel(); + + /** Run the encoder model. + * + * @param features A tensor of shape (N, T, C). + * @param features_len A tensor of shape (N,) with dtype int64. + * + * @return Return a pair containing: + * - n_layer_cross_k: A 4-D tensor of shape + * (num_decoder_layers, N, T, d_model) + * - n_layer_cross_v: A 4-D tensor of shape + * (num_decoder_layers, N, T, d_model) + */ + std::pair ForwardEncoder( + Ort::Value features, Ort::Value features_length) const; + + /** Run the decoder model. + * + * @param tokens A int64 tensor of shape (N, num_words) + * @param n_layer_self_k_cache A 5-D tensor of shape + * (num_decoder_layers, N, max_len, num_head, head_dim). + * @param n_layer_self_v_cache A 5-D tensor of shape + * (num_decoder_layers, N, max_len, num_head, head_dim). + * @param n_layer_cross_k A 5-D tensor of shape + * (num_decoder_layers, N, T, d_model). + * @param n_layer_cross_v A 5-D tensor of shape + * (num_decoder_layers, N, T, d_model). + * @param offset A int64 tensor of shape (N,) + * + * @return Return a tuple containing 6 tensors: + * + * - logits A 3-D tensor of shape (N, num_words, vocab_size) + * - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache + * - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache + * - out_n_layer_cross_k Same as n_layer_cross_k + * - out_n_layer_cross_v Same as n_layer_cross_v + * - out_offset Same as offset + */ + std::tuple + ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, + Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, + Ort::Value n_layer_cross_v, Ort::Value offset) const; + + /** Return the initial self kv cache in a pair + * - n_layer_self_k_cache A 5-D tensor of shape + * (num_decoder_layers, N, max_len, num_head, head_dim). + * - n_layer_self_v_cache A 5-D tensor of shape + * (num_decoder_layers, N, max_len, num_head, head_dim). + */ + std::pair GetInitialSelfKVCache() const; + + const OfflineFireRedAsrModelMetaData &GetModelMetadata() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index 7872903272..2aee77c824 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { paraformer.Register(po); nemo_ctc.Register(po); whisper.Register(po); + fire_red_asr.Register(po); tdnn.Register(po); zipformer_ctc.Register(po); wenet_ctc.Register(po); @@ -38,7 +39,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { po->Register("model-type", &model_type, "Specify it to reduce model initialization time. " "Valid values are: transducer, paraformer, nemo_ctc, whisper, " - "tdnn, zipformer2_ctc, telespeech_ctc." + "tdnn, zipformer2_ctc, telespeech_ctc, fire_red_asr." "All other values lead to loading the model twice."); po->Register("modeling-unit", &modeling_unit, "The modeling unit of the model, commonly used units are bpe, " @@ -84,6 +85,10 @@ bool OfflineModelConfig::Validate() const { return whisper.Validate(); } + if (!fire_red_asr.encoder.empty()) { + return fire_red_asr.Validate(); + } + if (!tdnn.model.empty()) { return tdnn.Validate(); } @@ -125,6 +130,7 @@ std::string OfflineModelConfig::ToString() const { os << "paraformer=" << paraformer.ToString() << ", "; os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; os << "whisper=" << whisper.ToString() << ", "; + os << "fire_red_asr=" << fire_red_asr.ToString() << ", "; os << "tdnn=" << tdnn.ToString() << ", "; os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index cfff5eed22..e99e39a584 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -6,6 +6,7 @@ #include +#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h" #include "sherpa-onnx/csrc/offline-moonshine-model-config.h" #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" @@ -23,6 +24,7 @@ struct OfflineModelConfig { OfflineParaformerModelConfig paraformer; OfflineNemoEncDecCtcModelConfig nemo_ctc; OfflineWhisperModelConfig whisper; + OfflineFireRedAsrModelConfig fire_red_asr; OfflineTdnnModelConfig tdnn; OfflineZipformerCtcModelConfig zipformer_ctc; OfflineWenetCtcModelConfig wenet_ctc; @@ -54,6 +56,7 @@ struct OfflineModelConfig { const OfflineParaformerModelConfig ¶former, const OfflineNemoEncDecCtcModelConfig &nemo_ctc, const OfflineWhisperModelConfig &whisper, + const OfflineFireRedAsrModelConfig &fire_red_asr, const OfflineTdnnModelConfig &tdnn, const OfflineZipformerCtcModelConfig &zipformer_ctc, const OfflineWenetCtcModelConfig &wenet_ctc, @@ -68,6 +71,7 @@ struct OfflineModelConfig { paraformer(paraformer), nemo_ctc(nemo_ctc), whisper(whisper), + fire_red_asr(fire_red_asr), tdnn(tdnn), zipformer_ctc(zipformer_ctc), wenet_ctc(wenet_ctc), diff --git a/sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h b/sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h new file mode 100644 index 0000000000..f206e314c1 --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h @@ -0,0 +1,158 @@ +// sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h" +#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h" +#include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +static OfflineRecognitionResult Convert( + const OfflineFireRedAsrDecoderResult &src, const SymbolTable &sym_table) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + + std::string text; + for (auto i : src.tokens) { + if (!sym_table.Contains(i)) { + continue; + } + + const auto &s = sym_table[i]; + text += s; + r.tokens.push_back(s); + } + + r.text = text; + + return r; +} + +class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerFireRedAsrImpl( + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config.model_config)) { + Init(); + } + + template + OfflineRecognizerFireRedAsrImpl(Manager *mgr, + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_(std::make_unique(mgr, + config.model_config)) { + Init(); + } + + void Init() { + if (config_.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(model_.get()); + } else { + SHERPA_ONNX_LOGE( + "Only greedy_search is supported at present for FireRedAsr. Given %s", + config_.decoding_method.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + const auto &meta_data = model_->GetModelMetadata(); + + config_.feat_config.normalize_samples = false; + config_.feat_config.high_freq = 0; + config_.feat_config.snip_edges = true; + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + // batch decoding is not implemented yet + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + private: + void DecodeStream(OfflineStream *s) const { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t feat_dim = s->FeatureDim(); + std::vector f = s->GetFrames(); + ApplyCMVN(&f); + + int64_t num_frames = f.size() / feat_dim; + + std::array shape{1, num_frames, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), + shape.data(), shape.size()); + + int64_t len_shape = 1; + Ort::Value x_len = + Ort::Value::CreateTensor(memory_info, &num_frames, 1, &len_shape, 1); + + auto cross_kv = model_->ForwardEncoder(std::move(x), std::move(x_len)); + + auto results = + decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); + + auto r = Convert(results[0], symbol_table_); + + r.text = ApplyInverseTextNormalization(std::move(r.text)); + s->SetResult(r); + } + + void ApplyCMVN(std::vector *v) const { + const auto &meta_data = model_->GetModelMetadata(); + const auto &mean = meta_data.mean; + const auto &inv_stddev = meta_data.inv_stddev; + int32_t feat_dim = static_cast(mean.size()); + int32_t num_frames = static_cast(v->size()) / feat_dim; + + float *p = v->data(); + + for (int32_t i = 0; i != num_frames; ++i) { + for (int32_t k = 0; k != feat_dim; ++k) { + p[k] = (p[k] - mean[k]) * inv_stddev[k]; + } + + p += feat_dim; + } + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 1867bf39be..59e74d4cdc 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -24,6 +24,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h" @@ -56,6 +57,10 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } + if (!config.model_config.fire_red_asr.encoder.empty()) { + return std::make_unique(config); + } + if (!config.model_config.moonshine.preprocessor.empty()) { return std::make_unique(config); } @@ -237,6 +242,10 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(mgr, config); } + if (!config.model_config.fire_red_asr.encoder.empty()) { + return std::make_unique(mgr, config); + } + if (!config.model_config.moonshine.preprocessor.empty()) { return std::make_unique(mgr, config); } diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index a4c15713c2..95b96bf980 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -9,6 +9,7 @@ set(srcs features.cc keyword-spotter.cc offline-ctc-fst-decoder-config.cc + offline-fire-red-asr-model-config.cc offline-lm-config.cc offline-model-config.cc offline-moonshine-model-config.cc diff --git a/sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc b/sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc new file mode 100644 index 0000000000..fe5929fc02 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc @@ -0,0 +1,24 @@ +// sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h" + +#include +#include + +#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineFireRedAsrModelConfig(py::module *m) { + using PyClass = OfflineFireRedAsrModelConfig; + py::class_(*m, "OfflineFireRedAsrModelConfig") + .def(py::init(), + py::arg("encoder"), py::arg("decoder")) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("decoder", &PyClass::decoder) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h b/sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h new file mode 100644 index 0000000000..d2e290ebda --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineFireRedAsrModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index d999486bc4..92f8a2a868 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -8,6 +8,7 @@ #include #include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h" #include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h" #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" @@ -25,6 +26,7 @@ void PybindOfflineModelConfig(py::module *m) { PybindOfflineParaformerModelConfig(m); PybindOfflineNemoEncDecCtcModelConfig(m); PybindOfflineWhisperModelConfig(m); + PybindOfflineFireRedAsrModelConfig(m); PybindOfflineTdnnModelConfig(m); PybindOfflineZipformerCtcModelConfig(m); PybindOfflineWenetCtcModelConfig(m); @@ -33,35 +35,38 @@ void PybindOfflineModelConfig(py::module *m) { using PyClass = OfflineModelConfig; py::class_(*m, "OfflineModelConfig") - .def( - py::init< - const OfflineTransducerModelConfig &, - const OfflineParaformerModelConfig &, - const OfflineNemoEncDecCtcModelConfig &, - const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &, - const OfflineZipformerCtcModelConfig &, - const OfflineWenetCtcModelConfig &, - const OfflineSenseVoiceModelConfig &, - const OfflineMoonshineModelConfig &, const std::string &, - const std::string &, int32_t, bool, const std::string &, - const std::string &, const std::string &, const std::string &>(), - py::arg("transducer") = OfflineTransducerModelConfig(), - py::arg("paraformer") = OfflineParaformerModelConfig(), - py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), - py::arg("whisper") = OfflineWhisperModelConfig(), - py::arg("tdnn") = OfflineTdnnModelConfig(), - py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), - py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), - py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), - py::arg("moonshine") = OfflineMoonshineModelConfig(), - py::arg("telespeech_ctc") = "", py::arg("tokens"), - py::arg("num_threads"), py::arg("debug") = false, - py::arg("provider") = "cpu", py::arg("model_type") = "", - py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "") + .def(py::init(), + py::arg("transducer") = OfflineTransducerModelConfig(), + py::arg("paraformer") = OfflineParaformerModelConfig(), + py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), + py::arg("whisper") = OfflineWhisperModelConfig(), + py::arg("fire_red_asr") = OfflineFireRedAsrModelConfig(), + py::arg("tdnn") = OfflineTdnnModelConfig(), + py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), + py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), + py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), + py::arg("moonshine") = OfflineMoonshineModelConfig(), + py::arg("telespeech_ctc") = "", py::arg("tokens"), + py::arg("num_threads"), py::arg("debug") = false, + py::arg("provider") = "cpu", py::arg("model_type") = "", + py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) .def_readwrite("whisper", &PyClass::whisper) + .def_readwrite("fire_red_asr", &PyClass::fire_red_asr) .def_readwrite("tdnn", &PyClass::tdnn) .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 3916660054..d8ab6709df 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -6,6 +6,7 @@ from _sherpa_onnx import ( FeatureExtractorConfig, OfflineCtcFstDecoderConfig, + OfflineFireRedAsrModelConfig, OfflineLMConfig, OfflineModelConfig, OfflineMoonshineModelConfig, @@ -571,6 +572,78 @@ def from_whisper( self.config = recognizer_config return self + @classmethod + def from_fire_red_asr( + cls, + encoder: str, + decoder: str, + tokens: str, + num_threads: int = 1, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different kinds of FireRedAsr models, + e.g., xs, large, etc. + + Args: + encoder: + Path to the encoder model. + decoder: + Path to the decoder model. + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + num_threads: + Number of threads for neural network computation. + decoding_method: + Valid values: greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + fire_red_asr=OfflineFireRedAsrModelConfig( + encoder=encoder, + decoder=decoder, + ), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=16000, + feature_dim=80, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + @classmethod def from_moonshine( cls,