Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(models): add output_dir option to model_download #178

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/kagglehub/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from shutil import copytree
from typing import Optional, Union

from kagglehub import registry
Expand All @@ -13,22 +14,44 @@
DEFAULT_IGNORE_PATTERNS = [".git/", "*/.git/", ".cache/", ".huggingface/"]


def model_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
def model_download(
handle: str,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
output_dir: Optional[str] = None) -> str:
"""Download model files.

Args:
handle: (string) the model handle.
path: (string) Optional path to a file within the model bundle.
force_download: (bool) Optional flag to force download a model, even if it's cached.

output_dir: (string) Optional path to copy model files to after successful download.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the user's expectation would be that files are downloaded directly to this output_dir and the cache folder is skipped entirely.


Returns:
A string representing the path to the requested model files.
"""
h = parse_model_handle(handle)
logger.info(f"Downloading Model: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK})
return registry.model_resolver(h, path, force_download=force_download)
cached_dir = registry.model_resolver(h, path, force_download=force_download)

if output_dir is None or output_dir == cached_dir:
return cached_dir

try:
# only copying so that we can maintain the cached files
logger.info(
f"Copying model files to requested directory: {output_dir} ...",
extra={**EXTRA_CONSOLE_BLOCK}
)
true_output_dir = copytree(cached_dir, output_dir, dirs_exist_ok=True)
return true_output_dir
except Exception as e:
logger.warn(
f"Successfully downloaded {handle}, but failed to copy from {cached_dir} "
f"to requested output directory {output_dir}. Encountered error: {e}"
)
return cached_dir

def model_upload(
handle: str,
Expand Down
29 changes: 28 additions & 1 deletion tests/test_http_model_download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from tempfile import TemporaryDirectory
from typing import Optional
from unittest import mock

import requests

Expand Down Expand Up @@ -147,6 +149,32 @@ def test_versioned_model_download_with_path_with_force_download(self) -> None:
with create_test_cache() as d:
self._download_test_file_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE, force_download=True)

def test_versioned_model_download_with_output_dir(self) -> None:
with create_test_cache() as d:
with TemporaryDirectory() as expected_output_dir:
self._download_model_and_assert_downloaded(
d,
VERSIONED_MODEL_HANDLE,
expected_output_dir,
output_dir=expected_output_dir
)

def test_versioned_model_download_with_bad_output_dir(self) -> None:
with (
create_test_cache() as d,
TemporaryDirectory() as placeholder_dir,
mock.patch("kagglehub.models.copytree") as mock_copytree
):
mock_copytree.side_effect = Exception("Mock exception")
expected_output_dir = EXPECTED_MODEL_SUBDIR # falls back to default
self._download_model_and_assert_downloaded(
d,
VERSIONED_MODEL_HANDLE,
expected_output_dir,
# note: placeholder name is irrelevant since copytree is mocked to throw
output_dir=placeholder_dir
)

def test_unversioned_model_download_with_path_with_force_download(self) -> None:
with create_test_cache() as d:
self._download_test_file_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE, force_download=True)
Expand Down Expand Up @@ -188,7 +216,6 @@ def test_versioned_model_download_with_path_already_cached_with_force_download_e

self.assertEqual(os.path.join(d, EXPECTED_MODEL_SUBPATH), model_path)


class TestHttpNoInternet(BaseTestCase):
def test_versioned_model_download_already_cached_with_force_download(self) -> None:
with create_test_cache():
Expand Down