Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C++ and Python API for FireRedASR AED models #1867

Merged
merged 4 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ lexicon.txt
us_gold.json
us_silver.json
kokoro-multi-lang-v1_0
sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
75 changes: 75 additions & 0 deletions python-api-examples/offline-fire-red-asr-decode-files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3

"""
This file shows how to use a non-streaming FireRedAsr AED model from
https://github.com/FireRedTeam/FireRedASR
to decode files.

Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models

For instance,

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
"""

from pathlib import Path

import sherpa_onnx
import soundfile as sf


def create_recognizer():
encoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx"
decoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx"
tokens = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt"
test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/1.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/2.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/8k.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3-sichuan.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/4-tianjin.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/5-henan.wav"

if (
not Path(encoder).is_file()
or not Path(decoder).is_file()
or not Path(test_wav).is_file()
):
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_fire_red_asr(
encoder=encoder,
decoder=decoder,
tokens=tokens,
debug=True,
),
test_wav,
)


def main():
recognizer, wave_filename = create_recognizer()

audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel

# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz

stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
print(wave_filename)
print(stream.result)


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ set(sources
offline-ctc-fst-decoder.cc
offline-ctc-greedy-search-decoder.cc
offline-ctc-model.cc
offline-fire-red-asr-greedy-search-decoder.cc
offline-fire-red-asr-model-config.cc
offline-fire-red-asr-model.cc
offline-lm-config.cc
offline-lm.cc
offline-model-config.cc
Expand Down
39 changes: 39 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
//
// Copyright (c) 2025 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_

#include <cstdint>
#include <vector>

#include "onnxruntime_cxx_api.h" // NOLINT

namespace sherpa_onnx {

struct OfflineFireRedAsrDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
};

class OfflineFireRedAsrDecoder {
public:
virtual ~OfflineFireRedAsrDecoder() = default;

/** Run beam search given the output from the FireRedAsr encoder model.
*
* @param n_layer_cross_k A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model).
* @param n_layer_cross_v A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model).
*
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineFireRedAsrDecoderResult> Decode(
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
87 changes: 87 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
//
// Copyright (c) 2025 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h"

#include <algorithm>
#include <tuple>
#include <utility>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"

namespace sherpa_onnx {

// Note: this functions works only for batch size == 1 at present
std::vector<OfflineFireRedAsrDecoderResult>
OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
const auto &meta_data = model_->GetModelMetadata();

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

// For multilingual models, initial_tokens contains [sot, language, task]
// - language is English by default
// - task is transcribe by default
//
// For non-multilingual models, initial_tokens contains [sot]
std::array<int64_t, 2> token_shape = {1, 1};
int64_t token = meta_data.sos_id;

int32_t batch_size = 1;

Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, &token, 1, token_shape.data(), token_shape.size());

std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;

std::vector<OfflineFireRedAsrDecoderResult> ans(1);

auto self_kv_cache = model_->GetInitialSelfKVCache();

std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
decoder_out = {Ort::Value{nullptr},
std::move(self_kv_cache.first),
std::move(self_kv_cache.second),
std::move(cross_k),
std::move(cross_v),
std::move(offset)};

for (int32_t i = 0; i < meta_data.max_len; ++i) {
decoder_out = model_->ForwardDecoder(View(&tokens),
std::move(std::get<1>(decoder_out)),
std::move(std::get<2>(decoder_out)),
std::move(std::get<3>(decoder_out)),
std::move(std::get<4>(decoder_out)),
std::move(std::get<5>(decoder_out)));

const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>();

auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
int32_t vocab_size = logits_shape[2];

int32_t max_token_id = static_cast<int32_t>(std::distance(
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
if (max_token_id == meta_data.eos_id) {
break;
}

ans[0].tokens.push_back(max_token_id);

token = max_token_id;

// increment offset
*(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) += 1;
}

return ans;
}

} // namespace sherpa_onnx
29 changes: 29 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
//
// Copyright (c) 2025 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_

#include <vector>

#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h"
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"

namespace sherpa_onnx {

class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder {
public:
explicit OfflineFireRedAsrGreedySearchDecoder(OfflineFireRedAsrModel *model)
: model_(model) {}

std::vector<OfflineFireRedAsrDecoderResult> Decode(
Ort::Value cross_k, Ort::Value cross_v) override;

private:
OfflineFireRedAsrModel *model_; // not owned
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
56 changes: 56 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

void OfflineFireRedAsrModelConfig::Register(ParseOptions *po) {
po->Register("fire-red-asr-encoder", &encoder,
"Path to onnx encoder of FireRedAsr");

po->Register("fire-red-asr-decoder", &decoder,
"Path to onnx decoder of FireRedAsr");
}

bool OfflineFireRedAsrModelConfig::Validate() const {
if (encoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --fire-red-asr-encoder");
return false;
}

if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("FireRedAsr encoder file '%s' does not exist",
encoder.c_str());
return false;
}

if (decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --fire-red-asr-decoder");
return false;
}

if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("FireRedAsr decoder file '%s' does not exist",
decoder.c_str());
return false;
}

return true;
}

std::string OfflineFireRedAsrModelConfig::ToString() const {
std::ostringstream os;

os << "OfflineFireRedAsrModelConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\")";

return os.str();
}

} // namespace sherpa_onnx
31 changes: 31 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_

#include <string>

#include "sherpa-onnx/csrc/parse-options.h"

namespace sherpa_onnx {

// see https://github.com/FireRedTeam/FireRedASR
struct OfflineFireRedAsrModelConfig {
std::string encoder;
std::string decoder;

OfflineFireRedAsrModelConfig() = default;
OfflineFireRedAsrModelConfig(const std::string &encoder,
const std::string &decoder)
: encoder(encoder), decoder(decoder) {}

void Register(ParseOptions *po);
bool Validate() const;

std::string ToString() const;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
28 changes: 28 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_

#include <string>
#include <unordered_map>
#include <vector>

namespace sherpa_onnx {

struct OfflineFireRedAsrModelMetaData {
int32_t sos_id;
int32_t eos_id;
int32_t max_len;

int32_t num_decoder_layers;
int32_t num_head;
int32_t head_dim;

std::vector<float> mean;
std::vector<float> inv_stddev;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
Loading
Loading