Skip to content

Commit

Permalink
feat/lang_detection_plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Apr 20, 2024
1 parent ad5743a commit ecbf85d
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish_alpha.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ on:
- 'LICENSE'
- 'CHANGELOG.md'
- 'MANIFEST.in'
- 'readme.md'
- 'README.md'
- 'scripts/**'
workflow_dispatch:

Expand Down
28 changes: 17 additions & 11 deletions readme.md → README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,26 @@ ovos-stt-server --help
usage: ovos-stt-server [-h] [--engine ENGINE] [--port PORT] [--host HOST]

options:
-h, --help show this help message and exit
--engine ENGINE stt plugin to be used
--port PORT port number
--host HOST host
--lang LANG default language
--gradio flag to enable Gradio web UI
--cache flag to pre-cache examples in Gradio web UI
--title TITLE title for Gradio UI
--description DESCRIPTION Description for Gradio UI
--info INFO Text to display in Gradio UI
--badge BADGE URL of badge to show in Gradio UI
-h, --help show this help message and exit
--engine ENGINE stt plugin to be used
--lang-engine LANG_ENGINE
audio language detection plugin to be used
--port PORT port number
--host HOST host
--lang LANG default language supported by plugin
--multi Load a plugin instance per language (force lang
support)
--gradio Enable Gradio Web UI
--cache Cache models for Gradio demo
--title TITLE Title for webUI
--description DESCRIPTION
Text description to print in UI
--info INFO Text to display at end of UI
--badge BADGE URL of visitor badge
```
> Note: `ffmpeg` is required for Gradio
eg `ovos-stt-server --engine ovos-stt-plugin-fasterwhisper --lang-engine ovos-audio-transformer-plugin-fasterwhisper`

## Docker

Expand Down
161 changes: 85 additions & 76 deletions ovos_stt_http_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,73 +14,75 @@

from fastapi import FastAPI
from ovos_plugin_manager.stt import load_stt_plugin
from ovos_plugin_manager.audio_transformers import load_audio_transformer_plugin, AudioLanguageDetector
from ovos_utils.log import LOG
from speech_recognition import Recognizer, AudioFile, AudioData
from starlette.requests import Request


LOG.set_level("ERROR") # avoid server side logs


class ModelContainer:
def __init__(self, plugin: str, config: dict=None):
self.plugin = load_stt_plugin(plugin)
if not self.plugin:
def __init__(self, plugin: str, lang_plugin: str = None, config: dict = None):
plugin = load_stt_plugin(plugin)
self.lang_plugin = None
if not plugin:
raise ValueError(f"Failed to load STT: {plugin}")
if lang_plugin:
lang_plugin = load_audio_transformer_plugin(lang_plugin)
if not lang_plugin:
raise ValueError(f"Failed to load lang detection plugin: {plugin}")
assert issubclass(lang_plugin, AudioLanguageDetector)
self.lang_plugin = lang_plugin()
self.engine = plugin(config)

def process_audio(self, audio: AudioData, lang):
if lang == "auto":
lang, prob = self.lang_plugin.detect(audio)
if audio or self.engine.can_stream:
return self.engine.execute(audio, language=lang) or ""
return ""


class MultiModelContainer:
""" loads 1 model per language """
def __init__(self, plugin: str, lang_plugin: str = None, config: dict = None):
self.plugin_class = load_stt_plugin(plugin)
self.lang_plugin = None
if not self.plugin_class:
raise ValueError(f"Failed to load STT: {plugin}")
if lang_plugin:
lang_plugin = load_audio_transformer_plugin(lang_plugin)
if not lang_plugin:
raise ValueError(f"Failed to load lang detection plugin: {plugin}")
assert issubclass(lang_plugin, AudioLanguageDetector)
self.lang_plugin = lang_plugin()
self.engines = {}
self.data = {}
self.config = config or {}

def get_engine(self, session_id):
if session_id not in self.engines:
self.load_engine(session_id)
return self.engines[session_id]
def get_engine(self, lang):
if lang not in self.engines:
self.load_engine(lang)
return self.engines[lang]

def load_engine(self, session_id, config=None):
def load_engine(self, lang, config=None):
# might need to load multiple models per language
config = config or self.config
self.engines[session_id] = self.plugin(config=config)
config["lang"] = lang
self.engines[lang] = self.plugin_class(config=config)

def unload_engine(self, session_id):
if session_id in self.engines:
self.engines.pop(session_id)
if session_id in self.data:
self.data.pop(session_id)
def unload_engine(self, lang):
if lang in self.engines:
self.engines.pop(lang)

def process_audio(self, audio: AudioData, lang, session_id=None):
session_id = session_id or lang # shared model for non-streaming stt
engine = self.get_engine(session_id)
def process_audio(self, audio: AudioData, lang):
if lang == "auto":
lang, prob = self.lang_plugin.detect(audio)
engine = self.get_engine(lang)
if audio or engine.can_stream:
return engine.execute(audio, language=lang) or ""
return ""

def stream_start(self, session_id):
engine = self.get_engine(session_id)
if engine.can_stream:
engine.stream_start()

def stream_data(self, audio, session_id):
engine = self.get_engine(session_id)
if engine.can_stream:
# streaming plugin in server + streaming plugin in core
return engine.stream_data(audio)
else:
# non streaming plugin in server + streaming plugin in core
if session_id not in self.data:
self.data[session_id] = b""
self.data[session_id] += audio
return ""

def stream_stop(self, session_id):
engine = self.get_engine(session_id)
if engine.can_stream:
transcript = engine.stream_stop()
else:
audio = AudioData(self.data[session_id],
sample_rate=16000, sample_width=2)
transcript = engine.execute(audio)
self.unload_engine(session_id)
return transcript or ""


def bytes2audiodata(data):
recognizer = Recognizer()
Expand All @@ -91,14 +93,18 @@ def bytes2audiodata(data):
return audio


def create_app(stt_plugin, has_gradio=False):
def create_app(stt_plugin, lang_plugin=None, multi=False, has_gradio=False):
app = FastAPI()
model = ModelContainer(stt_plugin)
if multi:
model = MultiModelContainer(stt_plugin, lang_plugin)
else:
model = ModelContainer(stt_plugin, lang_plugin)

@app.get("/status")
def stats(request: Request):
return {"status": "ok",
"plugin": stt_plugin,
"lang_plugin": lang_plugin,
"gradio": has_gradio}

@app.post("/stt")
Expand All @@ -108,35 +114,38 @@ async def get_stt(request: Request):
audio = bytes2audiodata(audio_bytes)
return model.process_audio(audio, lang)

@app.post("/stream/start")
def stream_start(request: Request):
lang = str(request.query_params.get("lang", "en-us")).lower()
uuid = str(request.query_params.get("uuid") or lang)
model.load_engine(uuid, {"lang": lang})
model.stream_start(uuid)
return {"status": "ok", "uuid": uuid, "lang": lang}

@app.post("/stream/audio")
async def stream(request: Request):
audio = await request.body()
lang = str(request.query_params.get("lang", "en-us")).lower()
uuid = str(request.query_params.get("uuid") or lang)
transcript = model.stream_data(audio, uuid)
return {"status": "ok", "uuid": uuid,
"lang": lang, "transcript": transcript}

@app.post("/stream/end")
def stream_end(request: Request):
lang = str(request.query_params.get("lang", "en-us")).lower()
uuid = str(request.query_params.get("uuid") or lang)
# model.wait_until_done(uuid)
transcript = model.stream_stop(uuid)
return {"status": "ok", "uuid": uuid,
"lang": lang, "transcript": transcript}
@app.post("/lang_detect")
async def get_lang(request: Request):
audio_bytes = await request.body()
audio = bytes2audiodata(audio_bytes)
lang, prob = model.lang_plugin.detect(audio)
return {"lang": lang, "conf": prob}

return app, model


def start_stt_server(engine: str, has_gradio: bool = False) -> (FastAPI, ModelContainer):
app, engine = create_app(engine, has_gradio)
def start_stt_server(engine: str,
lang_engine: str = None,
multi: bool = False,
has_gradio: bool = False) -> (FastAPI, ModelContainer):
app, engine = create_app(engine, lang_engine, multi, has_gradio)
return app, engine


if __name__ == "__main__":
model = ModelContainer("ovos-stt-plugin-fasterwhisper",
"ovos-audio-transformer-plugin-fasterwhisper")

from speech_recognition import Recognizer, AudioFile

jfk = "/home/miro/PycharmProjects/ovos-stt-plugin-fasterwhisper/jfk.wav"
with AudioFile(jfk) as source:
audio = Recognizer().record(source)

a = model.process_audio(audio, lang="en")
print(a)
# And so, my fellow Americans, ask not what your country can do for you. Ask what you can do for your country.

lang, prob = model.lang_plugin.detect(audio.get_wav_data())
print(lang, prob)
# en 1.0
7 changes: 6 additions & 1 deletion ovos_stt_http_server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--engine", help="stt plugin to be used", required=True)
parser.add_argument("--lang-engine", help="audio language detection plugin to be used", required=True)
parser.add_argument("--port", help="port number", default=8080)
parser.add_argument("--host", help="host", default="0.0.0.0")
parser.add_argument("--lang", help="default language supported by plugin",
default="en-us")
parser.add_argument("--multi", help="Load a plugin instance per language (force lang support)",
action="store_true")
parser.add_argument("--gradio", help="Enable Gradio Web UI",
action="store_true")
parser.add_argument("--cache", help="Cache models for Gradio demo",
Expand All @@ -38,7 +41,9 @@ def main():
parser.add_argument("--badge", help="URL of visitor badge", default=None)
args = parser.parse_args()

server, engine = start_stt_server(args.engine, has_gradio=bool(args.gradio))
server, engine = start_stt_server(args.engine, lang_engine=args.lang_engine,
multi=bool(args.multi),
has_gradio=bool(args.gradio))
LOG.info("Server Started")
if args.gradio:
bind_gradio_service(server, engine, args.title, args.description,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def required(requirements_file):
if pkg.strip() and not pkg.startswith("#")]


with open(path.join(BASE_PATH, "readme.md"), "r") as f:
with open(path.join(BASE_PATH, "README.md"), "r") as f:
long_description = f.read()


Expand Down

0 comments on commit ecbf85d

Please sign in to comment.