diff --git a/scripts/pyannote/segmentation/export-onnx.py b/scripts/pyannote/segmentation/export-onnx.py index 5f6e79c7e..7ebcae960 100755 --- a/scripts/pyannote/segmentation/export-onnx.py +++ b/scripts/pyannote/segmentation/export-onnx.py @@ -72,7 +72,7 @@ def main(): model.receptive_field.duration * 16000 ) - opset_version = 18 + opset_version = 15 filename = "model.onnx" torch.onnx.export( diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 222dbf151..3e6526563 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -169,6 +169,7 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) offline-speaker-diarization.cc offline-speaker-segmentation-model-config.cc offline-speaker-segmentation-pyannote-model-config.cc + offline-speaker-segmentation-pyannote-model.cc ) endif() diff --git a/sherpa-onnx/csrc/offline-sense-voice-model.cc b/sherpa-onnx/csrc/offline-sense-voice-model.cc index 1d2a14ef5..24903a41a 100644 --- a/sherpa-onnx/csrc/offline-sense-voice-model.cc +++ b/sherpa-onnx/csrc/offline-sense-voice-model.cc @@ -9,6 +9,7 @@ #include #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc index cefd30a7d..e41a7767a 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc @@ -6,11 +6,20 @@ #include +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h" + namespace sherpa_onnx { std::unique_ptr OfflineSpeakerDiarizationImpl::Create( const OfflineSpeakerDiarizationConfig &config) { + if (!config.segmentation.pyannote.model.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE("Please specify a speaker segmentation model."); + return nullptr; } diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index f4be51495..186f6ea1d 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -1,3 +1,33 @@ -// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.cc +// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h // // Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ + +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" + +namespace sherpa_onnx { +class OfflineSpeakerDiarizationPyannoteImpl + : public OfflineSpeakerDiarizationImpl { + public: + ~OfflineSpeakerDiarizationPyannoteImpl() override = default; + + explicit OfflineSpeakerDiarizationPyannoteImpl( + const OfflineSpeakerDiarizationConfig &config) + : config_(config), segmentation_model_(config_.segmentation) {} + + OfflineSpeakerDiarizationResult Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback = + nullptr) const override { + return {}; + } + + private: + OfflineSpeakerDiarizationConfig config_; + OfflineSpeakerSegmentationPyannoteModel segmentation_model_; +}; + +} // namespace sherpa_onnx +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc index 4eaf85498..bc992a295 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -4,6 +4,8 @@ #include "sherpa-onnx/csrc/offline-speaker-diarization.h" +#include + #include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" namespace sherpa_onnx { @@ -39,7 +41,8 @@ std::string OfflineSpeakerDiarizationConfig::ToString() const { } OfflineSpeakerDiarization::OfflineSpeakerDiarization( - const OfflineSpeakerDiarizationConfig &config) {} + const OfflineSpeakerDiarizationConfig &config) + : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {} OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h index 641886db4..1c4b88a69 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -7,6 +7,7 @@ #include #include +#include #include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc index 214892f05..f1c9f7d4a 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" #include +#include #include "sherpa-onnx/csrc/macros.h" diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h index c3ba7d7de..8e9e4a96e 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h @@ -4,6 +4,8 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ #define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ +#include + #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h" #include "sherpa-onnx/csrc/parse-options.h" diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc index 848524b1a..f7417ea83 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h" #include +#include #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h index a0467c53d..fb5ca4a48 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h @@ -15,8 +15,9 @@ struct OfflineSpeakerSegmentationPyannoteModelConfig { OfflineSpeakerSegmentationPyannoteModelConfig() = default; - OfflineSpeakerSegmentationPyannoteModelConfig(const std::string &model) - : model(model){}; + explicit OfflineSpeakerSegmentationPyannoteModelConfig( + const std::string &model) + : model(model) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h new file mode 100644 index 000000000..728ed7ff4 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa_onnx { + +// If you are not sure what each field means, please +// have a look of the Python file in the model directory that +// you have downloaded. +struct OfflineSpeakerSegmentationPyannoteModelMetaData { + int32_t sample_rate = 0; + int32_t window_size = 0; // in samples + int32_t window_shift = 0; // in samples + int32_t receptive_field_size = 0; // in samples + int32_t receptive_field_shift = 0; // in samples + int32_t num_speakers = 0; + int32_t powerset_max_classes = 0; + int32_t num_classes = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc new file mode 100644 index 000000000..620f22b8e --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc @@ -0,0 +1,92 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" + +namespace sherpa_onnx { + +class OfflineSpeakerSegmentationPyannoteModel::Impl { + public: + explicit Impl(const OfflineSpeakerSegmentationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.pyannote.model); + Init(buf.data(), buf.size()); + } + + Ort::Value Forward(Ort::Value x) { return Ort::Value(nullptr); } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "window_size"); + + meta_data_.window_shift = + static_cast(0.1 * meta_data_.window_size); + + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_size, + "receptive_field_size"); + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_shift, + "receptive_field_shift"); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "num_speakers"); + SHERPA_ONNX_READ_META_DATA(meta_data_.powerset_max_classes, + "powerset_max_classes"); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_classes, "num_classes"); + } + + private: + OfflineSpeakerSegmentationModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineSpeakerSegmentationPyannoteModelMetaData meta_data_; +}; + +OfflineSpeakerSegmentationPyannoteModel:: + OfflineSpeakerSegmentationPyannoteModel( + const OfflineSpeakerSegmentationModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineSpeakerSegmentationPyannoteModel:: + ~OfflineSpeakerSegmentationPyannoteModel() = default; + +Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward(Ort::Value x) { + return impl_->Forward(std::move(x)); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h new file mode 100644 index 000000000..7ea8f44c5 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h @@ -0,0 +1,39 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h" + +namespace sherpa_onnx { + +class OfflineSpeakerSegmentationPyannoteModel { + public: + explicit OfflineSpeakerSegmentationPyannoteModel( + const OfflineSpeakerSegmentationModelConfig &config); + + ~OfflineSpeakerSegmentationPyannoteModel(); + + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetMetaData() const; + + /** + * @param x A 3-D float tensor of shape (batch_size, 1, num_samples) + * @return Return a float tensor of + * shape (batch_size, num_frames, num_speakers). Note that + * num_speakers here uses powerset encoding. + */ + Ort::Value Forward(Ort::Value x); + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 1db62aa6b..165e2d9a2 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -61,8 +61,10 @@ void TensorrtConfig::Register(ParseOptions *po) { bool TensorrtConfig::Validate() const { if (trt_max_workspace_size < 0) { - SHERPA_ONNX_LOGE("trt_max_workspace_size: %ld is not valid.", - trt_max_workspace_size); + std::ostringstream os; + os << "trt_max_workspace_size: " << trt_max_workspace_size + << " is not valid."; + SHERPA_ONNX_LOGE("%s", os.str().c_str()); return false; } if (trt_max_partition_iterations < 0) { diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 7f6f685e0..9c5eb2b1a 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -35,9 +35,9 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { api.ReleaseStatus(status); } -static Ort::SessionOptions GetSessionOptionsImpl( +Ort::SessionOptions GetSessionOptionsImpl( int32_t num_threads, const std::string &provider_str, - const ProviderConfig *provider_config = nullptr) { + const ProviderConfig *provider_config /*= nullptr*/) { Provider p = StringToProvider(provider_str); Ort::SessionOptions sess_opts; @@ -259,10 +259,6 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, &config.provider_config); } -Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) { return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); } @@ -271,38 +267,4 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) { return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); } -Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -#if SHERPA_ONNX_ENABLE_TTS -Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} -#endif - -Ort::SessionOptions GetSessionOptions( - const SpeakerEmbeddingExtractorConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions( - const SpokenLanguageIdentificationConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions( - const OfflinePunctuationModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions( - const OnlinePunctuationModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 1e8beb114..e19db6c20 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -8,53 +8,28 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT -#include "sherpa-onnx/csrc/audio-tagging-model-config.h" #include "sherpa-onnx/csrc/offline-lm-config.h" -#include "sherpa-onnx/csrc/offline-model-config.h" -#include "sherpa-onnx/csrc/offline-punctuation-model-config.h" -#include "sherpa-onnx/csrc/online-punctuation-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" -#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" -#include "sherpa-onnx/csrc/spoken-language-identification.h" -#include "sherpa-onnx/csrc/vad-model-config.h" - -#if SHERPA_ONNX_ENABLE_TTS -#include "sherpa-onnx/csrc/offline-tts-model-config.h" -#endif namespace sherpa_onnx { -Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); - -Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, - const std::string &model_type); - -Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); +Ort::SessionOptions GetSessionOptionsImpl( + int32_t num_threads, const std::string &provider_str, + const ProviderConfig *provider_config = nullptr); Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); - Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); -Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); - -#if SHERPA_ONNX_ENABLE_TTS -Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); -#endif - -Ort::SessionOptions GetSessionOptions( - const SpeakerEmbeddingExtractorConfig &config); - -Ort::SessionOptions GetSessionOptions( - const SpokenLanguageIdentificationConfig &config); - -Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); -Ort::SessionOptions GetSessionOptions( - const OfflinePunctuationModelConfig &config); +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, + const std::string &model_type); -Ort::SessionOptions GetSessionOptions( - const OnlinePunctuationModelConfig &config); +template +Ort::SessionOptions GetSessionOptions(const T &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc index b084a4b1f..db921000c 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc @@ -13,7 +13,14 @@ Usage example: sherpa_onnx::OfflineSpeakerDiarizationConfig config; sherpa_onnx::ParseOptions po(kUsageMessage); config.Register(&po); - po.PrintUsage(); po.Read(argc, argv); + std::cout << config.ToString() << "\n"; + + if (!config.Validate()) { + po.PrintUsage(); + std::cerr << "Errors in config!\n"; + exit(-1); + } + sherpa_onnx::OfflineSpeakerDiarization sd(config); } diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor.cc b/sherpa-onnx/csrc/speaker-embedding-extractor.cc index 1c99de1a0..d90b0b1e0 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor.cc @@ -26,12 +26,12 @@ void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) { bool SpeakerEmbeddingExtractorConfig::Validate() const { if (model.empty()) { - SHERPA_ONNX_LOGE("Please provide --model"); + SHERPA_ONNX_LOGE("Please provide a speaker embedding extractor model"); return false; } if (!FileExists(model)) { - SHERPA_ONNX_LOGE("--speaker-embedding-model: '%s' does not exist", + SHERPA_ONNX_LOGE("speaker embedding extractor model: '%s' does not exist", model.c_str()); return false; }