diff --git a/.gitignore b/.gitignore index 18c7986..65fa59f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,239 @@ scraibe/app/*__pycache__ scraibe/.pyannotetoken +# Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,linux,windows +# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,linux,windows + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,linux,windows diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 2664e3f..7d54ba8 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -95,7 +95,7 @@ def __init__(self, elif isinstance(dia_model, str): self.diariser = Diariser.load_model(dia_model, **kwargs) else: - self.diariser = dia_model + self.diariser : Diariser = dia_model if kwargs.get("verbose"): print("Scraibe initialized all models successfully loaded.") @@ -133,7 +133,7 @@ def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray], if kwargs.get("verbose"): self.verbose = kwargs.get("verbose") # Get audio file as an AudioProcessor object - audio_file = self.get_audio_file(audio_file) + audio_file : AudioProcessor = self.get_audio_file(audio_file) # Prepare waveform and sample rate for diarization dia_audio = { @@ -203,7 +203,7 @@ def diarization(self, audio_file : Union[str, torch.Tensor, ndarray], """ # Get audio file as an AudioProcessor object - audio_file = self.get_audio_file(audio_file) + audio_file : AudioProcessor = self.get_audio_file(audio_file) # Prepare waveform and sample rate for diarization dia_audio = { @@ -232,9 +232,56 @@ def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], str: The transcribed text from the audio source. """ - audio_file = self.get_audio_file(audio_file) + audio_file : AudioProcessor = self.get_audio_file(audio_file) return self.transcriber.transcribe(audio_file.waveform, **kwargs) + + def update_transcriber(self, whisper_model : Union[str, whisper], **kwargs) -> None: + """ + Update the transcriber model. + + Args: + whisper_model (Union[str, whisper]): + The new whisper model to use for transcription. + **kwargs: + Additional keyword arguments for the transcriber model. + + Returns: + None + """ + _old_model = self.transcriber.model_name + + if isinstance(whisper_model, str): + self.transcriber = Transcriber.load_model(whisper_model, **kwargs) + elif isinstance(whisper_model, Transcriber): + self.transcriber = whisper_model + else: + warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning) + + return None + + def update_diariser(self, dia_model : Union[str, DiarisationType], **kwargs) -> None: + """ + Update the diariser model. + + Args: + dia_model (Union[str, DiarisationType]): + The new diariser model to use for diarization. + **kwargs: + Additional keyword arguments for the diariser model. + + Returns: + None + """ + if isinstance(dia_model, str): + self.diariser = Diariser.load_model(dia_model, **kwargs) + elif isinstance(dia_model, Diariser): + self.diariser = dia_model + else: + warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning) + + return None + @staticmethod def remove_audio_file(audio_file : str, shred : bool = False) -> None: @@ -269,7 +316,6 @@ def remove_audio_file(audio_file : str, print(f"Audiofile {audio_file} removed.") - @staticmethod def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray], *args, **kwargs) -> AudioProcessor: @@ -298,6 +344,7 @@ def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray], if not isinstance(audio_file, AudioProcessor): raise ValueError(f'Audiofile must be of type AudioProcessor,' \ f'not {type(audio_file)}') + return audio_file def __repr__(self): diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index dbb290e..910ea59 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -64,14 +64,18 @@ class Transcriber: The class supports various sizes and versions of Whisper models. Please refer to the load_model method for available options. """ - def __init__(self, model: whisper ) -> None: + def __init__(self, model: whisper , model_name: str ) -> None: """ Initialize the Transcriber class with a Whisper model. Args: model (whisper): The Whisper model to use for transcription. + model_name (str): The name of the model. """ + self.model = model + + self.model_name = model_name def transcribe(self, audio : Union[str, Tensor, ndarray] , *args, **kwargs) -> str: @@ -137,6 +141,7 @@ def load_model(cls, - 'medium' - 'large-v1' - 'large-v2' + - 'large-v3' - 'large' download_root (str, optional): Path to download the model. @@ -156,7 +161,7 @@ def load_model(cls, _model = load_model(model, download_root=download_root, device=device, in_memory=in_memory) - return cls(_model) + return cls(_model, model_name=model) @staticmethod def _get_whisper_kwargs(**kwargs) -> dict: @@ -179,4 +184,4 @@ def _get_whisper_kwargs(**kwargs) -> dict: return whisper_kwargs def __repr__(self) -> str: - return f"Transcriber(model={self.model})" \ No newline at end of file + return f"Transcriber(model_name={self.model_name}, model={self.model})" \ No newline at end of file