-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finetuning pipeline #414
Merged
Merged
Finetuning pipeline #414
Changes from 14 commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
de1c1d7
Adding finetuning pipeline (v1)
47d2936
Adding finetuning pipeline (v2)
4fe9e63
Merge branch 'main' into explore_finetuning
DomInvivo 6fbafad
Merge branch 'datamol-io:main' into explore_finetuning
WenkelF 5dbcf2a
Removing redundant files
89f765b
Removing redundant files
815ae34
Integrating finetuning with hydra
fa53287
Merge pull request #1 from WenkelF/explore_finetuning
WenkelF bc8b8c8
Implementing separate class FullFinetuningNetwork
4b9e6ec
Moving parameter overwriting logic from FullGraphMultitaskNetwork to …
243d9c8
Extending modify_cfg_for_finetuning function also to gnn and graph_ou…
d9f9d47
Fixing (un)freezing logic and extending to finetuning from graph_outp…
65811b8
Adding preliminary unit test for finetuning
064c610
Adding preliminary unit test for finetuning
c6836c0
Reformatting with black
b3e7910
Fixing different devices in tests/test_finetuning.py
76e2ba6
Addressing comments
d02fd93
Addressing comments
4da7089
Reformatting with black
2ea1535
Updating preliminary hydra configs
86e6f70
Comminging intermediate changes
0f33aed
Changing datahash for TDC benchmarks
0bbf244
Updating finetuning pipeling
03c65a1
Reformatting with black
d6d7aad
Merge pull request #6 from datamol-io/main
WenkelF 5c3effc
Merge pull request #7 from WenkelF/main
WenkelF 7efe817
WIP: Migrating to the new Hydra configs for fine-tuning
cwognum 7337762
Optimizing module map for overwriting and freezing when finetuning
e820053
Fixing bug in data_hash; skipping affected unit tests when PyTDC pack…
91d7197
Fixing bug in data_hash function
e24acf0
Migrated to new Hydra configs and CLI
cwognum c0609c4
Resolve merge conflicts
cwognum 692c26d
Everything until the freezing now works
cwognum ca79e45
Finishing (un)freezing weights for finetuning
4c9258f
Merged
cwognum ad58903
Fixed some bugs with WandB
cwognum fd21c87
Added back to old CLI
cwognum d919dc2
Don't overwrite task-heads during fine-tuning
cwognum e496f3d
Fixed failing test cases. Added script to loop over all ADMET benchmarks
cwognum f0f999b
Save results to YAML
cwognum 53a8e57
Merge pull request #10 from WenkelF/explore_finetuning_cas
WenkelF feed196
Finishing unit test for finetuning
6463e48
Minor changes and updating doc strings
da0d058
Fixing doc string
a3d4715
Making predictor model-unspecific
529362a
Reformatting with black
98db36c
Saving featurization to predictor checkpoint and load from pretrained…
cbf2ad7
Documentation pass and address PR review
cwognum 38748f9
Merge pull request #12 from WenkelF/finetuning_docs
cwognum 3e393d8
Updated deprecation warning
cwognum 119a862
Minor change to __main__.py
cwognum febdf2d
Fixing bug in finetuing training
24354ee
Finishing touches
b97289a
Reformatting
014b6a3
samuelm - Testing change RE CLI for IPU
s-maddrellmander 7d2eb5c
Revert
s-maddrellmander 6625e0c
Merge remote-tracking branch 'origin/main' into explore_finetuning
DomInvivo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
27 changes: 27 additions & 0 deletions
27
expts/hydra-configs/dataset/accelerator/lipophilicity_gpu.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# @package _global_ | ||
|
||
accelerator: | ||
float32_matmul_precision: medium | ||
|
||
architecture: | ||
task_heads: | ||
tox21: | ||
last_activation: none | ||
|
||
datamodule: | ||
args: | ||
batch_size_training: 200 | ||
batch_size_inference: 200 | ||
featurization_n_jobs: 0 | ||
num_workers: 0 | ||
|
||
predictor: | ||
optim_kwargs: {} | ||
# metrics_every_n_steps: 300 | ||
torch_scheduler_kwargs: | ||
max_num_epochs: &max_epochs 300 | ||
|
||
trainer: | ||
trainer: | ||
accumulate_grad_batches: 1 | ||
max_epochs: *max_epochs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,309 @@ | ||
# @package _global_ | ||
|
||
architecture: | ||
model_type: FullGraphFinetuningNetwork # FullGraphMultiTaskNetwork | ||
mup_base_path: null | ||
pre_nn: | ||
out_dim: 64 | ||
hidden_dims: 256 | ||
depth: 2 | ||
activation: relu | ||
last_activation: none | ||
dropout: &dropout 0.18 | ||
normalization: &normalization layer_norm | ||
last_normalization: *normalization | ||
residual_type: none | ||
|
||
pre_nn_edges: null | ||
|
||
pe_encoders: | ||
out_dim: 32 | ||
pool: "sum" #"mean" "max" | ||
last_norm: None #"batch_norm", "layer_norm" | ||
encoders: #la_pos | rw_pos | ||
la_pos: # Set as null to avoid a pre-nn network | ||
encoder_type: "laplacian_pe" | ||
input_keys: ["laplacian_eigvec", "laplacian_eigval"] | ||
output_keys: ["feat"] | ||
hidden_dim: 64 | ||
out_dim: 32 | ||
model_type: 'DeepSet' #'Transformer' or 'DeepSet' | ||
num_layers: 2 | ||
num_layers_post: 1 # Num. layers to apply after pooling | ||
dropout: 0.1 | ||
first_normalization: "none" #"batch_norm" or "layer_norm" | ||
rw_pos: | ||
encoder_type: "mlp" | ||
input_keys: ["rw_return_probs"] | ||
output_keys: ["feat"] | ||
hidden_dim: 64 | ||
out_dim: 32 | ||
num_layers: 2 | ||
dropout: 0.1 | ||
normalization: "layer_norm" #"batch_norm" or "layer_norm" | ||
first_normalization: "layer_norm" #"batch_norm" or "layer_norm" | ||
|
||
gnn: # Set as null to avoid a post-nn network | ||
in_dim: 64 # or otherwise the correct value | ||
out_dim: &gnn_dim 96 | ||
hidden_dims: *gnn_dim | ||
depth: 4 | ||
activation: gelu | ||
last_activation: none | ||
dropout: 0.1 | ||
normalization: "layer_norm" | ||
last_normalization: *normalization | ||
residual_type: simple | ||
virtual_node: 'none' | ||
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps | ||
layer_kwargs: null # Parameters for the model itself. You could define dropout_attn: 0.1 | ||
|
||
graph_output_nn: | ||
graph: | ||
pooling: [sum] | ||
out_dim: *gnn_dim | ||
hidden_dims: *gnn_dim | ||
depth: 1 | ||
activation: relu | ||
last_activation: none | ||
dropout: *dropout | ||
normalization: *normalization | ||
last_normalization: "none" | ||
residual_type: none | ||
|
||
task_heads: | ||
qm9: | ||
task_level: graph | ||
out_dim: 19 | ||
hidden_dims: 128 | ||
depth: 2 | ||
activation: relu | ||
last_activation: none | ||
dropout: *dropout | ||
normalization: *normalization | ||
last_normalization: "none" | ||
residual_type: none | ||
tox21: | ||
task_level: graph | ||
out_dim: 12 | ||
hidden_dims: 64 | ||
depth: 2 | ||
activation: relu | ||
last_activation: sigmoid | ||
dropout: *dropout | ||
normalization: *normalization | ||
last_normalization: "none" | ||
residual_type: none | ||
zinc: | ||
task_level: graph | ||
out_dim: 3 | ||
hidden_dims: 32 | ||
depth: 2 | ||
activation: relu | ||
last_activation: none | ||
dropout: *dropout | ||
normalization: *normalization | ||
last_normalization: "none" | ||
residual_type: none | ||
########################### | ||
last_layer_is_readout: false | ||
########################### | ||
|
||
finetuning_head: # none | ||
task: lipophilicity_astrazeneca | ||
previous_module: task_heads | ||
incoming_level: graph | ||
model_type: mlp | ||
in_dim: 8 | ||
out_dim: 1 | ||
hidden_dims: 8 | ||
depth: 2 | ||
last_layer_is_readout: true | ||
|
||
predictor: | ||
### Changes for finetuning ############################## | ||
metrics_on_progress_bar: | ||
lipophilicity_astrazeneca: ["mae"] | ||
loss_fun: | ||
lipophilicity_astrazeneca: mae | ||
######################################################### | ||
random_seed: ${constants.seed} | ||
optim_kwargs: | ||
lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs | ||
# weight_decay: 1.e-7 | ||
torch_scheduler_kwargs: | ||
module_type: WarmUpLinearLR | ||
max_num_epochs: &max_epochs 100 | ||
warmup_epochs: 10 | ||
verbose: False | ||
scheduler_kwargs: | ||
target_nan_mask: null | ||
multitask_handling: flatten # flatten, mean-per-label | ||
|
||
### Changes for finetuning ############################## | ||
metrics: | ||
lipophilicity_astrazeneca: | ||
- name: mae | ||
metric: mae | ||
target_nan_mask: null | ||
multitask_handling: flatten | ||
threshold_kwargs: null | ||
- name: spearman | ||
metric: spearmanr | ||
threshold_kwargs: null | ||
target_nan_mask: null | ||
multitask_handling: mean-per-label | ||
- name: pearson | ||
metric: pearsonr | ||
threshold_kwargs: null | ||
target_nan_mask: null | ||
multitask_handling: mean-per-label | ||
- name: r2_score | ||
metric: r2 | ||
target_nan_mask: null | ||
multitask_handling: mean-per-label | ||
threshold_kwargs: null | ||
######################################################### | ||
# qm9: &qm9_metrics | ||
# - name: mae | ||
# metric: mae_ipu | ||
# target_nan_mask: null | ||
# multitask_handling: flatten | ||
# threshold_kwargs: null | ||
# - name: pearsonr | ||
# metric: pearsonr_ipu | ||
# threshold_kwargs: null | ||
# target_nan_mask: null | ||
# multitask_handling: mean-per-label | ||
# - name: r2_score | ||
# metric: r2_score_ipu | ||
# target_nan_mask: null | ||
# multitask_handling: mean-per-label | ||
# threshold_kwargs: null | ||
# tox21: | ||
# - name: auroc | ||
# metric: auroc_ipu | ||
# task: binary | ||
# multitask_handling: mean-per-label | ||
# threshold_kwargs: null | ||
# - name: avpr | ||
# metric: average_precision_ipu | ||
# task: binary | ||
# multitask_handling: mean-per-label | ||
# threshold_kwargs: null | ||
# - name: f1 > 0.5 | ||
# metric: f1 | ||
# multitask_handling: mean-per-label | ||
# target_to_int: True | ||
# num_classes: 2 | ||
# average: micro | ||
# threshold_kwargs: &threshold_05 | ||
# operator: greater | ||
# threshold: 0.5 | ||
# th_on_preds: True | ||
# th_on_target: True | ||
# - name: precision > 0.5 | ||
# metric: precision | ||
# multitask_handling: mean-per-label | ||
# average: micro | ||
# threshold_kwargs: *threshold_05 | ||
# zinc: *qm9_metrics | ||
|
||
trainer: | ||
seed: ${constants.seed} | ||
logger: | ||
save_dir: logs/neurips2023-small/ | ||
name: ${constants.name} | ||
project: ${constants.name} | ||
model_checkpoint: | ||
dirpath: saved_models/pretrained_models/ | ||
filename: dummy-pretrained-model-{epoch} | ||
save_on_train_epoch_end: true | ||
trainer: | ||
precision: 16 | ||
max_epochs: *max_epochs | ||
min_epochs: 1 | ||
check_val_every_n_epoch: 20 | ||
|
||
datamodule: | ||
### Changes for finetuning ############################## | ||
module_type: "ADMETBenchmarkDataModule" | ||
args: | ||
# TDC specific | ||
tdc_benchmark_names: [lipophilicity_astrazeneca] | ||
tdc_train_val_seed: ${constants.seed} | ||
######################################################### | ||
prepare_dict_or_graph: pyg:graph | ||
featurization_n_jobs: 30 | ||
featurization_progress: True | ||
featurization_backend: "loky" | ||
processed_graph_data_path: "../datacache/neurips2023-small/" | ||
num_workers: 30 # -1 to use all | ||
persistent_workers: False | ||
# task_specific_args: | ||
# qm9: | ||
# df: null | ||
# df_path: ${constants.data_dir}/qm9.csv.gz | ||
# # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz | ||
# # or set path as the URL directly | ||
# smiles_col: "smiles" | ||
# label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"] | ||
# # sample_size: 2000 # use sample_size for test | ||
# splits_path: ${constants.data_dir}/qm9_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt` | ||
# seed: ${constants.seed} #*seed | ||
# task_level: graph | ||
# label_normalization: | ||
# normalize_val_test: True | ||
# method: "normal" | ||
|
||
# tox21: | ||
# df: null | ||
# df_path: ${constants.data_dir}/Tox21-7k-12-labels.csv.gz | ||
# # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz | ||
# # or set path as the URL directly | ||
# smiles_col: "smiles" | ||
# label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"] | ||
# # sample_size: 2000 # use sample_size for test | ||
# splits_path: ${constants.data_dir}/Tox21_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt` | ||
# seed: ${constants.seed} | ||
# task_level: graph | ||
|
||
# zinc: | ||
# df: null | ||
# df_path: ${constants.data_dir}/ZINC12k.csv.gz | ||
# # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz | ||
# # or set path as the URL directly | ||
# smiles_col: "smiles" | ||
# label_cols: ["SA", "logp", "score"] | ||
# # sample_size: 2000 # use sample_size for test | ||
# splits_path: ${constants.data_dir}/ZINC12k_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt` | ||
# seed: ${constants.seed} | ||
# task_level: graph | ||
# label_normalization: | ||
# normalize_val_test: True | ||
# method: "normal" | ||
featurization: | ||
atom_property_list_onehot: [atomic-number, group, period, total-valence] | ||
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring] | ||
edge_property_list: [bond-type-onehot, stereo, in-ring] | ||
add_self_loop: False | ||
explicit_H: False # if H is included | ||
use_bonds_weights: False | ||
pos_encoding_as_features: | ||
pos_types: | ||
lap_eigvec: | ||
pos_level: node | ||
pos_type: laplacian_eigvec | ||
num_pos: 8 | ||
normalization: "none" # normalization already applied on the eigen vectors | ||
disconnected_comp: True # if eigen values/vector for disconnected graph are included | ||
lap_eigval: | ||
pos_level: node | ||
pos_type: laplacian_eigval | ||
num_pos: 8 | ||
normalization: "none" # normalization already applied on the eigen vectors | ||
disconnected_comp: True # if eigen values/vector for disconnected graph are included | ||
rw_pos: # use same name as pe_encoder | ||
pos_level: node | ||
pos_type: rw_return_probs | ||
ksteps: 16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# @package _global_ | ||
|
||
constants: | ||
name: neurips2023_small_data_gcn | ||
entity: "multitask-gnn" | ||
seed: 42 | ||
max_epochs: 100 | ||
data_dir: expts/data/neurips2023/small-dataset | ||
raise_train_error: true | ||
|
||
trainer: | ||
model_checkpoint: | ||
dirpath: saved_models/finetuned_models/ | ||
filename: dummy-finetuned-model-{epoch} | ||
save_on_train_epoch_end: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
defaults: | ||
- accelerator: ipu | ||
- dataset: toymix | ||
- model: gcn | ||
|
||
# Specializations | ||
- experiment: ${dataset}_${model} | ||
- dataset/accelerator: ${dataset}_${accelerator} | ||
- finetuning: finetuning |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't make sense. We should not have any architectural choice from the original pre-trained model in here. Only things that would change.
That way, we can take different pre-trained models that have different hparams/seed and fine-tune them all with the same file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. The configurations are still structured in a way where we have access to both the full config of the pretrained model and the pretraining-related config. And the modify_cfg_for_finetuning function consolidates information to one config.
This will be fixed once we incorporate the new hydra config from #421. We will still need modify_cfg_for_finetuning as of now. Therefore, it could be good waiting for the final version.