From f7927fd524bd6a6d7527d18dd1ac5013c0412f01 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Fri, 19 Apr 2024 17:36:34 +0200 Subject: [PATCH 1/4] Add default path to pyannote model with fallback option. --- scraibe/diarisation.py | 100 +++++++++++++++++++++++------------------ scraibe/misc.py | 3 +- 2 files changed, 58 insertions(+), 45 deletions(-) diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 0f0e14a..161dae5 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -19,6 +19,7 @@ - TOKEN_PATH (str): Path to the Pyannote token. - PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models. - PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models. +- PYANNOTE_FALLBACK_CONFIG (str): Fallback config for Pyannote models if default config does not work. Usage: from .diarisation import Diariser @@ -39,7 +40,7 @@ from torch import device as torch_device from torch.cuda import is_available, current_device -from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG +from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, PYANNOTE_FALLBACK_CONFIG Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( @@ -183,7 +184,7 @@ def _save_token(token): @classmethod def load_model(cls, - model: str = PYANNOTE_DEFAULT_CONFIG, + model: str = PYANNOTE_FALLBACK_CONFIG, use_auth_token: str = None, cache_token: bool = True, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, @@ -210,53 +211,64 @@ def load_model(cls, Returns: Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. """ + try: + hf_model = PYANNOTE_DEFAULT_CONFIG + # if not use_auth_token: + # use_auth_token = cls._get_token() + _model = Pipeline.from_pretrained( + hf_model, use_auth_token=use_auth_token, + cache_dir=cache_dir, hparams_file=hparams_file + ) + except: + print(f'Trying fallback to config file.. ') + _model = None + if _model is None: - - if cache_token and use_auth_token is not None: - cls._save_token(use_auth_token) - - if not os.path.exists(model) and use_auth_token is None: - use_auth_token = cls._get_token() - - elif os.path.exists(model) and not use_auth_token: - # check if model can be found locally nearby the config file - with open(model, 'r') as file: - config = yaml.safe_load(file) - - path_to_model = config['pipeline']['params']['segmentation'] - - if not os.path.exists(path_to_model): - warnings.warn(f"Model not found at {path_to_model}. "\ - "Trying to find it nearby the config file.") + if cache_token and use_auth_token is not None: + cls._save_token(use_auth_token) + + if not os.path.exists(model) and use_auth_token is None: + use_auth_token = cls._get_token() - pwd = model.split("/")[:-1] - pwd = "/".join(pwd) + elif os.path.exists(model) and not use_auth_token: + # check if model can be found locally nearby the config file + with open(model, 'r') as file: + config = yaml.safe_load(file) - path_to_model = os.path.join(pwd, "pytorch_model.bin") + path_to_model = config['pipeline']['params']['segmentation'] if not os.path.exists(path_to_model): - warnings.warn(f"Model not found at {path_to_model}. \ - 'Trying to find it nearby .bin files instead.") - # list elementes with the ending .bin - bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] - if len(bin_files) == 1: - path_to_model = os.path.join(pwd, bin_files[0]) - else: - warnings.warn("Found more than one .bin file. "\ - "or none. Please specify the path to the model " \ - "or setup a huggingface token.") - - warnings.warn(f"Found model at {path_to_model} overwriting config file.") - - config['pipeline']['params']['segmentation'] = path_to_model - - with open(model, 'w') as file: - yaml.dump(config, file) - - _model = Pipeline.from_pretrained(model, - use_auth_token = use_auth_token, - cache_dir = cache_dir, - hparams_file = hparams_file,) + warnings.warn(f"Model not found at {path_to_model}. "\ + "Trying to find it nearby the config file.") + + pwd = model.split("/")[:-1] + pwd = "/".join(pwd) + + path_to_model = os.path.join(pwd, "pytorch_model.bin") + + if not os.path.exists(path_to_model): + warnings.warn(f"Model not found at {path_to_model}. \ + 'Trying to find it nearby .bin files instead.") + # list elementes with the ending .bin + bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] + if len(bin_files) == 1: + path_to_model = os.path.join(pwd, bin_files[0]) + else: + warnings.warn("Found more than one .bin file. "\ + "or none. Please specify the path to the model " \ + "or setup a huggingface token.") + + warnings.warn(f"Found model at {path_to_model} overwriting config file.") + + config['pipeline']['params']['segmentation'] = path_to_model + + with open(model, 'w') as file: + yaml.dump(config, file) + + _model = Pipeline.from_pretrained(model, + use_auth_token = use_auth_token, + cache_dir = cache_dir, + hparams_file = hparams_file,) # try to move the model to the device if device is None: diff --git a/scraibe/misc.py b/scraibe/misc.py index 992e40c..549ee67 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -13,7 +13,8 @@ WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") -PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ +PYANNOTE_DEFAULT_CONFIG = 'jaikinator/scraibe' +PYANNOTE_FALLBACK_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ else 'pyannote/speaker-diarization-3.1' From 7d8da3b81c31e5c53bbf7049862638ec6452af2e Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Tue, 23 Apr 2024 14:39:18 +0200 Subject: [PATCH 2/4] Reworking the hf wrapper, now without blank except block (wow)! --- scraibe/diarisation.py | 120 +++++++++++++++++++++-------------------- scraibe/misc.py | 5 +- 2 files changed, 63 insertions(+), 62 deletions(-) diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 161dae5..8523940 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -19,7 +19,6 @@ - TOKEN_PATH (str): Path to the Pyannote token. - PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models. - PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models. -- PYANNOTE_FALLBACK_CONFIG (str): Fallback config for Pyannote models if default config does not work. Usage: from .diarisation import Diariser @@ -39,8 +38,10 @@ from torch import Tensor from torch import device as torch_device from torch.cuda import is_available, current_device +from huggingface_hub import HfApi +from huggingface_hub.utils import RepositoryNotFoundError -from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, PYANNOTE_FALLBACK_CONFIG +from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( @@ -184,7 +185,7 @@ def _save_token(token): @classmethod def load_model(cls, - model: str = PYANNOTE_FALLBACK_CONFIG, + model: str = PYANNOTE_DEFAULT_CONFIG, use_auth_token: str = None, cache_token: bool = True, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, @@ -211,64 +212,65 @@ def load_model(cls, Returns: Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. """ - try: - hf_model = PYANNOTE_DEFAULT_CONFIG - # if not use_auth_token: - # use_auth_token = cls._get_token() - _model = Pipeline.from_pretrained( - hf_model, use_auth_token=use_auth_token, - cache_dir=cache_dir, hparams_file=hparams_file - ) - except: - print(f'Trying fallback to config file.. ') - _model = None - if _model is None: - - if cache_token and use_auth_token is not None: - cls._save_token(use_auth_token) - - if not os.path.exists(model) and use_auth_token is None: - use_auth_token = cls._get_token() - - elif os.path.exists(model) and not use_auth_token: - # check if model can be found locally nearby the config file - with open(model, 'r') as file: - config = yaml.safe_load(file) - - path_to_model = config['pipeline']['params']['segmentation'] + if isinstance(model, str) and os.path.exists(model): + # check if model can be found locally nearby the config file + with open(model, 'r') as file: + config = yaml.safe_load(file) + + path_to_model = config['pipeline']['params']['segmentation'] + + if not os.path.exists(path_to_model): + warnings.warn(f"Model not found at {path_to_model}. " + "Trying to find it nearby the config file.") + + pwd = model.split("/")[:-1] + pwd = "/".join(pwd) + + path_to_model = os.path.join(pwd, "pytorch_model.bin") if not os.path.exists(path_to_model): - warnings.warn(f"Model not found at {path_to_model}. "\ - "Trying to find it nearby the config file.") - - pwd = model.split("/")[:-1] - pwd = "/".join(pwd) - - path_to_model = os.path.join(pwd, "pytorch_model.bin") - - if not os.path.exists(path_to_model): - warnings.warn(f"Model not found at {path_to_model}. \ - 'Trying to find it nearby .bin files instead.") - # list elementes with the ending .bin - bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] - if len(bin_files) == 1: - path_to_model = os.path.join(pwd, bin_files[0]) - else: - warnings.warn("Found more than one .bin file. "\ - "or none. Please specify the path to the model " \ - "or setup a huggingface token.") - - warnings.warn(f"Found model at {path_to_model} overwriting config file.") - - config['pipeline']['params']['segmentation'] = path_to_model - - with open(model, 'w') as file: - yaml.dump(config, file) - - _model = Pipeline.from_pretrained(model, - use_auth_token = use_auth_token, - cache_dir = cache_dir, - hparams_file = hparams_file,) + warnings.warn(f"Model not found at {path_to_model}. \ + 'Trying to find it nearby .bin files instead.") + # list elementes with the ending .bin + bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] + if len(bin_files) == 1: + path_to_model = os.path.join(pwd, bin_files[0]) + else: + warnings.warn("Found more than one .bin file. "\ + "or none. Please specify the path to the model " \ + "or setup a huggingface token.") + raise FileNotFoundError + + warnings.warn(f"Found model at {path_to_model} overwriting config file.") + + config['pipeline']['params']['segmentation'] = path_to_model + + with open(model, 'w') as file: + yaml.dump(config, file) + elif isinstance(model, tuple): + try: + _model = model[0] + HfApi().model_info(_model) + model = _model + use_auth_token = None + except RepositoryNotFoundError: + print(f'{model[0]} not found on Huggingface, \ + trying {model[1]}') + _model = model[1] + HfApi().model_info(_model) + model = _model + if cache_token and use_auth_token is not None: + cls._save_token(use_auth_token) + + if not os.path.exists(model) and use_auth_token is None: + use_auth_token = cls._get_token() + else: + raise FileNotFoundError(f'No local model or directory found at {model}.') + + _model = Pipeline.from_pretrained(model, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + hparams_file=hparams_file,) # try to move the model to the device if device is None: diff --git a/scraibe/misc.py b/scraibe/misc.py index 549ee67..c1d5484 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -13,10 +13,9 @@ WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") -PYANNOTE_DEFAULT_CONFIG = 'jaikinator/scraibe' -PYANNOTE_FALLBACK_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ +PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ - else 'pyannote/speaker-diarization-3.1' + else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file. From 55a77b861cf40e5354e719e511ebe0e6977fe212 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Tue, 23 Apr 2024 16:29:48 +0200 Subject: [PATCH 3/4] Fixed cache default value, moved ValuError t othe right place, added to docstring. --- scraibe/diarisation.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 8523940..1643db2 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -187,7 +187,7 @@ def _save_token(token): def load_model(cls, model: str = PYANNOTE_DEFAULT_CONFIG, use_auth_token: str = None, - cache_token: bool = True, + cache_token: bool = False, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None, device: str = None, @@ -196,11 +196,12 @@ def load_model(cls, """ Loads a pretrained model from pyannote.audio, - either from a local cache or online repository. + either from a local cache or some online repository. Args: model: Path or identifier for the pyannote model. - default: /models/pyannote/speaker_diarization/config.yaml + default: '/home/[user]/.cache/torch/models/pyannote/config.yaml' + or one of 'jaikinator/scraibe', 'pyannote/speaker-diarization-3.1' token: Optional HUGGINGFACE_TOKEN for authenticated access. cache_token: Whether to cache the token locally for future use. cache_dir: Directory for caching models. @@ -261,8 +262,8 @@ def load_model(cls, model = _model if cache_token and use_auth_token is not None: cls._save_token(use_auth_token) - - if not os.path.exists(model) and use_auth_token is None: + + if use_auth_token is None: use_auth_token = cls._get_token() else: raise FileNotFoundError(f'No local model or directory found at {model}.') @@ -271,18 +272,17 @@ def load_model(cls, use_auth_token=use_auth_token, cache_dir=cache_dir, hparams_file=hparams_file,) - - # try to move the model to the device - if device is None: - device = "cuda" if is_available() else "cpu" - - _model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict - if _model is None: raise ValueError('Unable to load model either from local cache' \ 'or from huggingface.co models. Please check your token' \ 'or your local model path') - + + # try to move the model to the device + if device is None: + device = "cuda" if is_available() else "cpu" + + _model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict + return cls(_model) @staticmethod From 69b4e22b51a9f642566ff4a4a65819bb00f1c0f6 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Mon, 29 Apr 2024 13:38:55 +0200 Subject: [PATCH 4/4] Add deprecation warning. --- scraibe/diarisation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 1643db2..ade9220 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -232,6 +232,10 @@ def load_model(cls, if not os.path.exists(path_to_model): warnings.warn(f"Model not found at {path_to_model}. \ 'Trying to find it nearby .bin files instead.") + warnings.warn( + 'Searching for nearby files in a folder path is ' + 'deprecated and will be removed in future versions.', + category=DeprecationWarning) # list elementes with the ending .bin bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] if len(bin_files) == 1: