Skip to content

Commit

Permalink
Making predictor model-unspecific
Browse files Browse the repository at this point in the history
  • Loading branch information
WenkelF committed Aug 7, 2023
1 parent da0d058 commit a3d4715
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 27 deletions.
1 change: 1 addition & 0 deletions graphium/cli/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def run_training_finetuning(cfg: DictConfig) -> None:
model_class=model_class,
model_kwargs=model_kwargs,
metrics=metrics,
task_levels=datamodule.get_task_levels(),
accelerator_type=accelerator_type,
task_norms=datamodule.task_norms,
)
Expand Down
2 changes: 2 additions & 0 deletions graphium/config/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def load_predictor(
model_class: Type[torch.nn.Module],
model_kwargs: Dict[str, Any],
metrics: Dict[str, MetricWrapper],
task_levels: Dict[str, str],
accelerator_type: str,
task_norms: Optional[Dict[Callable, Any]] = None,
) -> PredictorModule:
Expand All @@ -302,6 +303,7 @@ def load_predictor(
model_class=model_class,
model_kwargs=model_kwargs,
metrics=metrics,
task_levels=task_levels,
task_norms=task_norms,
**cfg_pred,
)
Expand Down
12 changes: 12 additions & 0 deletions graphium/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,18 @@ def _get_task_key(self, task_level: str, task: str):
task = task_prefix + task
return task

def get_task_levels(self):

This comment has been minimized.

Copy link
@DomInvivo

DomInvivo Aug 7, 2023

Collaborator

Document this method

task_level_map = {}

for task, task_args in self.task_specific_args.items():
if isinstance(task_args, DatasetProcessingParams):
task_args = task_args.__dict__ # Convert the class to a dictionary
task_level_map.update({
task: task_args["task_level"]
})

return task_level_map

def prepare_data(self):
"""Called only from a single process in distributed settings. Steps:
Expand Down
31 changes: 5 additions & 26 deletions graphium/trainer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
model_class: Type[nn.Module],
model_kwargs: Dict[str, Any],
loss_fun: Dict[str, Union[str, Callable]],
task_levels: Dict[str, str],

This comment has been minimized.

Copy link
@DomInvivo

DomInvivo Aug 7, 2023

Collaborator

Add docstring explaining what is task_levels

random_seed: int = 42,
optim_kwargs: Optional[Dict[str, Any]] = None,
torch_scheduler_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(

self.target_nan_mask = target_nan_mask
self.multitask_handling = multitask_handling
self.task_levels = task_levels
self.task_norms = task_norms

super().__init__()
Expand All @@ -95,22 +97,11 @@ def __init__(
)
eval_options[task].check_metrics_validity()

# Work-around to retain task level when model_kwargs are modified for FullGraphFinetuningNetwork
if "task_heads_kwargs" in model_kwargs.keys():
task_heads_kwargs = model_kwargs["task_heads_kwargs"]
elif "pretrained_model_kwargs" in model_kwargs.keys():
# This covers finetuning cases where we finetune from the task_heads
task_heads_kwargs = model_kwargs["pretrained_model_kwargs"]["task_heads_kwargs"]
else:
raise ValueError("incorrect model_kwargs")
self.task_heads_kwargs = task_heads_kwargs

self._eval_options_dict: Dict[str, EvalOptions] = eval_options
self._eval_options_dict = {
self._get_task_key(
task_level=task_heads_kwargs[key]["task_level"],
task_level=task_levels[key],
task=key
# task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key
): value
for key, value in self._eval_options_dict.items()
}
Expand All @@ -119,22 +110,10 @@ def __init__(

self.model = self._model_options.model_class(**self._model_options.model_kwargs)

# Maintain module map to easily select modules
# We now need to define the module_map in pretrained_model in FinetuningNetwork
# self._module_map = OrderedDict(
# pe_encoders=self.model.encoder_manager,
# pre_nn=self.model.pre_nn,
# pre_nn_edges=self.model.pre_nn_edges,
# gnn=self.model.gnn,
# graph_output_nn=self.model.task_heads.graph_output_nn,
# task_heads=self.model.task_heads.task_heads,
# )

loss_fun = {
self._get_task_key(
task_level=task_heads_kwargs[key]["task_level"],
task_level=task_levels[key],
task=key
# task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key
): value
for key, value in loss_fun.items()
}
Expand Down Expand Up @@ -338,7 +317,7 @@ def _general_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool)
preds = {k: preds[ii] for ii, k in enumerate(targets_dict.keys())}

preds = {
self._get_task_key(task_level=self.task_heads_kwargs[key]["task_level"], task=key): value
self._get_task_key(task_level=self.task_levels[key], task=key): value
for key, value in preds.items()
}
# preds = {k: preds[ii] for ii, k in enumerate(targets_dict.keys())}
Expand Down
2 changes: 1 addition & 1 deletion tests/test_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_finetuning_pipeline(self):
metrics = load_metrics(cfg)

predictor = load_predictor(
cfg, model_class, model_kwargs, metrics, accelerator_type, datamodule.task_norms
cfg, model_class, model_kwargs, metrics, datamodule.get_task_levels(), accelerator_type, datamodule.task_norms
)

self.assertEqual(
Expand Down

0 comments on commit a3d4715

Please sign in to comment.