Skip to content

Commit

Permalink
Merge branch 'main' into updates_fred
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo authored Oct 18, 2023
2 parents 3895d23 + 285d07f commit 66abb9a
Show file tree
Hide file tree
Showing 20 changed files with 334 additions and 46 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ To change parameters specific to this experiment like switching from `fp16` to `
```bash
graphium-train dataset=toymix model=gcn trainer.trainer.precision=32
```
or change them permamently in the dedicated experiment config under `expts/hydra-configs/toymix_gcn.yaml`.
or change them permanently in the dedicated experiment config under `expts/hydra-configs/toymix_gcn.yaml`.
Integrating `hydra` also allows you to quickly switch between accelerators. E.g., running
```bash
graphium-train dataset=toymix model=gcn accelerator=gpu
Expand Down
52 changes: 47 additions & 5 deletions docs/baseline.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,33 @@ From the paper to be released soon. Below, you can see the baselines for the `To

One can observe that the smaller datasets (`Zinc12k` and `Tox21`) beneficiate from adding another unrelated task (`QM9`), where the labels are computed from DFT simulations.

**NEW baselines added 2023/09/18**: Multitask baselines have been added for GatedGCN and MPNN++ (sum aggretator) using 3 random seeds. They achieve the best performance by a significant margin on Zinc12k and Tox21, while sacrificing a little on QM9.

| Dataset | Model | MAE ↓ | Pearson ↑ | R² ↑ | MAE ↓ | Pearson ↑ | R² ↑ |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
| | <th colspan="3" style="text-align: center;">Single-Task Model</th> <th colspan="3" style="text-align: center;">Multi-Task Model</th> |
|
| **QM9** | GCN | 0.102 ± 0.0003 | 0.958 ± 0.0007 | 0.920 ± 0.002 | 0.119 ± 0.01 | 0.955 ± 0.001 | 0.915 ± 0.001 |
| | GIN | 0.0976 ± 0.0006 | **0.959 ± 0.0002** | **0.922 ± 0.0004** | 0.117 ± 0.01 | 0.950 ± 0.002 | 0.908 ± 0.003 |
| | GINE | **0.0959 ± 0.0002** | 0.955 ± 0.002 | 0.918 ± 0.004 | 0.102 ± 0.01 | 0.956 ± 0.0009 | 0.918 ± 0.002 |
|
| **Zinc12k** | GCN | 0.348 ± 0.02 | 0.941 ± 0.002 | 0.863 ± 0.01 | 0.226 ± 0.004 | 0.973 ± 0.0005 | 0.940 ± 0.003 |
| | GatedGCN | | | | 0.1212 ± 0.0009 | 0.9457 ± 0.0002 | 0.8964 ± 0.0006 |
| | MPNN++ (sum) | | | | 0.1174 ± 0.0012 | 0.9460 ± 0.0005 | 0.8989 ± 0.0008 |
**Zinc12k** | GCN | 0.348 ± 0.02 | 0.941 ± 0.002 | 0.863 ± 0.01 | 0.226 ± 0.004 | 0.973 ± 0.0005 | 0.940 ± 0.003 |
| | GIN | 0.303 ± 0.007 | 0.950 ± 0.003 | 0.889 ± 0.003 | 0.189 ± 0.004 | 0.978 ± 0.006 | 0.953 ± 0.002 |
| | GINE | 0.266 ± 0.02 | 0.961 ± 0.003 | 0.915 ± 0.01 | **0.147 ± 0.009** | **0.987 ± 0.001** | **0.971 ± 0.003** |
| | GINE | 0.266 ± 0.02 | 0.961 ± 0.003 | 0.915 ± 0.01 | 0.147 ± 0.009 | 0.987 ± 0.001 | 0.971 ± 0.003 |
| | GatedGCN | | | | 0.1282 ± 0.0045 | 0.9850 ± 0.0006 | 0.9639 ± 0.0024 |
| | MPNN++ (sum) | | | | **0.1002 ± 0.0025** | **0.9909 ± 0.0004** | **0.9777 ± 0.0014** |

| | | BCE ↓ | AUROC ↑ | AP ↑ | BCE ↓ | AUROC ↑ | AP ↑ |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
| | <th colspan="3" style="text-align: center;">Single-Task Model</th> <th colspan="3" style="text-align: center;">Multi-Task Model</th> |
|
| **Tox21** | GCN | 0.202 ± 0.005 | 0.773 ± 0.006 | 0.334 ± 0.03 | **0.176 ± 0.001** | **0.850 ± 0.006** | 0.446 ± 0.01 |
| **Tox21** | GCN | 0.202 ± 0.005 | 0.773 ± 0.006 | 0.334 ± 0.03 | 0.176 ± 0.001 | 0.850 ± 0.006 | 0.446 ± 0.01 |
| | GIN | 0.200 ± 0.002 | 0.789 ± 0.009 | 0.350 ± 0.01 | 0.176 ± 0.001 | 0.841 ± 0.005 | 0.454 ± 0.009 |
| | GINE | 0.201 ± 0.007 | 0.783 ± 0.007 | 0.345 ± 0.02 | 0.177 ± 0.0008 | 0.836 ± 0.004 | **0.455 ± 0.008** |
| | GINE | 0.201 ± 0.007 | 0.783 ± 0.007 | 0.345 ± 0.02 | 0.177 ± 0.0008 | 0.836 ± 0.004 | 0.455 ± 0.008 |
| | GatedGCN | | | | 0.1733 ± 0.0015 | 0.8522 ± 0.0022 | **0.4620 ± 0.0118** |
| | MPNN++ (sum) | | | | **0.1725 ± 0.0012** | **0.8569 ± 0.0005** | 0.4598 ± 0.0044 |


# LargeMix Baseline
## LargeMix test set metrics
Expand Down Expand Up @@ -88,6 +96,40 @@ This is not surprising as they contain two orders of magnitude more datapoints a
| | GIN | 0.1873 ± 0.0033 | **0.1701 ± 0.0142** |
| | GINE | 0.1883 ± 0.0039 | **0.1771 ± 0.0010** |

## NEW: Largemix improved sweep - 2023/08-18

Unsatisfied with the prior results, we ran a bayesian search over a broader set of parameters, and including only more expressive models, namely GINE, GatedGCN and MPNN++. We further increase the number of parameters to 10M due to evidence of underfitting. We evaluate only the multitask setting.

We observe a significant improvement over all tasks, with a very notable r2-score increase of +0.53 (0.27 -> 0.80) compared to the best node-level property prediction on PCQM4M_N4.

The results are reported below over 1 seed. We are currently running more seeds of the same models.

| Dataset | Model | MAE ↓ | Pearson ↑ | R² ↑ |
|---------------|----------------|--------|---------|--------|
| **PCQM4M_G25** | GINE | 0.2250 | 0.8840 | 0.7911 |
| | GatedGCN | 0.2457 | 0.8698 | 0.7688 |
| | MPNN++ (sum) | 0.2269 | 0.8802 | 0.7855 |
|
| **PCQM4M_N4** | GINE | 0.2699 | 0.8475 | 0.7182 |
| | GatedGCN | 0.3337 | 0.8102 | 0.6566 |
| | MPNN++ (sum) | 0.2114 | 0.8942 | 0.8000 |

| Dataset | Model | BCE ↓ | AUROC ↑ | AP ↑ |
|---------------|----------------|--------|---------|--------|
| **PCBA_1328** | GINE | 0.0334 | 0.7879 | 0.2808 |
| | GatedGCN | 0.0351 | 0.7788 | 0.2611 |
| | MPNN++ (sum) | 0.0344 | 0.7815 | 0.2666 |
|
| **L1000_VCAP** | GINE | 0.1907 | 0.6416 | 0.4042 |
| | GatedGCN | 0.1866 | 0.6395 | 0.4092 |
| | MPNN++ (sum) | 0.1867 | 0.6478 | 0.4131 |
|
| **L1000_MCF7** | GINE | 0.1931 | 0.6352 | 0.4235 |
| | GatedGCN | 0.1859 | 0.6547 | 0.4224 |
| | MPNN++ (sum) | 0.1870 | 0.6593 | 0.4254 |



# UltraLarge Baseline

## UltraLarge test set metrics
Expand Down
4 changes: 3 additions & 1 deletion docs/datasets.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Graphium Datasets

Graphium datasets are hosted at on Zenodo on [this link](https://zenodo.org/record/8206704).
Graphium datasets are hosted at on Zenodo
- ***ToyMix*** and ***LargeMix*** dataseets are hosted on [this link](https://doi.org/10.5281/zenodo.7998401)
- ***UltraLarge*** dataset is hosted on [this link](https://doi.org/10.5281/zenodo.8370547)

Instead of provinding datasets as a single entity, our aim is to provide dataset mixes containing a variety of datasets that are meant to be predicted simultaneously using multi-tasking.

Expand Down
22 changes: 22 additions & 0 deletions expts/hydra-configs/accelerator/ipu_pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
type: ipu
ipu_config:
- deviceIterations(60) # IPU would require large batches to be ready for the model.
# 60 for PCQM4mv2
# 30 for largemix
- replicationFactor(4)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 96)
- Precision.enableStochasticRounding(True)

ipu_inference_config:
# set device iteration and replication factor to 1 during inference
# gradient accumulation was set to 1 in the code
- deviceIterations(60)
- replicationFactor(1)
- Precision.enableStochasticRounding(False)

accelerator_kwargs:
_accelerator: "ipu"
gnn_layers_per_ipu: [4, 4, 4, 4]
4 changes: 2 additions & 2 deletions expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ metrics:
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2
metric: r2_score
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
Expand Down Expand Up @@ -138,4 +138,4 @@ datamodule:
args:
# TDC specific
tdc_benchmark_names: null
tdc_train_val_seed: ${constants.seed}
tdc_train_val_seed: ${constants.seed}
84 changes: 83 additions & 1 deletion graphium/cli/train_finetune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import fsspec
import hydra
import numpy as np
import torch
import wandb
import yaml
Expand Down Expand Up @@ -42,6 +43,8 @@

TESTING_ONLY_CONFIG_KEY = "testing_only"

OmegaConf.register_new_resolver("eval", lambda x: eval(x, {"np": np}))


@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main")
def cli(cfg: DictConfig) -> None:
Expand All @@ -51,13 +54,78 @@ def cli(cfg: DictConfig) -> None:
return run_training_finetuning_testing(cfg)


def get_replication_factor(cfg):
try:
ipu_config = cfg.get("accelerator", {}).get("ipu_config", [])
for item in ipu_config:
if "replicationFactor" in item:
# Extract the number between parentheses
start = item.find("(") + 1
end = item.find(")")
if start != 0 and end != -1:
return int(item[start:end])
except Exception as e:
print(f"An error occurred: {e}")

# Return default value if replicationFactor is not found or an error occurred
return 1


def get_gradient_accumulation_factor(cfg):
try:
# Navigate through the nested dictionaries and get the gradient accumulation factor
grad_accumulation_factor = (
cfg.get("accelerator", {})
.get("config_override", {})
.get("trainer", {})
.get("trainer", {})
.get("accumulate_grad_batches", 1)
)

# Ensure that the extracted value is an integer
return int(grad_accumulation_factor)
except Exception as e:
print(f"An error occurred: {e}")

# Return default value if an error occurred
return 1


def get_training_batch_size(cfg):
try:
# Navigate through the nested dictionaries and get the training batch size
batch_size_training = (
cfg.get("accelerator", {})
.get("config_override", {})
.get("datamodule", {})
.get("args", {})
.get("batch_size_training", 1)
)

# Ensure that the extracted value is an integer
return int(batch_size_training)
except Exception as e:
print(f"An error occurred: {e}")

# Return default value if an error occurred
return 1


def run_training_finetuning_testing(cfg: DictConfig) -> None:
"""
The main (pre-)training and fine-tuning loop.
"""

unresolved_cfg = OmegaConf.to_container(cfg, resolve=False)
cfg = OmegaConf.to_container(cfg, resolve=True)

# Get the current date and time
now = datetime.now()
# Format the datetime as a string
filename_datetime_suffix = now.strftime("%Y%m%d_%H%M%S")
# Append the datetime string to the existing filename in the cfg dictionary
cfg["trainer"]["model_checkpoint"]["filename"] += f"_{filename_datetime_suffix}"

dst_dir = cfg["constants"].get("results_dir")
hydra_cfg = HydraConfig.get()
output_dir = hydra_cfg["runtime"]["output_dir"]
Expand All @@ -75,6 +143,12 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:

st = timeit.default_timer()

replicas = get_replication_factor(cfg)
gradient_acc = get_gradient_accumulation_factor(cfg)
micro_bs = get_training_batch_size(cfg)

global_bs = replicas * gradient_acc * micro_bs

# Disable wandb if the user is not logged in.
wandb_cfg = cfg["constants"].get("wandb")
if wandb_cfg is not None and wandb.login() is False:
Expand Down Expand Up @@ -119,6 +193,9 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
accelerator_type=accelerator_type,
featurization=datamodule.featurization,
task_norms=datamodule.task_norms,
replicas=replicas,
gradient_acc=gradient_acc,
global_bs=global_bs,
)

logger.info(predictor.model)
Expand All @@ -135,7 +212,7 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
trainer.callbacks.append(GraphFinetuning(**finetuning_training_kwargs))

if wandb_cfg is not None:
save_params_to_wandb(trainer.logger, cfg, predictor, datamodule)
save_params_to_wandb(trainer.logger, cfg, predictor, datamodule, unresolved_config=unresolved_cfg)

# Determine the max num nodes and edges in training and validation
logger.info("Computing the maximum number of nodes and edges per graph")
Expand Down Expand Up @@ -173,6 +250,11 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
logger.info("-" * 50)

if wandb_cfg is not None:
# Save initial model state - and upload checkpoint to wandb
if cfg["trainer"]["model_checkpoint"]["save_last"] is True:
checkpoint_path = f"{cfg['trainer']['model_checkpoint']['dirpath']}{cfg['trainer']['model_checkpoint']['filename']}.ckpt"
# Log the initial model checkpoint to wandb
wandb.save(checkpoint_path)
wandb.finish()

# Save test metrics - Base utility in case someone doesn't use a logger.
Expand Down
Loading

0 comments on commit 66abb9a

Please sign in to comment.