diff --git a/README.md b/README.md index 290e84b1b..39764b620 100644 --- a/README.md +++ b/README.md @@ -85,21 +85,26 @@ If you are not familiar with [PyTorch](https://pytorch.org/docs) or [PyTorch-Lig ## Running an experiment We have setup Graphium with `hydra` for managing config files. To run an experiment go to the `expts/` folder. For example, to benchmark a GCN on the ToyMix dataset run ```bash -python main_run_multitask.py dataset=toymix model=gcn +graphium-train dataset=toymix model=gcn ``` To change parameters specific to this experiment like switching from `fp16` to `fp32` precision, you can either override them directly in the CLI via ```bash -python main_run_multitask.py dataset=toymix model=gcn trainer.trainer.precision=32 +graphium-train dataset=toymix model=gcn trainer.trainer.precision=32 ``` or change them permamently in the dedicated experiment config under `expts/hydra-configs/toymix_gcn.yaml`. Integrating `hydra` also allows you to quickly switch between accelerators. E.g., running ```bash -python main_run_multitask.py dataset=toymix model=gcn accelerator=gpu +graphium-train dataset=toymix model=gcn accelerator=gpu ``` automatically selects the correct configs to run the experiment on GPU. +Finally, you can also run a fine-tuning loop: +```bash +graphium-train +finetuning=admet +``` + To use a config file you built from scratch you can run ```bash -python main_run_multitask.py --config-path [PATH] --config-name [CONFIG] +graphium-train --config-path [PATH] --config-name [CONFIG] ``` Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium. diff --git a/docs/cli_references.md b/docs/cli_references.md index 52d72720f..b65bb2fba 100644 --- a/docs/cli_references.md +++ b/docs/cli_references.md @@ -5,4 +5,5 @@ This page provides documentation for our command line tools. ::: mkdocs-click :module: graphium.cli :command: main_cli - :command: data_cli + :style: table + :prog_name: graphium diff --git a/docs/tutorials/model_training/running-multitask-ipu.ipynb b/docs/tutorials/model_training/running-multitask-ipu.ipynb index 8da432e4d..05a972e9b 100644 --- a/docs/tutorials/model_training/running-multitask-ipu.ipynb +++ b/docs/tutorials/model_training/running-multitask-ipu.ipynb @@ -420,7 +420,14 @@ "logger.info(metrics)\n", "\n", "predictor = load_predictor(\n", - " cfg, model_class, model_kwargs, metrics, accelerator_type, datamodule.task_norms\n", + " cfg,\n", + " model_class,\n", + " model_kwargs,\n", + " metrics,\n", + " datamodule.get_task_levels(),\n", + " accelerator_type,\n", + " datamodule.featurization,\n", + " datamodule.task_norms\n", ")\n", "logger.info(predictor.model)\n", "logger.info(ModelSummary(predictor, max_depth=4))" diff --git a/docs/tutorials/model_training/simple-molecular-model.ipynb b/docs/tutorials/model_training/simple-molecular-model.ipynb index 717e1413e..26a45cfa0 100644 --- a/docs/tutorials/model_training/simple-molecular-model.ipynb +++ b/docs/tutorials/model_training/simple-molecular-model.ipynb @@ -10,30 +10,30 @@ "\n", "The work flow of testing your code on the entire pipeline is as follows:\n", "\n", - "1. select a corresponding yaml file in the [expts/main_run_multitask.py](https://github.com/datamol-io/graphium/blob/master/expts/main_run_multitask.py) i.e. by `CONFIG_FILE = \"expts/configs/config_gps_10M_pcqm4m_mod.yaml\"`\n", - "2. modify the yaml config file\n", - "3. `python expts/main_run_multitask.py`\n", - "\n", - "There are multiple examples of YAML files located in the folder `graphium/expts/configs` that one can refer to when training a new model. The file `config_gps_10M_pcqm4m_mod.yaml` shows an example of running the GPS model on the pcqm4m dataset.\n", + "1. Select a subset of the [available configs](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs) as a starting point.\n", + "2. Create additional configs or modify the existing configs to suit your needs.\n", + "3. Train or fine-tune a model with the `graphium-train` CLI.\n", "\n", "## Creating the yaml file\n", "\n", - "The first step is to create a YAML file containing all the required configurations, with an example given at `graphium/expts/config_gps_10M_pcqm4m_mod.yaml`. We will go through each part of the configurations." + "The first step is to create a YAML file containing all the required configurations, with an example given at `graphium/expts/hydra-configs/main.yaml`. We will go through each part of the configurations. See also the README [here](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs)." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "import yaml\n", - "import omegaconf" + "import omegaconf\n", + "\n", + "from hydra import compose, initialize" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -57,8 +57,8 @@ ], "source": [ "# First, let's read the yaml configuration file\n", - "with open(\"../../../expts/configs/config_gps_10M_pcqm4m_mod.yaml\", \"r\") as file:\n", - " yaml_config = yaml.load(file, Loader=yaml.FullLoader)\n", + "with initialize(version_base=None, config_path=\"../../../expts/hydra-configs\"):\n", + " yaml_config = compose(config_name=\"main\")\n", "\n", "print(\"Yaml file loaded\")" ] @@ -74,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -82,11 +82,11 @@ "output_type": "stream", "text": [ "constants:\n", - " name: pcqm4mv2_mpnn_4layer\n", + " name: neurips2023_small_data_gcn\n", " seed: 42\n", + " max_epochs: 100\n", + " data_dir: expts/data/neurips2023/small-dataset\n", " raise_train_error: true\n", - " accelerator:\n", - " type: gpu\n", "\n" ] } @@ -108,7 +108,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -118,22 +118,14 @@ "datamodule:\n", " module_type: MultitaskFromSmilesDataModule\n", " args:\n", - " task_specific_args:\n", - " homolumo:\n", - " df: null\n", - " task_level: graph\n", - " df_path: ~/scratch/data/graphium/data/PCQM4M/pcqm4mv2-20k.csv\n", - " smiles_col: cxsmiles\n", - " label_cols:\n", - " - homo_lumo_gap\n", - " split_val: 0.1\n", - " split_test: 0.1\n", " prepare_dict_or_graph: pyg:graph\n", - " featurization_n_jobs: 30\n", + " featurization_n_jobs: 4\n", " featurization_progress: true\n", " featurization_backend: loky\n", + " processed_graph_data_path: ../datacache/neurips2023-small/\n", + " num_workers: 4\n", + " persistent_workers: false\n", " featurization:\n", - " mask_nan: 0\n", " atom_property_list_onehot:\n", " - atomic-number\n", " - group\n", @@ -149,55 +141,94 @@ " - bond-type-onehot\n", " - stereo\n", " - in-ring\n", - " conformer_property_list:\n", - " - positions_3d\n", " add_self_loop: false\n", " explicit_H: false\n", " use_bonds_weights: false\n", " pos_encoding_as_features:\n", " pos_types:\n", - " node_laplacian_eigvec:\n", - " pos_type: laplacian_eigvec\n", + " lap_eigvec:\n", " pos_level: node\n", + " pos_type: laplacian_eigvec\n", " num_pos: 8\n", " normalization: none\n", " disconnected_comp: true\n", - " node_laplacian_eigval:\n", - " pos_type: laplacian_eigval\n", + " lap_eigval:\n", " pos_level: node\n", + " pos_type: laplacian_eigval\n", " num_pos: 8\n", " normalization: none\n", " disconnected_comp: true\n", - " rw_return_probs:\n", - " pos_type: rw_return_probs\n", + " rw_pos:\n", " pos_level: node\n", - " ksteps:\n", - " - 4\n", - " - 8\n", - " nodepair_rw_transition_probs:\n", - " pos_type: rw_transition_probs\n", - " pos_level: edge\n", - " ksteps:\n", - " - 2\n", - " - 4\n", - " nodepair_rw_return_probs:\n", " pos_type: rw_return_probs\n", - " pos_level: nodepair\n", - " ksteps:\n", - " - 4\n", - " electrostatic:\n", - " pos_type: electrostatic\n", - " pos_level: node\n", - " edge_commute:\n", - " pos_type: commute\n", - " pos_level: edge\n", - " nodepair_graphormer:\n", - " pos_type: graphormer\n", - " pos_level: nodepair\n", - " batch_size_training: 64\n", - " batch_size_inference: 16\n", - " num_workers: 0\n", - " persistent_workers: false\n", + " ksteps: 16\n", + " task_specific_args:\n", + " qm9:\n", + " df: null\n", + " df_path: ${constants.data_dir}/qm9.csv.gz\n", + " smiles_col: smiles\n", + " label_cols:\n", + " - A\n", + " - B\n", + " - C\n", + " - mu\n", + " - alpha\n", + " - homo\n", + " - lumo\n", + " - gap\n", + " - r2\n", + " - zpve\n", + " - u0\n", + " - u298\n", + " - h298\n", + " - g298\n", + " - cv\n", + " - u0_atom\n", + " - u298_atom\n", + " - h298_atom\n", + " - g298_atom\n", + " splits_path: ${constants.data_dir}/qm9_random_splits.pt\n", + " seed: ${constants.seed}\n", + " task_level: graph\n", + " label_normalization:\n", + " normalize_val_test: true\n", + " method: normal\n", + " tox21:\n", + " df: null\n", + " df_path: ${constants.data_dir}/Tox21-7k-12-labels.csv.gz\n", + " smiles_col: smiles\n", + " label_cols:\n", + " - NR-AR\n", + " - NR-AR-LBD\n", + " - NR-AhR\n", + " - NR-Aromatase\n", + " - NR-ER\n", + " - NR-ER-LBD\n", + " - NR-PPAR-gamma\n", + " - SR-ARE\n", + " - SR-ATAD5\n", + " - SR-HSE\n", + " - SR-MMP\n", + " - SR-p53\n", + " splits_path: ${constants.data_dir}/Tox21_random_splits.pt\n", + " seed: ${constants.seed}\n", + " task_level: graph\n", + " zinc:\n", + " df: null\n", + " df_path: ${constants.data_dir}/ZINC12k.csv.gz\n", + " smiles_col: smiles\n", + " label_cols:\n", + " - SA\n", + " - logp\n", + " - score\n", + " splits_path: ${constants.data_dir}/ZINC12k_random_splits.pt\n", + " seed: ${constants.seed}\n", + " task_level: graph\n", + " label_normalization:\n", + " normalize_val_test: true\n", + " method: normal\n", + " batch_size_training: 200\n", + " batch_size_inference: 200\n", "\n" ] } @@ -226,7 +257,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -237,134 +268,106 @@ " model_type: FullGraphMultiTaskNetwork\n", " mup_base_path: null\n", " pre_nn:\n", - " out_dim: 32\n", - " hidden_dims: 64\n", + " out_dim: 64\n", + " hidden_dims: 256\n", " depth: 2\n", " activation: relu\n", " last_activation: none\n", - " dropout: 0.1\n", + " dropout: 0.18\n", " normalization: layer_norm\n", - " last_normalization: layer_norm\n", - " residual_type: none\n", - " pre_nn_edges:\n", - " out_dim: 16\n", - " hidden_dims: 32\n", - " depth: 2\n", - " activation: relu\n", - " last_activation: none\n", - " dropout: 0.1\n", - " normalization: layer_norm\n", - " last_normalization: layer_norm\n", + " last_normalization: ${architecture.pre_nn.normalization}\n", " residual_type: none\n", + " pre_nn_edges: null\n", " pe_encoders:\n", " out_dim: 32\n", - " edge_out_dim: 16\n", " pool: sum\n", " last_norm: None\n", " encoders:\n", - " emb_la_pos:\n", + " la_pos:\n", " encoder_type: laplacian_pe\n", " input_keys:\n", " - laplacian_eigvec\n", " - laplacian_eigval\n", " output_keys:\n", " - feat\n", - " hidden_dim: 32\n", + " hidden_dim: 64\n", + " out_dim: 32\n", " model_type: DeepSet\n", " num_layers: 2\n", " num_layers_post: 1\n", " dropout: 0.1\n", " first_normalization: none\n", - " emb_rwse:\n", + " rw_pos:\n", " encoder_type: mlp\n", " input_keys:\n", " - rw_return_probs\n", " output_keys:\n", " - feat\n", - " hidden_dim: 32\n", + " hidden_dim: 64\n", + " out_dim: 32\n", " num_layers: 2\n", " dropout: 0.1\n", " normalization: layer_norm\n", " first_normalization: layer_norm\n", - " emb_electrostatic:\n", - " encoder_type: mlp\n", - " input_keys:\n", - " - electrostatic\n", - " output_keys:\n", - " - feat\n", - " hidden_dim: 32\n", - " num_layers: 1\n", - " dropout: 0.1\n", - " normalization: layer_norm\n", - " first_normalization: layer_norm\n", - " emb_edge_rwse:\n", - " encoder_type: mlp\n", - " input_keys:\n", - " - edge_rw_transition_probs\n", - " output_keys:\n", - " - edge_feat\n", - " hidden_dim: 32\n", - " num_layers: 1\n", - " dropout: 0.1\n", - " normalization: layer_norm\n", - " emb_edge_pes:\n", - " encoder_type: cat_mlp\n", - " input_keys:\n", - " - edge_rw_transition_probs\n", - " - edge_commute\n", - " output_keys:\n", - " - edge_feat\n", - " hidden_dim: 32\n", - " num_layers: 1\n", - " dropout: 0.1\n", - " normalization: layer_norm\n", - " gaussian_pos:\n", - " encoder_type: gaussian_kernel\n", - " input_keys:\n", - " - positions_3d\n", - " output_keys:\n", - " - feat\n", - " - nodepair_gaussian_bias_3d\n", - " num_heads: 2\n", - " num_layers: 2\n", - " embed_dim: 32\n", - " use_input_keys_prefix: false\n", " gnn:\n", - " out_dim: 32\n", - " hidden_dims: 32\n", + " in_dim: 64\n", + " out_dim: 96\n", + " hidden_dims: 96\n", " depth: 4\n", " activation: gelu\n", " last_activation: none\n", - " dropout: 0.0\n", + " dropout: 0.1\n", " normalization: layer_norm\n", - " last_normalization: layer_norm\n", + " last_normalization: ${architecture.pre_nn.normalization}\n", " residual_type: simple\n", - " pooling:\n", - " - sum\n", " virtual_node: none\n", - " layer_type: pyg:gps\n", - " layer_kwargs:\n", - " node_residual: false\n", - " mpnn_type: pyg:mpnnplus\n", - " mpnn_kwargs:\n", - " in_dim: 32\n", - " out_dim: 32\n", - " in_dim_edges: 16\n", - " out_dim_edges: 16\n", - " attn_type: full-attention\n", - " attn_kwargs:\n", - " num_heads: 2\n", - " biased_attention_key: nodepair_gaussian_bias_3d\n", - " post_nn: null\n", + " layer_type: pyg:gcn\n", + " layer_kwargs: null\n", + " graph_output_nn:\n", + " graph:\n", + " pooling:\n", + " - sum\n", + " out_dim: 96\n", + " hidden_dims: 96\n", + " depth: 1\n", + " activation: relu\n", + " last_activation: none\n", + " dropout: ${architecture.pre_nn.dropout}\n", + " normalization: ${architecture.pre_nn.normalization}\n", + " last_normalization: none\n", + " residual_type: none\n", " task_heads:\n", - " homolumo:\n", - " out_dim: 1\n", - " hidden_dims: 256\n", + " qm9:\n", + " task_level: graph\n", + " out_dim: 19\n", + " hidden_dims: 128\n", " depth: 2\n", " activation: relu\n", " last_activation: none\n", - " dropout: 0.1\n", - " normalization: layer_norm\n", + " dropout: ${architecture.pre_nn.dropout}\n", + " normalization: ${architecture.pre_nn.normalization}\n", + " last_normalization: none\n", + " residual_type: none\n", + " tox21:\n", + " task_level: graph\n", + " out_dim: 12\n", + " hidden_dims: 64\n", + " depth: 2\n", + " activation: relu\n", + " last_activation: none\n", + " dropout: ${architecture.pre_nn.dropout}\n", + " normalization: ${architecture.pre_nn.normalization}\n", + " last_normalization: none\n", + " residual_type: none\n", + " zinc:\n", + " task_level: graph\n", + " out_dim: 3\n", + " hidden_dims: 32\n", + " depth: 2\n", + " activation: relu\n", + " last_activation: none\n", + " dropout: ${architecture.pre_nn.dropout}\n", + " normalization: ${architecture.pre_nn.normalization}\n", " last_normalization: none\n", " residual_type: none\n", "\n" @@ -386,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -395,24 +398,28 @@ "text": [ "predictor:\n", " metrics_on_progress_bar:\n", - " homolumo:\n", + " qm9:\n", + " - mae\n", + " tox21:\n", + " - auroc\n", + " zinc:\n", " - mae\n", - " - pearsonr\n", " loss_fun:\n", - " homolumo: mse_ipu\n", - " random_seed: 42\n", + " qm9: mae_ipu\n", + " tox21: bce_logits_ipu\n", + " zinc: mae_ipu\n", + " random_seed: ${constants.seed}\n", " optim_kwargs:\n", - " lr: 0.0004\n", + " lr: 4.0e-05\n", " torch_scheduler_kwargs:\n", " module_type: WarmUpLinearLR\n", - " max_num_epochs: 5\n", + " max_num_epochs: ${constants.max_epochs}\n", " warmup_epochs: 10\n", " verbose: false\n", " scheduler_kwargs: null\n", " target_nan_mask: null\n", - " flag_kwargs:\n", - " n_steps: 0\n", - " alpha: 0.0\n", + " multitask_handling: flatten\n", + " metrics_every_n_train_steps: 300\n", "\n" ] } @@ -434,7 +441,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -442,7 +449,54 @@ "output_type": "stream", "text": [ "metrics:\n", - " homolumo:\n", + " qm9:\n", + " - name: mae\n", + " metric: mae_ipu\n", + " target_nan_mask: null\n", + " multitask_handling: flatten\n", + " threshold_kwargs: null\n", + " - name: pearsonr\n", + " metric: pearsonr_ipu\n", + " threshold_kwargs: null\n", + " target_nan_mask: null\n", + " multitask_handling: mean-per-label\n", + " - name: r2_score\n", + " metric: r2_score_ipu\n", + " target_nan_mask: null\n", + " multitask_handling: mean-per-label\n", + " threshold_kwargs: null\n", + " tox21:\n", + " - name: auroc\n", + " metric: auroc_ipu\n", + " task: binary\n", + " multitask_handling: mean-per-label\n", + " threshold_kwargs: null\n", + " - name: avpr\n", + " metric: average_precision_ipu\n", + " task: binary\n", + " multitask_handling: mean-per-label\n", + " threshold_kwargs: null\n", + " - name: f1 > 0.5\n", + " metric: f1\n", + " multitask_handling: mean-per-label\n", + " target_to_int: true\n", + " num_classes: 2\n", + " average: micro\n", + " threshold_kwargs:\n", + " operator: greater\n", + " threshold: 0.5\n", + " th_on_preds: true\n", + " th_on_target: true\n", + " - name: precision > 0.5\n", + " metric: precision\n", + " multitask_handling: mean-per-label\n", + " average: micro\n", + " threshold_kwargs:\n", + " operator: greater\n", + " threshold: 0.5\n", + " th_on_preds: true\n", + " th_on_target: true\n", + " zinc:\n", " - name: mae\n", " metric: mae_ipu\n", " target_nan_mask: null\n", @@ -453,6 +507,11 @@ " threshold_kwargs: null\n", " target_nan_mask: null\n", " multitask_handling: mean-per-label\n", + " - name: r2_score\n", + " metric: r2_score_ipu\n", + " target_nan_mask: null\n", + " multitask_handling: mean-per-label\n", + " threshold_kwargs: null\n", "\n" ] } @@ -472,7 +531,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -480,21 +539,17 @@ "output_type": "stream", "text": [ "trainer:\n", - " logger:\n", - " save_dir: logs/PCQMv2\n", - " name: pcqm4mv2_mpnn_4layer\n", - " project: PCQMv2_mpnn\n", + " seed: ${constants.seed}\n", " model_checkpoint:\n", - " dirpath: models_checkpoints/PCMQv2/\n", - " filename: pcqm4mv2_mpnn_4layer\n", - " save_top_k: 1\n", - " every_n_epochs: 100\n", + " filename: ${constants.name}\n", + " save_last: true\n", + " dirpath: models_checkpoints/neurips2023-small-gcn/\n", " trainer:\n", " precision: 32\n", - " max_epochs: 5\n", + " max_epochs: ${constants.max_epochs}\n", " min_epochs: 1\n", - " accumulate_grad_batches: 2\n", " check_val_every_n_epoch: 20\n", + " accumulate_grad_batches: 1\n", "\n" ] } @@ -511,16 +566,13 @@ "\n", "Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below.\n", "\n", - "First make sure the dataset file is downloaded. \n", - "Using `config_gps_10M_pcqm4m.yaml` as an example, if the file at `df_path` in the config is downloaded.\n", - "In this case, we need to download `pcqm4mv2-20k.csv` into the specified directory `graphium/data/PCQM4M/pcqm4mv2-20k.csv`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "$`python expts/main_run_multitask.py`" + "First make sure the dataset file is downloaded. Using `config_gps_10M_pcqm4m.yaml` as an example, make sure the file specified by `df_path` in the config is available.\n", + "In this case, we need to download `pcqm4mv2-20k.csv` into the specified directory `graphium/data/PCQM4M/pcqm4mv2-20k.csv`.\n", + "\n", + "After that, we can simply run a training through the CLI:\n", + "```bash\n", + "graphium-train\n", + "```" ] } ], @@ -543,7 +595,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" }, "widgets": { "application/vnd.jupyter.widget-state+json": { diff --git a/env.yml b/env.yml index 7fc668692..e49d071a4 100644 --- a/env.yml +++ b/env.yml @@ -28,7 +28,7 @@ dependencies: - gcsfs >=2021.6 # ML packages - - cudatoolkit # works also with CPU-only system. + - cuda-version # works also with CPU-only system. - pytorch >=1.12 - lightning >=2.0 - torchmetrics >=0.7.0,<0.11 diff --git a/expts/__init__.py b/expts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/expts/configs/config_tdc_admet_demo.yaml b/expts/configs/config_tdc_admet_demo.yaml deleted file mode 100644 index aac4d0e50..000000000 --- a/expts/configs/config_tdc_admet_demo.yaml +++ /dev/null @@ -1,315 +0,0 @@ -# Testing the gcn model with the PCQMv2 dataset on IPU. -constants: - name: &name tdc_admet_demo - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -accelerator: - type: gpu # cpu or ipu or gpu - -datamodule: - module_type: "ADMETBenchmarkDataModule" - args: - # TDC specific - tdc_benchmark_names: null - tdc_train_val_seed: *seed - # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - processed_graph_data_path: "../datacache/tdc-admet-demo/" - featurization: - atom_property_list_onehot: [atomic-number, group, period, total-valence] - atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring] - edge_property_list: [bond-type-onehot, stereo, in-ring] - add_self_loop: False - explicit_H: False # if H is included - use_bonds_weights: False - pos_encoding_as_features: - pos_types: - lap_eigvec: - pos_level: node - pos_type: laplacian_eigvec - num_pos: 8 - normalization: "none" # nomrlization already applied on the eigen vectors - disconnected_comp: True # if eigen values/vector for disconnected graph are included - lap_eigval: - pos_level: node - pos_type: laplacian_eigval - num_pos: 8 - normalization: "none" # nomrlization already applied on the eigen vectors - disconnected_comp: True # if eigen values/vector for disconnected graph are included - rw_pos: # use same name as pe_encoder - pos_level: node - pos_type: rw_return_probs - ksteps: 16 - - num_workers: -1 # -1 to use all - persistent_workers: False # if use persistent worker at the start of each epoch. - - -architecture: - model_type: FullGraphMultiTaskNetwork - mup_base_path: null - pre_nn: # Set as null to avoid a pre-nn network - out_dim: 64 - hidden_dims: 256 - depth: 2 - activation: relu - last_activation: none - dropout: &dropout 0.18 - normalization: &normalization layer_norm - last_normalization: *normalization - residual_type: none - - pre_nn_edges: null # Set as null to avoid a pre-nn network - - pe_encoders: - out_dim: 32 - pool: "sum" #"mean" "max" - last_norm: None #"batch_norm", "layer_norm" - encoders: #la_pos | rw_pos - la_pos: # Set as null to avoid a pre-nn network - encoder_type: "laplacian_pe" - input_keys: ["laplacian_eigvec", "laplacian_eigval"] - output_keys: ["feat"] - hidden_dim: 64 - out_dim: 32 - model_type: 'DeepSet' #'Transformer' or 'DeepSet' - num_layers: 2 - num_layers_post: 1 # Num. layers to apply after pooling - dropout: 0.1 - first_normalization: "none" #"batch_norm" or "layer_norm" - rw_pos: - encoder_type: "mlp" - input_keys: ["rw_return_probs"] - output_keys: ["feat"] - hidden_dim: 64 - out_dim: 32 - num_layers: 2 - dropout: 0.1 - normalization: "layer_norm" #"batch_norm" or "layer_norm" - first_normalization: "layer_norm" #"batch_norm" or "layer_norm" - - - - gnn: # Set as null to avoid a post-nn network - in_dim: 64 # or otherwise the correct value - out_dim: &gnn_dim 96 - hidden_dims: *gnn_dim - depth: 4 - activation: gelu - last_activation: none - dropout: 0.1 - normalization: "layer_norm" - last_normalization: *normalization - residual_type: simple - virtual_node: 'none' - layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps - layer_kwargs: null # Parameters for the model itself. You could define dropout_attn: 0.1 - - - graph_output_nn: - graph: - pooling: [sum] - out_dim: *gnn_dim - hidden_dims: *gnn_dim - depth: 1 - activation: relu - last_activation: none - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - - task_heads: - caco2_wang: ®ression_head - task_level: graph - out_dim: 1 - hidden_dims: 64 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - hia_hou: &classification_head - task_level: graph - out_dim: 1 - hidden_dims: 64 - depth: 2 - activation: relu - last_activation: sigmoid - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - pgp_broccatelli: *classification_head - bioavailability_ma: *classification_head - lipophilicity_astrazeneca: *regression_head - solubility_aqsoldb: *regression_head - bbb_martins: *classification_head - ppbr_az: *regression_head - vdss_lombardo: *regression_head - cyp2d6_veith: *classification_head - cyp3a4_veith: *classification_head - cyp2c9_veith: *classification_head - cyp2d6_substrate_carbonmangels: *classification_head - cyp3a4_substrate_carbonmangels: *classification_head - cyp2c9_substrate_carbonmangels: *classification_head - half_life_obach: *regression_head - clearance_microsome_az: *regression_head - clearance_hepatocyte_az: *regression_head - herg: *classification_head - ames: *classification_head - dili: *classification_head - ld50_zhu: *regression_head - -#Task-specific -predictor: - metrics_on_progress_bar: - # All below metrics are directly copied from the TDC website. - # For more information, see https://tdcommons.ai/benchmark/admet_group/overview/ - caco2_wang: ["mae"] - hia_hou: ["auroc"] - pgp_broccatelli: ["auroc"] - bioavailability_ma: ["auroc"] - lipophilicity_astrazeneca: ["mae"] - solubility_aqsoldb: ["mae"] - bbb_martins: ["auroc"] - ppbr_az: ["mae"] - vdss_lombardo: ["spearman"] - cyp2d6_veith: ["auprc"] - cyp3a4_veith: ["auprc"] - cyp2c9_veith: ["auprc"] - cyp2d6_substrate_carbonmangels: ["auprc"] - cyp3a4_substrate_carbonmangels: ["auprc"] - cyp2c9_substrate_carbonmangels: ["auprc"] - half_life_obach: ["spearman"] - clearance_microsome_az: ["spearman"] - clearance_hepatocyte_az: ["spearman"] - herg: ["mae"] - ames: ["auroc"] - dili: ["auroc"] - ld50_zhu: ["auroc"] - loss_fun: - caco2_wang: mae - hia_hou: bce - pgp_broccatelli: bce - bioavailability_ma: bce - lipophilicity_astrazeneca: mae - solubility_aqsoldb: mae - bbb_martins: bce - ppbr_az: mae - vdss_lombardo: mae - cyp2d6_veith: bce - cyp3a4_veith: bce - cyp2c9_veith: bce - cyp2d6_substrate_carbonmangels: bce - cyp3a4_substrate_carbonmangels: bce - cyp2c9_substrate_carbonmangels: bce - half_life_obach: mae - clearance_microsome_az: mae - clearance_hepatocyte_az: mae - herg: bce - ames: bce - dili: bce - ld50_zhu: mae - random_seed: *seed - optim_kwargs: - lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs - torch_scheduler_kwargs: - module_type: WarmUpLinearLR - max_num_epochs: &max_epochs 10 - warmup_epochs: 10 - verbose: False - target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label - multitask_handling: flatten # flatten, mean-per-label - -# Task-specific -metrics: - caco2_wang: ®ression_metrics - - name: mae - metric: mae - target_nan_mask: null - multitask_handling: flatten - threshold_kwargs: null - - name: spearman - metric: spearmanr - threshold_kwargs: null - target_nan_mask: null - multitask_handling: mean-per-label - - name: pearson - metric: pearsonr - threshold_kwargs: null - target_nan_mask: null - multitask_handling: mean-per-label - - name: r2_score - metric: r2 - target_nan_mask: null - multitask_handling: mean-per-label - threshold_kwargs: null - hia_hou: &classification_metrics - - name: auroc - metric: auroc - task: binary - multitask_handling: mean-per-label - threshold_kwargs: null - - name: auprc - metric: average_precision - task: binary - multitask_handling: mean-per-label - threshold_kwargs: null - - name: accuracy - metric: accuracy - multitask_handling: mean-per-label - target_to_int: True - average: micro - threshold_kwargs: &threshold_05 - operator: greater - threshold: 0.5 - th_on_preds: True - th_on_target: True - - name: mcc - metric: mcc - num_classes: 2 - multitask_handling: mean-per-label - target_to_int: True - average: micro - threshold_kwargs: *threshold_05 - pgp_broccatelli: *classification_metrics - bioavailability_ma: *classification_metrics - lipophilicity_astrazeneca: *regression_metrics - solubility_aqsoldb: *regression_metrics - bbb_martins: *classification_metrics - ppbr_az: *regression_metrics - vdss_lombardo: *regression_metrics - cyp2d6_veith: *classification_metrics - cyp3a4_veith: *classification_metrics - cyp2c9_veith: *classification_metrics - cyp2d6_substrate_carbonmangels: *classification_metrics - cyp3a4_substrate_carbonmangels: *classification_metrics - cyp2c9_substrate_carbonmangels: *classification_metrics - half_life_obach: *regression_metrics - clearance_microsome_az: *regression_metrics - clearance_hepatocyte_az: *regression_metrics - herg: *classification_metrics - ames: *classification_metrics - dili: *classification_metrics - ld50_zhu: *regression_metrics -trainer: - seed: *seed - logger: - save_dir: logs/tdc-admet-demo/ - name: *name - project: *name - model_checkpoint: - dirpath: models_checkpoints/tdc-admet-demo/ - filename: *name - save_last: True - trainer: - max_epochs: *max_epochs - min_epochs: 1 - check_val_every_n_epoch: 20 diff --git a/expts/hydra-configs/README.md b/expts/hydra-configs/README.md index 5e189c304..f695ae20c 100644 --- a/expts/hydra-configs/README.md +++ b/expts/hydra-configs/README.md @@ -38,18 +38,32 @@ trainer: We can now utilize `hydra` to e.g., run a sweep over our models on the ToyMix dataset via ```bash -python main_run_multitask.py -m model=gcn,gin +graphium-train -m model=gcn,gin ``` where the ToyMix dataset is pre-configured in `main.yaml`. Read on to find out how to define new datasets and architectures for pre-training and fine-tuning. ## Pre-training / Fine-tuning -From a configuration point-of-view, fine-tuning requires us to load a pre-trained model and attach new task heads. However, in a highly configurable library such as ours changing the task heads also requires changes to the logged metrics, loss functions and the source of the fine-tuning data. To allow a quick switch between pre-training and fine-tuning, by default, we configure models and the corresponding tasks in a separate manner. More specifically, +Say you trained a model with the following command: +```bash +graphium-train --config-name "main" +``` + +Fine-tuning this model on downstream tasks is then as simple as: +```bash +graphium-train --config-name "main" +finetuning=... +``` + +From a configuration point-of-view, fine-tuning requires us to load a pre-trained model and override part of the training configuration to fine-tune it on downstream tasks. To allow a quick switch between pre-training and fine-tuning, by default, we configure models and the corresponding tasks in a separate manner. More specifically, - under `architecture/` we store architecture related configurations such as the definition of the GNN/Transformer layers or positional/structural encoders - under `tasks/` we store configurations specific to one task set, such as the multi-task dataset ToyMix + - under `tasks/task_heads` we specify the task-specific heads to add on top of the base architecture. + - under `tasks/loss_metrics_datamodule` we specify the data-module to use and the task-specific loss functions and metrics - under `training/` we store configurations specific to training models which could be different for each combination of `architecture` and `tasks` +- under `finetuning/` we store configurations with overrides Since architecture and tasks are logically separated it now becomes very easy to e.g., use an existing architecture backbone on a new set of tasks or a new dataset altogether. Additionally, separating training allows us to specify different training parameters for e.g., pre-training and fine-tuning of the same architecture and task set. + We will now detail how you can add new architectures, tasks and training configurations. ### Adding an architecture @@ -88,7 +102,7 @@ datamodule: ``` You can then select your new architecture during training, e.g., by running ```bash -python main_run_multitask.py architecture=my_architecture +graphium-train architecture=my_architecture ``` ### Adding tasks @@ -125,7 +139,7 @@ predictor: ``` You can then select your new dataset during training, e.g., by running ```bash -python main_run_multitask.py tasks=my_tasks +graphium-train tasks=my_tasks ``` ### Adding training configs diff --git a/expts/hydra-configs/__init__.py b/expts/hydra-configs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/expts/hydra-configs/finetuning/admet.yaml b/expts/hydra-configs/finetuning/admet.yaml new file mode 100644 index 000000000..80fb20e35 --- /dev/null +++ b/expts/hydra-configs/finetuning/admet.yaml @@ -0,0 +1,91 @@ +# @package _global_ + +# == Fine-tuning configs in Graphium == +# +# A fine-tuning config is a appendum to a (pre-)training config. +# Since many things (e.g. the architecture), will stay constant between (pre-)training and fine-tuning, +# this config should be as minimal as possible to avoid unnecessary duplication. It only specifies +# what to override with regards to the config used for (pre-)training. +# +# Given the following training command: +# >>> graphium-train --cfg /path/to/train.yaml +# +# Fine-tuning now is as easy as: +# >>> graphium-train --cfg /path/to/train.yaml +finetune=admet +# +# NOTE: This config can be used for each of the benchmarks in the TDC ADMET benchmark suite. +# The only thing that needs to be changed is the `constants.task` key. + + +## == Overrides == + +defaults: + # This file contains all metrics and loss function info for all ADMET tasks. + # This config is filtered at runtime based on the `constants.task` key. + - override /tasks/loss_metrics_datamodule: admet + +constants: + + # For now, we assume a model is always fine-tuned on a single task at a time. + # You can override this value with any of the benchmark names in the TDC benchmark suite. + # See also https://tdcommons.ai/benchmark/admet_group/overview/ + task: &task lipophilicity_astrazeneca + + name: finetuning_${constants.task}_gcn + wandb: + name: ${constants.name} + project: *task + entity: multitask-gnn + save_dir: logs/${constants.task} + seed: 42 + max_epochs: 10 + data_dir: expts/data/admet/${constants.task} + raise_train_error: true + +predictor: + optim_kwargs: + lr: 4.e-5 + +# == Fine-tuning config == + +finetuning: + + # For now, we assume a model is always fine-tuned on a single task at a time. + # You can override this value with any of the benchmark names in the TDC benchmark suite. + # See also https://tdcommons.ai/benchmark/admet_group/overview/ + task: ${constants.task} + level: graph + + # Pretrained model + pretrained_model_name: dummy-pretrained-model + finetuning_module: task_heads # gnn + sub_module_from_pretrained: zinc # optional + new_sub_module: lipophilicity_astrazeneca # optional + + # keep_modules_after_finetuning_module: # optional + # graph_output_nn/graph: {} + # task_heads/zinc: + # new_sub_module: lipophilicity_astrazeneca + # out_dim: 1 + + + # Changes to finetuning_module + drop_depth: 1 + new_out_dim: 8 + added_depth: 2 + + # Training + unfreeze_pretrained_depth: 0 + epoch_unfreeze_all: none + + # Optional finetuning head appended to model after finetuning_module + finetuning_head: + task: ${constants.task} + previous_module: task_heads + incoming_level: graph + model_type: mlp + in_dim: 8 + out_dim: 1 + hidden_dims: 8 + depth: 2 + last_layer_is_readout: true diff --git a/expts/hydra-configs/main.yaml b/expts/hydra-configs/main.yaml index d4b3beceb..a57dd22ca 100644 --- a/expts/hydra-configs/main.yaml +++ b/expts/hydra-configs/main.yaml @@ -1,7 +1,7 @@ defaults: # Accelerators - - accelerator: ipu + - accelerator: cpu # Pre-training/fine-tuning - architecture: toymix diff --git a/expts/hydra-configs/tasks/admet.yaml b/expts/hydra-configs/tasks/admet.yaml new file mode 100644 index 000000000..30dec61e0 --- /dev/null +++ b/expts/hydra-configs/tasks/admet.yaml @@ -0,0 +1,7 @@ +# NOTE: We cannot have a single config, since for fine-tuning we will +# only want to override the loss_metrics_datamodule, whereas for training we will +# want to override both. + +defaults: + - task_heads: admet + - loss_metrics_datamodule: admet \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml new file mode 100644 index 000000000..87136b683 --- /dev/null +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml @@ -0,0 +1,141 @@ +# @package _global_ + +#Task-specific +predictor: + metrics_on_progress_bar: + # All below metrics are directly copied from the TDC website. + # For more information, see https://tdcommons.ai/benchmark/admet_group/overview/ + caco2_wang: ["mae"] + hia_hou: ["auroc"] + pgp_broccatelli: ["auroc"] + bioavailability_ma: ["auroc"] + lipophilicity_astrazeneca: ["mae"] + solubility_aqsoldb: ["mae"] + bbb_martins: ["auroc"] + ppbr_az: ["mae"] + vdss_lombardo: ["spearman"] + cyp2d6_veith: ["auprc"] + cyp3a4_veith: ["auprc"] + cyp2c9_veith: ["auprc"] + cyp2d6_substrate_carbonmangels: ["auprc"] + cyp3a4_substrate_carbonmangels: ["auprc"] + cyp2c9_substrate_carbonmangels: ["auprc"] + half_life_obach: ["spearman"] + clearance_microsome_az: ["spearman"] + clearance_hepatocyte_az: ["spearman"] + herg: ["mae"] + ames: ["auroc"] + dili: ["auroc"] + ld50_zhu: ["auroc"] + loss_fun: + caco2_wang: mae + hia_hou: bce + pgp_broccatelli: bce + bioavailability_ma: bce + lipophilicity_astrazeneca: mae + solubility_aqsoldb: mae + bbb_martins: bce + ppbr_az: mae + vdss_lombardo: mae + cyp2d6_veith: bce + cyp3a4_veith: bce + cyp2c9_veith: bce + cyp2d6_substrate_carbonmangels: bce + cyp3a4_substrate_carbonmangels: bce + cyp2c9_substrate_carbonmangels: bce + half_life_obach: mae + clearance_microsome_az: mae + clearance_hepatocyte_az: mae + herg: bce + ames: bce + dili: bce + ld50_zhu: mae + random_seed: ${constants.seed} + optim_kwargs: + lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs + torch_scheduler_kwargs: + module_type: WarmUpLinearLR + max_num_epochs: &max_epochs 10 + warmup_epochs: 10 + verbose: False + target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label + multitask_handling: flatten # flatten, mean-per-label + +# Task-specific +metrics: + caco2_wang: ®ression_metrics + - name: mae + metric: mae + target_nan_mask: null + multitask_handling: flatten + threshold_kwargs: null + - name: spearman + metric: spearmanr + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: pearson + metric: pearsonr + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: r2_score + metric: r2 + target_nan_mask: null + multitask_handling: mean-per-label + threshold_kwargs: null + hia_hou: &classification_metrics + - name: auroc + metric: auroc + task: binary + multitask_handling: mean-per-label + threshold_kwargs: null + - name: auprc + metric: averageprecision + task: binary + multitask_handling: mean-per-label + threshold_kwargs: null + - name: accuracy + metric: accuracy + multitask_handling: mean-per-label + target_to_int: True + average: micro + threshold_kwargs: &threshold_05 + operator: greater + threshold: 0.5 + th_on_preds: True + th_on_target: True + - name: mcc + metric: mcc + num_classes: 2 + multitask_handling: mean-per-label + target_to_int: True + average: micro + threshold_kwargs: *threshold_05 + pgp_broccatelli: *classification_metrics + bioavailability_ma: *classification_metrics + lipophilicity_astrazeneca: *regression_metrics + solubility_aqsoldb: *regression_metrics + bbb_martins: *classification_metrics + ppbr_az: *regression_metrics + vdss_lombardo: *regression_metrics + cyp2d6_veith: *classification_metrics + cyp3a4_veith: *classification_metrics + cyp2c9_veith: *classification_metrics + cyp2d6_substrate_carbonmangels: *classification_metrics + cyp3a4_substrate_carbonmangels: *classification_metrics + cyp2c9_substrate_carbonmangels: *classification_metrics + half_life_obach: *regression_metrics + clearance_microsome_az: *regression_metrics + clearance_hepatocyte_az: *regression_metrics + herg: *classification_metrics + ames: *classification_metrics + dili: *classification_metrics + ld50_zhu: *regression_metrics + +datamodule: + module_type: "ADMETBenchmarkDataModule" + args: + # TDC specific + tdc_benchmark_names: null + tdc_train_val_seed: ${constants.seed} \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml new file mode 100644 index 000000000..d5b302dd1 --- /dev/null +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml @@ -0,0 +1,48 @@ +# @package _global_ + +#Task-specific +predictor: + metrics_on_progress_bar: + homolumo: [] + metrics_on_training_set: + homolumo: ["pearsonr"] + loss_fun: + homolumo: mae_ipu + +# Task-specific +metrics: + homolumo: + - name: mae + metric: mae_ipu + target_nan_mask: null + multitask_handling: mean-per-label + threshold_kwargs: null + - name: pearsonr + metric: pearsonr_ipu + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + +datamodule: + module_type: "MultitaskFromSmilesDataModule" + # module_type: "FakeDataModule" # Option to use generated data + args: # Matches that in the test_multitask_datamodule.py case. + task_specific_args: # To be replaced by a new class "DatasetParams" + homolumo: + df: null + task_level: "graph" + df_path: graphium/data/PCQM4M/pcqm4mv2.csv + # wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2.csv + # or set path as https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2.csv directly + smiles_col: "cxsmiles" + label_cols: ["homo_lumo_gap"] + # sample_size: 8000 # use sample_size for test + splits_path: graphium/data/PCQM4M/split_dict_v2.pt # Download with `wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/split_dict_v2.pt` + split_names: ["train", "valid", "test-dev"] + # graphium/data/PCQM4Mv2/split_dict.pt + # graphium/data/PCQM4Mv2/pcqm4m_split.csv + # split_val: 0.1 + # split_test: 0.1 + seed: ${constants.seed} + label_normalization: + method: "normal" diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml new file mode 100644 index 000000000..9ac744a52 --- /dev/null +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml @@ -0,0 +1,102 @@ +# @package _global_ + +predictor: + metrics_on_progress_bar: + qm9: ["mae"] + tox21: ["auroc"] + zinc: ["mae"] + loss_fun: + qm9: mae_ipu + tox21: bce_logits_ipu + zinc: mae_ipu + +metrics: + qm9: &qm9_metrics + - name: mae + metric: mae_ipu + target_nan_mask: null + multitask_handling: flatten + threshold_kwargs: null + - name: pearsonr + metric: pearsonr_ipu + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: r2_score + metric: r2_score_ipu + target_nan_mask: null + multitask_handling: mean-per-label + threshold_kwargs: null + tox21: + - name: auroc + metric: auroc_ipu + task: binary + multitask_handling: mean-per-label + threshold_kwargs: null + - name: avpr + metric: average_precision_ipu + task: binary + multitask_handling: mean-per-label + threshold_kwargs: null + - name: f1 > 0.5 + metric: f1 + multitask_handling: mean-per-label + target_to_int: True + num_classes: 2 + average: micro + threshold_kwargs: &threshold_05 + operator: greater + threshold: 0.5 + th_on_preds: True + th_on_target: True + - name: precision > 0.5 + metric: precision + multitask_handling: mean-per-label + average: micro + threshold_kwargs: *threshold_05 + zinc: *qm9_metrics + +datamodule: + args: + task_specific_args: + qm9: + df: null + df_path: ${constants.data_dir}/qm9.csv.gz + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz + # or set path as the URL directly + smiles_col: "smiles" + label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"] + # sample_size: 2000 # use sample_size for test + splits_path: ${constants.data_dir}/qm9_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt` + seed: ${constants.seed} #*seed + task_level: graph + label_normalization: + normalize_val_test: True + method: "normal" + + tox21: + df: null + df_path: ${constants.data_dir}/Tox21-7k-12-labels.csv.gz + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz + # or set path as the URL directly + smiles_col: "smiles" + label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"] + # sample_size: 2000 # use sample_size for test + splits_path: ${constants.data_dir}/Tox21_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt` + seed: ${constants.seed} + task_level: graph + + zinc: + df: null + df_path: ${constants.data_dir}/ZINC12k.csv.gz + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz + # or set path as the URL directly + smiles_col: "smiles" + label_cols: ["SA", "logp", "score"] + # sample_size: 2000 # use sample_size for test + splits_path: ${constants.data_dir}/ZINC12k_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt` + seed: ${constants.seed} + task_level: graph + label_normalization: + normalize_val_test: True + method: "normal" \ No newline at end of file diff --git a/expts/hydra-configs/tasks/pcqm4m.yaml b/expts/hydra-configs/tasks/pcqm4m.yaml index d92d381f7..4bd477dcc 100644 --- a/expts/hydra-configs/tasks/pcqm4m.yaml +++ b/expts/hydra-configs/tasks/pcqm4m.yaml @@ -1,62 +1,7 @@ -# @package _global_ +# NOTE: We cannot have a single config, since for fine-tuning we will +# only want to override the loss_metrics_datamodule, whereas for training we will +# want to override both. -architecture: - task_heads: - homolumo: - task_level: graph - out_dim: 1 - hidden_dims: 256 - depth: 2 # Not needed if we have hidden_dims - activation: relu - last_activation: none - dropout: 0.18 - normalization: layer_norm - last_normalization: "none" - residual_type: none - -#Task-specific -predictor: - metrics_on_progress_bar: - homolumo: [] - metrics_on_training_set: - homolumo: ["pearsonr"] - loss_fun: - homolumo: mae_ipu - -# Task-specific -metrics: - homolumo: - - name: mae - metric: mae_ipu - target_nan_mask: null - multitask_handling: mean-per-label - threshold_kwargs: null - - name: pearsonr - metric: pearsonr_ipu - threshold_kwargs: null - target_nan_mask: null - multitask_handling: mean-per-label - -datamodule: - module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data - args: # Matches that in the test_multitask_datamodule.py case. - task_specific_args: # To be replaced by a new class "DatasetParams" - homolumo: - df: null - task_level: "graph" - df_path: graphium/data/PCQM4M/pcqm4mv2.csv - # wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2.csv - # or set path as https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2.csv directly - smiles_col: "cxsmiles" - label_cols: ["homo_lumo_gap"] - # sample_size: 8000 # use sample_size for test - splits_path: graphium/data/PCQM4M/split_dict_v2.pt # Download with `wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/split_dict_v2.pt` - split_names: ["train", "valid", "test-dev"] - # graphium/data/PCQM4Mv2/split_dict.pt - # graphium/data/PCQM4Mv2/pcqm4m_split.csv - # split_val: 0.1 - # split_test: 0.1 - seed: ${constants.seed} - label_normalization: - method: "normal" +defaults: + - task_heads: pcqm4m + - loss_metrics_datamodule: pcqm4m \ No newline at end of file diff --git a/expts/hydra-configs/tasks/task_heads/admet.yaml b/expts/hydra-configs/tasks/task_heads/admet.yaml new file mode 100644 index 000000000..2e697b15d --- /dev/null +++ b/expts/hydra-configs/tasks/task_heads/admet.yaml @@ -0,0 +1,47 @@ +# @package _global_ + +architecture: + task_heads: + caco2_wang: ®ression_head + task_level: graph + out_dim: 1 + hidden_dims: 64 + depth: 2 + activation: relu + last_activation: none + dropout: &dropout 0.5 + normalization: &normalization "layer_norm" + last_normalization: "none" + residual_type: none + hia_hou: &classification_head + task_level: graph + out_dim: 1 + hidden_dims: 64 + depth: 2 + activation: relu + last_activation: sigmoid + dropout: *dropout + normalization: *normalization + last_normalization: "none" + residual_type: none + pgp_broccatelli: *classification_head + bioavailability_ma: *classification_head + lipophilicity_astrazeneca: *regression_head + solubility_aqsoldb: *regression_head + bbb_martins: *classification_head + ppbr_az: *regression_head + vdss_lombardo: *regression_head + cyp2d6_veith: *classification_head + cyp3a4_veith: *classification_head + cyp2c9_veith: *classification_head + cyp2d6_substrate_carbonmangels: *classification_head + cyp3a4_substrate_carbonmangels: *classification_head + cyp2c9_substrate_carbonmangels: *classification_head + half_life_obach: *regression_head + clearance_microsome_az: *regression_head + clearance_hepatocyte_az: *regression_head + herg: *classification_head + ames: *classification_head + dili: *classification_head + ld50_zhu: *regression_head + \ No newline at end of file diff --git a/expts/hydra-configs/tasks/task_heads/pcqm4m.yaml b/expts/hydra-configs/tasks/task_heads/pcqm4m.yaml new file mode 100644 index 000000000..b45ee9e62 --- /dev/null +++ b/expts/hydra-configs/tasks/task_heads/pcqm4m.yaml @@ -0,0 +1,15 @@ +# @package _global_ + +architecture: + task_heads: + homolumo: + task_level: graph + out_dim: 1 + hidden_dims: 256 + depth: 2 # Not needed if we have hidden_dims + activation: relu + last_activation: none + dropout: 0.18 + normalization: layer_norm + last_normalization: "none" + residual_type: none diff --git a/expts/hydra-configs/tasks/task_heads/toymix.yaml b/expts/hydra-configs/tasks/task_heads/toymix.yaml new file mode 100644 index 000000000..c1df2522e --- /dev/null +++ b/expts/hydra-configs/tasks/task_heads/toymix.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +architecture: + task_heads: + qm9: + task_level: graph + out_dim: 19 + hidden_dims: 128 + depth: 2 + activation: relu + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none + tox21: + task_level: graph + out_dim: 12 + hidden_dims: 64 + depth: 2 + activation: relu + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none + zinc: + task_level: graph + out_dim: 3 + hidden_dims: 32 + depth: 2 + activation: relu + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none diff --git a/expts/hydra-configs/tasks/toymix.yaml b/expts/hydra-configs/tasks/toymix.yaml index e120c13a8..16d582982 100644 --- a/expts/hydra-configs/tasks/toymix.yaml +++ b/expts/hydra-configs/tasks/toymix.yaml @@ -1,138 +1,7 @@ -# @package _global_ +# NOTE: We cannot have a single config, since for fine-tuning we will +# only want to override the loss_metrics_datamodule, whereas for training we will +# want to override both. -architecture: - task_heads: - qm9: - task_level: graph - out_dim: 19 - hidden_dims: 128 - depth: 2 - activation: relu - last_activation: none - dropout: ${architecture.pre_nn.dropout} - normalization: ${architecture.pre_nn.normalization} - last_normalization: "none" - residual_type: none - tox21: - task_level: graph - out_dim: 12 - hidden_dims: 64 - depth: 2 - activation: relu - last_activation: none - dropout: ${architecture.pre_nn.dropout} - normalization: ${architecture.pre_nn.normalization} - last_normalization: "none" - residual_type: none - zinc: - task_level: graph - out_dim: 3 - hidden_dims: 32 - depth: 2 - activation: relu - last_activation: none - dropout: ${architecture.pre_nn.dropout} - normalization: ${architecture.pre_nn.normalization} - last_normalization: "none" - residual_type: none - -predictor: - metrics_on_progress_bar: - qm9: ["mae"] - tox21: ["auroc"] - zinc: ["mae"] - loss_fun: - qm9: mae_ipu - tox21: bce_logits_ipu - zinc: mae_ipu - -metrics: - qm9: &qm9_metrics - - name: mae - metric: mae_ipu - target_nan_mask: null - multitask_handling: flatten - threshold_kwargs: null - - name: pearsonr - metric: pearsonr_ipu - threshold_kwargs: null - target_nan_mask: null - multitask_handling: mean-per-label - - name: r2_score - metric: r2_score_ipu - target_nan_mask: null - multitask_handling: mean-per-label - threshold_kwargs: null - tox21: - - name: auroc - metric: auroc_ipu - task: binary - multitask_handling: mean-per-label - threshold_kwargs: null - - name: avpr - metric: average_precision_ipu - task: binary - multitask_handling: mean-per-label - threshold_kwargs: null - - name: f1 > 0.5 - metric: f1 - multitask_handling: mean-per-label - target_to_int: True - num_classes: 2 - average: micro - threshold_kwargs: &threshold_05 - operator: greater - threshold: 0.5 - th_on_preds: True - th_on_target: True - - name: precision > 0.5 - metric: precision - multitask_handling: mean-per-label - average: micro - threshold_kwargs: *threshold_05 - zinc: *qm9_metrics - -datamodule: - args: - task_specific_args: - qm9: - df: null - df_path: ${constants.data_dir}/qm9.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"] - # sample_size: 2000 # use sample_size for test - splits_path: ${constants.data_dir}/qm9_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt` - seed: ${constants.seed} #*seed - task_level: graph - label_normalization: - normalize_val_test: True - method: "normal" - - tox21: - df: null - df_path: ${constants.data_dir}/Tox21-7k-12-labels.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"] - # sample_size: 2000 # use sample_size for test - splits_path: ${constants.data_dir}/Tox21_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt` - seed: ${constants.seed} - task_level: graph - - zinc: - df: null - df_path: ${constants.data_dir}/ZINC12k.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["SA", "logp", "score"] - # sample_size: 2000 # use sample_size for test - splits_path: ${constants.data_dir}/ZINC12k_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt` - seed: ${constants.seed} - task_level: graph - label_normalization: - normalize_val_test: True - method: "normal" \ No newline at end of file +defaults: + - task_heads: toymix + - loss_metrics_datamodule: toymix \ No newline at end of file diff --git a/expts/hydra-configs/training/accelerator/toymix_cpu.yaml b/expts/hydra-configs/training/accelerator/toymix_cpu.yaml index eb12c8935..9022eeb84 100644 --- a/expts/hydra-configs/training/accelerator/toymix_cpu.yaml +++ b/expts/hydra-configs/training/accelerator/toymix_cpu.yaml @@ -11,10 +11,10 @@ predictor: optim_kwargs: {} metrics_every_n_train_steps: 300 torch_scheduler_kwargs: - max_num_epochs: &max_epochs 300 + max_num_epochs: ${constants.max_epochs} trainer: trainer: precision: 32 accumulate_grad_batches: 1 - max_epochs: *max_epochs \ No newline at end of file + max_epochs: ${constants.max_epochs} \ No newline at end of file diff --git a/expts/hydra-configs/training/accelerator/toymix_gpu.yaml b/expts/hydra-configs/training/accelerator/toymix_gpu.yaml index 3712373c3..c2c8e4066 100644 --- a/expts/hydra-configs/training/accelerator/toymix_gpu.yaml +++ b/expts/hydra-configs/training/accelerator/toymix_gpu.yaml @@ -14,9 +14,9 @@ predictor: optim_kwargs: {} metrics_every_n_train_steps: 300 torch_scheduler_kwargs: - max_num_epochs: &max_epochs 300 + max_num_epochs: ${constants.max_epochs} trainer: trainer: accumulate_grad_batches: 1 - max_epochs: *max_epochs \ No newline at end of file + max_epochs: ${constants.max_epochs} \ No newline at end of file diff --git a/expts/hydra-configs/training/model/pcqm4m_gpspp.yaml b/expts/hydra-configs/training/model/pcqm4m_gpspp.yaml index e13c44aa0..7fb1e1ee5 100644 --- a/expts/hydra-configs/training/model/pcqm4m_gpspp.yaml +++ b/expts/hydra-configs/training/model/pcqm4m_gpspp.yaml @@ -3,7 +3,6 @@ # GPS++ model with the PCQMv2 dataset. constants: name: pcqm4mv2_gpspp_4layer - entity: "multitask-gnn" seed: 42 max_epochs: 100 raise_train_error: true # Whether the code should raise an error if it crashes during training diff --git a/expts/hydra-configs/training/model/pcqm4m_mpnn.yaml b/expts/hydra-configs/training/model/pcqm4m_mpnn.yaml index 41b55eba1..ca643fe39 100644 --- a/expts/hydra-configs/training/model/pcqm4m_mpnn.yaml +++ b/expts/hydra-configs/training/model/pcqm4m_mpnn.yaml @@ -3,7 +3,6 @@ # MPNN model with the PCQMv2 dataset. constants: name: pcqm4mv2_mpnn_4layer - entity: "multitask-gnn" seed: 42 max_epochs: 100 raise_train_error: true # Whether the code should raise an error if it crashes during training diff --git a/expts/hydra-configs/training/model/toymix_gcn.yaml b/expts/hydra-configs/training/model/toymix_gcn.yaml index f422e37fc..48eabe003 100644 --- a/expts/hydra-configs/training/model/toymix_gcn.yaml +++ b/expts/hydra-configs/training/model/toymix_gcn.yaml @@ -2,7 +2,6 @@ constants: name: neurips2023_small_data_gcn - entity: "multitask-gnn" seed: 42 max_epochs: 100 data_dir: expts/data/neurips2023/small-dataset diff --git a/expts/hydra-configs/training/model/toymix_gin.yaml b/expts/hydra-configs/training/model/toymix_gin.yaml index 605671c68..ed2885efb 100644 --- a/expts/hydra-configs/training/model/toymix_gin.yaml +++ b/expts/hydra-configs/training/model/toymix_gin.yaml @@ -2,7 +2,6 @@ constants: name: neurips2023_small_data_gin - entity: "multitask-gnn" seed: 42 data_dir: expts/data/neurips2023/small-dataset raise_train_error: true diff --git a/expts/hydra-configs/training/pcqm4m.yaml b/expts/hydra-configs/training/pcqm4m.yaml index 910c78c67..871a2a5f1 100644 --- a/expts/hydra-configs/training/pcqm4m.yaml +++ b/expts/hydra-configs/training/pcqm4m.yaml @@ -7,7 +7,7 @@ predictor: # weight_decay: 1.e-7 torch_scheduler_kwargs: module_type: WarmUpLinearLR - max_num_epochs: &max_epochs 100 + max_num_epochs: ${constants.max_epochs} warmup_epochs: 10 verbose: False scheduler_kwargs: @@ -22,10 +22,6 @@ predictor: trainer: seed: ${constants.seed} - logger: - save_dir: logs/PCQMv2 - name: ${constants.name} - project: PCQMv2_mpnn #early_stopping: # monitor: *monitor # min_delta: 0 @@ -39,6 +35,6 @@ trainer: save_top_k: 1 every_n_epochs: 100 trainer: - max_epochs: *max_epochs + max_epochs: ${constants.max_epochs} min_epochs: 1 check_val_every_n_epoch: 20 diff --git a/expts/hydra-configs/training/toymix.yaml b/expts/hydra-configs/training/toymix.yaml index 05d7c4715..4afcbd56a 100644 --- a/expts/hydra-configs/training/toymix.yaml +++ b/expts/hydra-configs/training/toymix.yaml @@ -7,7 +7,7 @@ predictor: # weight_decay: 1.e-7 torch_scheduler_kwargs: module_type: WarmUpLinearLR - max_num_epochs: &max_epochs 100 + max_num_epochs: ${constants.max_epochs} warmup_epochs: 10 verbose: False scheduler_kwargs: null @@ -16,15 +16,11 @@ predictor: trainer: seed: ${constants.seed} - logger: - save_dir: logs/neurips2023-small/ - name: ${constants.name} - project: ${constants.name} model_checkpoint: filename: ${constants.name} save_last: True trainer: precision: 16 - max_epochs: *max_epochs + max_epochs: ${constants.max_epochs} min_epochs: 1 check_val_every_n_epoch: 20 \ No newline at end of file diff --git a/expts/main_run_multitask.py b/expts/main_run_multitask.py index c14670377..c68663a08 100644 --- a/expts/main_run_multitask.py +++ b/expts/main_run_multitask.py @@ -32,67 +32,10 @@ @hydra.main(version_base=None, config_path="hydra-configs", config_name="main") def main(cfg: DictConfig) -> None: - cfg = OmegaConf.to_container(cfg, resolve=True) - - run_name: str = "main" - add_date_time: bool = True - - st = timeit.default_timer() - - date_time_suffix = "" - if add_date_time: - date_time_suffix = datetime.now().strftime("%d.%m.%Y_%H.%M.%S") - - wandb.init(entity=cfg["constants"]["entity"], project=cfg["constants"]["name"], config=cfg) - - # Initialize the accelerator - cfg, accelerator_type = load_accelerator(cfg) - - # Load and initialize the dataset - datamodule = load_datamodule(cfg, accelerator_type) - - # Initialize the network - model_class, model_kwargs = load_architecture( - cfg, - in_dims=datamodule.in_dims, + raise DeprecationWarning( + "This script is deprecated. Use `python graphium/cli/train_finetune.py` (or `graphium-train`) instead!" ) - datamodule.prepare_data() - - metrics = load_metrics(cfg) - logger.info(metrics) - - predictor = load_predictor( - cfg, model_class, model_kwargs, metrics, accelerator_type, datamodule.task_norms - ) - - logger.info(predictor.model) - logger.info(ModelSummary(predictor, max_depth=4)) - - trainer = load_trainer(cfg, run_name, accelerator_type, date_time_suffix) - save_params_to_wandb(trainer.logger, cfg, predictor, datamodule) - - # Determine the max num nodes and edges in training and validation - predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"]) - - # Run the model training - with SafeRun(name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True): - trainer.fit(model=predictor, datamodule=datamodule) - - # Determine the max num nodes and edges in testing - predictor.set_max_nodes_edges_per_graph(datamodule, stages=["test"]) - - # Run the model testing - with SafeRun(name="TESTING", raise_error=cfg["constants"]["raise_train_error"], verbose=True): - trainer.test(model=predictor, datamodule=datamodule) # , ckpt_path=ckpt_path) - - logger.info("--------------------------------------------") - logger.info("total computation used", timeit.default_timer() - st) - logger.info("--------------------------------------------") - wandb.finish() - - return trainer.callback_metrics - if __name__ == "__main__": main() diff --git a/graphium/cli/README.md b/graphium/cli/README.md deleted file mode 100644 index eb744ac7e..000000000 --- a/graphium/cli/README.md +++ /dev/null @@ -1,9 +0,0 @@ -
- -

The Graph Of LIfe Library.

-
- - -## What is in this folder? - -- files for handling command line arguments \ No newline at end of file diff --git a/graphium/cli/__init__.py b/graphium/cli/__init__.py index 5c5f801f3..8928b9836 100644 --- a/graphium/cli/__init__.py +++ b/graphium/cli/__init__.py @@ -1,2 +1,3 @@ -from .main import main_cli from .data import data_cli +from .finetune_utils import finetune_cli +from .main import main_cli diff --git a/graphium/cli/__main__.py b/graphium/cli/__main__.py new file mode 100644 index 000000000..0baa7638c --- /dev/null +++ b/graphium/cli/__main__.py @@ -0,0 +1,4 @@ +from .main import main_cli + +if __name__ == "__main__": + main_cli() diff --git a/graphium/cli/finetune_utils.py b/graphium/cli/finetune_utils.py new file mode 100644 index 000000000..80d437f98 --- /dev/null +++ b/graphium/cli/finetune_utils.py @@ -0,0 +1,78 @@ +import yaml +import click +import fsspec + +from loguru import logger +from hydra import compose, initialize +from datamol.utils import fs + +from .main import main_cli +from .train_finetune import run_training_finetuning + + +@main_cli.group(name="finetune", help="Utility CLI for extra fine-tuning utilities.") +def finetune_cli(): + pass + + +@finetune_cli.command(name="admet") +@click.argument("save_dir") +@click.option("--wandb/--no-wandb", default=True, help="Whether to log to Weights & Biases.") +@click.option( + "--name", + "-n", + multiple=True, + help="One or multiple benchmarks to filter on. See also --inclusive-filter/--exclusive-filter.", +) +@click.option( + "--inclusive-filter/--exclusive-filter", + default=True, + help="Whether to include or exclude the benchmarks specified by `--name`.", +) +def benchmark_tdc_admet_cli(save_dir, wandb, name, inclusive_filter): + """ + Utility CLI to easily fine-tune a model on (a subset of) the benchmarks in the TDC ADMET group. + The results are saved to the SAVE_DIR. + """ + + try: + from tdc.utils import retrieve_benchmark_names + except ImportError: + raise ImportError("TDC needs to be installed to use this CLI. Run `pip install PyTDC`.") + + # Get the benchmarks to run this for + if name is None: + name = retrieve_benchmark_names("admet_group") + elif not inclusive_filter: + name = [n for n in name if n not in retrieve_benchmark_names("admet_group")] + + results = {} + + # Use the Compose API to construct the config + for n in name: + overrides = [ + "+finetuning=admet", + f"finetuning.task={n}", + f"finetuning.finetuning_head.task={n}", + ] + + if not wandb: + overrides.append("~constants.wandb") + + with initialize(version_base=None, config_path="../../expts/hydra-configs"): + cfg = compose( + config_name="main", + overrides=overrides, + ) + + # Run the training loop + ret = run_training_finetuning(cfg) + ret = {k: v.item() for k, v in ret.items()} + results[n] = ret + + fs.mkdir(save_dir, exist_ok=True) + path = fs.join(save_dir, "results.yaml") + logger.info(f"Saving results to {path}") + + with fsspec.open(path, "w") as f: + yaml.dump(results, f) diff --git a/graphium/cli/train_finetune.py b/graphium/cli/train_finetune.py new file mode 100644 index 000000000..e6ead5122 --- /dev/null +++ b/graphium/cli/train_finetune.py @@ -0,0 +1,123 @@ +import hydra +import wandb +import timeit + +from omegaconf import DictConfig, OmegaConf +from loguru import logger +from datetime import datetime +from lightning.pytorch.utilities.model_summary import ModelSummary + +from graphium.config._loader import ( + load_datamodule, + load_metrics, + load_architecture, + load_predictor, + load_trainer, + load_accelerator, + save_params_to_wandb, +) +from graphium.finetuning import modify_cfg_for_finetuning, GraphFinetuning +from graphium.utils.safe_run import SafeRun + + +FINETUNING_CONFIG_KEY = "finetuning" + + +@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main") +def cli(cfg: DictConfig) -> None: + """ + The main CLI endpoint for training and fine-tuning Graphium models. + """ + run_training_finetuning(cfg) + + +def run_training_finetuning(cfg: DictConfig) -> None: + """ + The main (pre-)training and fine-tuning loop. + """ + + cfg = OmegaConf.to_container(cfg, resolve=True) + + # Modify the config for finetuning + if FINETUNING_CONFIG_KEY in cfg: + cfg = modify_cfg_for_finetuning(cfg) + + st = timeit.default_timer() + + wandb_cfg = cfg["constants"].get("wandb") + if wandb_cfg is not None: + wandb.init( + entity=wandb_cfg["entity"], + project=wandb_cfg["project"], + config=cfg, + ) + + ## == Instantiate all required objects from their respective configs == + # Accelerator + cfg, accelerator_type = load_accelerator(cfg) + + ## Data-module + datamodule = load_datamodule(cfg, accelerator_type) + + ## Architecture + model_class, model_kwargs = load_architecture(cfg, in_dims=datamodule.in_dims) + + datamodule.prepare_data() + + ## Metrics + metrics = load_metrics(cfg) + + ## Predictor + predictor = load_predictor( + config=cfg, + model_class=model_class, + model_kwargs=model_kwargs, + metrics=metrics, + task_levels=datamodule.get_task_levels(), + accelerator_type=accelerator_type, + featurization=datamodule.featurization, + task_norms=datamodule.task_norms, + ) + + logger.info(predictor.model) + logger.info(ModelSummary(predictor, max_depth=4)) + + ## Trainer + date_time_suffix = datetime.now().strftime("%d.%m.%Y_%H.%M.%S") + trainer = load_trainer(cfg, accelerator_type, date_time_suffix) + + # Add the fine-tuning callback to trainer + if FINETUNING_CONFIG_KEY in cfg: + finetuning_training_kwargs = cfg["finetuning"]["training_kwargs"] + trainer.callbacks.append(GraphFinetuning(**finetuning_training_kwargs)) + + if wandb_cfg is not None: + save_params_to_wandb(trainer.logger, cfg, predictor, datamodule) + + # Determine the max num nodes and edges in training and validation + logger.info("Computing the maximum number of nodes and edges per graph") + predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"]) + + # Run the model training + with SafeRun(name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True): + trainer.fit(model=predictor, datamodule=datamodule) + + # Determine the max num nodes and edges in testing + predictor.set_max_nodes_edges_per_graph(datamodule, stages=["test"]) + + # Run the model testing + with SafeRun(name="TESTING", raise_error=cfg["constants"]["raise_train_error"], verbose=True): + trainer.test(model=predictor, datamodule=datamodule) # , ckpt_path=ckpt_path) + + logger.info("-" * 50) + logger.info("Total compute time:", timeit.default_timer() - st) + logger.info("-" * 50) + + if wandb_cfg is not None: + wandb.finish() + + return trainer.callback_metrics + + +if __name__ == "__main__": + cli() diff --git a/graphium/config/__init__.py b/graphium/config/__init__.py index 2f13795b6..fd5ade2eb 100644 --- a/graphium/config/__init__.py +++ b/graphium/config/__init__.py @@ -5,3 +5,5 @@ from ._loader import load_metrics from ._loader import load_predictor from ._loader import load_trainer +from ._loader import save_params_to_wandb +from ._loader import load_accelerator diff --git a/graphium/config/_loader.py b/graphium/config/_loader.py index b06279897..3c7a654e9 100644 --- a/graphium/config/_loader.py +++ b/graphium/config/_loader.py @@ -24,6 +24,7 @@ from graphium.ipu.ipu_dataloader import IPUDataloaderOptions from graphium.trainer.metrics import MetricWrapper from graphium.nn.architectures import FullGraphMultiTaskNetwork +from graphium.finetuning.finetuning_architecture import FullGraphFinetuningNetwork from graphium.nn.utils import MupMixin from graphium.trainer.predictor import PredictorModule from graphium.utils.spaces import DATAMODULE_DICT @@ -192,12 +193,12 @@ def load_architecture( config = omegaconf.OmegaConf.create(config) cfg_arch = config["architecture"] - kwargs = {} - # Select the architecture model_type = cfg_arch["model_type"].lower() if model_type == "fullgraphmultitasknetwork": model_class = FullGraphMultiTaskNetwork + elif model_type == "fullgraphfinetuningnetwork": + model_class = FullGraphFinetuningNetwork else: raise ValueError(f"Unsupported model_type=`{model_type}`") @@ -260,6 +261,18 @@ def load_architecture( task_heads_kwargs=task_heads_kwargs, ) + if model_class is FullGraphFinetuningNetwork: + finetuning_head_kwargs = config["finetuning"].pop("finetuning_head", None) + pretrained_overwriting_kwargs = config["finetuning"].pop("overwriting_kwargs") + pretrained_model_name = pretrained_overwriting_kwargs.pop("pretrained_model_name") + + model_kwargs = { + "pretrained_model_kwargs": deepcopy(model_kwargs), + "pretrained_overwriting_kwargs": pretrained_overwriting_kwargs, + "pretrained_model_name": pretrained_model_name, + "finetuning_head_kwargs": finetuning_head_kwargs, + } + return model_class, model_kwargs @@ -268,7 +281,9 @@ def load_predictor( model_class: Type[torch.nn.Module], model_kwargs: Dict[str, Any], metrics: Dict[str, MetricWrapper], + task_levels: Dict[str, str], accelerator_type: str, + featurization: Dict[str, str] = None, task_norms: Optional[Dict[Callable, Any]] = None, ) -> PredictorModule: """ @@ -292,6 +307,8 @@ def load_predictor( model_class=model_class, model_kwargs=model_kwargs, metrics=metrics, + task_levels=task_levels, + featurization=featurization, task_norms=task_norms, **cfg_pred, ) @@ -340,7 +357,6 @@ def load_mup(mup_base_path: str, predictor: PredictorModule) -> PredictorModule: def load_trainer( config: Union[omegaconf.DictConfig, Dict[str, Any]], - run_name: str, accelerator_type: str, date_time_suffix: str = "", ) -> Trainer: @@ -348,7 +364,6 @@ def load_trainer( Defining the pytorch-lightning Trainer module. Parameters: config: The config file, with key `trainer` - run_name: The name of the current run. To be used for logging. accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu" date_time_suffix: The date and time of the current run. To be used for logging. Returns: @@ -395,12 +410,12 @@ def load_trainer( callbacks.append(ModelCheckpoint(**cfg_trainer["model_checkpoint"])) # Define the logger parameters - logger = cfg_trainer.pop("logger", None) - if logger is not None: - name = logger.pop("name", run_name) + wandb_cfg = config["constants"].get("wandb") + if wandb_cfg is not None: + name = wandb_cfg.pop("name", "main") if len(date_time_suffix) > 0: name += f"_{date_time_suffix}" - trainer_kwargs["logger"] = WandbLogger(name=name, **logger) + trainer_kwargs["logger"] = WandbLogger(name=name, **wandb_cfg) trainer_kwargs["callbacks"] = callbacks trainer = Trainer( diff --git a/graphium/config/dummy_finetuning.yaml b/graphium/config/dummy_finetuning.yaml new file mode 100644 index 000000000..38cdf029b --- /dev/null +++ b/graphium/config/dummy_finetuning.yaml @@ -0,0 +1,146 @@ +# Here, we are finetuning a FullGraphMultitaskNetwork +# trained on ToyMix. We finetune from the zinc task-head +# (graph-level) on the TDC dataset lipophilicity_astraceneca + +# Here are the changes to the architecture: +# +# Change zinc task-head: +# depth: 2 -> 2 - 1 + 2 = 3 +# out_dim: 3 -> 8 +# +# Add finetuning head +# model_type: FeedForwardNN +# out_dim: 1 +# hidden_dims: 8 +# depth: 2 + + +################################################### +########### How to combine information ########### +################################################### + + +########################### +### FINETUNING-SPECIFIC ### +########################### + +finetuning: + # New task + task: lipophilicity_astrazeneca + level: graph + + # Pretrained model + pretrained_model_name: dummy-pretrained-model + finetuning_module: task_heads + sub_module_from_pretrained: zinc # optional + new_sub_module: lipophilicity_astrazeneca # optional + # keep_modules_after_finetuning_module: # optional + + # Changes to finetuning_module + drop_depth: 1 + new_out_dim: 8 + added_depth: 2 + + # Optional finetuning head appended to model after finetuning_module + finetuning_head: # none + task: lipophilicity_astrazeneca + previous_module: task_heads + incoming_level: graph + model_type: mlp + in_dim: 8 + out_dim: 1 + hidden_dims: 8 + depth: 2 + last_layer_is_readout: true + + # Finetuning training + unfreeze_pretrained_depth: 0 + epoch_unfreeze_all: 2 + +constants: + seed: 42 + max_epochs: 3 + +accelerator: + float32_matmul_precision: medium + type: cpu + +predictor: + random_seed: ${constants.seed} + optim_kwargs: + lr: 4.e-5 + scheduler_kwargs: null + target_nan_mask: null + multitask_handling: flatten # flatten, mean-per-label + + torch_scheduler_kwargs: + module_type: WarmUpLinearLR + max_num_epochs: 3 + warmup_epochs: 1 + verbose: False + + metrics_on_progress_bar: + lipophilicity_astrazeneca: ["mae"] + loss_fun: + lipophilicity_astrazeneca: mae + +metrics: + lipophilicity_astrazeneca: + - name: mae + metric: mae + target_nan_mask: null + multitask_handling: flatten + threshold_kwargs: null + - name: spearman + metric: spearmanr + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: pearson + metric: pearsonr + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: r2_score + metric: r2 + target_nan_mask: null + multitask_handling: mean-per-label + threshold_kwargs: null + +trainer: + seed: ${constants.seed} + trainer: + precision: 32 + max_epochs: 3 + min_epochs: 1 + check_val_every_n_epoch: 1 + accumulate_grad_batches: 1 + +################## +### DATAMODULE ### +################## + +datamodule: + +### FROM FINETUNING ### + + module_type: "ADMETBenchmarkDataModule" + args: + # TDC specific + tdc_benchmark_names: [lipophilicity_astrazeneca] + tdc_train_val_seed: ${constants.seed} + + batch_size_training: 200 + batch_size_inference: 200 + featurization_n_jobs: 0 + num_workers: 0 + + prepare_dict_or_graph: pyg:graph + featurization_progress: True + featurization_backend: "loky" + processed_graph_data_path: "../datacache/neurips2023-small/" + persistent_workers: False + + + + diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 0724b50ec..b85fd664c 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -2,6 +2,8 @@ from contextlib import redirect_stderr, redirect_stdout from typing import Type, List, Dict, Union, Any, Callable, Optional, Tuple, Iterable, Literal +from dataclasses import dataclass + import os from functools import partial import importlib.resources @@ -642,6 +644,7 @@ def get_max_num_edges_datamodule(self, stages: Optional[List[str]] = None) -> in return max_num_edges +@dataclass class DatasetProcessingParams: def __init__( self, @@ -916,6 +919,8 @@ def __init__( if featurization is None: featurization = {} + self.featurization = featurization + # Whether to transform the smiles into a pyg `Data` graph or a dictionary compatible with pyg if prepare_dict_or_graph == "pyg:dict": self.smiles_transformer = partial(mol_to_graph_dict, **featurization) @@ -933,6 +938,16 @@ def _get_task_key(self, task_level: str, task: str): task = task_prefix + task return task + def get_task_levels(self): + task_level_map = {} + + for task, task_args in self.task_specific_args.items(): + if isinstance(task_args, DatasetProcessingParams): + task_args = task_args.__dict__ # Convert the class to a dictionary + task_level_map.update({task: task_args["task_level"]}) + + return task_level_map + def prepare_data(self): """Called only from a single process in distributed settings. Steps: @@ -1896,6 +1911,8 @@ def _get_split_indices( # Split from an indices file file_type = self._get_data_file_type(splits_path) + train, val, test = split_names + if file_type == "pt": splits = torch.load(splits_path) elif file_type in ["csv", "tsv"]: @@ -1954,13 +1971,23 @@ def get_data_hash(self): Get a hash specific to a dataset and smiles_transformer. Useful to cache the pre-processed data. """ - args = deepcopy(self.task_specific_args) + args = {} # pop epoch_sampling_fraction out when creating hash # so that the data cache does not need to be regenerated # when epoch_sampling_fraction has changed. - for task in self.task_specific_args.keys(): - if "epoch_sampling_fraction" in args[task].keys(): - args[task].pop("epoch_sampling_fraction") + for task_key, task_args in deepcopy(self.task_specific_args).items(): + if isinstance(task_args, DatasetProcessingParams): + task_args = task_args.__dict__ # Convert the class to a dictionary + + # Keep only first 5 rows of a dataframe + if "df" in task_args.keys(): + if task_args["df"] is not None: + task_args["df"] = task_args["df"].iloc[:5] + + # Remove the `epoch_sampling_fraction` + task_args.pop("epoch_sampling_fraction", None) + args[task_key] = task_args + hash_dict = { "smiles_transformer": self.smiles_transformer, "task_specific_args": args, diff --git a/graphium/finetuning/__init__.py b/graphium/finetuning/__init__.py new file mode 100644 index 000000000..5ef566743 --- /dev/null +++ b/graphium/finetuning/__init__.py @@ -0,0 +1,3 @@ +from .utils import modify_cfg_for_finetuning +from .finetuning import GraphFinetuning +from .finetuning_architecture import FullGraphFinetuningNetwork diff --git a/graphium/finetuning/finetuning.py b/graphium/finetuning/finetuning.py new file mode 100644 index 000000000..afa33cad5 --- /dev/null +++ b/graphium/finetuning/finetuning.py @@ -0,0 +1,86 @@ +from typing import Iterable, List, Dict, Tuple, Union, Callable, Any, Optional, Type + +from collections import OrderedDict + +import torch.nn as nn +import pytorch_lightning as pl + +from torch.optim.optimizer import Optimizer +from pytorch_lightning.callbacks import BaseFinetuning + + +class GraphFinetuning(BaseFinetuning): + def __init__( + self, + finetuning_module: str, + added_depth: int, + unfreeze_pretrained_depth: Optional[int] = None, + epoch_unfreeze_all: int = 0, + train_bn: bool = False, + ): + """ + Finetuning training callback that (un)freezes modules as specified in the configuration file. + By default, the modified layers of the fineuning module and the finetuning head are unfrozen. + + Parameters: + finetuning_module: Module to finetune from + added_depth: Number of layers of finetuning module that have been modified rel. to pretrained model + unfreeze_pretrained_depth: Number of additional layers to unfreeze before layers modified rel. to pretrained model + epoch_unfreeze_all: Epoch to unfreeze entire model + train_bn: Boolean value indicating if batchnorm layers stay in training mode + + """ + super().__init__() + + self.finetuning_module = finetuning_module + self.training_depth = added_depth + if unfreeze_pretrained_depth is not None: + self.training_depth += unfreeze_pretrained_depth + self.epoch_unfreeze_all = epoch_unfreeze_all + self.train_bn = train_bn + + def freeze_before_training(self, pl_module: pl.LightningModule): + """ + Freeze everything up to finetuning module (and parts of finetuning module) + + Parameters: + pl_module: PredictorModule used for finetuning + """ + + # Access module map of pretrained module + module_map = pl_module.model.pretrained_model.net._module_map + + for module_name in module_map.keys(): + self.freeze_module(module_name, module_map) + + if module_name.startswith(self.finetuning_module): + # Do not freeze modules after finetuning module + break + + def freeze_module(self, module_name: str, module_map: Dict[str, Union[nn.ModuleList, Any]]): + """ + Freeze specific modules + + Parameters: + module_name: Name of module to (partally) freeze + module_map: Dictionary mapping from module_name to corresponding module(s) + """ + modules = module_map[module_name] + + # We only partially freeze the finetuning module + if module_name.startswith(self.finetuning_module): + modules = modules[: -self.training_depth] + + self.freeze(modules=modules, train_bn=self.train_bn) + + def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer): + """ + Function unfreezing entire model at specified epoch + + Parameters: + pl_module: PredictorModule used for finetuning + epoch: Current training epoch + optimizer: Optimizer used for finetuning + """ + if epoch == self.epoch_unfreeze_all: + self.unfreeze_and_add_param_group(modules=pl_module, optimizer=optimizer, train_bn=self.train_bn) diff --git a/graphium/finetuning/finetuning_architecture.py b/graphium/finetuning/finetuning_architecture.py new file mode 100644 index 000000000..8fd6263b5 --- /dev/null +++ b/graphium/finetuning/finetuning_architecture.py @@ -0,0 +1,347 @@ +from typing import Iterable, List, Dict, Tuple, Union, Callable, Any, Optional, Type + +from copy import deepcopy + +from loguru import logger + +import torch +import torch.nn as nn + +from torch import Tensor +from torch_geometric.data import Batch + +from graphium.data.utils import get_keys +from graphium.nn.base_graph_layer import BaseGraphStructure +from graphium.nn.architectures.encoder_manager import EncoderManager +from graphium.nn.architectures import FullGraphMultiTaskNetwork, FeedForwardNN, FeedForwardPyg, TaskHeads +from graphium.nn.architectures.global_architectures import FeedForwardGraph +from graphium.trainer.predictor_options import ModelOptions +from graphium.nn.utils import MupMixin + +from graphium.trainer.predictor import PredictorModule +from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT, FINETUNING_HEADS_DICT + + +class FullGraphFinetuningNetwork(nn.Module, MupMixin): + def __init__( + self, + pretrained_model_name: str, + pretrained_model_kwargs: Dict[str, Any], + pretrained_overwriting_kwargs: Dict[str, Any], + finetuning_head_kwargs: Optional[Dict[str, Any]] = None, + num_inference_to_average: int = 1, + last_layer_is_readout: bool = False, + name: str = "FullFinetuningGNN", + ): + r""" + Flexible class that allows to implement an end-to-end graph finetuning network architecture, supporting flexible pretrained models and finetuning heads. + The network decomposes into two parts of class PretrainedModel and FinetuningHead. The PretrainedModel class allows basic finetuning such as + finetuning from a specified module of the pretrained model and dropping/adding layers in this module. The (optional) FinetuningHead class allows more + flexible finetuning with a custom network applied after the pretrained model. If not specified, we fall back to basic finetuning integrated in PretrainedModel. + + Parameters: + + pretrained_model_name: + Identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT + + pretrained_model_kwargs: + Key-word arguments to instantiate a model of the same class as the pretrained model (e.g., FullGraphMultitaskNetwork)) + + pretrained_overwriting_kwargs: + Key-word arguments indicating which parameters of loaded model are shared with the pretrained part of FullGraphFinetuningNetwork + + finetuning_head_kwargs: + Key-word arguments to use for the finetuning head. + It must respect the following criteria: + - pretrained_model_kwargs[last_used_module]["out_level"] must be equal to finetuning_head_kwargs["in_level"] + - pretrained_model_kwargs[last_used_module]["out_dim"] must be equal to finetuning_head_kwargs["in_dim"] + + Here, [last_used_module] represents the module that is finetuned from, + e.g., gnn, graph_output or (one of the) task_heads + + num_inference_to_average: + Number of inferences to average at val/test time. This is used to avoid the noise introduced + by positional encodings with sign-flips. In case no such encoding is given, + this parameter is ignored. + NOTE: The inference time will be slowed-down proportionaly to this parameter. + + last_layer_is_readout: Whether the last layer should be treated as a readout layer. + Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup + + name: + Name attributed to the current network, for display and printing + purposes. + """ + + super().__init__() + + self.name = name + self.num_inference_to_average = num_inference_to_average + self.last_layer_is_readout = last_layer_is_readout + self._concat_last_layers = None + self.pretrained_model_name = pretrained_model_name + self.pretrained_overwriting_kwargs = pretrained_overwriting_kwargs + self.finetuning_head_kwargs = finetuning_head_kwargs + self.max_num_nodes_per_graph = None + self.max_num_edges_per_graph = None + self.finetuning_head = None + + self.pretrained_model = PretrainedModel( + pretrained_model_name, pretrained_model_kwargs, pretrained_overwriting_kwargs + ) + + if finetuning_head_kwargs is not None: + self.finetuning_head = FinetuningHead(finetuning_head_kwargs) + + def forward(self, g: Batch) -> Tensor: + r""" + Apply the pre-processing neural network, the graph neural network, + and the post-processing neural network on the graph features. + + Parameters: + + g: + pyg Batch graph on which the convolution is done. + Must contain the following elements: + + - Node key `"feat"`: `torch.Tensor[..., N, Din]`. + Input node feature tensor, before the network. + `N` is the number of nodes, `Din` is the input features dimension ``self.pre_nn.in_dim`` + + - Edge key `"edge_feat"`: `torch.Tensor[..., N, Ein]` **Optional**. + The edge features to use. It will be ignored if the + model doesn't supporte edge features or if + `self.in_dim_edges==0`. + + - Other keys related to positional encodings `"pos_enc_feats_sign_flip"`, + `"pos_enc_feats_no_flip"`. + + Returns: + + `torch.Tensor[..., M, Dout]` or `torch.Tensor[..., N, Dout]`: + Node or graph feature tensor, after the network. + `N` is the number of nodes, `M` is the number of graphs, + `Dout` is the output dimension ``self.graph_output_nn.out_dim`` + If the `self.gnn.pooling` is [`None`], then it returns node features and the output dimension is `N`, + otherwise it returns graph features and the output dimension is `M` + + """ + + g = self.pretrained_model.forward(g) + + if self.finetuning_head is not None: + g = self.finetuning_head.forward(g) + + return g + + def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str, Any]: + """ + Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model. + The base model is usually identical to the regular model, but with the + layers width divided by a given factor (2 by default) + + Parameter: + divide_factor: Factor by which to divide the width. + + Returns: + Dictionary with the kwargs to create the base model. + """ + kwargs = dict( + pretrained_model_name=self.pretrained_model_name, + pretrained_model_kwargs=None, + finetuning_head_kwargs=None, + num_inference_to_average=self.num_inference_to_average, + last_layer_is_readout=self.last_layer_is_readout, + name=self.name, + ) + + kwargs["pretrained_model_kwargs"] = self.pretrained_model.make_mup_base_kwargs( + divide_factor=divide_factor + ) + + if self.finetuning_head is not None: + kwargs["finetuning_head_kwargs"] = self.finetuning_head.make_mup_base_kwargs( + divide_factor=divide_factor, factor_in_dim=True + ) + + kwargs["pretrained_overwriting_kwargs"] = self.pretrained_overwriting_kwargs + + return kwargs + + def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], max_edges: Optional[int]) -> None: + r""" + Set the maximum number of nodes and edges for all gnn layers and encoder layers + + Parameters: + max_nodes: Maximum number of nodes in the dataset. + This will be useful for certain architecture, but ignored by others. + + max_edges: Maximum number of edges in the dataset. + This will be useful for certain architecture, but ignored by others. + """ + + self.pretrained_model.net.set_max_num_nodes_edges_per_graph(max_nodes, max_edges) + + +class PretrainedModel(nn.Module, MupMixin): + def __init__( + self, + pretrained_model_name: str, + pretrained_model_kwargs: Dict[str, Any], + pretrained_overwriting_kwargs: Dict[str, Any], + ): + r""" + Flexible class allowing to finetune pretrained models from GRAPHIUM_PRETRAINED_MODELS_DICT. + Can be any model that inherits from nn.Module, MupMixin and comes with a module map (e.g., FullGraphMultitaskNetwork) + + Parameters: + + pretrained_model_name: + Identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT + + pretrained_model_kwargs: + Key-word arguments to instantiate a model of the same class as the pretrained model (e.g., FullGraphMultitaskNetwork)) + + pretrained_overwriting_kwargs: + Key-word arguments indicating which parameters of loaded model are shared with the pretrained part of FullGraphFinetuningNetwork + + """ + + super().__init__() + + # Load pretrained model + pretrained_model = PredictorModule.load_from_checkpoint( + GRAPHIUM_PRETRAINED_MODELS_DICT[pretrained_model_name] + ).model + pretrained_model.create_module_map() + + # Initialize new model with architecture after desired modifications to architecture. + net = type(pretrained_model) + self.net = net(**pretrained_model_kwargs) + self.net.create_module_map() + + # Overwrite parameters shared between loaded and modified pretrained model + self.overwrite_with_pretrained(pretrained_model, **pretrained_overwriting_kwargs) + + def forward(self, g: Union[torch.Tensor, Batch]): + g = self.net.forward(g) + + return g + + def overwrite_with_pretrained( + self, + pretrained_model, + finetuning_module: str, + added_depth: int, + sub_module_from_pretrained: str = None, + ): + """ + Overwrite parameters shared between loaded and modified pretrained model + + Parameters: + pretrained_model: Model from GRAPHIUM_PRETRAINED_MODELS_DICT + finetuning_module: Module to finetune from + added_depth: Number of modified layers at the end of finetuning module + sub_module_from_pretrained: Optional submodule to finetune from + """ + module_map = self.net._module_map + module_map_from_pretrained = pretrained_model._module_map + + module_names_from_pretrained = module_map_from_pretrained.keys() + super_module_names_from_pretrained = set( + [module_name.split("/")[0] for module_name in module_names_from_pretrained] + ) + + for module_name in module_map.keys(): + # Below exception handles some modules (e.g., pe_encoders in FullGraphMultitaskNetwork) that do not support len()); + # They can always be replaced entirely + try: + shared_depth = len(module_map[module_name]) + except: + module_map[module_name] = module_map_from_pretrained[module_name] + continue + + if module_name.startswith(finetuning_module): + shared_depth -= added_depth + + if module_name in module_map_from_pretrained.keys(): + for idx in range(shared_depth): + module_map[module_name][idx] = module_map_from_pretrained[module_name][idx] + elif module_name.split("/")[0] in super_module_names_from_pretrained: + for idx in range(shared_depth): + module_map[module_name][idx] = module_map_from_pretrained[ + "".join([module_name.split("/")[0], "/", sub_module_from_pretrained]) + ][idx] + else: + raise RuntimeError("Mismatch between loaded pretrained model and model to be overwritten.") + + if module_name.startswith(finetuning_module): + break + + def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str, Any]: + """ + Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model. + The base model is usually identical to the regular model, but with the + layers width divided by a given factor (2 by default) + + Parameter: + divide_factor: Factor by which to divide the width. + factor_in_dim: Whether to factor the input dimension + + Returns: + Dictionary with the kwargs to create the base model. + """ + # For the post-nn network, all the dimension are divided + + return self.net.make_mup_base_kwargs(divide_factor=divide_factor) + + +class FinetuningHead(nn.Module, MupMixin): + def __init__(self, finetuning_head_kwargs: Dict[str, Any]): + r""" + Flexible class allowing to use a custom finetuning head on top of the pretrained model. + Can be any model that inherits from nn.Module, MupMixin. + + Parameters: + + finetuning_head_kwargs: Key-word arguments needed to instantiate a custom (or existing) finetuning head from FINETUNING_HEADS_DICT + + """ + + super().__init__() + self.task = finetuning_head_kwargs.pop("task", None) + self.previous_module = finetuning_head_kwargs.pop("previous_module", "task_heads") + self.incoming_level = finetuning_head_kwargs.pop("incoming_level", "graph") + + model_type = finetuning_head_kwargs.pop("model_type", "mlp") + net = FINETUNING_HEADS_DICT[model_type] + self.net = net(**finetuning_head_kwargs) + + def forward(self, g: Union[Dict[str, Union[torch.Tensor, Batch]], torch.Tensor, Batch]): + if isinstance(g, Union[torch.Tensor, Batch]): + pass + elif isinstance(g, Dict) and len(g) == 1: + g = list(g.values())[0] + else: + raise TypeError("Output type from pretrained model not appropriate for finetuning head") + + g = self.net.forward(g) + + return {self.task: g} + + def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]: + """ + Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model. + The base model is usually identical to the regular model, but with the + layers width divided by a given factor (2 by default) + + Parameter: + divide_factor: Factor by which to divide the width. + factor_in_dim: Whether to factor the input dimension + + Returns: + Dictionary with the kwargs to create the base model. + """ + # For the post-nn network, all the dimension are divided + + return self.net.make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim) diff --git a/graphium/finetuning/utils.py b/graphium/finetuning/utils.py new file mode 100644 index 000000000..605ca536f --- /dev/null +++ b/graphium/finetuning/utils.py @@ -0,0 +1,196 @@ +from typing import Union, List, Dict, Any + +from copy import deepcopy +from loguru import logger +from graphium.trainer import PredictorModule + +from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT + + +def filter_cfg_based_on_admet_benchmark_name(config: Dict[str, Any], names: Union[List[str], str]): + """ + Filter a base config for the full TDC ADMET benchmarking group to only + have settings related to a subset of the endpoints + """ + + if config["datamodule"]["module_type"] != "ADMETBenchmarkDataModule": + # NOTE (cwognum): For now, this implies we only support the ADMET benchmark from TDC. + # It is easy to extend this in the future to support more datasets. + raise ValueError("You can only use this method for the `ADMETBenchmarkDataModule`") + + if isinstance(names, str): + names = [names] + + def _filter(d): + return {k: v for k, v in d.items() if k in names} + + cfg = deepcopy(config) + + # Update the datamodule arguments + cfg["datamodule"]["args"]["tdc_benchmark_names"] = names + + # Filter the relevant config sections + if "architecture" in cfg and "task_heads" in cfg["architecture"]: + cfg["architecture"]["task_heads"] = _filter(cfg["architecture"]["task_heads"]) + if "predictor" in cfg and "metrics_on_progress_bar" in cfg["predictor"]: + cfg["predictor"]["metrics_on_progress_bar"] = _filter(cfg["predictor"]["metrics_on_progress_bar"]) + if "predictor" in cfg and "loss_fun" in cfg["predictor"]: + cfg["predictor"]["loss_fun"] = _filter(cfg["predictor"]["loss_fun"]) + if "metrics" in cfg: + cfg["metrics"] = _filter(cfg["metrics"]) + + return cfg + + +def modify_cfg_for_finetuning(cfg: Dict[str, Any]): + """ + Function combining information from configuration and pretrained model for finetuning. + """ + + task = cfg["finetuning"]["task"] + + # Filter the config based on the task name + # NOTE (cwognum): This prevents the need for having many different files for each of the tasks + # with lots and lots of config repetition. + cfg = filter_cfg_based_on_admet_benchmark_name(cfg, task) + cfg_finetune = cfg["finetuning"] + + # Load pretrained model + pretrained_model_name = cfg_finetune["pretrained_model_name"] + pretrained_predictor = PredictorModule.load_from_checkpoint( + GRAPHIUM_PRETRAINED_MODELS_DICT[pretrained_model_name], device="cpu" + ) + + # Inherit shared configuration from pretrained + # Architecture + pretrained_architecture = pretrained_predictor.model_kwargs + arch_keys = pretrained_architecture.keys() + arch_keys = [key.replace("_kwargs", "") for key in arch_keys] + cfg_arch = {arch_keys[idx]: value for idx, value in enumerate(pretrained_architecture.values())} + cfg_arch_from_pretrained = deepcopy(cfg_arch) + # Featurization + cfg["datamodule"]["args"]["featurization"] = pretrained_predictor.featurization + + finetuning_module = cfg_finetune["finetuning_module"] + sub_module_from_pretrained = cfg_finetune.get("sub_module_from_pretrained", None) + new_sub_module = cfg_finetune.pop("new_sub_module", None) + keep_modules_after_finetuning_module = cfg_finetune.pop("keep_modules_after_finetuning_module", None) + + # Find part of config of module to finetune from + pretrained_predictor.model.create_module_map() + module_map_from_pretrained = pretrained_predictor.model._module_map + + if not any([module.startswith(finetuning_module) for module in module_map_from_pretrained.keys()]): + raise ValueError("Unkown module {finetuning_module}") + elif sub_module_from_pretrained is None: + new_module_kwargs = deepcopy(cfg_arch[finetuning_module]) + else: + new_module_kwargs = deepcopy(cfg_arch[finetuning_module][sub_module_from_pretrained]) + + # Modify config according to desired finetuning architecture + out_dim = ( + cfg_arch[finetuning_module].get("out_dim") + if sub_module_from_pretrained is None + else cfg_arch[finetuning_module][sub_module_from_pretrained].get("out_dim") + ) + + upd_kwargs = { + "out_dim": cfg_finetune.pop("new_out_dim", out_dim), + "depth": new_module_kwargs["depth"] + + cfg_finetune.get("added_depth", 0) + - cfg_finetune.pop("drop_depth", 0), + } + + # Update config + new_module_kwargs.update(upd_kwargs) + + if sub_module_from_pretrained is None: + cfg_arch[finetuning_module] = new_module_kwargs + else: + cfg_arch[finetuning_module] = {new_sub_module: new_module_kwargs} + + # Remove modules of pretrained model after module to finetune from unless specified differently + module_list = list(module_map_from_pretrained.keys()) + super_module_list = [] + for module in module_list: + if module.split("/")[0] not in super_module_list: # Only add each supermodule once + super_module_list.append(module.split("/")[0]) + + # Set configuration of modules after finetuning module to None + cutoff_idx = ( + super_module_list.index(finetuning_module) + 1 + ) # Index of module after module to finetune from + for module in super_module_list[cutoff_idx:]: + cfg_arch[module] = None + + # If desired, we can keep specific modules after the finetuning module (specified in cfg/finetuning/keep_modules_after_finetuning_module) + if keep_modules_after_finetuning_module is not None: + for module_name, updates in keep_modules_after_finetuning_module.items(): + cfg_arch = update_cfg_arch_for_module(cfg_arch, cfg_arch_from_pretrained, module_name, updates) + + # Change architecture to FullGraphFinetuningNetwork + cfg_arch["model_type"] = "FullGraphFinetuningNetwork" + + cfg["architecture"] = cfg_arch + + pretrained_overwriting_kwargs = deepcopy(cfg["finetuning"]) + drop_keys = [ + "task", + "level", + "finetuning_head", + "unfreeze_pretrained_depth", + "epoch_unfreeze_all", + ] + + for key in drop_keys: + pretrained_overwriting_kwargs.pop(key, None) + + finetuning_training_kwargs = deepcopy(cfg["finetuning"]) + drop_keys = ["task", "level", "pretrained_model_name", "sub_module_from_pretrained", "finetuning_head"] + for key in drop_keys: + finetuning_training_kwargs.pop(key, None) + + cfg["finetuning"].update( + {"overwriting_kwargs": pretrained_overwriting_kwargs, "training_kwargs": finetuning_training_kwargs} + ) + + return cfg + + +def update_cfg_arch_for_module( + cfg_arch: Dict[str, Any], + cfg_arch_from_pretrained: Dict[str, Any], + module_name: str, + updates: Dict[str, Any], +): + """ + Function to modify the key-word arguments of modules after the finetuning module if they are kept. + + Parameters: + cfg_arch: Configuration of the architecture of the model used for finetuning + cfg_arch_from_pretrained: Configuration of the architecture of the loaded pretrained model + module_name: Module of loaded pretrained model + updates: Changes to apply to key-work arguments of selected module + """ + # We need to distinguish between modules with & without submodules + if "/" not in module_name: + if cfg_arch[module_name] is None: + cfg_arch[module_name] = {} + + cfg_arch_from_pretrained[module_name].update({key: value for key, value in updates.items()}) + + cfg_arch.update({module_name, cfg_arch_from_pretrained}) + + else: + module_name, sub_module = module_name.split("/") + new_sub_module = updates.pop("new_sub_module", sub_module) + + if cfg_arch[module_name] is None: + cfg_arch[module_name] = {} + + cfg_arch_from_pretrained[module_name][sub_module].update( + {key: value for key, value in updates.items()} + ) + cfg_arch[module_name].update({new_sub_module: cfg_arch_from_pretrained[module_name][sub_module]}) + + return cfg_arch diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 42478efa9..08ccddb0a 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -1,10 +1,12 @@ from typing import Iterable, List, Dict, Tuple, Union, Callable, Any, Optional, Type from torch_geometric.data import Batch from graphium.ipu.to_dense_batch import to_dense_batch +from loguru import logger # Misc imports import inspect from copy import deepcopy +from collections import OrderedDict # Torch imports from torch import Tensor, nn @@ -272,6 +274,28 @@ def _create_layers(self): if ii < len(residual_out_dims): this_in_dim = residual_out_dims[ii] + def drop_layers(self, depth: int) -> None: + r""" + Remove the last layers of the model part. + """ + + assert depth >= 0 + assert depth <= len(self.layers) + + if depth > 0: + self.layers = self.layers[:-depth] + + def add_layers(self, layers: int) -> None: + r""" + Add layers to the end of the model. + """ + assert isinstance(layers, nn.ModuleList) + assert len(layers) > 0 + if len(self.layers) > 0: + assert layers[0].in_dim == self.layers[-1].out_dim + + self.layers.extend(layers) + def forward(self, h: torch.Tensor) -> torch.Tensor: r""" Apply the neural network on the input features. @@ -1118,6 +1142,43 @@ def _apply_ipu_pipeline_split(self, gnn_layers_per_ipu): self.gnn.layers[begin_block_layer_index], ipu_id=ipu_id ) + def create_module_map(self): + """ + Function to create mapping between each (sub)module name and corresponding nn.ModuleList() (if possible); + Used for finetuning when (partially) loading or freezing specific modules of the pretrained model + """ + self._module_map = OrderedDict() + + if self.encoder_manager is not None: + self._module_map.update( + {"pe_encoders": self.encoder_manager} + ) # could be extended to submodules, e.g. pe_encoders/la_pos/linear_in/..., etc.; not necessary for current finetuning + + if self.pre_nn is not None: + self._module_map.update({"pre_nn": self.pre_nn.layers}) + + if self.pre_nn_edges is not None: + self._module_map.update({"pre_nn_edges": self.pre_nn_edges.layers}) + + # No need to check for NoneType as GNN module is not optional in FullGraphMultitaskNetwork + self._module_map.update({"gnn": self.gnn.layers}) + + if self.task_heads is not None: + self._module_map.update( + { + "graph_output_nn/" + + output_level: self.task_heads.graph_output_nn[output_level].graph_output_nn.layers + for output_level in self.task_heads.graph_output_nn.keys() + } + ) + + self._module_map.update( + { + "task_heads/" + task_head_name: self.task_heads.task_heads[task_head_name].layers + for task_head_name in self.task_heads.task_heads.keys() + } + ) + def forward(self, g: Batch) -> Tensor: r""" Apply the pre-processing neural network, the graph neural network, diff --git a/graphium/nn/utils.py b/graphium/nn/utils.py index 1b10d6497..68a8779c4 100644 --- a/graphium/nn/utils.py +++ b/graphium/nn/utils.py @@ -47,4 +47,8 @@ def scale_kwargs(self, scale_factor: Real, scale_in_dim: bool = False): try: return self.make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=scale_in_dim) except TypeError as e: - raise "This error may have been caused by passing scale_in_dim to scale_kwargs for a class that does not support passing factor_in_dim to make_mup_base_kwargs, which cannot be done" from e + raise RuntimeError( + "This error may have been caused by passing scale_in_dim to scale_kwargs " + "for a class that does not support passing factor_in_dim to make_mup_base_kwargs, " + "which cannot be done" + ) from e diff --git a/graphium/trainer/predictor.py b/graphium/trainer/predictor.py index 71551d61e..3f2d2e676 100644 --- a/graphium/trainer/predictor.py +++ b/graphium/trainer/predictor.py @@ -1,5 +1,6 @@ from graphium.trainer.metrics import MetricWrapper from typing import Dict, List, Any, Union, Any, Callable, Tuple, Type, Optional +from collections import OrderedDict import numpy as np from copy import deepcopy import time @@ -18,9 +19,7 @@ from graphium.utils.moving_average_tracker import MovingAverageTracker from graphium.utils.tensor import dict_tensor_fp16_to_fp32 -GRAPHIUM_PRETRAINED_MODELS = { - "graphium-zinc-micro-dummy-test": "gcs://graphium-public/pretrained-models/graphium-zinc-micro-dummy-test/model.ckpt" -} +from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT class PredictorModule(lightning.LightningModule): @@ -29,7 +28,9 @@ def __init__( model_class: Type[nn.Module], model_kwargs: Dict[str, Any], loss_fun: Dict[str, Union[str, Callable]], + task_levels: Dict[str, str], random_seed: int = 42, + featurization: Dict[str, str] = None, optim_kwargs: Optional[Dict[str, Any]] = None, torch_scheduler_kwargs: Optional[Dict[str, Any]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, @@ -71,6 +72,8 @@ def __init__( self.target_nan_mask = target_nan_mask self.multitask_handling = multitask_handling + self.task_levels = task_levels + self.featurization = featurization self.task_norms = task_norms super().__init__() @@ -99,19 +102,16 @@ def __init__( self._eval_options_dict: Dict[str, EvalOptions] = eval_options self._eval_options_dict = { - self._get_task_key( - task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key - ): value + self._get_task_key(task_level=task_levels[key], task=key): value for key, value in self._eval_options_dict.items() } # Setting the flag options self._flag_options = FlagOptions(flag_kwargs=flag_kwargs) self.model = self._model_options.model_class(**self._model_options.model_kwargs) + loss_fun = { - self._get_task_key( - task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key - ): value + self._get_task_key(task_level=task_levels[key], task=key): value for key, value in loss_fun.items() } self.tasks = list(loss_fun.keys()) @@ -314,9 +314,7 @@ def _general_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool) preds = {k: preds[ii] for ii, k in enumerate(targets_dict.keys())} preds = { - self._get_task_key( - task_level=self.model_kwargs["task_heads_kwargs"][key]["task_level"], task=key - ): value + self._get_task_key(task_level=self.task_levels[key], task=key): value for key, value in preds.items() } # preds = {k: preds[ii] for ii, k in enumerate(targets_dict.keys())} @@ -668,10 +666,10 @@ def __repr__(self) -> str: @staticmethod def list_pretrained_models(): """List available pretrained models.""" - return GRAPHIUM_PRETRAINED_MODELS + return GRAPHIUM_PRETRAINED_MODELS_DICT @staticmethod - def load_pretrained_models(name: str): + def load_pretrained_models(name: str, device: str = None): """Load a pretrained model from its name. Args: @@ -679,12 +677,14 @@ def load_pretrained_models(name: str): from `graphium.trainer.PredictorModule.list_pretrained_models()`. """ - if name not in GRAPHIUM_PRETRAINED_MODELS: + if name not in GRAPHIUM_PRETRAINED_MODELS_DICT: raise ValueError( - f"The model '{name}' is not available. Choose from {set(GRAPHIUM_PRETRAINED_MODELS.keys())}." + f"The model '{name}' is not available. Choose from {set(GRAPHIUM_PRETRAINED_MODELS_DICT.keys())}." ) - return PredictorModule.load_from_checkpoint(GRAPHIUM_PRETRAINED_MODELS[name]) + return PredictorModule.load_from_checkpoint( + GRAPHIUM_PRETRAINED_MODELS_DICT[name], map_location=device + ) def set_max_nodes_edges_per_graph(self, datamodule: BaseDataModule, stages: Optional[List[str]] = None): datamodule.setup() diff --git a/graphium/utils/read_file.py b/graphium/utils/read_file.py index d11d6501d..09035cb00 100644 --- a/graphium/utils/read_file.py +++ b/graphium/utils/read_file.py @@ -53,7 +53,7 @@ def read_file(filepath, as_ext=None, **kwargs): else: file_ext = as_ext if not isinstance(file_ext, str): - raise "`file_type` must be a `str`. Provided: {}".format(file_ext) + raise TypeError("`file_type` must be a `str`. Provided: {}".format(file_ext)) open_mode = "r" diff --git a/graphium/utils/spaces.py b/graphium/utils/spaces.py index 4a2cf41fc..6641294eb 100644 --- a/graphium/utils/spaces.py +++ b/graphium/utils/spaces.py @@ -4,6 +4,7 @@ import torchmetrics.functional as TorchMetrics import graphium.nn.base_layers as BaseLayers +from graphium.nn.architectures import FeedForwardNN, FeedForwardPyg, TaskHeads import graphium.utils.custom_lr as CustomLR import graphium.data.datamodule as Datamodules import graphium.ipu.ipu_losses as IPULosses @@ -123,3 +124,9 @@ "ADMETBenchmarkDataModule": Datamodules.ADMETBenchmarkDataModule, "FakeDataModule": Datamodules.FakeDataModule, } + +GRAPHIUM_PRETRAINED_MODELS_DICT = { + "dummy-pretrained-model": "tests/dummy-pretrained-model.ckpt", # dummy model (to be deleted later) +} + +FINETUNING_HEADS_DICT = {"mlp": FeedForwardNN, "gnn": FeedForwardPyg, "task_head": TaskHeads} diff --git a/mkdocs.yml b/mkdocs.yml index be93eb803..e0759cebe 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -36,8 +36,7 @@ nav: - Using GNN layers: tutorials/gnn/using_gnn_layers.ipynb - model_training: - Simple Molecular Model: tutorials/model_training/simple-molecular-model.ipynb - - Training on IPU: tutorials/model_training/ipu_training.ipynb - - Chosing the right parallelization: tutorials/basics/choosing_parallelization.ipynb + - Training on IPU: tutorials/model_training/running-multitask-ipu.ipynb - Design: design.md - Datasets: datasets.md - Pretrained Models: pretrained_models.md diff --git a/notebooks/finetuning-on-tdc-admet-benchmark.ipynb b/notebooks/finetuning-on-tdc-admet-benchmark.ipynb index 3c42c0349..43eb47081 100644 --- a/notebooks/finetuning-on-tdc-admet-benchmark.ipynb +++ b/notebooks/finetuning-on-tdc-admet-benchmark.ipynb @@ -146,7 +146,14 @@ " \n", " # Initialize the predictor\n", " predictor = load_predictor(\n", - " cfg, model_class, model_kwargs, metrics, accelerator_type, datamodule.task_norms\n", + " cfg,\n", + " model_class,\n", + " model_kwargs,\n", + " metrics,\n", + " datamodule.get_task_levels(),\n", + " accelerator_type,\n", + " datamodule.featurization,\n", + " datamodule.task_norms\n", " )\n", " \n", " # Initialize the trainer\n", diff --git a/pyproject.toml b/pyproject.toml index 7b9302931..9e55eb5f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,8 @@ dependencies = [ ] [project.scripts] -graphium = "graphium.cli:main_cli" +graphium = "graphium.cli.main:main_cli" + graphium-train = "graphium.cli.train_finetune:cli" [project.urls] Website = "https://graphium.datamol.io/" @@ -79,7 +80,7 @@ fallback_version = "dev" [tool.setuptools.packages.find] where = ["."] -include = ["graphium", "graphium.*"] +include = ["graphium", "graphium.*", "expts", "expts.*"] exclude = [] namespaces = true diff --git a/tests/dummy-pretrained-model.ckpt b/tests/dummy-pretrained-model.ckpt new file mode 100644 index 000000000..b1312cffa Binary files /dev/null and b/tests/dummy-pretrained-model.ckpt differ diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py new file mode 100644 index 000000000..7eeaf5f3f --- /dev/null +++ b/tests/test_finetuning.py @@ -0,0 +1,227 @@ +import os +from os.path import dirname, abspath + +import unittest as ut + +import torch +from copy import deepcopy + +from lightning.pytorch.callbacks import Callback + +from omegaconf import OmegaConf +import graphium + +from graphium.finetuning import modify_cfg_for_finetuning +from graphium.trainer import PredictorModule + +from graphium.finetuning import GraphFinetuning + +from graphium.config._loader import ( + load_datamodule, + load_metrics, + load_architecture, + load_predictor, + load_trainer, + save_params_to_wandb, + load_accelerator, +) + + +MAIN_DIR = dirname(dirname(abspath(graphium.__file__))) +CONFIG_FILE = "graphium/config/dummy_finetuning.yaml" + +os.chdir(MAIN_DIR) + + +class Test_Finetuning(ut.TestCase): + def test_finetuning_pipeline(self): + # Skip test if PyTDC package not installed + try: + import tdc + except ImportError: + self.skipTest("PyTDC needs to be installed to run this test. Use `pip install PyTDC`.") + + ################################################## + ### Test modification of config for finetuning ### + ################################################## + + cfg = graphium.load_config(name="dummy_finetuning") + cfg = OmegaConf.to_container(cfg, resolve=True) + + cfg = modify_cfg_for_finetuning(cfg) + + # Initialize the accelerator + cfg, accelerator_type = load_accelerator(cfg) + + # Load and initialize the dataset + datamodule = load_datamodule(cfg, accelerator_type) + + # Initialize the network + model_class, model_kwargs = load_architecture( + cfg, + in_dims=datamodule.in_dims, + ) + + datamodule.prepare_data() + + metrics = load_metrics(cfg) + + predictor = load_predictor( + cfg, + model_class, + model_kwargs, + metrics, + datamodule.get_task_levels(), + accelerator_type, + datamodule.featurization, + datamodule.task_norms, + ) + + self.assertEqual( + len( + predictor.model.pretrained_model.net.task_heads.task_heads["lipophilicity_astrazeneca"].layers + ), + 3, + ) + self.assertEqual( + predictor.model.pretrained_model.net.task_heads.task_heads["lipophilicity_astrazeneca"].out_dim, 8 + ) + self.assertEqual(predictor.model.finetuning_head.net.in_dim, 8) + self.assertEqual(len(predictor.model.finetuning_head.net.layers), 2) + self.assertEqual(predictor.model.finetuning_head.net.out_dim, 1) + + ################################################ + ### Test overwriting with pretrained weights ### + ################################################ + + # Load pretrained & replace in predictor + pretrained_model = PredictorModule.load_pretrained_models( + cfg["finetuning"]["pretrained_model_name"], device="cpu" + ).model + + pretrained_model.create_module_map() + module_map_from_pretrained = deepcopy(pretrained_model._module_map) + module_map = deepcopy(predictor.model.pretrained_model.net._module_map) + + # Finetuning module has only been partially overwritten + cfg_finetune = cfg["finetuning"] + finetuning_module = "".join([cfg_finetune["finetuning_module"], "/", cfg_finetune["task"]]) + finetuning_module_from_pretrained = "".join( + [cfg_finetune["finetuning_module"], "/", cfg_finetune["sub_module_from_pretrained"]] + ) + + pretrained_layers = module_map[finetuning_module] + overwritten_layers = module_map_from_pretrained[finetuning_module_from_pretrained] + + for idx, (pretrained, overwritten) in enumerate(zip(pretrained_layers, overwritten_layers)): + if idx < 1: + assert torch.equal(pretrained.linear.weight, overwritten.linear.weight) + assert torch.equal(pretrained.linear.bias, overwritten.linear.bias) + else: + assert not torch.equal(pretrained.linear.weight, overwritten.linear.weight) + assert not torch.equal(pretrained.linear.bias, overwritten.linear.bias) + + if idx + 1 == min(len(pretrained_layers), len(overwritten_layers)): + break + + _ = module_map.popitem(last=True) + overwritten_modules = module_map.values() + + _ = module_map_from_pretrained.popitem(last=True) + pretrained_modules = module_map_from_pretrained.values() + + for overwritten_module, pretrained_module in zip(overwritten_modules, pretrained_modules): + for overwritten, pretrained in zip( + overwritten_module.parameters(), pretrained_module.parameters() + ): + assert torch.equal(overwritten.data, pretrained.data) + + ################################################# + ### Test correct (un)freezing during training ### + ################################################# + + # Define test callback that checks for correct (un)freezing + class TestCallback(Callback): + def __init__(self, cfg): + super().__init__() + + self.cfg_finetune = cfg["finetuning"] + + def on_train_epoch_start(self, trainer, pl_module): + module_map = pl_module.model.pretrained_model.net._module_map + + finetuning_module = "".join( + [self.cfg_finetune["finetuning_module"], "/", self.cfg_finetune["task"]] + ) + training_depth = self.cfg_finetune["added_depth"] + self.cfg_finetune.pop( + "unfreeze_pretrained_depth", 0 + ) + + frozen_parameters, unfrozen_parameters = [], [] + + if trainer.current_epoch == 0: + frozen = True + + for module_name, module in module_map.items(): + if module_name == finetuning_module: + # After the finetuning module, all parameters are unfrozen + frozen = False + + frozen_parameters.extend( + [ + parameter.requires_grad + for parameter in module[:-training_depth].parameters() + ] + ) + unfrozen_parameters.extend( + [ + parameter.requires_grad + for parameter in module[-training_depth:].parameters() + ] + ) + continue + + if frozen: + frozen_parameters.extend( + [parameter.requires_grad for parameter in module.parameters()] + ) + else: + unfrozen_parameters.extend( + [parameter.requires_grad for parameter in module.parameters()] + ) + + # Finetuning head is always unfrozen + unfrozen_parameters.extend( + [ + parameter.requires_grad + for parameter in pl_module.model.finetuning_head.parameters() + ] + ) + + assert not True in frozen_parameters + assert not False in unfrozen_parameters + + if trainer.current_epoch == 2: + # All parameter are unfrozen starting from epoch_unfreeze_all + unfrozen_parameters = [ + parameter.requires_grad for parameter in pl_module.model.parameters() + ] + + assert not False in unfrozen_parameters + + trainer = load_trainer(cfg, accelerator_type) + + finetuning_training_kwargs = cfg["finetuning"]["training_kwargs"] + trainer.callbacks.append(GraphFinetuning(**finetuning_training_kwargs)) + + # Add test callback to trainer + trainer.callbacks.append(TestCallback(cfg)) + + predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"]) + + # Run the model training + trainer.fit(model=predictor, datamodule=datamodule) + + +if __name__ == "__main__": + ut.main() diff --git a/tests/test_ipu_dataloader.py b/tests/test_ipu_dataloader.py index 8aedcbf86..a5882d9e8 100644 --- a/tests/test_ipu_dataloader.py +++ b/tests/test_ipu_dataloader.py @@ -219,7 +219,14 @@ def test_poptorch_graphium_deviceiterations_gradient_accumulation_full(self): model_class, model_kwargs = load_architecture(cfg, in_dims=datamodule.in_dims) # datamodule.setup() predictor = load_predictor( - cfg, model_class, model_kwargs, metrics, accelerator, datamodule.task_norms + cfg, + model_class, + model_kwargs, + metrics, + datamodule.get_task_levels(), + accelerator, + datamodule.featurization, + datamodule.task_norms, ) assert poptorch.ipuHardwareIsAvailable() trainer = load_trainer(cfg, "test", accelerator, "date_time_suffix")