From 36930e1ccf82fa147fcb46ed9c4215b5623a2f08 Mon Sep 17 00:00:00 2001 From: wenkelf Date: Thu, 12 Oct 2023 04:31:56 +0000 Subject: [PATCH] Various updates --- .gitignore | 2 + .../finetuning/admet_baseline.yaml | 71 +++++++++++++++++++ .../tasks/loss_metrics_datamodule/admet.yaml | 2 +- graphium/cli/finetune_utils.py | 34 +++++++++ graphium/cli/parameters.py | 1 + graphium/cli/train_finetune_test.py | 16 ++++- graphium/config/_loader.py | 8 ++- .../config/dummy_finetuning_from_gnn.yaml | 12 ++-- graphium/data/datamodule.py | 2 +- graphium/finetuning/finetuning.py | 13 ++-- .../finetuning/finetuning_architecture.py | 10 +-- graphium/finetuning/utils.py | 18 +++-- .../nn/architectures/global_architectures.py | 4 +- graphium/trainer/predictor_options.py | 2 +- tests/test_finetuning.py | 10 +-- 15 files changed, 170 insertions(+), 35 deletions(-) create mode 100644 expts/hydra-configs/finetuning/admet_baseline.yaml diff --git a/.gitignore b/.gitignore index b9f39521e..289f10a4d 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,8 @@ tests/temp_cache* predictions/ draft/ scripts-expts/ +sweeps/ +mup/ # Data and predictions graphium/data/ZINC_bench_gnn/ diff --git a/expts/hydra-configs/finetuning/admet_baseline.yaml b/expts/hydra-configs/finetuning/admet_baseline.yaml new file mode 100644 index 000000000..410d0dd64 --- /dev/null +++ b/expts/hydra-configs/finetuning/admet_baseline.yaml @@ -0,0 +1,71 @@ +# @package _global_ + +defaults: + - override /tasks/loss_metrics_datamodule: admet + +constants: + task: tbd + name: finetune_${constants.task} + wandb: + name: ${constants.name} + project: finetuning + entity: recursion + seed: 42 + max_epochs: 100 + data_dir: ../data/graphium/admet/${constants.task} + datacache_path: ../datacache/admet/${constants.task} + raise_train_error: true + metric: ${get_metric_name:${constants.task}} + +datamodule: + args: + batch_size_training: 32 + dataloading_from: ram + persistent_workers: true + num_workers: 4 + +trainer: + model_checkpoint: + # save_top_k: 1 + # monitor: graph_${constants.task}/${constants.metric}/val + # mode: ${get_metric_mode:${constants.task}} + # save_last: true + # filename: best + dirpath: model_checkpoints/finetuning/${constants.task}/${now:%Y-%m-%d_%H-%M-%S.%f}/ + every_n_epochs: 200 + trainer: + precision: 32 + check_val_every_n_epoch: 1 + # early_stopping: + # monitor: graph_${constants.task}/${constants.metric}/val + # mode: ${get_metric_mode:${constants.task}} + # min_delta: 0.001 + # patience: 10 + accumulate_grad_batches: none + # test_from_checkpoint: best.ckpt + # test_from_checkpoint: ${trainer.model_checkpoint.dirpath}/best.ckpt + +predictor: + optim_kwargs: + lr: 0.000005 + + +# == Fine-tuning config == + +finetuning: + task: ${constants.task} + level: graph + pretrained_model: tbd + finetuning_module: graph_output_nn + sub_module_from_pretrained: graph + new_sub_module: graph + + keep_modules_after_finetuning_module: # optional + task_heads-pcqm4m_g25: + new_sub_module: ${constants.task} + hidden_dims: 256 + depth: 2 + last_activation: ${get_last_activation:${constants.task}} + out_dim: 1 + + epoch_unfreeze_all: tbd \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml index 87136b683..cfff5f689 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml @@ -23,7 +23,7 @@ predictor: half_life_obach: ["spearman"] clearance_microsome_az: ["spearman"] clearance_hepatocyte_az: ["spearman"] - herg: ["mae"] + herg: ["auroc"] ames: ["auroc"] dili: ["auroc"] ld50_zhu: ["auroc"] diff --git a/graphium/cli/finetune_utils.py b/graphium/cli/finetune_utils.py index ef9089279..847b991f4 100644 --- a/graphium/cli/finetune_utils.py +++ b/graphium/cli/finetune_utils.py @@ -134,3 +134,37 @@ def get_fingerprints_from_model( path = fs.join(save_destination, "fingerprints.pt") logger.info(f"Saving fingerprints to {path}") torch.save(fps, path) + + +def get_tdc_task_specific(task: str, output: Literal['name', 'mode', 'last_activation']): + + if output == 'last_activation': + config_arch_path="expts/hydra-configs/tasks/task_heads/admet.yaml" + with open(config_arch_path, 'r') as yaml_file: + config_tdc_arch = yaml.load(yaml_file, Loader=yaml.FullLoader) + + return config_tdc_arch['architecture']['task_heads'][task]['last_activation'] + + else: + config_metrics_path="expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml" + with open(config_metrics_path, 'r') as yaml_file: + config_tdc_task_metric = yaml.load(yaml_file, Loader=yaml.FullLoader) + + metric = config_tdc_task_metric['predictor']['metrics_on_progress_bar'][task][0] + + metric_mode_map = { + 'mae': 'min', + 'auroc': 'max', + 'auprc': 'max', + 'spearman': 'max', + } + + if output == 'name': + return metric + elif output == 'mode': + return metric_mode_map[metric] + +OmegaConf.register_new_resolver("get_metric_name", lambda x: get_tdc_task_specific(x, output='name')) +OmegaConf.register_new_resolver("get_metric_mode", lambda x: get_tdc_task_specific(x, output='mode')) +OmegaConf.register_new_resolver("get_last_activation", lambda x: get_tdc_task_specific(x, output='last_activation')) +OmegaConf.register_new_resolver("eval", lambda x: eval(x, {"np": np})) diff --git a/graphium/cli/parameters.py b/graphium/cli/parameters.py index 8021c88e1..e0b31ac91 100644 --- a/graphium/cli/parameters.py +++ b/graphium/cli/parameters.py @@ -17,6 +17,7 @@ from graphium.trainer.predictor_options import ModelOptions + param_app = typer.Typer(help="Parameter counts.") app.add_typer(param_app, name="params") diff --git a/graphium/cli/train_finetune_test.py b/graphium/cli/train_finetune_test.py index ffb5a7512..4e7fe2961 100644 --- a/graphium/cli/train_finetune_test.py +++ b/graphium/cli/train_finetune_test.py @@ -1,3 +1,4 @@ +from typing import List, Literal, Union import os import time import timeit @@ -37,6 +38,7 @@ from graphium.trainer.predictor import PredictorModule from graphium.utils.safe_run import SafeRun +import graphium.cli.finetune_utils TESTING_ONLY_CONFIG_KEY = "testing_only" @@ -139,9 +141,12 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: logger.info("Computing the maximum number of nodes and edges per graph") predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"]) + # When resuming training from a checkpoint, we need to provide the path to the checkpoint in the config + resume_ckpt_path = cfg["trainer"].pop("resume_from_checkpoint", None) + # Run the model training with SafeRun(name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True): - trainer.fit(model=predictor, datamodule=datamodule) + trainer.fit(model=predictor, datamodule=datamodule, ckpt_path=resume_ckpt_path) # Save validation metrics - Base utility in case someone doesn't use a logger. results = trainer.callback_metrics @@ -152,9 +157,16 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: # Determine the max num nodes and edges in testing predictor.set_max_nodes_edges_per_graph(datamodule, stages=["test"]) + # When checkpoints are logged during training, we can, e.g., use the best or last checkpoint for testing + test_ckpt_path = None + test_ckpt_name = cfg['trainer'].pop('test_from_checkpoint', None) + test_ckpt_dir = cfg['trainer']['model_checkpoint'].pop('dirpath', None) + if test_ckpt_name is not None and test_ckpt_dir is not None: + test_ckpt_path = os.path.join(test_ckpt_dir, test_ckpt_name) + # Run the model testing with SafeRun(name="TESTING", raise_error=cfg["constants"]["raise_train_error"], verbose=True): - trainer.test(model=predictor, datamodule=datamodule) # , ckpt_path=ckpt_path) + trainer.test(model=predictor, datamodule=datamodule, ckpt_path=test_ckpt_path) logger.info("-" * 50) logger.info("Total compute time:", timeit.default_timer() - st) diff --git a/graphium/config/_loader.py b/graphium/config/_loader.py index 3235c9b68..000cd59f1 100644 --- a/graphium/config/_loader.py +++ b/graphium/config/_loader.py @@ -189,7 +189,7 @@ def load_architecture( architecture: The datamodule used to process and load the data """ - if isinstance(config, dict): + if isinstance(config, dict) and 'finetuning' not in config: config = omegaconf.OmegaConf.create(config) cfg_arch = config["architecture"] @@ -249,7 +249,8 @@ def load_architecture( gnn_kwargs.setdefault("in_dim", edge_in_dim) # Set the parameters for the full network - task_heads_kwargs = omegaconf.OmegaConf.to_object(task_heads_kwargs) + if 'finetuning' not in config: + task_heads_kwargs = omegaconf.OmegaConf.to_object(task_heads_kwargs) # Set all the input arguments for the model model_kwargs = dict( @@ -323,6 +324,9 @@ def load_predictor( model_class=model_class, model_kwargs=scaled_model_kwargs, metrics=metrics, + task_levels=task_levels, + featurization=featurization, + task_norms=task_norms, **cfg_pred, ) diff --git a/graphium/config/dummy_finetuning_from_gnn.yaml b/graphium/config/dummy_finetuning_from_gnn.yaml index 3e7224b5e..9fe058e28 100644 --- a/graphium/config/dummy_finetuning_from_gnn.yaml +++ b/graphium/config/dummy_finetuning_from_gnn.yaml @@ -34,7 +34,7 @@ finetuning: level: graph # Pretrained model - pretrained_model_name: dummy-pretrained-model + pretrained_model: dummy-pretrained-model finetuning_module: gnn # Changes to finetuning_module @@ -42,10 +42,10 @@ finetuning: added_depth: 3 keep_modules_after_finetuning_module: # optional - graph_output_nn/graph: + graph_output_nn-graph: pooling: [mean] depth: 2 - task_heads/zinc: + task_heads-zinc: new_sub_module: lipophilicity_astrazeneca out_dim: 1 @@ -134,8 +134,4 @@ datamodule: prepare_dict_or_graph: pyg:graph featurization_progress: True featurization_backend: "loky" - persistent_workers: False - - - - + persistent_workers: False \ No newline at end of file diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index f150ba83d..1a4e49114 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -1961,7 +1961,7 @@ def _sub_sample_df( n = min(sample_size, df.shape[0]) df = df.sample(n=n, random_state=seed) elif isinstance(sample_size, float): - df = df.sample(f=sample_size, random_state=seed) + df = df.sample(frac=sample_size, random_state=seed) elif sample_size is None: pass else: diff --git a/graphium/finetuning/finetuning.py b/graphium/finetuning/finetuning.py index afa33cad5..3829db9e4 100644 --- a/graphium/finetuning/finetuning.py +++ b/graphium/finetuning/finetuning.py @@ -13,7 +13,7 @@ class GraphFinetuning(BaseFinetuning): def __init__( self, finetuning_module: str, - added_depth: int, + added_depth: int = 0, unfreeze_pretrained_depth: Optional[int] = None, epoch_unfreeze_all: int = 0, train_bn: bool = False, @@ -51,13 +51,13 @@ def freeze_before_training(self, pl_module: pl.LightningModule): module_map = pl_module.model.pretrained_model.net._module_map for module_name in module_map.keys(): - self.freeze_module(module_name, module_map) + self.freeze_module(pl_module, module_name, module_map) if module_name.startswith(self.finetuning_module): # Do not freeze modules after finetuning module break - def freeze_module(self, module_name: str, module_map: Dict[str, Union[nn.ModuleList, Any]]): + def freeze_module(self, pl_module, module_name: str, module_map: Dict[str, Union[nn.ModuleList, Any]]): """ Freeze specific modules @@ -66,10 +66,15 @@ def freeze_module(self, module_name: str, module_map: Dict[str, Union[nn.ModuleL module_map: Dictionary mapping from module_name to corresponding module(s) """ modules = module_map[module_name] + + if module_name == "pe_encoders": + for param in pl_module.model.pretrained_model.net.encoder_manager.parameters(): + param.requires_grad = False # We only partially freeze the finetuning module if module_name.startswith(self.finetuning_module): - modules = modules[: -self.training_depth] + if self.training_depth > 0: + modules = modules[: -self.training_depth] self.freeze(modules=modules, train_bn=self.train_bn) diff --git a/graphium/finetuning/finetuning_architecture.py b/graphium/finetuning/finetuning_architecture.py index 1f9649277..55718fed0 100644 --- a/graphium/finetuning/finetuning_architecture.py +++ b/graphium/finetuning/finetuning_architecture.py @@ -199,7 +199,7 @@ def __init__( super().__init__() # Load pretrained model - pretrained_model = PredictorModule.load_pretrained_model(pretrained_model).model + pretrained_model = PredictorModule.load_pretrained_model(pretrained_model, device='cpu').model pretrained_model.create_module_map() # Initialize new model with architecture after desired modifications to architecture. @@ -219,7 +219,7 @@ def overwrite_with_pretrained( self, pretrained_model, finetuning_module: str, - added_depth: int, + added_depth: int = 0, sub_module_from_pretrained: str = None, ): """ @@ -236,7 +236,7 @@ def overwrite_with_pretrained( module_names_from_pretrained = module_map_from_pretrained.keys() super_module_names_from_pretrained = set( - [module_name.split("/")[0] for module_name in module_names_from_pretrained] + [module_name.split("-")[0] for module_name in module_names_from_pretrained] ) for module_name in module_map.keys(): @@ -254,10 +254,10 @@ def overwrite_with_pretrained( if module_name in module_map_from_pretrained.keys(): for idx in range(shared_depth): module_map[module_name][idx] = module_map_from_pretrained[module_name][idx] - elif module_name.split("/")[0] in super_module_names_from_pretrained: + elif module_name.split("-")[0] in super_module_names_from_pretrained: for idx in range(shared_depth): module_map[module_name][idx] = module_map_from_pretrained[ - "".join([module_name.split("/")[0], "/", sub_module_from_pretrained]) + "".join([module_name.split("-")[0], "-", sub_module_from_pretrained]) ][idx] else: raise RuntimeError("Mismatch between loaded pretrained model and model to be overwritten.") diff --git a/graphium/finetuning/utils.py b/graphium/finetuning/utils.py index ede0f639c..abcd11644 100644 --- a/graphium/finetuning/utils.py +++ b/graphium/finetuning/utils.py @@ -5,6 +5,8 @@ from graphium.trainer import PredictorModule +import graphium + def filter_cfg_based_on_admet_benchmark_name(config: Dict[str, Any], names: Union[List[str], str]): """ @@ -64,6 +66,7 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): arch_keys = pretrained_architecture.keys() arch_keys = [key.replace("_kwargs", "") for key in arch_keys] cfg_arch = {arch_keys[idx]: value for idx, value in enumerate(pretrained_architecture.values())} + cfg_arch_from_pretrained = deepcopy(cfg_arch) # Featurization cfg["datamodule"]["args"]["featurization"] = pretrained_predictor.featurization @@ -91,6 +94,9 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): else cfg_arch[finetuning_module][sub_module_from_pretrained].get("out_dim") ) + if new_module_kwargs["depth"] is None: + new_module_kwargs["depth"] = len(new_module_kwargs["hidden_dims"]) + 1 + upd_kwargs = { "out_dim": cfg_finetune.pop("new_out_dim", out_dim), "depth": new_module_kwargs["depth"] @@ -98,6 +104,10 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): - cfg_finetune.pop("drop_depth", 0), } + new_last_activation = cfg_finetune.pop("new_last_activation", None) + if new_last_activation is not None: + upd_kwargs["last_activation"] = new_last_activation + # Update config new_module_kwargs.update(upd_kwargs) @@ -110,8 +120,8 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): module_list = list(module_map_from_pretrained.keys()) super_module_list = [] for module in module_list: - if module.split("/")[0] not in super_module_list: # Only add each supermodule once - super_module_list.append(module.split("/")[0]) + if module.split("-")[0] not in super_module_list: # Only add each supermodule once + super_module_list.append(module.split("-")[0]) # Set configuration of modules after finetuning module to None cutoff_idx = ( @@ -170,7 +180,7 @@ def update_cfg_arch_for_module( updates: Changes to apply to key-work arguments of selected module """ # We need to distinguish between modules with & without submodules - if "/" not in module_name: + if "-" not in module_name: if cfg_arch[module_name] is None: cfg_arch[module_name] = {} @@ -179,7 +189,7 @@ def update_cfg_arch_for_module( cfg_arch.update({module_name, cfg_arch_from_pretrained}) else: - module_name, sub_module = module_name.split("/") + module_name, sub_module = module_name.split("-") new_sub_module = updates.pop("new_sub_module", sub_module) if cfg_arch[module_name] is None: diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 903803642..0e4599b24 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -1229,7 +1229,7 @@ def create_module_map(self, level: Union[Literal["layers"], Literal["module"]] = if self.task_heads is not None: self._module_map.update( { - "graph_output_nn/" + "graph_output_nn-" + output_level: self.task_heads.graph_output_nn[output_level].graph_output_nn for output_level in self.task_heads.graph_output_nn.keys() } @@ -1237,7 +1237,7 @@ def create_module_map(self, level: Union[Literal["layers"], Literal["module"]] = self._module_map.update( { - "task_heads/" + task_head_name: self.task_heads.task_heads[task_head_name] + "task_heads-" + task_head_name: self.task_heads.task_heads[task_head_name] for task_head_name in self.task_heads.task_heads.keys() } ) diff --git a/graphium/trainer/predictor_options.py b/graphium/trainer/predictor_options.py index 0bab97674..20e193bca 100644 --- a/graphium/trainer/predictor_options.py +++ b/graphium/trainer/predictor_options.py @@ -99,7 +99,7 @@ def set_kwargs(self): self.torch_scheduler_kwargs.setdefault("module_type", "ReduceLROnPlateau") # Get the class for the scheduler - scheduler_class = self.torch_scheduler_kwargs.pop("module_type", None) + scheduler_class = self.torch_scheduler_kwargs.pop("module_type") if self.scheduler_class is None: if isinstance(scheduler_class, str): self.scheduler_class = SCHEDULER_DICT[scheduler_class] diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py index 5ce760264..d2cd09743 100644 --- a/tests/test_finetuning.py +++ b/tests/test_finetuning.py @@ -75,9 +75,9 @@ def test_finetuning_from_task_head(self): module_map = deepcopy(predictor.model.pretrained_model.net._module_map) cfg_finetune = cfg["finetuning"] - finetuning_module = "".join([cfg_finetune["finetuning_module"], "/", cfg_finetune["task"]]) + finetuning_module = "".join([cfg_finetune["finetuning_module"], "-", cfg_finetune["task"]]) finetuning_module_from_pretrained = "".join( - [cfg_finetune["finetuning_module"], "/", cfg_finetune["sub_module_from_pretrained"]] + [cfg_finetune["finetuning_module"], "-", cfg_finetune["sub_module_from_pretrained"]] ) # Test for correctly modified shapes and number of layers in finetuning module @@ -145,7 +145,7 @@ def on_train_epoch_start(self, trainer, pl_module): module_map = pl_module.model.pretrained_model.net._module_map finetuning_module = "".join( - [self.cfg_finetune["finetuning_module"], "/", self.cfg_finetune["task"]] + [self.cfg_finetune["finetuning_module"], "-", self.cfg_finetune["task"]] ) training_depth = self.cfg_finetune["added_depth"] + self.cfg_finetune.pop( "unfreeze_pretrained_depth", 0 @@ -272,7 +272,7 @@ def test_finetuning_from_gnn(self): 5, ) self.assertEqual(module_map[finetuning_module][-1].model.lin.weight.size(0), 96) - self.assertEqual(len(module_map["graph_output_nn/graph"]), 2) + self.assertEqual(len(module_map["graph_output_nn-graph"]), 2) assert predictor.model.pretrained_model.net.task_heads.graph_output_nn[ "graph" @@ -284,7 +284,7 @@ def test_finetuning_from_gnn(self): # Load pretrained & replace in predictor pretrained_model = PredictorModule.load_pretrained_model( - cfg["finetuning"]["pretrained_model_name"], device="cpu" + cfg["finetuning"]["pretrained_model"], device="cpu" ).model pretrained_model.create_module_map()