From 9476c95b4889d99be84360993fbff54b5da96256 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Sat, 10 Feb 2024 19:36:29 +0330 Subject: [PATCH] :pencil2: Add backends verification to the `Trainer` --- hezar/constants.py | 1 - hezar/trainer/trainer.py | 6 +++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/hezar/constants.py b/hezar/constants.py index 79bd525e..79ac90cb 100644 --- a/hezar/constants.py +++ b/hezar/constants.py @@ -58,7 +58,6 @@ class Backends(ExplicitEnum): NLTK = "nltk" SCIKIT = "sklearn" SEQEVAL = "seqeval" - EVALUATE = "evaluate" ROUGE = "rouge_score" diff --git a/hezar/trainer/trainer.py b/hezar/trainer/trainer.py index 99e4b5be..0bc7fc8c 100644 --- a/hezar/trainer/trainer.py +++ b/hezar/trainer/trainer.py @@ -38,7 +38,7 @@ from ..data.datasets import Dataset from ..models import Model from ..preprocessors import Preprocessor, PreprocessorsContainer -from ..utils import Logger, colorize_text, is_backend_available, sanitize_function_parameters +from ..utils import Logger, colorize_text, is_backend_available, sanitize_function_parameters, verify_dependencies if TYPE_CHECKING: from accelerate import Accelerator @@ -97,6 +97,7 @@ class Trainer: trainer_state_file = DEFAULT_TRAINER_STATE_FILE default_optimizer = OptimizerType.ADAM default_lr_scheduler = None + _required_backends = [] def __init__( self, @@ -111,6 +112,9 @@ def __init__( lr_scheduler=None, accelerator: "Accelerator" = None, ): + # Check if all required dependencies are installed + verify_dependencies(self._required_backends) + self.config = config self.device = "cuda" if torch.cuda.is_available() and not self.config.use_cpu else "cpu"