Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapted to 8k #1175

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions runtime/onnxruntime/include/audio.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ class Audio {
queue<AudioFrame *> frame_queue;
queue<AudioFrame *> asr_online_queue;
queue<AudioFrame *> asr_offline_queue;

int dest_sample_rate;
public:
Audio(int data_type);
Audio(int data_type, int size);
Audio(int model_sample_rate,int data_type);
Audio(int model_sample_rate,int data_type, int size);
~Audio();
void Disp();
void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);
Expand Down
2 changes: 2 additions & 0 deletions runtime/onnxruntime/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class Model {
virtual void InitSegDict(const std::string &seg_dict_model){};
virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
virtual std::string GetLang(){return "";};
virtual int GetAsrSampleRate() = 0;

};

Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
Expand Down
1 change: 1 addition & 0 deletions runtime/onnxruntime/include/vad-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class VadModel {
virtual ~VadModel(){};
virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
virtual std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true)=0;
virtual int GetVadSampleRate() = 0;
};

VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
Expand Down
44 changes: 27 additions & 17 deletions runtime/onnxruntime/src/audio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,28 @@ int AudioFrame::Disp()
return 0;
}

Audio::Audio(int data_type) : data_type(data_type)
Audio::Audio(int data_type) : dest_sample_rate(MODEL_SAMPLE_RATE), data_type(data_type)
{
speech_buff = NULL;
speech_data = NULL;
align_size = 1360;
seg_sample = dest_sample_rate / 1000;
}

Audio::Audio(int data_type, int size) : data_type(data_type)
Audio::Audio(int model_sample_rate, int data_type) : dest_sample_rate(model_sample_rate), data_type(data_type)
{
speech_buff = NULL;
speech_data = NULL;
align_size = 1360;
seg_sample = dest_sample_rate / 1000;
}

Audio::Audio(int model_sample_rate, int data_type, int size) : dest_sample_rate(model_sample_rate), data_type(data_type)
{
speech_buff = NULL;
speech_data = NULL;
align_size = (float)size;
seg_sample = dest_sample_rate / 1000;
}

Audio::~Audio()
Expand All @@ -222,12 +232,12 @@ Audio::~Audio()

void Audio::Disp()
{
LOG(INFO) << "Audio time is " << (float)speech_len / MODEL_SAMPLE_RATE << " s. len is " << speech_len;
LOG(INFO) << "Audio time is " << (float)speech_len / dest_sample_rate << " s. len is " << speech_len;
}

float Audio::GetTimeLen()
{
return (float)speech_len / MODEL_SAMPLE_RATE;
return (float)speech_len / dest_sample_rate;
}

void Audio::WavResample(int32_t sampling_rate, const float *waveform,
Expand All @@ -237,13 +247,13 @@ void Audio::WavResample(int32_t sampling_rate, const float *waveform,
<< " in_sample_rate: "<< sampling_rate << "\n"
<< " output_sample_rate: " << static_cast<int32_t>(MODEL_SAMPLE_RATE);
float min_freq =
std::min<int32_t>(sampling_rate, MODEL_SAMPLE_RATE);
std::min<int32_t>(sampling_rate, dest_sample_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;

int32_t lowpass_filter_width = 6;

auto resampler = std::make_unique<LinearResample>(
sampling_rate, MODEL_SAMPLE_RATE, lowpass_cutoff, lowpass_filter_width);
sampling_rate, dest_sample_rate, lowpass_cutoff, lowpass_filter_width);
std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples);
//reset speech_data
Expand Down Expand Up @@ -311,7 +321,7 @@ bool Audio::FfmpegLoad(const char *filename, bool copy2char){
nullptr, // allocate a new context
AV_CH_LAYOUT_MONO, // output channel layout (stereo)
AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
16000, // output sample rate (same as input)
dest_sample_rate, // output sample rate (same as input)
av_get_default_channel_layout(codecContext->channels), // input channel layout
codecContext->sample_fmt, // input sample format
codecContext->sample_rate, // input sample rate
Expand Down Expand Up @@ -347,7 +357,7 @@ bool Audio::FfmpegLoad(const char *filename, bool copy2char){
int in_samples = frame->nb_samples;
uint8_t **in_data = frame->extended_data;
int out_samples = av_rescale_rnd(in_samples,
16000,
dest_sample_rate,
codecContext->sample_rate,
AV_ROUND_DOWN);

Expand Down Expand Up @@ -494,7 +504,7 @@ bool Audio::FfmpegLoad(const char* buf, int n_file_len){
nullptr, // allocate a new context
AV_CH_LAYOUT_MONO, // output channel layout (stereo)
AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
16000, // output sample rate (same as input)
dest_sample_rate, // output sample rate (same as input)
av_get_default_channel_layout(codecContext->channels), // input channel layout
codecContext->sample_fmt, // input sample format
codecContext->sample_rate, // input sample rate
Expand Down Expand Up @@ -532,7 +542,7 @@ bool Audio::FfmpegLoad(const char* buf, int n_file_len){
int in_samples = frame->nb_samples;
uint8_t **in_data = frame->extended_data;
int out_samples = av_rescale_rnd(in_samples,
16000,
dest_sample_rate,
codecContext->sample_rate,
AV_ROUND_DOWN);

Expand Down Expand Up @@ -666,7 +676,7 @@ bool Audio::LoadWav(const char *filename, int32_t* sampling_rate)
}

//resample
if(*sampling_rate != MODEL_SAMPLE_RATE){
if(*sampling_rate != dest_sample_rate){
WavResample(*sampling_rate, speech_data, speech_len);
}

Expand Down Expand Up @@ -752,7 +762,7 @@ bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
}

//resample
if(*sampling_rate != MODEL_SAMPLE_RATE){
if(*sampling_rate != dest_sample_rate){
WavResample(*sampling_rate, speech_data, speech_len);
}

Expand Down Expand Up @@ -795,7 +805,7 @@ bool Audio::LoadPcmwav(const char* buf, int n_buf_len, int32_t* sampling_rate)
}

//resample
if(*sampling_rate != MODEL_SAMPLE_RATE){
if(*sampling_rate != dest_sample_rate){
WavResample(*sampling_rate, speech_data, speech_len);
}

Expand Down Expand Up @@ -840,7 +850,7 @@ bool Audio::LoadPcmwavOnline(const char* buf, int n_buf_len, int32_t* sampling_r
}

//resample
if(*sampling_rate != MODEL_SAMPLE_RATE){
if(*sampling_rate != dest_sample_rate){
WavResample(*sampling_rate, speech_data, speech_len);
}

Expand Down Expand Up @@ -898,7 +908,7 @@ bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
}

//resample
if(*sampling_rate != MODEL_SAMPLE_RATE){
if(*sampling_rate != dest_sample_rate){
WavResample(*sampling_rate, speech_data, speech_len);
}

Expand Down Expand Up @@ -1009,7 +1019,7 @@ int Audio::Fetch(float *&dout, int &len, int &flag, float &start_time)
AudioFrame *frame = frame_queue.front();
frame_queue.pop();

start_time = (float)(frame->GetStart())/MODEL_SAMPLE_RATE;
start_time = (float)(frame->GetStart())/ dest_sample_rate;
dout = speech_data + frame->GetStart();
len = frame->GetLen();
delete frame;
Expand Down Expand Up @@ -1248,7 +1258,7 @@ void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYP
}

// erase all_samples
int vector_cache = MODEL_SAMPLE_RATE*2;
int vector_cache = dest_sample_rate*2;
if(speech_offline_start == -1){
if(all_samples.size() > vector_cache){
int erase_num = all_samples.size() - vector_cache;
Expand Down
5 changes: 4 additions & 1 deletion runtime/onnxruntime/src/fsmn-vad-online.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,11 @@ void FsmnVadOnline::InitOnline(std::shared_ptr<Ort::Session> &vad_session,
vad_max_len_ = vad_max_len;
vad_speech_noise_thres_ = vad_speech_noise_thres;

frame_sample_length_ = vad_sample_rate_ / 1000 * 25;;
frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10;

// 2pass
audio_handle = make_unique<Audio>(1);
audio_handle = make_unique<Audio>(vad_sample_rate,1);
}

FsmnVadOnline::~FsmnVadOnline() {
Expand Down
2 changes: 2 additions & 0 deletions runtime/onnxruntime/src/fsmn-vad-online.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class FsmnVadOnline : public VadModel {
std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
void Reset();
int GetVadSampleRate() { return vad_sample_rate_; };

// 2pass
std::unique_ptr<Audio> audio_handle = nullptr;

Expand Down
2 changes: 2 additions & 0 deletions runtime/onnxruntime/src/fsmn-vad.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class FsmnVad : public VadModel {
std::vector<std::vector<float>> *in_cache,
bool is_final);
void Reset();

int GetVadSampleRate() { return vad_sample_rate_; };

std::shared_ptr<Ort::Session> vad_session_ = nullptr;
Ort::Env env_;
Expand Down
12 changes: 6 additions & 6 deletions runtime/onnxruntime/src/funasrruntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
if (!recog_obj)
return nullptr;

funasr::Audio audio(1);
funasr::Audio audio(recog_obj->GetAsrSampleRate(),1);
if(wav_format == "pcm" || wav_format == "PCM"){
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
return nullptr;
Expand Down Expand Up @@ -93,7 +93,7 @@
if (!recog_obj)
return nullptr;

funasr::Audio audio(1);
funasr::Audio audio(recog_obj->GetAsrSampleRate(),1);
if(funasr::is_target_file(sz_filename, "wav")){
int32_t sampling_rate_ = -1;
if(!audio.LoadWav(sz_filename, &sampling_rate_))
Expand Down Expand Up @@ -134,7 +134,7 @@
if (!vad_obj)
return nullptr;

funasr::Audio audio(1);
funasr::Audio audio(vad_obj->GetVadSampleRate(),1);
if(wav_format == "pcm" || wav_format == "PCM"){
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
return nullptr;
Expand Down Expand Up @@ -162,7 +162,7 @@
if (!vad_obj)
return nullptr;

funasr::Audio audio(1);
funasr::Audio audio(vad_obj->GetVadSampleRate(),1);
if(funasr::is_target_file(sz_filename, "wav")){
int32_t sampling_rate_ = -1;
if(!audio.LoadWav(sz_filename, &sampling_rate_))
Expand Down Expand Up @@ -222,7 +222,7 @@
if (!offline_stream)
return nullptr;

funasr::Audio audio(1);
funasr::Audio audio(offline_stream->asr_handle->GetAsrSampleRate(),1);
try{
if(wav_format == "pcm" || wav_format == "PCM"){
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
Expand Down Expand Up @@ -314,7 +314,7 @@
if (!offline_stream)
return nullptr;

funasr::Audio audio(1);
funasr::Audio audio((offline_stream->asr_handle)->GetAsrSampleRate(),1);
try{
if(funasr::is_target_file(sz_filename, "wav")){
int32_t sampling_rate_ = -1;
Expand Down
8 changes: 6 additions & 2 deletions runtime/onnxruntime/src/paraformer-online.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ void ParaformerOnline::InitOnline(
for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
fsmn_init_cache_.emplace_back(0);
}
chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
chunk_len = chunk_size[1]*frame_shift*lfr_n*para_handle_->asr_sample_rate/1000;

frame_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_length;
frame_shift_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_shift;

}

void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
Expand Down Expand Up @@ -489,7 +493,7 @@ string ParaformerOnline::Forward(float* din, int len, bool input_finished, const
if(is_first_chunk){
is_first_chunk = false;
}
ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished);
ExtractFeats(para_handle_->asr_sample_rate, wav_feats, waves, input_finished);
if(wav_feats.size() == 0){
return result;
}
Expand Down
3 changes: 3 additions & 0 deletions runtime/onnxruntime/src/paraformer-online.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ namespace funasr {
string ForwardChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
string Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
string Rescoring();

int GetAsrSampleRate() { return para_handle_->asr_sample_rate; };

// 2pass
std::string online_res;
int chunk_len;
Expand Down
12 changes: 9 additions & 3 deletions runtime/onnxruntime/src/paraformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ Paraformer::Paraformer()

// offline
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
LoadConfigFromYaml(am_config.c_str());
// knf options
fbank_opts_.frame_opts.dither = 0;
fbank_opts_.mel_opts.num_bins = n_mels;
fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
fbank_opts_.frame_opts.window_type = window_type;
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
fbank_opts_.frame_opts.frame_length_ms = frame_length;
Expand Down Expand Up @@ -65,7 +66,6 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
vocab = new Vocab(am_config.c_str());
LoadConfigFromYaml(am_config.c_str());
phone_set_ = new PhoneSet(am_config.c_str());
LoadCmvn(am_cmvn.c_str());
}
Expand All @@ -77,7 +77,7 @@ void Paraformer::InitAsr(const std::string &en_model, const std::string &de_mode
// knf options
fbank_opts_.frame_opts.dither = 0;
fbank_opts_.mel_opts.num_bins = n_mels;
fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
fbank_opts_.frame_opts.window_type = window_type;
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
fbank_opts_.frame_opts.frame_length_ms = frame_length;
Expand Down Expand Up @@ -216,6 +216,9 @@ void Paraformer::LoadConfigFromYaml(const char* filename){
}

try{
YAML::Node frontend_conf = config["frontend_conf"];
this->asr_sample_rate = frontend_conf["fs"].as<int>();

YAML::Node lang_conf = config["lang"];
if (lang_conf.IsDefined()){
language = lang_conf.as<string>();
Expand Down Expand Up @@ -258,6 +261,9 @@ void Paraformer::LoadOnlineConfigFromYaml(const char* filename){
this->cif_threshold = predictor_conf["threshold"].as<double>();
this->tail_alphas = predictor_conf["tail_threshold"].as<double>();

this->asr_sample_rate = frontend_conf["fs"].as<int>();


}catch(exception const &e){
LOG(ERROR) << "Error when load argument from vad config YAML.";
exit(-1);
Expand Down
5 changes: 2 additions & 3 deletions runtime/onnxruntime/src/paraformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ namespace funasr {

string Rescoring();
string GetLang(){return language;};
int GetAsrSampleRate() { return asr_sample_rate; };
void StartUtterance();
void EndUtterance();
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
Expand Down Expand Up @@ -107,8 +107,7 @@ namespace funasr {
int fsmn_dims = 512;
float cif_threshold = 1.0;
float tail_alphas = 0.45;


int asr_sample_rate = MODEL_SAMPLE_RATE;
};

} // namespace funasr
Loading