Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

packaging:drop dependency on neon package #8

Merged
merged 5 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 0 additions & 66 deletions .github/workflows/unit_tests.yml

This file was deleted.

6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ for [Nemo Citrinet](https://docs.nvidia.com/nemo-framework/user-guide/latest/nem

## Install

by default this plugin will install the full pytorch, to avoid dragging all the dependencies it is recommended you install the cpu only version of pytorch **before** installing the plugin

`pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html`

If you skip the step above then the full pytorch will be installed together with the plugin

`pip install ovos-stt-plugin-citrinet`

## Configuration
Expand Down
5 changes: 2 additions & 3 deletions ovos_stt_plugin_citrinet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from ovos_plugin_manager.templates.stt import STT
from ovos_utils.log import LOG
from speech_recognition import AudioData
from streaming_stt_nemo import Model, available_languages

from ovos_stt_plugin_citrinet.engine import Model, available_languages


class CitrinetSTT(STT):

def __init__(self, config: dict = None):
super().__init__(config)
# replace default Neon model with project aina model
Model.langs["ca"]["model"] = "projecte-aina/stt-ca-citrinet-512"
self.lang = self.config.get('lang') or "ca"
self.models: Dict[str, Model] = {}
lang = self.lang.split("-")[0]
Expand Down
206 changes: 206 additions & 0 deletions ovos_stt_plugin_citrinet/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# taken from https://github.com/NeonGeckoCom/streaming-stt-nemo

# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework
# All trademark and other rights reserved by their respective owners
# Copyright 2008-2022 Neongecko.com Inc.
# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds,
# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo
# BSD-3 License
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import ctypes
import gc
import os.path

import numpy as np
import onnxruntime as ort
import sentencepiece as spm
import soxr
import torch
from huggingface_hub import hf_hub_download
from pydub import AudioSegment

languages = {
"en": {
"model": "neongeckocom/stt_en_citrinet_512_gamma_0_25",
},
"es": {
"model": "neongeckocom/stt_es_citrinet_512_gamma_0_25",
},
"fr": {
"model": "neongeckocom/stt_fr_citrinet_512_gamma_0_25",
},
"de": {
"model": "neongeckocom/stt_de_citrinet_512_gamma_0_25",
},
"it": {
"model": "neongeckocom/stt_it_citrinet_512_gamma_0_25",
},
"uk": {
"model": "neongeckocom/stt_uk_citrinet_512_gamma_0_25",
},
"nl": {
"model": "neongeckocom/stt_nl_citrinet_512_gamma_0_25",
},
"pt": {
"model": "neongeckocom/stt_pt_citrinet_512_gamma_0_25",
},
"ca": {
"model": "projecte-aina/stt-ca-citrinet-512"
},
}

sample_rate = 16000
subfolder_name = "onnx"
available_languages = list(languages.keys())


class Model:
langs = languages
sample_rate = sample_rate

def __init__(self, lang="en", model_folder=None):
if model_folder:
self._init_model_from_path(model_folder)
else:
self._init_model(lang)

def _init_model(self, lang: str):
if lang not in self.langs:
raise ValueError(f"Unsupported language '{lang}'. Available languages: {list(self.langs.keys())}")
model_name = self.langs[lang]["model"]
self._init_preprocessor(model_name)
self._init_encoder(model_name)
self._init_tokenizer(model_name)
self._trim_memory()

def _init_model_from_path(self, path: str):
if not os.path.isdir(path):
raise ValueError(f"'{path}' is not valid NemoSTT onnx model folder")
preprocessor_path = f"{path}/preprocessor.ts"
encoder_path = f"{path}/model.onnx"
tokenizer_path = f"{path}/tokenizer.spm"
self._init_preprocessor(preprocessor_path)
self._init_encoder(encoder_path)
self._init_tokenizer(tokenizer_path)
self._trim_memory()

def _init_preprocessor(self, model_name: str):
if os.path.isfile(model_name):
preprocessor_path = model_name
else:
preprocessor_path = hf_hub_download(model_name, "preprocessor.ts", subfolder=subfolder_name)
self.preprocessor = torch.jit.load(preprocessor_path)

def _init_encoder(self, model_name: str):
if os.path.isfile(model_name):
encoder_path = model_name
else:
encoder_path = hf_hub_download(model_name, "model.onnx", subfolder=subfolder_name)
self.encoder = ort.InferenceSession(encoder_path)

def _init_tokenizer(self, model_name: str):
if os.path.isfile(model_name):
tokenizer_path = model_name
else:
tokenizer_path = hf_hub_download(model_name, "tokenizer.spm", subfolder=subfolder_name)
self.tokenizer = spm.SentencePieceProcessor(tokenizer_path)

def _run_preprocessor(self, audio_16k: np.array):
input_signal = torch.tensor(audio_16k).unsqueeze(0)
length = torch.tensor(len(audio_16k)).unsqueeze(0)
processed_signal, processed_signal_len = self.preprocessor.forward(
input_signal=input_signal, length=length
)
processed_signal = processed_signal.numpy()
processed_signal_len = processed_signal_len.numpy()
return processed_signal, processed_signal_len

def _run_encoder(self, processed_signal: np.array, processed_signal_len: np.array):
outputs = self.encoder.run(None, {'audio_signal': processed_signal,
'length': processed_signal_len})
logits = outputs[0][0]
return logits

def _run_tokenizer(self, logits: np.array):
blank_id = self.tokenizer.vocab_size()
decoded_prediction = self._ctc_decode(logits, blank_id)
text = self.tokenizer.decode_ids(decoded_prediction)
current_hypotheses = [text]
return current_hypotheses

@staticmethod
def _ctc_decode(logits: np.array, blank_id: int):
labels = logits.argmax(axis=1).tolist()

previous = blank_id
decoded_prediction = []
for p in labels:
if (p != previous or previous == blank_id) and p != blank_id:
decoded_prediction.append(p)
previous = p
return decoded_prediction

def stt(self, audio_buffer: np.array, sr: int):
audio_fp32 = self._to_float32(audio_buffer)
audio_16k = self._resample(audio_fp32, sr)

processed_signal, processed_signal_len = self._run_preprocessor(audio_16k)
logits = self._run_encoder(processed_signal, processed_signal_len)
current_hypotheses = self._run_tokenizer(logits)

self._trim_memory()
return current_hypotheses

def stt_file(self, file_path: str):
audio_buffer, sr = self.read_file(file_path)
current_hypotheses = self.stt(audio_buffer, sr)
return current_hypotheses

def read_file(self, file_path: str):
audio_file = AudioSegment.from_file(file_path)
sr = audio_file.frame_rate

samples = audio_file.get_array_of_samples()
audio_buffer = np.array(samples)
return audio_buffer, sr

@staticmethod
def _trim_memory():
"""
If possible, gives memory allocated by PyTorch back to the system
"""
libc = ctypes.CDLL("libc.so.6")
libc.malloc_trim(0)
gc.collect()

Comment on lines +189 to +196
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure cross-platform compatibility in _trim_memory method.

The _trim_memory method uses libc.so.6, which is specific to Linux systems. This will raise an exception on other platforms like Windows or macOS. Consider adding a platform check or using a cross-platform approach for memory trimming.

Modify the method to check the operating system:

import platform

@staticmethod
def _trim_memory():
    """
    If possible, gives memory allocated by PyTorch back to the system.
    """
    if platform.system() == 'Linux':
        libc = ctypes.CDLL("libc.so.6")
        libc.malloc_trim(0)
    gc.collect()

def _resample(self, audio_fp32: np.array, sr: int):
audio_16k = soxr.resample(audio_fp32, sr, self.sample_rate)
return audio_16k

def _to_float32(self, audio_buffer: np.array):
audio_fp32 = np.divide(audio_buffer, np.iinfo(audio_buffer.dtype).max, dtype=np.float32)
return audio_fp32
Comment on lines +201 to +203
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle floating-point data types in _to_float32 method.

The _to_float32 method assumes that audio_buffer has an integer data type. If audio_buffer is already a floating-point array, np.iinfo will raise an error. Add a check to handle floating-point inputs appropriately.

Update the method to accommodate different data types:

def _to_float32(self, audio_buffer: np.array):
    if np.issubdtype(audio_buffer.dtype, np.integer):
        max_val = np.iinfo(audio_buffer.dtype).max
        audio_fp32 = np.divide(audio_buffer, max_val, dtype=np.float32)
    elif np.issubdtype(audio_buffer.dtype, np.floating):
        audio_fp32 = audio_buffer.astype(np.float32)
    else:
        raise ValueError(f"Unsupported audio buffer data type: {audio_buffer.dtype}")
    return audio_fp32



__all__ = ["Model", "available_languages"]
13 changes: 11 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
ovos-plugin-manager>=0.0.24
ovos-utils~=0.0,>=0.0.30
streaming-stt-nemo~=0.2
SpeechRecognition~=3.8
SpeechRecognition~=3.8
# model
torch>=1.13.1
onnxruntime
sentencepiece
# resampling
soxr
pydub
# huggingface
huggingface-hub
numpy<2.0.0
Loading