Skip to content

Commit

Permalink
reduce dependencies (#10)
Browse files Browse the repository at this point in the history
* reduce dependencies

* simplify

* simplify

* @coderabbitai suggestion
  • Loading branch information
JarbasAl authored Dec 13, 2024
1 parent 5872b5d commit 4d0e024
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 74 deletions.
14 changes: 7 additions & 7 deletions ovos_stt_plugin_citrinet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ovos_utils.log import LOG
from speech_recognition import AudioData

from ovos_stt_plugin_citrinet.engine import Model, available_languages
from ovos_stt_plugin_citrinet.engine import Model


class CitrinetSTT(STT):
Expand All @@ -15,9 +15,9 @@ def __init__(self, config: dict = None):
self.lang = self.config.get('lang') or "ca"
self.models: Dict[str, Model] = {}
lang = self.lang.split("-")[0]
if lang not in available_languages:
raise ValueError(f"unsupported language, must be one of {available_languages}")
LOG.info(f"preloading model: {Model.langs[lang]}")
if lang not in self.available_languages:
raise ValueError(f"unsupported language, must be one of {self.available_languages}")
LOG.info(f"preloading model: {Model.default_models[lang]}")
self.load_model(lang)

def load_model(self, lang: str):
Expand All @@ -27,7 +27,7 @@ def load_model(self, lang: str):

@property
def available_languages(self) -> set:
return set(available_languages)
return set(Model.default_models.keys())

def execute(self, audio: AudioData, language: Optional[str] = None):
'''
Expand All @@ -40,8 +40,8 @@ def execute(self, audio: AudioData, language: Optional[str] = None):
'''
language = language or self.lang
lang = language.split("-")[0]
if lang not in available_languages:
raise ValueError(f"unsupported language, must be one of {available_languages}")
if lang not in self.available_languages:
raise ValueError(f"unsupported language, must be one of {self.available_languages}")
model = self.load_model(lang)

audio_buffer = np.frombuffer(audio.get_raw_data(), dtype=np.int16)
Expand Down
100 changes: 36 additions & 64 deletions ovos_stt_plugin_citrinet/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# taken from https://github.com/NeonGeckoCom/streaming-stt-nemo
# modified from https://github.com/NeonGeckoCom/streaming-stt-nemo

# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework
# All trademark and other rights reserved by their respective owners
Expand Down Expand Up @@ -31,64 +31,41 @@
import ctypes
import gc
import os.path
from typing import Optional

import numpy as np
import onnxruntime as ort
import sentencepiece as spm
import soxr
import torch
import torch # TODO - try to drop dependency if we can convert preprocessor to onnx, currently not possible
from huggingface_hub import hf_hub_download
from pydub import AudioSegment

languages = {
"en": {
"model": "neongeckocom/stt_en_citrinet_512_gamma_0_25",
},
"es": {
"model": "neongeckocom/stt_es_citrinet_512_gamma_0_25",
},
"fr": {
"model": "neongeckocom/stt_fr_citrinet_512_gamma_0_25",
},
"de": {
"model": "neongeckocom/stt_de_citrinet_512_gamma_0_25",
},
"it": {
"model": "neongeckocom/stt_it_citrinet_512_gamma_0_25",
},
"uk": {
"model": "neongeckocom/stt_uk_citrinet_512_gamma_0_25",
},
"nl": {
"model": "neongeckocom/stt_nl_citrinet_512_gamma_0_25",
},
"pt": {
"model": "neongeckocom/stt_pt_citrinet_512_gamma_0_25",
},
"ca": {
"model": "projecte-aina/stt-ca-citrinet-512"
},
}

sample_rate = 16000
subfolder_name = "onnx"
available_languages = list(languages.keys())
from ovos_utils.log import LOG


class Model:
langs = languages
sample_rate = sample_rate

def __init__(self, lang="en", model_folder=None):
default_models = {
"en": "neongeckocom/stt_en_citrinet_512_gamma_0_25",
"es": "neongeckocom/stt_es_citrinet_512_gamma_0_25",
"fr": "neongeckocom/stt_fr_citrinet_512_gamma_0_25",
"de": "neongeckocom/stt_de_citrinet_512_gamma_0_25",
"it": "neongeckocom/stt_it_citrinet_512_gamma_0_25",
"uk": "neongeckocom/stt_uk_citrinet_512_gamma_0_25",
"nl": "neongeckocom/stt_nl_citrinet_512_gamma_0_25",
"pt": "neongeckocom/stt_pt_citrinet_512_gamma_0_25",
"ca": "projecte-aina/stt-ca-citrinet-512",
}
sample_rate = 16000
subfolder_name = "onnx"

def __init__(self, lang: str, model_folder: Optional[str] = None):
if model_folder:
self._init_model_from_path(model_folder)
else:
self._init_model(lang)

def _init_model(self, lang: str):
if lang not in self.langs:
raise ValueError(f"Unsupported language '{lang}'. Available languages: {list(self.langs.keys())}")
model_name = self.langs[lang]["model"]
if lang not in self.default_models:
raise ValueError(f"Unsupported language '{lang}'. Available languages: {list(self.default_models.keys())}")
model_name = self.default_models[lang]
self._init_preprocessor(model_name)
self._init_encoder(model_name)
self._init_tokenizer(model_name)
Expand All @@ -109,21 +86,21 @@ def _init_preprocessor(self, model_name: str):
if os.path.isfile(model_name):
preprocessor_path = model_name
else:
preprocessor_path = hf_hub_download(model_name, "preprocessor.ts", subfolder=subfolder_name)
preprocessor_path = hf_hub_download(model_name, "preprocessor.ts", subfolder=self.subfolder_name)
self.preprocessor = torch.jit.load(preprocessor_path)

def _init_encoder(self, model_name: str):
if os.path.isfile(model_name):
encoder_path = model_name
else:
encoder_path = hf_hub_download(model_name, "model.onnx", subfolder=subfolder_name)
encoder_path = hf_hub_download(model_name, "model.onnx", subfolder=self.subfolder_name)
self.encoder = ort.InferenceSession(encoder_path)

def _init_tokenizer(self, model_name: str):
if os.path.isfile(model_name):
tokenizer_path = model_name
else:
tokenizer_path = hf_hub_download(model_name, "tokenizer.spm", subfolder=subfolder_name)
tokenizer_path = hf_hub_download(model_name, "tokenizer.spm", subfolder=self.subfolder_name)
self.tokenizer = spm.SentencePieceProcessor(tokenizer_path)

def _run_preprocessor(self, audio_16k: np.array):
Expand Down Expand Up @@ -152,7 +129,6 @@ def _run_tokenizer(self, logits: np.array):
@staticmethod
def _ctc_decode(logits: np.array, blank_id: int):
labels = logits.argmax(axis=1).tolist()

previous = blank_id
decoded_prediction = []
for p in labels:
Expand All @@ -172,19 +148,6 @@ def stt(self, audio_buffer: np.array, sr: int):
self._trim_memory()
return current_hypotheses

def stt_file(self, file_path: str):
audio_buffer, sr = self.read_file(file_path)
current_hypotheses = self.stt(audio_buffer, sr)
return current_hypotheses

def read_file(self, file_path: str):
audio_file = AudioSegment.from_file(file_path)
sr = audio_file.frame_rate

samples = audio_file.get_array_of_samples()
audio_buffer = np.array(samples)
return audio_buffer, sr

@staticmethod
def _trim_memory():
"""
Expand All @@ -195,12 +158,21 @@ def _trim_memory():
gc.collect()

def _resample(self, audio_fp32: np.array, sr: int):
if sr == self.sample_rate:
return audio_fp32
try:
import soxr
except ImportError:
LOG.error(
f"Either provide audio at {self.sample_rate} sample rate or install soxr for automatic resampling")
raise
audio_16k = soxr.resample(audio_fp32, sr, self.sample_rate)
return audio_16k

def _to_float32(self, audio_buffer: np.array):
@staticmethod
def _to_float32(audio_buffer: np.array):
audio_fp32 = np.divide(audio_buffer, np.iinfo(audio_buffer.dtype).max, dtype=np.float32)
return audio_fp32


__all__ = ["Model", "available_languages"]
__all__ = ["Model"]
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ SpeechRecognition~=3.8
torch>=1.13.1
onnxruntime
sentencepiece
# resampling
soxr
pydub
# huggingface
huggingface-hub
numpy<2.0.0
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def required(requirements_file):
license='Apache-2.0',
packages=['ovos_stt_plugin_citrinet'],
install_requires=required("requirements.txt"),
extras_require={
'resampling': ["soxr"]
},
zip_safe=True,
classifiers=[
'Development Status :: 3 - Alpha',
Expand Down

0 comments on commit 4d0e024

Please sign in to comment.