Skip to content

Commit

Permalink
Merge pull request #25 from DeepLabCut/minor_fixes
Browse files Browse the repository at this point in the history
Refactor and other formatting improvements
  • Loading branch information
AlexEMG authored Dec 18, 2023
2 parents 15c13ef + 209b461 commit efcd6ad
Showing 1 changed file with 50 additions and 41 deletions.
91 changes: 50 additions & 41 deletions dlclibrary/dlcmodelzoo/modelzoo_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@

import json
import os
import tarfile
from pathlib import Path

from huggingface_hub import hf_hub_download
from ruamel.yaml.comments import CommentedBase

# just expand this list when adding new models:
MODELOPTIONS = [
Expand Down Expand Up @@ -52,34 +57,54 @@ def parse_available_supermodels():
return json.load(file)


def _handle_downloaded_file(
file_path: str, target_dir: str, rename_mapping: dict | None = None
):
"""Handle the downloaded file from HuggingFace"""
file_name = os.path.basename(file_path)
try:
with tarfile.open(file_path, mode="r:gz") as tar:
for member in tar:
if not member.isdir():
fname = Path(member.name).name
tar.makefile(member, os.path.join(target_dir, fname))
except tarfile.ReadError: # The model is a .pt file
if rename_mapping is not None:
file_name = rename_mapping.get(file_name, file_name)
if os.path.islink(file_path):
file_path_ = os.readlink(file_path)
if not os.path.isabs(file_path_):
file_path_ = os.path.abspath(
os.path.join(os.path.dirname(file_path), file_path_)
)
file_path = file_path_
os.rename(file_path, os.path.join(target_dir, file_name))


def download_huggingface_model(
modelname, target_dir=".", remove_hf_folder=True, rename_mapping: dict | None = None
model_name: str,
target_dir: str = ".",
remove_hf_folder: bool = True,
rename_mapping: dict | None = None,
):
"""
Download a DeepLabCut Model Zoo Project from Hugging Face
Parameters
----------
modelname : string
Name of the ModelZoo model. For visualizations see: http://www.mackenziemathislab.org/dlc-modelzoo
target_dir : directory (as string)
Directory where to store the model weights and pose_cfg.yaml file
remove_hf_folder : bool, default True
Whether to remove the directory structure provided by HuggingFace after downloading and decompressing data into DeepLabCut format.
rename_mapping : dict, default None
Dictionary to rename the downloaded file. If None, the original filename is used.
Downloads a DeepLabCut Model Zoo Project from Hugging Face.
Args:
model_name (str): Name of the ModelZoo model.
For visualizations, see http://www.mackenziemathislab.org/dlc-modelzoo.
target_dir (str): Directory where the model weights and pose_cfg.yaml file will be stored.
remove_hf_folder (bool, optional): Whether to remove the directory structure provided by HuggingFace
after downloading and decompressing the data into DeepLabCut format. Defaults to True.
rename_mapping (dict, optional): A dictionary to rename the downloaded file.
If None, the original filename is used. Defaults to None.
"""
from huggingface_hub import hf_hub_download
import tarfile
from pathlib import Path
from ruamel.yaml.comments import CommentedBase

neturls = _load_model_names()
if modelname not in neturls:
raise ValueError(f"`modelname` should be one of: {', '.join(modelname)}.")
net_urls = _load_model_names()
if model_name not in net_urls:
raise ValueError(f"`modelname` should be one of: {', '.join(net_urls)}.")

print("Loading....", modelname)
urls = neturls[modelname]
print("Loading....", model_name)
urls = net_urls[model_name]
if isinstance(urls, CommentedBase):
urls = list(urls)
else:
Expand All @@ -98,26 +123,10 @@ def download_huggingface_model(
hf_folder = f"models--{url[0]}--{url[1]}"
path_ = os.path.join(target_dir, hf_folder, "snapshots")
commit = os.listdir(path_)[0]
filename = os.path.join(path_, commit, targzfn)
try:
with tarfile.open(filename, mode="r:gz") as tar:
for member in tar:
if not member.isdir():
fname = Path(member.name).name
tar.makefile(member, os.path.join(target_dir, fname))
except tarfile.ReadError: # The model is a .pt file
if rename_mapping is not None:
targzfn = rename_mapping.get(targzfn, targzfn)
if os.path.islink(filename):
filename_ = os.readlink(filename)
if not os.path.isabs(filename_):
filename_ = os.path.abspath(os.path.join(os.path.dirname(filename), filename_))
filename = filename_
os.rename(filename, os.path.join(target_dir, targzfn))
file_name = os.path.join(path_, commit, targzfn)
_handle_downloaded_file(file_name, target_dir, rename_mapping)

if remove_hf_folder:
import shutil

shutil.rmtree(os.path.join(target_dir, hf_folder))

'../../blobs/6c9c66d48f25cac9f8adaea7a485b07f4bd781ba656785bc4e077d9064e8e5df'

0 comments on commit efcd6ad

Please sign in to comment.