Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C++ support for streaming NeMo CTC models. #857

Merged
merged 5 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .github/scripts/test-online-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,28 @@ echo "PATH: $PATH"

which $EXE

log "------------------------------------------------------------"
log "Run streaming NeMo CTC "
log "------------------------------------------------------------"

url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2
name=$(basename $url)
repo=$(basename -s .tar.bz2 $name)

curl -SL -O $url
tar xvf $name
rm $name
ls -lh $repo

$EXE \
--nemo-ctc-model=$repo/model.onnx \
--tokens=$repo/tokens.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

rm -rf $repo

log "------------------------------------------------------------"
log "Run streaming Zipformer2 CTC HLG decoding "
log "------------------------------------------------------------"
Expand Down
13 changes: 13 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

log "test online NeMo CTC"

url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2
name=$(basename $url)
repo=$(basename -s .tar.bz2 $name)

curl -SL -O $url
tar xvf $name
rm $name
ls -lh $repo
python3 ./python-api-examples/online-nemo-ctc-decode-files.py
rm -rf $repo

log "test offline punctuation"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
Expand Down
16 changes: 8 additions & 8 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*

- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx

.github/scripts/test-online-ctc.sh

- name: Test offline transducer
shell: bash
run: |
Expand Down Expand Up @@ -163,14 +171,6 @@ jobs:

.github/scripts/test-offline-ctc.sh

- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx

.github/scripts/test-online-ctc.sh

- name: Test offline punctuation
shell: bash
run: |
Expand Down
67 changes: 67 additions & 0 deletions python-api-examples/online-nemo-ctc-decode-files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3

"""
This file shows how to use a streaming CTC model from NeMo
to decode files.

Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models


The example model is converted from
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_80ms
"""

from pathlib import Path

import numpy as np
import sherpa_onnx
import soundfile as sf


def create_recognizer():
model = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/model.onnx"
tokens = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt"

test_wav = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/test_wavs/0.wav"

if not Path(model).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OnlineRecognizer.from_nemo_ctc(
model=model,
tokens=tokens,
debug=True,
),
test_wav,
)


def main():
recognizer, wave_filename = create_recognizer()

audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel

# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz

stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)

tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
stream.accept_waveform(sample_rate, tail_paddings)
stream.input_finished()

while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
print(wave_filename)
print(recognizer.get_result_all(stream))


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def init_cache_state(self):
dtype=torch.float32,
).numpy()

self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy()
self.cache_last_channel_len = torch.zeros([1], dtype=torch.int64).numpy()

def __call__(self, x: np.ndarray):
# x: (T, C)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def init_cache_state(self):
dtype=torch.float32,
).numpy()

self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy()
self.cache_last_channel_len = torch.zeros([1], dtype=torch.int64).numpy()

def run_encoder(self, x: np.ndarray):
# x: (T, C)
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ set(sources
online-lm.cc
online-lstm-transducer-model.cc
online-model-config.cc
online-nemo-ctc-model-config.cc
online-nemo-ctc-model.cc
online-paraformer-model-config.cc
online-paraformer-model.cc
online-recognizer-impl.cc
Expand Down
7 changes: 5 additions & 2 deletions sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_

#include <math.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <math.h>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
Expand Down Expand Up @@ -61,7 +62,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {

int32_t segment_size = 20;
int32_t max_len = 200;
int32_t num_segments = ceil(((float)token_ids.size() + segment_size - 1) / segment_size);
int32_t num_segments =
ceil((static_cast<float>(token_ids.size()) + segment_size - 1) /
segment_size);

std::vector<int32_t> punctuations;
int32_t last = -1;
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/online-ctc-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <string>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-nemo-ctc-model.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
Expand All @@ -22,6 +23,8 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
return std::make_unique<OnlineWenetCtcModel>(config);
} else if (!config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineZipformer2CtcModel>(config);
} else if (!config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineNeMoCtcModel>(config);
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
Expand All @@ -36,6 +39,8 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
return std::make_unique<OnlineWenetCtcModel>(mgr, config);
} else if (!config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineZipformer2CtcModel>(mgr, config);
} else if (!config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineNeMoCtcModel>(mgr, config);
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
Expand Down
16 changes: 11 additions & 5 deletions sherpa-onnx/csrc/online-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ void OnlineModelConfig::Register(ParseOptions *po) {
paraformer.Register(po);
wenet_ctc.Register(po);
zipformer2_ctc.Register(po);
nemo_ctc.Register(po);

po->Register("tokens", &tokens, "Path to tokens.txt");

Expand All @@ -31,11 +32,11 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");

po->Register(
"model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: conformer, lstm, zipformer, zipformer2, wenet_ctc"
"All other values lead to loading the model twice.");
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: conformer, lstm, zipformer, zipformer2, "
"wenet_ctc, nemo_ctc. "
"All other values lead to loading the model twice.");
}

bool OnlineModelConfig::Validate() const {
Expand All @@ -61,6 +62,10 @@ bool OnlineModelConfig::Validate() const {
return zipformer2_ctc.Validate();
}

if (!nemo_ctc.model.empty()) {
return nemo_ctc.Validate();
}

return transducer.Validate();
}

Expand All @@ -72,6 +77,7 @@ std::string OnlineModelConfig::ToString() const {
os << "paraformer=" << paraformer.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "warm_up=" << warm_up << ", ";
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/online-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <string>

#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h"
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
Expand All @@ -18,6 +19,7 @@ struct OnlineModelConfig {
OnlineParaformerModelConfig paraformer;
OnlineWenetCtcModelConfig wenet_ctc;
OnlineZipformer2CtcModelConfig zipformer2_ctc;
OnlineNeMoCtcModelConfig nemo_ctc;
std::string tokens;
int32_t num_threads = 1;
int32_t warm_up = 0;
Expand All @@ -30,6 +32,7 @@ struct OnlineModelConfig {
// - zipformer, zipformer transducer from icefall
// - zipformer2, zipformer2 transducer or CTC from icefall
// - wenet_ctc, wenet CTC model
// - nemo_ctc, NeMo CTC model
//
// All other values are invalid and lead to loading the model twice.
std::string model_type;
Expand All @@ -39,13 +42,15 @@ struct OnlineModelConfig {
const OnlineParaformerModelConfig &paraformer,
const OnlineWenetCtcModelConfig &wenet_ctc,
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
const OnlineNeMoCtcModelConfig &nemo_ctc,
const std::string &tokens, int32_t num_threads,
int32_t warm_up, bool debug, const std::string &provider,
const std::string &model_type)
: transducer(transducer),
paraformer(paraformer),
wenet_ctc(wenet_ctc),
zipformer2_ctc(zipformer2_ctc),
nemo_ctc(nemo_ctc),
tokens(tokens),
num_threads(num_threads),
warm_up(warm_up),
Expand Down
36 changes: 36 additions & 0 deletions sherpa-onnx/csrc/online-nemo-ctc-model-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// sherpa-onnx/csrc/online-nemo-ctc-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h"

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

void OnlineNeMoCtcModelConfig::Register(ParseOptions *po) {
po->Register("nemo-ctc-model", &model,
"Path to CTC model.onnx from NeMo. Please see "
"https://github.com/k2-fsa/sherpa-onnx/pull/843");
}

bool OnlineNeMoCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("NeMo CTC model '%s' does not exist", model.c_str());
return false;
}

return true;
}

std::string OnlineNeMoCtcModelConfig::ToString() const {
std::ostringstream os;

os << "OnlineNeMoCtcModelConfig(";
os << "model=\"" << model << "\")";

return os.str();
}

} // namespace sherpa_onnx
28 changes: 28 additions & 0 deletions sherpa-onnx/csrc/online-nemo-ctc-model-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// sherpa-onnx/csrc/online-nemo-ctc-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_

#include <string>

#include "sherpa-onnx/csrc/parse-options.h"

namespace sherpa_onnx {

struct OnlineNeMoCtcModelConfig {
std::string model;

OnlineNeMoCtcModelConfig() = default;

explicit OnlineNeMoCtcModelConfig(const std::string &model) : model(model) {}

void Register(ParseOptions *po);
bool Validate() const;

std::string ToString() const;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_
Loading
Loading