From 546959f2d8627cb78189b5e1d5c8e6afd9f5369b Mon Sep 17 00:00:00 2001 From: Brandon Horton Date: Mon, 4 Nov 2024 18:35:10 +0000 Subject: [PATCH 1/2] feat(models): add output-dir option to model_download --- src/kagglehub/models.py | 29 ++++++++++++++++++++++++++--- tests/test_http_model_download.py | 24 +++++++++++++++++++++++- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index 19bca04..f63ae52 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -1,4 +1,5 @@ import logging +from shutil import copytree from typing import Optional, Union from kagglehub import registry @@ -13,22 +14,44 @@ DEFAULT_IGNORE_PATTERNS = [".git/", "*/.git/", ".cache/", ".huggingface/"] -def model_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: +def model_download( + handle: str, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + output_dir: Optional[str] = None) -> str: """Download model files. Args: handle: (string) the model handle. path: (string) Optional path to a file within the model bundle. force_download: (bool) Optional flag to force download a model, even if it's cached. - + output_dir: (str) Optional path to copy model files to after successful download. Returns: A string representing the path to the requested model files. """ h = parse_model_handle(handle) logger.info(f"Downloading Model: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) - return registry.model_resolver(h, path, force_download=force_download) + cached_dir = registry.model_resolver(h, path, force_download=force_download) + + if output_dir is None: + return cached_dir + try: + # only copying so that we can maintain the cached files + logger.info( + f"Copying model files to requested directory: {output_dir} ...", + extra={**EXTRA_CONSOLE_BLOCK} + ) + true_output_dir = copytree(cached_dir, output_dir, dirs_exist_ok=True) + return true_output_dir + except Exception as e: + logger.warn( + f"Successfully downloaded {handle}, but failed to copy from {cached_dir} " + f"to requested output directory {output_dir}. Encountered error: {e}" + ) + return cached_dir def model_upload( handle: str, diff --git a/tests/test_http_model_download.py b/tests/test_http_model_download.py index 9ba6c39..ca6cd18 100644 --- a/tests/test_http_model_download.py +++ b/tests/test_http_model_download.py @@ -7,6 +7,7 @@ from kagglehub.cache import MODELS_CACHE_SUBFOLDER, get_cached_archive_path from kagglehub.handle import parse_model_handle from tests.fixtures import BaseTestCase +from unittest import mock from .server_stubs import model_download_stub as stub from .server_stubs import serv @@ -147,6 +148,28 @@ def test_versioned_model_download_with_path_with_force_download(self) -> None: with create_test_cache() as d: self._download_test_file_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE, force_download=True) + def test_versioned_model_download_with_output_dir(self) -> None: + with create_test_cache() as d: + expected_ouput_dir = "/tmp/downloaded_model" + self._download_model_and_assert_downloaded( + d, + VERSIONED_MODEL_HANDLE, + expected_ouput_dir, + output_dir=expected_ouput_dir + ) + + def test_versioned_model_download_with_bad_output_dir(self) -> None: + with create_test_cache() as d: + mock.patch("kagglehub.models.copytree", side_effect=Exception()) + bad_output_dir = "/bad/path/that/fails" + expected_output_dir = EXPECTED_MODEL_SUBDIR # falls back to default + self._download_model_and_assert_downloaded( + d, + VERSIONED_MODEL_HANDLE, + expected_output_dir, + output_dir=bad_output_dir + ) + def test_unversioned_model_download_with_path_with_force_download(self) -> None: with create_test_cache() as d: self._download_test_file_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE, force_download=True) @@ -188,7 +211,6 @@ def test_versioned_model_download_with_path_already_cached_with_force_download_e self.assertEqual(os.path.join(d, EXPECTED_MODEL_SUBPATH), model_path) - class TestHttpNoInternet(BaseTestCase): def test_versioned_model_download_already_cached_with_force_download(self) -> None: with create_test_cache(): From 0997d406b84c90d0ebad80ac10d44c99358cfc6f Mon Sep 17 00:00:00 2001 From: Brandon Horton Date: Mon, 4 Nov 2024 21:45:57 +0000 Subject: [PATCH 2/2] fix(pr): lint + extra equality checks --- src/kagglehub/models.py | 4 ++-- tests/test_http_model_download.py | 29 +++++++++++++++++------------ 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index f63ae52..a3a0aec 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -26,7 +26,7 @@ def model_download( handle: (string) the model handle. path: (string) Optional path to a file within the model bundle. force_download: (bool) Optional flag to force download a model, even if it's cached. - output_dir: (str) Optional path to copy model files to after successful download. + output_dir: (string) Optional path to copy model files to after successful download. Returns: A string representing the path to the requested model files. @@ -35,7 +35,7 @@ def model_download( logger.info(f"Downloading Model: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) cached_dir = registry.model_resolver(h, path, force_download=force_download) - if output_dir is None: + if output_dir is None or output_dir == cached_dir: return cached_dir try: diff --git a/tests/test_http_model_download.py b/tests/test_http_model_download.py index ca6cd18..e2f55d9 100644 --- a/tests/test_http_model_download.py +++ b/tests/test_http_model_download.py @@ -1,5 +1,7 @@ import os +from tempfile import TemporaryDirectory from typing import Optional +from unittest import mock import requests @@ -7,7 +9,6 @@ from kagglehub.cache import MODELS_CACHE_SUBFOLDER, get_cached_archive_path from kagglehub.handle import parse_model_handle from tests.fixtures import BaseTestCase -from unittest import mock from .server_stubs import model_download_stub as stub from .server_stubs import serv @@ -150,24 +151,28 @@ def test_versioned_model_download_with_path_with_force_download(self) -> None: def test_versioned_model_download_with_output_dir(self) -> None: with create_test_cache() as d: - expected_ouput_dir = "/tmp/downloaded_model" - self._download_model_and_assert_downloaded( - d, - VERSIONED_MODEL_HANDLE, - expected_ouput_dir, - output_dir=expected_ouput_dir - ) + with TemporaryDirectory() as expected_output_dir: + self._download_model_and_assert_downloaded( + d, + VERSIONED_MODEL_HANDLE, + expected_output_dir, + output_dir=expected_output_dir + ) def test_versioned_model_download_with_bad_output_dir(self) -> None: - with create_test_cache() as d: - mock.patch("kagglehub.models.copytree", side_effect=Exception()) - bad_output_dir = "/bad/path/that/fails" + with ( + create_test_cache() as d, + TemporaryDirectory() as placeholder_dir, + mock.patch("kagglehub.models.copytree") as mock_copytree + ): + mock_copytree.side_effect = Exception("Mock exception") expected_output_dir = EXPECTED_MODEL_SUBDIR # falls back to default self._download_model_and_assert_downloaded( d, VERSIONED_MODEL_HANDLE, expected_output_dir, - output_dir=bad_output_dir + # note: placeholder name is irrelevant since copytree is mocked to throw + output_dir=placeholder_dir ) def test_unversioned_model_download_with_path_with_force_download(self) -> None: