Scaling molecular GNNs to infinity

A deep learning library focused on graph representation learning for real-world chemical tasks.

  • ✅ State-of-the-art GNN architectures.
  • 🐍 Extensible API: build your own GNN model and train it with ease.
  • ⚗️ Rich featurization: powerful and flexible built-in molecular featurization.
  • 🧠 Pretrained models: for fast and easy inference or transfer learning.
  • ⮔ Read-to-use training loop based on Pytorch Lightning.
  • 🔌 Have a new dataset? Graphium provides a simple plug-and-play interface. Change the path, the name of the columns to predict, the atomic featurization, and you’re ready to play!



Installation for users


conda-forge is the recommended method for installing Graphium. To install Graphium via conda-forge, run the following command:

mamba install graphium -c conda-forge

Note: we recommend using mamba instead of conda. It is a faster and better alternative.


To install Graphium via PyPi, run the following command:

pip install graphium

Note: the latest available version of Graphium on PyPi is 2.4.7. This is due to the addition of C++ code in version 3.0.0 that depends on packages only available via conda-forge. There are plans to eventually support Graphium >=3.0.0 on PyPi.

Installation for developers

If you are using a GPU, we recommend enforcing the CUDA version that you need with CONDA_OVERRIDE_CUDA=XX.X.

# Install Graphium's dependencies in a new environment named `graphium`
mamba env create -f env.yml -n graphium

# To force the CUDA version to 11.2, or any other version you prefer, use the following command:
# CONDA_OVERRIDE_CUDA=11.2 mamba env create -f env.yml -n graphium

# Activate the mamba environment containing Graphium's dependencies
mamba activate graphium

# Install Graphium in dev mode
pip install --no-deps --no-build-isolation -e .

Training a model

To learn how to train a model, we invite you to look at the documentation, or the jupyter notebooks available here.

If you are not familiar with PyTorch or PyTorch-Lightning, we highly recommend going through their tutorial first.

Running an experiment


Graphium provides configs for 2 datasets: toymix and largemix. Toymix uses 3 datasets, which are referenced in datamodule here. Its datasets and their splits files can be downloaded from here:

# Change or make the directory to where the dataset is to be downloaded
cd expts/data/neurips2023/small-dataset

# QM9 

# Tox21

# Zinc

Largemix uses datasets referenced in datamodule here. Its datasets and their splits files can be downloaded from here:

# Change or make the directory to where the dataset is to be downloaded
cd ../data/graphium/large-dataset/

# L1000_VCAP

# L1000_MCF7

# PCBA_1328

# PCQM4M_G25


These datasets can be used further for pretraining.


We have setup Graphium with hydra for managing config files. To run an experiment go to the expts/ folder. For example, to benchmark a GCN on the ToyMix dataset run

graphium-train architecture=toymix tasks=toymix training=toymix model=gcn

To change parameters specific to this experiment like switching from fp16 to fp32 precision, you can either override them directly in the CLI via

graphium-train architecture=toymix tasks=toymix training=toymix model=gcn trainer.trainer.precision=32

or change them permanently in the dedicated experiment config under expts/hydra-configs/toymix_gcn.yaml. Integrating hydra also allows you to quickly switch between accelerators. E.g., running

graphium-train architecture=toymix tasks=toymix training=toymix model=gcn accelerator=gpu

automatically selects the correct configs to run the experiment on GPU. To use Largemix dataset instead, replace toymix to largemix in the above commmands.

To use a config file you built from scratch you can run

graphium-train --config-path [PATH] --config-name [CONFIG]

Thanks to the modular nature of hydra you can reuse many of our config settings for your own experiments with Graphium.


After pretraining a model and saving a model checkpoint, the model can be finetuned to a new task

graphium-train +finetuning [example-custom OR example-tdc] finetuning.pretrained_model=[model_identifier]

The [model_identifier] serves to identify the pretrained model among those maintained in the GRAPHIUM_PRETRAINED_MODELS_DICT in graphium/utils/, where the [model_identifier] maps to the location of the checkpoint of the pretrained model.

We have provided two example yaml configs under expts/hydra-configs/finetuning for finetuning on a custom dataset (example-custom.yaml) or for a task from the TDC benchmark collection (example-tdc.yaml).

When using example-custom.yaml, to finetune on a custom dataset, we nee to provide the location of the data (constants.data_path=[path_to_data]) and the type of task (constants.task_type=[cls OR reg]).

When using example-tdc.yaml, to finetune on a TDC task, we only need to provide the task name (constants.task=[task_name]) and the task type is inferred automatically.

Custom datasets to finetune from consist of two files raw.csv and split.csv. The raw.csv contains two columns, namely smiles with the smiles strings, and target with the corresponding targets. In split.csv, three columns train, val, test contain the indices of the rows in raw.csv. Examples can be found under expts/data/finetuning_example-reg (regression) and expts/data/finetuning_example-cls (binary classification).


Alternatively, we can also obtain molecular embeddings (fingerprints) from a pretrained model:

graphium fps create [example-custom OR example-tdc] pretrained.model=[model_identifier] pretrained.layers=[layer_identifiers]

We have provided two example yaml configs under expts/hydra-configs/fingerprinting for extracting fingerprints for a custom dataset (example-custom.yaml) or for a dataset from the TDC benchmark collection (expample-tdc.yaml).

After specifiying the [model_identifier], we need to provide a list of layers from that model where we want to read out embeddings via [layer_identifiers] (which requires knowledge of the architecture of the pretrained model).

When using example-custom.yaml, the location of the smiles to be embedded needs to be passed via datamodule.df_path=[path_to_data]. The data can be passed as a csv/parquet file with a column smiles, similar to expts/data/finetuning_example-reg/raw.csv.

When extracting fingerprints for a TDC task using expample-tdc.yaml, we need to specify datamodule.benchmark and datamodule.task instead of datamodule.df_path.


Under the Apache-2.0 license. See LICENSE.


  • Diagram for data processing in Graphium.

Data Processing Chart

  • Diagram for Muti-task network in Graphium

Full Graph Multi-task Network