Skip to content

Commit

Permalink
Merge pull request #71 from JSchmie/develop_hf_wrapper
Browse files Browse the repository at this point in the history
Add default path to pyannote model with fallback option.
  • Loading branch information
JSchmie authored Apr 29, 2024
2 parents b075271 + 69b4e22 commit 37e28c5
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 35 deletions.
86 changes: 52 additions & 34 deletions scraibe/diarisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from torch import Tensor
from torch import device as torch_device
from torch.cuda import is_available, current_device
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError

from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
Annotation = TypeVar('Annotation')
Expand Down Expand Up @@ -183,9 +185,9 @@ def _save_token(token):

@classmethod
def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG,
model: str = PYANNOTE_DEFAULT_CONFIG,
use_auth_token: str = None,
cache_token: bool = True,
cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None,
device: str = None,
Expand All @@ -194,11 +196,12 @@ def load_model(cls,

"""
Loads a pretrained model from pyannote.audio,
either from a local cache or online repository.
either from a local cache or some online repository.
Args:
model: Path or identifier for the pyannote model.
default: /models/pyannote/speaker_diarization/config.yaml
default: '/home/[user]/.cache/torch/models/pyannote/config.yaml'
or one of 'jaikinator/scraibe', 'pyannote/speaker-diarization-3.1'
token: Optional HUGGINGFACE_TOKEN for authenticated access.
cache_token: Whether to cache the token locally for future use.
cache_dir: Directory for caching models.
Expand All @@ -210,33 +213,29 @@ def load_model(cls,
Returns:
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
"""


if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)

if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token()

elif os.path.exists(model) and not use_auth_token:
if isinstance(model, str) and os.path.exists(model):
# check if model can be found locally nearby the config file
with open(model, 'r') as file:
config = yaml.safe_load(file)

path_to_model = config['pipeline']['params']['segmentation']

if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. "\
"Trying to find it nearby the config file.")
warnings.warn(f"Model not found at {path_to_model}. "
"Trying to find it nearby the config file.")

pwd = model.split("/")[:-1]
pwd = "/".join(pwd)

path_to_model = os.path.join(pwd, "pytorch_model.bin")

if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. \
'Trying to find it nearby .bin files instead.")
warnings.warn(
'Searching for nearby files in a folder path is '
'deprecated and will be removed in future versions.',
category=DeprecationWarning)
# list elementes with the ending .bin
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
if len(bin_files) == 1:
Expand All @@ -245,30 +244,49 @@ def load_model(cls,
warnings.warn("Found more than one .bin file. "\
"or none. Please specify the path to the model " \
"or setup a huggingface token.")

raise FileNotFoundError

warnings.warn(f"Found model at {path_to_model} overwriting config file.")

config['pipeline']['params']['segmentation'] = path_to_model

with open(model, 'w') as file:
yaml.dump(config, file)

_model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token,
cache_dir = cache_dir,
hparams_file = hparams_file,)

# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"

_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict

elif isinstance(model, tuple):
try:
_model = model[0]
HfApi().model_info(_model)
model = _model
use_auth_token = None
except RepositoryNotFoundError:
print(f'{model[0]} not found on Huggingface, \
trying {model[1]}')
_model = model[1]
HfApi().model_info(_model)
model = _model
if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)

if use_auth_token is None:
use_auth_token = cls._get_token()
else:
raise FileNotFoundError(f'No local model or directory found at {model}.')

_model = Pipeline.from_pretrained(model,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
hparams_file=hparams_file,)
if _model is None:
raise ValueError('Unable to load model either from local cache' \
'or from huggingface.co models. Please check your token' \
'or your local model path')


# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"

_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict

return cls(_model)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion scraibe/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else 'pyannote/speaker-diarization-3.1'
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')

def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file.
Expand Down

0 comments on commit 37e28c5

Please sign in to comment.