Skip to content

Commit

Permalink
Further changes to finetuning + hparam search pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
cwognum committed Aug 16, 2023
1 parent 713337c commit 994d2d4
Show file tree
Hide file tree
Showing 12 changed files with 162 additions and 148 deletions.
8 changes: 4 additions & 4 deletions expts/hydra-configs/finetuning/admet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ constants:
# For now, we assume a model is always fine-tuned on a single task at a time.
# You can override this value with any of the benchmark names in the TDC benchmark suite.
# See also https://tdcommons.ai/benchmark/admet_group/overview/
task: &task lipophilicity_astrazeneca
task: lipophilicity_astrazeneca

name: finetuning_${constants.task}_gcn
wandb:
name: ${constants.name}
project: *task
project: ${constants.task}
entity: multitask-gnn
save_dir: logs/${constants.task}
seed: 42
max_epochs: 10
max_epochs: 100
data_dir: expts/data/admet/${constants.task}
raise_train_error: true

Expand All @@ -57,7 +57,7 @@ finetuning:
level: graph

# Pretrained model
pretrained_model_name: dummy-pretrained-model
pretrained_model: dummy-pretrained-model
finetuning_module: task_heads # gnn
sub_module_from_pretrained: zinc # optional
new_sub_module: lipophilicity_astrazeneca # optional
Expand Down
4 changes: 2 additions & 2 deletions expts/hydra-configs/hparam_search/optuna.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ hydra:
direction: minimize
study_name: ${constants.name}
storage: null
n_trials: 20
n_trials: 100
n_jobs: 1

# The hyper-parameter search space definition
# See https://hydra.cc/docs/plugins/optuna_sweeper/#search-space-configuration for the options
params:
constants.seed: choice(0, 42)
predictor.optim_kwargs.lr: tag(log, interval(0.00001, 0.001))

68 changes: 57 additions & 11 deletions graphium/cli/train_finetune.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,38 @@
import hydra
import wandb
import os
import time
import timeit
from datetime import datetime

import fsspec
import hydra
import torch
import wandb
import yaml
from datamol.utils import fs
from hydra.core.hydra_config import HydraConfig
from hydra.types import RunMode
from omegaconf import DictConfig, OmegaConf
from loguru import logger
from datetime import datetime
from lightning.pytorch.utilities.model_summary import ModelSummary
from loguru import logger
from omegaconf import DictConfig, OmegaConf

from graphium.config._loader import (
load_accelerator,
load_architecture,
load_datamodule,
load_metrics,
load_architecture,
load_predictor,
load_trainer,
load_accelerator,
save_params_to_wandb,
)
from graphium.hyper_param_search import process_results_for_hyper_param_search, HYPER_PARAM_SEARCH_CONFIG_KEY
from graphium.finetuning import modify_cfg_for_finetuning, GraphFinetuning, FINETUNING_CONFIG_KEY
from graphium.finetuning import (
FINETUNING_CONFIG_KEY,
GraphFinetuning,
modify_cfg_for_finetuning,
)
from graphium.hyper_param_search import (
HYPER_PARAM_SEARCH_CONFIG_KEY,
extract_main_metric_for_hparam_search,
)
from graphium.utils.safe_run import SafeRun


Expand All @@ -38,6 +51,18 @@ def run_training_finetuning(cfg: DictConfig) -> None:

cfg = OmegaConf.to_container(cfg, resolve=True)

dst_dir = cfg["constants"].get("results_dir")
hydra_cfg = HydraConfig.get()
output_dir = hydra_cfg["runtime"]["output_dir"]

if dst_dir is not None and fs.exists(dst_dir) and len(fs.get_mapper(dst_dir).fs.ls(dst_dir)) > 0:
logger.warning(
"The destination directory is not empty. "
"If files already exist, this would lead to a crash at the end of training."
)
# We pause here briefly, to make sure the notification is seen as there's lots of logs afterwards
time.sleep(5)

# Modify the config for finetuning
if FINETUNING_CONFIG_KEY in cfg:
cfg = modify_cfg_for_finetuning(cfg)
Expand Down Expand Up @@ -102,6 +127,12 @@ def run_training_finetuning(cfg: DictConfig) -> None:
with SafeRun(name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
trainer.fit(model=predictor, datamodule=datamodule)

# Save validation metrics - Base utility in case someone doesn't use a logger.
results = trainer.callback_metrics
results = {k: v.item() if torch.is_tensor(v) else v for k, v in results.items()}
with fsspec.open(fs.join(output_dir, "val_results.yaml"), "w") as f:
yaml.dump(results, f)

# Determine the max num nodes and edges in testing
predictor.set_max_nodes_edges_per_graph(datamodule, stages=["test"])

Expand All @@ -116,12 +147,27 @@ def run_training_finetuning(cfg: DictConfig) -> None:
if wandb_cfg is not None:
wandb.finish()

# Save test metrics - Base utility in case someone doesn't use a logger.
results = trainer.callback_metrics
results = {k: v.item() if torch.is_tensor(v) else v for k, v in results.items()}
with fsspec.open(fs.join(output_dir, "test_results.yaml"), "w") as f:
yaml.dump(results, f)

# When part of of a hyper-parameter search, we are very specific about how we save our results
# NOTE (cwognum): We also check if the we are in multi-run mode, as the sweeper is otherwise not active.
if HYPER_PARAM_SEARCH_CONFIG_KEY in cfg and HydraConfig.get().mode == RunMode.MULTIRUN:
results = process_results_for_hyper_param_search(results, cfg[HYPER_PARAM_SEARCH_CONFIG_KEY])
if HYPER_PARAM_SEARCH_CONFIG_KEY in cfg and hydra_cfg.mode == RunMode.MULTIRUN:
results = extract_main_metric_for_hparam_search(results, cfg[HYPER_PARAM_SEARCH_CONFIG_KEY])

# Copy the current working directory to remote
# By default, processes should just write results to Hydra's output directory.
# However, this currently does not support remote storage, which is why we copy the results here if needed.
# For more info, see also: https://github.com/facebookresearch/hydra/issues/993

if dst_dir is not None:
src_dir = hydra_cfg["runtime"]["output_dir"]
dst_dir = fs.join(dst_dir, fs.get_basename(src_dir))
fs.mkdir(dst_dir, exist_ok=True)
fs.copy_dir(src_dir, dst_dir)

return results

Expand Down
36 changes: 17 additions & 19 deletions graphium/config/_loader.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
from typing import Dict, Mapping, Tuple, Type, Union, Any, Optional, Callable

# Misc
import os
import omegaconf
from copy import deepcopy
from loguru import logger
import yaml
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union

import joblib
import pathlib
import warnings
import mup
import omegaconf

# Torch
import torch
import mup
import yaml

# Lightning
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger, Logger
from lightning.pytorch.loggers import Logger, WandbLogger
from loguru import logger

# Graphium
from graphium.utils.mup import set_base_shapes
from graphium.data.datamodule import BaseDataModule, MultitaskFromSmilesDataModule
from graphium.finetuning.finetuning_architecture import FullGraphFinetuningNetwork
from graphium.ipu.ipu_dataloader import IPUDataloaderOptions
from graphium.trainer.metrics import MetricWrapper
from graphium.ipu.ipu_utils import import_poptorch, load_ipu_options
from graphium.nn.architectures import FullGraphMultiTaskNetwork
from graphium.finetuning.finetuning_architecture import FullGraphFinetuningNetwork
from graphium.nn.utils import MupMixin
from graphium.trainer.metrics import MetricWrapper
from graphium.trainer.predictor import PredictorModule
from graphium.utils.command_line_utils import get_anchors_and_aliases, update_config

# Graphium
from graphium.utils.mup import set_base_shapes
from graphium.utils.spaces import DATAMODULE_DICT
from graphium.ipu.ipu_utils import import_poptorch, load_ipu_options
from graphium.data.datamodule import MultitaskFromSmilesDataModule, BaseDataModule
from graphium.utils.command_line_utils import update_config, get_anchors_and_aliases


def get_accelerator(
Expand Down Expand Up @@ -264,12 +263,12 @@ def load_architecture(
if model_class is FullGraphFinetuningNetwork:
finetuning_head_kwargs = config["finetuning"].pop("finetuning_head", None)
pretrained_overwriting_kwargs = config["finetuning"].pop("overwriting_kwargs")
pretrained_model_name = pretrained_overwriting_kwargs.pop("pretrained_model_name")
pretrained_model = pretrained_overwriting_kwargs.pop("pretrained_model")

model_kwargs = {
"pretrained_model_kwargs": deepcopy(model_kwargs),
"pretrained_overwriting_kwargs": pretrained_overwriting_kwargs,
"pretrained_model_name": pretrained_model_name,
"pretrained_model": pretrained_model,
"finetuning_head_kwargs": finetuning_head_kwargs,
}

Expand Down Expand Up @@ -406,7 +405,6 @@ def load_trainer(

# Define the early model checkpoing parameters
if "model_checkpoint" in cfg_trainer.keys():
cfg_trainer["model_checkpoint"]["dirpath"] += str(cfg_trainer["seed"]) + "/"
callbacks.append(ModelCheckpoint(**cfg_trainer["model_checkpoint"]))

# Define the logger parameters
Expand Down
2 changes: 1 addition & 1 deletion graphium/config/dummy_finetuning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ finetuning:
level: graph

# Pretrained model
pretrained_model_name: dummy-pretrained-model
pretrained_model: dummy-pretrained-model
finetuning_module: task_heads
sub_module_from_pretrained: zinc # optional
new_sub_module: lipophilicity_astrazeneca # optional
Expand Down
23 changes: 12 additions & 11 deletions graphium/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Type, List, Dict, Union, Any, Callable, Optional, Tuple, Iterable

from multiprocessing import Manager
import numpy as np
from functools import lru_cache
from loguru import logger
from copy import deepcopy
import os
import numpy as np
from copy import deepcopy
from functools import lru_cache
from multiprocessing import Manager
from typing import Any, Dict, List, Optional, Tuple, Union

import fsspec
import numpy as np
import torch
from loguru import logger
from torch.utils.data.dataloader import Dataset
from torch_geometric.data import Data, Batch
from torch_geometric.data import Batch, Data

from graphium.data.smiles_transform import smiles_to_unique_mol_ids
from graphium.features import GraphDict
Expand Down Expand Up @@ -247,7 +246,8 @@ def _load_metadata(self):
"_num_edges_list",
]
path = os.path.join(self.data_path, "multitask_metadata.pkl")
attrs = torch.load(path)
with fsspec.open(path, "rb") as f:
attrs = torch.load(path)

if not set(attrs_to_load).issubset(set(attrs.keys())):
raise ValueError(
Expand Down Expand Up @@ -409,7 +409,8 @@ def load_graph_from_index(self, data_idx):
filename = os.path.join(
self.data_path, format(data_idx // 1000, "04d"), format(data_idx, "07d") + ".pkl"
)
data_dict = torch.load(filename)
with fsspec.open(filename, "rb") as f:
data_dict = torch.load(f)
return data_dict

def merge(
Expand Down
49 changes: 19 additions & 30 deletions graphium/finetuning/finetuning_architecture.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,23 @@
from typing import Iterable, List, Dict, Tuple, Union, Callable, Any, Optional, Type

from copy import deepcopy

from loguru import logger
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn

from loguru import logger
from torch import Tensor
from torch_geometric.data import Batch

from graphium.data.utils import get_keys
from graphium.nn.base_graph_layer import BaseGraphStructure
from graphium.nn.architectures.encoder_manager import EncoderManager
from graphium.nn.architectures import FullGraphMultiTaskNetwork, FeedForwardNN, FeedForwardPyg, TaskHeads
from graphium.nn.architectures.global_architectures import FeedForwardGraph
from graphium.trainer.predictor_options import ModelOptions
from graphium.nn.utils import MupMixin

from graphium.trainer.predictor import PredictorModule
from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT, FINETUNING_HEADS_DICT
from graphium.utils.spaces import FINETUNING_HEADS_DICT


class FullGraphFinetuningNetwork(nn.Module, MupMixin):
def __init__(
self,
pretrained_model_name: str,
pretrained_model_kwargs: Dict[str, Any],
pretrained_overwriting_kwargs: Dict[str, Any],
pretrained_model: Union[str, "PretrainedModel"],
pretrained_model_kwargs: Dict[str, Any] = {},
pretrained_overwriting_kwargs: Dict[str, Any] = {},
finetuning_head_kwargs: Optional[Dict[str, Any]] = None,
num_inference_to_average: int = 1,
last_layer_is_readout: bool = False,
Expand All @@ -41,8 +31,8 @@ def __init__(
Parameters:
pretrained_model_name:
Identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT
pretrained_model:
A PretrainedModel or an identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT or a valid .ckpt checkpoint path
pretrained_model_kwargs:
Key-word arguments to instantiate a model of the same class as the pretrained model (e.g., FullGraphMultitaskNetwork))
Expand Down Expand Up @@ -79,16 +69,17 @@ def __init__(
self.num_inference_to_average = num_inference_to_average
self.last_layer_is_readout = last_layer_is_readout
self._concat_last_layers = None
self.pretrained_model_name = pretrained_model_name
self.pretrained_model = pretrained_model
self.pretrained_overwriting_kwargs = pretrained_overwriting_kwargs
self.finetuning_head_kwargs = finetuning_head_kwargs
self.max_num_nodes_per_graph = None
self.max_num_edges_per_graph = None
self.finetuning_head = None

self.pretrained_model = PretrainedModel(
pretrained_model_name, pretrained_model_kwargs, pretrained_overwriting_kwargs
)
if not isinstance(self.pretrained_model, PretrainedModel):
self.pretrained_model = PretrainedModel(
self.pretrained_model, pretrained_model_kwargs, pretrained_overwriting_kwargs
)

if finetuning_head_kwargs is not None:
self.finetuning_head = FinetuningHead(finetuning_head_kwargs)
Expand Down Expand Up @@ -147,7 +138,7 @@ def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str, Any]:
Dictionary with the kwargs to create the base model.
"""
kwargs = dict(
pretrained_model_name=self.pretrained_model_name,
pretrained_model=self.pretrained_model,
pretrained_model_kwargs=None,
finetuning_head_kwargs=None,
num_inference_to_average=self.num_inference_to_average,
Expand Down Expand Up @@ -186,18 +177,18 @@ def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], max_edges:
class PretrainedModel(nn.Module, MupMixin):
def __init__(
self,
pretrained_model_name: str,
pretrained_model: str,
pretrained_model_kwargs: Dict[str, Any],
pretrained_overwriting_kwargs: Dict[str, Any],
):
r"""
Flexible class allowing to finetune pretrained models from GRAPHIUM_PRETRAINED_MODELS_DICT.
Flexible class allowing to finetune pretrained models from GRAPHIUM_PRETRAINED_MODELS_DICT or from a ckeckpoint path.
Can be any model that inherits from nn.Module, MupMixin and comes with a module map (e.g., FullGraphMultitaskNetwork)
Parameters:
pretrained_model_name:
Identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT
pretrained_model:
Identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT or from a checkpoint path
pretrained_model_kwargs:
Key-word arguments to instantiate a model of the same class as the pretrained model (e.g., FullGraphMultitaskNetwork))
Expand All @@ -210,9 +201,7 @@ def __init__(
super().__init__()

# Load pretrained model
pretrained_model = PredictorModule.load_from_checkpoint(
GRAPHIUM_PRETRAINED_MODELS_DICT[pretrained_model_name]
).model
pretrained_model = PredictorModule.load_pretrained_models(pretrained_model).model
pretrained_model.create_module_map()

# Initialize new model with architecture after desired modifications to architecture.
Expand Down
Loading

0 comments on commit 994d2d4

Please sign in to comment.