Skip to content

Commit

Permalink
packaging:drop dependency on neon package (#8)
Browse files Browse the repository at this point in the history
* packaging:drop dependency on neon package

* cpu only pytorch

* Apply suggestions from code review

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* allow lower pytorch version

* .

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
JarbasAl and coderabbitai[bot] authored Dec 13, 2024
1 parent 8eaef9e commit 4bfcbeb
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 71 deletions.
66 changes: 0 additions & 66 deletions .github/workflows/unit_tests.yml

This file was deleted.

6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ for [Nemo Citrinet](https://docs.nvidia.com/nemo-framework/user-guide/latest/nem
## Install

by default this plugin will install the full pytorch, to avoid dragging all the dependencies it is recommended you install the cpu only version of pytorch **before** installing the plugin

`pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html`

If you skip the step above then the full pytorch will be installed together with the plugin

`pip install ovos-stt-plugin-citrinet`

## Configuration
Expand Down
5 changes: 2 additions & 3 deletions ovos_stt_plugin_citrinet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from ovos_plugin_manager.templates.stt import STT
from ovos_utils.log import LOG
from speech_recognition import AudioData
from streaming_stt_nemo import Model, available_languages

from ovos_stt_plugin_citrinet.engine import Model, available_languages


class CitrinetSTT(STT):

def __init__(self, config: dict = None):
super().__init__(config)
# replace default Neon model with project aina model
Model.langs["ca"]["model"] = "projecte-aina/stt-ca-citrinet-512"
self.lang = self.config.get('lang') or "ca"
self.models: Dict[str, Model] = {}
lang = self.lang.split("-")[0]
Expand Down
206 changes: 206 additions & 0 deletions ovos_stt_plugin_citrinet/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# taken from https://github.com/NeonGeckoCom/streaming-stt-nemo

# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework
# All trademark and other rights reserved by their respective owners
# Copyright 2008-2022 Neongecko.com Inc.
# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds,
# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo
# BSD-3 License
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import ctypes
import gc
import os.path

import numpy as np
import onnxruntime as ort
import sentencepiece as spm
import soxr
import torch
from huggingface_hub import hf_hub_download
from pydub import AudioSegment

languages = {
"en": {
"model": "neongeckocom/stt_en_citrinet_512_gamma_0_25",
},
"es": {
"model": "neongeckocom/stt_es_citrinet_512_gamma_0_25",
},
"fr": {
"model": "neongeckocom/stt_fr_citrinet_512_gamma_0_25",
},
"de": {
"model": "neongeckocom/stt_de_citrinet_512_gamma_0_25",
},
"it": {
"model": "neongeckocom/stt_it_citrinet_512_gamma_0_25",
},
"uk": {
"model": "neongeckocom/stt_uk_citrinet_512_gamma_0_25",
},
"nl": {
"model": "neongeckocom/stt_nl_citrinet_512_gamma_0_25",
},
"pt": {
"model": "neongeckocom/stt_pt_citrinet_512_gamma_0_25",
},
"ca": {
"model": "projecte-aina/stt-ca-citrinet-512"
},
}

sample_rate = 16000
subfolder_name = "onnx"
available_languages = list(languages.keys())


class Model:
langs = languages
sample_rate = sample_rate

def __init__(self, lang="en", model_folder=None):
if model_folder:
self._init_model_from_path(model_folder)
else:
self._init_model(lang)

def _init_model(self, lang: str):
if lang not in self.langs:
raise ValueError(f"Unsupported language '{lang}'. Available languages: {list(self.langs.keys())}")
model_name = self.langs[lang]["model"]
self._init_preprocessor(model_name)
self._init_encoder(model_name)
self._init_tokenizer(model_name)
self._trim_memory()

def _init_model_from_path(self, path: str):
if not os.path.isdir(path):
raise ValueError(f"'{path}' is not valid NemoSTT onnx model folder")
preprocessor_path = f"{path}/preprocessor.ts"
encoder_path = f"{path}/model.onnx"
tokenizer_path = f"{path}/tokenizer.spm"
self._init_preprocessor(preprocessor_path)
self._init_encoder(encoder_path)
self._init_tokenizer(tokenizer_path)
self._trim_memory()

def _init_preprocessor(self, model_name: str):
if os.path.isfile(model_name):
preprocessor_path = model_name
else:
preprocessor_path = hf_hub_download(model_name, "preprocessor.ts", subfolder=subfolder_name)
self.preprocessor = torch.jit.load(preprocessor_path)

def _init_encoder(self, model_name: str):
if os.path.isfile(model_name):
encoder_path = model_name
else:
encoder_path = hf_hub_download(model_name, "model.onnx", subfolder=subfolder_name)
self.encoder = ort.InferenceSession(encoder_path)

def _init_tokenizer(self, model_name: str):
if os.path.isfile(model_name):
tokenizer_path = model_name
else:
tokenizer_path = hf_hub_download(model_name, "tokenizer.spm", subfolder=subfolder_name)
self.tokenizer = spm.SentencePieceProcessor(tokenizer_path)

def _run_preprocessor(self, audio_16k: np.array):
input_signal = torch.tensor(audio_16k).unsqueeze(0)
length = torch.tensor(len(audio_16k)).unsqueeze(0)
processed_signal, processed_signal_len = self.preprocessor.forward(
input_signal=input_signal, length=length
)
processed_signal = processed_signal.numpy()
processed_signal_len = processed_signal_len.numpy()
return processed_signal, processed_signal_len

def _run_encoder(self, processed_signal: np.array, processed_signal_len: np.array):
outputs = self.encoder.run(None, {'audio_signal': processed_signal,
'length': processed_signal_len})
logits = outputs[0][0]
return logits

def _run_tokenizer(self, logits: np.array):
blank_id = self.tokenizer.vocab_size()
decoded_prediction = self._ctc_decode(logits, blank_id)
text = self.tokenizer.decode_ids(decoded_prediction)
current_hypotheses = [text]
return current_hypotheses

@staticmethod
def _ctc_decode(logits: np.array, blank_id: int):
labels = logits.argmax(axis=1).tolist()

previous = blank_id
decoded_prediction = []
for p in labels:
if (p != previous or previous == blank_id) and p != blank_id:
decoded_prediction.append(p)
previous = p
return decoded_prediction

def stt(self, audio_buffer: np.array, sr: int):
audio_fp32 = self._to_float32(audio_buffer)
audio_16k = self._resample(audio_fp32, sr)

processed_signal, processed_signal_len = self._run_preprocessor(audio_16k)
logits = self._run_encoder(processed_signal, processed_signal_len)
current_hypotheses = self._run_tokenizer(logits)

self._trim_memory()
return current_hypotheses

def stt_file(self, file_path: str):
audio_buffer, sr = self.read_file(file_path)
current_hypotheses = self.stt(audio_buffer, sr)
return current_hypotheses

def read_file(self, file_path: str):
audio_file = AudioSegment.from_file(file_path)
sr = audio_file.frame_rate

samples = audio_file.get_array_of_samples()
audio_buffer = np.array(samples)
return audio_buffer, sr

@staticmethod
def _trim_memory():
"""
If possible, gives memory allocated by PyTorch back to the system
"""
libc = ctypes.CDLL("libc.so.6")
libc.malloc_trim(0)
gc.collect()

def _resample(self, audio_fp32: np.array, sr: int):
audio_16k = soxr.resample(audio_fp32, sr, self.sample_rate)
return audio_16k

def _to_float32(self, audio_buffer: np.array):
audio_fp32 = np.divide(audio_buffer, np.iinfo(audio_buffer.dtype).max, dtype=np.float32)
return audio_fp32


__all__ = ["Model", "available_languages"]
13 changes: 11 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
ovos-plugin-manager>=0.0.24
ovos-utils~=0.0,>=0.0.30
streaming-stt-nemo~=0.2
SpeechRecognition~=3.8
SpeechRecognition~=3.8
# model
torch>=1.13.1
onnxruntime
sentencepiece
# resampling
soxr
pydub
# huggingface
huggingface-hub
numpy<2.0.0

0 comments on commit 4bfcbeb

Please sign in to comment.