diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 0f0e14a..ade9220 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -38,6 +38,8 @@ from torch import Tensor from torch import device as torch_device from torch.cuda import is_available, current_device +from huggingface_hub import HfApi +from huggingface_hub.utils import RepositoryNotFoundError from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG Annotation = TypeVar('Annotation') @@ -183,9 +185,9 @@ def _save_token(token): @classmethod def load_model(cls, - model: str = PYANNOTE_DEFAULT_CONFIG, + model: str = PYANNOTE_DEFAULT_CONFIG, use_auth_token: str = None, - cache_token: bool = True, + cache_token: bool = False, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None, device: str = None, @@ -194,11 +196,12 @@ def load_model(cls, """ Loads a pretrained model from pyannote.audio, - either from a local cache or online repository. + either from a local cache or some online repository. Args: model: Path or identifier for the pyannote model. - default: /models/pyannote/speaker_diarization/config.yaml + default: '/home/[user]/.cache/torch/models/pyannote/config.yaml' + or one of 'jaikinator/scraibe', 'pyannote/speaker-diarization-3.1' token: Optional HUGGINGFACE_TOKEN for authenticated access. cache_token: Whether to cache the token locally for future use. cache_dir: Directory for caching models. @@ -210,33 +213,29 @@ def load_model(cls, Returns: Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. """ - - - if cache_token and use_auth_token is not None: - cls._save_token(use_auth_token) - - if not os.path.exists(model) and use_auth_token is None: - use_auth_token = cls._get_token() - - elif os.path.exists(model) and not use_auth_token: + if isinstance(model, str) and os.path.exists(model): # check if model can be found locally nearby the config file with open(model, 'r') as file: config = yaml.safe_load(file) - + path_to_model = config['pipeline']['params']['segmentation'] if not os.path.exists(path_to_model): - warnings.warn(f"Model not found at {path_to_model}. "\ - "Trying to find it nearby the config file.") - + warnings.warn(f"Model not found at {path_to_model}. " + "Trying to find it nearby the config file.") + pwd = model.split("/")[:-1] pwd = "/".join(pwd) - + path_to_model = os.path.join(pwd, "pytorch_model.bin") if not os.path.exists(path_to_model): warnings.warn(f"Model not found at {path_to_model}. \ 'Trying to find it nearby .bin files instead.") + warnings.warn( + 'Searching for nearby files in a folder path is ' + 'deprecated and will be removed in future versions.', + category=DeprecationWarning) # list elementes with the ending .bin bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] if len(bin_files) == 1: @@ -245,30 +244,49 @@ def load_model(cls, warnings.warn("Found more than one .bin file. "\ "or none. Please specify the path to the model " \ "or setup a huggingface token.") - + raise FileNotFoundError + warnings.warn(f"Found model at {path_to_model} overwriting config file.") - + config['pipeline']['params']['segmentation'] = path_to_model - + with open(model, 'w') as file: yaml.dump(config, file) - - _model = Pipeline.from_pretrained(model, - use_auth_token = use_auth_token, - cache_dir = cache_dir, - hparams_file = hparams_file,) - - # try to move the model to the device - if device is None: - device = "cuda" if is_available() else "cpu" - - _model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict - + elif isinstance(model, tuple): + try: + _model = model[0] + HfApi().model_info(_model) + model = _model + use_auth_token = None + except RepositoryNotFoundError: + print(f'{model[0]} not found on Huggingface, \ + trying {model[1]}') + _model = model[1] + HfApi().model_info(_model) + model = _model + if cache_token and use_auth_token is not None: + cls._save_token(use_auth_token) + + if use_auth_token is None: + use_auth_token = cls._get_token() + else: + raise FileNotFoundError(f'No local model or directory found at {model}.') + + _model = Pipeline.from_pretrained(model, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + hparams_file=hparams_file,) if _model is None: raise ValueError('Unable to load model either from local cache' \ 'or from huggingface.co models. Please check your token' \ 'or your local model path') - + + # try to move the model to the device + if device is None: + device = "cuda" if is_available() else "cpu" + + _model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict + return cls(_model) @staticmethod diff --git a/scraibe/misc.py b/scraibe/misc.py index 992e40c..c1d5484 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -15,7 +15,7 @@ PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ - else 'pyannote/speaker-diarization-3.1' + else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file.