diff --git a/.github/scripts/test-online-ctc.sh b/.github/scripts/test-online-ctc.sh index 7c631dd05..c28d2b3cf 100755 --- a/.github/scripts/test-online-ctc.sh +++ b/.github/scripts/test-online-ctc.sh @@ -13,6 +13,28 @@ echo "PATH: $PATH" which $EXE +log "------------------------------------------------------------" +log "Run streaming NeMo CTC " +log "------------------------------------------------------------" + +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 +name=$(basename $url) +repo=$(basename -s .tar.bz2 $name) + +curl -SL -O $url +tar xvf $name +rm $name +ls -lh $repo + +$EXE \ + --nemo-ctc-model=$repo/model.onnx \ + --tokens=$repo/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + +rm -rf $repo + log "------------------------------------------------------------" log "Run streaming Zipformer2 CTC HLG decoding " log "------------------------------------------------------------" diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index fe0f568f0..a39d0c6bb 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -8,6 +8,19 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "test online NeMo CTC" + +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 +name=$(basename $url) +repo=$(basename -s .tar.bz2 $name) + +curl -SL -O $url +tar xvf $name +rm $name +ls -lh $repo +python3 ./python-api-examples/online-nemo-ctc-decode-files.py +rm -rf $repo + log "test offline punctuation" curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 59ef986b1..154d6d774 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -128,6 +128,14 @@ jobs: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} path: install/* + - name: Test online CTC + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx + + .github/scripts/test-online-ctc.sh + - name: Test offline transducer shell: bash run: | @@ -163,14 +171,6 @@ jobs: .github/scripts/test-offline-ctc.sh - - name: Test online CTC - shell: bash - run: | - export PATH=$PWD/build/bin:$PATH - export EXE=sherpa-onnx - - .github/scripts/test-online-ctc.sh - - name: Test offline punctuation shell: bash run: | diff --git a/python-api-examples/online-nemo-ctc-decode-files.py b/python-api-examples/online-nemo-ctc-decode-files.py new file mode 100755 index 000000000..757537733 --- /dev/null +++ b/python-api-examples/online-nemo-ctc-decode-files.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a streaming CTC model from NeMo +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + + +The example model is converted from +https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_80ms +""" + +from pathlib import Path + +import numpy as np +import sherpa_onnx +import soundfile as sf + + +def create_recognizer(): + model = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/model.onnx" + tokens = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt" + + test_wav = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/test_wavs/0.wav" + + if not Path(model).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_onnx.OnlineRecognizer.from_nemo_ctc( + model=model, + tokens=tokens, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) + stream.accept_waveform(sample_rate, tail_paddings) + stream.input_finished() + + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + print(wave_filename) + print(recognizer.get_result_all(stream)) + + +if __name__ == "__main__": + main() diff --git a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-ctc.py b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-ctc.py index 1ed6c4a61..727ba40c1 100755 --- a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-ctc.py +++ b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-ctc.py @@ -100,7 +100,7 @@ def init_cache_state(self): dtype=torch.float32, ).numpy() - self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy() + self.cache_last_channel_len = torch.zeros([1], dtype=torch.int64).numpy() def __call__(self, x: np.ndarray): # x: (T, C) diff --git a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-transducer.py b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-transducer.py index c671851f5..fb114f171 100755 --- a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-transducer.py +++ b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-transducer.py @@ -142,7 +142,7 @@ def init_cache_state(self): dtype=torch.float32, ).numpy() - self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy() + self.cache_last_channel_len = torch.zeros([1], dtype=torch.int64).numpy() def run_encoder(self, x: np.ndarray): # x: (T, C) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index fc5d240ce..fc32e5a4f 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -61,6 +61,8 @@ set(sources online-lm.cc online-lstm-transducer-model.cc online-model-config.cc + online-nemo-ctc-model-config.cc + online-nemo-ctc-model.cc online-paraformer-model-config.cc online-paraformer-model.cc online-recognizer-impl.cc diff --git a/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h b/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h index 48933dd87..b2b8704a3 100644 --- a/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h +++ b/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h @@ -4,11 +4,12 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ +#include + #include #include #include #include -#include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" @@ -61,7 +62,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { int32_t segment_size = 20; int32_t max_len = 200; - int32_t num_segments = ceil(((float)token_ids.size() + segment_size - 1) / segment_size); + int32_t num_segments = + ceil((static_cast(token_ids.size()) + segment_size - 1) / + segment_size); std::vector punctuations; int32_t last = -1; diff --git a/sherpa-onnx/csrc/online-ctc-model.cc b/sherpa-onnx/csrc/online-ctc-model.cc index 5fa76c192..a3a071a72 100644 --- a/sherpa-onnx/csrc/online-ctc-model.cc +++ b/sherpa-onnx/csrc/online-ctc-model.cc @@ -10,6 +10,7 @@ #include #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-nemo-ctc-model.h" #include "sherpa-onnx/csrc/online-wenet-ctc-model.h" #include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h" #include "sherpa-onnx/csrc/onnx-utils.h" @@ -22,6 +23,8 @@ std::unique_ptr OnlineCtcModel::Create( return std::make_unique(config); } else if (!config.zipformer2_ctc.model.empty()) { return std::make_unique(config); + } else if (!config.nemo_ctc.model.empty()) { + return std::make_unique(config); } else { SHERPA_ONNX_LOGE("Please specify a CTC model"); exit(-1); @@ -36,6 +39,8 @@ std::unique_ptr OnlineCtcModel::Create( return std::make_unique(mgr, config); } else if (!config.zipformer2_ctc.model.empty()) { return std::make_unique(mgr, config); + } else if (!config.nemo_ctc.model.empty()) { + return std::make_unique(mgr, config); } else { SHERPA_ONNX_LOGE("Please specify a CTC model"); exit(-1); diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index d2da161e9..bc4d55dcf 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -15,6 +15,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { paraformer.Register(po); wenet_ctc.Register(po); zipformer2_ctc.Register(po); + nemo_ctc.Register(po); po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -31,11 +32,11 @@ void OnlineModelConfig::Register(ParseOptions *po) { po->Register("provider", &provider, "Specify a provider to use: cpu, cuda, coreml"); - po->Register( - "model-type", &model_type, - "Specify it to reduce model initialization time. " - "Valid values are: conformer, lstm, zipformer, zipformer2, wenet_ctc" - "All other values lead to loading the model twice."); + po->Register("model-type", &model_type, + "Specify it to reduce model initialization time. " + "Valid values are: conformer, lstm, zipformer, zipformer2, " + "wenet_ctc, nemo_ctc. " + "All other values lead to loading the model twice."); } bool OnlineModelConfig::Validate() const { @@ -61,6 +62,10 @@ bool OnlineModelConfig::Validate() const { return zipformer2_ctc.Validate(); } + if (!nemo_ctc.model.empty()) { + return nemo_ctc.Validate(); + } + return transducer.Validate(); } @@ -72,6 +77,7 @@ std::string OnlineModelConfig::ToString() const { os << "paraformer=" << paraformer.ToString() << ", "; os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; + os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "warm_up=" << warm_up << ", "; diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 3857ee426..08acf773f 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -6,6 +6,7 @@ #include +#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h" #include "sherpa-onnx/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" @@ -18,6 +19,7 @@ struct OnlineModelConfig { OnlineParaformerModelConfig paraformer; OnlineWenetCtcModelConfig wenet_ctc; OnlineZipformer2CtcModelConfig zipformer2_ctc; + OnlineNeMoCtcModelConfig nemo_ctc; std::string tokens; int32_t num_threads = 1; int32_t warm_up = 0; @@ -30,6 +32,7 @@ struct OnlineModelConfig { // - zipformer, zipformer transducer from icefall // - zipformer2, zipformer2 transducer or CTC from icefall // - wenet_ctc, wenet CTC model + // - nemo_ctc, NeMo CTC model // // All other values are invalid and lead to loading the model twice. std::string model_type; @@ -39,6 +42,7 @@ struct OnlineModelConfig { const OnlineParaformerModelConfig ¶former, const OnlineWenetCtcModelConfig &wenet_ctc, const OnlineZipformer2CtcModelConfig &zipformer2_ctc, + const OnlineNeMoCtcModelConfig &nemo_ctc, const std::string &tokens, int32_t num_threads, int32_t warm_up, bool debug, const std::string &provider, const std::string &model_type) @@ -46,6 +50,7 @@ struct OnlineModelConfig { paraformer(paraformer), wenet_ctc(wenet_ctc), zipformer2_ctc(zipformer2_ctc), + nemo_ctc(nemo_ctc), tokens(tokens), num_threads(num_threads), warm_up(warm_up), diff --git a/sherpa-onnx/csrc/online-nemo-ctc-model-config.cc b/sherpa-onnx/csrc/online-nemo-ctc-model-config.cc new file mode 100644 index 000000000..c3c22b971 --- /dev/null +++ b/sherpa-onnx/csrc/online-nemo-ctc-model-config.cc @@ -0,0 +1,36 @@ +// sherpa-onnx/csrc/online-nemo-ctc-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OnlineNeMoCtcModelConfig::Register(ParseOptions *po) { + po->Register("nemo-ctc-model", &model, + "Path to CTC model.onnx from NeMo. Please see " + "https://github.com/k2-fsa/sherpa-onnx/pull/843"); +} + +bool OnlineNeMoCtcModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("NeMo CTC model '%s' does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OnlineNeMoCtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineNeMoCtcModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-nemo-ctc-model-config.h b/sherpa-onnx/csrc/online-nemo-ctc-model-config.h new file mode 100644 index 000000000..4fb1de0ad --- /dev/null +++ b/sherpa-onnx/csrc/online-nemo-ctc-model-config.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/online-nemo-ctc-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OnlineNeMoCtcModelConfig { + std::string model; + + OnlineNeMoCtcModelConfig() = default; + + explicit OnlineNeMoCtcModelConfig(const std::string &model) : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/online-nemo-ctc-model.cc b/sherpa-onnx/csrc/online-nemo-ctc-model.cc new file mode 100644 index 000000000..3f796e2d7 --- /dev/null +++ b/sherpa-onnx/csrc/online-nemo-ctc-model.cc @@ -0,0 +1,324 @@ +// sherpa-onnx/csrc/online-nemo-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-nemo-ctc-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/transpose.h" +#include "sherpa-onnx/csrc/unbind.h" + +namespace sherpa_onnx { + +class OnlineNeMoCtcModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.nemo_ctc.model); + Init(buf.data(), buf.size()); + } + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.nemo_ctc.model); + Init(buf.data(), buf.size()); + } + } +#endif + + std::vector Forward(Ort::Value x, + std::vector states) { + Ort::Value &cache_last_channel = states[0]; + Ort::Value &cache_last_time = states[1]; + Ort::Value &cache_last_channel_len = states[2]; + + int32_t batch_size = x.GetTensorTypeAndShapeInfo().GetShape()[0]; + + std::array length_shape{batch_size}; + + Ort::Value length = Ort::Value::CreateTensor( + allocator_, length_shape.data(), length_shape.size()); + + int64_t *p_length = length.GetTensorMutableData(); + + std::fill(p_length, p_length + batch_size, ChunkLength()); + + // (B, T, C) -> (B, C, T) + x = Transpose12(allocator_, &x); + + std::array inputs = { + std::move(x), View(&length), std::move(cache_last_channel), + std::move(cache_last_time), std::move(cache_last_channel_len)}; + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + // out[0]: logit + // out[1] logit_length + // out[2:] states_next + // + // we need to remove out[1] + + std::vector ans; + ans.reserve(out.size() - 1); + + for (int32_t i = 0; i != out.size(); ++i) { + if (i == 1) { + continue; + } + + ans.push_back(std::move(out[i])); + } + + return ans; + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t ChunkLength() const { return window_size_; } + + int32_t ChunkShift() const { return chunk_shift_; } + + OrtAllocator *Allocator() const { return allocator_; } + + // Return a vector containing 3 tensors + // - cache_last_channel + // - cache_last_time_ + // - cache_last_channel_len + std::vector GetInitStates() { + std::vector ans; + ans.reserve(3); + ans.push_back(View(&cache_last_channel_)); + ans.push_back(View(&cache_last_time_)); + ans.push_back(View(&cache_last_channel_len_)); + + return ans; + } + + std::vector StackStates( + std::vector> states) const { + int32_t batch_size = static_cast(states.size()); + if (batch_size == 1) { + return std::move(states[0]); + } + + std::vector ans; + + // stack cache_last_channel + std::vector buf(batch_size); + + // there are 3 states to be stacked + for (int32_t i = 0; i != 3; ++i) { + buf.clear(); + buf.reserve(batch_size); + + for (int32_t b = 0; b != batch_size; ++b) { + assert(states[b].size() == 3); + buf.push_back(&states[b][i]); + } + + Ort::Value c{nullptr}; + if (i == 2) { + c = Cat(allocator_, buf, 0); + } else { + c = Cat(allocator_, buf, 0); + } + + ans.push_back(std::move(c)); + } + + return ans; + } + + std::vector> UnStackStates( + std::vector states) const { + assert(states.size() == 3); + + std::vector> ans; + + auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape(); + int32_t batch_size = shape[0]; + ans.resize(batch_size); + + if (batch_size == 1) { + ans[0] = std::move(states); + return ans; + } + + for (int32_t i = 0; i != 3; ++i) { + std::vector v; + if (i == 2) { + v = Unbind(allocator_, &states[i], 0); + } else { + v = Unbind(allocator_, &states[i], 0); + } + + assert(v.size() == batch_size); + + for (int32_t b = 0; b != batch_size; ++b) { + ans[b].push_back(std::move(v[b])); + } + } + + return ans; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(window_size_, "window_size"); + SHERPA_ONNX_READ_META_DATA(chunk_shift_, "chunk_shift"); + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim1_, + "cache_last_channel_dim1"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim2_, + "cache_last_channel_dim2"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim3_, + "cache_last_channel_dim3"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim1_, "cache_last_time_dim1"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim2_, "cache_last_time_dim2"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim3_, "cache_last_time_dim3"); + + // need to increase by 1 since the blank token is not included in computing + // vocab_size in NeMo. + vocab_size_ += 1; + + InitStates(); + } + + void InitStates() { + std::array cache_last_channel_shape{1, cache_last_channel_dim1_, + cache_last_channel_dim2_, + cache_last_channel_dim3_}; + + cache_last_channel_ = Ort::Value::CreateTensor( + allocator_, cache_last_channel_shape.data(), + cache_last_channel_shape.size()); + + Fill(&cache_last_channel_, 0); + + std::array cache_last_time_shape{ + 1, cache_last_time_dim1_, cache_last_time_dim2_, cache_last_time_dim3_}; + + cache_last_time_ = Ort::Value::CreateTensor( + allocator_, cache_last_time_shape.data(), cache_last_time_shape.size()); + + Fill(&cache_last_time_, 0); + + int64_t shape = 1; + cache_last_channel_len_ = + Ort::Value::CreateTensor(allocator_, &shape, 1); + + cache_last_channel_len_.GetTensorMutableData()[0] = 0; + } + + private: + OnlineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t window_size_; + int32_t chunk_shift_; + int32_t subsampling_factor_; + int32_t vocab_size_; + int32_t cache_last_channel_dim1_; + int32_t cache_last_channel_dim2_; + int32_t cache_last_channel_dim3_; + int32_t cache_last_time_dim1_; + int32_t cache_last_time_dim2_; + int32_t cache_last_time_dim3_; + + Ort::Value cache_last_channel_{nullptr}; + Ort::Value cache_last_time_{nullptr}; + Ort::Value cache_last_channel_len_{nullptr}; +}; + +OnlineNeMoCtcModel::OnlineNeMoCtcModel(const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OnlineNeMoCtcModel::OnlineNeMoCtcModel(AAssetManager *mgr, + const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OnlineNeMoCtcModel::~OnlineNeMoCtcModel() = default; + +std::vector OnlineNeMoCtcModel::Forward( + Ort::Value x, std::vector states) const { + return impl_->Forward(std::move(x), std::move(states)); +} + +int32_t OnlineNeMoCtcModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OnlineNeMoCtcModel::ChunkLength() const { return impl_->ChunkLength(); } + +int32_t OnlineNeMoCtcModel::ChunkShift() const { return impl_->ChunkShift(); } + +OrtAllocator *OnlineNeMoCtcModel::Allocator() const { + return impl_->Allocator(); +} + +std::vector OnlineNeMoCtcModel::GetInitStates() const { + return impl_->GetInitStates(); +} + +std::vector OnlineNeMoCtcModel::StackStates( + std::vector> states) const { + return impl_->StackStates(std::move(states)); +} + +std::vector> OnlineNeMoCtcModel::UnStackStates( + std::vector states) const { + return impl_->UnStackStates(std::move(states)); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-nemo-ctc-model.h b/sherpa-onnx/csrc/online-nemo-ctc-model.h new file mode 100644 index 000000000..c8dd182e8 --- /dev/null +++ b/sherpa-onnx/csrc/online-nemo-ctc-model.h @@ -0,0 +1,81 @@ +// sherpa-onnx/csrc/online-nemo-ctc-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_ + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-ctc-model.h" +#include "sherpa-onnx/csrc/online-model-config.h" + +namespace sherpa_onnx { + +class OnlineNeMoCtcModel : public OnlineCtcModel { + public: + explicit OnlineNeMoCtcModel(const OnlineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineNeMoCtcModel(AAssetManager *mgr, const OnlineModelConfig &config); +#endif + + ~OnlineNeMoCtcModel() override; + + // A list of 3 tensors: + // - cache_last_channel + // - cache_last_time + // - cache_last_channel_len + std::vector GetInitStates() const override; + + std::vector StackStates( + std::vector> states) const override; + + std::vector> UnStackStates( + std::vector states) const override; + + /** + * + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. + * @param states It is from GetInitStates() or returned from this method. + * + * @return Return a list of tensors + * - ans[0] contains log_probs, of shape (N, T, C) + * - ans[1:] contains next_states + */ + std::vector Forward( + Ort::Value x, std::vector states) const override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const override; + + // The model accepts this number of frames before subsampling as input + int32_t ChunkLength() const override; + + // Similar to frame_shift in feature extractor, after processing + // ChunkLength() frames, we advance by ChunkShift() frames + // before we process the next chunk. + int32_t ChunkShift() const override; + + bool SupportBatchProcessing() const override { return true; } + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_ diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index c5923c608..56da814f7 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -21,7 +21,8 @@ std::unique_ptr OnlineRecognizerImpl::Create( } if (!config.model_config.wenet_ctc.model.empty() || - !config.model_config.zipformer2_ctc.model.empty()) { + !config.model_config.zipformer2_ctc.model.empty() || + !config.model_config.nemo_ctc.model.empty()) { return std::make_unique(config); } @@ -41,7 +42,8 @@ std::unique_ptr OnlineRecognizerImpl::Create( } if (!config.model_config.wenet_ctc.model.empty() || - !config.model_config.zipformer2_ctc.model.empty()) { + !config.model_config.zipformer2_ctc.model.empty() || + !config.model_config.nemo_ctc.model.empty()) { return std::make_unique(mgr, config); } diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index e4bff01c6..5c6fd4708 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -23,6 +23,7 @@ set(srcs online-ctc-fst-decoder-config.cc online-lm-config.cc online-model-config.cc + online-nemo-ctc-model-config.cc online-paraformer-model-config.cc online-recognizer.cc online-stream.cc diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index 473be930f..7da0089cb 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -9,6 +9,7 @@ #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" @@ -21,26 +22,30 @@ void PybindOnlineModelConfig(py::module *m) { PybindOnlineParaformerModelConfig(m); PybindOnlineWenetCtcModelConfig(m); PybindOnlineZipformer2CtcModelConfig(m); + PybindOnlineNeMoCtcModelConfig(m); using PyClass = OnlineModelConfig; py::class_(*m, "OnlineModelConfig") .def(py::init(), py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), - py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, + py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"), + py::arg("num_threads"), py::arg("warm_up") = 0, py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) + .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) diff --git a/sherpa-onnx/python/csrc/online-nemo-ctc-model-config.cc b/sherpa-onnx/python/csrc/online-nemo-ctc-model-config.cc new file mode 100644 index 000000000..a61180456 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-nemo-ctc-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-onnx/python/csrc/online-nemo-ctc-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h" + +namespace sherpa_onnx { + +void PybindOnlineNeMoCtcModelConfig(py::module *m) { + using PyClass = OnlineNeMoCtcModelConfig; + py::class_(*m, "OnlineNeMoCtcModelConfig") + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h b/sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h new file mode 100644 index 000000000..b8fcbc3c4 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOnlineNeMoCtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 79f154699..c402163fe 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -42,6 +42,8 @@ static void PybindOnlineRecognizerResult(py::module *m) { "segment", [](PyClass &self) -> int32_t { return self.segment; }) .def_property_readonly( "is_final", [](PyClass &self) -> bool { return self.is_final; }) + .def("__str__", &PyClass::AsJsonString, + py::call_guard()) .def("as_json_string", &PyClass::AsJsonString, py::call_guard()); } @@ -50,29 +52,17 @@ static void PybindOnlineRecognizerConfig(py::module *m) { using PyClass = OnlineRecognizerConfig; py::class_(*m, "OnlineRecognizerConfig") .def( - py::init(), - py::arg("feat_config"), - py::arg("model_config"), + py::init(), + py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config") = EndpointConfig(), py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), - py::arg("enable_endpoint"), - py::arg("decoding_method"), - py::arg("max_active_paths") = 4, - py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 0, - py::arg("blank_penalty") = 0.0, + py::arg("enable_endpoint"), py::arg("decoding_method"), + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, py::arg("temperature_scale") = 2.0) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 520000028..36fb66826 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -12,9 +12,11 @@ from _sherpa_onnx import OnlineRecognizer as _Recognizer from _sherpa_onnx import ( OnlineRecognizerConfig, + OnlineRecognizerResult, OnlineStream, OnlineTransducerModelConfig, OnlineWenetCtcModelConfig, + OnlineNeMoCtcModelConfig, OnlineZipformer2CtcModelConfig, OnlineCtcFstDecoderConfig, ) @@ -59,6 +61,7 @@ def from_transducer( lm: str = "", lm_scale: float = 0.1, temperature_scale: float = 2.0, + debug: bool = False, ): """ Please refer to @@ -154,6 +157,7 @@ def from_transducer( num_threads=num_threads, provider=provider, model_type=model_type, + debug=debug, ) feat_config = FeatureExtractorConfig( @@ -220,6 +224,7 @@ def from_paraformer( rule3_min_utterance_length: float = 20.0, decoding_method: str = "greedy_search", provider: str = "cpu", + debug: bool = False, ): """ Please refer to @@ -283,6 +288,7 @@ def from_paraformer( num_threads=num_threads, provider=provider, model_type="paraformer", + debug=debug, ) feat_config = FeatureExtractorConfig( @@ -324,6 +330,7 @@ def from_zipformer2_ctc( ctc_graph: str = "", ctc_max_active: int = 3000, provider: str = "cpu", + debug: bool = False, ): """ Please refer to @@ -386,6 +393,7 @@ def from_zipformer2_ctc( tokens=tokens, num_threads=num_threads, provider=provider, + debug=debug, ) feat_config = FeatureExtractorConfig( @@ -417,6 +425,106 @@ def from_zipformer2_ctc( self.config = recognizer_config return self + @classmethod + def from_nemo_ctc( + cls, + tokens: str, + model: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + provider: str = "cpu", + debug: bool = False, + ): + """ + Please refer to + ``_ + to download pre-trained models. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + debug: + True to show meta data in the model. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(model) + + assert num_threads > 0, num_threads + + nemo_ctc_config = OnlineNeMoCtcModelConfig( + model=model, + ) + + model_config = OnlineModelConfig( + nemo_ctc=nemo_ctc_config, + tokens=tokens, + num_threads=num_threads, + provider=provider, + debug=debug, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + @classmethod def from_wenet_ctc( cls, @@ -433,6 +541,7 @@ def from_wenet_ctc( rule3_min_utterance_length: float = 20.0, decoding_method: str = "greedy_search", provider: str = "cpu", + debug: bool = False, ): """ Please refer to @@ -497,6 +606,7 @@ def from_wenet_ctc( tokens=tokens, num_threads=num_threads, provider=provider, + debug=debug, ) feat_config = FeatureExtractorConfig( @@ -537,6 +647,9 @@ def decode_streams(self, ss: List[OnlineStream]): def is_ready(self, s: OnlineStream) -> bool: return self.recognizer.is_ready(s) + def get_result_all(self, s: OnlineStream) -> OnlineRecognizerResult: + return self.recognizer.get_result(s) + def get_result(self, s: OnlineStream) -> str: return self.recognizer.get_result(s).text.strip()