diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index feb51a97..db819d78 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -264,7 +264,7 @@ def _fit_classifier(self, level, separator): f"Loaded trained model for local classifier {level} from file {filename}" ) return classifier - except: + except pickle.UnpicklingError: self.logger_.warning(f"Could not load model from file {filename}") self.logger_.info(f"Training local classifier {level}") X, y, sample_weight = self._remove_empty_leaves( diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 707c618d..7e76862e 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -255,7 +255,7 @@ def _fit_classifier(self, node): f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" ) return classifier - except: + except pickle.UnpicklingError: self.logger_.warning(f"Could not load model from file {filename}") self.logger_.info( f"Training local classifier {str(node).split(self.separator_)[-1]}" diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 8eb504c2..439c5e75 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -224,7 +224,7 @@ def _fit_classifier(self, node): f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" ) return classifier - except: + except pickle.UnpicklingError: self.logger_.warning(f"Could not load model from file {filename}") self.logger_.info( f"Training local classifier {str(node).split(self.separator_)[-1]}" diff --git a/tests/test_LocalClassifiers.py b/tests/test_LocalClassifiers.py index b42c9736..abd7bddf 100644 --- a/tests/test_LocalClassifiers.py +++ b/tests/test_LocalClassifiers.py @@ -142,6 +142,4 @@ def test_tmp_dir(classifier): (name, classifier) = pickle.load(open(filename, "rb")) assert expected_name == name check_is_fitted(classifier) - patcher.fs.remove("0cc175b9c0f1b6a831c399e269772661.sav") - patcher.fs.create_file("0cc175b9c0f1b6a831c399e269772661.sav") clf.fit(x, y)