Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning pipeline #414

Merged
merged 57 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
de1c1d7
Adding finetuning pipeline (v1)
Jul 21, 2023
47d2936
Adding finetuning pipeline (v2)
Jul 21, 2023
4fe9e63
Merge branch 'main' into explore_finetuning
DomInvivo Jul 23, 2023
6fbafad
Merge branch 'datamol-io:main' into explore_finetuning
WenkelF Jul 26, 2023
5dbcf2a
Removing redundant files
Jul 26, 2023
89f765b
Removing redundant files
Jul 26, 2023
815ae34
Integrating finetuning with hydra
Jul 26, 2023
fa53287
Merge pull request #1 from WenkelF/explore_finetuning
WenkelF Jul 27, 2023
bc8b8c8
Implementing separate class FullFinetuningNetwork
Jul 28, 2023
4b9e6ec
Moving parameter overwriting logic from FullGraphMultitaskNetwork to …
Jul 28, 2023
243d9c8
Extending modify_cfg_for_finetuning function also to gnn and graph_ou…
Jul 28, 2023
d9f9d47
Fixing (un)freezing logic and extending to finetuning from graph_outp…
Jul 28, 2023
65811b8
Adding preliminary unit test for finetuning
Jul 29, 2023
064c610
Adding preliminary unit test for finetuning
Jul 29, 2023
c6836c0
Reformatting with black
Jul 29, 2023
b3e7910
Fixing different devices in tests/test_finetuning.py
Jul 31, 2023
76e2ba6
Addressing comments
Aug 1, 2023
d02fd93
Addressing comments
Aug 1, 2023
4da7089
Reformatting with black
Aug 1, 2023
2ea1535
Updating preliminary hydra configs
Aug 1, 2023
86e6f70
Comminging intermediate changes
Aug 2, 2023
0f33aed
Changing datahash for TDC benchmarks
Aug 2, 2023
0bbf244
Updating finetuning pipeling
Aug 2, 2023
03c65a1
Reformatting with black
Aug 2, 2023
d6d7aad
Merge pull request #6 from datamol-io/main
WenkelF Aug 2, 2023
5c3effc
Merge pull request #7 from WenkelF/main
WenkelF Aug 2, 2023
7efe817
WIP: Migrating to the new Hydra configs for fine-tuning
cwognum Aug 2, 2023
7337762
Optimizing module map for overwriting and freezing when finetuning
Aug 3, 2023
e820053
Fixing bug in data_hash; skipping affected unit tests when PyTDC pack…
Aug 3, 2023
91d7197
Fixing bug in data_hash function
Aug 3, 2023
e24acf0
Migrated to new Hydra configs and CLI
cwognum Aug 3, 2023
c0609c4
Resolve merge conflicts
cwognum Aug 3, 2023
692c26d
Everything until the freezing now works
cwognum Aug 3, 2023
ca79e45
Finishing (un)freezing weights for finetuning
Aug 3, 2023
4c9258f
Merged
cwognum Aug 3, 2023
ad58903
Fixed some bugs with WandB
cwognum Aug 3, 2023
fd21c87
Added back to old CLI
cwognum Aug 3, 2023
d919dc2
Don't overwrite task-heads during fine-tuning
cwognum Aug 3, 2023
e496f3d
Fixed failing test cases. Added script to loop over all ADMET benchmarks
cwognum Aug 3, 2023
f0f999b
Save results to YAML
cwognum Aug 3, 2023
53a8e57
Merge pull request #10 from WenkelF/explore_finetuning_cas
WenkelF Aug 3, 2023
feed196
Finishing unit test for finetuning
Aug 4, 2023
6463e48
Minor changes and updating doc strings
Aug 4, 2023
da0d058
Fixing doc string
Aug 4, 2023
a3d4715
Making predictor model-unspecific
Aug 7, 2023
529362a
Reformatting with black
Aug 7, 2023
98db36c
Saving featurization to predictor checkpoint and load from pretrained…
Aug 8, 2023
cbf2ad7
Documentation pass and address PR review
cwognum Aug 8, 2023
38748f9
Merge pull request #12 from WenkelF/finetuning_docs
cwognum Aug 8, 2023
3e393d8
Updated deprecation warning
cwognum Aug 8, 2023
119a862
Minor change to __main__.py
cwognum Aug 8, 2023
febdf2d
Fixing bug in finetuing training
Aug 8, 2023
24354ee
Finishing touches
Aug 9, 2023
b97289a
Reformatting
Aug 9, 2023
014b6a3
samuelm - Testing change RE CLI for IPU
s-maddrellmander Aug 9, 2023
7d2eb5c
Revert
s-maddrellmander Aug 9, 2023
6625e0c
Merge remote-tracking branch 'origin/main' into explore_finetuning
DomInvivo Aug 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions expts/hydra-configs/dataset/accelerator/lipophilicity_gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# @package _global_

accelerator:
float32_matmul_precision: medium

architecture:
task_heads:
tox21:
last_activation: none

datamodule:
args:
batch_size_training: 200
batch_size_inference: 200
featurization_n_jobs: 0
num_workers: 0

predictor:
optim_kwargs: {}
# metrics_every_n_steps: 300
torch_scheduler_kwargs:
max_num_epochs: &max_epochs 300

trainer:
trainer:
accumulate_grad_batches: 1
max_epochs: *max_epochs
309 changes: 309 additions & 0 deletions expts/hydra-configs/dataset/lipophilicity.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
# @package _global_

architecture:
model_type: FullGraphFinetuningNetwork # FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn:
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none

pre_nn_edges: null

pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"

gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 96
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
layer_kwargs: null # Parameters for the model itself. You could define dropout_attn: 0.1

graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none

task_heads:
qm9:
task_level: graph
out_dim: 19
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
tox21:
task_level: graph
out_dim: 12
hidden_dims: 64
depth: 2
activation: relu
last_activation: sigmoid
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
zinc:
task_level: graph
out_dim: 3
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense. We should not have any architectural choice from the original pre-trained model in here. Only things that would change.

That way, we can take different pre-trained models that have different hparams/seed and fine-tune them all with the same file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. The configurations are still structured in a way where we have access to both the full config of the pretrained model and the pretraining-related config. And the modify_cfg_for_finetuning function consolidates information to one config.

This will be fixed once we incorporate the new hydra config from #421. We will still need modify_cfg_for_finetuning as of now. Therefore, it could be good waiting for the final version.

dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
###########################
last_layer_is_readout: false
###########################

finetuning_head: # none
task: lipophilicity_astrazeneca
previous_module: task_heads
incoming_level: graph
model_type: mlp
in_dim: 8
out_dim: 1
hidden_dims: 8
depth: 2
last_layer_is_readout: true

predictor:
### Changes for finetuning ##############################
metrics_on_progress_bar:
lipophilicity_astrazeneca: ["mae"]
loss_fun:
lipophilicity_astrazeneca: mae
#########################################################
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
target_nan_mask: null
multitask_handling: flatten # flatten, mean-per-label

### Changes for finetuning ##############################
metrics:
lipophilicity_astrazeneca:
- name: mae
metric: mae
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: spearman
metric: spearmanr
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: pearson
metric: pearsonr
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
#########################################################
# qm9: &qm9_metrics
# - name: mae
# metric: mae_ipu
# target_nan_mask: null
# multitask_handling: flatten
# threshold_kwargs: null
# - name: pearsonr
# metric: pearsonr_ipu
# threshold_kwargs: null
# target_nan_mask: null
# multitask_handling: mean-per-label
# - name: r2_score
# metric: r2_score_ipu
# target_nan_mask: null
# multitask_handling: mean-per-label
# threshold_kwargs: null
# tox21:
# - name: auroc
# metric: auroc_ipu
# task: binary
# multitask_handling: mean-per-label
# threshold_kwargs: null
# - name: avpr
# metric: average_precision_ipu
# task: binary
# multitask_handling: mean-per-label
# threshold_kwargs: null
# - name: f1 > 0.5
# metric: f1
# multitask_handling: mean-per-label
# target_to_int: True
# num_classes: 2
# average: micro
# threshold_kwargs: &threshold_05
# operator: greater
# threshold: 0.5
# th_on_preds: True
# th_on_target: True
# - name: precision > 0.5
# metric: precision
# multitask_handling: mean-per-label
# average: micro
# threshold_kwargs: *threshold_05
# zinc: *qm9_metrics

trainer:
seed: ${constants.seed}
logger:
save_dir: logs/neurips2023-small/
name: ${constants.name}
project: ${constants.name}
model_checkpoint:
dirpath: saved_models/pretrained_models/
filename: dummy-pretrained-model-{epoch}
save_on_train_epoch_end: true
trainer:
precision: 16
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20

datamodule:
### Changes for finetuning ##############################
module_type: "ADMETBenchmarkDataModule"
args:
# TDC specific
tdc_benchmark_names: [lipophilicity_astrazeneca]
tdc_train_val_seed: ${constants.seed}
#########################################################
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-small/"
num_workers: 30 # -1 to use all
persistent_workers: False
# task_specific_args:
# qm9:
# df: null
# df_path: ${constants.data_dir}/qm9.csv.gz
# # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz
# # or set path as the URL directly
# smiles_col: "smiles"
# label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"]
# # sample_size: 2000 # use sample_size for test
# splits_path: ${constants.data_dir}/qm9_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt`
# seed: ${constants.seed} #*seed
# task_level: graph
# label_normalization:
# normalize_val_test: True
# method: "normal"

# tox21:
# df: null
# df_path: ${constants.data_dir}/Tox21-7k-12-labels.csv.gz
# # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz
# # or set path as the URL directly
# smiles_col: "smiles"
# label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"]
# # sample_size: 2000 # use sample_size for test
# splits_path: ${constants.data_dir}/Tox21_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt`
# seed: ${constants.seed}
# task_level: graph

# zinc:
# df: null
# df_path: ${constants.data_dir}/ZINC12k.csv.gz
# # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz
# # or set path as the URL directly
# smiles_col: "smiles"
# label_cols: ["SA", "logp", "score"]
# # sample_size: 2000 # use sample_size for test
# splits_path: ${constants.data_dir}/ZINC12k_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt`
# seed: ${constants.seed}
# task_level: graph
# label_normalization:
# normalize_val_test: True
# method: "normal"
featurization:
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features:
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # normalization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # normalization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
15 changes: 15 additions & 0 deletions expts/hydra-configs/experiment/lipophilicity_gcn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# @package _global_

constants:
name: neurips2023_small_data_gcn
entity: "multitask-gnn"
seed: 42
max_epochs: 100
data_dir: expts/data/neurips2023/small-dataset
raise_train_error: true

trainer:
model_checkpoint:
dirpath: saved_models/finetuned_models/
filename: dummy-finetuned-model-{epoch}
save_on_train_epoch_end: true
9 changes: 9 additions & 0 deletions expts/hydra-configs/finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- accelerator: ipu
- dataset: toymix
- model: gcn

# Specializations
- experiment: ${dataset}_${model}
- dataset/accelerator: ${dataset}_${accelerator}
- finetuning: finetuning
Loading