diff --git a/graphium/cli/hydra.py b/graphium/cli/hydra.py index 46f678ed4..0ca626ed3 100644 --- a/graphium/cli/hydra.py +++ b/graphium/cli/hydra.py @@ -73,6 +73,7 @@ def run_training_finetuning(cfg: DictConfig) -> None: model_class=model_class, model_kwargs=model_kwargs, metrics=metrics, + task_levels=datamodule.get_task_levels(), accelerator_type=accelerator_type, task_norms=datamodule.task_norms, ) diff --git a/graphium/config/_loader.py b/graphium/config/_loader.py index a562fd54a..461ad218b 100644 --- a/graphium/config/_loader.py +++ b/graphium/config/_loader.py @@ -278,6 +278,7 @@ def load_predictor( model_class: Type[torch.nn.Module], model_kwargs: Dict[str, Any], metrics: Dict[str, MetricWrapper], + task_levels: Dict[str, str], accelerator_type: str, task_norms: Optional[Dict[Callable, Any]] = None, ) -> PredictorModule: @@ -302,6 +303,7 @@ def load_predictor( model_class=model_class, model_kwargs=model_kwargs, metrics=metrics, + task_levels=task_levels, task_norms=task_norms, **cfg_pred, ) diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 48d096803..4b795dd76 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -936,6 +936,18 @@ def _get_task_key(self, task_level: str, task: str): task = task_prefix + task return task + def get_task_levels(self): + task_level_map = {} + + for task, task_args in self.task_specific_args.items(): + if isinstance(task_args, DatasetProcessingParams): + task_args = task_args.__dict__ # Convert the class to a dictionary + task_level_map.update({ + task: task_args["task_level"] + }) + + return task_level_map + def prepare_data(self): """Called only from a single process in distributed settings. Steps: diff --git a/graphium/trainer/predictor.py b/graphium/trainer/predictor.py index 35d284179..74270bd74 100644 --- a/graphium/trainer/predictor.py +++ b/graphium/trainer/predictor.py @@ -27,6 +27,7 @@ def __init__( model_class: Type[nn.Module], model_kwargs: Dict[str, Any], loss_fun: Dict[str, Union[str, Callable]], + task_levels: Dict[str, str], random_seed: int = 42, optim_kwargs: Optional[Dict[str, Any]] = None, torch_scheduler_kwargs: Optional[Dict[str, Any]] = None, @@ -69,6 +70,7 @@ def __init__( self.target_nan_mask = target_nan_mask self.multitask_handling = multitask_handling + self.task_levels = task_levels self.task_norms = task_norms super().__init__() @@ -95,22 +97,11 @@ def __init__( ) eval_options[task].check_metrics_validity() - # Work-around to retain task level when model_kwargs are modified for FullGraphFinetuningNetwork - if "task_heads_kwargs" in model_kwargs.keys(): - task_heads_kwargs = model_kwargs["task_heads_kwargs"] - elif "pretrained_model_kwargs" in model_kwargs.keys(): - # This covers finetuning cases where we finetune from the task_heads - task_heads_kwargs = model_kwargs["pretrained_model_kwargs"]["task_heads_kwargs"] - else: - raise ValueError("incorrect model_kwargs") - self.task_heads_kwargs = task_heads_kwargs - self._eval_options_dict: Dict[str, EvalOptions] = eval_options self._eval_options_dict = { self._get_task_key( - task_level=task_heads_kwargs[key]["task_level"], + task_level=task_levels[key], task=key - # task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key ): value for key, value in self._eval_options_dict.items() } @@ -119,22 +110,10 @@ def __init__( self.model = self._model_options.model_class(**self._model_options.model_kwargs) - # Maintain module map to easily select modules - # We now need to define the module_map in pretrained_model in FinetuningNetwork - # self._module_map = OrderedDict( - # pe_encoders=self.model.encoder_manager, - # pre_nn=self.model.pre_nn, - # pre_nn_edges=self.model.pre_nn_edges, - # gnn=self.model.gnn, - # graph_output_nn=self.model.task_heads.graph_output_nn, - # task_heads=self.model.task_heads.task_heads, - # ) - loss_fun = { self._get_task_key( - task_level=task_heads_kwargs[key]["task_level"], + task_level=task_levels[key], task=key - # task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key ): value for key, value in loss_fun.items() } @@ -338,7 +317,7 @@ def _general_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool) preds = {k: preds[ii] for ii, k in enumerate(targets_dict.keys())} preds = { - self._get_task_key(task_level=self.task_heads_kwargs[key]["task_level"], task=key): value + self._get_task_key(task_level=self.task_levels[key], task=key): value for key, value in preds.items() } # preds = {k: preds[ii] for ii, k in enumerate(targets_dict.keys())} diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py index ffda3a8ca..51c2746f7 100644 --- a/tests/test_finetuning.py +++ b/tests/test_finetuning.py @@ -67,7 +67,7 @@ def test_finetuning_pipeline(self): metrics = load_metrics(cfg) predictor = load_predictor( - cfg, model_class, model_kwargs, metrics, accelerator_type, datamodule.task_norms + cfg, model_class, model_kwargs, metrics, datamodule.get_task_levels(), accelerator_type, datamodule.task_norms ) self.assertEqual(