diff --git a/deepface/commons/weight_utils.py b/deepface/commons/weight_utils.py index fa44b9972..f1ccf2c8a 100644 --- a/deepface/commons/weight_utils.py +++ b/deepface/commons/weight_utils.py @@ -19,6 +19,8 @@ logger = Logger() +ALLOWED_COMPRESS_TYPES = ["zip", "bz2"] + def download_weights_if_necessary( file_name: str, source_url: str, compress_type: Optional[str] = None @@ -40,12 +42,15 @@ def download_weights_if_necessary( logger.debug(f"{file_name} is already available at {target_file}") return target_file + if compress_type is not None and compress_type not in ALLOWED_COMPRESS_TYPES: + raise ValueError(f"unimplemented compress type - {compress_type}") + try: logger.info(f"🔗 {file_name} will be downloaded from {source_url} to {target_file}...") if compress_type is None: gdown.download(source_url, target_file, quiet=False) - elif compress_type is not None: + elif compress_type is not None and compress_type in ALLOWED_COMPRESS_TYPES: gdown.download(source_url, f"{target_file}.{compress_type}", quiet=False) except Exception as err: diff --git a/tests/test_commons.py b/tests/test_commons.py index ded1754b1..01a2210ac 100644 --- a/tests/test_commons.py +++ b/tests/test_commons.py @@ -1,15 +1,19 @@ # built-in dependencies import os +from unittest import mock import pytest # project dependencies from deepface.commons import folder_utils, weight_utils, package_utils from deepface.commons.logger import Logger +# pylint: disable=unused-argument + logger = Logger() tf_version = package_utils.get_tf_major_version() +# conditional imports if tf_version == 1: from keras.models import Sequential from keras.layers import ( @@ -41,9 +45,208 @@ def test_loading_broken_weights(): # vgg's weights cannot be loaded to this model with pytest.raises( - ValueError, - match="An exception occurred while loading the pre-trained weights from" + ValueError, match="An exception occurred while loading the pre-trained weights from" ): model = weight_utils.load_model_weights(model=model, weight_file=weight_file) logger.info("✅ test loading broken weight file is done") + + +@mock.patch("deepface.commons.folder_utils.get_deepface_home") # Update with your actual module +@mock.patch("gdown.download") # Mocking gdown's download function +@mock.patch("os.path.isfile") # Mocking os.path.isfile +@mock.patch("os.makedirs") # Mocking os.makedirs to avoid FileNotFoundError +@mock.patch("zipfile.ZipFile") # Mocking the ZipFile class +@mock.patch("bz2.BZ2File") # Mocking the BZ2File class +@mock.patch("builtins.open", new_callable=mock.mock_open()) # Mocking open +class TestDownloadWeightFeature: + def test_download_weights_for_available_file( + self, + mock_open, + mock_zipfile, + mock_bz2file, + mock_makedirs, + mock_isfile, + mock_gdown, + mock_get_deepface_home, + ): + mock_isfile.return_value = True + mock_get_deepface_home.return_value = "/mock/home" + + file_name = "model_weights.h5" + source_url = "http://example.com/model_weights.zip" + + result = weight_utils.download_weights_if_necessary(file_name, source_url) + + assert result == os.path.join("/mock/home", ".deepface/weights", file_name) + + mock_gdown.assert_not_called() + mock_zipfile.assert_not_called() + mock_bz2file.assert_not_called() + logger.info("✅ test download weights for available file is done") + + def test_download_weights_if_necessary_gdown_failure( + self, + mock_open, + mock_zipfile, + mock_bz2file, + mock_makedirs, + mock_isfile, + mock_gdown, + mock_get_deepface_home, + ): + # Setting up the mock return values + mock_get_deepface_home.return_value = "/mock/home" + mock_isfile.return_value = False # Simulate file not being present + + file_name = "model_weights.h5" + source_url = "http://example.com/model_weights.h5" + + # Simulate gdown.download raising an exception + mock_gdown.side_effect = Exception("Download failed!") + + # Call the function and check for ValueError + with pytest.raises( + ValueError, + match=f"⛓️‍💥 An exception occurred while downloading {file_name} from {source_url}.", + ): + weight_utils.download_weights_if_necessary(file_name, source_url) + + logger.info("✅ test for downloading weights while gdown fails done") + + def test_download_weights_if_necessary_no_compression( + self, + mock_open, + mock_zipfile, + mock_bz2file, + mock_makedirs, + mock_isfile, + mock_gdown, + mock_get_deepface_home, + ): + # Setting up the mock return values + mock_get_deepface_home.return_value = "/mock/home" + mock_isfile.return_value = False # Simulate file not being present + + file_name = "model_weights.h5" + source_url = "http://example.com/model_weights.h5" + + # Call the function + result = weight_utils.download_weights_if_necessary(file_name, source_url) + + # Assert that gdown.download was called with the correct parameters + mock_gdown.assert_called_once_with( + source_url, "/mock/home/.deepface/weights/model_weights.h5", quiet=False + ) + + # Assert that the return value is correct + assert result == "/mock/home/.deepface/weights/model_weights.h5" + + # Assert that zipfile.ZipFile and bz2.BZ2File were not called + mock_zipfile.assert_not_called() + mock_bz2file.assert_not_called() + + logger.info("✅ test download weights with no compression is done") + + def test_download_weights_if_necessary_zip( + self, + mock_open, + mock_zipfile, + mock_bz2file, + mock_makedirs, + mock_isfile, + mock_gdown, + mock_get_deepface_home, + ): + # Setting up the mock return values + mock_get_deepface_home.return_value = "/mock/home" + mock_isfile.return_value = False # Simulate file not being present + + file_name = "model_weights.h5" + source_url = "http://example.com/model_weights.zip" + compress_type = "zip" + + # Call the function + result = weight_utils.download_weights_if_necessary(file_name, source_url, compress_type) + + # Assert that gdown.download was called with the correct parameters + mock_gdown.assert_called_once_with( + source_url, "/mock/home/.deepface/weights/model_weights.h5.zip", quiet=False + ) + + # Simulate the unzipping behavior + mock_zipfile.return_value.__enter__.return_value.extractall = mock.Mock() + + # Call the function again to simulate unzipping + with mock_zipfile.return_value as zip_ref: + zip_ref.extractall("/mock/home/.deepface/weights") + + # Assert that the zip file was unzipped correctly + zip_ref.extractall.assert_called_once_with("/mock/home/.deepface/weights") + + # Assert that the return value is correct + assert result == "/mock/home/.deepface/weights/model_weights.h5" + + logger.info("✅ test download weights for zip is done") + + def test_download_weights_if_necessary_bz2( + self, + mock_open, + mock_zipfile, + mock_bz2file, + mock_makedirs, + mock_isfile, + mock_gdown, + mock_get_deepface_home, + ): + + # Setting up the mock return values + mock_get_deepface_home.return_value = "/mock/home" + mock_isfile.return_value = False # Simulate file not being present + + file_name = "model_weights.h5" + source_url = "http://example.com/model_weights.bz2" + compress_type = "bz2" + + # Simulate the download success + mock_gdown.return_value = None + + # Simulate the BZ2 file reading behavior + mock_bz2file.return_value.__enter__.return_value.read.return_value = b"fake data" + + # Call the function under test + result = weight_utils.download_weights_if_necessary(file_name, source_url, compress_type) + + # Assert that gdown.download was called with the correct parameters + mock_gdown.assert_called_once_with( + source_url, "/mock/home/.deepface/weights/model_weights.h5.bz2", quiet=False + ) + + # Ensure open() is called once for writing the decompressed data + mock_open.assert_called_once_with("/mock/home/.deepface/weights/model_weights.h5", "wb") + + # TODO: find a way to check write is called + + # Assert that the return value is correct + assert result == "/mock/home/.deepface/weights/model_weights.h5" + + logger.info("✅ test download weights for bz2 is done") + + def test_download_weights_for_non_supported_compress_type( + self, + mock_open, + mock_zipfile, + mock_bz2file, + mock_makedirs, + mock_isfile, + mock_gdown, + mock_get_deepface_home, + ): + mock_isfile.return_value = False + + file_name = "model_weights.h5" + source_url = "http://example.com/model_weights.bz2" + compress_type = "7z" + with pytest.raises(ValueError, match="unimplemented compress type - 7z"): + _ = weight_utils.download_weights_if_necessary(file_name, source_url, compress_type) + logger.info("✅ test download weights for unsupported compress type is done")