forked from k2-fsa/sherpa-onnx
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Export Pyannote speaker segmentation models to onnx (k2-fsa#1382)
- Loading branch information
1 parent
750515a
commit 04d83aa
Showing
9 changed files
with
707 additions
and
0 deletions.
There are no files selected for viewing
86 changes: 86 additions & 0 deletions
86
.github/workflows/export-pyannote-segmentation-to-onnx.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
name: export-pyannote-segmentation-to-onnx | ||
|
||
on: | ||
workflow_dispatch: | ||
|
||
concurrency: | ||
group: export-pyannote-segmentation-to-onnx-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
export-pyannote-segmentation-to-onnx: | ||
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
name: export Pyannote segmentation models to ONNX | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
os: [macos-latest] | ||
python-version: ["3.10"] | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- name: Setup Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Install pyannote | ||
shell: bash | ||
run: | | ||
pip install pyannote.audio onnx onnxruntime | ||
- name: Run | ||
shell: bash | ||
run: | | ||
d=sherpa-onnx-pyannote-segmentation-3-0 | ||
src=$PWD/$d | ||
mkdir -p $src | ||
pushd scripts/pyannote/segmentation | ||
./run.sh | ||
cp ./*.onnx $src/ | ||
cp ./README.md $src/ | ||
cp ./LICENSE $src/ | ||
cp ./run.sh $src/ | ||
cp ./*.py $src/ | ||
popd | ||
ls -lh $d | ||
tar cjfv $d.tar.bz2 $d | ||
- name: Release | ||
uses: svenstaro/upload-release-action@v2 | ||
with: | ||
file_glob: true | ||
file: ./*.tar.bz2 | ||
overwrite: true | ||
repo_name: k2-fsa/sherpa-onnx | ||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
tag: speaker-segmentation-models | ||
|
||
- name: Publish to huggingface | ||
env: | ||
HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
uses: nick-fields/retry@v3 | ||
with: | ||
max_attempts: 20 | ||
timeout_seconds: 200 | ||
shell: bash | ||
command: | | ||
git config --global user.email "[email protected]" | ||
git config --global user.name "Fangjun Kuang" | ||
d=sherpa-onnx-pyannote-segmentation-3-0 | ||
export GIT_LFS_SKIP_SMUDGE=1 | ||
export GIT_CLONE_PROTECTION_ACTIVE=false | ||
git clone https://huggingface.co/csukuangfj/$d huggingface | ||
cp -v $d/* ./huggingface | ||
cd huggingface | ||
git lfs track "*.onnx" | ||
git status | ||
git add . | ||
git status | ||
git commit -m "add models" | ||
git push https://csukuangfj:[email protected]/csukuangfj/$d main |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*.bin | ||
*.onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from typing import Any, Dict | ||
|
||
import onnx | ||
import torch | ||
from onnxruntime.quantization import QuantType, quantize_dynamic | ||
from pyannote.audio import Model | ||
from pyannote.audio.core.task import Problem, Resolution | ||
|
||
|
||
def add_meta_data(filename: str, meta_data: Dict[str, Any]): | ||
"""Add meta data to an ONNX model. It is changed in-place. | ||
Args: | ||
filename: | ||
Filename of the ONNX model to be changed. | ||
meta_data: | ||
Key-value pairs. | ||
""" | ||
model = onnx.load(filename) | ||
|
||
while len(model.metadata_props): | ||
model.metadata_props.pop() | ||
|
||
for key, value in meta_data.items(): | ||
meta = model.metadata_props.add() | ||
meta.key = key | ||
meta.value = str(value) | ||
|
||
onnx.save(model, filename) | ||
|
||
|
||
@torch.no_grad() | ||
def main(): | ||
# You can download ./pytorch_model.bin from | ||
# https://hf-mirror.com/csukuangfj/pyannote-models/tree/main/segmentation-3.0 | ||
pt_filename = "./pytorch_model.bin" | ||
model = Model.from_pretrained(pt_filename) | ||
model.eval() | ||
assert model.dimension == 7, model.dimension | ||
print(model.specifications) | ||
|
||
assert ( | ||
model.specifications.problem == Problem.MONO_LABEL_CLASSIFICATION | ||
), model.specifications.problem | ||
|
||
assert ( | ||
model.specifications.resolution == Resolution.FRAME | ||
), model.specifications.resolution | ||
|
||
assert model.specifications.duration == 10.0, model.specifications.duration | ||
|
||
assert model.audio.sample_rate == 16000, model.audio.sample_rate | ||
|
||
# (batch, num_channels, num_samples) | ||
assert list(model.example_input_array.shape) == [ | ||
1, | ||
1, | ||
16000 * 10, | ||
], model.example_input_array.shape | ||
|
||
example_output = model(model.example_input_array) | ||
|
||
# (batch, num_frames, num_classes) | ||
assert list(example_output.shape) == [1, 589, 7], example_output.shape | ||
|
||
assert model.receptive_field.step == 0.016875, model.receptive_field.step | ||
assert model.receptive_field.duration == 0.0619375, model.receptive_field.duration | ||
assert model.receptive_field.step * 16000 == 270, model.receptive_field.step * 16000 | ||
assert model.receptive_field.duration * 16000 == 991, ( | ||
model.receptive_field.duration * 16000 | ||
) | ||
|
||
opset_version = 18 | ||
|
||
filename = "model.onnx" | ||
torch.onnx.export( | ||
model, | ||
model.example_input_array, | ||
filename, | ||
opset_version=opset_version, | ||
input_names=["x"], | ||
output_names=["y"], | ||
dynamic_axes={ | ||
"x": {0: "N", 2: "T"}, | ||
"y": {0: "N", 1: "T"}, | ||
}, | ||
) | ||
|
||
sample_rate = model.audio.sample_rate | ||
|
||
window_size = int(model.specifications.duration) * 16000 | ||
receptive_field_size = int(model.receptive_field.duration * 16000) | ||
receptive_field_shift = int(model.receptive_field.step * 16000) | ||
|
||
meta_data = { | ||
"num_speakers": len(model.specifications.classes), | ||
"powerset_max_classes": model.specifications.powerset_max_classes, | ||
"num_classes": model.dimension, | ||
"sample_rate": sample_rate, | ||
"window_size": window_size, | ||
"receptive_field_size": receptive_field_size, | ||
"receptive_field_shift": receptive_field_shift, | ||
"model_type": "pyannote-segmentation-3.0", | ||
"version": "1", | ||
"model_author": "pyannote", | ||
"maintainer": "k2-fsa", | ||
"url_1": "https://huggingface.co/pyannote/segmentation-3.0", | ||
"url_2": "https://huggingface.co/csukuangfj/pyannote-models/tree/main/segmentation-3.0", | ||
"license": "https://huggingface.co/pyannote/segmentation-3.0/blob/main/LICENSE", | ||
} | ||
add_meta_data(filename=filename, meta_data=meta_data) | ||
|
||
print("Generate int8 quantization models") | ||
|
||
filename_int8 = "model.int8.onnx" | ||
quantize_dynamic( | ||
model_input=filename, | ||
model_output=filename_int8, | ||
weight_type=QuantType.QUInt8, | ||
) | ||
|
||
print(f"Saved to {filename} and {filename_int8}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
|
||
# config.yaml | ||
|
||
|
||
```yaml | ||
task: | ||
_target_: pyannote.audio.tasks.SpeakerDiarization | ||
duration: 10.0 | ||
max_speakers_per_chunk: 3 | ||
max_speakers_per_frame: 2 | ||
model: | ||
_target_: pyannote.audio.models.segmentation.PyanNet | ||
sample_rate: 16000 | ||
num_channels: 1 | ||
sincnet: | ||
stride: 10 | ||
lstm: | ||
hidden_size: 128 | ||
num_layers: 4 | ||
bidirectional: true | ||
monolithic: true | ||
linear: | ||
hidden_size: 128 | ||
num_layers: 2 | ||
``` | ||
# Model architecture of ./pytorch_model.bin | ||
`print(model)`: | ||
|
||
```python3 | ||
PyanNet( | ||
(sincnet): SincNet( | ||
(wav_norm1d): InstanceNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) | ||
(conv1d): ModuleList( | ||
(0): Encoder( | ||
(filterbank): ParamSincFB() | ||
) | ||
(1): Conv1d(80, 60, kernel_size=(5,), stride=(1,)) | ||
(2): Conv1d(60, 60, kernel_size=(5,), stride=(1,)) | ||
) | ||
(pool1d): ModuleList( | ||
(0-2): 3 x MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False) | ||
) | ||
(norm1d): ModuleList( | ||
(0): InstanceNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) | ||
(1-2): 2 x InstanceNorm1d(60, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) | ||
) | ||
) | ||
(lstm): LSTM(60, 128, num_layers=4, batch_first=True, dropout=0.5, bidirectional=True) | ||
(linear): ModuleList( | ||
(0): Linear(in_features=256, out_features=128, bias=True) | ||
(1): Linear(in_features=128, out_features=128, bias=True) | ||
) | ||
(classifier): Linear(in_features=128, out_features=7, bias=True) | ||
(activation): LogSoftmax(dim=-1) | ||
) | ||
``` | ||
|
||
```python3 | ||
>>> list(model.specifications) | ||
[Specifications(problem=<Problem.MONO_LABEL_CLASSIFICATION: 1>, resolution=<Resolution.FRAME: 1>, duration=10.0, min_duration=None, warm_up=(0.0, 0.0), classes=['speaker#1', 'speaker#2', 'speaker#3'], powerset_max_classes=2, permutation_invariant=True)] | ||
``` | ||
|
||
```python3 | ||
>>> model.hparams | ||
"linear": {'hidden_size': 128, 'num_layers': 2} | ||
"lstm": {'hidden_size': 128, 'num_layers': 4, 'bidirectional': True, 'monolithic': True, 'dropout': 0.5, 'batch_first': True} | ||
"num_channels": 1 | ||
"sample_rate": 16000 | ||
"sincnet": {'stride': 10, 'sample_rate': 16000} | ||
``` | ||
|
||
## Papers | ||
|
||
- [pyannote.audio 2.1 speaker diarization pipeline: principle, benchmark, and recipe](https://hal.science/hal-04247212/document) | ||
- [pyannote.audio speaker diarization pipeline at VoxSRC 2023](https://mmai.io/datasets/voxceleb/voxsrc/data_workshop_2023/reports/pyannote_report.pdf) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#!/usr/bin/env bash | ||
|
||
|
||
python3 -m onnxruntime.quantization.preprocess --input model.onnx --output tmp.preprocessed.onnx | ||
mv ./tmp.preprocessed.onnx ./model.onnx | ||
./show-onnx.py --filename ./model.onnx | ||
|
||
<<EOF | ||
=========./model.onnx========== | ||
NodeArg(name='x', type='tensor(float)', shape=[1, 1, 'T']) | ||
----- | ||
NodeArg(name='y', type='tensor(float)', shape=[1, 'floor(floor(floor(floor(T/10 - 251/10)/3 - 2/3)/3)/3 - 8/3) + 1', 7]) | ||
floor(floor(floor(floor(T/10 - 251/10)/3 - 2/3)/3)/3 - 8/3) + 1 | ||
= floor(floor(floor(floor(T - 251)/30 - 2/3)/3)/3 - 8/3) + 1 | ||
= floor(floor(floor(floor(T - 271)/30)/3)/3 - 8/3) + 1 | ||
= floor(floor(floor(floor(T - 271)/90))/3 - 8/3) + 1 | ||
= floor(floor(floor(T - 271)/90)/3 - 8/3) + 1 | ||
= floor(floor((T - 271)/90)/3 - 8/3) + 1 | ||
= floor(floor((T - 271)/90 - 8)/3) + 1 | ||
= floor(floor((T - 271 - 720)/90)/3) + 1 | ||
= floor(floor((T - 991)/90)/3) + 1 | ||
= floor(floor((T - 991)/270)) + 1 | ||
= (T - 991)/270 + 1 | ||
= (T - 991 + 270)/270 | ||
= (T - 721)/270 | ||
It means: | ||
- Number of input samples should be at least 721 | ||
- One frame corresponds to 270 samples. (If we use T + 270, it outputs one more frame) | ||
EOF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#!/usr/bin/env bash | ||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
|
||
set -ex | ||
function install_pyannote() { | ||
pip install pyannote.audio onnx onnxruntime | ||
} | ||
|
||
function download_test_files() { | ||
curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin | ||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav | ||
} | ||
|
||
install_pyannote | ||
download_test_files | ||
|
||
./export-onnx.py | ||
./preprocess.sh | ||
|
||
echo "----------torch----------" | ||
./vad-torch.py | ||
|
||
echo "----------onnx model.onnx----------" | ||
./vad-onnx.py --model ./model.onnx --wav ./lei-jun-test.wav | ||
|
||
echo "----------onnx model.int8.onnx----------" | ||
./vad-onnx.py --model ./model.int8.onnx --wav ./lei-jun-test.wav | ||
|
||
cat >README.md << EOF | ||
# Introduction | ||
Models in this file are converted from | ||
https://huggingface.co/pyannote/segmentation-3.0/tree/main | ||
EOF | ||
|
||
cat >LICENSE <<EOF | ||
MIT License | ||
Copyright (c) 2022 CNRS | ||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. | ||
EOF |
Oops, something went wrong.