diff --git a/.github/workflows/publish_alpha.yml b/.github/workflows/publish_alpha.yml index 23f89e4..a1e50bc 100644 --- a/.github/workflows/publish_alpha.yml +++ b/.github/workflows/publish_alpha.yml @@ -13,7 +13,7 @@ on: - 'LICENSE' - 'CHANGELOG.md' - 'MANIFEST.in' - - 'readme.md' + - 'README.md' - 'scripts/**' workflow_dispatch: diff --git a/readme.md b/README.md similarity index 56% rename from readme.md rename to README.md index 32901df..93d9ad1 100644 --- a/readme.md +++ b/README.md @@ -30,20 +30,26 @@ ovos-stt-server --help usage: ovos-stt-server [-h] [--engine ENGINE] [--port PORT] [--host HOST] options: - -h, --help show this help message and exit - --engine ENGINE stt plugin to be used - --port PORT port number - --host HOST host - --lang LANG default language - --gradio flag to enable Gradio web UI - --cache flag to pre-cache examples in Gradio web UI - --title TITLE title for Gradio UI - --description DESCRIPTION Description for Gradio UI - --info INFO Text to display in Gradio UI - --badge BADGE URL of badge to show in Gradio UI + -h, --help show this help message and exit + --engine ENGINE stt plugin to be used + --lang-engine LANG_ENGINE + audio language detection plugin to be used + --port PORT port number + --host HOST host + --lang LANG default language supported by plugin + --multi Load a plugin instance per language (force lang + support) + --gradio Enable Gradio Web UI + --cache Cache models for Gradio demo + --title TITLE Title for webUI + --description DESCRIPTION + Text description to print in UI + --info INFO Text to display at end of UI + --badge BADGE URL of visitor badge ``` > Note: `ffmpeg` is required for Gradio +eg `ovos-stt-server --engine ovos-stt-plugin-fasterwhisper --lang-engine ovos-audio-transformer-plugin-fasterwhisper` ## Docker diff --git a/ovos_stt_http_server/__init__.py b/ovos_stt_http_server/__init__.py index 120e708..45445d0 100644 --- a/ovos_stt_http_server/__init__.py +++ b/ovos_stt_http_server/__init__.py @@ -14,73 +14,75 @@ from fastapi import FastAPI from ovos_plugin_manager.stt import load_stt_plugin +from ovos_plugin_manager.audio_transformers import load_audio_transformer_plugin, AudioLanguageDetector from ovos_utils.log import LOG from speech_recognition import Recognizer, AudioFile, AudioData from starlette.requests import Request - LOG.set_level("ERROR") # avoid server side logs class ModelContainer: - def __init__(self, plugin: str, config: dict=None): - self.plugin = load_stt_plugin(plugin) - if not self.plugin: + def __init__(self, plugin: str, lang_plugin: str = None, config: dict = None): + plugin = load_stt_plugin(plugin) + self.lang_plugin = None + if not plugin: + raise ValueError(f"Failed to load STT: {plugin}") + if lang_plugin: + lang_plugin = load_audio_transformer_plugin(lang_plugin) + if not lang_plugin: + raise ValueError(f"Failed to load lang detection plugin: {plugin}") + assert issubclass(lang_plugin, AudioLanguageDetector) + self.lang_plugin = lang_plugin() + self.engine = plugin(config) + + def process_audio(self, audio: AudioData, lang): + if lang == "auto": + lang, prob = self.lang_plugin.detect(audio) + if audio or self.engine.can_stream: + return self.engine.execute(audio, language=lang) or "" + return "" + + +class MultiModelContainer: + """ loads 1 model per language """ + def __init__(self, plugin: str, lang_plugin: str = None, config: dict = None): + self.plugin_class = load_stt_plugin(plugin) + self.lang_plugin = None + if not self.plugin_class: raise ValueError(f"Failed to load STT: {plugin}") + if lang_plugin: + lang_plugin = load_audio_transformer_plugin(lang_plugin) + if not lang_plugin: + raise ValueError(f"Failed to load lang detection plugin: {plugin}") + assert issubclass(lang_plugin, AudioLanguageDetector) + self.lang_plugin = lang_plugin() self.engines = {} - self.data = {} self.config = config or {} - def get_engine(self, session_id): - if session_id not in self.engines: - self.load_engine(session_id) - return self.engines[session_id] + def get_engine(self, lang): + if lang not in self.engines: + self.load_engine(lang) + return self.engines[lang] - def load_engine(self, session_id, config=None): + def load_engine(self, lang, config=None): + # might need to load multiple models per language config = config or self.config - self.engines[session_id] = self.plugin(config=config) + config["lang"] = lang + self.engines[lang] = self.plugin_class(config=config) - def unload_engine(self, session_id): - if session_id in self.engines: - self.engines.pop(session_id) - if session_id in self.data: - self.data.pop(session_id) + def unload_engine(self, lang): + if lang in self.engines: + self.engines.pop(lang) - def process_audio(self, audio: AudioData, lang, session_id=None): - session_id = session_id or lang # shared model for non-streaming stt - engine = self.get_engine(session_id) + def process_audio(self, audio: AudioData, lang): + if lang == "auto": + lang, prob = self.lang_plugin.detect(audio) + engine = self.get_engine(lang) if audio or engine.can_stream: return engine.execute(audio, language=lang) or "" return "" - def stream_start(self, session_id): - engine = self.get_engine(session_id) - if engine.can_stream: - engine.stream_start() - - def stream_data(self, audio, session_id): - engine = self.get_engine(session_id) - if engine.can_stream: - # streaming plugin in server + streaming plugin in core - return engine.stream_data(audio) - else: - # non streaming plugin in server + streaming plugin in core - if session_id not in self.data: - self.data[session_id] = b"" - self.data[session_id] += audio - return "" - - def stream_stop(self, session_id): - engine = self.get_engine(session_id) - if engine.can_stream: - transcript = engine.stream_stop() - else: - audio = AudioData(self.data[session_id], - sample_rate=16000, sample_width=2) - transcript = engine.execute(audio) - self.unload_engine(session_id) - return transcript or "" - def bytes2audiodata(data): recognizer = Recognizer() @@ -91,14 +93,18 @@ def bytes2audiodata(data): return audio -def create_app(stt_plugin, has_gradio=False): +def create_app(stt_plugin, lang_plugin=None, multi=False, has_gradio=False): app = FastAPI() - model = ModelContainer(stt_plugin) + if multi: + model = MultiModelContainer(stt_plugin, lang_plugin) + else: + model = ModelContainer(stt_plugin, lang_plugin) @app.get("/status") def stats(request: Request): return {"status": "ok", "plugin": stt_plugin, + "lang_plugin": lang_plugin, "gradio": has_gradio} @app.post("/stt") @@ -108,35 +114,38 @@ async def get_stt(request: Request): audio = bytes2audiodata(audio_bytes) return model.process_audio(audio, lang) - @app.post("/stream/start") - def stream_start(request: Request): - lang = str(request.query_params.get("lang", "en-us")).lower() - uuid = str(request.query_params.get("uuid") or lang) - model.load_engine(uuid, {"lang": lang}) - model.stream_start(uuid) - return {"status": "ok", "uuid": uuid, "lang": lang} - - @app.post("/stream/audio") - async def stream(request: Request): - audio = await request.body() - lang = str(request.query_params.get("lang", "en-us")).lower() - uuid = str(request.query_params.get("uuid") or lang) - transcript = model.stream_data(audio, uuid) - return {"status": "ok", "uuid": uuid, - "lang": lang, "transcript": transcript} - - @app.post("/stream/end") - def stream_end(request: Request): - lang = str(request.query_params.get("lang", "en-us")).lower() - uuid = str(request.query_params.get("uuid") or lang) - # model.wait_until_done(uuid) - transcript = model.stream_stop(uuid) - return {"status": "ok", "uuid": uuid, - "lang": lang, "transcript": transcript} + @app.post("/lang_detect") + async def get_lang(request: Request): + audio_bytes = await request.body() + audio = bytes2audiodata(audio_bytes) + lang, prob = model.lang_plugin.detect(audio) + return {"lang": lang, "conf": prob} return app, model -def start_stt_server(engine: str, has_gradio: bool = False) -> (FastAPI, ModelContainer): - app, engine = create_app(engine, has_gradio) +def start_stt_server(engine: str, + lang_engine: str = None, + multi: bool = False, + has_gradio: bool = False) -> (FastAPI, ModelContainer): + app, engine = create_app(engine, lang_engine, multi, has_gradio) return app, engine + + +if __name__ == "__main__": + model = ModelContainer("ovos-stt-plugin-fasterwhisper", + "ovos-audio-transformer-plugin-fasterwhisper") + + from speech_recognition import Recognizer, AudioFile + + jfk = "/home/miro/PycharmProjects/ovos-stt-plugin-fasterwhisper/jfk.wav" + with AudioFile(jfk) as source: + audio = Recognizer().record(source) + + a = model.process_audio(audio, lang="en") + print(a) + # And so, my fellow Americans, ask not what your country can do for you. Ask what you can do for your country. + + lang, prob = model.lang_plugin.detect(audio.get_wav_data()) + print(lang, prob) + # en 1.0 diff --git a/ovos_stt_http_server/__main__.py b/ovos_stt_http_server/__main__.py index 0642fcb..31291a0 100644 --- a/ovos_stt_http_server/__main__.py +++ b/ovos_stt_http_server/__main__.py @@ -21,10 +21,13 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--engine", help="stt plugin to be used", required=True) + parser.add_argument("--lang-engine", help="audio language detection plugin to be used", required=True) parser.add_argument("--port", help="port number", default=8080) parser.add_argument("--host", help="host", default="0.0.0.0") parser.add_argument("--lang", help="default language supported by plugin", default="en-us") + parser.add_argument("--multi", help="Load a plugin instance per language (force lang support)", + action="store_true") parser.add_argument("--gradio", help="Enable Gradio Web UI", action="store_true") parser.add_argument("--cache", help="Cache models for Gradio demo", @@ -38,7 +41,9 @@ def main(): parser.add_argument("--badge", help="URL of visitor badge", default=None) args = parser.parse_args() - server, engine = start_stt_server(args.engine, has_gradio=bool(args.gradio)) + server, engine = start_stt_server(args.engine, lang_engine=args.lang_engine, + multi=bool(args.multi), + has_gradio=bool(args.gradio)) LOG.info("Server Started") if args.gradio: bind_gradio_service(server, engine, args.title, args.description, diff --git a/setup.py b/setup.py index aba298f..5b220f6 100755 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def required(requirements_file): if pkg.strip() and not pkg.startswith("#")] -with open(path.join(BASE_PATH, "readme.md"), "r") as f: +with open(path.join(BASE_PATH, "README.md"), "r") as f: long_description = f.read()