Skip to content

Commit

Permalink
Catch unpickling error
Browse files Browse the repository at this point in the history
  • Loading branch information
mirand863 committed Mar 10, 2024
1 parent 79100b7 commit ee7c294
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 5 deletions.
2 changes: 1 addition & 1 deletion hiclass/LocalClassifierPerLevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _fit_classifier(self, level, separator):
f"Loaded trained model for local classifier {level} from file {filename}"
)
return classifier
except:
except pickle.UnpicklingError:
self.logger_.warning(f"Could not load model from file {filename}")
self.logger_.info(f"Training local classifier {level}")
X, y, sample_weight = self._remove_empty_leaves(
Expand Down
2 changes: 1 addition & 1 deletion hiclass/LocalClassifierPerNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _fit_classifier(self, node):
f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}"
)
return classifier
except:
except pickle.UnpicklingError:
self.logger_.warning(f"Could not load model from file {filename}")
self.logger_.info(
f"Training local classifier {str(node).split(self.separator_)[-1]}"
Expand Down
2 changes: 1 addition & 1 deletion hiclass/LocalClassifierPerParentNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _fit_classifier(self, node):
f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}"
)
return classifier
except:
except pickle.UnpicklingError:
self.logger_.warning(f"Could not load model from file {filename}")
self.logger_.info(
f"Training local classifier {str(node).split(self.separator_)[-1]}"
Expand Down
2 changes: 0 additions & 2 deletions tests/test_LocalClassifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,4 @@ def test_tmp_dir(classifier):
(name, classifier) = pickle.load(open(filename, "rb"))
assert expected_name == name
check_is_fitted(classifier)
patcher.fs.remove("0cc175b9c0f1b6a831c399e269772661.sav")
patcher.fs.create_file("0cc175b9c0f1b6a831c399e269772661.sav")
clf.fit(x, y)

0 comments on commit ee7c294

Please sign in to comment.