diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index db819d78..896828cd 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -264,8 +264,8 @@ def _fit_classifier(self, level, separator): f"Loaded trained model for local classifier {level} from file {filename}" ) return classifier - except pickle.UnpicklingError: - self.logger_.warning(f"Could not load model from file {filename}") + except (pickle.UnpicklingError, EOFError): + self.logger_.error(f"Could not load model from file {filename}") self.logger_.info(f"Training local classifier {level}") X, y, sample_weight = self._remove_empty_leaves( separator, self.X_, self.y_[:, level], self.sample_weight_ diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 7e76862e..65c14113 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -255,8 +255,8 @@ def _fit_classifier(self, node): f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" ) return classifier - except pickle.UnpicklingError: - self.logger_.warning(f"Could not load model from file {filename}") + except (pickle.UnpicklingError, EOFError): + self.logger_.error(f"Could not load model from file {filename}") self.logger_.info( f"Training local classifier {str(node).split(self.separator_)[-1]}" ) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 439c5e75..5873b52a 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -224,8 +224,8 @@ def _fit_classifier(self, node): f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" ) return classifier - except pickle.UnpicklingError: - self.logger_.warning(f"Could not load model from file {filename}") + except (pickle.UnpicklingError, EOFError): + self.logger_.error(f"Could not load model from file {filename}") self.logger_.info( f"Training local classifier {str(node).split(self.separator_)[-1]}" )