diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 075b2ab5..e49b93a2 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -164,6 +164,7 @@ def load( if latest: cached_artifact = self._load_from_cache(model_key) if cached_artifact: + _LOGGER.debug("Found cached artifact for key: %s", model_key) return cached_artifact version_info = self.client.get_latest_versions(model_key, stages=[self.model_stage]) if not version_info: diff --git a/numalogic/registry/redis_registry.py b/numalogic/registry/redis_registry.py index b84ae404..35e32529 100644 --- a/numalogic/registry/redis_registry.py +++ b/numalogic/registry/redis_registry.py @@ -129,6 +129,7 @@ def __get_artifact_data( def __load_latest_artifact(self, key: str) -> ArtifactData: cached_artifact = self._load_from_cache(key) if cached_artifact: + _LOGGER.debug("Found cached artifact for key: %s", key) return cached_artifact production_key = self.__construct_production_key(key) if not self.client.exists(production_key): diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 09332cb1..6c96a513 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -10,7 +10,7 @@ from sklearn.ensemble import RandomForestRegressor from numalogic.models.autoencoder.variants import VanillaAE -from numalogic.registry import MLflowRegistry, ArtifactData +from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache from numalogic.registry.mlflow_registry import ModelStage from numalogic.tools.exceptions import ModelVersionError from tests.registry._mlflow_utils import ( @@ -338,6 +338,31 @@ def test_no_cache(self): self.assertIsNone(registry._load_from_cache("key")) self.assertIsNone(registry._clear_cache("key")) + def test_cache(self): + cache_registry = LocalLRUCache() + registry = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry) + registry._save_in_cache("key", ArtifactData(artifact=self.model, extras={}, metadata={})) + self.assertIsNotNone(registry._load_from_cache("key")) + self.assertIsNotNone(registry._clear_cache("key")) + + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch()) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) + @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) + @patch("mlflow.log_params", {"lr": 0.01}) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict())) + def test_cache_loading(self): + cache_registry = LocalLRUCache(ttl=50000) + ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry) + ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.model, **{"lr": 0.01}) + ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch") + key = MLflowRegistry.construct_key(self.skeys, self.dkeys) + self.assertIsNotNone(ml._load_from_cache(key)) + data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch") + self.assertIsNotNone(data) + if __name__ == "__main__": unittest.main()