Skip to content

Commit

Permalink
Various updates
Browse files Browse the repository at this point in the history
  • Loading branch information
WenkelF committed Oct 12, 2023
1 parent 3800525 commit 36930e1
Show file tree
Hide file tree
Showing 15 changed files with 170 additions and 35 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ tests/temp_cache*
predictions/
draft/
scripts-expts/
sweeps/
mup/

# Data and predictions
graphium/data/ZINC_bench_gnn/
Expand Down
71 changes: 71 additions & 0 deletions expts/hydra-configs/finetuning/admet_baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# @package _global_

defaults:
- override /tasks/loss_metrics_datamodule: admet

constants:
task: tbd
name: finetune_${constants.task}
wandb:
name: ${constants.name}
project: finetuning
entity: recursion
seed: 42
max_epochs: 100
data_dir: ../data/graphium/admet/${constants.task}
datacache_path: ../datacache/admet/${constants.task}
raise_train_error: true
metric: ${get_metric_name:${constants.task}}

datamodule:
args:
batch_size_training: 32
dataloading_from: ram
persistent_workers: true
num_workers: 4

trainer:
model_checkpoint:
# save_top_k: 1
# monitor: graph_${constants.task}/${constants.metric}/val
# mode: ${get_metric_mode:${constants.task}}
# save_last: true
# filename: best
dirpath: model_checkpoints/finetuning/${constants.task}/${now:%Y-%m-%d_%H-%M-%S.%f}/
every_n_epochs: 200
trainer:
precision: 32
check_val_every_n_epoch: 1
# early_stopping:
# monitor: graph_${constants.task}/${constants.metric}/val
# mode: ${get_metric_mode:${constants.task}}
# min_delta: 0.001
# patience: 10
accumulate_grad_batches: none
# test_from_checkpoint: best.ckpt
# test_from_checkpoint: ${trainer.model_checkpoint.dirpath}/best.ckpt

predictor:
optim_kwargs:
lr: 0.000005


# == Fine-tuning config ==

finetuning:
task: ${constants.task}
level: graph
pretrained_model: tbd
finetuning_module: graph_output_nn
sub_module_from_pretrained: graph
new_sub_module: graph

keep_modules_after_finetuning_module: # optional
task_heads-pcqm4m_g25:
new_sub_module: ${constants.task}
hidden_dims: 256
depth: 2
last_activation: ${get_last_activation:${constants.task}}
out_dim: 1

epoch_unfreeze_all: tbd
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ predictor:
half_life_obach: ["spearman"]
clearance_microsome_az: ["spearman"]
clearance_hepatocyte_az: ["spearman"]
herg: ["mae"]
herg: ["auroc"]
ames: ["auroc"]
dili: ["auroc"]
ld50_zhu: ["auroc"]
Expand Down
34 changes: 34 additions & 0 deletions graphium/cli/finetune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,37 @@ def get_fingerprints_from_model(
path = fs.join(save_destination, "fingerprints.pt")
logger.info(f"Saving fingerprints to {path}")
torch.save(fps, path)


def get_tdc_task_specific(task: str, output: Literal['name', 'mode', 'last_activation']):

if output == 'last_activation':
config_arch_path="expts/hydra-configs/tasks/task_heads/admet.yaml"
with open(config_arch_path, 'r') as yaml_file:
config_tdc_arch = yaml.load(yaml_file, Loader=yaml.FullLoader)

return config_tdc_arch['architecture']['task_heads'][task]['last_activation']

else:
config_metrics_path="expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml"
with open(config_metrics_path, 'r') as yaml_file:
config_tdc_task_metric = yaml.load(yaml_file, Loader=yaml.FullLoader)

metric = config_tdc_task_metric['predictor']['metrics_on_progress_bar'][task][0]

metric_mode_map = {
'mae': 'min',
'auroc': 'max',
'auprc': 'max',
'spearman': 'max',
}

if output == 'name':
return metric
elif output == 'mode':
return metric_mode_map[metric]

OmegaConf.register_new_resolver("get_metric_name", lambda x: get_tdc_task_specific(x, output='name'))
OmegaConf.register_new_resolver("get_metric_mode", lambda x: get_tdc_task_specific(x, output='mode'))
OmegaConf.register_new_resolver("get_last_activation", lambda x: get_tdc_task_specific(x, output='last_activation'))
OmegaConf.register_new_resolver("eval", lambda x: eval(x, {"np": np}))
1 change: 1 addition & 0 deletions graphium/cli/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from graphium.trainer.predictor_options import ModelOptions


param_app = typer.Typer(help="Parameter counts.")
app.add_typer(param_app, name="params")

Expand Down
16 changes: 14 additions & 2 deletions graphium/cli/train_finetune_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List, Literal, Union
import os
import time
import timeit
Expand Down Expand Up @@ -37,6 +38,7 @@
from graphium.trainer.predictor import PredictorModule
from graphium.utils.safe_run import SafeRun

import graphium.cli.finetune_utils

TESTING_ONLY_CONFIG_KEY = "testing_only"

Expand Down Expand Up @@ -139,9 +141,12 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
logger.info("Computing the maximum number of nodes and edges per graph")
predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"])

# When resuming training from a checkpoint, we need to provide the path to the checkpoint in the config
resume_ckpt_path = cfg["trainer"].pop("resume_from_checkpoint", None)

# Run the model training
with SafeRun(name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
trainer.fit(model=predictor, datamodule=datamodule)
trainer.fit(model=predictor, datamodule=datamodule, ckpt_path=resume_ckpt_path)

# Save validation metrics - Base utility in case someone doesn't use a logger.
results = trainer.callback_metrics
Expand All @@ -152,9 +157,16 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
# Determine the max num nodes and edges in testing
predictor.set_max_nodes_edges_per_graph(datamodule, stages=["test"])

# When checkpoints are logged during training, we can, e.g., use the best or last checkpoint for testing
test_ckpt_path = None
test_ckpt_name = cfg['trainer'].pop('test_from_checkpoint', None)
test_ckpt_dir = cfg['trainer']['model_checkpoint'].pop('dirpath', None)
if test_ckpt_name is not None and test_ckpt_dir is not None:
test_ckpt_path = os.path.join(test_ckpt_dir, test_ckpt_name)

# Run the model testing
with SafeRun(name="TESTING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
trainer.test(model=predictor, datamodule=datamodule) # , ckpt_path=ckpt_path)
trainer.test(model=predictor, datamodule=datamodule, ckpt_path=test_ckpt_path)

logger.info("-" * 50)
logger.info("Total compute time:", timeit.default_timer() - st)
Expand Down
8 changes: 6 additions & 2 deletions graphium/config/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def load_architecture(
architecture: The datamodule used to process and load the data
"""

if isinstance(config, dict):
if isinstance(config, dict) and 'finetuning' not in config:
config = omegaconf.OmegaConf.create(config)
cfg_arch = config["architecture"]

Expand Down Expand Up @@ -249,7 +249,8 @@ def load_architecture(
gnn_kwargs.setdefault("in_dim", edge_in_dim)

# Set the parameters for the full network
task_heads_kwargs = omegaconf.OmegaConf.to_object(task_heads_kwargs)
if 'finetuning' not in config:
task_heads_kwargs = omegaconf.OmegaConf.to_object(task_heads_kwargs)

# Set all the input arguments for the model
model_kwargs = dict(
Expand Down Expand Up @@ -323,6 +324,9 @@ def load_predictor(
model_class=model_class,
model_kwargs=scaled_model_kwargs,
metrics=metrics,
task_levels=task_levels,
featurization=featurization,
task_norms=task_norms,
**cfg_pred,
)

Expand Down
12 changes: 4 additions & 8 deletions graphium/config/dummy_finetuning_from_gnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,18 @@ finetuning:
level: graph

# Pretrained model
pretrained_model_name: dummy-pretrained-model
pretrained_model: dummy-pretrained-model
finetuning_module: gnn

# Changes to finetuning_module
drop_depth: 2
added_depth: 3

keep_modules_after_finetuning_module: # optional
graph_output_nn/graph:
graph_output_nn-graph:
pooling: [mean]
depth: 2
task_heads/zinc:
task_heads-zinc:
new_sub_module: lipophilicity_astrazeneca
out_dim: 1

Expand Down Expand Up @@ -134,8 +134,4 @@ datamodule:
prepare_dict_or_graph: pyg:graph
featurization_progress: True
featurization_backend: "loky"
persistent_workers: False




persistent_workers: False
2 changes: 1 addition & 1 deletion graphium/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,7 +1961,7 @@ def _sub_sample_df(
n = min(sample_size, df.shape[0])
df = df.sample(n=n, random_state=seed)
elif isinstance(sample_size, float):
df = df.sample(f=sample_size, random_state=seed)
df = df.sample(frac=sample_size, random_state=seed)
elif sample_size is None:
pass
else:
Expand Down
13 changes: 9 additions & 4 deletions graphium/finetuning/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class GraphFinetuning(BaseFinetuning):
def __init__(
self,
finetuning_module: str,
added_depth: int,
added_depth: int = 0,
unfreeze_pretrained_depth: Optional[int] = None,
epoch_unfreeze_all: int = 0,
train_bn: bool = False,
Expand Down Expand Up @@ -51,13 +51,13 @@ def freeze_before_training(self, pl_module: pl.LightningModule):
module_map = pl_module.model.pretrained_model.net._module_map

for module_name in module_map.keys():
self.freeze_module(module_name, module_map)
self.freeze_module(pl_module, module_name, module_map)

if module_name.startswith(self.finetuning_module):
# Do not freeze modules after finetuning module
break

def freeze_module(self, module_name: str, module_map: Dict[str, Union[nn.ModuleList, Any]]):
def freeze_module(self, pl_module, module_name: str, module_map: Dict[str, Union[nn.ModuleList, Any]]):
"""
Freeze specific modules
Expand All @@ -66,10 +66,15 @@ def freeze_module(self, module_name: str, module_map: Dict[str, Union[nn.ModuleL
module_map: Dictionary mapping from module_name to corresponding module(s)
"""
modules = module_map[module_name]

if module_name == "pe_encoders":
for param in pl_module.model.pretrained_model.net.encoder_manager.parameters():
param.requires_grad = False

# We only partially freeze the finetuning module
if module_name.startswith(self.finetuning_module):
modules = modules[: -self.training_depth]
if self.training_depth > 0:
modules = modules[: -self.training_depth]

self.freeze(modules=modules, train_bn=self.train_bn)

Expand Down
10 changes: 5 additions & 5 deletions graphium/finetuning/finetuning_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __init__(
super().__init__()

# Load pretrained model
pretrained_model = PredictorModule.load_pretrained_model(pretrained_model).model
pretrained_model = PredictorModule.load_pretrained_model(pretrained_model, device='cpu').model
pretrained_model.create_module_map()

# Initialize new model with architecture after desired modifications to architecture.
Expand All @@ -219,7 +219,7 @@ def overwrite_with_pretrained(
self,
pretrained_model,
finetuning_module: str,
added_depth: int,
added_depth: int = 0,
sub_module_from_pretrained: str = None,
):
"""
Expand All @@ -236,7 +236,7 @@ def overwrite_with_pretrained(

module_names_from_pretrained = module_map_from_pretrained.keys()
super_module_names_from_pretrained = set(
[module_name.split("/")[0] for module_name in module_names_from_pretrained]
[module_name.split("-")[0] for module_name in module_names_from_pretrained]
)

for module_name in module_map.keys():
Expand All @@ -254,10 +254,10 @@ def overwrite_with_pretrained(
if module_name in module_map_from_pretrained.keys():
for idx in range(shared_depth):
module_map[module_name][idx] = module_map_from_pretrained[module_name][idx]
elif module_name.split("/")[0] in super_module_names_from_pretrained:
elif module_name.split("-")[0] in super_module_names_from_pretrained:
for idx in range(shared_depth):
module_map[module_name][idx] = module_map_from_pretrained[
"".join([module_name.split("/")[0], "/", sub_module_from_pretrained])
"".join([module_name.split("-")[0], "-", sub_module_from_pretrained])
][idx]
else:
raise RuntimeError("Mismatch between loaded pretrained model and model to be overwritten.")
Expand Down
Loading

0 comments on commit 36930e1

Please sign in to comment.