Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python API for speaker diarization. #1400

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

log "test offline speaker diarization"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav

python3 ./python-api-examples/offline-speaker-diarization.py

rm -rf *.wav *.onnx ./sherpa-onnx-pyannote-segmentation-3-0


log "test_clustering"
pushd /tmp/
mkdir test-cluster
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/windows-x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe

.github/scripts/test-speaker-diarization.sh
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/windows-x86.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe

.github/scripts/test-speaker-diarization.sh
Expand Down
118 changes: 118 additions & 0 deletions python-api-examples/offline-speaker-diarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/usr/bin/env python3
# Copyright (c) 2024 Xiaomi Corporation

"""
This file shows how to use sherpa-onnx Python API for
offline/non-streaming speaker diarization.

Usage:

Step 1: Download a speaker segmentation model

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2

Step 2: Download a speaker embedding extractor model

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx

Step 3. Download test wave files

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav

Step 4. Run it

python3 ./python-api-examples/offline-speaker-diarization.py

"""
from pathlib import Path

import sherpa_onnx
import soundfile as sf


def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5):
"""
Args:
num_speakers:
If you know the actual number of speakers in the wave file, then please
specify it. Otherwise, leave it to -1
cluster_threshold:
If num_speakers is -1, then this threshold is used for clustering.
A smaller cluster_threshold leads to more clusters, i.e., more speakers.
A larger cluster_threshold leads to fewer clusters, i.e., fewer speakers.
"""
segmentation_model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"
embedding_extractor_model = (
"./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
)

config = sherpa_onnx.OfflineSpeakerDiarizationConfig(
segmentation=sherpa_onnx.OfflineSpeakerSegmentationModelConfig(
pyannote=sherpa_onnx.OfflineSpeakerSegmentationPyannoteModelConfig(
model=segmentation_model
),
),
embedding=sherpa_onnx.SpeakerEmbeddingExtractorConfig(
model=embedding_extractor_model
),
clustering=sherpa_onnx.FastClusteringConfig(
num_clusters=num_speakers, threshold=cluster_threshold
),
min_duration_on=0.3,
min_duration_off=0.5,
)
if not config.validate():
raise RuntimeError(
"Please check your config and make sure all required files exist"
)

return sherpa_onnx.OfflineSpeakerDiarization(config)


def progress_callback(num_processed_chunk: int, num_total_chunks: int) -> int:
progress = num_processed_chunk / num_total_chunks * 100
print(f"Progress: {progress:.3f}%")
return 0


def main():
wave_filename = "./0-four-speakers-zh.wav"
if not Path(wave_filename).is_file():
raise RuntimeError(f"{wave_filename} does not exist")

audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel

# Since we know there are 4 speakers in the above test wave file, we use
# num_speakers 4 here
sd = init_speaker_diarization(num_speakers=4)
if sample_rate != sd.sample_rate:
raise RuntimeError(
f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}"
)

show_porgress = True

if show_porgress:
result = sd.process(audio, callback=progress_callback).sort_by_start_time()
else:
result = sd.process(audio).sort_by_start_time()

for r in result:
print(f"{r.start:.3f} -- {r.end:.3f} speaker_{r.speaker:02}")
# print(r) # this one is simpler


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
Matrix2D embeddings =
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
callback, callback_arg);
std::move(callback), callback_arg);

std::vector<int32_t> cluster_labels = clustering_.Cluster(
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-result.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class OfflineSpeakerDiarizationSegment {
const std::string &Text() const { return text_; }
float Duration() const { return end_ - start_; }

void SetText(const std::string &text) { text_ = text; }

std::string ToString() const;

private:
Expand Down
7 changes: 5 additions & 2 deletions sherpa-onnx/csrc/offline-speaker-diarization.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ struct OfflineSpeakerDiarizationConfig {
OfflineSpeakerDiarizationConfig(
const OfflineSpeakerSegmentationModelConfig &segmentation,
const SpeakerEmbeddingExtractorConfig &embedding,
const FastClusteringConfig &clustering)
const FastClusteringConfig &clustering, float min_duration_on,
float min_duration_off)
: segmentation(segmentation),
embedding(embedding),
clustering(clustering) {}
clustering(clustering),
min_duration_on(min_duration_on),
min_duration_off(min_duration_off) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ endif()
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND srcs
fast-clustering.cc
offline-speaker-diarization-result.cc
offline-speaker-diarization.cc
)
endif()

Expand Down
32 changes: 32 additions & 0 deletions sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"

#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"

namespace sherpa_onnx {

static void PybindOfflineSpeakerDiarizationSegment(py::module *m) {
using PyClass = OfflineSpeakerDiarizationSegment;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationSegment")
.def_property_readonly("start", &PyClass::Start)
.def_property_readonly("end", &PyClass::End)
.def_property_readonly("duration", &PyClass::Duration)
.def_property_readonly("speaker", &PyClass::Speaker)
.def_property("text", &PyClass::Text, &PyClass::SetText)
.def("__str__", &PyClass::ToString);
}

void PybindOfflineSpeakerDiarizationResult(py::module *m) {
PybindOfflineSpeakerDiarizationSegment(m);
using PyClass = OfflineSpeakerDiarizationResult;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationResult")
.def_property_readonly("num_speakers", &PyClass::NumSpeakers)
.def_property_readonly("num_segments", &PyClass::NumSegments)
.def("sort_by_start_time", &PyClass::SortByStartTime)
.def("sort_by_speaker", &PyClass::SortBySpeaker);
}

} // namespace sherpa_onnx
16 changes: 16 additions & 0 deletions sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_

#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

namespace sherpa_onnx {

void PybindOfflineSpeakerDiarizationResult(py::module *m);

}

#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
92 changes: 92 additions & 0 deletions sherpa-onnx/python/csrc/offline-speaker-diarization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"

#include <string>
#include <vector>

#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"

namespace sherpa_onnx {

static void PybindOfflineSpeakerSegmentationPyannoteModelConfig(py::module *m) {
using PyClass = OfflineSpeakerSegmentationPyannoteModelConfig;
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationPyannoteModelConfig")
.def(py::init<>())
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}

static void PybindOfflineSpeakerSegmentationModelConfig(py::module *m) {
PybindOfflineSpeakerSegmentationPyannoteModelConfig(m);

using PyClass = OfflineSpeakerSegmentationModelConfig;
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationModelConfig")
.def(py::init<>())
.def(py::init<const OfflineSpeakerSegmentationPyannoteModelConfig &,
int32_t, bool, const std::string &>(),
py::arg("pyannote"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("pyannote", &PyClass::pyannote)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}

static void PybindOfflineSpeakerDiarizationConfig(py::module *m) {
PybindOfflineSpeakerSegmentationModelConfig(m);

using PyClass = OfflineSpeakerDiarizationConfig;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationConfig")
.def(py::init<const OfflineSpeakerSegmentationModelConfig &,
const SpeakerEmbeddingExtractorConfig &,
const FastClusteringConfig &, float, float>(),
py::arg("segmentation"), py::arg("embedding"), py::arg("clustering"),
py::arg("min_duration_on") = 0.3, py::arg("min_duration_off") = 0.5)
.def_readwrite("segmentation", &PyClass::segmentation)
.def_readwrite("embedding", &PyClass::embedding)
.def_readwrite("clustering", &PyClass::clustering)
.def_readwrite("min_duration_on", &PyClass::min_duration_on)
.def_readwrite("min_duration_off", &PyClass::min_duration_off)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}

void PybindOfflineSpeakerDiarization(py::module *m) {
PybindOfflineSpeakerDiarizationConfig(m);

using PyClass = OfflineSpeakerDiarization;
py::class_<PyClass>(*m, "OfflineSpeakerDiarization")
.def(py::init<const OfflineSpeakerDiarizationConfig &>(),
py::arg("config"))
.def_property_readonly("sample_rate", &PyClass::SampleRate)
.def(
"process",
[](const PyClass &self, const std::vector<float> samples,
std::function<int32_t(int32_t, int32_t)> callback) {
if (!callback) {
return self.Process(samples.data(), samples.size());
}

std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
[callback](int32_t processed_chunks, int32_t num_chunks,
void *) -> int32_t {
callback(processed_chunks, num_chunks);
return 0;
};

return self.Process(samples.data(), samples.size(),
callback_wrapper);
},
py::arg("samples"), py::arg("callback") = py::none());
}

} // namespace sherpa_onnx
16 changes: 16 additions & 0 deletions sherpa-onnx/python/csrc/offline-speaker-diarization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_

#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

namespace sherpa_onnx {

void PybindOfflineSpeakerDiarization(py::module *m);

}

#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
12 changes: 8 additions & 4 deletions sherpa-onnx/python/csrc/sherpa-onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
#include "sherpa-onnx/python/csrc/fast-clustering.h"
#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
#endif

namespace sherpa_onnx {
Expand Down Expand Up @@ -74,14 +76,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOfflineTts(&m);
#endif

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
PybindFastClustering(&m);
#endif

PybindSpeakerEmbeddingExtractor(&m);
PybindSpeakerEmbeddingManager(&m);
PybindSpokenLanguageIdentification(&m);

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
PybindFastClustering(&m);
PybindOfflineSpeakerDiarizationResult(&m);
PybindOfflineSpeakerDiarization(&m);
#endif

PybindAlsa(&m);
}

Expand Down
Loading
Loading