From 4d0e024e815f0e4292badcf92bbd615ad3d98cc3 Mon Sep 17 00:00:00 2001 From: JarbasAI <33701864+JarbasAl@users.noreply.github.com> Date: Fri, 13 Dec 2024 18:25:47 +0000 Subject: [PATCH] reduce dependencies (#10) * reduce dependencies * simplify * simplify * @coderabbitai suggestion --- ovos_stt_plugin_citrinet/__init__.py | 14 ++-- ovos_stt_plugin_citrinet/engine.py | 100 ++++++++++----------------- requirements.txt | 3 - setup.py | 3 + 4 files changed, 46 insertions(+), 74 deletions(-) diff --git a/ovos_stt_plugin_citrinet/__init__.py b/ovos_stt_plugin_citrinet/__init__.py index 721b4dd..06d815d 100644 --- a/ovos_stt_plugin_citrinet/__init__.py +++ b/ovos_stt_plugin_citrinet/__init__.py @@ -5,7 +5,7 @@ from ovos_utils.log import LOG from speech_recognition import AudioData -from ovos_stt_plugin_citrinet.engine import Model, available_languages +from ovos_stt_plugin_citrinet.engine import Model class CitrinetSTT(STT): @@ -15,9 +15,9 @@ def __init__(self, config: dict = None): self.lang = self.config.get('lang') or "ca" self.models: Dict[str, Model] = {} lang = self.lang.split("-")[0] - if lang not in available_languages: - raise ValueError(f"unsupported language, must be one of {available_languages}") - LOG.info(f"preloading model: {Model.langs[lang]}") + if lang not in self.available_languages: + raise ValueError(f"unsupported language, must be one of {self.available_languages}") + LOG.info(f"preloading model: {Model.default_models[lang]}") self.load_model(lang) def load_model(self, lang: str): @@ -27,7 +27,7 @@ def load_model(self, lang: str): @property def available_languages(self) -> set: - return set(available_languages) + return set(Model.default_models.keys()) def execute(self, audio: AudioData, language: Optional[str] = None): ''' @@ -40,8 +40,8 @@ def execute(self, audio: AudioData, language: Optional[str] = None): ''' language = language or self.lang lang = language.split("-")[0] - if lang not in available_languages: - raise ValueError(f"unsupported language, must be one of {available_languages}") + if lang not in self.available_languages: + raise ValueError(f"unsupported language, must be one of {self.available_languages}") model = self.load_model(lang) audio_buffer = np.frombuffer(audio.get_raw_data(), dtype=np.int16) diff --git a/ovos_stt_plugin_citrinet/engine.py b/ovos_stt_plugin_citrinet/engine.py index f0e6cff..fe7d026 100644 --- a/ovos_stt_plugin_citrinet/engine.py +++ b/ovos_stt_plugin_citrinet/engine.py @@ -1,4 +1,4 @@ -# taken from https://github.com/NeonGeckoCom/streaming-stt-nemo +# modified from https://github.com/NeonGeckoCom/streaming-stt-nemo # NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework # All trademark and other rights reserved by their respective owners @@ -31,64 +31,41 @@ import ctypes import gc import os.path +from typing import Optional import numpy as np import onnxruntime as ort import sentencepiece as spm -import soxr -import torch +import torch # TODO - try to drop dependency if we can convert preprocessor to onnx, currently not possible from huggingface_hub import hf_hub_download -from pydub import AudioSegment - -languages = { - "en": { - "model": "neongeckocom/stt_en_citrinet_512_gamma_0_25", - }, - "es": { - "model": "neongeckocom/stt_es_citrinet_512_gamma_0_25", - }, - "fr": { - "model": "neongeckocom/stt_fr_citrinet_512_gamma_0_25", - }, - "de": { - "model": "neongeckocom/stt_de_citrinet_512_gamma_0_25", - }, - "it": { - "model": "neongeckocom/stt_it_citrinet_512_gamma_0_25", - }, - "uk": { - "model": "neongeckocom/stt_uk_citrinet_512_gamma_0_25", - }, - "nl": { - "model": "neongeckocom/stt_nl_citrinet_512_gamma_0_25", - }, - "pt": { - "model": "neongeckocom/stt_pt_citrinet_512_gamma_0_25", - }, - "ca": { - "model": "projecte-aina/stt-ca-citrinet-512" - }, -} - -sample_rate = 16000 -subfolder_name = "onnx" -available_languages = list(languages.keys()) +from ovos_utils.log import LOG class Model: - langs = languages - sample_rate = sample_rate - - def __init__(self, lang="en", model_folder=None): + default_models = { + "en": "neongeckocom/stt_en_citrinet_512_gamma_0_25", + "es": "neongeckocom/stt_es_citrinet_512_gamma_0_25", + "fr": "neongeckocom/stt_fr_citrinet_512_gamma_0_25", + "de": "neongeckocom/stt_de_citrinet_512_gamma_0_25", + "it": "neongeckocom/stt_it_citrinet_512_gamma_0_25", + "uk": "neongeckocom/stt_uk_citrinet_512_gamma_0_25", + "nl": "neongeckocom/stt_nl_citrinet_512_gamma_0_25", + "pt": "neongeckocom/stt_pt_citrinet_512_gamma_0_25", + "ca": "projecte-aina/stt-ca-citrinet-512", + } + sample_rate = 16000 + subfolder_name = "onnx" + + def __init__(self, lang: str, model_folder: Optional[str] = None): if model_folder: self._init_model_from_path(model_folder) else: self._init_model(lang) def _init_model(self, lang: str): - if lang not in self.langs: - raise ValueError(f"Unsupported language '{lang}'. Available languages: {list(self.langs.keys())}") - model_name = self.langs[lang]["model"] + if lang not in self.default_models: + raise ValueError(f"Unsupported language '{lang}'. Available languages: {list(self.default_models.keys())}") + model_name = self.default_models[lang] self._init_preprocessor(model_name) self._init_encoder(model_name) self._init_tokenizer(model_name) @@ -109,21 +86,21 @@ def _init_preprocessor(self, model_name: str): if os.path.isfile(model_name): preprocessor_path = model_name else: - preprocessor_path = hf_hub_download(model_name, "preprocessor.ts", subfolder=subfolder_name) + preprocessor_path = hf_hub_download(model_name, "preprocessor.ts", subfolder=self.subfolder_name) self.preprocessor = torch.jit.load(preprocessor_path) def _init_encoder(self, model_name: str): if os.path.isfile(model_name): encoder_path = model_name else: - encoder_path = hf_hub_download(model_name, "model.onnx", subfolder=subfolder_name) + encoder_path = hf_hub_download(model_name, "model.onnx", subfolder=self.subfolder_name) self.encoder = ort.InferenceSession(encoder_path) def _init_tokenizer(self, model_name: str): if os.path.isfile(model_name): tokenizer_path = model_name else: - tokenizer_path = hf_hub_download(model_name, "tokenizer.spm", subfolder=subfolder_name) + tokenizer_path = hf_hub_download(model_name, "tokenizer.spm", subfolder=self.subfolder_name) self.tokenizer = spm.SentencePieceProcessor(tokenizer_path) def _run_preprocessor(self, audio_16k: np.array): @@ -152,7 +129,6 @@ def _run_tokenizer(self, logits: np.array): @staticmethod def _ctc_decode(logits: np.array, blank_id: int): labels = logits.argmax(axis=1).tolist() - previous = blank_id decoded_prediction = [] for p in labels: @@ -172,19 +148,6 @@ def stt(self, audio_buffer: np.array, sr: int): self._trim_memory() return current_hypotheses - def stt_file(self, file_path: str): - audio_buffer, sr = self.read_file(file_path) - current_hypotheses = self.stt(audio_buffer, sr) - return current_hypotheses - - def read_file(self, file_path: str): - audio_file = AudioSegment.from_file(file_path) - sr = audio_file.frame_rate - - samples = audio_file.get_array_of_samples() - audio_buffer = np.array(samples) - return audio_buffer, sr - @staticmethod def _trim_memory(): """ @@ -195,12 +158,21 @@ def _trim_memory(): gc.collect() def _resample(self, audio_fp32: np.array, sr: int): + if sr == self.sample_rate: + return audio_fp32 + try: + import soxr + except ImportError: + LOG.error( + f"Either provide audio at {self.sample_rate} sample rate or install soxr for automatic resampling") + raise audio_16k = soxr.resample(audio_fp32, sr, self.sample_rate) return audio_16k - def _to_float32(self, audio_buffer: np.array): + @staticmethod + def _to_float32(audio_buffer: np.array): audio_fp32 = np.divide(audio_buffer, np.iinfo(audio_buffer.dtype).max, dtype=np.float32) return audio_fp32 -__all__ = ["Model", "available_languages"] +__all__ = ["Model"] diff --git a/requirements.txt b/requirements.txt index 7d706c8..8afa2bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,9 +5,6 @@ SpeechRecognition~=3.8 torch>=1.13.1 onnxruntime sentencepiece -# resampling -soxr -pydub # huggingface huggingface-hub numpy<2.0.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 592daac..64dc3ae 100755 --- a/setup.py +++ b/setup.py @@ -62,6 +62,9 @@ def required(requirements_file): license='Apache-2.0', packages=['ovos_stt_plugin_citrinet'], install_requires=required("requirements.txt"), + extras_require={ + 'resampling': ["soxr"] + }, zip_safe=True, classifiers=[ 'Development Status :: 3 - Alpha',