diff --git a/README.md b/README.md
index 11b707bba..a83f7ab40 100644
--- a/README.md
+++ b/README.md
@@ -80,7 +80,7 @@ To change parameters specific to this experiment like switching from `fp16` to `
```bash
graphium-train dataset=toymix model=gcn trainer.trainer.precision=32
```
-or change them permamently in the dedicated experiment config under `expts/hydra-configs/toymix_gcn.yaml`.
+or change them permanently in the dedicated experiment config under `expts/hydra-configs/toymix_gcn.yaml`.
Integrating `hydra` also allows you to quickly switch between accelerators. E.g., running
```bash
graphium-train dataset=toymix model=gcn accelerator=gpu
diff --git a/docs/baseline.md b/docs/baseline.md
index 029996554..cac1ee282 100644
--- a/docs/baseline.md
+++ b/docs/baseline.md
@@ -4,6 +4,8 @@ From the paper to be released soon. Below, you can see the baselines for the `To
One can observe that the smaller datasets (`Zinc12k` and `Tox21`) beneficiate from adding another unrelated task (`QM9`), where the labels are computed from DFT simulations.
+**NEW baselines added 2023/09/18**: Multitask baselines have been added for GatedGCN and MPNN++ (sum aggretator) using 3 random seeds. They achieve the best performance by a significant margin on Zinc12k and Tox21, while sacrificing a little on QM9.
+
| Dataset | Model | MAE ↓ | Pearson ↑ | R² ↑ | MAE ↓ | Pearson ↑ | R² ↑ |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
| |
Single-Task Model | Multi-Task Model | |
@@ -11,18 +13,24 @@ One can observe that the smaller datasets (`Zinc12k` and `Tox21`) beneficiate fr
| **QM9** | GCN | 0.102 ± 0.0003 | 0.958 ± 0.0007 | 0.920 ± 0.002 | 0.119 ± 0.01 | 0.955 ± 0.001 | 0.915 ± 0.001 |
| | GIN | 0.0976 ± 0.0006 | **0.959 ± 0.0002** | **0.922 ± 0.0004** | 0.117 ± 0.01 | 0.950 ± 0.002 | 0.908 ± 0.003 |
| | GINE | **0.0959 ± 0.0002** | 0.955 ± 0.002 | 0.918 ± 0.004 | 0.102 ± 0.01 | 0.956 ± 0.0009 | 0.918 ± 0.002 |
-|
-| **Zinc12k** | GCN | 0.348 ± 0.02 | 0.941 ± 0.002 | 0.863 ± 0.01 | 0.226 ± 0.004 | 0.973 ± 0.0005 | 0.940 ± 0.003 |
+| | GatedGCN | | | | 0.1212 ± 0.0009 | 0.9457 ± 0.0002 | 0.8964 ± 0.0006 |
+| | MPNN++ (sum) | | | | 0.1174 ± 0.0012 | 0.9460 ± 0.0005 | 0.8989 ± 0.0008 |
+ **Zinc12k** | GCN | 0.348 ± 0.02 | 0.941 ± 0.002 | 0.863 ± 0.01 | 0.226 ± 0.004 | 0.973 ± 0.0005 | 0.940 ± 0.003 |
| | GIN | 0.303 ± 0.007 | 0.950 ± 0.003 | 0.889 ± 0.003 | 0.189 ± 0.004 | 0.978 ± 0.006 | 0.953 ± 0.002 |
-| | GINE | 0.266 ± 0.02 | 0.961 ± 0.003 | 0.915 ± 0.01 | **0.147 ± 0.009** | **0.987 ± 0.001** | **0.971 ± 0.003** |
+| | GINE | 0.266 ± 0.02 | 0.961 ± 0.003 | 0.915 ± 0.01 | 0.147 ± 0.009 | 0.987 ± 0.001 | 0.971 ± 0.003 |
+| | GatedGCN | | | | 0.1282 ± 0.0045 | 0.9850 ± 0.0006 | 0.9639 ± 0.0024 |
+| | MPNN++ (sum) | | | | **0.1002 ± 0.0025** | **0.9909 ± 0.0004** | **0.9777 ± 0.0014** |
| | | BCE ↓ | AUROC ↑ | AP ↑ | BCE ↓ | AUROC ↑ | AP ↑ |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
| | Single-Task Model | Multi-Task Model | |
|
-| **Tox21** | GCN | 0.202 ± 0.005 | 0.773 ± 0.006 | 0.334 ± 0.03 | **0.176 ± 0.001** | **0.850 ± 0.006** | 0.446 ± 0.01 |
+| **Tox21** | GCN | 0.202 ± 0.005 | 0.773 ± 0.006 | 0.334 ± 0.03 | 0.176 ± 0.001 | 0.850 ± 0.006 | 0.446 ± 0.01 |
| | GIN | 0.200 ± 0.002 | 0.789 ± 0.009 | 0.350 ± 0.01 | 0.176 ± 0.001 | 0.841 ± 0.005 | 0.454 ± 0.009 |
-| | GINE | 0.201 ± 0.007 | 0.783 ± 0.007 | 0.345 ± 0.02 | 0.177 ± 0.0008 | 0.836 ± 0.004 | **0.455 ± 0.008** |
+| | GINE | 0.201 ± 0.007 | 0.783 ± 0.007 | 0.345 ± 0.02 | 0.177 ± 0.0008 | 0.836 ± 0.004 | 0.455 ± 0.008 |
+| | GatedGCN | | | | 0.1733 ± 0.0015 | 0.8522 ± 0.0022 | **0.4620 ± 0.0118** |
+| | MPNN++ (sum) | | | | **0.1725 ± 0.0012** | **0.8569 ± 0.0005** | 0.4598 ± 0.0044 |
+
# LargeMix Baseline
## LargeMix test set metrics
@@ -88,6 +96,40 @@ This is not surprising as they contain two orders of magnitude more datapoints a
| | GIN | 0.1873 ± 0.0033 | **0.1701 ± 0.0142** |
| | GINE | 0.1883 ± 0.0039 | **0.1771 ± 0.0010** |
+## NEW: Largemix improved sweep - 2023/08-18
+
+Unsatisfied with the prior results, we ran a bayesian search over a broader set of parameters, and including only more expressive models, namely GINE, GatedGCN and MPNN++. We further increase the number of parameters to 10M due to evidence of underfitting. We evaluate only the multitask setting.
+
+We observe a significant improvement over all tasks, with a very notable r2-score increase of +0.53 (0.27 -> 0.80) compared to the best node-level property prediction on PCQM4M_N4.
+
+The results are reported below over 1 seed. We are currently running more seeds of the same models.
+
+| Dataset | Model | MAE ↓ | Pearson ↑ | R² ↑ |
+|---------------|----------------|--------|---------|--------|
+| **PCQM4M_G25** | GINE | 0.2250 | 0.8840 | 0.7911 |
+| | GatedGCN | 0.2457 | 0.8698 | 0.7688 |
+| | MPNN++ (sum) | 0.2269 | 0.8802 | 0.7855 |
+|
+| **PCQM4M_N4** | GINE | 0.2699 | 0.8475 | 0.7182 |
+| | GatedGCN | 0.3337 | 0.8102 | 0.6566 |
+| | MPNN++ (sum) | 0.2114 | 0.8942 | 0.8000 |
+
+| Dataset | Model | BCE ↓ | AUROC ↑ | AP ↑ |
+|---------------|----------------|--------|---------|--------|
+| **PCBA_1328** | GINE | 0.0334 | 0.7879 | 0.2808 |
+| | GatedGCN | 0.0351 | 0.7788 | 0.2611 |
+| | MPNN++ (sum) | 0.0344 | 0.7815 | 0.2666 |
+|
+| **L1000_VCAP** | GINE | 0.1907 | 0.6416 | 0.4042 |
+| | GatedGCN | 0.1866 | 0.6395 | 0.4092 |
+| | MPNN++ (sum) | 0.1867 | 0.6478 | 0.4131 |
+|
+| **L1000_MCF7** | GINE | 0.1931 | 0.6352 | 0.4235 |
+| | GatedGCN | 0.1859 | 0.6547 | 0.4224 |
+| | MPNN++ (sum) | 0.1870 | 0.6593 | 0.4254 |
+
+
+
# UltraLarge Baseline
## UltraLarge test set metrics
diff --git a/docs/datasets.md b/docs/datasets.md
index fc4e0f292..6733736f4 100644
--- a/docs/datasets.md
+++ b/docs/datasets.md
@@ -1,6 +1,8 @@
# Graphium Datasets
-Graphium datasets are hosted at on Zenodo on [this link](https://zenodo.org/record/8206704).
+Graphium datasets are hosted at on Zenodo
+- ***ToyMix*** and ***LargeMix*** dataseets are hosted on [this link](https://doi.org/10.5281/zenodo.7998401)
+- ***UltraLarge*** dataset is hosted on [this link](https://doi.org/10.5281/zenodo.8370547)
Instead of provinding datasets as a single entity, our aim is to provide dataset mixes containing a variety of datasets that are meant to be predicted simultaneously using multi-tasking.
diff --git a/expts/hydra-configs/accelerator/ipu_pipeline.yaml b/expts/hydra-configs/accelerator/ipu_pipeline.yaml
new file mode 100644
index 000000000..996218646
--- /dev/null
+++ b/expts/hydra-configs/accelerator/ipu_pipeline.yaml
@@ -0,0 +1,22 @@
+type: ipu
+ipu_config:
+ - deviceIterations(60) # IPU would require large batches to be ready for the model.
+ # 60 for PCQM4mv2
+ # 30 for largemix
+ - replicationFactor(4)
+ # - enableProfiling("graph_analyser") # The folder where the profile will be stored
+ # - enableExecutableCaching("pop_compiler_cache")
+ - TensorLocations.numIOTiles(128)
+ - _Popart.set("defaultBufferingDepth", 96)
+ - Precision.enableStochasticRounding(True)
+
+ipu_inference_config:
+ # set device iteration and replication factor to 1 during inference
+ # gradient accumulation was set to 1 in the code
+ - deviceIterations(60)
+ - replicationFactor(1)
+ - Precision.enableStochasticRounding(False)
+
+accelerator_kwargs:
+ _accelerator: "ipu"
+ gnn_layers_per_ipu: [4, 4, 4, 4]
\ No newline at end of file
diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml
index cfff5f689..89176f2b6 100644
--- a/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml
+++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml
@@ -80,7 +80,7 @@ metrics:
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
- metric: r2
+ metric: r2_score
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
@@ -138,4 +138,4 @@ datamodule:
args:
# TDC specific
tdc_benchmark_names: null
- tdc_train_val_seed: ${constants.seed}
\ No newline at end of file
+ tdc_train_val_seed: ${constants.seed}
diff --git a/graphium/cli/train_finetune_test.py b/graphium/cli/train_finetune_test.py
index 5d9a29b97..1be59e652 100644
--- a/graphium/cli/train_finetune_test.py
+++ b/graphium/cli/train_finetune_test.py
@@ -6,6 +6,7 @@
import fsspec
import hydra
+import numpy as np
import torch
import wandb
import yaml
@@ -42,6 +43,8 @@
TESTING_ONLY_CONFIG_KEY = "testing_only"
+OmegaConf.register_new_resolver("eval", lambda x: eval(x, {"np": np}))
+
@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main")
def cli(cfg: DictConfig) -> None:
@@ -51,13 +54,78 @@ def cli(cfg: DictConfig) -> None:
return run_training_finetuning_testing(cfg)
+def get_replication_factor(cfg):
+ try:
+ ipu_config = cfg.get("accelerator", {}).get("ipu_config", [])
+ for item in ipu_config:
+ if "replicationFactor" in item:
+ # Extract the number between parentheses
+ start = item.find("(") + 1
+ end = item.find(")")
+ if start != 0 and end != -1:
+ return int(item[start:end])
+ except Exception as e:
+ print(f"An error occurred: {e}")
+
+ # Return default value if replicationFactor is not found or an error occurred
+ return 1
+
+
+def get_gradient_accumulation_factor(cfg):
+ try:
+ # Navigate through the nested dictionaries and get the gradient accumulation factor
+ grad_accumulation_factor = (
+ cfg.get("accelerator", {})
+ .get("config_override", {})
+ .get("trainer", {})
+ .get("trainer", {})
+ .get("accumulate_grad_batches", 1)
+ )
+
+ # Ensure that the extracted value is an integer
+ return int(grad_accumulation_factor)
+ except Exception as e:
+ print(f"An error occurred: {e}")
+
+ # Return default value if an error occurred
+ return 1
+
+
+def get_training_batch_size(cfg):
+ try:
+ # Navigate through the nested dictionaries and get the training batch size
+ batch_size_training = (
+ cfg.get("accelerator", {})
+ .get("config_override", {})
+ .get("datamodule", {})
+ .get("args", {})
+ .get("batch_size_training", 1)
+ )
+
+ # Ensure that the extracted value is an integer
+ return int(batch_size_training)
+ except Exception as e:
+ print(f"An error occurred: {e}")
+
+ # Return default value if an error occurred
+ return 1
+
+
def run_training_finetuning_testing(cfg: DictConfig) -> None:
"""
The main (pre-)training and fine-tuning loop.
"""
+ unresolved_cfg = OmegaConf.to_container(cfg, resolve=False)
cfg = OmegaConf.to_container(cfg, resolve=True)
+ # Get the current date and time
+ now = datetime.now()
+ # Format the datetime as a string
+ filename_datetime_suffix = now.strftime("%Y%m%d_%H%M%S")
+ # Append the datetime string to the existing filename in the cfg dictionary
+ cfg["trainer"]["model_checkpoint"]["filename"] += f"_{filename_datetime_suffix}"
+
dst_dir = cfg["constants"].get("results_dir")
hydra_cfg = HydraConfig.get()
output_dir = hydra_cfg["runtime"]["output_dir"]
@@ -75,6 +143,12 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
st = timeit.default_timer()
+ replicas = get_replication_factor(cfg)
+ gradient_acc = get_gradient_accumulation_factor(cfg)
+ micro_bs = get_training_batch_size(cfg)
+
+ global_bs = replicas * gradient_acc * micro_bs
+
# Disable wandb if the user is not logged in.
wandb_cfg = cfg["constants"].get("wandb")
if wandb_cfg is not None and wandb.login() is False:
@@ -119,6 +193,9 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
accelerator_type=accelerator_type,
featurization=datamodule.featurization,
task_norms=datamodule.task_norms,
+ replicas=replicas,
+ gradient_acc=gradient_acc,
+ global_bs=global_bs,
)
logger.info(predictor.model)
@@ -135,7 +212,7 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
trainer.callbacks.append(GraphFinetuning(**finetuning_training_kwargs))
if wandb_cfg is not None:
- save_params_to_wandb(trainer.logger, cfg, predictor, datamodule)
+ save_params_to_wandb(trainer.logger, cfg, predictor, datamodule, unresolved_config=unresolved_cfg)
# Determine the max num nodes and edges in training and validation
logger.info("Computing the maximum number of nodes and edges per graph")
@@ -173,6 +250,11 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
logger.info("-" * 50)
if wandb_cfg is not None:
+ # Save initial model state - and upload checkpoint to wandb
+ if cfg["trainer"]["model_checkpoint"]["save_last"] is True:
+ checkpoint_path = f"{cfg['trainer']['model_checkpoint']['dirpath']}{cfg['trainer']['model_checkpoint']['filename']}.ckpt"
+ # Log the initial model checkpoint to wandb
+ wandb.save(checkpoint_path)
wandb.finish()
# Save test metrics - Base utility in case someone doesn't use a logger.
diff --git a/graphium/config/_loader.py b/graphium/config/_loader.py
index 48d7e9078..259c61e34 100644
--- a/graphium/config/_loader.py
+++ b/graphium/config/_loader.py
@@ -13,7 +13,7 @@
# Lightning
from lightning import Trainer
-from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
+from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import Logger, WandbLogger
from loguru import logger
@@ -76,7 +76,6 @@ def _get_ipu_opts(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -> Tuple[
if accelerator_type != "ipu":
return None, None
-
ipu_opts = accelerator_options["ipu_config"]
ipu_inference_opts = accelerator_options.get("ipu_inference_config", None)
@@ -126,6 +125,7 @@ def load_datamodule(
ipu_inference_opts=ipu_inference_opts,
precision=config["trainer"]["trainer"].get("precision"),
)
+
# Define the Dataloader options for the IPU on the training sets
bz_train = cfg_data["batch_size_training"]
ipu_dataloader_training_opts = IPUDataloaderOptions(
@@ -261,6 +261,10 @@ def load_architecture(
graph_output_nn_kwargs=graph_output_nn_kwargs,
task_heads_kwargs=task_heads_kwargs,
)
+ # Get accelerator_kwargs if they exist
+ accelerator_kwargs = config["accelerator"].get("accelerator_kwargs", None)
+ if accelerator_kwargs is not None:
+ model_kwargs["accelerator_kwargs"] = accelerator_kwargs
if model_class is FullGraphFinetuningNetwork:
finetuning_head_kwargs = config["finetuning"].pop("finetuning_head", None)
@@ -286,6 +290,9 @@ def load_predictor(
accelerator_type: str,
featurization: Dict[str, str] = None,
task_norms: Optional[Dict[Callable, Any]] = None,
+ replicas: int = 1,
+ gradient_acc: int = 1,
+ global_bs: int = 1,
) -> PredictorModule:
"""
Defining the predictor module, which handles the training logic from `lightning.LighningModule`
@@ -311,6 +318,9 @@ def load_predictor(
task_levels=task_levels,
featurization=featurization,
task_norms=task_norms,
+ replicas=replicas,
+ gradient_acc=gradient_acc,
+ global_bs=global_bs,
**cfg_pred,
)
@@ -327,6 +337,9 @@ def load_predictor(
task_levels=task_levels,
featurization=featurization,
task_norms=task_norms,
+ replicas=replicas,
+ gradient_acc=gradient_acc,
+ global_bs=global_bs,
**cfg_pred,
)
@@ -415,13 +428,18 @@ def load_trainer(
if "model_checkpoint" in cfg_trainer.keys():
callbacks.append(ModelCheckpoint(**cfg_trainer["model_checkpoint"]))
+ if "learning_rate_monitor" in cfg_trainer.keys():
+ callbacks.append(LearningRateMonitor(**cfg_trainer["learning_rate_monitor"]))
+ else:
+ callbacks.append(LearningRateMonitor())
+
# Define the logger parameters
wandb_cfg = config["constants"].get("wandb")
if wandb_cfg is not None:
name = wandb_cfg.pop("name", "main")
if len(date_time_suffix) > 0:
name += f"_{date_time_suffix}"
- trainer_kwargs["logger"] = WandbLogger(name=name, **wandb_cfg)
+ trainer_kwargs["logger"] = WandbLogger(name=name, log_model=True, **wandb_cfg)
trainer_kwargs["callbacks"] = callbacks
trainer = Trainer(
@@ -440,6 +458,7 @@ def save_params_to_wandb(
config: Union[omegaconf.DictConfig, Dict[str, Any]],
predictor: PredictorModule,
datamodule: MultitaskFromSmilesDataModule,
+ unresolved_config: Optional[Union[omegaconf.DictConfig, Dict[str, Any]]] = None,
):
"""
Save a few stuff to weights-and-biases WandB
@@ -448,13 +467,16 @@ def save_params_to_wandb(
config: The config file, with key `trainer`
predictor: The predictor used to handle the train/val/test steps logic
datamodule: The datamodule used to load the data into training
+ unresolved_config: The unresolved config file
"""
# Get the wandb runner and directory
wandb_run = logger.experiment
+
if wandb_run is None:
- wandb_run = ""
- wandb_dir = wandb_run.dir
+ wandb_dir = ""
+ else:
+ wandb_dir = wandb_run.dir
# Save the mup base model to WandB as a yaml file
mup.save_base_shapes(predictor.model, os.path.join(wandb_dir, "mup_base_params.yaml"))
@@ -463,14 +485,18 @@ def save_params_to_wandb(
with open(os.path.join(wandb_dir, "full_configs.yaml"), "w") as file:
yaml.dump(config, file)
+ if unresolved_config is not None:
+ with open(os.path.join(wandb_dir, "unresolved_config.yaml"), "w") as file:
+ yaml.dump(unresolved_config, file)
+
# Save the featurizer into wandb
featurizer_path = os.path.join(wandb_dir, "featurizer.pickle")
joblib.dump(datamodule.smiles_transformer, featurizer_path)
# Save the featurizer and configs into wandb
if wandb_run is not None:
- wandb_run.save("*.yaml")
- wandb_run.save("*.pickle")
+ wandb_run.save(os.path.join(wandb_dir, "*.yaml"), wandb_dir)
+ wandb_run.save(os.path.join(wandb_dir, "*.pickle"), wandb_dir)
def load_accelerator(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -> Tuple[Dict[str, Any], str]:
diff --git a/graphium/config/zinc_default_multitask_pyg.yaml b/graphium/config/zinc_default_multitask_pyg.yaml
index 07ae4bf9b..b9435ec7e 100644
--- a/graphium/config/zinc_default_multitask_pyg.yaml
+++ b/graphium/config/zinc_default_multitask_pyg.yaml
@@ -181,3 +181,5 @@ architecture: # The parameters for the full graph network are taken from `co
dropout: 0.2
normalization: none
residual_type: none
+accelerator:
+ type: cpu
\ No newline at end of file
diff --git a/graphium/features/featurizer.py b/graphium/features/featurizer.py
index 66f241663..d8efdb2ab 100644
--- a/graphium/features/featurizer.py
+++ b/graphium/features/featurizer.py
@@ -1062,11 +1062,9 @@ def mol_to_graph_dict(
mol = Chem.AddHs(mol)
else:
mol = Chem.RemoveHs(mol)
-
num_atoms = mol.GetNumAtoms()
if (max_num_atoms is not None) and (num_atoms > max_num_atoms):
raise ValueError(f"Maximum number of atoms greater than permitted {num_atoms}>{max_num_atoms}")
-
(
adj,
ndata,
diff --git a/graphium/finetuning/utils.py b/graphium/finetuning/utils.py
index abcd11644..a2bd20d68 100644
--- a/graphium/finetuning/utils.py
+++ b/graphium/finetuning/utils.py
@@ -47,7 +47,6 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]):
"""
Function combining information from configuration and pretrained model for finetuning.
"""
-
task = cfg["finetuning"]["task"]
# Filter the config based on the task name
diff --git a/graphium/nn/architectures/encoder_manager.py b/graphium/nn/architectures/encoder_manager.py
index e3e48aeba..464d9e9cc 100644
--- a/graphium/nn/architectures/encoder_manager.py
+++ b/graphium/nn/architectures/encoder_manager.py
@@ -135,6 +135,8 @@ def _initialize_positional_encoders(self, pe_encoders_kwargs: Dict[str, Any]) ->
if pe_out_dim2 is not None:
assert edge_pe_out_dim == pe_out_dim2, f"values mismatch {pe_out_dim}!={pe_out_dim2}"
pe_encoders[encoder_name] = encoder(out_dim=edge_pe_out_dim, **this_in_dims, **encoder_kwargs)
+ else:
+ pe_encoders[encoder_name] = encoder(**this_in_dims, **encoder_kwargs)
return pe_encoders
diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py
index 0e4599b24..dc05dbe60 100644
--- a/graphium/nn/architectures/global_architectures.py
+++ b/graphium/nn/architectures/global_architectures.py
@@ -12,6 +12,7 @@
from torch import Tensor, nn
import torch
from torch_geometric.data import Data
+from omegaconf import DictConfig, OmegaConf
# graphium imports
from graphium.data.utils import get_keys
@@ -421,6 +422,7 @@ def __init__(
residual_skip_steps: int = 1,
in_dim_edges: int = 0,
hidden_dims_edges: List[int] = [],
+ out_dim_edges: Optional[int] = None,
name: str = "GNN",
layer_kwargs: Optional[Dict] = None,
virtual_node: str = "none",
@@ -508,6 +510,11 @@ def __init__(
Hidden dimensions for the edges. Most models don't support it, so it
should only be used for those that do, i.e. `GatedGCNLayer`
+ out_dim_edges:
+ Output edge-feature dimensions of the network. Keep at 0 if not using
+ edge features, or if the layer doesn't support edges. Defaults to the
+ last value of hidden_dims_edges.
+
name:
Name attributed to the current network, for display and printing
purposes.
@@ -551,9 +558,17 @@ def __init__(
else:
self.hidden_dims_edges = list(hidden_dims_edges)
assert depth is None
+ self.out_dim_edges = (
+ out_dim_edges
+ if out_dim_edges is not None
+ else self.hidden_dims_edges[-1]
+ if self.hidden_dims_edges
+ else 0
+ )
self.full_dims_edges = None
- if len(self.hidden_dims_edges) > 0:
- self.full_dims_edges = [self.in_dim_edges] + self.hidden_dims_edges + [self.hidden_dims_edges[-1]]
+ if len(self.hidden_dims_edges) or self.out_dim_edges > 0:
+ assert self.out_dim_edges > 0, self.out_dim_edges
+ self.full_dims_edges = [self.in_dim_edges] + self.hidden_dims_edges + [self.out_dim_edges]
self.virtual_node = virtual_node.lower() if virtual_node is not None else "none"
@@ -593,6 +608,26 @@ def _check_bad_arguments(self):
) and not self.layer_class.layer_supports_edges:
raise ValueError(f"Cannot use edge features with class `{self.layer_class}`")
+ def get_nested_key(self, d, target_key):
+ """
+ Get the value associated with a key in a nested dictionary.
+
+ Parameters:
+ - d: The dictionary to search in
+ - target_key: The key to search for
+
+ Returns:
+ - The value associated with the key if found, None otherwise
+ """
+ if target_key in d:
+ return d[target_key]
+ for key, value in d.items():
+ if isinstance(value, (dict, DictConfig)):
+ nested_result = self.get_nested_key(value, target_key)
+ if nested_result is not None:
+ return nested_result
+ return None
+
def _create_layers(self):
r"""
Create all the necessary layers for the network.
@@ -639,7 +674,8 @@ def _create_layers(self):
this_out_dim_edges = self.full_dims_edges[ii + 1]
this_edge_kwargs["out_dim_edges"] = this_out_dim_edges
else:
- this_out_dim_edges = self.layer_kwargs.get("out_dim_edges")
+ this_out_dim_edges = self.get_nested_key(self.layer_kwargs, "out_dim_edges")
+ this_edge_kwargs["out_dim_edges"] = this_out_dim_edges
layer_out_dims_edges.append(this_out_dim_edges)
# Create the GNN layer
@@ -900,6 +936,7 @@ def get_init_kwargs(self) -> Dict[str, Any]:
new_kwargs = dict(
in_dim_edges=self.in_dim_edges,
hidden_dims_edges=self.hidden_dims_edges,
+ out_dim_edges=self.out_dim_edges,
virtual_node=self.virtual_node,
use_virtual_edges=self.use_virtual_edges,
)
@@ -931,6 +968,7 @@ def make_mup_base_kwargs(
kwargs["in_dim_edges"] = round(kwargs["in_dim_edges"] / divide_factor)
if not self.last_layer_is_readout:
kwargs["out_dim"] = round(kwargs["out_dim"] / divide_factor)
+ kwargs["out_dim_edges"] = round(kwargs["out_dim_edges"] / divide_factor)
def _recursive_divide_dim(x: collections.abc.Mapping):
for k, v in x.items():
diff --git a/graphium/nn/encoders/laplace_pos_encoder.py b/graphium/nn/encoders/laplace_pos_encoder.py
index ccf642e9d..7cc69919b 100644
--- a/graphium/nn/encoders/laplace_pos_encoder.py
+++ b/graphium/nn/encoders/laplace_pos_encoder.py
@@ -3,7 +3,7 @@
import torch.nn as nn
from torch_geometric.data import Batch
-from graphium.nn.base_layers import MLP, get_norm, FCLayer
+from graphium.nn.base_layers import MLP, get_norm, FCLayer, TransformerEncoderLayerMup
from graphium.nn.encoders.base_encoder import BaseEncoder
@@ -70,7 +70,8 @@ def __init__(
if self.model_type == "Transformer":
# Transformer model for LapPE
model_kwargs.setdefault("nhead", 1)
- encoder_layer = nn.TransformerEncoderLayer(
+ encoder_layer = TransformerEncoderLayerMup(
+ None,
d_model=hidden_dim,
batch_first=True,
dropout=dropout,
diff --git a/graphium/nn/pyg_layers/gps_pyg.py b/graphium/nn/pyg_layers/gps_pyg.py
index f3da56979..7af7107ac 100644
--- a/graphium/nn/pyg_layers/gps_pyg.py
+++ b/graphium/nn/pyg_layers/gps_pyg.py
@@ -47,9 +47,10 @@ def __init__(
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
node_residual: Optional[bool] = True,
+ edge_residual: Optional[bool] = True,
normalization: Union[str, Callable] = "none",
mpnn_type: str = "pyg:gine",
- mpnn_kwargs=None,
+ mpnn_kwargs: Optional[dict] = None,
attn_type: str = "full-attention",
precision: str = "32",
biased_attention_key: Optional[str] = None,
@@ -57,6 +58,7 @@ def __init__(
droppath_rate_attn: float = 0.0,
droppath_rate_ffn: float = 0.0,
hidden_dim_scaling: float = 4.0,
+ output_scale: float = 1.0,
**kwargs,
):
r"""
@@ -99,6 +101,9 @@ def __init__(
node_residual:
If node residual is used after on the gnn layer output
+ edge_residual:
+ If edge residual is used after on the gnn layer output
+
normalization:
Normalization to use. Choices:
@@ -141,6 +146,11 @@ def __init__(
attn_kwargs:
Keyword arguments to pass to the attention layer
+ output_scale:
+ Float value that will be used to scale the activations, helps reduce growth of activations
+
+ as the model gets deeper. Default value of 1.0 leaves the layer unchanged.
+
"""
super().__init__(
@@ -165,6 +175,7 @@ def __init__(
# Residual connections
self.node_residual = node_residual
+ self.edge_residual = edge_residual
self.precision = precision
@@ -190,6 +201,37 @@ def __init__(
self.mpnn = self._parse_mpnn_layer(mpnn_type, mpnn_kwargs)
self.attn_layer = self._parse_attn_layer(attn_type, self.biased_attention_key, attn_kwargs)
+ self.output_scale = output_scale
+ self.use_edges = True if self.in_dim_edges is not None else False
+
+ def residual_add(self, feature: Tensor, input_feature: Tensor) -> Tensor:
+ r"""
+ Residual additition layer. Allows information to propagate through the model
+ by skipping the computational layers.
+ Parameters:
+ feature: The feature (typically nodes or edges) after message passing
+ input_feature: The same feature from before message passing
+ Returns:
+ The addition of the two tensors.
+ """
+ feature += input_feature
+ return feature
+
+ def scale_activations(self, feature: Tensor, scale_factor: Tensor) -> Tensor:
+ """Scale Activations by a constant factor to stop growth of activation scale
+ and reduce numerical stability issues at low precision
+
+ Args:
+ feature (Tensor): The feature to scale
+ scale_factor (float): The floating point scale factor
+
+ Returns:
+ Tensor: The scaled features
+ """
+ scale_factor = torch.tensor(scale_factor).to(feature.device)
+ feature *= scale_factor.to(dtype=feature.dtype)
+ return feature
+
def forward(self, batch: Batch) -> Batch:
r"""
forward function of the layer
@@ -200,6 +242,8 @@ def forward(self, batch: Batch) -> Batch:
"""
# pe, feat, edge_index, edge_feat = batch.pos_enc_feats_sign_flip, batch.feat, batch.edge_index, batch.edge_feat
feat = batch.feat
+ if self.use_edges:
+ edges_feat_in = batch.edge_feat
feat_in = feat # for first residual connection
@@ -208,10 +252,21 @@ def forward(self, batch: Batch) -> Batch:
if self.mpnn is not None:
batch_out = self.mpnn(batch_out)
h_local = batch_out.feat
+ e_local = batch_out.edge_feat
if self.dropout_local is not None:
h_local = self.dropout_local(h_local)
+ # Apply the residual connection for the node features
if self.node_residual:
- h_local = feat_in + h_local # Residual connection for nodes, not used in gps++.
+ h_local = self.residual_add(h_local, feat_in)
+ # Scale the activations by some value to help reduce activation growth
+ h_local = self.scale_activations(h_local, self.output_scale)
+ # Apply the residual connection for the edge features
+ if self.edge_residual and self.use_edges:
+ e_local = self.residual_add(e_local, edges_feat_in)
+ # Scale the activations by some value to help reduce activation growth
+ if self.use_edges:
+ e_local = self.scale_activations(e_local, self.output_scale)
+
if self.norm_layer_local is not None:
h_local = self.norm_layer_local(h_local)
@@ -240,7 +295,7 @@ def forward(self, batch: Batch) -> Batch:
def _parse_mpnn_layer(self, mpnn_type, mpnn_kwargs: Dict[str, Any]) -> Optional[Module]:
"""Parse the MPNN layer."""
- if mpnn_type is None:
+ if mpnn_type is None or mpnn_type == "none":
return
mpnn_kwargs = deepcopy(mpnn_kwargs)
@@ -375,7 +430,7 @@ def _self_attention_block(self, feat: Tensor, feat_in: Tensor, batch: Batch) ->
)
attn_bias = None
- if self.biased_attention_key is not None:
+ if self.biased_attention_key is not None and self.biased_attention_key != "none":
attn_bias = batch[self.biased_attention_key]
# h_dense[num_graphs, max_num_nodes, hidden_dim] -> feat_attn[num_graphs, max_num_nodes, hidden_dim]
@@ -463,6 +518,8 @@ def layer_outputs_edges(self) -> bool:
bool:
Always ``False`` for the current class
"""
+ if self.mpnn is None:
+ return False
return self.mpnn.layer_outputs_edges
@property
diff --git a/graphium/nn/pyg_layers/mpnn_pyg.py b/graphium/nn/pyg_layers/mpnn_pyg.py
index f2cdcb16c..25df03714 100644
--- a/graphium/nn/pyg_layers/mpnn_pyg.py
+++ b/graphium/nn/pyg_layers/mpnn_pyg.py
@@ -130,14 +130,15 @@ def __init__(
self.num_edge_mlp = num_edge_mlp
self.edge_dropout_rate = edge_dropout_rate
- self.aggregator = MultiAggregation(aggregation_method)
+ self.aggregator = MultiAggregation(list(aggregation_method))
+ n_agg = len(aggregation_method)
# node_model:
edge_dim = self.out_dim_edges if use_edges else self.in_dim_edges
if self.node_combine_method == "concat":
- node_model_in_dim = 3 * self.in_dim + 2 * edge_dim
+ node_model_in_dim = (1 + 2 * n_agg) * self.in_dim + 2 * n_agg * edge_dim
elif self.node_combine_method == "sum":
- node_model_in_dim = 2 * self.in_dim + edge_dim
+ node_model_in_dim = (1 + n_agg) * self.in_dim + n_agg * edge_dim
else:
raise ValueError(f"node_combine_method {self.node_combine_method} not recognised.")
node_model_hidden_dim = self.mlp_expansion_ratio * self.in_dim
diff --git a/graphium/trainer/predictor.py b/graphium/trainer/predictor.py
index c4e700895..588d7e3f2 100644
--- a/graphium/trainer/predictor.py
+++ b/graphium/trainer/predictor.py
@@ -46,6 +46,9 @@ def __init__(
flag_kwargs: Dict[str, Any] = None,
task_norms: Optional[Dict[Callable, Any]] = None,
metrics_every_n_train_steps: Optional[int] = None,
+ replicas: int = 1,
+ gradient_acc: int = 1,
+ global_bs: Optional[int] = 1,
):
"""
The Lightning module responsible for handling the predictions, losses, metrics, optimization, etc.
@@ -175,6 +178,9 @@ def __init__(
self.metrics_every_n_train_steps = metrics_every_n_train_steps
# Wether save preds and targets for each training step.
+ self.samples_seen = 0
+ self.global_bs = global_bs
+
def forward(
self, inputs: Dict
) -> Dict[str, Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]]:
@@ -221,6 +227,7 @@ def configure_optimizers(self, impl=None):
# Define the optimizer and schedulers
optimiser = MuAdam(self.parameters(), **self.optim_options.optim_kwargs, impl=impl)
+ self.optim_options.torch_scheduler_kwargs.pop("module_type")
torch_scheduler = self.optim_options.scheduler_class(
optimizer=optimiser, **self.optim_options.torch_scheduler_kwargs
)
@@ -461,6 +468,10 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None:
# Get the metrics that are logged at every step (loss, grad_norm, batch_time, batch_tput)
concatenated_metrics_logs = {}
concatenated_metrics_logs["train/loss"] = outputs["loss"]
+ concatenated_metrics_logs["epoch_count"] = self.current_epoch
+ # Incriment by the batch size
+ self.samples_seen += self.global_bs
+ concatenated_metrics_logs["samples_seen"] = self.samples_seen
# report the training loss for each individual tasks
for task in self.tasks:
@@ -618,11 +629,6 @@ def on_validation_epoch_end(self) -> None:
concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs)
concatenated_metrics_logs["val/mean_time"] = torch.tensor(self.mean_val_time_tracker.mean_value)
concatenated_metrics_logs["val/mean_tput"] = self.mean_val_tput_tracker.mean_value
-
- if hasattr(self.optimizers(), "param_groups"):
- lr = self.optimizers().param_groups[0]["lr"]
- concatenated_metrics_logs["lr"] = torch.tensor(lr)
- concatenated_metrics_logs["n_epochs"] = torch.tensor(self.current_epoch, dtype=torch.float32)
self.log_dict(concatenated_metrics_logs)
# Save yaml file with the per-task metrics summaries
diff --git a/graphium/trainer/predictor_options.py b/graphium/trainer/predictor_options.py
index 20e193bca..25fc6fad0 100644
--- a/graphium/trainer/predictor_options.py
+++ b/graphium/trainer/predictor_options.py
@@ -76,6 +76,7 @@ class OptimOptions:
# Instead of passing a dictionary to be processed by the predictor,
# this class will process the dictionary in advance and return the optimizer
def set_kwargs(self):
+ torch_scheduler_kwargs = deepcopy(self.torch_scheduler_kwargs)
# Set the parameters and default value for the optimizer, and check values
if self.optim_kwargs is None:
self.optim_kwargs = {}
@@ -94,12 +95,12 @@ def set_kwargs(self):
self.scheduler_kwargs.setdefault("strict", True)
# Set the pytorch scheduler arguments
- if self.torch_scheduler_kwargs is None:
- self.torch_scheduler_kwargs = {}
- self.torch_scheduler_kwargs.setdefault("module_type", "ReduceLROnPlateau")
+ if torch_scheduler_kwargs is None:
+ torch_scheduler_kwargs = {}
+ torch_scheduler_kwargs.setdefault("module_type", "ReduceLROnPlateau")
# Get the class for the scheduler
- scheduler_class = self.torch_scheduler_kwargs.pop("module_type")
+ scheduler_class = torch_scheduler_kwargs.pop("module_type")
if self.scheduler_class is None:
if isinstance(scheduler_class, str):
self.scheduler_class = SCHEDULER_DICT[scheduler_class]
@@ -112,9 +113,9 @@ def set_kwargs(self):
sig = signature(self.scheduler_class.__init__)
key_args = [p.name for p in sig.parameters.values()]
if "monitor" in key_args:
- self.torch_scheduler_kwargs.setdefault("monitor", self.scheduler_kwargs["monitor"])
+ torch_scheduler_kwargs.setdefault("monitor", self.scheduler_kwargs["monitor"])
if "mode" in key_args:
- self.torch_scheduler_kwargs.setdefault("mode", self.scheduler_kwargs["mode"])
+ torch_scheduler_kwargs.setdefault("mode", self.scheduler_kwargs["mode"])
@dataclass
diff --git a/graphium/trainer/predictor_summaries.py b/graphium/trainer/predictor_summaries.py
index d62e50a42..8ce863e74 100644
--- a/graphium/trainer/predictor_summaries.py
+++ b/graphium/trainer/predictor_summaries.py
@@ -248,8 +248,6 @@ def get_metrics_logs(self) -> Dict[str, Any]:
metric_logs[self.metric_log_name(self.task_name, "median_target", self.step_name)] = nan_median(
targets
)
- if torch.cuda.is_available():
- metric_logs[f"gpu_allocated_GB"] = torch.tensor(torch.cuda.memory_allocated() / (2**30))
# Specify which metrics to use
metrics_to_use = self.metrics
diff --git a/graphium/utils/spaces.py b/graphium/utils/spaces.py
index 6658d0ca0..d821223a4 100644
--- a/graphium/utils/spaces.py
+++ b/graphium/utils/spaces.py
@@ -52,6 +52,8 @@
}
LOSS_DICT = {
+ "bce": torch.nn.BCELoss,
+ "bce_logits": torch.nn.BCEWithLogitsLoss,
"mse": torch.nn.MSELoss,
"bce": torch.nn.BCELoss,
"l1": torch.nn.L1Loss,
@@ -106,7 +108,7 @@
"msle": TorchMetrics.mean_squared_log_error,
"pearsonr": TorchMetrics.pearson_corrcoef,
"spearmanr": TorchMetrics.spearman_corrcoef,
- "r2": TorchMetrics.r2_score,
+ "r2_score": TorchMetrics.r2_score,
"cosine": TorchMetrics.cosine_similarity,
"pearsonr_ipu": Metrics.pearson_ipu,
"spearmanr_ipu": Metrics.spearman_ipu,
diff --git a/scripts/scale_mpnn.sh b/scripts/scale_mpnn.sh
new file mode 100644
index 000000000..8cd61fb86
--- /dev/null
+++ b/scripts/scale_mpnn.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+graphium-train \
+ --config-path=/home/frederik_valencediscovery_com/projects/graphium_hps/expts/configs/ \
+ --config-name=config_mpnn_base.yaml \
+ constants.max_epochs=100 \
+ trainer.model_checkpoint.dirpath=model_checkpoints/large-dataset/scale_mpnn/ \
+ +architecture.mup_scale_factor=2 +architecture.mup_base_path=mup/mpnn_base/base_shapes.yaml \
+ datamodule.args.batch_size_inference=1024 datamodule.args.batch_size_training=1024 +trainer.trainer.accumulate_grad_batches=2 \
\ No newline at end of file