Skip to content

Commit

Permalink
Merge pull request #423 from datamol-io/fix_configs
Browse files Browse the repository at this point in the history
Fix neurips configs after hydra integration
  • Loading branch information
DomInvivo authored Aug 1, 2023
2 parents ddec593 + 89ba022 commit debe62f
Show file tree
Hide file tree
Showing 15 changed files with 862 additions and 890 deletions.
423 changes: 423 additions & 0 deletions expts/neurips2023_configs/base_config/large.yaml

Large diffs are not rendered by default.

343 changes: 343 additions & 0 deletions expts/neurips2023_configs/base_config/small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
# @package _global_

constants:
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
entity: multitask-gnn

accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 44 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 80
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 44 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 80
# Data handling-related
batch_size_training: 50
batch_size_inference: 50
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 16
accumulate_grad_batches: 4

ipu_config:
- deviceIterations(5) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)

# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1

datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
qm9:
df: null
df_path: data/neurips2023/small-dataset/qm9.csv.gz
# wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"]
# sample_size: 2000 # use sample_size for test
splits_path: data/neurips2023/small-dataset/qm9_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt`
seed: *seed
task_level: graph
label_normalization:
normalize_val_test: True
method: "normal"

tox21:
df: null
df_path: data/neurips2023/small-dataset/Tox21-7k-12-labels.csv.gz
# wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"]
# sample_size: 2000 # use sample_size for test
splits_path: data/neurips2023/small-dataset/Tox21_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt`
seed: *seed
task_level: graph

zinc:
df: null
df_path: data/neurips2023/small-dataset/ZINC12k.csv.gz
# wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["SA", "logp", "score"]
# sample_size: 2000 # use sample_size for test
splits_path: data/neurips2023/small-dataset/ZINC12k_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt`
seed: *seed
task_level: graph
label_normalization:
normalize_val_test: True
method: "normal"

# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-small/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16

# cache_data_path: .
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.


architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none

pre_nn_edges: null # Set as null to avoid a pre-nn network

pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"



gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 96
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_kwargs: null # Parameters for the model itself. You could define dropout_attn: 0.1


graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none

task_heads:
qm9:
task_level: graph
out_dim: 19
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
tox21:
task_level: graph
out_dim: 12
hidden_dims: 64
depth: 2
activation: relu
last_activation: sigmoid
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
zinc:
task_level: graph
out_dim: 3
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none

#Task-specific
predictor:
metrics_on_progress_bar:
qm9: ["mae"]
tox21: ["auroc"]
zinc: ["mae"]
loss_fun:
qm9: mae_ipu
tox21: bce_ipu
zinc: mae_ipu
random_seed: *seed
optim_kwargs:
lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label

# Task-specific
metrics:
qm9: &qm9_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
tox21:
- name: auroc
metric: auroc_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: f1 > 0.5
metric: f1
multitask_handling: mean-per-label
target_to_int: True
num_classes: 2
average: micro
threshold_kwargs: &threshold_05
operator: greater
threshold: 0.5
th_on_preds: True
th_on_target: True
- name: precision > 0.5
metric: precision
multitask_handling: mean-per-label
average: micro
threshold_kwargs: *threshold_05
zinc: *qm9_metrics

trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-small/
name: ${constants.name}
project: ${constants.name}
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/${constants.name}/
filename: ${constants.name}
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
Loading

0 comments on commit debe62f

Please sign in to comment.