diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 8523940..1643db2 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -187,7 +187,7 @@ def _save_token(token): def load_model(cls, model: str = PYANNOTE_DEFAULT_CONFIG, use_auth_token: str = None, - cache_token: bool = True, + cache_token: bool = False, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None, device: str = None, @@ -196,11 +196,12 @@ def load_model(cls, """ Loads a pretrained model from pyannote.audio, - either from a local cache or online repository. + either from a local cache or some online repository. Args: model: Path or identifier for the pyannote model. - default: /models/pyannote/speaker_diarization/config.yaml + default: '/home/[user]/.cache/torch/models/pyannote/config.yaml' + or one of 'jaikinator/scraibe', 'pyannote/speaker-diarization-3.1' token: Optional HUGGINGFACE_TOKEN for authenticated access. cache_token: Whether to cache the token locally for future use. cache_dir: Directory for caching models. @@ -261,8 +262,8 @@ def load_model(cls, model = _model if cache_token and use_auth_token is not None: cls._save_token(use_auth_token) - - if not os.path.exists(model) and use_auth_token is None: + + if use_auth_token is None: use_auth_token = cls._get_token() else: raise FileNotFoundError(f'No local model or directory found at {model}.') @@ -271,18 +272,17 @@ def load_model(cls, use_auth_token=use_auth_token, cache_dir=cache_dir, hparams_file=hparams_file,) - - # try to move the model to the device - if device is None: - device = "cuda" if is_available() else "cpu" - - _model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict - if _model is None: raise ValueError('Unable to load model either from local cache' \ 'or from huggingface.co models. Please check your token' \ 'or your local model path') - + + # try to move the model to the device + if device is None: + device = "cuda" if is_available() else "cpu" + + _model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict + return cls(_model) @staticmethod