diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 39d51baff..af85c8616 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.09", "3.10", "3.11"] pytorch-version: ["2.0"] runs-on: "ubuntu-latest" @@ -52,6 +52,9 @@ jobs: - name: Install test dependencies run: micromamba install -c conda-forge pytdc # Required to run the `test_finetuning.py` + - name: Install C++ library + run: cd graphium/graphium_cpp && git clone https://github.com/pybind/pybind11.git && export PYTHONPATH=$PYTHONPATH:./pybind11 && python -m pip install . && cd ../.. + - name: Run tests run: pytest -m 'not ipu' diff --git a/LICENSE b/LICENSE index 4cef7c9e1..cbca6ebfd 100644 --- a/LICENSE +++ b/LICENSE @@ -189,6 +189,7 @@ Copyright 2023 Valence Labs Copyright 2023 Recursion Pharmaceuticals Copyright 2023 Graphcore Limited + Copyright 2024 NVIDIA CORPORATION & AFFILIATES Various Academic groups have also contributed to this software under the given license. These include, but are not limited, to the following diff --git a/docs/api/graphium.features.md b/docs/api/graphium.features.md index 758d14135..fa9080700 100644 --- a/docs/api/graphium.features.md +++ b/docs/api/graphium.features.md @@ -5,37 +5,8 @@ Feature extraction and manipulation === "Contents" * [Featurizer](#featurizer) - * [Positional Encoding](#positional-encoding) - * [Properties](#properties) - * [Spectral PE](#spectral-pe) - * [Random Walk PE](#random-walk-pe) - * [NMP](#nmp) ## Featurizer ------------ ::: graphium.features.featurizer - -## Positional Encoding ------------- -::: graphium.features.positional_encoding - - -## Properties ------------- -::: graphium.features.properties - - -## Spectral PE ------------- -::: graphium.features.spectral - - -## Random Walk PE ------------- -::: graphium.features.rw - - -## NMP ------------- -::: graphium.features.nmp diff --git a/env.yml b/env.yml index 64169f3a8..aa8ff8eb5 100644 --- a/env.yml +++ b/env.yml @@ -28,7 +28,7 @@ dependencies: - gcsfs >=2021.6 # ML packages - - cuda-version # works also with CPU-only system. + - cuda-version == 11.2 # works also with CPU-only system. - pytorch >=1.12 - lightning >=2.0 - torchmetrics @@ -43,6 +43,7 @@ dependencies: # chemistry - rdkit - datamol >=0.10 + - boost # needed by rdkit # Optional deps - sympy diff --git a/expts/configs/config_gps_10M_pcqm4m.yaml b/expts/configs/config_gps_10M_pcqm4m.yaml index 10faa3b1e..0487a8d04 100644 --- a/expts/configs/config_gps_10M_pcqm4m.yaml +++ b/expts/configs/config_gps_10M_pcqm4m.yaml @@ -59,7 +59,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" homolumo: @@ -76,10 +75,6 @@ datamodule: split_test: 0.1 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -115,7 +110,6 @@ datamodule: num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/configs/config_gps_10M_pcqm4m_mod.yaml b/expts/configs/config_gps_10M_pcqm4m_mod.yaml index e2cdb44c2..19543302b 100644 --- a/expts/configs/config_gps_10M_pcqm4m_mod.yaml +++ b/expts/configs/config_gps_10M_pcqm4m_mod.yaml @@ -8,7 +8,6 @@ constants: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" homolumo: @@ -25,10 +24,6 @@ datamodule: split_test: 0.1 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -84,7 +79,6 @@ datamodule: num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" # ipu_dataloader_training_opts: # mode: async diff --git a/expts/configs/config_mpnn_10M_b3lyp.yaml b/expts/configs/config_mpnn_10M_b3lyp.yaml index c385d7689..424dbcd71 100644 --- a/expts/configs/config_mpnn_10M_b3lyp.yaml +++ b/expts/configs/config_mpnn_10M_b3lyp.yaml @@ -60,7 +60,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" betagap: @@ -88,12 +87,7 @@ datamodule: split_test: 0.1 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/b3lyp/" - dataloading_from: ram featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -127,7 +121,6 @@ datamodule: num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/configs/config_mpnn_pcqm4m.yaml b/expts/configs/config_mpnn_pcqm4m.yaml index 9735f9555..70972d370 100644 --- a/expts/configs/config_mpnn_pcqm4m.yaml +++ b/expts/configs/config_mpnn_pcqm4m.yaml @@ -8,7 +8,6 @@ constants: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" homolumo: @@ -26,12 +25,7 @@ datamodule: split_names: ["train", "valid", "test-dev"] # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 20 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "graphium/data/PCQM4Mv2/" - dataloading_from: ram featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -61,7 +55,6 @@ datamodule: num_workers: 40 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" # ipu_dataloader_training_opts: # mode: async diff --git a/expts/hydra-configs/architecture/largemix.yaml b/expts/hydra-configs/architecture/largemix.yaml index 32efef778..f1f494157 100644 --- a/expts/hydra-configs/architecture/largemix.yaml +++ b/expts/hydra-configs/architecture/largemix.yaml @@ -83,12 +83,7 @@ architecture: datamodule: module_type: "MultitaskFromSmilesDataModule" args: - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 20 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: ${constants.datacache_path} - dataloading_from: "disk" num_workers: 20 # -1 to use all persistent_workers: True featurization: diff --git a/expts/hydra-configs/architecture/pcqm4m.yaml b/expts/hydra-configs/architecture/pcqm4m.yaml index 494875765..f3fc04b63 100644 --- a/expts/hydra-configs/architecture/pcqm4m.yaml +++ b/expts/hydra-configs/architecture/pcqm4m.yaml @@ -81,13 +81,8 @@ architecture: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: ${constants.datacache_path} num_workers: 40 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. diff --git a/expts/hydra-configs/architecture/toymix.yaml b/expts/hydra-configs/architecture/toymix.yaml index a62b839cd..f4ae5a5db 100644 --- a/expts/hydra-configs/architecture/toymix.yaml +++ b/expts/hydra-configs/architecture/toymix.yaml @@ -74,12 +74,7 @@ architecture: datamodule: module_type: "MultitaskFromSmilesDataModule" args: - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: ${constants.datacache_path} - dataloading_from: ram num_workers: 30 # -1 to use all persistent_workers: False featurization: diff --git a/expts/hydra-configs/finetuning/admet_baseline.yaml b/expts/hydra-configs/finetuning/admet_baseline.yaml index 410d0dd64..6f9fc1c93 100644 --- a/expts/hydra-configs/finetuning/admet_baseline.yaml +++ b/expts/hydra-configs/finetuning/admet_baseline.yaml @@ -20,7 +20,6 @@ constants: datamodule: args: batch_size_training: 32 - dataloading_from: ram persistent_workers: true num_workers: 4 diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml index d5b302dd1..8dcf2c0c4 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml @@ -25,7 +25,6 @@ metrics: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" homolumo: diff --git a/expts/hydra-configs/training/accelerator/largemix_cpu.yaml b/expts/hydra-configs/training/accelerator/largemix_cpu.yaml index 6f5e0606a..ea83fdf58 100644 --- a/expts/hydra-configs/training/accelerator/largemix_cpu.yaml +++ b/expts/hydra-configs/training/accelerator/largemix_cpu.yaml @@ -4,7 +4,6 @@ datamodule: args: batch_size_training: 200 batch_size_inference: 200 - featurization_n_jobs: 20 num_workers: 20 predictor: diff --git a/expts/hydra-configs/training/accelerator/largemix_gpu.yaml b/expts/hydra-configs/training/accelerator/largemix_gpu.yaml index ac728c982..17ac12ad8 100644 --- a/expts/hydra-configs/training/accelerator/largemix_gpu.yaml +++ b/expts/hydra-configs/training/accelerator/largemix_gpu.yaml @@ -7,7 +7,6 @@ datamodule: args: batch_size_training: 2048 batch_size_inference: 2048 - featurization_n_jobs: 6 num_workers: 6 predictor: diff --git a/expts/hydra-configs/training/accelerator/toymix_cpu.yaml b/expts/hydra-configs/training/accelerator/toymix_cpu.yaml index 9022eeb84..f81662285 100644 --- a/expts/hydra-configs/training/accelerator/toymix_cpu.yaml +++ b/expts/hydra-configs/training/accelerator/toymix_cpu.yaml @@ -4,7 +4,6 @@ datamodule: args: batch_size_training: 200 batch_size_inference: 200 - featurization_n_jobs: 4 num_workers: 4 predictor: diff --git a/expts/hydra-configs/training/accelerator/toymix_gpu.yaml b/expts/hydra-configs/training/accelerator/toymix_gpu.yaml index c2c8e4066..ac4e48c26 100644 --- a/expts/hydra-configs/training/accelerator/toymix_gpu.yaml +++ b/expts/hydra-configs/training/accelerator/toymix_gpu.yaml @@ -7,7 +7,6 @@ datamodule: args: batch_size_training: 200 batch_size_inference: 200 - featurization_n_jobs: 4 num_workers: 4 predictor: diff --git a/expts/neurips2023_configs/base_config/large.yaml b/expts/neurips2023_configs/base_config/large.yaml index 8a836f368..18503527c 100644 --- a/expts/neurips2023_configs/base_config/large.yaml +++ b/expts/neurips2023_configs/base_config/large.yaml @@ -62,7 +62,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -133,11 +132,6 @@ datamodule: epoch_sampling_fraction: 1.0 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - dataloading_from: disk processed_graph_data_path: ${constants.datacache_path} featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), diff --git a/expts/neurips2023_configs/base_config/large_pcba.yaml b/expts/neurips2023_configs/base_config/large_pcba.yaml index f90675e73..a1e3d108f 100644 --- a/expts/neurips2023_configs/base_config/large_pcba.yaml +++ b/expts/neurips2023_configs/base_config/large_pcba.yaml @@ -62,7 +62,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" @@ -132,11 +131,6 @@ datamodule: #epoch_sampling_fraction: 1.0 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - dataloading_from: disk processed_graph_data_path: ${constants.datacache_path} featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), diff --git a/expts/neurips2023_configs/base_config/large_pcqm_g25.yaml b/expts/neurips2023_configs/base_config/large_pcqm_g25.yaml index 1fac9176b..b71c43cf2 100644 --- a/expts/neurips2023_configs/base_config/large_pcqm_g25.yaml +++ b/expts/neurips2023_configs/base_config/large_pcqm_g25.yaml @@ -62,7 +62,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" @@ -132,11 +131,6 @@ datamodule: # epoch_sampling_fraction: 1.0 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - dataloading_from: disk processed_graph_data_path: ${constants.datacache_path} featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), diff --git a/expts/neurips2023_configs/base_config/large_pcqm_n4.yaml b/expts/neurips2023_configs/base_config/large_pcqm_n4.yaml index f9a9e58b8..464e49581 100644 --- a/expts/neurips2023_configs/base_config/large_pcqm_n4.yaml +++ b/expts/neurips2023_configs/base_config/large_pcqm_n4.yaml @@ -62,7 +62,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" @@ -132,11 +131,6 @@ datamodule: epoch_sampling_fraction: 1.0 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - dataloading_from: disk processed_graph_data_path: ${constants.datacache_path} featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), diff --git a/expts/neurips2023_configs/base_config/small.yaml b/expts/neurips2023_configs/base_config/small.yaml index fd7ce3fbe..4914fdda3 100644 --- a/expts/neurips2023_configs/base_config/small.yaml +++ b/expts/neurips2023_configs/base_config/small.yaml @@ -51,7 +51,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" qm9: @@ -97,10 +96,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), diff --git a/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml b/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml index 7b2d2cbdf..e107fa386 100644 --- a/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml +++ b/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" qm9: @@ -96,10 +95,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -134,7 +129,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/config_classifigression_l1000.yaml b/expts/neurips2023_configs/config_classifigression_l1000.yaml index 48f06d9d1..fb77ad457 100644 --- a/expts/neurips2023_configs/config_classifigression_l1000.yaml +++ b/expts/neurips2023_configs/config_classifigression_l1000.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -76,10 +75,6 @@ datamodule: splits_path: graphium/data/neurips2023/small-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 1 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -114,7 +109,6 @@ datamodule: num_workers: 5 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/config_large_gcn_gpu.yaml b/expts/neurips2023_configs/config_large_gcn_gpu.yaml index 2830530aa..31a02e22c 100644 --- a/expts/neurips2023_configs/config_large_gcn_gpu.yaml +++ b/expts/neurips2023_configs/config_large_gcn_gpu.yaml @@ -49,7 +49,6 @@ datamodule: df_path: expts/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet splits_path: expts/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt` - featurization_n_jobs: 4 # 30 processed_graph_data_path: "../datacache/neurips2023-small/" num_workers: 4 # 30 diff --git a/expts/neurips2023_configs/config_luis_jama.yaml b/expts/neurips2023_configs/config_luis_jama.yaml index 5135c5cae..e0549e0f0 100644 --- a/expts/neurips2023_configs/config_luis_jama.yaml +++ b/expts/neurips2023_configs/config_luis_jama.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" @@ -84,10 +83,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -122,7 +117,6 @@ datamodule: num_workers: 4 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/config_small_gcn_gpu.yaml b/expts/neurips2023_configs/config_small_gcn_gpu.yaml index 8b5a46e26..ccad70af6 100644 --- a/expts/neurips2023_configs/config_small_gcn_gpu.yaml +++ b/expts/neurips2023_configs/config_small_gcn_gpu.yaml @@ -41,7 +41,6 @@ datamodule: zinc: df_path: expts/data/neurips2023/small-dataset/ZINC12k.csv.gz splits_path: expts/data/neurips2023/small-dataset/ZINC12k_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt` - featurization_n_jobs: 4 # 30 processed_graph_data_path: "../datacache/neurips2023-small/" num_workers: 4 # 30 diff --git a/expts/neurips2023_configs/debug/config_debug.yaml b/expts/neurips2023_configs/debug/config_debug.yaml index 3d31e5e8c..21a8c30b2 100644 --- a/expts/neurips2023_configs/debug/config_debug.yaml +++ b/expts/neurips2023_configs/debug/config_debug.yaml @@ -51,7 +51,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" @@ -70,10 +69,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 0 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), diff --git a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml index ec05bf6eb..236673699 100644 --- a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml +++ b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml @@ -60,7 +60,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -131,10 +130,6 @@ datamodule: epoch_sampling_fraction: 1.0 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), diff --git a/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml b/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml index 26b50756f..773ca8814 100644 --- a/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml +++ b/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml @@ -40,7 +40,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" qm9: @@ -84,10 +83,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml index e05d1be8d..1aba37eb4 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_mcf7: @@ -65,10 +64,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/mcf7/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +98,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml index cf924850e..77837d750 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcba_1328: @@ -65,10 +64,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/pcba/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +98,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml index f1c9bcfd4..1c021a559 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -65,10 +64,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/vcap/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +98,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml index 01988e527..bd09385f5 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_g25: @@ -68,10 +67,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/g25/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -106,7 +101,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml index fdeb4b399..5abf9790d 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_mcf7: @@ -65,10 +64,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/mcf7/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +98,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml index 5920a80f6..834967498 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_n4: @@ -69,10 +68,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/n4/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -107,7 +102,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml index de2f7fbc4..f390a7a2b 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcba_1328: @@ -65,10 +64,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/pcba/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +98,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml index ca820e86b..d13a757f3 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_g25: @@ -83,10 +82,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/pcq/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -121,7 +116,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml index c21b765b3..75f802926 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -65,10 +64,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/vcap/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +98,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml index b88314797..02679153c 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_g25: @@ -68,10 +67,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/g25/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -106,7 +101,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml index b96fc8daf..0506dbfea 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_mcf7: @@ -65,10 +64,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/mcf7/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +98,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml index e98ae03da..58bad3bbc 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_n4: @@ -69,10 +68,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/n4/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -107,7 +102,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml index 427f7ca0f..3ce9ffde2 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcba_1328: @@ -65,10 +64,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/pcba/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +98,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml index 07fc6d009..d541b9b04 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_g25: @@ -83,10 +82,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/pcq/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -121,7 +116,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml index b63263b3d..121d74ddb 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml @@ -50,7 +50,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -65,10 +64,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/vcap/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +98,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/graphium/config/_loader.py b/graphium/config/_loader.py index dc1da4998..704c9029c 100644 --- a/graphium/config/_loader.py +++ b/graphium/config/_loader.py @@ -204,8 +204,6 @@ def load_architecture( architecture: The datamodule used to process and load the data """ - if isinstance(config, dict) and "finetuning" not in config: - config = omegaconf.OmegaConf.create(config) cfg_arch = config["architecture"] # Select the architecture @@ -263,10 +261,6 @@ def load_architecture( else: gnn_kwargs.setdefault("in_dim", edge_in_dim) - # Set the parameters for the full network - if "finetuning" not in config: - task_heads_kwargs = omegaconf.OmegaConf.to_object(task_heads_kwargs) - # Set all the input arguments for the model model_kwargs = dict( gnn_kwargs=gnn_kwargs, diff --git a/graphium/config/dummy_finetuning_from_gnn.yaml b/graphium/config/dummy_finetuning_from_gnn.yaml index ca9493d30..4de1e79bc 100644 --- a/graphium/config/dummy_finetuning_from_gnn.yaml +++ b/graphium/config/dummy_finetuning_from_gnn.yaml @@ -128,10 +128,6 @@ datamodule: batch_size_training: 200 batch_size_inference: 200 - featurization_n_jobs: 0 num_workers: 0 - prepare_dict_or_graph: pyg:graph - featurization_progress: True - featurization_backend: "loky" persistent_workers: False \ No newline at end of file diff --git a/graphium/config/dummy_finetuning_from_task_head.yaml b/graphium/config/dummy_finetuning_from_task_head.yaml index 2682ccee3..90b0d5341 100644 --- a/graphium/config/dummy_finetuning_from_task_head.yaml +++ b/graphium/config/dummy_finetuning_from_task_head.yaml @@ -134,12 +134,8 @@ datamodule: batch_size_training: 200 batch_size_inference: 200 - featurization_n_jobs: 0 num_workers: 0 - prepare_dict_or_graph: pyg:graph - featurization_progress: True - featurization_backend: "loky" persistent_workers: False diff --git a/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml b/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml index 044a0129c..a34399dd1 100644 --- a/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml +++ b/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml @@ -45,8 +45,6 @@ datamodule: weights_col: null # This may not always be provided # Featurization - featurization_n_jobs: 16 - featurization_progress: True featurization: atom_property_list_onehot: ["atomic-number", "degree"] atom_property_list_float: [] diff --git a/graphium/config/fake_multilevel_multitask_pyg.yaml b/graphium/config/fake_multilevel_multitask_pyg.yaml index 918807cb4..3cce7b5e2 100644 --- a/graphium/config/fake_multilevel_multitask_pyg.yaml +++ b/graphium/config/fake_multilevel_multitask_pyg.yaml @@ -45,8 +45,6 @@ datamodule: weights_col: null # This may not always be provided # Featurization - featurization_n_jobs: 16 - featurization_progress: True featurization: atom_property_list_onehot: ["atomic-number", "degree"] atom_property_list_float: [] diff --git a/graphium/config/zinc_default_multitask_pyg.yaml b/graphium/config/zinc_default_multitask_pyg.yaml index b9435ec7e..01d20bc53 100644 --- a/graphium/config/zinc_default_multitask_pyg.yaml +++ b/graphium/config/zinc_default_multitask_pyg.yaml @@ -45,8 +45,6 @@ datamodule: weights_type: null # Featurization - featurization_n_jobs: 16 - featurization_progress: True featurization: atom_property_list_onehot: ["atomic-number", "degree"] atom_property_list_float: [] diff --git a/graphium/data/__init__.py b/graphium/data/__init__.py index 0a8fcd24d..b18cda421 100644 --- a/graphium/data/__init__.py +++ b/graphium/data/__init__.py @@ -6,8 +6,5 @@ from .datamodule import GraphOGBDataModule from .datamodule import MultitaskFromSmilesDataModule from .datamodule import ADMETBenchmarkDataModule -from .datamodule import FakeDataModule -from .dataset import SingleTaskDataset from .dataset import MultitaskDataset -from .dataset import FakeDataset diff --git a/graphium/data/collate.py b/graphium/data/collate.py index 22486b034..cab3151de 100644 --- a/graphium/data/collate.py +++ b/graphium/data/collate.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -22,15 +22,15 @@ from typing import Union, List, Optional, Dict, Type, Any, Iterable from torch_geometric.data import Data, Batch -from graphium.features import GraphDict, to_dense_array from graphium.utils.packing import fast_packing, get_pack_sizes, node_to_pack_indices_mask from loguru import logger from graphium.data.utils import get_keys +from graphium.data.dataset import torch_enum_to_dtype def graphium_collate_fn( elements: Union[List[Any], Dict[str, List[Any]]], - labels_size_dict: Optional[Dict[str, Any]] = None, + labels_num_cols_dict: Optional[Dict[str, Any]] = None, labels_dtype_dict: Optional[Dict[str, Any]] = None, mask_nan: Union[str, float, Type[None]] = "raise", do_not_collate_keys: List[str] = [], @@ -52,7 +52,7 @@ def graphium_collate_fn( elements: The elements to batch. See `torch.utils.data.dataloader.default_collate`. - labels_size_dict: + labels_num_cols_dict: (Note): This is an attribute of the `MultitaskDataset`. A dictionary of the form Dict[tasks, sizes] which has task names as keys and the size of the label tensor as value. The size of the tensor corresponds to how many @@ -86,25 +86,34 @@ def graphium_collate_fn( The batched elements. See `torch.utils.data.dataloader.default_collate`. """ + # Skip any elements that failed + if None in elements: + elements = [e for e in elements if e is not None] + elem = elements[0] if isinstance(elem, Mapping): + if "features" in elem: + num_nodes = [d["features"].num_nodes for d in elements] + num_edges = [d["features"].num_edges for d in elements] + else: + num_nodes = [d["num_nodes"] for d in elements] + num_edges = [d["num_edges"] for d in elements] + batch = {} for key in elem: # Multitask setting: We have to pad the missing labels if key == "labels": labels = [d[key] for d in elements] - batch[key] = collate_labels(labels, labels_size_dict, labels_dtype_dict) - - # If the features are a dictionary containing GraphDict elements, - # Convert to pyg graphs and use the pyg batching. - elif isinstance(elem[key], GraphDict): - pyg_graphs = [d[key].make_pyg_graph(mask_nan=mask_nan) for d in elements] - batch[key] = collage_pyg_graph(pyg_graphs) + batch[key] = collate_labels( + labels, labels_num_cols_dict, labels_dtype_dict, num_nodes, num_edges + ) + elif key == "num_nodes" or key == "num_edges": + continue # If a PyG Graph is provided, use the PyG batching elif isinstance(elem[key], Data): pyg_graphs = [d[key] for d in elements] - batch[key] = collage_pyg_graph(pyg_graphs, batch_size_per_pack=batch_size_per_pack) + batch[key] = collage_pyg_graph(pyg_graphs, num_nodes, batch_size_per_pack=batch_size_per_pack) # Ignore the collate for specific keys elif key in do_not_collate_keys: @@ -125,42 +134,33 @@ def graphium_collate_fn( return default_collate(elements) -def collage_pyg_graph(pyg_graphs: Iterable[Union[Data, Dict]], batch_size_per_pack: Optional[int] = None): +def collage_pyg_graph( + pyg_graphs: List[Data], num_nodes: List[int], batch_size_per_pack: Optional[int] = None +): """ Function to collate pytorch geometric graphs. Convert all numpy types to torch Convert edge indices to int64 Parameters: - pyg_graphs: Iterable of PyG graphs + pyg_graphs: List of PyG graphs batch_size_per_pack: The number of graphs to pack together. This is useful for using packing with the Transformer, """ # Calculate maximum number of nodes per graph in current batch - num_nodes_list = [] - for pyg_graph in pyg_graphs: - num_nodes_list.append(pyg_graph["num_nodes"]) - max_num_nodes_per_graph = max(num_nodes_list) + max_num_nodes_per_graph = max(num_nodes) - pyg_batch = [] for pyg_graph in pyg_graphs: for pyg_key in get_keys(pyg_graph): - tensor = pyg_graph[pyg_key] - - # Convert numpy/scipy to Pytorch - if isinstance(tensor, (ndarray, spmatrix)): - tensor = torch.as_tensor(to_dense_array(tensor, tensor.dtype)) - # pad nodepair-level positional encodings if pyg_key.startswith("nodepair_"): - pyg_graph[pyg_key] = pad_nodepairs(tensor, pyg_graph["num_nodes"], max_num_nodes_per_graph) - else: - pyg_graph[pyg_key] = tensor + pyg_graph[pyg_key] = pad_nodepairs( + pyg_graph[pyg_key], pyg_graph.num_nodes, max_num_nodes_per_graph + ) # Convert edge index to int64 pyg_graph.edge_index = pyg_graph.edge_index.to(torch.int64) - pyg_batch.append(pyg_graph) # Apply the packing at the mini-batch level. This is useful for using packing with the Transformer, # especially in the case of the large graphs being much larger than the small graphs. @@ -170,87 +170,68 @@ def collage_pyg_graph(pyg_graphs: Iterable[Union[Data, Dict]], batch_size_per_pa raise NotImplementedError( "Packing is not yet functional, as it changes the order of the graphs in the batch without changing the label order" ) - num_nodes = [g.num_nodes for g in pyg_batch] packed_graph_idx = fast_packing(num_nodes, batch_size_per_pack) # Get the node to pack indices and the mask pack_from_node_idx, pack_attn_mask = node_to_pack_indices_mask(packed_graph_idx, num_nodes) - for pyg_graph in pyg_batch: + for pyg_graph in pyg_graphs: pyg_graph.pack_from_node_idx = pack_from_node_idx pyg_graph.pack_attn_mask = pack_attn_mask - return Batch.from_data_list(pyg_batch) + return Batch.from_data_list(pyg_graphs) -def pad_to_expected_label_size(labels: torch.Tensor, label_size: List[int]): +def pad_to_expected_label_size(labels: torch.Tensor, label_rows: int, label_cols: int): """Determine difference of ``labels`` shape to expected shape `label_size` and pad with ``torch.nan`` accordingly. """ - if label_size == list(labels.shape): + if len(labels.shape) == 2 and label_rows == labels.shape[0] and label_cols == labels.shape[1]: return labels - missing_dims = len(label_size) - len(labels.shape) + missing_dims = 2 - len(labels.shape) for _ in range(missing_dims): labels.unsqueeze(-1) - pad_sizes = [(0, expected - actual) for expected, actual in zip(label_size, labels.shape)] - pad_sizes = [item for before_after in pad_sizes for item in before_after] - pad_sizes.reverse() + pad_sizes = [label_cols - labels.shape[1], 0, label_rows - labels.shape[0], 0] if any([s < 0 for s in pad_sizes]): - logger.warning(f"More labels available than expected. Will remove data to fit expected size.") + logger.warning( + f"More labels available than expected. Will remove data to fit expected size. cols: {labels.shape[1]}->{label_cols}, rows: {labels.shape[0]}->{label_rows}" + ) return torch.nn.functional.pad(labels, pad_sizes, value=torch.nan) -def collate_pyg_graph_labels(pyg_labels: List[Data]): - """ - Function to collate pytorch geometric labels. - Convert all numpy types to torch - - Parameters: - pyg_labels: Iterable of PyG label Data objects - """ - pyg_batch = [] - for pyg_label in pyg_labels: - for pyg_key in set(get_keys(pyg_label)) - set(["x", "edge_index"]): - tensor = pyg_label[pyg_key] - # Convert numpy/scipy to Pytorch - if isinstance(tensor, (ndarray, spmatrix)): - tensor = torch.as_tensor(to_dense_array(tensor, tensor.dtype)) - - pyg_label[pyg_key] = tensor - - pyg_batch.append(pyg_label) - - return Batch.from_data_list(pyg_batch) - - -def get_expected_label_size(label_data: Data, task: str, label_size: List[int]): +def get_expected_label_rows(label_data: Data, task: str, num_nodes: int, num_edges: int): """Determines expected label size based on the specfic graph properties and the number of targets in the task-dataset. """ if task.startswith("graph_"): num_labels = 1 elif task.startswith("node_"): - num_labels = label_data.x.size(0) + num_labels = num_nodes elif task.startswith("edge_"): - num_labels = label_data.edge_index.size(1) + num_labels = num_edges elif task.startswith("nodepair_"): raise NotImplementedError() - return [num_labels] + label_size + else: + print("Task name " + task + " in get_expected_label_rows") + raise NotImplementedError() + return num_labels def collate_labels( labels: List[Data], - labels_size_dict: Optional[Dict[str, Any]] = None, + labels_num_cols_dict: Optional[Dict[str, Any]] = None, labels_dtype_dict: Optional[Dict[str, Any]] = None, + num_nodes: List[int] = None, + num_edges: List[int] = None, ): """Collate labels for multitask learning. Parameters: labels: List of labels - labels_size_dict: Dict of the form Dict[tasks, sizes] which has task names as keys + labels_num_cols_dict: Dict of the form Dict[tasks, sizes] which has task names as keys and the size of the label tensor as value. The size of the tensor corresponds to how many labels/values there are to predict for that task. labels_dtype_dict: @@ -260,25 +241,21 @@ def collate_labels( Returns: A dictionary of the form Dict[tasks, labels] where tasks is the name of the task and labels - is a tensor of shape (batch_size, *labels_size_dict[task]). + is a tensor of shape (batch_size, *labels_num_cols_dict[task]). """ - if labels_size_dict is not None: - for this_label in labels: - for task in labels_size_dict.keys(): - labels_size_dict[task] = list(labels_size_dict[task]) - if len(labels_size_dict[task]) >= 2: - labels_size_dict[task] = labels_size_dict[task][1:] - elif not task.startswith("graph_"): - labels_size_dict[task] = [1] + if labels_num_cols_dict is not None: + for index, this_label in enumerate(labels): label_keys_set = set(get_keys(this_label)) - empty_task_labels = set(labels_size_dict.keys()) - label_keys_set + empty_task_labels = set(labels_num_cols_dict.keys()) - label_keys_set for task in empty_task_labels: - labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task]) - dtype = labels_dtype_dict[task] - this_label[task] = torch.full([*labels_size_dict[task]], torch.nan, dtype=dtype) + label_rows = get_expected_label_rows(this_label, task, num_nodes[index], num_edges[index]) + dtype = torch_enum_to_dtype(labels_dtype_dict[task]) + this_label[task] = torch.full( + (label_rows, labels_num_cols_dict[task]), fill_value=torch.nan, dtype=dtype + ) for task in label_keys_set - set(["x", "edge_index"]) - empty_task_labels: - labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task]) + label_rows = get_expected_label_rows(this_label, task, num_nodes[index], num_edges[index]) if not isinstance(this_label[task], (torch.Tensor)): this_label[task] = torch.as_tensor(this_label[task]) @@ -286,21 +263,23 @@ def collate_labels( # Ensure explicit task dimension also for single task labels if len(this_label[task].shape) == 1: # Distinguish whether target dim or entity dim is missing - if labels_size_dict[task][0] == this_label[task].shape[0]: + if label_rows == this_label[task].shape[0]: # num graphs/nodes/edges/nodepairs already matching this_label[task] = this_label[task].unsqueeze(1) else: # data lost unless entity dim is supposed to be 1 - if labels_size_dict[task][0] == 1: + if label_rows == 1: this_label[task] = this_label[task].unsqueeze(0) else: raise ValueError( - f"Labels for {labels_size_dict[task][0]} nodes/edges/nodepairs expected, got 1." + f"Labels for {label_rows} nodes/edges/nodepairs expected, got 1." ) - this_label[task] = pad_to_expected_label_size(this_label[task], labels_size_dict[task]) + this_label[task] = pad_to_expected_label_size( + this_label[task], label_rows, labels_num_cols_dict[task] + ) - return collate_pyg_graph_labels(labels) + return Batch.from_data_list(labels) def pad_nodepairs(pe: torch.Tensor, num_nodes: int, max_num_nodes_per_graph: int): diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 4e89f6728..e2368d3b3 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -51,27 +51,22 @@ from torch.utils.data.dataloader import DataLoader, Dataset from torch.utils.data import Subset +from rdkit import RDLogger + from graphium.utils import fs -from graphium.features import ( - mol_to_graph_dict, - GraphDict, - mol_to_pyggraph, -) +from graphium.features import mol_to_pyggraph from graphium.data.sampler import DatasetSubSampler from graphium.data.utils import graphium_package_path, found_size_mismatch from graphium.utils.arg_checker import check_arg_iterator from graphium.utils.hashing import get_md5_hash -from graphium.data.smiles_transform import ( - did_featurization_fail, - BatchingSmilesTransform, - smiles_to_unique_mol_ids, -) from graphium.data.collate import graphium_collate_fn import graphium.data.dataset as Datasets from graphium.data.normalization import LabelNormalization from graphium.data.multilevel_utils import extract_labels +import graphium_cpp + torch.multiprocessing.set_sharing_strategy("file_system") @@ -153,7 +148,6 @@ def __init__( self._predict_ds = None self._data_is_prepared = False - self._data_is_cached = False def prepare_data(self): raise NotImplementedError() @@ -790,8 +784,7 @@ class MultitaskFromSmilesDataModule(BaseDataModule, IPUDataModuleModifier): def __init__( self, task_specific_args: Union[Dict[str, DatasetProcessingParams], Dict[str, Any]], - processed_graph_data_path: Optional[Union[str, os.PathLike]] = None, - dataloading_from: str = "ram", + processed_graph_data_path: Union[str, os.PathLike], featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, @@ -800,12 +793,8 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = False, multiprocessing_context: Optional[str] = None, - featurization_n_jobs: int = -1, - featurization_progress: bool = False, - featurization_backend: str = "loky", - featurization_batch_size: int = 1000, collate_fn: Optional[Callable] = None, - prepare_dict_or_graph: str = "pyg:graph", + preprocessing_n_jobs: int = 0, **kwargs, ): """ @@ -821,33 +810,27 @@ def __init__( - `df_path` - `smiles_col` - `label_cols` - dataloading_from: Whether to load the data from RAM or from disk. If set to "disk", the data - must have been previously cached with `processed_graph_data_path` set. If set to "ram", the data - will be loaded in RAM and the `processed_graph_data_path` will be ignored. featurization: args to apply to the SMILES to Graph featurizer. batch_size_training: batch size for training and val dataset. batch_size_inference: batch size for test dataset. num_workers: Number of workers for the dataloader. Use -1 to use all available cores. pin_memory: Whether to pin on paginated CPU memory for the dataloader. - featurization_n_jobs: Number of cores to use for the featurization. - featurization_progress: whether to show a progress bar during featurization. - featurization_backend: The backend to use for the molecular featurization. - "multiprocessing": Found to cause less memory issues. - "loky": joblib's Default. Found to cause memory leaks. - "threading": Found to be slow. - featurization_batch_size: Batch size to use for the featurization. collate_fn: A custom torch collate function. Default is to `graphium.data.graphium_collate_fn` - prepare_dict_or_graph: Whether to preprocess all molecules as Graph dict or PyG graphs. - Possible options: + preprocessing_n_jobs: Number of threads to use during preprocessing. + Use 0 to use all available cores, or -1 to use all but one core. - - "pyg:dict": Process molecules as a `dict`. It's faster and requires less RAM during - pre-processing. It is slower during training with with `num_workers=0` since - pyg `Data` will be created during data-loading, but faster with large - `num_workers`, and less likely to cause memory issues with the parallelization. - - "pyg:graph": Process molecules as `pyg.data.Data`. + dataloading_from: Deprecated. Behaviour now always matches previous "disk" option. + featurization_n_jobs: Deprecated. + featurization_progress: Deprecated. + featurization_backend: Deprecated. + featurization_batch_size: Deprecated. + prepare_dict_or_graph: Deprecated. Behaviour now always matches previous "pyg:graph" option. """ BaseDataModule.__init__( self, @@ -878,26 +861,17 @@ def __init__( task: self.task_dataset_processing_params[task].epoch_sampling_fraction for task in self.task_dataset_processing_params.keys() } - - self.featurization_n_jobs = featurization_n_jobs - self.featurization_progress = featurization_progress - self.featurization_backend = featurization_backend - self.featurization_batch_size = featurization_batch_size + self.task_names = [task for task in self.task_dataset_processing_params.keys()] self.task_train_indices = None self.task_val_indices = None self.task_test_indices = None - self.single_task_datasets = None - self.train_singletask_datasets = None - self.val_singletask_datasets = None - self.test_singletask_datasets = None - self.train_ds = None self.val_ds = None self.test_ds = None - self._parse_caching_args(processed_graph_data_path, dataloading_from) + self._parse_caching_args(processed_graph_data_path) self.task_norms = {} @@ -906,42 +880,66 @@ def __init__( self.featurization = featurization - # Whether to transform the smiles into a pyg `Data` graph or a dictionary compatible with pyg - if prepare_dict_or_graph == "pyg:dict": - self.smiles_transformer = partial(mol_to_graph_dict, **featurization) - elif prepare_dict_or_graph == "pyg:graph": - self.smiles_transformer = partial(mol_to_pyggraph, **featurization) - else: - raise ValueError( - f"`prepare_dict_or_graph` should be either 'pyg:dict' or 'pyg:graph', Provided: `{prepare_dict_or_graph}`" + # Copy featurization for the representation used by graphium_cpp + encoded_featurization = deepcopy(featurization) + self.encoded_featurization = encoded_featurization + + def encode_feature_options(options, name, encoding_function): + if name not in options or options[name] is None: + options[name] = torch.tensor(data=[], dtype=torch.int64) + else: + options[name] = encoding_function(options[name]) + + encode_feature_options( + encoded_featurization, + "atom_property_list_onehot", + graphium_cpp.atom_onehot_feature_names_to_tensor, + ) + encode_feature_options( + encoded_featurization, "atom_property_list_float", graphium_cpp.atom_float_feature_names_to_tensor + ) + encode_feature_options( + encoded_featurization, "edge_property_list", graphium_cpp.bond_feature_names_to_tensor + ) + + if ( + "pos_encoding_as_features" in featurization + and featurization["pos_encoding_as_features"] is not None + and featurization["pos_encoding_as_features"]["pos_types"] is not None + ): + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor( + featurization["pos_encoding_as_features"]["pos_types"] ) + else: + pos_encoding_names = [] + pos_encoding_tensor = torch.tensor(data=[], dtype=torch.int64) + encoded_featurization["pos_encoding_as_features"] = (pos_encoding_names, pos_encoding_tensor) + + explicit_H = featurization["explicit_H"] if "explicit_H" in featurization else False + add_self_loop = featurization["add_self_loop"] if "add_self_loop" in featurization else False + + # Save these for calling graphium_cpp.prepare_and_save_data later + self.add_self_loop = add_self_loop + self.explicit_H = explicit_H + + self.preprocessing_n_jobs = preprocessing_n_jobs + + self.smiles_transformer = partial(mol_to_pyggraph, **encoded_featurization) self.data_hash = self.get_data_hash() - if self.processed_graph_data_path is not None: - if self._ready_to_load_all_from_file(): - self._data_is_prepared = True - self._data_is_cached = True + if self._ready_to_load_all_from_file(): + self._data_is_prepared = True - def _parse_caching_args(self, processed_graph_data_path, dataloading_from): + def _parse_caching_args(self, processed_graph_data_path): """ Parse the caching arguments, and raise errors if the arguments are invalid. """ - # Whether to load the data from RAM or from disk - dataloading_from = dataloading_from.lower() - if dataloading_from not in ["disk", "ram"]: - raise ValueError( - f"`dataloading_from` should be either 'disk' or 'ram', Provided: `{dataloading_from}`" - ) - # If loading from disk, the path to the cached data must be provided - if dataloading_from == "disk" and processed_graph_data_path is None: - raise ValueError( - "When `dataloading_from` is 'disk', `processed_graph_data_path` must be provided." - ) + if processed_graph_data_path is None: + raise ValueError("`processed_graph_data_path` must be provided.") self.processed_graph_data_path = processed_graph_data_path - self.dataloading_from = dataloading_from def _get_task_key(self, task_level: str, task: str): task_prefix = f"{task_level}_" @@ -959,7 +957,27 @@ def get_task_levels(self): return task_level_map - def prepare_data(self, save_smiles_and_ids: bool = False): + @staticmethod + def concat_smiles_tensor_index(): + return 0 + + @staticmethod + def smiles_offsets_tensor_index(): + return 1 + + @staticmethod + def num_nodes_tensor_index(): + return 2 + + @staticmethod + def num_edges_tensor_index(): + return 3 + + @staticmethod + def data_offsets_tensor_index(): + return 4 + + def prepare_data(self): """Called only from a single process in distributed settings. Steps: - If each cache is set and exists, reload from cache and return. Otherwise, @@ -970,30 +988,54 @@ def prepare_data(self, save_smiles_and_ids: bool = False): - In the previous step, we were also able to get the unique smiles, which we use to compute the features - For each single-task dataframe and associated data (smiles, labels, etc.): - Filter out the data corresponding to molecules which failed featurization. - - Create a corresponding SingletaskDataset - - Split the SingletaskDataset according to the task-specific splits for train, val and test + - Split the dataset according to the task-specific splits for train, val and test """ - def has_atoms_after_h_removal(smiles): - # Remove all 'H' characters from the SMILES - smiles_without_h = re.sub("H", "", smiles) - # Check if any letters are remaining in the modified string - has_atoms = bool(re.search("[a-zA-Z]", smiles_without_h)) - if has_atoms == False: - logger.info(f"Removed Hydrogen molecule: {smiles}") - return has_atoms + # Don't log error messages from SMILES parsing in RDKit. + # Common error messages were: + # WARNING: not removing hydrogen atom without neighbors + # SMILES Parse Error: syntax error while parsing: restricted + # SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted' + RDLogger.DisableLog("rdApp.*") + + for task, args in self.task_dataset_processing_params.items(): + if args.label_normalization is None: + args.label_normalization = {} + label_normalization = LabelNormalization(**args.label_normalization) + self.task_norms[task] = label_normalization if self._data_is_prepared: logger.info("Data is already prepared.") - self.get_label_statistics(self.processed_graph_data_path, self.data_hash, dataset=None) + self.label_num_cols, self.label_dtypes = graphium_cpp.load_num_cols_and_dtypes( + self.processed_graph_data_path, self.data_hash + ) + self.stage_data = { + "train": graphium_cpp.load_metadata_tensors( + self.processed_graph_data_path, "train", self.data_hash + ), + "val": graphium_cpp.load_metadata_tensors( + self.processed_graph_data_path, "val", self.data_hash + ), + "test": graphium_cpp.load_metadata_tensors( + self.processed_graph_data_path, "test", self.data_hash + ), + } + if len(self.label_num_cols) > 0: + for task in self.task_dataset_processing_params.keys(): + stats = graphium_cpp.load_stats(self.processed_graph_data_path, self.data_hash, task) + if len(stats) < 4: + raise RuntimeError(f'Error loading cached stats for task "{task}"') + + self.task_norms[task].set_statistics(stats[0], stats[1], stats[2], stats[3]) return + task_dataset_args = {} + self.task_train_indices = {} + self.task_val_indices = {} + self.task_test_indices = {} + """Load all single-task dataframes.""" - task_df = {} for task, args in self.task_dataset_processing_params.items(): - if args.label_normalization is None: - args.label_normalization = {} - label_normalization = LabelNormalization(**args.label_normalization) logger.info(f"Reading data for task '{task}'") if args.df is None: # Only load the useful columns, as some datasets can be very large when loading all columns. @@ -1007,24 +1049,18 @@ def has_atoms_after_h_removal(smiles): + check_arg_iterator(args.weights_col, enforce_type=list) ) label_dtype = {col: np.float32 for col in label_cols} - task_df[task] = self._read_table(args.df_path, usecols=usecols, dtype=label_dtype) + df = self._read_table(args.df_path, usecols=usecols, dtype=label_dtype) else: label_cols = self._parse_label_cols( df=args.df, df_path=None, label_cols=args.label_cols, smiles_col=args.smiles_col ) - task_df[task] = args.df - task_df[task] = task_df[task] + df = args.df + args.label_cols = label_cols - self.task_norms[task] = label_normalization - logger.info("Done reading datasets") - """Subsample the data frames and extract the necessary data to create SingleTaskDatasets for each task (smiles, labels, extras).""" - task_dataset_args = {} - for task in task_df.keys(): - task_dataset_args[task] = {} + """Subsample the data frames and extract the necessary data for each task (smiles, labels, extras).""" - for task, df in task_df.items(): # Subsample all the dataframes sample_size = self.task_dataset_processing_params[task].sample_size df = self._sub_sample_df(df, sample_size, self.task_dataset_processing_params[task].seed) @@ -1036,7 +1072,7 @@ def has_atoms_after_h_removal(smiles): logger.info("Filtering done") # Extract smiles, labels, extras args = self.task_dataset_processing_params[task] - smiles, labels, sample_idx, extras = self._extract_smiles_labels( + smiles, labels, label_offsets, sample_idx, extras = self._extract_smiles_labels( df, task_level=args.task_level, smiles_col=args.smiles_col, @@ -1046,125 +1082,74 @@ def has_atoms_after_h_removal(smiles): weights_type=args.weights_type, ) - # Store the relevant information for each task's dataset - task_dataset_args[task]["smiles"] = smiles - task_dataset_args[task]["labels"] = labels - task_dataset_args[task]["sample_idx"] = sample_idx - task_dataset_args[task]["extras"] = extras - - """Convert SMILES to features (graphs, fingerprints, etc.) for the unique molecules found.""" - all_smiles = [] - all_tasks = [] - idx_per_task = {} - total_len = 0 - for task, dataset_args in task_dataset_args.items(): - all_smiles.extend(dataset_args["smiles"]) - num_smiles = len(dataset_args["smiles"]) - idx_per_task[task] = (total_len, total_len + num_smiles) - total_len += num_smiles - for count in range(len(dataset_args["smiles"])): - all_tasks.append(task) - # Get all unique mol ids - all_unique_mol_ids = smiles_to_unique_mol_ids( - all_smiles, - n_jobs=self.featurization_n_jobs, - featurization_batch_size=self.featurization_batch_size, - backend=self.featurization_backend, - ) - _, unique_ids_idx, unique_ids_inv = np.unique( - all_unique_mol_ids, return_index=True, return_inverse=True - ) + num_molecules = len(smiles) - smiles_to_featurize = [all_smiles[ii] for ii in unique_ids_idx] - - # Convert SMILES to features - features, _ = self._featurize_molecules(smiles_to_featurize) - - # Store the features (including Nones, which will be filtered in the next step) - for task in task_dataset_args.keys(): - task_dataset_args[task]["features"] = [] - task_dataset_args[task]["idx_none"] = [] - # Create a list of features matching up with the original smiles - all_features = [features[unique_idx] for unique_idx in unique_ids_inv] - - # Add the features to the task-specific data - for all_idx, task in enumerate(all_tasks): - task_dataset_args[task]["features"].append(all_features[all_idx]) - - """Filter data based on molecules which failed featurization. Create single task datasets as well.""" - self.single_task_datasets = {} - for task, args in task_dataset_args.items(): - # Find out which molecule failed featurization, and filter them out - idx_none = [] - for idx, (feat, labels, smiles) in enumerate( - zip(args["features"], args["labels"], args["smiles"]) - ): - if did_featurization_fail(feat) or found_size_mismatch(task, feat, labels, smiles): - idx_none.append(idx) - this_unique_ids = all_unique_mol_ids[idx_per_task[task][0] : idx_per_task[task][1]] - df, features, smiles, labels, sample_idx, extras, this_unique_ids = self._filter_none_molecules( - idx_none, - task_df[task], - args["features"], - args["smiles"], - args["labels"], - args["sample_idx"], - args["extras"], - this_unique_ids, - ) - task_dataset_args[task]["smiles"] = smiles - task_dataset_args[task]["labels"] = labels - task_dataset_args[task]["features"] = features - task_dataset_args[task]["sample_idx"] = sample_idx - task_dataset_args[task]["extras"] = extras - - # We have the necessary components to create single-task datasets. - self.single_task_datasets[task] = Datasets.SingleTaskDataset( - features=task_dataset_args[task]["features"], - labels=task_dataset_args[task]["labels"], - smiles=task_dataset_args[task]["smiles"], - unique_ids=this_unique_ids, - indices=task_dataset_args[task]["sample_idx"], - **task_dataset_args[task]["extras"], - ) + # Clear the reference to the DataFrame, so that Python can free up the memory. + df = None - """We split the data up to create train, val and test datasets""" - self.task_train_indices = {} - self.task_val_indices = {} - self.task_test_indices = {} + # Store the relevant information for each task's dataset + task_dataset_args[task] = { + "smiles": smiles, + "extras": extras, + } + if args.label_cols != 0: + task_dataset_args[task]["labels"] = labels + task_dataset_args[task]["label_offsets"] = label_offsets + + """We split the data up to create train, val and test datasets""" - for task, df in task_df.items(): train_indices, val_indices, test_indices = self._get_split_indices( - len(df), + num_molecules, split_val=self.task_dataset_processing_params[task].split_val, split_test=self.task_dataset_processing_params[task].split_test, split_seed=self.task_dataset_processing_params[task].seed, splits_path=self.task_dataset_processing_params[task].splits_path, split_names=self.task_dataset_processing_params[task].split_names, - sample_idx=task_dataset_args[task]["sample_idx"], + # smiles and labels are already sub-sampled, so the split indices need to be + # relative to the sample, not the original. + # sample_idx=task_dataset_args[task]["sample_idx"], ) self.task_train_indices[task] = train_indices self.task_val_indices[task] = val_indices self.task_test_indices[task] = test_indices + logger.info("Done reading datasets") + + # The rest of the data preparation and caching is done in graphium_cpp.prepare_and_save_data + normalizations = { + task: self.task_dataset_processing_params[task].label_normalization + for task in self.task_dataset_processing_params.keys() + } ( - self.train_singletask_datasets, - self.val_singletask_datasets, - self.test_singletask_datasets, - ) = self.get_subsets_of_datasets( - self.single_task_datasets, self.task_train_indices, self.task_val_indices, self.task_test_indices + self.stage_data, + all_stats, + self.label_num_cols, + self.label_dtypes, + ) = graphium_cpp.prepare_and_save_data( + self.task_names, + task_dataset_args, + normalizations, + self.processed_graph_data_path, + self.data_hash, + self.task_train_indices, + self.task_val_indices, + self.task_test_indices, + self.add_self_loop, + self.explicit_H, + self.preprocessing_n_jobs, ) - if self.processed_graph_data_path is not None: - self._save_data_to_files(save_smiles_and_ids) - self._data_is_cached = True + for task, stats in all_stats.items(): + if len(stats) < 4: + raise RuntimeError(f'Error loading cached stats for task "{task}"') + + self.task_norms[task].set_statistics(stats[0], stats[1], stats[2], stats[3]) self._data_is_prepared = True def setup( self, stage: str = None, - save_smiles_and_ids: bool = False, ): """ Prepare the torch dataset. Called on every GPUs. Setting state here is ok. @@ -1174,54 +1159,49 @@ def setup( # Can possibly get rid of setup because a single dataset will have molecules exclusively in train, val or test # Produce the label sizes to update the collate function - labels_size = {} - labels_dtype = {} + label_num_cols = {} + label_dtypes = {} if stage == "fit" or stage is None: if self.train_ds is None: - self.train_ds = self._make_multitask_dataset( - self.dataloading_from, "train", save_smiles_and_ids=save_smiles_and_ids - ) + self.train_ds = self._make_multitask_dataset("train") - if self.val_ds is None: - self.val_ds = self._make_multitask_dataset( - self.dataloading_from, "val", save_smiles_and_ids=save_smiles_and_ids - ) + if self.val_ds is None and len(self.stage_data["val"]) >= self.num_edges_tensor_index(): + self.val_ds = self._make_multitask_dataset("val") logger.info(self.train_ds) - logger.info(self.val_ds) - labels_size.update( - self.train_ds.labels_size + label_num_cols.update( + dict(zip(self.train_ds.task_names, self.train_ds.label_num_cols)) ) # Make sure that all task label sizes are contained in here. Maybe do the update outside these if statements. - labels_size.update(self.val_ds.labels_size) - labels_dtype.update(self.train_ds.labels_dtype) - labels_dtype.update(self.val_ds.labels_dtype) + label_dtypes.update(dict(zip(self.train_ds.task_names, self.train_ds.label_dtypes))) + + if self.val_ds is not None: + logger.info(self.val_ds) + label_num_cols.update(dict(zip(self.val_ds.task_names, self.val_ds.label_num_cols))) + label_dtypes.update(dict(zip(self.val_ds.task_names, self.val_ds.label_dtypes))) if stage == "test" or stage is None: - if self.test_ds is None: - self.test_ds = self._make_multitask_dataset( - self.dataloading_from, "test", save_smiles_and_ids=save_smiles_and_ids - ) + if self.test_ds is None and len(self.stage_data["test"]) >= self.num_edges_tensor_index(): + self.test_ds = self._make_multitask_dataset("test") - logger.info(self.test_ds) + if self.test_ds is not None: + logger.info(self.test_ds) - labels_size.update(self.test_ds.labels_size) - labels_dtype.update(self.test_ds.labels_dtype) + label_num_cols.update(dict(zip(self.test_ds.task_names, self.test_ds.label_num_cols))) + label_dtypes.update(dict(zip(self.test_ds.task_names, self.test_ds.label_dtypes))) - default_labels_size_dict = self.collate_fn.keywords.get("labels_size_dict", None) + default_labels_num_cols_dict = self.collate_fn.keywords.get("labels_num_cols_dict", None) - if default_labels_size_dict is None: - self.collate_fn.keywords["labels_size_dict"] = labels_size + if default_labels_num_cols_dict is None: + self.collate_fn.keywords["labels_num_cols_dict"] = label_num_cols default_labels_dtype_dict = self.collate_fn.keywords.get("labels_dtype_dict", None) if default_labels_dtype_dict is None: - self.collate_fn.keywords["labels_dtype_dict"] = labels_dtype + self.collate_fn.keywords["labels_dtype_dict"] = label_dtypes def _make_multitask_dataset( self, - dataloading_from: Literal["disk", "ram"], stage: Literal["train", "val", "test"], - save_smiles_and_ids: bool, ) -> Datasets.MultitaskDataset: """ Create a MultitaskDataset for the given stage using single task datasets @@ -1229,7 +1209,6 @@ def _make_multitask_dataset( Parameters: stage: Stage to create multitask dataset for - save_smiles_and_ids: Whether to save SMILES strings and unique IDs processed_graph_data_path: path to save and load processed graph data from """ @@ -1237,41 +1216,35 @@ def _make_multitask_dataset( assert stage in allowed_stages, f"Multitask dataset stage `{stage}` not in {allowed_stages}" if stage == "train": - singletask_datasets = self.train_singletask_datasets about = "training set" elif stage == "val": - singletask_datasets = self.val_singletask_datasets about = "validation set" elif stage == "test": - singletask_datasets = self.test_singletask_datasets about = "test set" else: raise ValueError(f"Unknown stage {stage}") processed_graph_data_path = self.processed_graph_data_path + stage_data = self.stage_data[stage] + data_offsets = None + if self.data_offsets_tensor_index() < len(stage_data): + data_offsets = stage_data[self.data_offsets_tensor_index()] + multitask_dataset = Datasets.MultitaskDataset( - singletask_datasets, - n_jobs=self.featurization_n_jobs, - backend=self.featurization_backend, - featurization_batch_size=self.featurization_batch_size, - progress=self.featurization_progress, about=about, - save_smiles_and_ids=save_smiles_and_ids, data_path=self._path_to_load_from_file(stage) if processed_graph_data_path else None, - dataloading_from=dataloading_from, - data_is_cached=self._data_is_cached, + featurize_smiles=self.smiles_transformer, + task_names=self.task_names, + label_num_cols=self.label_num_cols, + label_dtypes=self.label_dtypes, + mol_file_data_offsets=data_offsets, + concat_smiles_tensor=stage_data[self.concat_smiles_tensor_index()], + smiles_offsets_tensor=stage_data[self.smiles_offsets_tensor_index()], + num_nodes_tensor=stage_data[self.num_nodes_tensor_index()], + num_edges_tensor=stage_data[self.num_edges_tensor_index()], ) # type: ignore - # calculate statistics for the train split and used for all splits normalization - if stage == "train": - self.get_label_statistics( - self.processed_graph_data_path, self.data_hash, multitask_dataset, train=True - ) - # Normalization has already been applied in cached data - if not self._data_is_prepared: - self.normalize_label(multitask_dataset, stage) - return multitask_dataset def _ready_to_load_all_from_file(self) -> bool: @@ -1300,139 +1273,10 @@ def _data_ready_at_path(self, path: str) -> bool: return can_load_from_file - def _save_data_to_files(self, save_smiles_and_ids: bool = False) -> None: - """ - Save data to files so that they can be loaded from file during training/validation/test - """ - - stages = ["train", "val", "test"] - - # At the moment, we need to merge the `SingleTaskDataset`'s into `MultitaskDataset`s in order to save to file - # This is because the combined labels need to be stored together. We can investigate not doing this if this is a problem - temp_datasets = { - stage: self._make_multitask_dataset( - dataloading_from="ram", stage=stage, save_smiles_and_ids=save_smiles_and_ids - ) - for stage in stages - } - for stage in stages: - self.save_featurized_data(temp_datasets[stage], self._path_to_load_from_file(stage)) - temp_datasets[stage].save_metadata(self._path_to_load_from_file(stage)) - # self.train_ds, self.val_ds, self.test_ds will be created during `setup()` - - if self.dataloading_from == "disk": - del temp_datasets - else: - self.train_ds = temp_datasets["train"] - self.val_ds = temp_datasets["val"] - self.test_ds = temp_datasets["test"] - def get_folder_size(self, path): # check if the data items are actually saved into the folders return sum(os.path.getsize(osp.join(path, f)) for f in os.listdir(path)) - def calculate_statistics(self, dataset: Datasets.MultitaskDataset, train: bool = False): - """ - Calculate the statistics of the labels for each task, and overwrites the `self.task_norms` attribute. - - Parameters: - dataset: the dataset to calculate the statistics from - train: whether the dataset is the training set - - """ - - if self.task_norms and train: - for task in dataset.labels_size.keys(): - # if the label type is graph_*, we need to stack them as the tensor shape is (num_labels, ) - if task.startswith("graph"): - labels = np.stack( - np.array([datum["labels"][task] for datum in dataset if task in datum["labels"]]), - axis=0, - ) - # for other tasks with node_ and edge_, the label shape is [num_nodes/num_edges, num_labels] - # we can concatenate them directly - else: - labels = np.concatenate( - [datum["labels"][task] for datum in dataset if task in datum["labels"]], axis=0 - ) - - self.task_norms[task].calculate_statistics(labels) - - def get_label_statistics( - self, - data_path: Union[str, os.PathLike], - data_hash: str, - dataset: Datasets.MultitaskDataset, - train: bool = False, - ): - """ - Get the label statistics from the dataset, and save them to file, if needed. - `self.task_norms` will be modified in-place with the label statistics. - - Parameters: - data_path: the path to save and load the label statistics to. If None, no saving and loading will be done. - data_hash: the hash of the dataset generated by `get_data_hash()` - dataset: the dataset to calculate the statistics from - train: whether the dataset is the training set - - """ - if data_path is None: - self.calculate_statistics(dataset, train=train) - else: - path_with_hash = os.path.join(data_path, data_hash) - os.makedirs(path_with_hash, exist_ok=True) - filename = os.path.join(path_with_hash, "task_norms.pkl") - if self.task_norms and train and not os.path.isfile(filename): - self.calculate_statistics(dataset, train=train) - torch.save(self.task_norms, filename, pickle_protocol=4) - # if any of the above three condition does not satisfy, we load from file. - else: - self.task_norms = torch.load(filename) - - def normalize_label(self, dataset: Datasets.MultitaskDataset, stage) -> Datasets.MultitaskDataset: - """ - Normalize the labels in the dataset using the statistics in `self.task_norms`. - - Parameters: - dataset: the dataset to normalize the labels from - - Returns: - the dataset with normalized labels - """ - for task in dataset.labels_size.keys(): - # we normalize the dataset if (it is train split) or (it is val/test splits and normalize_val_test is set to true) - if (stage == "train") or (stage in ["val", "test"] and self.task_norms[task].normalize_val_test): - for i in range(len(dataset)): - if task in dataset[i]["labels"]: - dataset[i]["labels"][task] = self.task_norms[task].normalize( - dataset[i]["labels"][task] - ) - return dataset - - def save_featurized_data(self, dataset: Datasets.MultitaskDataset, processed_data_path): - os.makedirs(processed_data_path) # In case the len(dataset) is 0 - for i in range(0, len(dataset), 1000): - os.makedirs(os.path.join(processed_data_path, format(i // 1000, "04d")), exist_ok=True) - process_params = [(index, datum, processed_data_path) for index, datum in enumerate(dataset)] - - # Check if "about" is in the Dataset object - about = "" - if hasattr(dataset, "about"): - about = dataset.about - for param in tqdm(process_params, desc=f"Saving featurized data {about}"): - self.process_func(param) - return - - def process_func(self, param): - index, datum, folder = param - filename = os.path.join(folder, format(index // 1000, "04d"), format(index, "07d") + ".pkl") - torch.save( - {"graph_with_features": datum["features"], "labels": datum["labels"]}, - filename, - pickle_protocol=4, - ) - return - def get_dataloader_kwargs(self, stage: RunningStage, shuffle: bool, **kwargs) -> Dict[str, Any]: """ Get the options for the dataloader depending on the current stage. @@ -1514,110 +1358,6 @@ def get_collate_fn(self, collate_fn): collate_fn.__name__ = graphium_collate_fn.__name__ return collate_fn - # Cannot be used as is for the multitask version, because sample_idx does not apply. - def _featurize_molecules(self, smiles: Iterable[str]) -> Tuple[List, List]: - """ - Precompute the features (graphs, fingerprints, etc.) from the SMILES. - Features are computed from `self.smiles_transformer`. - A warning is issued to mention which molecules failed featurization. - - Note: - (hadim): in case of very large dataset we could: - - or cache the data and read from it during `next(iter(dataloader))` - - or compute the features on-the-fly during `next(iter(dataloader))` - For now we compute in advance and hold everything in memory. - - Parameters: - smiles: A list of all the molecular SMILES to featurize - sample_idx: The indexes corresponding to the sampled SMILES. - If not provided, computed from `numpy.arange`. - - Returns: - features: A list of all the featurized molecules - idx_none: A list of the indexes that failed featurization - """ - - batch_size = BatchingSmilesTransform.parse_batch_size( - numel=len(smiles), - desired_batch_size=self.featurization_batch_size, - n_jobs=self.featurization_n_jobs, - ) - - # Loop all the smiles and compute the features - features = dm.parallelized_with_batches( - BatchingSmilesTransform(self.smiles_transformer), - smiles, - batch_size=batch_size, - progress=True, - n_jobs=self.featurization_n_jobs, - backend=self.featurization_backend, - tqdm_kwargs={"desc": f"featurizing_smiles, batch={batch_size}"}, - ) - - # Warn about None molecules - idx_none = [ii for ii, feat in enumerate(features) if did_featurization_fail(feat)] - if len(idx_none) > 0: - mols_to_msg = [ - f"idx={idx} - smiles={smiles[idx]} - Error_msg[:-200]=\n{str(features[idx])[:-200]}" - for idx in idx_none - ] - msg = "\n".join(mols_to_msg) - logger.warning( - (f"{len(idx_none)} molecules will be removed since they failed featurization:\n" + msg) - ) - - return features, idx_none - - @staticmethod - def _filter_none_molecules( - idx_none: Iterable, - *args: Union[pd.DataFrame, pd.Series, np.ndarray, torch.Tensor, list, tuple, Dict[Any, Iterable]], - ) -> List[Union[pd.DataFrame, pd.Series, np.ndarray, torch.Tensor, list, tuple, Dict[Any, Iterable]]]: - """ - Filter the molecules, labels, etc. for the molecules that failed featurization. - - Parameters: - idx_none: A list of the indexes that failed featurization - args: Any argument from which to filter the failed SMILES. - Can be a `list`, `tuple`, `Tensor`, `np.array`, `Dict`, `pd.DataFrame`, `pd.Series`. - Otherwise, it is not filtered. - WARNING: If a `pd.DataFrame` or `pd.Series` is passed, it filters by the row indexes, - NOT by the `DataFrame.index` or `Series.index`! Be careful! - - Returns: - out: All the `args` with the indexes from `idx_none` removed. - """ - if len(idx_none) == 0: - return args - idx_none = np.asarray(idx_none) - - out = [] - for arg in args: - if isinstance(arg, pd.DataFrame): - new = arg.drop(arg.index[idx_none], axis=0) - elif isinstance(arg, pd.Series): - new = arg.drop(arg.index[idx_none], axis=0) - elif isinstance(arg, np.ndarray): - new = np.delete(arg, idx_none, axis=0) - elif isinstance(arg, torch.Tensor): - not_none = torch.ones(arg.shape[0], dtype=bool) - not_none[idx_none] = False - new = arg[not_none] - elif isinstance(arg, (list, tuple)): - arg = list(arg) - new = [elem for ii, elem in enumerate(arg) if ii not in idx_none] - elif isinstance(arg, dict): - new = {} - for key, val in arg.items(): - new[key] = MultitaskFromSmilesDataModule._filter_none_molecules(idx_none, val) # Careful - else: - new = arg - out.append(new) - - out = tuple(out) if len(out) > 1 else out[0] - - return out - def _parse_label_cols( self, df: pd.DataFrame, @@ -1695,8 +1435,6 @@ def in_dims(self): """ graph = self.get_fake_graph() - if isinstance(graph, (GraphDict)): - graph = graph.data # get list of all keys corresponding to positional encoding pe_dim_dict = {} @@ -1735,14 +1473,9 @@ def get_fake_graph(self): return graph ########################## Private methods ###################################### - def _save_to_cache(self): - raise NotImplementedError() - - def _load_from_cache(self): - raise NotImplementedError() + @staticmethod def _extract_smiles_labels( - self, df: pd.DataFrame, task_level: str, smiles_col: Optional[str] = None, @@ -1752,7 +1485,11 @@ def _extract_smiles_labels( weights_col: Optional[str] = None, weights_type: Optional[str] = None, ) -> Tuple[ - np.ndarray, np.ndarray, Union[Type[None], np.ndarray], Dict[str, Union[Type[None], np.ndarray]] + np.ndarray, + np.ndarray, + np.ndarray, + Union[Type[None], np.ndarray], + Dict[str, Union[Type[None], np.ndarray]], ]: """ For a given dataframe extract the SMILES and labels columns. Smiles is returned as a list @@ -1767,7 +1504,7 @@ def _extract_smiles_labels( weights_col: Name of the column containing the weights weights_type: Type of weights to use. Returns: - smiles, labels, sample_idx, extras + smiles, labels, label_offsets, sample_idx, extras """ if smiles_col is None: # Should we specify which dataset has caused the potential issue? @@ -1788,17 +1525,18 @@ def _extract_smiles_labels( smiles = df[smiles_col].values if len(label_cols) > 0: if task_level == "graph": - labels = extract_labels(df, "graph", label_cols) + labels, label_offsets = extract_labels(df, "graph", label_cols) elif task_level == "node": - labels = extract_labels(df, "node", label_cols) + labels, label_offsets = extract_labels(df, "node", label_cols) elif task_level == "edge": - labels = extract_labels(df, "edge", label_cols) + labels, label_offsets = extract_labels(df, "edge", label_cols) elif task_level == "nodepair": - labels = extract_labels(df, "nodepair", label_cols) + labels, label_offsets = extract_labels(df, "nodepair", label_cols) else: raise ValueError(f"Unknown task level: {task_level}") else: labels = float("nan") + np.zeros([len(smiles), 0]) + label_offsets = None # Get the indices, used for sub-sampling and splitting the dataset if idx_col is not None: @@ -1837,10 +1575,10 @@ def _extract_smiles_labels( weights /= np.max(weights) # Put the max weight to 1 extras = {"weights": weights, "mol_ids": mol_ids} - return smiles, labels, sample_idx, extras + return smiles, labels, label_offsets, sample_idx, extras + @staticmethod def _get_split_indices( - self, dataset_size: int, split_val: float, split_test: float, @@ -1896,7 +1634,7 @@ def _get_split_indices( splits = splits_path else: # Split from an indices file - file_type = self._get_data_file_type(splits_path) + file_type = BaseDataModule._get_data_file_type(splits_path) train, val, test = split_names @@ -1904,7 +1642,7 @@ def _get_split_indices( splits = torch.load(splits_path) elif file_type in ["csv", "tsv"]: with fsspec.open(str(splits_path)) as f: - splits = self._read_csv(splits_path) + splits = BaseDataModule._read_csv(splits_path) else: raise ValueError( f"file type `{file_type}` for `{splits_path}` not recognised, please use .pt, .csv or .tsv" @@ -1919,15 +1657,16 @@ def _get_split_indices( # Filter train, val and test indices _, train_idx, _ = np.intersect1d(sample_idx, train_indices, return_indices=True) train_indices = train_idx.tolist() - _, valid_idx, _ = np.intersect1d(sample_idx, val_indices, return_indices=True) - val_indices = valid_idx.tolist() + _, val_idx, _ = np.intersect1d(sample_idx, val_indices, return_indices=True) + val_indices = val_idx.tolist() _, test_idx, _ = np.intersect1d(sample_idx, test_indices, return_indices=True) test_indices = test_idx.tolist() return train_indices, val_indices, test_indices + @staticmethod def _sub_sample_df( - self, df: pd.DataFrame, sample_size: Union[int, float, None], seed: Optional[int] = None + df: pd.DataFrame, sample_size: Union[int, float, None], seed: Optional[int] = None ) -> pd.DataFrame: r""" subsample from a pandas dataframe @@ -1954,7 +1693,7 @@ def _sub_sample_df( def get_data_hash(self): """ - Get a hash specific to a dataset and smiles_transformer. + Get a hash specific to a dataset. Useful to cache the pre-processed data. """ args = {} @@ -1974,115 +1713,9 @@ def get_data_hash(self): task_args.pop("epoch_sampling_fraction", None) args[task_key] = task_args - hash_dict = { - "smiles_transformer": self.smiles_transformer, - "task_specific_args": args, - } - data_hash = get_md5_hash(hash_dict) + data_hash = get_md5_hash(args) return data_hash - def get_data_cache_fullname(self, compress: bool = False) -> str: - """ - Create a hash for the dataset, and use it to generate a file name - - Parameters: - compress: Whether to compress the data - Returns: - full path to the data cache file - """ - if self.processed_graph_data_path is None: - return - ext = ".datacache" - if compress: - ext += ".gz" - data_cache_fullname = fs.join(self.processed_graph_data_path, self.data_hash + ext) - return data_cache_fullname - - def load_data_from_cache(self, verbose: bool = True, compress: bool = False) -> bool: - """ - Load the datasets from cache. First create a hash for the dataset, and verify if that - hash is available at the path given by `self.processed_graph_data_path`. - - Parameters: - verbose: Whether to print the progress - compress: Whether to compress the data - - Returns: - cache_data_exists: Whether the cache exists (if the hash matches) and the loading succeeded - """ - full_cache_data_path = self.get_data_cache_fullname(compress=compress) - - if full_cache_data_path is None: - logger.info("No cache data path specified. Skipping loading the data from cache.") - return False - - cache_data_exists = fs.exists(full_cache_data_path) - - if cache_data_exists: - try: - logger.info(f"Loading the data from cache at path `{full_cache_data_path}`") - now = time.time() - with fsspec.open(full_cache_data_path, mode="rb", compression="infer") as file: - load_params = torch.load(file) - self.__dict__.update(load_params) - ( - self.train_singletask_datasets, - self.val_singletask_datasets, - self.test_singletask_datasets, - ) = self.get_subsets_of_datasets( - self.single_task_datasets, - self.task_train_indices, - self.task_val_indices, - self.task_test_indices, - ) - elapsed = round(time.time() - now) - logger.info( - f"Successfully loaded the data from cache in {elapsed}s at path: `{full_cache_data_path}`" - ) - return True - except Exception as e: - if verbose: - logger.warning( - f"Data cache failed to load path: `{full_cache_data_path}`.\nThe data will be prepared and cache will be created for future runs." - ) - logger.warning(e.__str__()) - return False - else: - if verbose: - logger.info( - f"Data cache not found at path: `{full_cache_data_path}`.\nThe data will be prepared and cache will be created for future runs." - ) - return False - - def get_subsets_of_datasets( - self, - single_task_datasets: Dict[str, Datasets.SingleTaskDataset], - task_train_indices: Dict[str, Iterable], - task_val_indices: Dict[str, Iterable], - task_test_indices: Dict[str, Iterable], - ) -> Tuple[Subset, Subset, Subset]: - """ - From a dictionary of datasets and their associated indices, subset the train/val/test sets - - Parameters: - single_task_datasets: Dictionary of datasets - task_train_indices: Dictionary of train indices - task_val_indices: Dictionary of val indices - task_test_indices: Dictionary of test indices - Returns: - train_singletask_datasets: Dictionary of train subsets - val_singletask_datasets: Dictionary of val subsets - test_singletask_datasets: Dictionary of test subsets - """ - train_singletask_datasets = {} - val_singletask_datasets = {} - test_singletask_datasets = {} - for task in task_train_indices.keys(): - train_singletask_datasets[task] = Subset(single_task_datasets[task], task_train_indices[task]) - val_singletask_datasets[task] = Subset(single_task_datasets[task], task_val_indices[task]) - test_singletask_datasets[task] = Subset(single_task_datasets[task], task_test_indices[task]) - return train_singletask_datasets, val_singletask_datasets, test_singletask_datasets - def __len__(self) -> int: r""" Returns the number of elements of the current DataModule, which is the combined size of all single-task datasets given. @@ -2138,7 +1771,6 @@ def __init__( self, task_specific_args: Dict[str, Union[DatasetProcessingParams, Dict[str, Any]]], processed_graph_data_path: Optional[Union[str, os.PathLike]] = None, - dataloading_from: str = "ram", featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, @@ -2147,11 +1779,8 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = False, multiprocessing_context: Optional[str] = None, - featurization_n_jobs: int = -1, - featurization_progress: bool = False, - featurization_backend: str = "loky", collate_fn: Optional[Callable] = None, - prepare_dict_or_graph: str = "pyg:graph", + preprocessing_n_jobs: int = 0, **kwargs, ): r""" @@ -2168,27 +1797,26 @@ def __init__( meaning that all molecules will be considered. processed_graph_data_path: Path to the processed graph data. If None, the data will be downloaded from the OGB website. - dataloading_from: Whether to load the data from RAM or disk. Default is "ram". featurization: args to apply to the SMILES to Graph featurizer. batch_size_training: batch size for training and val dataset. batch_size_inference: batch size for test dataset. num_workers: Number of workers for the dataloader. Use -1 to use all available cores. pin_memory: Whether to pin on paginated CPU memory for the dataloader. - featurization_n_jobs: Number of cores to use for the featurization. - featurization_progress: whether to show a progress bar during featurization. - featurization_backend: The backend to use for the molecular featurization. - - - "multiprocessing": Found to cause less memory issues. - - "loky": joblib's Default. Found to cause memory leaks. - - "threading": Found to be slow. - collate_fn: A custom torch collate function. Default is to `graphium.data.graphium_collate_fn` sample_size: - `int`: The maximum number of elements to take from the dataset. - `float`: Value between 0 and 1 representing the fraction of the dataset to consider - `None`: all elements are considered. + preprocessing_n_jobs: Number of threads to use during preprocessing. + Use 0 to use all available cores, or -1 to use all but one core. + + dataloading_from: Deprecated. Behaviour now always matches previous "disk" option. + featurization_n_jobs: Deprecated. + featurization_progress: Deprecated. + featurization_backend: Deprecated. + prepare_dict_or_graph: Deprecated. Behaviour now always matches previous "pyg:graph" option. """ new_task_specific_args = {} @@ -2214,21 +1842,16 @@ def __init__( dm_args = {} dm_args["task_specific_args"] = new_task_specific_args dm_args["processed_graph_data_path"] = processed_graph_data_path - dm_args["dataloading_from"] = dataloading_from - dm_args["dataloader_from"] = dataloading_from dm_args["featurization"] = featurization dm_args["batch_size_training"] = batch_size_training dm_args["batch_size_inference"] = batch_size_inference dm_args["batch_size_per_pack"] = batch_size_per_pack dm_args["num_workers"] = num_workers dm_args["pin_memory"] = pin_memory - dm_args["featurization_n_jobs"] = featurization_n_jobs - dm_args["featurization_progress"] = featurization_progress - dm_args["featurization_backend"] = featurization_backend dm_args["persistent_workers"] = persistent_workers dm_args["multiprocessing_context"] = multiprocessing_context dm_args["collate_fn"] = collate_fn - dm_args["prepare_dict_or_graph"] = prepare_dict_or_graph + dm_args["preprocessing_n_jobs"] = preprocessing_n_jobs super().__init__(**dm_args, **kwargs) @@ -2400,7 +2023,6 @@ def __init__( tdc_train_val_seed: int = 0, # Inherited arguments from superclass processed_graph_data_path: Optional[Union[str, Path]] = None, - dataloading_from: str = "ram", featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, @@ -2409,11 +2031,8 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = False, multiprocessing_context: Optional[str] = None, - featurization_n_jobs: int = -1, - featurization_progress: bool = False, - featurization_backend: str = "loky", collate_fn: Optional[Callable] = None, - prepare_dict_or_graph: str = "pyg:graph", + preprocessing_n_jobs: int = 0, **kwargs, ): try: @@ -2460,7 +2079,6 @@ def __init__( task_specific_args=task_specific_args, featurization=featurization, processed_graph_data_path=processed_graph_data_path, - dataloading_from=dataloading_from, batch_size_training=batch_size_training, batch_size_inference=batch_size_inference, batch_size_per_pack=batch_size_per_pack, @@ -2468,11 +2086,8 @@ def __init__( pin_memory=pin_memory, persistent_workers=persistent_workers, multiprocessing_context=multiprocessing_context, - featurization_n_jobs=featurization_n_jobs, - featurization_progress=featurization_progress, - featurization_backend=featurization_backend, collate_fn=collate_fn, - prepare_dict_or_graph=prepare_dict_or_graph, + preprocessing_n_jobs=preprocessing_n_jobs, **kwargs, ) @@ -2532,237 +2147,3 @@ def _get_task_specific_arguments(self, name: str, seed: int, cache_dir: str) -> split_names=["train", "val", "test"], task_level="graph", ) - - -class FakeDataModule(MultitaskFromSmilesDataModule): - """ - A fake datamodule that generates artificial data by mimicking the true data coming - from the provided dataset. - It is useful to test the speed and performance of the model on a dataset without - having to featurize it and wait for the workers to load it. - """ - - def __init__( - self, - task_specific_args: Dict[str, Dict[str, Any]], # TODO: Replace this with DatasetParams - featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, - batch_size_training: int = 16, - batch_size_inference: int = 16, - num_workers: int = 0, - pin_memory: bool = True, - persistent_workers: bool = False, - multiprocessing_context: Optional[str] = None, - collate_fn: Optional[Callable] = None, - prepare_dict_or_graph: str = "pyg:graph", - num_mols_to_generate: int = 1000000, - indexing_single_elem: bool = True, - **kwargs, - ): - super().__init__( - task_specific_args=task_specific_args, - featurization=featurization, - batch_size_training=batch_size_training, - batch_size_inference=batch_size_inference, - num_workers=num_workers, - pin_memory=pin_memory, - persistent_workers=persistent_workers, - multiprocessing_context=multiprocessing_context, - collate_fn=collate_fn, - prepare_dict_or_graph=prepare_dict_or_graph, - **kwargs, - ) - self.num_mols_to_generate = num_mols_to_generate - self.indexing_single_elem = indexing_single_elem - - def generate_data(self, label_cols: List[str], smiles_col: str): - """ - Parameters: - labels_cols - smiles_col - Returns: - pd.DataFrame - """ - num_generated_mols = int(1) - # Create a dummy generated dataset - singel smiles string, duplicated N times - example_molecules = dict( - smiles="C1N2C3C4C5OC13C2C45", - cxsmiles="[H]C1C2=C(NC(=O)[C@@]1([H])C1=C([H])C([H])=C(C([H])([H])[H])C([H])=C1[H])C([H])=C([H])N=C2[H] |(6.4528,-1.5789,-1.2859;5.789,-0.835,-0.8455;4.8499,-0.2104,-1.5946;3.9134,0.7241,-0.934;3.9796,1.1019,0.3172;5.0405,0.6404,1.1008;5.2985,1.1457,2.1772;5.9121,-0.5519,0.613;6.9467,-0.2303,0.8014;5.677,-1.7955,1.4745;4.7751,-2.7953,1.0929;4.2336,-2.7113,0.154;4.5521,-3.9001,1.914;3.8445,-4.6636,1.5979;5.215,-4.0391,3.1392;4.9919,-5.2514,4.0126;5.1819,-5.0262,5.0671;5.6619,-6.0746,3.7296;3.966,-5.6247,3.925;6.1051,-3.0257,3.52;6.6247,-3.101,4.4725;6.3372,-1.9217,2.7029;7.0168,-1.1395,3.0281;2.8586,1.2252,-1.7853;2.1303,1.9004,-1.3493;2.8118,0.8707,-3.0956;2.0282,1.2549,-3.7434;3.716,0.0207,-3.7371;4.6658,-0.476,-3.0127;5.3755,-1.1468,-3.5021)|", - ) - example_df_entry = {smiles_col: example_molecules[smiles_col]} - for label in label_cols: - example_df_entry[label] = np.random.random() - df = pd.DataFrame([example_df_entry]) - logger.info(f"Generating fake dataset on host... \n Generating {num_generated_mols} rows in the df.") - df = pd.concat([df] * num_generated_mols, ignore_index=True) - return df - - def prepare_data(self): - """Called only from a single process in distributed settings. Steps: - - - If each cache is set and exists, reload from cache and return. Otherwise, - - For each single-task dataset: - - Load its dataframe from a path (if provided) - - Subsample the dataframe - - Extract the smiles, labels from the dataframe - - In the previous step, we were also able to get the unique smiles, which we use to compute the features - - For each single-task dataframe and associated data (smiles, labels, etc.): - - Filter out the data corresponding to molecules which failed featurization. - - Create a corresponding SingletaskDataset - - Split the SingletaskDataset according to the task-specific splits for train, val and test - """ - - """Load all single-task dataframes.""" - if self.num_mols_to_generate is None: - num_mols = 0 - - task_df = {} - for task, args in self.task_dataset_processing_params.items(): - logger.info(f"Reading data for task '{task}'") - if args.df is None: - # Only load the useful columns, as some datasets can be very large when loading all columns. - label_cols = self._parse_label_cols( - df=None, df_path=args.df_path, label_cols=args.label_cols, smiles_col=args.smiles_col - ) - task_df[task] = self.generate_data(label_cols=args.label_cols, smiles_col=args.smiles_col) - if self.num_mols_to_generate is None: - num_mols = max(num_mols, len(task_df[task])) - task_df[task] = task_df[task].iloc[0:1] - - args.label_cols = label_cols - if self.num_mols_to_generate is None: - self.num_mols_to_generate = num_mols - logger.info("Done reading datasets") - - """Subsample the data frames and extract the necessary data to create SingleTaskDatasets for each task (smiles, labels, extras).""" - task_dataset_args = {} - for task in task_df.keys(): - task_dataset_args[task] = {} - - for task, df in task_df.items(): - logger.info(f"Prepare single-task dataset for task '{task}' with {len(df)} data points.") - # Extract smiles, labels, extras - args = self.task_dataset_processing_params[task] - smiles, labels, sample_idx, extras = self._extract_smiles_labels( - df, - task_level=args.task_level, - smiles_col=args.smiles_col, - label_cols=args.label_cols, - idx_col=args.idx_col, - weights_col=args.weights_col, - weights_type=args.weights_type, - ) - - # Store the relevant information for each task's dataset - task_dataset_args[task]["smiles"] = smiles - task_dataset_args[task]["labels"] = labels - task_dataset_args[task]["sample_idx"] = sample_idx - task_dataset_args[task]["extras"] = extras - - """Convert SMILES to features (graphs, fingerprints, etc.) for the unique molecules found.""" - all_smiles = [] - idx_per_task = {} - total_len = 0 - for task, dataset_args in task_dataset_args.items(): - all_smiles.extend(dataset_args["smiles"]) - num_smiles = len(dataset_args["smiles"]) - idx_per_task[task] = (total_len, total_len + num_smiles) - total_len += num_smiles - # Get all unique mol ids - all_unique_mol_ids = smiles_to_unique_mol_ids( - all_smiles, - n_jobs=self.featurization_n_jobs, - featurization_batch_size=self.featurization_batch_size, - backend=self.featurization_backend, - ) - # Convert SMILES to features - features, _ = self._featurize_molecules(all_smiles) - task_dataset_args[task]["features"] = features - """Filter data based on molecules which failed featurization. Create single task datasets as well.""" - self.single_task_datasets = {} - for task, args in task_dataset_args.items(): - self.single_task_datasets[task] = Datasets.SingleTaskDataset( - features=task_dataset_args[task]["features"], - labels=task_dataset_args[task]["labels"], - smiles=task_dataset_args[task]["smiles"], - indices=task_dataset_args[task]["sample_idx"], - unique_ids=all_unique_mol_ids[idx_per_task[task][0] : idx_per_task[task][1]], - **task_dataset_args[task]["extras"], - ) - - """We split the data up to create train, val and test datasets""" - self.train_singletask_datasets = {} - self.val_singletask_datasets = {} - self.test_singletask_datasets = {} - for task, df in task_df.items(): - self.train_singletask_datasets[task] = Subset(self.single_task_datasets[task], [0]) - self.val_singletask_datasets[task] = Subset(self.single_task_datasets[task], [0]) - self.test_singletask_datasets[task] = Subset(self.single_task_datasets[task], [0]) - - def setup(self, stage=None): - # TODO - """ - Prepare the torch dataset. Called on every GPUs. Setting state here is ok. - Parameters: - stage (str): Either 'fit', 'test', or None. - """ - labels_size = {} - - if stage == "fit" or stage is None: - self.train_ds = Datasets.FakeDataset(self.train_singletask_datasets, num_mols=self.num_mols_to_generate, indexing_same_elem=self.indexing_single_elem) # type: ignore - self.val_ds = Datasets.FakeDataset(self.val_singletask_datasets, num_mols=self.num_mols_to_generate, indexing_same_elem=self.indexing_single_elem) # type: ignore - print(self.train_ds) - print(self.val_ds) - - labels_size.update( - self.train_ds.labels_size - ) # Make sure that all task label sizes are contained in here. Maybe do the update outside these if statements. - labels_size.update(self.val_ds.labels_size) - - if stage == "test" or stage is None: - self.test_ds = Datasets.FakeDataset(self.test_singletask_datasets, num_mols=self.num_mols_to_generate, indexing_same_elem=self.indexing_single_elem) # type: ignore - print(self.test_ds) - labels_size.update(self.test_ds.labels_size) - - default_labels_size_dict = self.collate_fn.keywords.get("labels_size_dict", None) - - if default_labels_size_dict is None: - self.collate_fn.keywords["labels_size_dict"] = labels_size - - def get_fake_graph(self): - """ - Low memory footprint method to get the first datapoint DGL graph. - The first 10 rows of the data are read in case the first one has a featurization - error. If all 20 first element, then `None` is returned, otherwise the first - graph to not fail is returned. - """ - keys = list(self.task_dataset_processing_params.keys()) - task = keys[0] - args = self.task_dataset_processing_params[task] - if args.df is None: - df = self._read_csv(args.df_path, nrows=20) - else: - df = args.df.iloc[0:20, :] - - df = df.iloc[0:20, :] - label_cols = self._parse_label_cols( - df, df_path=None, label_cols=args.label_cols, smiles_col=args.smiles_col - ) - - smiles, labels, sample_idx, extras = self._extract_smiles_labels( - df, - task_level=args.task_level, - smiles_col=args.smiles_col, - label_cols=label_cols, - idx_col=args.idx_col, - weights_col=args.weights_col, - weights_type=args.weights_type, - ) - - graph = None - for s in smiles: - graph = self.smiles_transformer(s, mask_nan=0.0) - num_nodes = graph.num_nodes - num_edges = graph.num_edges - if (graph is not None) and (num_edges > 0) and (num_nodes > 0): - break - return graph diff --git a/graphium/data/dataset.py b/graphium/data/dataset.py index 34c1b30aa..498515fc3 100644 --- a/graphium/data/dataset.py +++ b/graphium/data/dataset.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -16,7 +16,7 @@ from copy import deepcopy from functools import lru_cache from multiprocessing import Manager -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import fsspec import numpy as np @@ -26,125 +26,7 @@ from torch.utils.data.dataloader import Dataset from torch_geometric.data import Batch, Data -from graphium.data.smiles_transform import smiles_to_unique_mol_ids -from graphium.features import GraphDict - - -class SingleTaskDataset(Dataset): - def __init__( - self, - labels: List[Union[torch.Tensor, np.ndarray]], - features: Optional[List[Union[Data, "GraphDict"]]] = None, - smiles: Optional[List[str]] = None, - indices: Optional[List[int]] = None, - weights: Optional[Union[torch.Tensor, np.ndarray]] = None, - unique_ids: Optional[List[str]] = None, - mol_ids: Optional[List[str]] = None, - ): - r""" - dataset for a single task - Parameters: - labels: A list of labels for the given task (one per graph) - features: A list of graphs - smiles: A list of smiles - indices: A list of indices - weights: A list of weights - unique_ids: A list of unique ids for each molecule generated from `datamol.unique_id` - mol_ids: A list of ids coming from the original dataset. Useful to identify the molecule in the original dataset. - """ - - # Verify that all lists are the same length - numel = len(labels) - - def _check_if_same_length(to_check, label): - """Simple utility method to throw an error if the length is not as expected.""" - if to_check is not None and len(to_check) != numel: - raise ValueError( - f"{label} must be the same length as `labels`, got {len(to_check)} and {numel}" - ) - - _check_if_same_length(features, "features") - _check_if_same_length(indices, "indices") - _check_if_same_length(weights, "weights") - _check_if_same_length(unique_ids, "unique_ids") - _check_if_same_length(mol_ids, "mol_ids") - - self.labels = labels - if smiles is not None: - manager = Manager() # Avoid memory leaks with `num_workers > 0` by using the Manager - self.smiles = manager.list(smiles) - else: - self.smiles = None - self.features = features - self.indices = indices - if self.indices is not None: - self.indices = np.array( - self.indices - ) # Avoid memory leaks with `num_workers > 0` by using numpy array - self.weights = weights - self.unique_ids = unique_ids - self.mol_ids = mol_ids - - def __len__(self): - r""" - return the size of the dataset - Returns: - size: the size of the dataset - """ - return len(self.labels) - - def __getitem__(self, idx): - """ - get the data at the given index - Parameters: - idx: the index to get the data at - Returns: - datum: a dictionary containing the data at the given index, with keys "features", "labels", "smiles", "indices", "weights", "unique_ids" - """ - datum = {} - - if self.features is not None: - datum["features"] = self.features[idx] - - if self.labels is not None: - datum["labels"] = self.labels[idx] - - if self.smiles is not None: - datum["smiles"] = self.smiles[idx] - - if self.indices is not None: - datum["indices"] = self.indices[idx] - - if self.weights is not None: - datum["weights"] = self.weights[idx] - - if self.unique_ids is not None: - datum["unique_ids"] = self.unique_ids[idx] - - if self.mol_ids is not None: - datum["mol_ids"] = self.mol_ids[idx] - - return datum - - def __getstate__(self): - """Serialize the class for pickling.""" - state = {} - state["labels"] = self.labels - state["smiles"] = list(self.smiles) if self.smiles is not None else None - state["features"] = self.features - state["indices"] = self.indices - state["weights"] = self.weights - state["unique_ids"] = self.unique_ids - state["mol_ids"] = self.mol_ids - return state - - def __setstate__(self, state: dict): - """Reload the class from pickling.""" - if state["smiles"] is not None: - manager = Manager() - state["smiles"] = manager.list(state["smiles"]) - - self.__dict__.update(state) +import graphium_cpp class MultitaskDataset(Dataset): @@ -152,178 +34,48 @@ class MultitaskDataset(Dataset): def __init__( self, - datasets: Dict[str, SingleTaskDataset], - n_jobs=-1, - backend: str = "loky", - featurization_batch_size=1000, - progress: bool = True, - save_smiles_and_ids: bool = False, + featurize_smiles: Callable[[str], dict], + task_names: List[str] = None, + label_num_cols: List[int] = None, + label_dtypes: List[int] = None, + mol_file_data_offsets=None, + concat_smiles_tensor=None, + smiles_offsets_tensor=None, + num_nodes_tensor=None, + num_edges_tensor=None, about: str = "", data_path: Optional[Union[str, os.PathLike]] = None, - dataloading_from: str = "ram", - data_is_cached: bool = False, ): r""" This class holds the information for the multitask dataset. - Several single-task datasets can be merged to create a multi-task dataset. After merging the dictionary of single-task datasets. we will have a multitask dataset of the following form: - - self.mol_ids will be a list to contain the unique molecular IDs to identify the molecules - - self.smiles will be a list to contain the corresponding smiles for that molecular ID across all single-task datasets - - self.labels will be a list of dictionaries where the key is the task name and the value is the label(s) for that task. - At this point, any particular molecule will only have entries for tasks for which it has a label. Later, in the collate - function, we fill up the missing task labels with NaNs. - - self.features will be a list of featurized graphs corresponding to that particular unique molecule. - However, for testing purposes we may not require features so that we can make sure that this merge function works. + - self.mol_file_data_offsets will be a Tensor representing where to find + label data about each molecule in the corresponding file + - self.smiles_tensor will be a Tensor containing all smiles strings concatenated, with null terminators + - self.smiles_offsets_tensor will be a Tensor indicating where smiles strings start in smiles_tensor + - self.num_nodes_tensor will be a Tensor of the number of nodes in each graph + - self.num_edges_tensor will be a Tensor of the number of edges in each graph Parameters: - datasets: A dictionary of single-task datasets - n_jobs: Number of jobs to run in parallel - backend: Parallelization backend - featurization_batch_size: The batch size to use for the parallelization of the featurization - progress: Whether to display the progress bar - save_smiles_and_ids: Whether to save the smiles and ids for the dataset. If `False`, `mol_ids` and `smiles` are set to `None` about: A description of the dataset data_path: The location of the data if saved on disk - dataloading_from: Whether to load the data from `"disk"` or `"ram"` - data_is_cached: Whether the data is already cached on `"disk"` """ super().__init__() - self.n_jobs = n_jobs - self.backend = backend - self.featurization_batch_size = featurization_batch_size - self.progress = progress + self.about = about - self.save_smiles_and_ids = save_smiles_and_ids self.data_path = data_path - self.dataloading_from = dataloading_from - - logger.info(f"Dataloading from {dataloading_from.upper()}") - - if data_is_cached: - self._load_metadata() - - if dataloading_from == "disk": - self.features = None - self.labels = None - elif dataloading_from == "ram": - logger.info(f"Transferring {about} from DISK to RAM...") - self.transfer_from_disk_to_ram() - - else: - task = next(iter(datasets)) - self.features = None - if (len(datasets[task]) > 0) and ("features" in datasets[task][0]): - self.mol_ids, self.smiles, self.labels, self.features = self.merge(datasets) - else: - self.mol_ids, self.smiles, self.labels = self.merge(datasets) - # Set mol_ids and smiles to None to save memory as they are not needed. - if not save_smiles_and_ids: - self.mol_ids = None - self.smiles = None - self.labels_size = self.set_label_size_dict(datasets) - self.labels_dtype = self.set_label_dtype_dict(datasets) - self.dataset_length = len(self.labels) - self._num_nodes_list = None - self._num_edges_list = None - if self.features is not None: - self._num_nodes_list = get_num_nodes_per_graph(self.features) - self._num_edges_list = get_num_edges_per_graph(self.features) - - def transfer_from_disk_to_ram(self, parallel_with_batches: bool = False): - """ - Function parallelizing transfer from DISK to RAM - """ - - def transfer_mol_from_disk_to_ram(idx): - """ - Function transferring single mol from DISK to RAM - """ - data_dict = self.load_graph_from_index(idx) - mol_in_ram = { - "features": data_dict["graph_with_features"], - "labels": data_dict["labels"], - } - - return mol_in_ram - - if parallel_with_batches and self.featurization_batch_size: - data_in_ram = parallelized_with_batches( - transfer_mol_from_disk_to_ram, - range(self.dataset_length), - batch_size=self.featurization_batch_size, - n_jobs=0, - backend=self.backend, - progress=self.progress, - tqdm_kwargs={"desc": "Transfer from DISK to RAM"}, - ) - else: - data_in_ram = parallelized( - transfer_mol_from_disk_to_ram, - range(self.dataset_length), - n_jobs=0, - backend=self.backend, - progress=self.progress, - tqdm_kwargs={"desc": "Transfer from DISK to RAM"}, - ) - - self.features = [sample["features"] for sample in data_in_ram] - self.labels = [sample["labels"] for sample in data_in_ram] - - def save_metadata(self, directory: str): - """ - Save everything other than features/labels - """ - attrs_to_save = [ - "mol_ids", - "smiles", - "labels_size", - "labels_dtype", - "dataset_length", - "_num_nodes_list", - "_num_edges_list", - ] - attrs = {attr: getattr(self, attr) for attr in attrs_to_save} - - path = os.path.join(directory, "multitask_metadata.pkl") - - torch.save(attrs, path, pickle_protocol=4) - - def _load_metadata(self): - """ - Load everything other than features/labels - """ - attrs_to_load = [ - "mol_ids", - "smiles", - "labels_size", - "labels_dtype", - "dataset_length", - "_num_nodes_list", - "_num_edges_list", - ] - path = os.path.join(self.data_path, "multitask_metadata.pkl") - with fsspec.open(path, "rb") as f: - attrs = torch.load(path) - - if not set(attrs_to_load).issubset(set(attrs.keys())): - raise ValueError( - f"The metadata in the cache at {self.data_path} does not contain the right information. " - f"This may be because the cache was prepared using an earlier version of Graphium. " - f"You can try deleting the cache and running the data preparation again. " - f"\nMetadata keys found: {attrs.keys()}" - f"\nMetadata keys required: {attrs_to_load}" - ) - - for attr, value in attrs.items(): - setattr(self, attr, value) - - if self.save_smiles_and_ids: - if self.smiles is None or self.mol_ids is None: - logger.warning( - f"Argument `save_smiles_and_ids` is set to {self.save_smiles_and_ids} but metadata in the cache at {self.data_path} does not contain smiles and mol_ids. " - f"This may be because `Datamodule.prepare_data(save_smiles_and_ids=False)` was run followed by `Datamodule.setup(save_smiles_and_ids=True)`. " - f"When loading from cached files, the `save_smiles_and_ids` argument of `Datamodule.setup()` is superseeded by the `Datamodule.prepare_data()`. " - ) + self.featurize_smiles = featurize_smiles + self.task_names = task_names + self.label_num_cols = label_num_cols + self.label_dtypes = label_dtypes + self.mol_file_data_offsets = mol_file_data_offsets + self.smiles_tensor = concat_smiles_tensor + self.smiles_offsets_tensor = smiles_offsets_tensor + self.num_nodes_tensor = num_nodes_tensor + self.num_edges_tensor = num_edges_tensor + self.dataset_length = num_nodes_tensor.size(dim=0) + + logger.info(f"Dataloading from DISK") def __len__(self): r""" @@ -336,24 +88,14 @@ def num_nodes_list(self): """ The number of nodes per graph """ - if self._num_nodes_list is None: - if len(self) == 0: - self._num_nodes_list = [] - else: - self._num_nodes_list = get_num_nodes_per_graph(self.features) - return self._num_nodes_list + return self.num_nodes_tensor @property def num_edges_list(self): """ The number of edges per graph """ - if self._num_edges_list is None: - if len(self) == 0: - self._num_edges_list = [] - else: - self._num_edges_list = get_num_edges_per_graph(self.features) - return self._num_edges_list + return self.num_edges_tensor @property def num_graphs_total(self): @@ -367,28 +109,30 @@ def num_nodes_total(self): """Total number of nodes for all graphs""" if len(self) == 0: return - return sum(self.num_nodes_list) + return torch.sum(self.num_nodes_list, dtype=torch.int64).item() @property def max_num_nodes_per_graph(self): """Maximum number of nodes per graph""" if len(self) == 0: return - return max(self.num_nodes_list) + return torch.max(self.num_nodes_list).item() @property def std_num_nodes_per_graph(self): """Standard deviation of number of nodes per graph""" if len(self) == 0: return - return np.std(self.num_nodes_list) + # correction is zero to match previous default behaviour of numpy.std + # Consider changing it to 1 (the torch.std default) + return torch.std(self.num_nodes_list.to(torch.float64), correction=0).item() @property def min_num_nodes_per_graph(self): """Minimum number of nodes per graph""" if len(self) == 0: return - return min(self.num_nodes_list) + return torch.min(self.num_nodes_list).item() @property def mean_num_nodes_per_graph(self): @@ -402,28 +146,30 @@ def num_edges_total(self): """Total number of edges for all graphs""" if len(self) == 0: return - return sum(self.num_edges_list) + return torch.sum(self.num_edges_list, dtype=torch.int64).item() @property def max_num_edges_per_graph(self): """Maximum number of edges per graph""" if len(self) == 0: return - return max(self.num_edges_list) + return torch.max(self.num_edges_list).item() @property def min_num_edges_per_graph(self): """Minimum number of edges per graph""" if len(self) == 0: return - return min(self.num_edges_list) + return torch.min(self.num_edges_list).item() @property def std_num_edges_per_graph(self): """Standard deviation of number of nodes per graph""" if len(self) == 0: return - return np.std(self.num_edges_list) + # correction is zero to match previous default behaviour of numpy.std + # Consider changing it to 1 (the torch.std default) + return torch.std(self.num_edges_list.to(torch.float64), correction=0).item() @property def mean_num_edges_per_graph(self): @@ -438,27 +184,26 @@ def __getitem__(self, idx): Parameters: idx: The index of the data to retrieve Returns: - A dictionary containing the data for the specified index with keys "mol_ids", "smiles", "labels", and "features" + A dictionary containing the data for the specified index with keys "labels", "num_nodes", "num_edges", and "features" """ - datum = {} - if self.dataloading_from == "disk": - data_dict = self.load_graph_from_index(idx) - datum["features"] = data_dict["graph_with_features"] - datum["labels"] = data_dict["labels"] - if "smiles" in data_dict.keys(): - datum["smiles"] = data_dict["smiles"] - else: - if self.mol_ids is not None: - datum["mol_ids"] = self.mol_ids[idx] + if self.smiles_tensor is None or self.smiles_offsets_tensor is None: + raise ValueError("Missing smiles in MultitaskDataset.__getitem__") - if self.smiles is not None: - datum["smiles"] = self.smiles[idx] + smiles_str = graphium_cpp.extract_string(self.smiles_tensor, self.smiles_offsets_tensor, idx) - if self.labels is not None: - datum["labels"] = self.labels[idx] + if self.mol_file_data_offsets is None: + datum = {"features": self.featurize_smiles(smiles_str)} + else: + datum = { + "labels": self.load_graph_from_index(idx), + "features": self.featurize_smiles(smiles_str), + } - if self.features is not None: - datum["features"] = self.features[idx] + # One of the featurization error handling options returns a string on error, + # instead of throwing an exception, so assume that the intention is to just skip, + # instead of crashing. + if isinstance(datum["features"], str): + datum = None return datum @@ -468,165 +213,23 @@ def load_graph_from_index(self, data_idx): Parameters: data_idx: The index of the data to retrieve Returns: - A dictionary containing the data for the specified index with keys "graph_with_features", "labels" and "smiles" (optional). + A Data object containing the data for the specified index with keys corresponding to the tasks. """ - filename = os.path.join( - self.data_path, format(data_idx // 1000, "04d"), format(data_idx, "07d") + ".pkl" + labels = {} + graphium_cpp.load_labels_from_index( + self.data_path, + data_idx, + self.mol_file_data_offsets, + self.task_names, + self.label_num_cols, + self.label_dtypes, + labels, ) - with fsspec.open(filename, "rb") as f: - data_dict = torch.load(f) - return data_dict - - def merge( - self, datasets: Dict[str, SingleTaskDataset] - ) -> Tuple[List[str], List[str], List[Dict[str, Any]], List[Any]]: - r"""This function merges several single task datasets into a multitask dataset. - - The idea: for each of the smiles, labels, features and tasks, we create a corresponding list that concatenates these items across all tasks. - In particular, for any index, the elements in the smiles, labels, features and task lists at that index will correspond to each other (i.e. match up). - Over this list of all smiles (which we created by concatenating the smiles across all tasks), we compute their molecular ID using functions from Datamol. - Once again, we will have a list of molecular IDs which is the same size as the list of smiles, labels, features and tasks. - We then use numpy's `unique` function to find the exact list of unique molecular IDs as these will identify the molecules in our dataset. We also get the - inverse from numpy's `unique`, which will allow us to index in addition to the list of all molecular IDs, the list of all smiles, labels, features and tasks. - Finally, we use this inverse to construct the list of list of smiles, list of label dictionaries (indexed by task) and the list of features such that - the indices match up. This is what is needed for the `get_item` function to work. - - Parameters: - datasets: A dictionary of single-task datasets - Returns: - A tuple of (list of molecular IDs, list of smiles, list of label dictionaries, list of features) - """ + data_dict = Data() + for task, values in labels.items(): + data_dict[task] = values - # Get all the smiles, labels, features and tasks. - all_lists = self._get_all_lists_ids(datasets=datasets) - mol_ids, inv = self._get_inv_of_mol_ids(all_mol_ids=all_lists["mol_ids"]) - - # Store the smiles. - smiles = [[] for _ in range(len(mol_ids))] - for all_idx, unique_idx in enumerate(inv): - smiles[unique_idx].append(all_lists["smiles"][all_idx]) - - # Store the labels. - labels = [Data() for _ in range(len(mol_ids))] - for all_idx, unique_idx in enumerate(inv): - task: str = all_lists["tasks"][all_idx] - label = all_lists["labels"][all_idx] - labels[unique_idx][task] = label - - if all_idx < len(all_lists["features"]): - features = all_lists["features"][all_idx] - labels[unique_idx]["x"] = torch.empty( - (features.num_nodes, 1) - ) # IPU is not happy with zero-sized tensors, so use shape (features.num_nodes, 1) here - labels[unique_idx]["edge_index"] = torch.empty((2, features.num_edges)) - - # Store the features - if len(all_lists["features"]) > 0: - features = [-1 for i in range(len(mol_ids))] - for all_idx, unique_idx in enumerate(inv): - features[unique_idx] = all_lists["features"][all_idx] - return mol_ids, smiles, labels, features - else: - return mol_ids, smiles, labels - - def _get_all_lists_ids(self, datasets: Dict[str, SingleTaskDataset]) -> Dict[str, Any]: - all_smiles = [] - all_features = [] - all_labels = [] - all_mol_ids = [] - all_tasks = [] - - for task, ds in datasets.items(): - if len(ds) == 0: - continue - # Get data from single task dataset - ds_smiles = [ds[i]["smiles"] for i in range(len(ds))] - ds_labels = [ds[i]["labels"] for i in range(len(ds))] - if "unique_ids" in ds[0].keys(): - ds_mol_ids = [ds[i]["unique_ids"] for i in range(len(ds))] - else: - ds_mol_ids = smiles_to_unique_mol_ids( - ds_smiles, - n_jobs=self.n_jobs, - featurization_batch_size=self.featurization_batch_size, - backend=self.backend, - progress=self.progress, - progress_desc=f"{task}: mol to ids", - ) - if "features" in ds[0]: - ds_features = [ds[i]["features"] for i in range(len(ds))] - else: - ds_features = None - all_smiles.extend(ds_smiles) - all_labels.extend(ds_labels) - all_mol_ids.extend(ds_mol_ids) - if ds_features is not None: - all_features.extend(ds_features) - - task_list = [task] * ds.__len__() - all_tasks.extend(task_list) - - all_lists = { - "smiles": all_smiles, - "features": all_features, - "labels": all_labels, - "mol_ids": all_mol_ids, - "tasks": all_tasks, - } - - return all_lists - - def _get_inv_of_mol_ids(self, all_mol_ids): - mol_ids, inv = np.unique(all_mol_ids, return_inverse=True) - return mol_ids, inv - - def _find_valid_label(self, task, ds): - r""" - For a given dataset, find a genuine label for that dataset - """ - valid_label = None - for i in range(len(ds)): - if ds[i] is not None: - valid_label = ds[i]["labels"] - break - - if valid_label is None: - raise ValueError(f"Dataset for task {task} has no valid labels.") - - return valid_label - - def set_label_size_dict(self, datasets: Dict[str, SingleTaskDataset]): - r""" - This gives the number of labels to predict for a given task. - """ - task_labels_size = {} - for task, ds in datasets.items(): - if len(ds) == 0: - continue - - valid_label = self._find_valid_label(task, ds) - - # Assume for a fixed task, the label dimension is the same across data points - torch_label = torch.as_tensor(valid_label) - - # First dimension is graph-specific - task_labels_size[task] = torch_label.size() - return task_labels_size - - def set_label_dtype_dict(self, datasets: Dict[str, SingleTaskDataset]): - r""" - Gets correct dtype for a given label - """ - task_labels_dtype = {} - for task, ds in datasets.items(): - if len(ds) == 0: - continue - - valid_label = self._find_valid_label(task, ds) - - torch_label = torch.as_tensor(valid_label) - task_labels_dtype[task] = torch_label.dtype - return task_labels_dtype + return data_dict def __repr__(self) -> str: """ @@ -643,11 +246,6 @@ def __repr__(self) -> str: ) return out_str - # Faster to compute the statistics if we unbatch first. - features = self.features - if isinstance(self.features, Batch): - self.features = self.features.to_data_list() - out_str = ( f"-------------------\n{self.__class__.__name__}\n" + f"\tabout = {self.about}\n" @@ -665,111 +263,33 @@ def __repr__(self) -> str: + f"-------------------\n" ) - # Restore the original features. - self.features = features - return out_str -class FakeDataset(MultitaskDataset): - """ - A dataset to hold the fake data. - """ - - def __init__( - self, datasets: Dict[str, SingleTaskDataset], num_mols: int = 1234, indexing_same_elem: bool = False - ): - """ - Parameters: - datasets: - A dictionary of datasets. The keys are the task names and the values are the datasets. - num_mols: - The number of molecules to generate. In reality, it is the same molecule, - but `num_mols` will change the length of the dataset. - indexing_same_elem: - If True, the same molecule is used for all samples. - Otherwise, a deepcopied molecule is used for each sample. - """ - self.indexing_same_elem = indexing_same_elem - self.num_mols = num_mols - self.num_datasets = len(datasets) - - self.about = "FakeDatasets" - task = next(iter(datasets)) - if "features" in datasets[task][0]: - self.mol_ids, self.smiles, self.labels, self.features = self.merge(datasets) - if self.indexing_same_elem is False: - self.mol_ids, self.smiles, self.labels, self.features = self.deepcopy_mol( - self.mol_ids, self.smiles, self.labels, self.features - ) - else: - self.mol_ids, self.smiles, self.labels = self.merge(datasets) - if self.indexing_same_elem is False: - self.mol_ids, self.smiles, self.labels, _ = self.deepcopy_mol( - self.mol_ids, self.smiles, self.labels - ) - - self.labels_size = self.set_label_size_dict(datasets) - self.labels_dtype = self.set_label_dtype_dict(datasets) - self.features = self.features - - def _get_inv_of_mol_ids(self, all_mol_ids): - # The generated data is a single molecule duplicated - mol_ids = np.array(all_mol_ids) - inv = [_ for _ in range(len(mol_ids) // self.num_datasets)] * self.num_datasets - mol_ids = np.unique(inv) - return mol_ids, inv - - def deepcopy_mol(self, mol_ids, labels, smiles, features=None): - """ - Create a deepcopy of the single molecule num_mols times - - Args: - mol_ids (array): The single value for the mol ID - labels (List[Dict]): List containing one dict with the label name-value pairs - smiles (List[List[str]]): List of list containing SMILE sting - features (List[Data], optional): list containing Data object. Defaults to None. - - Returns: - The deep copy of the inputs - """ - logger.info("Duplicating the single dataset element...") - mol_ids = [deepcopy(mol_ids[0]) for _ in range(self.num_mols)] - logger.info("Finished `mol_ids`") - labels = [deepcopy(labels[0]) for _ in range(self.num_mols)] - logger.info("Finished `labels`") - smiles = [deepcopy(smiles[0]) for _ in range(self.num_mols)] - logger.info("Finished `smiles`") - if features is not None: - features = [deepcopy(features[0]) for _ in range(self.num_mols)] - logger.info("Finished `features`") - return mol_ids, labels, smiles, features - - def __len__(self): - r""" - Returns the number of molecules - """ - return self.num_mols - - def __getitem__(self, idx): - r""" - get the data for at the specified index - Parameters: - idx: The index of the data to retrieve - Returns: - A dictionary containing the data for the specified index with keys "mol_ids", "smiles", "labels", and "features" - """ - datum = {} - if self.indexing_same_elem is True: - # If using a single memory location override the idx value passed - idx = 0 - if self.labels is not None: - datum["labels"] = self.labels[idx] - - if self.features is not None: - datum["features"] = self.features[idx] - - return datum +def torch_enum_to_dtype(v: Union[int, torch.dtype]): + if isinstance(v, torch.dtype): + return v + + mapping = [ + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + torch.complex32, + torch.complex64, + torch.complex128, + torch.bool, + torch.qint8, + torch.quint8, + torch.qint32, + torch.bfloat16, + torch.quint4x2, + ] + return mapping[v] if (v >= 0 and v < len(mapping)) else None def get_num_nodes_per_graph(graphs): diff --git a/graphium/data/multilevel_utils.py b/graphium/data/multilevel_utils.py index 7f9ed5813..a096979dd 100644 --- a/graphium/data/multilevel_utils.py +++ b/graphium/data/multilevel_utils.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -21,45 +21,110 @@ def extract_labels(df: pd.DataFrame, task_level: str, label_cols: List[str]): - """Extracts labels in label_cols from dataframe df for a given task_level. - Returns a list of numpy arrays converted to the correct shape. Multiple - targets are concatenated for each graph. + """Extracts the labels specified by label_cols from dataframe df. + If task_level is "graph", each entry in df must be a single numeric value, + and this function returns a single, 2D numpy array containing the data. + If task_level is something else, each entry in df must be a numpy array, + python list, or single numeric value, and this function returns both a 2D + numpy array of data and a 1D numpy array of integers indicating the row + number in the first array where each molecule's data starts, with an extra + integer at the end that should equal the total number of rows in the first + array. The first array can have type float16, float32, or float64, + depending on the largest precision of input data, and arrays of varying + sizes across columns are padded with nan values, so that a single molecule + occupies a fixed number of rows and len(label_cols) columns. """ - def unpack(graph_data): - graph_data = pd.to_numeric(graph_data, errors="coerce") - if isinstance(graph_data, str): - graph_data_list = ast.literal_eval(graph_data) - return np.array(graph_data_list) - elif isinstance(graph_data, (int, float)): - return np.array([graph_data]) - elif isinstance(graph_data, list): - return np.array(graph_data) - elif isinstance(graph_data, np.ndarray): - if len(graph_data.shape) == 0: - graph_data = np.expand_dims(graph_data, 0) - if graph_data.shape[0] == 0: - graph_data = np.array([np.nan]) - # TODO: Warning - return graph_data - else: - raise ValueError( - f"Graph data should be one of str, float, int, list, np.ndarray, got {type(graph_data)}" - ) - - def unpack_column(data: pd.Series): - return data.apply(unpack) - - def merge_columns(data: pd.Series): - data = data.to_list() - data = [np.array([np.nan]) if not isinstance(d, np.ndarray) and math.isnan(d) else d for d in data] - padded_data = itertools.zip_longest(*data, fillvalue=np.nan) - data = np.stack(list(padded_data), 1).T - return data - - unpacked_df: pd.DataFrame = df[label_cols].apply(unpack_column) - output = unpacked_df.apply(merge_columns, axis="columns").to_list() + num_rows = df.shape[0] + num_cols = len(label_cols) if task_level == "graph": - return np.concatenate(output) - return output + output = np.empty((num_rows, num_cols), dtype=np.float64) + + for col_index, col in enumerate(label_cols): + for i, v in enumerate(df[col]): + if isinstance(v, float): + output[i, col_index] = v + continue + + v = pd.to_numeric(v, errors="coerce") + + if isinstance(v, (int, float)): + output[i, col_index] = v + + else: + raise ValueError(f"Graph data should be one of float or int, got {type(v)}") + + return output, None + + # First, find the max length of each row (likely the number of nodes or edges) + # +1 is for the cumulative sum below + begin_offsets = np.zeros((num_rows + 1,), dtype=np.int64) + max_type = np.float16 + for col in label_cols: + for i, v in enumerate(df[col]): + if not isinstance(v, np.ndarray) and not isinstance(v, (int, float, list)): + v = pd.to_numeric(v, errors="coerce") + length = 0 + if isinstance(v, np.ndarray): + if len(v.shape) == 1: + length = v.shape[0] + elif len(v.shape) == 0: + length = 0 + else: + raise ValueError( + f"Graph data should be 1D np.ndarray, got ndarray with {len(v.shape)} dimensions" + ) + dtype = v.dtype + if dtype == np.float64: + max_type = np.float64 + elif dtype == np.float32 and max_type == np.float16: + max_type = np.float32 + elif isinstance(v, (int, float)): + length = 1 + max_type = np.float64 + elif isinstance(v, list): + length = len(v) + max_type = np.float64 + else: + raise ValueError(f"Graph data should be one of float, int, list, np.ndarray, got {type(v)}") + # The +1 is so that the cumulative sum below gives the beginning offsets + begin_offsets[i + 1] = max(begin_offsets[i + 1], length) + + begin_offsets = np.cumsum(begin_offsets) + full_num_rows = begin_offsets[-1] + + output = np.empty((full_num_rows, num_cols), dtype=max_type) + + # Now, fill in the values + for col_index, col in enumerate(label_cols): + for i, v in enumerate(df[col]): + full_row = begin_offsets[i] + end_row = begin_offsets[i + 1] + + if not isinstance(v, np.ndarray): + v = pd.to_numeric(v, errors="coerce") + + if isinstance(v, np.ndarray): + length = v.shape[0] if len(v.shape) == 1 else 0 + for j in range(length): + output[full_row + j, col_index] = v[j] + + elif isinstance(v, (int, float)): + length = 1 + output[full_row, col_index] = v + + elif isinstance(v, list): + length = len(v) + for j in range(length): + output[full_row + j, col_index] = v[j] + + else: + raise ValueError(f"Graph data should be one of float, int, list, np.ndarray, got {type(v)}") + + # Fill the rest of the rows in the column with nan + if full_row + length != end_row: + for row in range(full_row + length, end_row): + output[row, col_index] = np.nan + + return output, begin_offsets diff --git a/graphium/data/normalization.py b/graphium/data/normalization.py index 994e8939b..e57a9bcc8 100644 --- a/graphium/data/normalization.py +++ b/graphium/data/normalization.py @@ -57,6 +57,12 @@ def __init__( self.data_mean = None self.data_std = None + def set_statistics(self, data_min, data_max, data_mean, data_std): + self.data_min = data_min + self.data_max = data_max + self.data_mean = data_mean + self.data_std = data_std + def calculate_statistics(self, array): """ Saves the normalization parameters (e.g. mean and variance) to the object. diff --git a/graphium/data/utils.py b/graphium/data/utils.py index aa5151a90..5136ce60e 100644 --- a/graphium/data/utils.py +++ b/graphium/data/utils.py @@ -25,7 +25,6 @@ import graphium from torch_geometric.data import Data -from graphium.features.featurizer import GraphDict GRAPHIUM_DATASETS_BASE_URL = "gs://graphium-public/datasets" GRAPHIUM_DATASETS = { @@ -129,7 +128,7 @@ def get_keys(pyg_data): return pyg_data.keys() -def found_size_mismatch(task: str, features: Union[Data, GraphDict], labels: np.ndarray, smiles: str) -> bool: +def found_size_mismatch(task: str, features: Data, labels: np.ndarray, smiles: str) -> bool: """Check if a size mismatch exists between features and labels with respect to node/edge/nodepair. Args: diff --git a/graphium/features/README.md b/graphium/features/README.md index 4188948fe..14b123106 100644 --- a/graphium/features/README.md +++ b/graphium/features/README.md @@ -7,8 +7,5 @@ ## What is in this folder? - ✅ `featurizer.py`: featurization code for the molecules, adding node, edge and graph features to the mol object -- `nmp.py`: check if a string can be converted to float, helper function for featurization -- `positional_encoding.py`: code for computing all raw positional and structural encoding of the graph, see `graph_positional_encoder` function -- `properties.py`: code for computing properties of the molecule -- `rw.py`: code for computing random walk positional encoding -- `spectral.py`: code for computing the spectral positional encoding such as the Laplacian eigenvalues and eigenvectors \ No newline at end of file + +Positional encodings, and atom/bond features (`nmp.py`) have been moved to the `/graphium_cpp` folder. \ No newline at end of file diff --git a/graphium/features/__init__.py b/graphium/features/__init__.py index 40984a2a4..e9cb41d1f 100644 --- a/graphium/features/__init__.py +++ b/graphium/features/__init__.py @@ -1,9 +1,2 @@ -from .featurizer import get_mol_atomic_features_onehot -from .featurizer import get_mol_atomic_features_float -from .featurizer import get_mol_edge_features -from .featurizer import mol_to_adj_and_features -from .featurizer import mol_to_graph_dict from .featurizer import mol_to_graph_signature -from .featurizer import GraphDict from .featurizer import mol_to_pyggraph -from .featurizer import to_dense_array diff --git a/graphium/features/commute.py b/graphium/features/commute.py deleted file mode 100644 index a7cea768c..000000000 --- a/graphium/features/commute.py +++ /dev/null @@ -1,69 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Dict, Any - -import numpy as np - -from scipy.sparse import spmatrix, issparse -from scipy.linalg import pinv - - -def compute_commute_distances( - adj: Union[np.ndarray, spmatrix], num_nodes: int, cache: Dict[str, Any] -) -> Tuple[np.ndarray, str, Dict[str, Any]]: - """ - Compute avg. commute time/distance between nodepairs. This is the avg. number of steps a random walker, starting - at node i, will take before reaching a given node j for the first time, and then return to node i. - - Reference: Saerens et al. "The principal components analysis of a graph, and its relationships to spectral clustering." ECML. 2004. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - Returns: - dist [num_nodes, num_nodes]: 2D array with avg. commute distances between nodepairs - base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here nodepair - cache: Updated dictionary of cached objects - """ - - base_level = "nodepair" - - if "commute" in cache: - dist = cache["commute"] - - else: - if issparse(adj): - adj = adj.toarray() - - volG = adj.sum() - - if "pinvL" in cache: - pinvL = cache["pinvL"] - - else: - L = np.diagflat(np.sum(adj, axis=1)) - adj - pinvL = pinv(L) - cache["pinvL"] = pinvL - - dist = volG * np.asarray( - [ - [pinvL[i, i] + pinvL[j, j] - 2 * pinvL[i, j] for j in range(num_nodes)] - for i in range(num_nodes) - ] - ) - cache["commute"] = dist - - return dist, base_level, cache diff --git a/graphium/features/electrostatic.py b/graphium/features/electrostatic.py deleted file mode 100644 index 58dc115f7..000000000 --- a/graphium/features/electrostatic.py +++ /dev/null @@ -1,58 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Dict, Any - -import numpy as np - -from scipy.linalg import pinv -from scipy.sparse import spmatrix, issparse - - -def compute_electrostatic_interactions( - adj: Union[np.ndarray, spmatrix], cache: Dict[str, Any] -) -> Tuple[np.ndarray, str, Dict[str, Any]]: - """ - Compute electrostatic interaction of nodepairs. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix - cache: Dictionary of cached objects - Returns: - electrostatic [num_nodes, num_nodes]: 2D array with electrostatic interactions of node nodepairs - base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here nodepair - cache: Updated dictionary of cached objects - """ - - base_level = "nodepair" - - if "electrostatic" in cache: - electrostatic = cache["electrostatic"] - - else: - if "pinvL" in cache: - pinvL = cache["pinvL"] - - else: - if issparse(adj): - adj = adj.toarray() - - L = np.diagflat(np.sum(adj, axis=1)) - adj - pinvL = pinv(L) - cache["pinvL"] = pinvL - - electrostatic = pinvL - np.diag(pinvL) # This means that the "ground" is set to any given atom - cache["electrostatic"] = electrostatic - - return electrostatic, base_level, cache diff --git a/graphium/features/featurizer.py b/graphium/features/featurizer.py index 8d8e18159..21d874de1 100644 --- a/graphium/features/featurizer.py +++ b/graphium/features/featurizer.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -23,966 +23,27 @@ from torch_geometric.data import Data -from rdkit import Chem -import datamol as dm +import graphium_cpp -from graphium.features import nmp -from graphium.utils.tensor import one_of_k_encoding -from graphium.features.positional_encoding import get_all_positional_encodings +# These are the integers that correspond with the torch data types in C++ +NP_DTYPE_TO_TORCH_INT = {np.float16: 5, np.float32: 6, np.float64: 7} -def to_dense_array(array: np.ndarray, dtype: str = None) -> np.ndarray: - r""" - Assign the node data - Parameters: - array: The array to convert to dense - dtype: The dtype of the array - Returns: - The dense array - """ - if array is not None: - if issparse(array): - if array.dtype == np.float16: # float16 doesn't support `todense` - array = array.astype(np.float32) - array = array.todense() - - if dtype is not None: - array = array.astype(dtype) - return array - - -def to_dense_tensor(tensor: Tensor, dtype: str = None) -> Tensor: - r""" - Assign the node data - Parameters: - array: The array to convert to dense - dtype: The dtype of the array - Returns: - The dense array - """ - if tensor is not None: - if tensor.is_sparse: - tensor = tensor.todense() - if dtype is not None: - tensor = tensor.to(dtype) - return tensor - - -def _mask_nans_inf(mask_nan: Optional[str], array: np.ndarray, array_name: str) -> np.ndarray: - r""" - mask the NaNs in the array - Parameters: - mask_nan: How to mask the NaNs - array: The array to mask - array_name: The name of the array - Returns: - The masked array - """ - if (mask_nan is None) or (array is None): - return array - - new_array = array - if issparse(new_array): - new_array = new_array.data - nans = ~np.isfinite(new_array) - - # Mask the NaNs - if nans.any(): - msg = f"There are {np.sum(nans)} NaNs in `{array_name}`" - if mask_nan == "raise": - raise ValueError(msg) - elif mask_nan == "warn": - logger.warning(msg) - else: - new_array[nans] = mask_nan - if issparse(array): - array.data = new_array - new_array = array - return new_array - - -def get_mol_atomic_features_onehot(mol: dm.Mol, property_list: List[str]) -> Dict[str, Tensor]: - r""" - Get the following set of features for any given atom - - * One-hot representation of the atom - * One-hot representation of the atom degree - * One-hot representation of the atom implicit valence - * One-hot representation of the the atom hybridization - * Whether the atom is aromatic - * The atom's formal charge - * The atom's number of radical electrons - - Additionally, the following features can be set, depending on the value of input Parameters - - * One-hot representation of the number of hydrogen atom in the the current atom neighborhood if `explicit_H` is false - * One-hot encoding of the atom chirality, and whether such configuration is even possible - - Parameters: - - mol: - molecule from which to extract the properties - - property_list: - A list of integer atomic properties to get from the molecule. - The integer values are converted to a one-hot vector. - Callables are not supported by this function. - - Accepted properties are: - - - "atomic-number" - - "degree" - - "valence", "total-valence" - - "implicit-valence" - - "hybridization" - - "chirality" - - "phase" - - "type" - - "group" - - "period" - - Returns: - prop_dict: - A dictionnary where the element of ``property_list`` are the keys - and the values are np.ndarray of shape (N, OH). N is the number of atoms - in ``mol`` and OH the lenght of the one-hot encoding. - - """ - - prop_dict = {} - - for prop in property_list: - prop = prop.lower() - prop_name = prop - - property_array = [] - for ii, atom in enumerate(mol.GetAtoms()): - if prop in ["atomic-number"]: - one_hot = one_of_k_encoding(atom.GetSymbol(), nmp.ATOM_LIST) - elif prop in ["degree"]: - one_hot = one_of_k_encoding(atom.GetDegree(), nmp.ATOM_DEGREE_LIST) - elif prop in ["valence", "total-valence"]: - prop_name = "valence" - one_hot = one_of_k_encoding(atom.GetTotalValence(), nmp.VALENCE) - elif prop in ["implicit-valence"]: - one_hot = one_of_k_encoding(atom.GetImplicitValence(), nmp.VALENCE) - elif prop in ["hybridization"]: - one_hot = one_of_k_encoding(atom.GetHybridization(), nmp.HYBRIDIZATION_LIST) - elif prop in ["chirality"]: - try: - one_hot = one_of_k_encoding(atom.GetProp("_CIPCode"), nmp.CHIRALITY_LIST) - one_hot.append(int(atom.HasProp("_ChiralityPossible"))) - except: - one_hot = [0, 0, int(atom.HasProp("_ChiralityPossible"))] - elif prop in "phase": - one_hot = one_of_k_encoding(nmp.PHASE[atom.GetAtomicNum() - 1], nmp.PHASE_SET) - elif prop in "type": - one_hot = one_of_k_encoding(nmp.TYPE[atom.GetAtomicNum() - 1], nmp.TYPE_SET) - elif prop in "group": - one_hot = one_of_k_encoding(nmp.GROUP[atom.GetAtomicNum() - 1], nmp.GROUP_SET) - elif prop in "period": - one_hot = one_of_k_encoding(nmp.PERIOD[atom.GetAtomicNum() - 1], nmp.PERIOD_SET) - else: - raise ValueError(f"Unsupported property `{prop}`") - - property_array.append(np.asarray(one_hot, dtype=np.float16)) - - prop_dict[prop_name] = np.stack(property_array, axis=0) - - return prop_dict - - -def get_mol_conformer_features( - mol: dm.Mol, - property_list: Union[List[str], List[Callable]], - mask_nan: Optional[Union[float, str]] = None, -) -> Dict[str, np.ndarray]: - r"""obtain the conformer features of a molecule - Parameters: - - mol: - molecule from which to extract the properties - - property_list: - A list of conformer property to get from the molecule - Accepted properties are: - - "positions_3d" - - Returns: - prop_dict: a dictionary where the element of ``property_list`` are the keys - """ - prop_dict = {} - has_conf = True - - try: - mol.GetConformer() - except: - has_conf = False - # * currently only accepts "positions_3d", raise errors otherwise - for prop in property_list: - if isinstance(prop, str): - if prop in ["positions_3d"]: # locating 3d conformer coordinates - if not has_conf: - positions = np.full((mol.GetNumAtoms(), 3), float("nan"), dtype=np.float16) - else: - positions = [[], [], []] - for i in range(mol.GetNumAtoms()): - pos = mol.GetConformer().GetAtomPosition(i) - positions[0].append(pos.x) - positions[1].append(pos.y) - positions[2].append(pos.z) - positions = np.asarray(positions, dtype=np.float16).T - prop_dict[prop] = positions - else: - raise ValueError( - str(prop) + " is not currently supported as a conformer property in `property_list`" - ) - else: - raise ValueError(f"Elements in `property_list` must be str or callable, provided `{type(prop)}`") - - prop_dict[prop] = _mask_nans_inf(mask_nan, prop_dict[prop], prop) - - return prop_dict - - -def get_mol_atomic_features_float( - mol: dm.Mol, - property_list: Union[List[str], List[Callable]], - offset_carbon: bool = True, - mask_nan: Union[str, float, type(None)] = "raise", -) -> Dict[str, np.ndarray]: - r""" - Get a dictionary of floating-point arrays of atomic properties. - To ensure all properties are at a similar scale, some of the properties - are divided by a constant. - - There is also the possibility of offseting by the carbon value using - the `offset_carbon` parameter. - - Parameters: - - mol: - molecule from which to extract the properties - - property_list: - A list of atomic properties to get from the molecule, such as 'atomic-number', - 'mass', 'valence', 'degree', 'electronegativity'. - Some elements are divided by a factor to avoid feature explosion. - - Accepted properties are: - - - "atomic-number" - - "mass", "weight" - - "valence", "total-valence" - - "implicit-valence" - - "hybridization" - - "chirality" - - "hybridization" - - "aromatic" - - "ring", "in-ring" - - "min-ring" - - "max-ring" - - "num-ring" - - "degree" - - "radical-electron" - - "formal-charge" - - "vdw-radius" - - "covalent-radius" - - "electronegativity" - - "ionization", "first-ionization" - - "melting-point" - - "metal" - - "single-bond" - - "aromatic-bond" - - "double-bond" - - "triple-bond" - - "is-carbon" - - "group" - - "period" - - offset_carbon: - Whether to subract the Carbon property from the desired atomic property. - For example, if we want the mass of the Lithium (6.941), the mass of the - Carbon (12.0107) will be subracted, resulting in a value of -5.0697 - - mask_nan: - Deal with molecules that fail a part of the featurization. - NaNs can happen when taking the of a noble gas, - or other properties that are not measured for specific atoms. - - - "raise": Raise an error when there is a nan or inf in the featurization - - "warn": Raise a warning when there is a nan or inf in the featurization - - "None": DEFAULT. Don't do anything - - "Floating value": Replace nans or inf by the specified value - - Returns: - - prop_dict: - A dictionnary where the element of ``property_list`` are the keys - and the values are np.ndarray of shape (N,). N is the number of atoms - in ``mol``. - - """ - - periodic_table = Chem.GetPeriodicTable() - prop_dict = {} - C = Chem.Atom("C") - C_num = C.GetAtomicNum() - offC = bool(offset_carbon) - atom_list = list(mol.GetAtoms()) - - for prop in property_list: - prop_name = None - - property_array = np.zeros(mol.GetNumAtoms(), dtype=np.float16) - for ii, atom in enumerate(atom_list): - val = None - atomic_num = atom.GetAtomicNum() - - if isinstance(prop, str): - prop = prop.lower() - prop_name = prop - - if prop in ["atomic-number"]: - val = (atomic_num - (offC * C_num)) / 5 - elif prop in ["mass", "weight"]: - prop_name = "mass" - val = (atom.GetMass() - (offC * C.GetMass())) / 10 - elif prop in ["valence", "total-valence"]: - prop_name = "valence" - val = atom.GetTotalValence() - (offC * 4) - elif prop in ["implicit-valence"]: - val = atom.GetImplicitValence() - elif prop in ["hybridization"]: - val = atom.GetHybridization() - elif prop in ["chirality"]: - val = (atom.GetProp("_CIPCode") == "R") if atom.HasProp("_CIPCode") else 2 - elif prop in ["hybridization"]: - val = atom.GetHybridization() - elif prop in ["aromatic"]: - val = atom.GetIsAromatic() - elif prop in ["ring", "in-ring"]: - prop_name = "in-ring" - val = atom.IsInRing() - elif prop in ["min-ring"]: - ring_info = mol.GetRingInfo() - val = ring_info.MinAtomRingSize(atom.GetIdx()) - elif prop in ["max-ring"]: - rings = mol.GetRingInfo().AtomRings() - val = 0 - for ring in rings: - if atom.GetIdx() in ring: - if len(ring) > val: - val = len(ring) - elif prop in ["num-ring"]: - ring_info = mol.GetRingInfo() - val = ring_info.NumAtomRings(atom.GetIdx()) - elif prop in ["degree"]: - val = atom.GetTotalDegree() - (offC * 2) - elif prop in ["radical-electron"]: - val = atom.GetNumRadicalElectrons() - elif prop in ["formal-charge"]: - val = atom.GetFormalCharge() - elif prop in ["vdw-radius"]: - val = periodic_table.GetRvdw(atom.GetAtomicNum()) - offC * periodic_table.GetRvdw(C_num) - elif prop in ["covalent-radius"]: - val = periodic_table.GetRcovalent(atomic_num) - offC * periodic_table.GetRcovalent(C_num) - elif prop in ["electronegativity"]: - val = ( - nmp.ELECTRONEGATIVITY[atom.GetAtomicNum() - 1] - - offC * nmp.ELECTRONEGATIVITY[C_num - 1] - ) - elif prop in ["ionization", "first-ionization"]: - prop_name = "ionization" - val = (nmp.FIRST_IONIZATION[atomic_num - 1] - offC * nmp.FIRST_IONIZATION[C_num - 1]) / 5 - elif prop in ["melting-point"]: - val = (nmp.MELTING_POINT[atomic_num - 1] - offC * nmp.MELTING_POINT[C_num - 1]) / 200 - elif prop in ["metal"]: - val = nmp.METAL[atomic_num - 1] - elif prop in "group": - val = float(nmp.GROUP[atomic_num - 1]) - offC * float(nmp.GROUP[C_num - 1]) - elif prop in "period": - val = float(nmp.PERIOD[atomic_num - 1]) - offC * float(nmp.PERIOD[C_num - 1]) - elif "-bond" in prop: - bonds = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()] - if prop in ["single-bond"]: - val = len([bond == 1 for bond in bonds]) - elif prop in ["aromatic-bond"]: - val = len([bond == 1.5 for bond in bonds]) - elif prop in ["double-bond"]: - val = len([bond == 2 for bond in bonds]) - elif prop in ["triple-bond"]: - val = len([bond == 3 for bond in bonds]) - else: - raise ValueError(f"{prop} is not a correct bond.") - val -= offC * 1 - elif prop in ["is-carbon"]: - val = atom.GetAtomicNum() == 6 - val -= offC * 1 - else: - raise ValueError(f"Unsupported property `{prop}`") - - elif callable(prop): - prop_name = str(prop) - val = prop(atom) - else: - ValueError(f"Elements in `property_list` must be str or callable, provided `{type(prop)}`") - - if val is None: - raise ValueError("val is undefined.") - - property_array[ii] = val - - if prop_name is None: - raise ValueError("prop_name is undefined.") - - # Mask the NaNs - prop_dict[prop_name] = _mask_nans_inf(mask_nan, property_array, "atom featurization") - - return prop_dict - - -def get_simple_mol_conformer(mol: dm.Mol) -> Union[Chem.rdchem.Conformer, None]: - r""" - If the molecule has a conformer, then it will return the conformer at idx `0`. - Otherwise, it generates a simple molecule conformer using `rdkit.Chem.rdDistGeom.EmbedMolecule` - and returns it. This is meant to be used in simple functions like `GetBondLength`, - not in functions requiring complex 3D structure. - - Parameters: - - mol: Rdkit Molecule - - Returns: - conf: A conformer of the molecule, or `None` if it fails - """ - - val = 0 - if mol.GetNumConformers() == 0: - val = Chem.rdDistGeom.EmbedMolecule(mol) - if val == -1: - val = Chem.rdDistGeom.EmbedMolecule( - mol, - enforceChirality=False, - ignoreSmoothingFailures=True, - useBasicKnowledge=True, - useExpTorsionAnglePrefs=True, - forceTol=0.1, - ) - - if val == -1: - conf = None - logger.warn("Couldn't compute conformer for molecule `{}`".format(Chem.MolToSmiles(mol))) - else: - conf = mol.GetConformer(0) - - return conf - - -def get_estimated_bond_length(bond: Chem.rdchem.Bond, mol: dm.Mol) -> float: - r""" - Estimate the bond length between atoms by looking at the estimated atomic radius - that depends both on the atom type and the bond type. The resulting bond-length is - then the sum of the radius. - - Keep in mind that this function only provides an estimate of the bond length and not - the true one based on a conformer. The vast majority od estimated bond lengths will - have an error below 5% while some bonds can have an error up to 20%. This function - is mostly useful when conformer generation fails for some molecules, or for - increased computation speed. - - Parameters: - bond: The bond to measure its lenght - mol: The molecule containing the bond (used to get neighbouring atoms) - - Returns: - bond_length: The bond length in Angstrom, typically a value around 1-2. - - """ - - # Get the atoms connected by the bond - idx1 = bond.GetBeginAtomIdx() - idx2 = bond.GetEndAtomIdx() - atom1 = mol.GetAtomWithIdx(idx1).GetAtomicNum() - atom2 = mol.GetAtomWithIdx(idx2).GetAtomicNum() - bond_type = bond.GetBondType() - - # Get single bond atomic radius - if bond_type == Chem.rdchem.BondType.SINGLE: - rad1 = [nmp.BOND_RADIUS_SINGLE[atom1 - 1]] - rad2 = [nmp.BOND_RADIUS_SINGLE[atom2 - 1]] - # Get double bond atomic radius - elif bond_type == Chem.rdchem.BondType.DOUBLE: - rad1 = [nmp.BOND_RADIUS_DOUBLE[atom1 - 1]] - rad2 = [nmp.BOND_RADIUS_DOUBLE[atom2 - 1]] - # Get triple bond atomic radius - elif bond_type == Chem.rdchem.BondType.TRIPLE: - rad1 = [nmp.BOND_RADIUS_TRIPLE[atom1 - 1]] - rad2 = [nmp.BOND_RADIUS_TRIPLE[atom2 - 1]] - # Get average of single bond and double bond atomic radius - elif bond_type == Chem.rdchem.BondType.AROMATIC: - rad1 = [nmp.BOND_RADIUS_SINGLE[atom1 - 1], nmp.BOND_RADIUS_DOUBLE[atom1 - 1]] - rad2 = [nmp.BOND_RADIUS_SINGLE[atom2 - 1], nmp.BOND_RADIUS_DOUBLE[atom2 - 1]] - - # Average the bond lengths, while ignoring nans in case some missing value - rad1_float = [elem for elem in rad1 if elem is not None] - rad2_float = [elem for elem in rad2 if elem is not None] - - if len(rad1_float) > 0: - rad1_float = sum(rad1_float) / len(rad1_float) - else: - rad1_float = float(nmp.BOND_RADIUS_SINGLE[atom1 - 1]) - - if len(rad2_float) > 0: - rad2_float = sum(rad2_float) / len(rad2_float) - else: - rad2_float = float(nmp.BOND_RADIUS_SINGLE[atom2 - 1]) - - bond_length = rad1_float + rad2_float - return bond_length - - -def get_mol_edge_features( - mol: dm.Mol, property_list: List[str], mask_nan: Union[str, float, type(None)] = "raise" -) -> Dict[str, np.ndarray]: - r""" - Get the following set of features for any given bond - See `graphium.features.nmp` for allowed values in one hot encoding - - * One-hot representation of the bond type. Note that you should not kekulize your - molecules, if you expect this to take aromatic bond into account. - * Bond stereo type, following CIP classification - * Whether the bond is conjugated - * Whether the bond is in a ring - - Parameters: - mol: rdkit.Chem.Molecule - the molecule of interest - - property_list: - A list of edge properties to return for the given molecule. - Accepted properties are: - - - "bond-type-onehot" - - "bond-type-float" - - "stereo" - - "in-ring" - - "conjugated" - - "conformer-bond-length" (might cause problems with complex molecules) - - "estimated-bond-length" - - Returns: - prop_dict: - A dictionnary where the element of ``property_list`` are the keys - and the values are np.ndarray of shape (N,). N is the number of atoms - in ``mol``. - - """ - - prop_dict = {} - - # Compute features for each bond - num_bonds = mol.GetNumBonds() - for prop in property_list: - property_array = [] - for ii in range(num_bonds): - prop = prop.lower() - bond = mol.GetBondWithIdx(ii) - - if prop in ["bond-type-onehot"]: - encoding = one_of_k_encoding(bond.GetBondType(), nmp.BOND_TYPES) - elif prop in ["bond-type-float"]: - encoding = [bond.GetBondTypeAsDouble()] - elif prop in ["stereo"]: - encoding = one_of_k_encoding(bond.GetStereo(), nmp.BOND_STEREO) - elif prop in ["in-ring"]: - encoding = [bond.IsInRing()] - elif prop in ["conjugated"]: - encoding = [bond.GetIsConjugated()] - elif prop in ["conformer-bond-length"]: - conf = get_simple_mol_conformer(mol) - if conf is not None: - idx1 = bond.GetBeginAtomIdx() - idx2 = bond.GetEndAtomIdx() - encoding = [Chem.rdMolTransforms.GetBondLength(conf, idx1, idx2)] - else: - encoding = [0] - elif prop in ["estimated-bond-length"]: - encoding = [get_estimated_bond_length(bond, mol)] - - else: - raise ValueError(f"Unsupported property `{prop}`") - - property_array.append(np.asarray(encoding, dtype=np.float16)) - - if num_bonds > 0: - property_array = np.stack(property_array, axis=0) - # Mask the NaNs - prop_dict[prop] = _mask_nans_inf(mask_nan, property_array, "edge property") - else: - # Add an empty vector with the right shape - arr_len = 1 - if prop in ["bond-type-onehot"]: - arr_len = len(nmp.BOND_TYPES) + 1 - elif prop in ["stereo"]: - arr_len = len(nmp.BOND_STEREO) + 1 - - prop_dict[prop] = np.zeros((0, arr_len)) - - return prop_dict - - -def mol_to_adj_and_features( - mol: Union[str, dm.Mol], - atom_property_list_onehot: List[str] = [], - atom_property_list_float: List[Union[str, Callable]] = [], - conformer_property_list: List[str] = [], - edge_property_list: List[str] = [], - add_self_loop: bool = False, - explicit_H: bool = False, - use_bonds_weights: bool = False, - pos_encoding_as_features: Dict[str, Any] = None, - dtype: np.dtype = np.float16, - mask_nan: Union[str, float, type(None)] = "raise", -) -> Union[ - coo_matrix, - Union[Tensor, None], - Union[Tensor, None], - Dict[str, Tensor], - Union[Tensor, None], - Dict[str, Tensor], -]: - r""" - Transforms a molecule into an adjacency matrix representing the molecular graph - and a set of atom and bond features. - - It also returns the positional encodings associated to the graph. - - Parameters: - - mol: - The molecule to be converted - - atom_property_list_onehot: - List of the properties used to get one-hot encoding of the atom type, - such as the atom index represented as a one-hot vector. - See function `get_mol_atomic_features_onehot` - - atom_property_list_float: - List of the properties used to get floating-point encoding of the atom type, - such as the atomic mass or electronegativity. - See function `get_mol_atomic_features_float` - - conformer_property_list: - list of properties used to encode the conformer information, outside of atom properties, currently support "positions_3d" - - edge_property_list: - List of the properties used to encode the edges, such as the edge type - and the stereo type. - - add_self_loop: - Whether to add a value of `1` on the diagonal of the adjacency matrix. - - explicit_H: - Whether to consider the Hydrogens explicitely. If `False`, the hydrogens - are implicit. - - use_bonds_weights: - Whether to use the floating-point value of the bonds in the adjacency matrix, - such that single bonds are represented by 1, double bonds 2, triple 3, aromatic 1.5 - - pos_encoding_as_features: keyword arguments for function `graph_positional_encoder` - to generate positional encoding for node features. - - dtype: - The torch data type used to build the graph - - mask_nan: - Deal with molecules that fail a part of the featurization. - NaNs can happen when taking the of a noble gas, - or other properties that are not measured for specific atoms. - - - "raise": Raise an error when there is a nan or inf in the featurization - - "warn": Raise a warning when there is a nan or inf in the featurization - - "None": DEFAULT. Don't do anything - - "Floating value": Replace nans or inf by the specified value - Returns: - - adj: - torch coo sparse adjacency matrix of the molecule - - ndata: - Concatenated node data of the atoms, based on the properties from - `atom_property_list_onehot` and `atom_property_list_float`. - If no properties are given, it returns `None` - - edata: - Concatenated node edge of the molecule, based on the properties from - `edge_property_list`. - If no properties are given, it returns `None` - - pe_dict: - Dictionary of all positional encodings. Current supported keys: - - - "pos_enc_feats_sign_flip": - Node positional encoding that requires augmentation via sign-flip. - For example, eigenvectors of the Laplacian are ambiguous to the - sign and are returned here. - - - "pos_enc_feats_no_flip": - Node positional encoding that requires does not use sign-flip. - For example, distance from centroid are returned here. - - - "rwse": - Node structural encoding corresponding to the diagonal of the random - walk matrix - - conf_dict: - contains the 3d positions of a conformer of the molecule or 0s if none is found - - """ - - if isinstance(mol, str): - mol = dm.to_mol(mol, ordered=True) - - # Add or remove explicit hydrogens - if explicit_H: - mol = Chem.AddHs(mol) - else: - mol = Chem.RemoveHs(mol) - - num_nodes = mol.GetNumAtoms() - - adj = mol_to_adjacency_matrix( - mol, use_bonds_weights=use_bonds_weights, add_self_loop=add_self_loop, dtype=dtype - ) - - # Get the node features - atom_features_onehot = get_mol_atomic_features_onehot(mol, atom_property_list_onehot) - atom_features_float = get_mol_atomic_features_float(mol, atom_property_list_float, mask_nan=mask_nan) - conf_dict = get_mol_conformer_features(mol, conformer_property_list, mask_nan=mask_nan) - ndata = list(atom_features_float.values()) + list(atom_features_onehot.values()) - ndata = [d[:, np.newaxis] if d.ndim == 1 else d for d in ndata] - - if len(ndata) > 0: - ndata = np.concatenate(ndata, axis=1).astype(dtype=dtype) - else: - ndata = None - - # Get the edge features - edge_features = get_mol_edge_features(mol, edge_property_list, mask_nan=mask_nan) - edata = list(edge_features.values()) - edata = [np.expand_dims(d, axis=1) if d.ndim == 1 else d for d in edata] - if len(edata) > 0: - edata = np.concatenate(edata, axis=1).astype(dtype=dtype) - else: - edata = None - - # Get all positional encodings - pe_dict = get_all_positional_encodings(adj, num_nodes, pos_encoding_as_features) - - # Mask the NaNs - for pe_key, pe_val in pe_dict.items(): - pe_val = np.asarray(pe_val, dtype=dtype) - pe_dict[pe_key] = _mask_nans_inf(mask_nan, pe_val, pe_key) - - return adj, ndata, edata, pe_dict, conf_dict - - -def mol_to_adjacency_matrix( - mol: dm.Mol, - use_bonds_weights: bool = False, - add_self_loop: bool = False, - dtype: np.dtype = np.float32, -) -> coo_matrix: - r""" - Convert a molecule to a sparse adjacency matrix, as a torch Tensor. - Instead of using the Rdkit `GetAdjacencyMatrix()` method, this method - uses the bond ordering from the molecule object, which is the same as - the bond ordering in the bond features. - - Warning: - Do not use `Tensor.coalesce()` on the returned adjacency matrix, as it - will change the ordering of the bonds. - - Args: - mol: A molecule in the form of a SMILES string or an RDKit molecule object. - - use_bonds_weights: - If `True`, the adjacency matrix will contain the bond type as the - value of the edge. If `False`, the adjacency matrix will contain - `1` as the value of the edge. - - add_self_loop: - If `True`, the adjacency matrix will contain a self-loop for each - node. - - dtype: - The data type used to build the graph - - Returns: - adj: - coo sparse adjacency matrix of the molecule - """ - - # Get the indices for the adjacency matrix, and the bond value - adj_idx, adj_val = [], [] - for bond in mol.GetBonds(): - adj_idx.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) - adj_idx.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) - if use_bonds_weights: - val = nmp.BOND_TYPES[bond.GetBondType()] - else: - val = 1 - adj_val.extend([val, val]) - - # Convert to torch coo sparse tensor - if len(adj_val) > 0: # ensure tensor is not empty - adj = coo_matrix( - (torch.as_tensor(adj_val), torch.as_tensor(adj_idx).T.reshape(2, -1)), - shape=(mol.GetNumAtoms(), mol.GetNumAtoms()), - dtype=dtype, - ) - else: - # Special case for molecules with one atom - adj = coo_matrix(([], np.array([[], []])), shape=(mol.GetNumAtoms(), mol.GetNumAtoms()), dtype=dtype) - - # Add self loops - if add_self_loop: - arange = np.arange(adj.shape[0], dtype=int) - adj[arange, arange] = 1 - return adj - - -class GraphDict(dict): - def __init__( - self, - dic: Dict, - ): - r""" - Store the parameters required to initialize a `pyg.data.Data`, but - as a dictionary to reduce memory consumption. - - Possible keys for the dictionary: - - - adj: A sparse Tensor containing the adjacency matrix - - - ndata: A dictionnary containing different keys and Tensors - associated to the node features. - - - edata: A dictionnary containing different keys and Tensors - associated to the edge features. - - - dtype: The dtype for the floating data. - - - mask_nan: - Deal with molecules that fail a part of the featurization. - NaNs can happen when taking the of a noble gas, - or other properties that are not measured for specific atoms. - - - "raise": Raise an error when there is a nan or inf in the featurization - - "warn": Raise a warning when there is a nan or inf in the featurization - - "None": DEFAULT. Don't do anything - - "Floating value": Replace nans or inf by the specified value - """ - default_dic = { - "dtype": np.float16, - "mask_nan": "raise", - } - data = dic.pop("data", {}) - # ndata = dic.pop("ndata", {}) - # edata = dic.pop("edata", {}) - # for key in edata.keys(): - # assert key.startswith("edge_"), f"Edge features must start with 'edge_' but got {key}" - default_dic.update(dic) - default_dic.update(data) - # default_dic.update(ndata) - # default_dic.update(edata) - super().__init__(default_dic) - - @property - def keys(self): - return list(super().keys()) - - @property - def values(self): - return list(super().self.values()) - - def make_pyg_graph(self, **kwargs) -> Data: - """ - Convert the current dictionary of parameters, containing an adjacency matrix with node/edge data - into a `pyg.data.Data` of torch Tensors. - - `**kwargs` can be used to overwrite any parameter from the current dictionary. See `GraphDict.__init__` - for a list of parameters - """ - - num_nodes = self.adj.shape[0] - data_dict = {} - - # Convert the numpy and numpy sparse data to torch - for key, val in self.items(): - if key in ["adj", "dtype", "mask_nan"]: # Skip the parameters - continue - elif isinstance(val, np.ndarray): - # Convert the data to the specified dtype in torch format - val = val.astype(self.dtype) - data_dict[key] = torch.as_tensor(val) - elif issparse(val): - data_dict[key] = torch.as_tensor(val.astype(np.float32).todense()) - # `torch.sparse_coo_tensor` is too slow. Slows down the multiprocessing of features by >3x on 32 cores. - # indices = torch.from_numpy(np.vstack((val.row, val.col)).astype(np.int64)) - # data_dict[key] = torch.sparse_coo_tensor(indices=indices, values=val.data, size=val.shape) - elif isinstance(val, torch.Tensor): - data_dict[key] = val - else: - pass # Skip the other parameters - - # Create the PyG graph object `Data` - edge_index = torch.as_tensor(np.vstack((self.adj.row, self.adj.col))) - edge_weight = torch.as_tensor(self.adj.data) - data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes, **data_dict) - return data - - @property - def adj(self): - return self["adj"] - - @property - def dtype(self): - return self["dtype"] - - @property - def mask_nan(self): - return self["mask_nan"] - - @property - def num_nodes(self) -> int: - return self.adj.shape[0] - - @property - def num_edges(self) -> int: - if issparse(self.adj): - return self.adj.nnz - else: - return np.count_nonzero(self.adj) # No division by 2 because edges are counted twice - - -def mol_to_graph_dict( - mol: dm.Mol, - atom_property_list_onehot: List[str] = [], - atom_property_list_float: List[Union[str, Callable]] = [], +def mol_to_pyggraph( + mol: str, + atom_property_list_onehot: torch.Tensor = torch.tensor(data=[], dtype=torch.int64), + atom_property_list_float: torch.Tensor = torch.tensor(data=[], dtype=torch.int64), conformer_property_list: List[str] = [], - edge_property_list: List[str] = [], + edge_property_list: torch.Tensor = torch.tensor(data=[], dtype=torch.int64), add_self_loop: bool = False, explicit_H: bool = False, use_bonds_weights: bool = False, - pos_encoding_as_features: Dict[str, Any] = None, + pos_encoding_as_features: Tuple[List[str], torch.Tensor] = ([], torch.tensor(data=[], dtype=torch.int64)), dtype: np.dtype = np.float16, on_error: str = "ignore", mask_nan: Union[str, float, type(None)] = "raise", max_num_atoms: Optional[int] = None, -) -> Union[GraphDict, str]: +) -> Union[Data, str]: r""" Transforms a molecule into an adjacency matrix representing the molecular graph and a set of atom and bond features, and re-organizes them into a dictionary @@ -999,12 +60,10 @@ def mol_to_graph_dict( atom_property_list_onehot: List of the properties used to get one-hot encoding of the atom type, such as the atom index represented as a one-hot vector. - See function `get_mol_atomic_features_onehot` atom_property_list_float: List of the properties used to get floating-point encoding of the atom type, such as the atomic mass or electronegativity. - See function `get_mol_atomic_features_float` conformer_property_list: list of properties used to encode the conformer information, outside of atom properties, currently support "positions_3d" @@ -1068,191 +127,83 @@ def mol_to_graph_dict( - "dtype": The numpy dtype for the floating data. """ - input_mol = mol + if not isinstance(mol, str): + raise ValueError( + f"mol_to_pyggraph requires that molecule be received as a string, not type " + str(type(mol)) + ) + try: - if isinstance(mol, str): - mol = dm.to_mol(mol, ordered=True) - if explicit_H: - mol = Chem.AddHs(mol) + has_conformer = "positions_3d" in conformer_property_list + pe_index = 4 + if has_conformer: + pe_index = 5 + mask_nan_value = 0.0 + if mask_nan is None: + mask_nan_style_int = 0 + elif mask_nan == "raise" or mask_nan == "warn": + mask_nan_style_int = 1 else: - mol = Chem.RemoveHs(mol) - num_atoms = mol.GetNumAtoms() - if (max_num_atoms is not None) and (num_atoms > max_num_atoms): - raise ValueError(f"Maximum number of atoms greater than permitted {num_atoms}>{max_num_atoms}") - ( - adj, - ndata, - edata, - pe_dict, - conf_dict, - ) = mol_to_adj_and_features( - mol=mol, - atom_property_list_onehot=atom_property_list_onehot, - atom_property_list_float=atom_property_list_float, - conformer_property_list=conformer_property_list, - edge_property_list=edge_property_list, - add_self_loop=add_self_loop, - explicit_H=explicit_H, - use_bonds_weights=use_bonds_weights, - pos_encoding_as_features=pos_encoding_as_features, - mask_nan=mask_nan, + mask_nan_style_int = 2 + mask_nan_value = float(mask_nan) + tensors, num_nans, nan_tensor_index = graphium_cpp.featurize_smiles( + mol, + atom_property_list_onehot, + atom_property_list_float, + has_conformer, + edge_property_list, + pos_encoding_as_features[1], + True, # duplicate_edges, so that we don't have to duplicate below + add_self_loop, + explicit_H, + use_bonds_weights, + True, # offset_carbon + NP_DTYPE_TO_TORCH_INT[dtype], + mask_nan_style_int, + mask_nan_value, ) + + if num_nans > 0: + if nan_tensor_index == 2: + array_name = "atom featurization" + elif nan_tensor_index == 3: + array_name = "edge property" + elif nan_tensor_index == 4 and has_conformer: + array_name = "positions_3d" + else: + array_name = pos_encoding_as_features[0][nan_tensor_index - pe_index] + msg = f"There are {num_nans} NaNs in `{array_name}`" + if mask_nan == "raise": + raise ValueError(msg) + elif mask_nan == "warn": + logger.warning(msg) + + num_atoms = tensors[2].size(0) + data_dict = {"feat": tensors[2], "edge_feat": tensors[3]} + if has_conformer: + data_dict["positions_3d"] = tensors[4] + for i in range(len(tensors) - pe_index): + data_dict[pos_encoding_as_features[0][i]] = tensors[i + pe_index] + # Create the PyG graph object `Data` + data = Data(edge_index=tensors[0], edge_weight=tensors[1], num_nodes=num_atoms, **data_dict) + return data + except Exception as e: if on_error.lower() == "raise": raise e elif on_error.lower() == "warn": - smiles = input_mol - if isinstance(smiles, dm.Mol): - smiles = Chem.MolToSmiles(input_mol) - msg = str(e) + "\nIgnoring following molecule:" + smiles + msg = str(e) + "\nIgnoring following molecule:" + mol logger.warning(msg) return str(e) elif on_error.lower() == "ignore": return str(e) - - graph_dict = {"adj": adj, "data": {}, "dtype": dtype} - - # Assign the node data - if ndata is not None: - graph_dict["data"]["feat"] = ndata - - # Assign the edge data - if edata is not None: - if issparse(edata): - edata = to_dense_array(edata, dtype=dtype) - hetero_edata = edata.repeat(2, axis=0) - graph_dict["data"]["edge_feat"] = hetero_edata - - # Put the positional encodings as node features - # TODO: add support for PE on edges - for key, pe in pe_dict.items(): - graph_dict["data"][key] = pe - - # put the conformer positions here - for key, val in conf_dict.items(): - graph_dict["data"][key] = val - - graph_dict = GraphDict(graph_dict) - return graph_dict - - -def mol_to_pyggraph( - mol: dm.Mol, - atom_property_list_onehot: List[str] = [], - atom_property_list_float: List[Union[str, Callable]] = [], - conformer_property_list: List[str] = [], - edge_property_list: List[str] = [], - add_self_loop: bool = False, - explicit_H: bool = False, - use_bonds_weights: bool = False, - pos_encoding_as_features: Dict[str, Any] = None, - dtype: np.dtype = np.float16, - on_error: str = "ignore", - mask_nan: Union[str, float, type(None)] = "raise", - max_num_atoms: Optional[int] = None, -) -> Union[Data, str]: - r""" - Transforms a molecule into an adjacency matrix representing the molecular graph - and a set of atom and bond features. - - Then, the adjacency matrix and node/edge features are used to build a - `pyg.data.Data` with pytorch Tensors. - - Parameters: - - mol: - The molecule to be converted - - atom_property_list_onehot: - List of the properties used to get one-hot encoding of the atom type, - such as the atom index represented as a one-hot vector. - See function `get_mol_atomic_features_onehot` - - atom_property_list_float: - List of the properties used to get floating-point encoding of the atom type, - such as the atomic mass or electronegativity. - See function `get_mol_atomic_features_float` - - conformer_property_list: - list of properties used to encode the conformer information, outside of atom properties, currently support "positions_3d" - - edge_property_list: - List of the properties used to encode the edges, such as the edge type - and the stereo type. - - add_self_loop: - Whether to add a value of `1` on the diagonal of the adjacency matrix. - - explicit_H: - Whether to consider the Hydrogens explicitely. If `False`, the hydrogens - are implicit. - - use_bonds_weights: - Whether to use the floating-point value of the bonds in the adjacency matrix, - such that single bonds are represented by 1, double bonds 2, triple 3, aromatic 1.5 - - pos_encoding_as_features: keyword arguments for function `graph_positional_encoder` - to generate positional encoding for node features. - - dtype: - The numpy data type used to build the graph - - on_error: - What to do when the featurization fails. This can change the - behavior of `mask_nan`. - - - "raise": Raise an error - - "warn": Raise a warning and return a string of the error - - "ignore": Ignore the error and return a string of the error - - mask_nan: - Deal with molecules that fail a part of the featurization. - NaNs can happen when taking the of a noble gas, - or other properties that are not measured for specific atoms. - - - "raise": Raise an error when there is a nan in the featurization - - "warn": Raise a warning when there is a nan in the featurization - - "None": DEFAULT. Don't do anything - - "Floating value": Replace nans by the specified value - - max_num_atoms: - Maximum number of atoms for a given molecule. If a molecule with more atoms - is give, an error is raised, but catpured according to the rules of - `on_error`. - Returns: - - graph: - Pyg graph, with `graph['feat']` corresponding to the concatenated - node data from `atom_property_list_onehot` and `atom_property_list_float`, - `graph['edge_feat']` corresponding to the concatenated edge data from `edge_property_list`. - There are also additional entries for the positional encodings. - - """ - graph_dict = mol_to_graph_dict( - mol=mol, - atom_property_list_onehot=atom_property_list_onehot, - atom_property_list_float=atom_property_list_float, - conformer_property_list=conformer_property_list, - edge_property_list=edge_property_list, - add_self_loop=add_self_loop, - explicit_H=explicit_H, - use_bonds_weights=use_bonds_weights, - pos_encoding_as_features=pos_encoding_as_features, - dtype=dtype, - on_error=on_error, - mask_nan=mask_nan, - max_num_atoms=max_num_atoms, - ) - - if (graph_dict is not None) and not isinstance(graph_dict, str): - return graph_dict.make_pyg_graph() - else: - return graph_dict + else: + # Invalid on_error value, so default to raising an exception. + raise e def mol_to_graph_signature(featurizer_args: Dict[str, Any] = None) -> Dict[str, Any]: """ - Get the default arguments of `mol_to_graph_dict` and update it + Get the default arguments of `mol_to_pyggraph` and update it with a provided dict of arguments in order to get a fulle signature of the featurizer args actually used for the features computation. @@ -1262,8 +213,8 @@ def mol_to_graph_signature(featurizer_args: Dict[str, Any] = None) -> Dict[str, A dictionary of featurizer arguments """ - # Get the signature of `mol_to_graph_dict` - signature = inspect.signature(mol_to_graph_dict) + # Get the signature of `mol_to_pyggraph` + signature = inspect.signature(mol_to_pyggraph) # Filter out empty arguments (without default value) parameters = list(filter(lambda param: param.default is not param.empty, signature.parameters.values())) diff --git a/graphium/features/graphormer.py b/graphium/features/graphormer.py deleted file mode 100644 index d62010801..000000000 --- a/graphium/features/graphormer.py +++ /dev/null @@ -1,55 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Dict, Any - -import numpy as np -import networkx as nx - -from scipy.sparse import spmatrix, issparse - - -def compute_graphormer_distances( - adj: Union[np.ndarray, spmatrix], num_nodes: int, cache: Dict[str, Any] -) -> Tuple[np.ndarray, str, Dict[str, Any]]: - """ - Compute Graphormer distance between nodepairs. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - Returns: - dist [num_nodes, num_nodes]: 2D array with Graphormer distances between nodepairs - base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here nodepair - cache: Updated dictionary of cached objects - """ - - base_level = "nodepair" - - if "graphormer" in cache: - dist = cache["graphormer"] - - else: - if issparse(adj): - adj = adj.toarray() - - G = nx.from_numpy_array(adj) - paths = nx.all_pairs_shortest_path(G) - - dist_dict = {i: {j: len(path) - 1 for j, path in paths_from_i.items()} for i, paths_from_i in paths} - dist = np.asarray([[dist_dict[i][j] for j in range(num_nodes)] for i in range(num_nodes)]) - cache["graphormer"] = dist - - return dist, base_level, cache diff --git a/graphium/features/positional_encoding.py b/graphium/features/positional_encoding.py deleted file mode 100644 index 8acc231d8..000000000 --- a/graphium/features/positional_encoding.py +++ /dev/null @@ -1,181 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Optional, Dict, Any, OrderedDict -from copy import deepcopy -import numpy as np -import torch -from scipy.sparse import spmatrix -from collections import OrderedDict as OderedDictClass - -from graphium.features.spectral import compute_laplacian_pe -from graphium.features.rw import compute_rwse -from graphium.features.electrostatic import compute_electrostatic_interactions -from graphium.features.commute import compute_commute_distances -from graphium.features.graphormer import compute_graphormer_distances -from graphium.features.transfer_pos_level import transfer_pos_level - - -def get_all_positional_encodings( - adj: Union[np.ndarray, spmatrix], - num_nodes: int, - pos_kwargs: Optional[Dict] = None, -) -> Tuple["OrderedDict[str, np.ndarray]"]: - r""" - Get features positional encoding. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - num_nodes: Number of nodes in the graph - pos_encoding_as_features: keyword arguments for function `graph_positional_encoder` - to generate positional encoding for node features. - - Returns: - pe_dict: Dictionary of positional and structural encodings - """ - - pos_kwargs = {} if pos_kwargs is None else pos_kwargs - - pe_dict = OderedDictClass() - - # Initialize cache - cache = {} - - # Get the positional encoding for the features - if len(pos_kwargs) > 0: - for pos_name, this_pos_kwargs in pos_kwargs["pos_types"].items(): - this_pos_kwargs = deepcopy(this_pos_kwargs) - pos_type = this_pos_kwargs.pop("pos_type", None) - pos_level = this_pos_kwargs.pop("pos_level", None) - this_pe, cache = graph_positional_encoder( - deepcopy(adj), - num_nodes, - pos_type=pos_type, - pos_level=pos_level, - pos_kwargs=this_pos_kwargs, - cache=cache, - ) - if pos_level == "node": - pe_dict.update({f"{pos_type}": this_pe}) - else: - pe_dict.update({f"{pos_level}_{pos_type}": this_pe}) - - return pe_dict - - -def graph_positional_encoder( - adj: Union[np.ndarray, spmatrix], - num_nodes: int, - pos_type: Optional[str] = None, - pos_level: Optional[str] = None, - pos_kwargs: Optional[Dict[str, Any]] = None, - cache: Optional[Dict[str, Any]] = None, -) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: - r""" - Get a positional encoding that depends on the parameters. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - num_nodes: Number of nodes in the graph - pos_type: The type of positional encoding to use. If None, it must be provided by `pos_kwargs["pos_type"]`. Supported types are: - - laplacian_eigvec \ - - laplacian_eigval \ -> cache connected comps. & eigendecomp. - - rwse - - electrostatic \ - - commute \ -> cache pinvL - - graphormer - pos_level: Positional level to output. If None, it must be provided by `pos_kwargs["pos_level"]`. - - node - - edge - - nodepair - - graph - pos_kwargs: Extra keyword arguments for the positional encoding. Can include the keys pos_type and pos_level. - cache: Dictionary of cached objects - - Returns: - pe: Positional or structural encoding - cache: Updated dictionary of cached objects - """ - - pos_kwargs = deepcopy(pos_kwargs) - if pos_kwargs is None: - pos_kwargs = {} - if cache is None: - cache = {} - - # Get the positional type - pos_type2 = pos_kwargs.pop("pos_type", None) - if pos_type is None: - pos_type = pos_type2 - if pos_type2 is not None: - assert ( - pos_type == pos_type2 - ), f"The positional type must be the same in `pos_type` and `pos_kwargs['pos_type']`. Provided: {pos_type} and {pos_type2}" - assert pos_type is not None, "Either `pos_type` or `pos_kwargs['pos_type']` must be provided." - - # Get the positional level - pos_level2 = pos_kwargs.pop("pos_level", None) - if pos_level is None: - pos_level = pos_level2 - if pos_level2 is not None: - assert ( - pos_level == pos_level2 - ), f"The positional level must be the same in `pos_level` and `pos_kwargs['pos_level']`. Provided: {pos_level} and {pos_level2}" - assert pos_level is not None, "Either `pos_level` or `pos_kwargs['pos_level']` must be provided." - - # Convert to numpy array - if isinstance(adj, torch.sparse.Tensor): - adj = adj.to_dense().numpy() - elif isinstance(adj, torch.Tensor): - adj = adj.numpy() - adj = adj.astype(np.float64) - - # Calculate positional encoding - if pos_type == "laplacian_eigvec": - _, pe, base_level, cache = compute_laplacian_pe(adj, cache=cache, **pos_kwargs) - - elif pos_type == "laplacian_eigval": - pe, _, base_level, cache = compute_laplacian_pe(adj, cache=cache, **pos_kwargs) - - elif pos_type == "rw_return_probs": - pe, base_level, cache = compute_rwse( - adj.astype(np.float32), num_nodes=num_nodes, cache=cache, pos_type=pos_type, **pos_kwargs - ) - - elif pos_type == "rw_transition_probs": - pe, base_level, cache = compute_rwse( - adj.astype(np.float32), num_nodes=num_nodes, cache=cache, pos_type=pos_type, **pos_kwargs - ) - - elif pos_type == "electrostatic": - pe, base_level, cache = compute_electrostatic_interactions(adj, cache, **pos_kwargs) - - elif pos_type == "commute": - pe, base_level, cache = compute_commute_distances(adj, num_nodes, cache, **pos_kwargs) - - elif pos_type == "graphormer": - pe, base_level, cache = compute_graphormer_distances(adj, num_nodes, cache, **pos_kwargs) - - else: - raise ValueError(f"Unknown `pos_type`: {pos_type}") - - # Convert to float32 and Convert between different pos levels - if isinstance(pe, (list, tuple)): - pe = [this_pe.astype(np.float32) for this_pe in pe] - pe = [transfer_pos_level(this_pe, base_level, pos_level, adj, num_nodes, cache) for this_pe in pe] - else: - pe = np.real(pe).astype(np.float32) - pe = transfer_pos_level(pe, base_level, pos_level, adj, num_nodes, cache) - - return pe, cache diff --git a/graphium/features/properties.py b/graphium/features/properties.py deleted file mode 100644 index 89a90ffee..000000000 --- a/graphium/features/properties.py +++ /dev/null @@ -1,127 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Union, List, Callable - -import numpy as np -import datamol as dm - -from rdkit.Chem import rdMolDescriptors as rdMD -from loguru import logger - - -def get_prop_or_none( - prop: Callable, n: int, *args: Union[dm.Mol, str], **kwargs: Union[dm.Mol, str] -) -> Union[List[float], List[None]]: - r""" - return properties. If error, return list of `None` with lenght `n`. - Parameters: - prop: The property to compute. - n: The number of elements in the property. - *args: The arguments to pass to the property. - **kwargs: The keyword arguments to pass to the property. - Returns: - The property or a list of `None` with lenght `n`. - """ - logger.warning("get_prop_or_none is deprecated. Use `datamol.to_fp` instead.") - try: - return prop(*args, **kwargs) - except RuntimeError: - return [None] * n - - -def get_props_from_mol( - mol: Union[dm.Mol, str], - properties: Union[List[str], str] = "autocorr3d", -) -> np.ndarray: - r""" - Function to get a given set of desired properties from a molecule, - and output a property list. - - Parameters: - mol: The molecule from which to compute the properties. - properties: - The list of properties to compute for each molecule. It can be the following: - - - 'descriptors' - - 'autocorr3d' - - 'rdf' - - 'morse' - - 'whim' - - 'all' - - Returns: - props: np.array(float) - The array of properties for the desired molecule - classes_start_idx: list(int) - The list of index specifying the start of each new class of - descriptor or property. For example, if props has 20 elements, - the first 5 are rotatable bonds, the next 8 are morse, and - the rest are whim, then ``classes_start_idx = [0, 5, 13]``. - This will mainly be useful to normalize the features of - each class. - classes_names: list(str) - The name of the classes associated to each starting index. - Will be usefull to understand what property is the network learning. - - """ - - logger.warning("get_props_from_mol is deprecated. Use `datamol.to_fp` instead.") - - if isinstance(mol, str): - mol = dm.to_mol( - mol - ) # Doesn't need `ordered=True` because the fingerprints don't depend on the atom order - - if isinstance(properties, str): - properties = [properties] - - properties = [p.lower() for p in properties] - - # Initialize arrays - props = [] # Property vector for the features - classes_start_idx = [] # The starting index for each property class - classes_names = [] - - # Generate a 3D structure for the molecule - mol = dm.add_hs(mol) - - if ("autocorr3d" in properties) or ("all" in properties): - # Some kind of 3D description of the molecule - classes_names.append("autocorr3d") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcAUTOCORR3D, 80, mol)) - - if ("rdf" in properties) or ("all" in properties): - # The radial distribution function (better than the inertia) - # https://en.wikipedia.org/wiki/Radial_distribution_function - classes_names.append("rdf") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcRDF, 210, mol)) - - if ("morse" in properties) or ("all" in properties): - # Molecule Representation of Structures based on Electron diffraction descriptors - classes_names.append("morse") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcMORSE, 224, mol)) - - if ("whim" in properties) or ("all" in properties): - # WHIM descriptors are 3D structural descriptors obtained from the - # (x,y,z)‐atomic coordinates of a molecular conformation of a chemical, - # and are used successfully in QSAR modelling. - classes_names.append("whim") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcWHIM, 114, mol)) - - return np.array(props), classes_start_idx, classes_names diff --git a/graphium/features/rw.py b/graphium/features/rw.py deleted file mode 100644 index c7eada2ba..000000000 --- a/graphium/features/rw.py +++ /dev/null @@ -1,169 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Optional, List, Dict, Any, Iterable - -from scipy.sparse import issparse, spmatrix, coo_matrix -import numpy as np -import torch - -from torch_geometric.utils import to_dense_adj, from_scipy_sparse_matrix -from torch_scatter import scatter_add -from torch_geometric.utils.num_nodes import maybe_num_nodes - - -def compute_rwse( - adj: Union[np.ndarray, spmatrix], - ksteps: Union[int, List[int]], - num_nodes: int, - cache: Dict[str, Any], - pos_type: str = "rw_return_probs" or "rw_transition_probs", - space_dim: int = 0, -) -> Tuple[np.ndarray, str, Dict[str, Any]]: - """ - Compute Random Walk Spectral Embedding (RWSE) for given list of K steps. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix - ksteps: List of numbers of steps for the random walks. If int, a list is generated from 1 to ksteps. - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - pos_type: Desired output - space_dim: Estimated dimensionality of the space. Used to - correct the random-walk diagonal by a factor `k^(space_dim/2)`. - In euclidean space, this correction means that the height of - the gaussian distribution stays almost constant across the number of - steps, if `space_dim` is the dimension of the euclidean space. - Returns: - Two possible outputs: - rw_return_probs [num_nodes, len(ksteps)]: Random-Walk k-step landing probabilities - rw_transition_probs [num_nodes, num_nodes, len(ksteps)]: Random-Walk k-step transition probabilities - base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here either node or nodepair - cache: Updated dictionary of cached objects - """ - - base_level = "node" if pos_type == "rw_return_probs" else "nodepair" - - # Manually handles edge case of 1 atom molecules here - if not isinstance(ksteps, Iterable): - ksteps = list(range(1, ksteps + 1)) - if num_nodes == 1: - if pos_type == "rw_return_probs": - return np.ones((1, len(ksteps))), base_level, cache - else: - return np.ones((1, 1, len(ksteps))), base_level, cache - - # Get the edge indices from the adjacency matrix - if not issparse(adj): - if "coo_adj" in cache: - adj = cache["coo_adj"] - elif "csr_adj" in cache: - adj = cache["csr_adj"] - else: - adj = coo_matrix(adj, dtype=np.float64) - cache["coo_adj"] = adj - - edge_index, edge_weight = from_scipy_sparse_matrix(adj) - - # Compute the random-walk transition probabilities - if "ksteps" in cache: - cached_k = cache["ksteps"] - missing_k = [k for k in ksteps if k not in cached_k] - if missing_k == []: - pass - elif min(missing_k) < min(cached_k): - Pk_dict = get_Pks(missing_k, edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes) - cache["ksteps"] = sorted(missing_k + cache["ksteps"]) - for k in missing_k: - cache["Pk"][k] = Pk_dict[k] - else: - start_k = min([max(cached_k), min(missing_k)]) - start_Pk = cache["Pk"][start_k] - Pk_dict = get_Pks( - missing_k, - edge_index=edge_index, - edge_weight=edge_weight, - num_nodes=num_nodes, - start_Pk=start_Pk, - start_k=start_k, - ) - cache["ksteps"] = sorted(cache["ksteps"] + missing_k) - for k in missing_k: - cache["Pk"][k] = Pk_dict[k] - else: - Pk_dict = get_Pks(ksteps, edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes) - - cache["ksteps"] = list(Pk_dict.keys()) - cache["Pk"] = Pk_dict - - pe_list = [] - if pos_type == "rw_return_probs": - for k in ksteps: - pe_list.append(torch.diagonal(cache["Pk"][k], dim1=-2, dim2=-1) * (k ** (space_dim / 2))) - else: - for k in ksteps: - pe_list.append(cache["Pk"][k]) - - pe = torch.stack(pe_list, dim=-1).numpy() - - return pe, base_level, cache - - -def get_Pks( - ksteps: List[int], - edge_index: Tuple[torch.Tensor, torch.Tensor], - edge_weight: Optional[torch.Tensor] = None, - num_nodes: Optional[int] = None, - start_Pk: Optional[torch.Tensor] = None, - start_k: Optional[int] = None, -) -> Dict[int, np.ndarray]: - """ - Compute Random Walk landing probabilities for given list of K steps. - - Parameters: - ksteps: List of numbers of k-steps for which to compute the RW landings - edge_index: PyG sparse representation of the graph - edge_weight: Edge weights - num_nodes: Number of nodes in the graph - - Returns: - 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs - """ - if edge_weight is None: - edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) - num_nodes = maybe_num_nodes(edge_index, num_nodes) - src = edge_index[0] - deg = scatter_add(edge_weight, src, dim=0, dim_size=num_nodes) # Out degrees. - deg_inv = deg.pow(-1.0) - deg_inv.masked_fill_(deg_inv == float("inf"), 0) - - if edge_index.numel() == 0: - P = edge_index.new_zeros((1, num_nodes, num_nodes)) - else: - # P = D^-1 * A - P = torch.diag(deg_inv).float() @ to_dense_adj( - edge_index, max_num_nodes=num_nodes - ) # 1 x (Num nodes) x (Num nodes) - - if start_Pk is not None: - Pk = start_Pk @ P.clone().detach().matrix_power(min(ksteps) - start_k) - else: - Pk = P.clone().detach().matrix_power(min(ksteps)) - - Pk_dict = {} - for k in range(min(ksteps), max(ksteps) + 1): - Pk_dict[k] = Pk.squeeze(0) - Pk = Pk @ P - - return Pk_dict diff --git a/graphium/features/spectral.py b/graphium/features/spectral.py deleted file mode 100644 index 55d8527a4..000000000 --- a/graphium/features/spectral.py +++ /dev/null @@ -1,218 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Dict, Any -from scipy.linalg import eig -from scipy.sparse import csr_matrix, diags, issparse, spmatrix -import numpy as np -import torch -import networkx as nx - -from graphium.utils.tensor import is_dtype_torch_tensor, is_dtype_numpy_array - - -def compute_laplacian_pe( - adj: Union[np.ndarray, spmatrix], - num_pos: int, - cache: Dict[str, Any], - disconnected_comp: bool = True, - normalization: str = "none", -) -> Tuple[np.ndarray, str, Dict[str, Any]]: - r""" - Compute the Laplacian eigenvalues and eigenvectors of the Laplacian of the graph. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - num_pos: Number of Laplacian eigenvectors to compute - cache: Dictionary of cached objects - disconnected_comp: Whether to compute the eigenvectors for each connected component - normalization: Normalization to apply to the Laplacian - - Returns: - Two possible outputs: - eigvals [num_nodes, num_pos]: Eigenvalues of the Laplacian repeated for each node. - This repetition is necessary in case of disconnected components, where - the eigenvalues of the Laplacian are not the same for each node. - eigvecs [num_nodes, num_pos]: Eigenvectors of the Laplacian - base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here node - cache: Updated dictionary of cached objects - """ - - base_level = "node" - - # Sparsify the adjacency patrix - if not issparse(adj): - if "csr_adj" not in cache: - adj = csr_matrix(adj, dtype=np.float64) - cache["csr_adj"] = adj - else: - adj = cache["csr_adj"] - - # Compute the Laplacian, and normalize it - if f"L_{normalization}_sp" not in cache: - D = np.array(np.sum(adj, axis=1)).flatten() - D_mat = diags(D) - L = -adj + D_mat - L_norm = normalize_matrix(L, degree_vector=D, normalization=normalization) - cache[f"L_{normalization}_sp"] = L_norm - else: - L_norm = cache[f"L_{normalization}_sp"] - - components = [] - - if disconnected_comp: - if "components" not in cache: - # Get the list of connected components - components = list(nx.connected_components(nx.from_scipy_sparse_array(adj))) - cache["components"] = components - - else: - components = cache["components"] - - # Compute the eigenvectors for each connected component, and stack them together - if len(components) > 1: - if "lap_eig_comp" not in cache: - eigvals = np.zeros((adj.shape[0], num_pos), dtype=np.complex64) - eigvecs = np.zeros((adj.shape[0], num_pos), dtype=np.complex64) - for component in components: - comp = list(component) - this_L = L_norm[comp][:, comp] - this_eigvals, this_eigvecs = _get_positional_eigvecs(this_L, num_pos=num_pos) - - # Eigenvalues previously set to infinity are now set to 0 - # Any NaN in the eigvals or eigvecs will be set to 0 - this_eigvecs[~np.isfinite(this_eigvecs)] = 0.0 - this_eigvals[~np.isfinite(this_eigvals)] = 0.0 - - eigvals[comp, :] = np.expand_dims(this_eigvals, axis=0) - eigvecs[comp, :] = this_eigvecs - cache["lap_eig_comp"] = (eigvals, eigvecs) - - else: - eigvals, eigvecs = cache["lap_eig_comp"] - - else: - if "lap_eig" not in cache: - eigvals, eigvecs = _get_positional_eigvecs(L, num_pos=num_pos) - - # Eigenvalues previously set to infinity are now set to 0 - # Any NaN in the eigvals or eigvecs will be set to 0 - eigvecs[~np.isfinite(eigvecs)] = 0.0 - eigvals[~np.isfinite(eigvals)] = 0.0 - eigvals = np.repeat(np.expand_dims(eigvals, axis=0), adj.shape[0], axis=0) - - cache["lap_eig"] = (eigvals, eigvecs) - - else: - eigvals, eigvecs = cache["lap_eig"] - - return eigvals, eigvecs, base_level, cache - - -def _get_positional_eigvecs( - matrix: Union[np.ndarray, spmatrix], - num_pos: int, -) -> Tuple[np.ndarray, np.ndarray]: - r""" - compute the eigenvalues and eigenvectors of a matrix - Parameters: - matrix: Matrix to compute the eigenvalues and eigenvectors of - num_pos: Number of eigenvalues and eigenvectors to compute - Returns: - eigvals: Eigenvalues of the matrix - eigvecs: Eigenvectors of the matrix - """ - mat_len = matrix.shape[0] - eigvals, eigvecs = eig(matrix.todense()) - - # Pad with non-sense eigenvectors if required - if num_pos > mat_len: - temp_EigVal = np.ones(num_pos - mat_len, dtype=np.float64) + float("inf") - temp_EigVec = np.zeros((mat_len, num_pos - mat_len), dtype=np.float64) - eigvals = np.concatenate([eigvals, temp_EigVal], axis=0) - eigvecs = np.concatenate([eigvecs, temp_EigVec], axis=1) - - # Sort and keep only the first `num_pos` elements - sort_idx = eigvals.argsort() - eigvals = eigvals[sort_idx] - eigvals = eigvals[:num_pos] - eigvecs = eigvecs[:, sort_idx] - eigvecs = eigvecs[:, :num_pos] - - # Normalize the eigvecs - eigvecs = eigvecs / np.maximum(np.sqrt(np.sum(eigvecs**2, axis=0, keepdims=True)), 1e-4) - - return eigvals, eigvecs - - -def normalize_matrix( - matrix: Union[np.ndarray, spmatrix], - degree_vector=None, - normalization: str = None, -) -> Union[np.ndarray, spmatrix]: - r""" - Normalize a given matrix using its degree vector - - Parameters - --------------- - - matrix: torch.tensor(N, N) or scipy.sparse.spmatrix(N, N) - A square matrix representing either an Adjacency matrix or a Laplacian. - - degree_vector: torch.tensor(N) or np.ndarray(N) or None - A vector representing the degree of ``matrix``. - ``None`` is only accepted if ``normalization==None`` - - normalization: str or None, Default='none' - Normalization to use on the eig_matrix - - - 'none' or ``None``: no normalization - - - 'sym': Symmetric normalization ``D^-0.5 L D^-0.5`` - - - 'inv': Inverse normalization ``D^-1 L`` - - Returns - ----------- - matrix: torch.tensor(N, N) or scipy.sparse.spmatrix(N, N) - The normalized matrix - - """ - - # Transform the degree vector into a matrix - if degree_vector is None: - if not ((normalization is None) or (normalization.lower() == "none")): - raise ValueError("`degree_vector` cannot be `None` if `normalization` is not `None`") - else: - if is_dtype_numpy_array(matrix.dtype): - with np.errstate(divide="ignore", invalid="ignore"): - degree_inv = np.expand_dims(degree_vector**-0.5, axis=1) - degree_inv[np.isinf(degree_inv)] = 0 - elif is_dtype_torch_tensor(matrix.dtype): - degree_inv = torch.unsqueeze(degree_vector**-0.5, dim=1) - degree_inv[torch.isinf(degree_inv)] = 0 - - # Compute the normalized matrix - if (normalization is None) or (normalization.lower() == "none"): - pass - elif normalization.lower() == "sym": - matrix = degree_inv * matrix * degree_inv.T - elif normalization.lower() == "inv": - matrix = (degree_inv**2) * matrix - else: - raise ValueError( - f'`normalization` should be `None`, `"None"`, `"sym"` or `"inv"`, but `{normalization}` was provided' - ) - - return matrix diff --git a/graphium/features/transfer_pos_level.py b/graphium/features/transfer_pos_level.py deleted file mode 100644 index 4bb70e160..000000000 --- a/graphium/features/transfer_pos_level.py +++ /dev/null @@ -1,376 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, List, Dict, Any, Optional - -import numpy as np - -from scipy.sparse import spmatrix, issparse, coo_matrix - -from torch_geometric.utils import from_scipy_sparse_matrix - - -def transfer_pos_level( - pe: np.ndarray, - in_level: str, - out_level: str, - adj: Union[np.ndarray, spmatrix], - num_nodes: int, - cache: Optional[Dict[str, Any]] = None, -) -> np.ndarray: - r""" - Transfer positional encoding between different positional levels (node, edge, nodepair, graph) - - Parameters: - pe: Input pe with pos_level defined by in_level - in_level: pos_level of input pe - out_level: desired pos_level of output pe - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - - Returns: - pe: Output pe with pos_level defined by out_level - """ - - if cache is None: - cache = {} - - if in_level == "node": - if out_level == "node": - pass - - elif out_level == "edge": - pe, cache = node_to_edge(pe, adj, cache) - - elif out_level == "nodepair": - pe = node_to_nodepair(pe, num_nodes) - - elif out_level == "graph": - raise NotImplementedError("Transfer function (node -> graph) not yet implemented.") - - else: - raise ValueError(f"Unknown `pos_level`: {out_level}") - - elif in_level == "edge": - raise NotImplementedError("Transfer function (edge -> *) not yet implemented.") - - elif in_level == "nodepair": - if len(pe.shape) == 2: - pe = np.expand_dims(pe, -1) - - if out_level == "node": - pe = nodepair_to_node(pe) - - elif out_level == "edge": - pe, cache = nodepair_to_edge(pe, adj, cache) - - elif out_level == "nodepair": - pass - - elif out_level == "graph": - raise NotImplementedError("Transfer function (nodepair -> graph) not yet implemented.") - - else: - raise ValueError(f"Unknown `pos_level`: {out_level}") - - elif in_level == "graph": - if out_level == "node": - pe = graph_to_node(pe, num_nodes, cache) - - elif out_level in ["edge", "nodepair"]: - raise NotImplementedError("Transfer function (graph -> edge/nodepair) not yet implemented.") - - else: - raise ValueError(f"Unknown `pos_level`: {out_level}") - - else: - raise ValueError(f"Unknown `pos_level`: {in_level}") - - return pe - - -# Transfer functions between different levels, i.e., node, edge, nodepair and graph level. - -# TODO: -# - Implement missing transfer functions below -# - Are transfer functions graph -> edge/nodepair and edge -> graph needed? - - -def node_to_edge( - pe: np.ndarray, adj: Union[np.ndarray, spmatrix], cache: Optional[Dict[str, Any]] = None -) -> Tuple[np.ndarray, Dict[str, Any]]: - r""" - Get an edge-level positional encoding from a node-level positional encoding. - -> For each edge, concatenate the sum and absolute difference of pe of source and destination node. - - Parameters: - pe [num_nodes, num_feat]: Node-level positional encoding - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - cache: Dictionary of cached objects - - Returns: - edge_pe [2 * num_edges, 2 * num_feat]: Edge-level positional encoding - cache: Updated dictionary of cached objects - """ - - if cache is None: - cache = {} - - if not issparse(adj): - if "coo_adj" in cache: - adj = cache["coo_adj"] - elif "csr_adj" in cache: - adj = cache["csr_adj"] - else: - adj = coo_matrix(adj, dtype=np.float64) - cache["coo_adj"] = adj - - edge_index, _ = from_scipy_sparse_matrix(adj) - src, dst = edge_index[0], edge_index[1] - - pe_sum = pe[src] + pe[dst] - pe_abs_diff = np.abs(pe[src] - pe[dst]) - - edge_pe = np.concatenate((pe_sum, pe_abs_diff), axis=-1) - - return edge_pe, cache - - -def node_to_nodepair(pe: np.ndarray, num_nodes: int) -> np.ndarray: - r""" - Get a nodepair-level positional encoding from a node-level positional encoding. - -> For each nodepair (i,j) concatenate the sum and absolute difference of pe at node i and j. - - Parameters: - pe [num_nodes, num_feat]: Node-level positional encoding - num_nodes: Number of nodes in the graph - - Returns: - nodepair_pe [num_nodes, num_nodes, 2 * num_feat]: Nodepair-level positional encoding - """ - - expanded_pe = np.expand_dims(pe, axis=1) - expanded_pe = np.repeat(expanded_pe, repeats=num_nodes, axis=1) - - pe_sum = expanded_pe + expanded_pe.transpose([1, 0, 2]) - pe_abs_diff = np.abs(expanded_pe - expanded_pe.transpose([1, 0, 2])) - - nodepair_pe = np.concatenate((pe_sum, pe_abs_diff), axis=-1) - - return nodepair_pe - - -def node_to_graph(pe: np.ndarray, num_nodes: int) -> np.ndarray: - r""" - Get a graph-level positional encoding from a node-level positional encoding. - -> E.g., min/max/mean-pooling of node features. - - Parameters: - pe [num_nodes, num_feat]: Node-level positional encoding - num_nodes: Number of nodes in the graph - - Returns: - graph_pe [1, num_feat]: Graph-level positional encoding - """ - - raise NotImplementedError("Transfer function (node -> graph) not yet implemented.") - - -def edge_to_node(pe: np.ndarray, adj: Union[np.ndarray, spmatrix]) -> np.ndarray: - r""" - Get a node-level positional encoding from an edge-level positional encoding. - -> E.g., min/max/mean-pooling of information from edges (i,j) that contain node i - - Parameters: - pe [num_edges, num_feat]: Edge-level positional encoding - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - - Returns: - node_pe [num_edges, num_feat]: Node-level positional encoding - """ - - raise NotImplementedError("Transfer function (edge -> node) not yet implemented.") - - -def edge_to_nodepair( - pe: np.ndarray, adj: Union[np.ndarray, spmatrix], num_nodes: int, cache: Optional[Dict[str, Any]] = None -) -> np.ndarray: - r""" - Get a nodepair-level positional encoding from an edge-level positional encoding. - -> Zero-padding of non-existing edges. - - Parameters: - pe [num_edges, num_feat]: Edge-level positional encoding - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - - Returns: - nodepair_pe [num_edges, num_edges, num_feat]: Nodepair-level positional encoding - cache: Updated dictionary of cached objects - """ - - if cache is None: - cache = {} - - num_feat = pe.shape[-1] - - if not isinstance(adj, coo_matrix): - if "coo_adj" in cache: - adj = cache["coo_adj"] - else: - adj = coo_matrix(adj, dtype=np.float64) - cache["coo_adj"] = adj - - dst, src = adj.row, adj.col - - nodepair_pe = np.zeros((num_nodes, num_nodes, num_feat)) - - for i in range(len(dst)): - nodepair_pe[dst[i], src[i], ...] = pe[i, ...] - - return nodepair_pe, cache - - -def edge_to_graph(pe: np.ndarray) -> np.ndarray: - r""" - Get a graph-level positional encoding from an edge-level positional encoding. - - Parameters: - pe [num_edges, num_feat]: Edge-level positional encoding - - Returns: - graph_pe [1, num_feat]: Graph-level positional encoding - """ - - raise NotImplementedError("Transfer function (edge -> graph) not yet implemented.") - - -def nodepair_to_node(pe: np.ndarray, stats_list: List = [np.min, np.mean, np.std]) -> np.ndarray: - r""" - Get a node-level positional encoding from a graph-level positional encoding. - -> Calculate statistics over rows & cols of input positional encoding - - Parameters: - pe [num_nodes, num_nodes, num_feat]: Nodepair-level positional encoding - stats_list: List of statistics to calculate per row/col of nodepair-level pe - - Returns: - node_pe [num_nodes, 2 * len(stats_list) * num_feat]: Node-level positional encoding - """ - - num_feat = pe.shape[-1] - - node_pe_list = [] - - for stat in stats_list: - for i in range(num_feat): - node_pe_list.append(stat(pe[..., i], axis=0)) - node_pe_list.append(stat(pe[..., i], axis=1)) - node_pe = np.stack(node_pe_list, axis=-1) - - return node_pe - - -def nodepair_to_edge( - pe: np.ndarray, adj: Union[np.ndarray, spmatrix], cache: Optional[Dict[str, Any]] = None -) -> np.ndarray: - r""" - Get a edge-level positional encoding from a nodepair-level positional encoding. - -> Mask and sparsify nodepair-level positional encoding - - Parameters: - pe [num_nodes, num_nodes, num_feat]: Nodepair-level positional encoding - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - cache: Dictionary of cached objects - - Returns: - edge_pe [num_edges, num_feat]: Edge-level positional encoding - cache: Updated dictionary of cached objects - """ - - if cache is None: - cache = {} - - num_feat = pe.shape[-1] - - if not isinstance(adj, coo_matrix): - if "coo_adj" in cache: - adj = cache["coo_adj"] - else: - adj = coo_matrix(adj, dtype=np.float64) - cache["coo_adj"] = adj - - dst, src = adj.row, adj.col - - edge_pe = np.zeros((len(dst), num_feat)) - - for i in range(len(src)): - edge_pe[i, ...] = pe[dst[i], src[i]] - - return edge_pe, cache - - -def nodepair_to_graph(pe: np.ndarray, num_nodes: int) -> np.ndarray: - r""" - Get a graph-level positional encoding from a nodepair-level positional encoding. - -> E.g., min/max/mean-pooling of entries of input pe - - Parameters: - pe [num_nodes, num_nodes, num_feat]: Nodepair-level positional encoding - num_nodes: Number of nodes in the graph - - Returns: - graph_pe [1, num_feat]: Graph-level positional encoding - """ - - raise NotImplementedError("Transfer function (nodepair -> graph) not yet implemented.") - - -def graph_to_node( - pe: Union[np.ndarray, List], num_nodes: int, cache: Optional[Dict[str, Any]] = None -) -> np.ndarray: - r""" - Get a node-level positional encoding from a nodepair-level positional encoding. - -> E.g., expand dimension of graph-level pe - - Parameters: - pe [num_feat]: Nodepair-level positional encoding (or list of them if graph disconnected) - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - - Returns: - node_pe [num_nodes, num_feat]: Node-level positional encoding - """ - - if cache is None: - cache = {} - - node_pe = None - - # The key 'components' is only in cache if disconnected_comp == True when computing base pe - if "components" in cache: - if len(cache["components"]) > 1: - node_pe = np.zeros((num_nodes, len(pe))) - components = cache["components"] - - for i, component in enumerate(components): - comp = list(component) - node_pe[comp, :] = np.real(pe[i]) - - if node_pe is None: - node_pe = np.tile(pe, (num_nodes, 1)) - - return node_pe diff --git a/graphium/graphium_cpp/commute.cpp b/graphium/graphium_cpp/commute.cpp new file mode 100644 index 000000000..a81156815 --- /dev/null +++ b/graphium/graphium_cpp/commute.cpp @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "commute.h" + +#include "electrostatic.h" +#include "spectral.h" + +#include +#include + +template +void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const T* weights) { + + if (laplacian_pseudoinverse.size() == 0) { + compute_laplacian_pseudoinverse(n, row_starts, neighbors, data, laplacian_pseudoinverse, weights); + } + + T full_sum = T(0); + if (weights != nullptr) { + for (size_t i = 0, weights_size = row_starts[n]; i < weights_size; ++i) { + full_sum += weights[i]; + } + } + else { + // Unweighted, so just twice the unique edge count + // (each edge appears twice in neighbors) + full_sum = T(row_starts[n]); + } + + matrix.resize(n * n); + + for (size_t row = 0, row_diag_index = 0, i = 0; row < n; ++row, row_diag_index += (n + 1)) { + for (size_t col = 0, col_diag_index = 0; col < n; ++col, ++i, col_diag_index += (n + 1)) { + matrix[i] = full_sum * ( + laplacian_pseudoinverse[row_diag_index] + + laplacian_pseudoinverse[col_diag_index] + - 2 * laplacian_pseudoinverse[row*n + col]); + } + } +} + +template +void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const float* weights); +template +void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const double* weights); diff --git a/graphium/graphium_cpp/commute.h b/graphium/graphium_cpp/commute.h new file mode 100644 index 000000000..a8611d74c --- /dev/null +++ b/graphium/graphium_cpp/commute.h @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "spectral.h" + +#include +#include + +template +void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const T* weights = nullptr); + +extern template +void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const float* weights); +extern template +void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const double* weights); diff --git a/graphium/graphium_cpp/electrostatic.cpp b/graphium/graphium_cpp/electrostatic.cpp new file mode 100644 index 000000000..56efd2f5c --- /dev/null +++ b/graphium/graphium_cpp/electrostatic.cpp @@ -0,0 +1,106 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "electrostatic.h" + +#include "spectral.h" + +#include +#include +#include + +template +void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const T* weights) { + + // If we've already computed the eigendecomposition with the correct normalization, + // skip recomputing it. + if (data.eigenvalues.size() != n || data.normalization != Normalization::NONE) { + compute_laplacian_eigendecomp(n, row_starts, neighbors, Normalization::NONE, data, 1, nullptr, weights); + } + + matrix.clear(); + matrix.resize(size_t(n) * n, T(0)); + const T maxEigenvalue = data.eigenvalues.back(); + // zero_threshold is an estimate of how accurately the diagonalization + // algorithm determines eigenvalues close to zero. Anything smaller + // should be considered zero for the pseudoinverse. + const T eigendecomp_relative_threshold = T(1e-6); + const T zero_threshold = n * eigendecomp_relative_threshold * maxEigenvalue; + for (size_t eigenIndex = 0; eigenIndex < n; ++eigenIndex) { + // This is a positive semi-definite matrix, so we don't need to take the absolute value + // when checking the threshold. + if (data.eigenvalues[eigenIndex] < zero_threshold) { + continue; + } + const T eigenvalueInverse = T(1) / data.eigenvalues[eigenIndex]; + const T* const eigenvector = data.vectors.data() + eigenIndex * n; + for (size_t row = 0, i = 0; row < n; ++row) { + for (size_t col = 0; col < n; ++col, ++i) { + const T value = eigenvalueInverse * eigenvector[row] * eigenvector[col]; + matrix[i] += value; + } + } + } +} + +template void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const float* weights); +template void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const double* weights); + +template +void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const T* weights) { + + if (laplacian_pseudoinverse.size() == 0) { + compute_laplacian_pseudoinverse(n, row_starts, neighbors, data, laplacian_pseudoinverse, weights); + } + + matrix.resize(n * n); + + // Subtract the diagonal value from each column + for (size_t row = 0, i = 0; row < n; ++row) { + for (size_t col = 0, diag_index = 0; col < n; ++col, ++i, diag_index += (n+1)) { + matrix[i] = laplacian_pseudoinverse[i] - laplacian_pseudoinverse[diag_index]; + } + } +} + +template void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const float* weights); +template void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const double* weights); diff --git a/graphium/graphium_cpp/electrostatic.h b/graphium/graphium_cpp/electrostatic.h new file mode 100644 index 000000000..575dc3f83 --- /dev/null +++ b/graphium/graphium_cpp/electrostatic.h @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "spectral.h" + +#include +#include + +template +void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const T* weights = nullptr); + +extern template void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const float* weights); +extern template void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const double* weights); + +template +void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const T* weights = nullptr); + +extern template void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const float* weights); +extern template void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const double* weights); diff --git a/graphium/graphium_cpp/features.cpp b/graphium/graphium_cpp/features.cpp new file mode 100644 index 000000000..f9357eaad --- /dev/null +++ b/graphium/graphium_cpp/features.cpp @@ -0,0 +1,1488 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#define DEBUG_LOGGING 0 + +#include "features.h" + +#include "commute.h" +#include "electrostatic.h" +#include "float_features.h" +#include "graphormer.h" +#include "one_hot.h" +#include "random_walk.h" +#include "spectral.h" + +#include // For RDKit's addHs +#include // For RDKit's EmbedMolecule + +#include + +static GraphData read_graph(const std::string& smiles_string, bool explicit_H) { + std::unique_ptr mol{ parse_mol(smiles_string, explicit_H) }; + + if (!mol) { + return GraphData{ 0, std::unique_ptr(), 0, std::unique_ptr(), std::move(mol) }; + } + + const size_t num_atoms = mol->getNumAtoms(); + const size_t num_bonds = mol->getNumBonds(); +#if DEBUG_LOGGING + printf("# atoms = %zu\n# bonds = %zu\n", num_atoms, num_bonds); +#endif +#if REPORT_STATS + ++statsMolAtomCounts[(num_atoms >= STATS_NUM_MOL_ATOM_COUNTS) ? (STATS_NUM_MOL_ATOM_COUNTS - 1) : num_atoms]; + ++statsMolBondCounts[(num_bonds >= STATS_NUM_MOL_BOND_COUNTS) ? (STATS_NUM_MOL_BOND_COUNTS - 1) : num_bonds]; + statsTotalNumAtoms += num_atoms; + statsTotalNumBonds += num_bonds; +#endif + +#if ORDER_ATOMS + // Determine a canonical ordering of the atoms, if desired. + std::vector atomOrder; + atomOrder.reserve(num_atoms); + RDKit::Canon::rankMolAtoms(*mol, atomOrder); + assert(atomOrder.size() == num_atoms); +#endif + + // Allocate an array of atom data, and fill it from the RDKit atom data. + std::unique_ptr atoms(new CompactAtom[num_atoms]); + for (size_t atomIdx = 0; atomIdx < num_atoms; ++atomIdx) { + const RDKit::Atom* const atom = mol->getAtomWithIdx(atomIdx); + auto atomicNum = atom->getAtomicNum(); + auto totalDegree = atom->getTotalDegree(); + auto formalCharge = atom->getFormalCharge(); + const RDKit::Atom::ChiralType chiralType = atom->getChiralTag(); + auto totalNumHs = atom->getTotalNumHs(); + const RDKit::Atom::HybridizationType hybridization = atom->getHybridization(); + + const bool isAromatic = atom->getIsAromatic(); +#if REPORT_STATS + ++statsElementCounts[(atomicNum < 0 || atomicNum >= STATS_NUM_ELEMENTS) ? (STATS_NUM_ELEMENTS - 1) : atomicNum]; + ++statsDegreeCounts[(totalDegree < 0 || totalDegree >= STATS_NUM_DEGREES) ? (STATS_NUM_DEGREES - 1) : totalDegree]; + size_t formalChargeIndex = formalCharge + int(STATS_CHARGE_OFFSET); + if (formalCharge < -int(STATS_CHARGE_OFFSET)) { + formalChargeIndex = 0; + } + else if (formalCharge > int(STATS_CHARGE_OFFSET)) { + formalChargeIndex = STATS_NUM_CHARGES - 1; + } + + ++statsChargeCounts[formalChargeIndex]; + ++statsChiralityCounts[(size_t(chiralType) >= STATS_NUM_CHIRALITIES) ? (STATS_NUM_CHIRALITIES - 1) : size_t(chiralType)]; + ++statsHCounts[(totalNumHs < 0 || totalNumHs >= STATS_NUM_HS) ? (STATS_NUM_HS - 1) : totalNumHs]; + ++statsHybridizationCounts[(size_t(hybridization) >= STATS_NUM_HYBRIDIZATIONS) ? (STATS_NUM_HYBRIDIZATIONS - 1) : size_t(hybridization)]; + statsAromaticAtomCount += (isAromatic ? 1 : 0); +#endif + const double mass = atom->getMass(); + +#if ORDER_ATOMS + const size_t destAtomIdx = atomOrder[atomIdx]; +#else + const size_t destAtomIdx = atomIdx; +#endif + atoms[destAtomIdx] = CompactAtom{ + uint8_t(atomicNum), + uint8_t(totalDegree), + int8_t(formalCharge), + uint8_t(chiralType), + uint8_t(totalNumHs), + uint8_t(hybridization), + isAromatic, + float(mass) + }; +#if DEBUG_LOGGING + printf( + "atom[%zu] = {%zu, %u, %d, %u, %u, %u, %s, %f}\n", + destAtomIdx, + int(atomicNum), + int(totalDegree), + int(formalCharge), + int(chiralType), + int(totalNumHs), + int(hybridization), + isAromatic ? "true" : "false", + mass + ); +#endif + } + + // Allocate an array of bond data, and fill it from the RDKit bond data. + std::unique_ptr bonds(new CompactBond[num_bonds]); + const RDKit::RingInfo* const ringInfo = mol->getRingInfo(); + for (size_t bondIdx = 0; bondIdx < num_bonds; ++bondIdx) { + const RDKit::Bond* const bond = mol->getBondWithIdx(bondIdx); + const RDKit::Bond::BondType bondType = bond->getBondType(); + const bool isConjugated = bond->getIsConjugated(); + // TODO: Verify that it's the same index as bond->getIdx() + const bool isInRing = (ringInfo->numBondRings(bondIdx) != 0); + const RDKit::Bond::BondStereo stereo = bond->getStereo(); + +#if REPORT_STATS + ++statsBondTypeCounts[(size_t(bondType) >= STATS_NUM_BOND_TYPES) ? (STATS_NUM_BOND_TYPES - 1) : size_t(bondType)]; + ++statsBondStereoCounts[(size_t(stereo) >= STATS_NUM_BOND_STEREOS) ? (STATS_NUM_BOND_STEREOS - 1) : size_t(stereo)]; + statsConjugatedBondCount += (isConjugated ? 1 : 0); + statsBondInRingCount += (isInRing ? 1 : 0); +#endif + + auto beginAtomIdx = bond->getBeginAtomIdx(); + auto endAtomIdx = bond->getEndAtomIdx(); +#if ORDER_ATOMS + beginAtomIdx = atomOrder[beginAtomIdx]; + endAtomIdx = atomOrder[endAtomIdx]; +#endif + bonds[bondIdx] = CompactBond{ + uint8_t(bondType), + isConjugated, + isInRing, + uint8_t(stereo), + beginAtomIdx, + endAtomIdx + }; +#if DEBUG_LOGGING + printf( + "bond[%zu] = {%u, %s, %s, %u, {%u, %u}}\n", + bondIdx, + int(bondType), + isConjugated ? "true" : "false", + isInRing ? "true" : "false", + int(stereo), + beginAtomIdx, + endAtomIdx + ); +#endif + } + + // Return a GraphData structure, taking ownership of the atom and bond data arrays. + return GraphData{ num_atoms, std::move(atoms), num_bonds, std::move(bonds), std::move(mol) }; +} + +// This is a structure for managing the adjacency data (CSR format) for use by randomSubgraph. +struct NeighbourData { + // This owns the data of all 3 arrays, which are actually a single, contiguous allocation. + std::unique_ptr deleter; + + // This is an array of indices into the other two arrays, indicating where + // each atom's neighbours start, including the first entry being 0 for the start of + // atom 0, and the num_atoms entry being 2*num_bonds (2x because each bond is on 2 atoms), + // so there are num_atoms+1 entries. The number of neighbours of an atom i is + // neighbour_starts[i+1]-neighbour_starts[i] + const uint32_t* neighbour_starts; + + // The neighbour atom for each bond, with each atom having an entry for each of + // its neighbours, so each bond occurs twice. + const uint32_t* neighbours; + + // This is in the same order as neighbours, but indicates the index of the bond. + // Each bond occurs twice, so each number occurs twice. + const uint32_t* bond_indices; +}; + +// Construct a NeighbourData structure representing the molecule's graph in CSR format. +static NeighbourData construct_neighbours(const GraphData& graph) { + const uint32_t num_atoms = graph.num_atoms; + const uint32_t num_bonds = graph.num_bonds; + // Do a single allocation for all 3 arrays. + std::unique_ptr deleter(new uint32_t[num_atoms + 1 + 4 * num_bonds]); + + uint32_t* neighbour_starts = deleter.get(); + for (uint32_t i = 0; i <= num_atoms; ++i) { + neighbour_starts[i] = 0; + } + + // First, get atom neighbour counts + const CompactBond* const bonds = graph.bonds.get(); + for (uint32_t i = 0; i < num_bonds; ++i) { + uint32_t a = bonds[i].beginAtomIdx; + uint32_t b = bonds[i].endAtomIdx; + // NOTE: +1 is because first entry will stay zero. + ++neighbour_starts[a + 1]; + ++neighbour_starts[b + 1]; + } + + // Find the starts by partial-summing the neighbour counts. + // NOTE: +1 is because first entry will stay zero. + std::partial_sum(neighbour_starts + 1, neighbour_starts + 1 + num_atoms, neighbour_starts + 1); + + // Fill in the neighbours and bond_indices arrays. + uint32_t* neighbours = neighbour_starts + num_atoms + 1; + uint32_t* bond_indices = neighbours + 2 * num_bonds; + for (uint32_t i = 0; i < num_bonds; ++i) { + uint32_t a = bonds[i].beginAtomIdx; + uint32_t b = bonds[i].endAtomIdx; + + uint32_t ai = neighbour_starts[a]; + neighbours[ai] = b; + bond_indices[ai] = i; + ++neighbour_starts[a]; + + uint32_t bi = neighbour_starts[b]; + neighbours[bi] = a; + bond_indices[bi] = i; + ++neighbour_starts[b]; + } + + // Shift neighbour_starts forward one after incrementing it. + uint32_t previous = 0; + for (uint32_t i = 0; i < num_atoms; ++i) { + uint32_t next = neighbour_starts[i]; + neighbour_starts[i] = previous; + previous = next; + } + + // NeighbourData takes ownership of the memory. + return NeighbourData{ std::move(deleter), neighbour_starts, neighbours, bond_indices }; +} + +// This fills in 3 values for each atom +template +at::Tensor get_conformer_features( + RDKit::ROMol &mol, + bool already_has_Hs, + c10::ScalarType dtype, + MaskNaNStyle mask_nan_style, + T mask_nan_value, + int64_t &num_nans, + const std::string& smiles_string) { + + const size_t n = mol.getNumAtoms(); + std::unique_ptr conformer_data(new T[3 * n]); + T* data = conformer_data.get(); + + std::unique_ptr mol_with_Hs_added; + RDKit::ROMol* mol_with_Hs = &mol; + if (mol.beginConformers() == mol.endConformers()) { + // No conformers. + // Before generating conformers, it's recommended to add Hs explicitly. + if (!already_has_Hs) { + // Add Hs. They're added at the end, so the original atoms + // will have the same indices as before. + mol_with_Hs_added.reset(new RDKit::RWMol(mol)); + RDKit::MolOps::addHs(*mol_with_Hs_added); + mol_with_Hs = mol_with_Hs_added.get(); + } + + // Default Python arguments to EmbedMolecule + int conformer_id = RDKit::DGeomHelpers::EmbedMolecule( + *mol_with_Hs, + 0, // maxIterations + -1, // seed + true, // clearConfs + false, // useRandomCoords + 2.0, // boxSizeMult + true, // randNedEig + 1, // numZeroFail + nullptr,// coordMap + 1e-3, // optimizerForceTol + false, // ignoreSmoothingFailures + true, // enforceChirality + true, // useExpTorsionAnglePrefs (default in Python; non-default in C++) + true, // useBasicKnowledge (default in Python; non-default in C++) + false, // verbose + 5.0, // basinThresh + false, // onlyHeavyAtomsForRMS + 1, // ETversion + false, // useSmallRingTorsions + false, // useMacrocycleTorsions + false // useMacrocycle14config + ); + + if (conformer_id == -1) { + // Custom arguments as fallback + RDKit::DGeomHelpers::EmbedMolecule( + *mol_with_Hs, + 0, // maxIterations + -1, // seed + true, // clearConfs + false, // useRandomCoords (TODO: consider using true) + 2.0, // boxSizeMult + true, // randNedEig + 1, // numZeroFail + nullptr,// coordMap + 0.1, // optimizerForceTol (changed) + true, // ignoreSmoothingFailures (changed) + false, // enforceChirality (changed) + true, // useExpTorsionAnglePrefs (default in Python; non-default in C++) + true, // useBasicKnowledge (default in Python; non-default in C++) + false, // verbose + 5.0, // basinThresh + false, // onlyHeavyAtomsForRMS + 1, // ETversion + false, // useSmallRingTorsions + false, // useMacrocycleTorsions + false // useMacrocycle14config + ); + } + } + if (mol_with_Hs->beginConformers() == mol_with_Hs->endConformers()) { + // Still no conformers: treat as NaN + for (size_t i = 0; i < 3 * n; ++i) { + data[i] = mask_nan_value; + } + if (mask_nan_style == MaskNaNStyle::REPORT) { + num_nans += 3*n; + } + printf("Warning: Couldn't compute conformer for molecule \"%s\"\n", smiles_string.c_str()); + } + else { + const RDKit::Conformer& conformer = mol_with_Hs->getConformer(); + const auto& positions = conformer.getPositions(); + assert(positions.size() >= n); + for (size_t i = 0; i < n; ++i, data += 3) { + const auto& position = positions[i]; + data[0] = FeatureValues::convertToFeatureType(position.x); + data[1] = FeatureValues::convertToFeatureType(position.y); + data[2] = FeatureValues::convertToFeatureType(position.z); + } + + num_nans += mask_nans(data, 3 * n, mask_nan_style, mask_nan_value); + } + + const int64_t dims[1] = { int64_t(3 * n) }; + return torch_tensor_from_array(std::move(conformer_data), dims, 1, dtype); +} + +static const std::unordered_map atom_float_name_to_enum{ + {std::string("atomic-number"), int64_t(AtomFloatFeature::ATOMIC_NUMBER)}, + {std::string("mass"), int64_t(AtomFloatFeature::MASS)}, + {std::string("weight"), int64_t(AtomFloatFeature::MASS)}, + {std::string("valence"), int64_t(AtomFloatFeature::VALENCE)}, + {std::string("total-valence"), int64_t(AtomFloatFeature::VALENCE)}, + {std::string("implicit-valence"), int64_t(AtomFloatFeature::IMPLICIT_VALENCE)}, + {std::string("hybridization"), int64_t(AtomFloatFeature::HYBRIDIZATION)}, + {std::string("chirality"), int64_t(AtomFloatFeature::CHIRALITY)}, + {std::string("aromatic"), int64_t(AtomFloatFeature::AROMATIC)}, + {std::string("ring"), int64_t(AtomFloatFeature::IN_RING)}, + {std::string("in-ring"), int64_t(AtomFloatFeature::IN_RING)}, + {std::string("min-ring"), int64_t(AtomFloatFeature::MIN_RING)}, + {std::string("max-ring"), int64_t(AtomFloatFeature::MAX_RING)}, + {std::string("num-ring"), int64_t(AtomFloatFeature::NUM_RING)}, + {std::string("degree"), int64_t(AtomFloatFeature::DEGREE)}, + {std::string("radical-electron"), int64_t(AtomFloatFeature::RADICAL_ELECTRON)}, + {std::string("formal-charge"), int64_t(AtomFloatFeature::FORMAL_CHARGE)}, + {std::string("vdw-radius"), int64_t(AtomFloatFeature::VDW_RADIUS)}, + {std::string("covalent-radius"), int64_t(AtomFloatFeature::COVALENT_RADIUS)}, + {std::string("electronegativity"), int64_t(AtomFloatFeature::ELECTRONEGATIVITY)}, + {std::string("ionization"), int64_t(AtomFloatFeature::IONIZATION)}, + {std::string("first-ionization"), int64_t(AtomFloatFeature::IONIZATION)}, + {std::string("melting-point"), int64_t(AtomFloatFeature::MELTING_POINT)}, + {std::string("metal"), int64_t(AtomFloatFeature::METAL)}, + {std::string("group"), int64_t(AtomFloatFeature::GROUP)}, + {std::string("period"), int64_t(AtomFloatFeature::PERIOD)}, + {std::string("single-bond"), int64_t(AtomFloatFeature::SINGLE_BOND)}, + {std::string("aromatic-bond"), int64_t(AtomFloatFeature::AROMATIC_BOND)}, + {std::string("double-bond"), int64_t(AtomFloatFeature::DOUBLE_BOND)}, + {std::string("triple-bond"), int64_t(AtomFloatFeature::TRIPLE_BOND)}, + {std::string("is-carbon"), int64_t(AtomFloatFeature::IS_CARBON)}, +}; + +at::Tensor atom_float_feature_names_to_tensor(const std::vector& features) { + const size_t num_features = features.size(); + std::unique_ptr feature_enum_values(new int64_t[num_features]); + for (size_t i = 0; i < num_features; ++i) { + auto it = atom_float_name_to_enum.find(features[i]); + if (it != atom_float_name_to_enum.end()) { + feature_enum_values[i] = it->second; + } + else { + feature_enum_values[i] = int64_t(AtomFloatFeature::UNKNOWN); + } + } + const int64_t dims[1] = { int64_t(num_features) }; + return torch_tensor_from_array(std::move(feature_enum_values), dims, 1, c10::ScalarType::Long); +} + +static const std::unordered_map atom_onehot_name_to_enum{ + {std::string("atomic-number"), int64_t(AtomOneHotFeature::ATOMIC_NUM)}, + {std::string("degree"), int64_t(AtomOneHotFeature::DEGREE)}, + {std::string("valence"), int64_t(AtomOneHotFeature::VALENCE)}, + {std::string("total-valence"), int64_t(AtomOneHotFeature::VALENCE)}, + {std::string("implicit-valence"), int64_t(AtomOneHotFeature::IMPLICIT_VALENCE)}, + {std::string("hybridization"), int64_t(AtomOneHotFeature::HYBRIDIZATION)}, + {std::string("chirality"), int64_t(AtomOneHotFeature::CHIRALITY)}, + {std::string("phase"), int64_t(AtomOneHotFeature::PHASE)}, + {std::string("type"), int64_t(AtomOneHotFeature::TYPE)}, + {std::string("group"), int64_t(AtomOneHotFeature::GROUP)}, + {std::string("period"), int64_t(AtomOneHotFeature::PERIOD)}, +}; + +at::Tensor atom_onehot_feature_names_to_tensor(const std::vector& features) { + const size_t num_features = features.size(); + std::unique_ptr feature_enum_values(new int64_t[num_features]); + for (size_t i = 0; i < num_features; ++i) { + auto it = atom_onehot_name_to_enum.find(features[i]); + if (it != atom_onehot_name_to_enum.end()) { + feature_enum_values[i] = it->second; + } + else { + feature_enum_values[i] = int64_t(AtomOneHotFeature::UNKNOWN); + } + } + const int64_t dims[1] = { int64_t(num_features) }; + return torch_tensor_from_array(std::move(feature_enum_values), dims, 1, c10::ScalarType::Long); +} + +static const std::unordered_map bond_name_to_enum{ + {std::string("bond-type-onehot"), int64_t(BondFeature::TYPE_ONE_HOT)}, + {std::string("bond-type-float"), int64_t(BondFeature::TYPE_FLOAT)}, + {std::string("stereo"), int64_t(BondFeature::STEREO_ONE_HOT)}, + {std::string("in-ring"), int64_t(BondFeature::IN_RING)}, + {std::string("conjugated"), int64_t(BondFeature::CONJUGATED)}, + {std::string("conformer-bond-length"), int64_t(BondFeature::CONFORMER_BOND_LENGTH)}, + {std::string("estimated-bond-length"), int64_t(BondFeature::ESTIMATED_BOND_LENGTH)}, +}; + +at::Tensor bond_feature_names_to_tensor(const std::vector& features) { + const size_t num_features = features.size(); + std::unique_ptr feature_enum_values(new int64_t[num_features]); + for (size_t i = 0; i < num_features; ++i) { + auto it = bond_name_to_enum.find(features[i]); + if (it != bond_name_to_enum.end()) { + feature_enum_values[i] = it->second; + } + else { + feature_enum_values[i] = int64_t(BondFeature::UNKNOWN); + } + } + const int64_t dims[1] = { int64_t(num_features) }; + return torch_tensor_from_array(std::move(feature_enum_values), dims, 1, c10::ScalarType::Long); +} + +static const std::unordered_map positional_name_to_enum{ + {std::string("laplacian_eigvec"), int64_t(PositionalFeature::LAPLACIAN_EIGENVEC)}, + {std::string("laplacian_eigval"), int64_t(PositionalFeature::LAPLACIAN_EIGENVAL)}, + {std::string("rw_return_probs"), int64_t(PositionalFeature::RW_RETURN_PROBS)}, + {std::string("rw_transition_probs"), int64_t(PositionalFeature::RW_TRANSITION_PROBS)}, + {std::string("electrostatic"), int64_t(PositionalFeature::ELECTROSTATIC)}, + {std::string("commute"), int64_t(PositionalFeature::COMMUTE)}, + {std::string("graphormer"), int64_t(PositionalFeature::GRAPHORMER)}, +}; + +static const std::unordered_map feature_level_to_enum{ + {std::string("node"), int64_t(FeatureLevel::NODE)}, + {std::string("edge"), int64_t(FeatureLevel::EDGE)}, + {std::string("nodepair"), int64_t(FeatureLevel::NODEPAIR)}, + {std::string("graph"), int64_t(FeatureLevel::GRAPH)}, +}; + +static const std::unordered_map normalization_to_enum{ + {std::string("none"), int64_t(Normalization::NONE)}, + {std::string("inv"), int64_t(Normalization::INVERSE)}, + {std::string("sym"), int64_t(Normalization::SYMMETRIC)}, +}; + +std::pair,at::Tensor> positional_feature_options_to_tensor( + const pybind11::dict& dict) { + size_t num_features = 0; + size_t num_values = 0; + for (const auto& pair : dict) { + // The string keys (pair.first) of the outer dictionary aren't needed for this + if (!pybind11::isinstance(pair.second)) { + continue; + } + pybind11::dict feature_dict = pair.second.cast(); + pybind11::handle feature_name_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "pos_type")); + pybind11::handle feature_level_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "pos_level")); + if (!feature_name_handle || !feature_level_handle) { + continue; + } + std::string feature_name{ pybind11::str(feature_name_handle) }; + std::string feature_level{ pybind11::str(feature_level_handle) }; + + auto feature_it = positional_name_to_enum.find(feature_name); + auto level_it = feature_level_to_enum.find(feature_level); + if (feature_it == positional_name_to_enum.end() || level_it == feature_level_to_enum.end()) { + continue; + } + + PositionalFeature feature = PositionalFeature(feature_it->second); + switch (feature) { + case PositionalFeature::LAPLACIAN_EIGENVEC: + case PositionalFeature::LAPLACIAN_EIGENVAL: { + // Required int num_pos + pybind11::handle num_pos_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "num_pos")); + if (!num_pos_handle || !pybind11::isinstance(num_pos_handle)) { + break; + } + // Optional string normalization + pybind11::handle normalization_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "normalization")); + if (normalization_handle) { + if (!pybind11::isinstance(normalization_handle)) { + break; + } + std::string normalization_name{ pybind11::str(normalization_handle) }; + if (!normalization_to_enum.contains(normalization_name)) { + break; + } + } + // Optional bool disconnected_comp + pybind11::handle disconnected_comp_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "disconnected_comp")); + if (disconnected_comp_handle && !pybind11::isinstance(disconnected_comp_handle)) { + break; + } + num_values += 3 + 3; + ++num_features; + break; + } + case PositionalFeature::RW_RETURN_PROBS: + case PositionalFeature::RW_TRANSITION_PROBS: { + pybind11::handle ksteps_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "ksteps")); + if (!ksteps_handle) { + break; + } + int64_t power_count = 0; + if (pybind11::isinstance(ksteps_handle)) { + power_count = int64_t(ksteps_handle.cast()); + } + else if (pybind11::isinstance(ksteps_handle)) { + power_count = ksteps_handle.cast().size(); + } + if (power_count < 1) { + break; + } + pybind11::handle space_dim_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "space_dim")); + if (space_dim_handle && !pybind11::isinstance(space_dim_handle)) { + break; + } + num_values += 3 + 1 + power_count; + ++num_features; + break; + } + case PositionalFeature::ELECTROSTATIC: + case PositionalFeature::COMMUTE: + case PositionalFeature::GRAPHORMER: + num_values += 3; + ++num_features; + break; + } + } + + std::unique_ptr values(new int64_t[num_values]); + std::vector names(num_features); + + size_t prev_feature_index = 0; + size_t feature_index = 0; + size_t value_index = 0; + for (const auto& pair : dict) { + // The string keys (pair.first) of the outer dictionary aren't needed for this + if (!pybind11::isinstance(pair.second)) { + continue; + } + pybind11::dict feature_dict = pair.second.cast(); + pybind11::handle feature_name_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "pos_type")); + pybind11::handle feature_level_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "pos_level")); + if (!feature_name_handle || !feature_level_handle) { + continue; + } + std::string feature_name{ pybind11::str(feature_name_handle) }; + std::string feature_level{ pybind11::str(feature_level_handle) }; + + auto feature_it = positional_name_to_enum.find(feature_name); + auto level_it = feature_level_to_enum.find(feature_level); + if (feature_it == positional_name_to_enum.end() || level_it == feature_level_to_enum.end()) { + continue; + } + + PositionalFeature feature = PositionalFeature(feature_it->second); + switch (feature) { + case PositionalFeature::LAPLACIAN_EIGENVEC: + case PositionalFeature::LAPLACIAN_EIGENVAL: { + // Required int num_pos + pybind11::handle num_pos_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "num_pos")); + if (!num_pos_handle || !pybind11::isinstance(num_pos_handle)) { + continue; + } + // Optional string normalization + pybind11::handle normalization_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "normalization")); + Normalization normalization = Normalization::NONE; + if (normalization_handle) { + if (!pybind11::isinstance(normalization_handle)) { + continue; + } + std::string normalization_name{ pybind11::str(normalization_handle) }; + auto it = normalization_to_enum.find(normalization_name); + if (it == normalization_to_enum.end()) { + continue; + } + normalization = Normalization(it->second); + } + // Optional bool disconnected_comp + pybind11::handle disconnected_comp_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "disconnected_comp")); + if (disconnected_comp_handle && !pybind11::isinstance(disconnected_comp_handle)) { + continue; + } + values[value_index] = feature_it->second; + values[value_index + 1] = 3; + values[value_index + 2] = level_it->second; + values[value_index + 3] = int64_t(num_pos_handle.cast()); + values[value_index + 4] = int64_t(normalization); + values[value_index + 5] = disconnected_comp_handle ? bool(disconnected_comp_handle.cast()) : true; + value_index += 3 + 3; + ++feature_index; + break; + } + case PositionalFeature::RW_RETURN_PROBS: + case PositionalFeature::RW_TRANSITION_PROBS: { + // Required int or list[int] ksteps + pybind11::handle ksteps_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "ksteps")); + if (!ksteps_handle) { + continue; + } + int64_t power_count = 0; + if (pybind11::isinstance(ksteps_handle)) { + // Integer means use all powers from 1 up to this value, inclusive. + power_count = int64_t(ksteps_handle.cast()); + } + else if (pybind11::isinstance(ksteps_handle)) { + power_count = ksteps_handle.cast().size(); + } + if (power_count < 1) { + break; + } + // Optional int space_dim + pybind11::handle space_dim_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "space_dim")); + if (space_dim_handle && !pybind11::isinstance(space_dim_handle)) { + break; + } + values[value_index] = feature_it->second; + values[value_index + 1] = 1 + power_count; + values[value_index + 2] = level_it->second; + + int64_t space_dim = space_dim_handle ? int64_t(space_dim_handle.cast()) : 0; + values[value_index + 3] = space_dim; + if (pybind11::isinstance(ksteps_handle)) { + for (int64_t power = 1; power <= power_count; ++power) { + values[value_index + 3 + power] = power; + } + } + else if (pybind11::isinstance(ksteps_handle)) { + size_t power_index = 0; + int64_t prev_power = 0; + for(const auto item : ksteps_handle.cast()) { + int64_t next_power = pybind11::isinstance(item) ? int64_t(item.cast()) : prev_power; + if (next_power < prev_power) { + // Force the integers to be ascending + next_power = prev_power; + } + values[value_index + 3 + 1 + power_index] = next_power; + prev_power = next_power; + ++power_index; + } + } + value_index += 3 + 1 + power_count; + ++feature_index; + break; + } + case PositionalFeature::ELECTROSTATIC: + case PositionalFeature::COMMUTE: + case PositionalFeature::GRAPHORMER: + values[value_index] = feature_it->second; + values[value_index + 1] = 0; + values[value_index + 2] = level_it->second; + value_index += 3; + ++feature_index; + break; + } + if (feature_index != prev_feature_index) { + names[prev_feature_index] = (level_it->second == int64_t(FeatureLevel::NODE)) ? feature_name : (feature_level + std::string("_") + feature_name); + ++prev_feature_index; + } + } + assert(feature_index == num_features && prev_feature_index == num_features && value_index == num_values); + + const int64_t dims[1] = { int64_t(num_values) }; + return std::make_pair( + std::move(names), + torch_tensor_from_array(std::move(values), dims, 1, c10::ScalarType::Long)); +} + +template +at::Tensor create_edge_weights( + const GraphData& graph, + bool duplicate_edges, + bool add_self_loop, + bool use_bonds_weights, + c10::ScalarType dtype) { + + const size_t edge_coo_count = (duplicate_edges ? 2*graph.num_bonds : graph.num_bonds) + + (add_self_loop ? graph.num_atoms : 0); + std::unique_ptr edge_weights(new T[edge_coo_count]); + + // TODO: Use use_bonds_weights to optionally give weights + // in same order as other edge features + for (size_t i = 0; i < edge_coo_count; ++i) { + edge_weights[i] = FeatureValues::one; + } + + const int64_t dims[1] = { int64_t(edge_coo_count) }; + return torch_tensor_from_array(std::move(edge_weights), dims, 1, dtype); +} + +template +at::Tensor create_atom_features( + const GraphData& graph, + const at::Tensor& atom_property_list_onehot, + const at::Tensor& atom_property_list_float, + bool offset_carbon, + c10::ScalarType dtype, + MaskNaNStyle mask_nan_style, + T mask_nan_value, + int64_t &num_nans) { + + const size_t num_onehot_properties = (atom_property_list_onehot.scalar_type() == c10::ScalarType::Long && atom_property_list_onehot.ndimension() == 1) ? atom_property_list_onehot.size(0) : 0; + // NOTE: If TensorBase::data_ptr is ever removed, change it to TensorBase::const_data_ptr. + // Some torch version being used doesn't have const_data_ptr yet. + const int64_t* const property_list_onehot = (num_onehot_properties != 0) ? atom_property_list_onehot.data_ptr() : nullptr; + const size_t num_float_properties = (atom_property_list_float.scalar_type() == c10::ScalarType::Long && atom_property_list_float.ndimension() == 1) ? atom_property_list_float.size(0) : 0; + const int64_t* const property_list_float = (num_float_properties != 0) ? atom_property_list_float.data_ptr() : nullptr; + + size_t single_atom_float_count = num_float_properties; + for (size_t i = 0; i < num_onehot_properties; ++i) { + const int64_t property = property_list_onehot[i]; + single_atom_float_count += get_one_hot_atom_feature_size(AtomOneHotFeature(property)); + } + const size_t atom_float_count = single_atom_float_count * graph.num_atoms; + + std::unique_ptr atom_data(new T[atom_float_count]); + + T* current_atom_data = atom_data.get(); + + for (size_t i = 0; i < num_float_properties; ++i) { + const int64_t property = property_list_float[i]; + get_atom_float_feature(graph, current_atom_data, AtomFloatFeature(property), single_atom_float_count, offset_carbon); + ++current_atom_data; + } + for (size_t i = 0; i < num_onehot_properties; ++i) { + const int64_t property = property_list_onehot[i]; + current_atom_data += get_one_hot_atom_feature(graph, current_atom_data, AtomOneHotFeature(property), single_atom_float_count); + } + + num_nans += mask_nans(atom_data.get(), atom_float_count, mask_nan_style, mask_nan_value); + + const int64_t dims[2] = { int64_t(graph.num_atoms), int64_t(single_atom_float_count) }; + return torch_tensor_from_array(std::move(atom_data), dims, 2, dtype); +} + +template +at::Tensor create_bond_features( + const GraphData& graph, + const at::Tensor& bond_property_list, + const bool duplicate_edges, + bool add_self_loop, + c10::ScalarType dtype, + MaskNaNStyle mask_nan_style, + T mask_nan_value, + int64_t& num_nans) { + + const size_t num_properties = (bond_property_list.scalar_type() == c10::ScalarType::Long && bond_property_list.ndimension() == 1) ? bond_property_list.size(0) : 0; + const int64_t* const property_list = (num_properties != 0) ? bond_property_list.data_ptr() : nullptr; + + size_t single_bond_float_count = 0; + for (size_t i = 0; i < num_properties; ++i) { + const int64_t property = property_list[i]; + if (BondFeature(property) == BondFeature::TYPE_ONE_HOT || BondFeature(property) == BondFeature::STEREO_ONE_HOT) { + single_bond_float_count += get_one_hot_bond_feature_size(BondFeature(property)); + } + else { + ++single_bond_float_count; + } + } + // add_self_loop is only supported if duplicating edges + add_self_loop = add_self_loop && duplicate_edges; + size_t total_edge_count = graph.num_bonds; + if (duplicate_edges) { + total_edge_count = 2*total_edge_count + size_t(add_self_loop); + } + const size_t bond_float_count = single_bond_float_count * total_edge_count; + + std::unique_ptr bond_data(new T[bond_float_count]); + + T* current_bond_data = bond_data.get(); + + // This is the stride length (in floats) for each unique bond + const size_t duplicated_bond_float_count = duplicate_edges ? (2*single_bond_float_count) : single_bond_float_count; + + for (size_t i = 0; i < num_properties; ++i) { + const int64_t property = property_list[i]; + if (BondFeature(property) == BondFeature::TYPE_ONE_HOT || BondFeature(property) == BondFeature::STEREO_ONE_HOT) { + current_bond_data += get_one_hot_bond_feature(graph, current_bond_data, BondFeature(property), duplicated_bond_float_count); + } + else { + get_bond_float_feature(graph, current_bond_data, BondFeature(property), duplicated_bond_float_count); + ++current_bond_data; + } + } + + if (duplicate_edges) { + current_bond_data = bond_data.get(); + // Duplicate the data for each bond + for (size_t i = 0; i < graph.num_bonds; ++i) { + for (size_t j = 0; j < single_bond_float_count; ++j) { + current_bond_data[j+single_bond_float_count] = current_bond_data[j]; + } + current_bond_data += duplicated_bond_float_count; + } + if (add_self_loop) { + // Self loops don't have valid bond data, but don't treat them as NaNs. + // Fill with zeros, instead. + memset(current_bond_data, 0, graph.num_atoms * graph.num_atoms); + } + } + + num_nans += mask_nans(bond_data.get(), bond_float_count, mask_nan_style, mask_nan_value); + + int64_t dims[2] = { int64_t(total_edge_count), int64_t(single_bond_float_count) }; + return torch_tensor_from_array(std::move(bond_data), dims, 2, dtype); +} + +template +void node_to_edge( + std::unique_ptr& output_ptr, + size_t& floats_per_half_edge, + const IN_T* input, + const size_t n, + const size_t floats_per_node, + const GraphData& graph) { + + // Edge order must be consistent with the edges in the graph, + // which is not necessarily lexicographic order. + const size_t num_half_edges = 2*graph.num_bonds; + floats_per_half_edge = 2 * floats_per_node; + output_ptr.reset(new OUT_T[num_half_edges * 2 * floats_per_node]); + OUT_T* output = output_ptr.get(); + for (size_t bond = 0; bond < graph.num_bonds; ++bond, output += 2*floats_per_half_edge) { + const size_t atomi = graph.bonds[bond].beginAtomIdx; + const size_t atomj = graph.bonds[bond].endAtomIdx; + const IN_T* input_i = input + atomi * floats_per_node; + const IN_T* input_j = input + atomj * floats_per_node; + // For each edge, record all of the sums followed by all of the absolute differences + OUT_T* output_sum = output; + OUT_T* output_absdiff = output + floats_per_node; + for (size_t float_index = 0; float_index < floats_per_node; ++float_index) { + const IN_T sum = input_i[float_index] + input_j[float_index]; + const IN_T diff = input_i[float_index] - input_j[float_index]; + const IN_T absdiff = std::abs(diff); + const OUT_T sum_out = FeatureValues::convertToFeatureType(sum); + const OUT_T absdiff_out = FeatureValues::convertToFeatureType(absdiff); + output_sum[float_index] = sum_out; + output_absdiff[float_index] = absdiff_out; + // Same values for opposite direction + output_sum[floats_per_half_edge + float_index] = sum_out; + output_absdiff[floats_per_half_edge + float_index] = absdiff_out; + } + } +} + +template +void node_to_node_pair( + std::unique_ptr& output_ptr, + size_t& floats_per_pair, + const IN_T* input, + const size_t n, + const size_t floats_per_node) { + + floats_per_pair = 2 * floats_per_node; + output_ptr.reset(new OUT_T[n * n * floats_per_pair]); + OUT_T* output = output_ptr.get(); + const IN_T* input_i = input; + for (size_t i = 0; i < n; ++i, input_i += floats_per_node) { + const IN_T* input_j = input; + for (size_t j = 0; j < n; ++j, input_j += floats_per_node, output += floats_per_pair) { + // For each pair, record all of the sums followed by all of the absolute differences + OUT_T* output_sum = output; + OUT_T* output_absdiff = output + floats_per_node; + for (size_t float_index = 0; float_index < floats_per_node; ++float_index) { + const IN_T sum = input_i[float_index] + input_j[float_index]; + const IN_T diff = input_i[float_index] - input_j[float_index]; + const IN_T absdiff = std::abs(diff); + output_sum[float_index] = FeatureValues::convertToFeatureType(sum); + output_absdiff[float_index] = FeatureValues::convertToFeatureType(absdiff); + } + } + } +} + +enum class StatOperation { + MINIMUM, + MEAN +}; + +template +T stat_init_accum(T v) { + return v; +} + +template +void stat_accum(T& accum, T v) { + switch (op) { + case StatOperation::MINIMUM: + accum = (v < accum) ? v : accum; + break; + case StatOperation::MEAN: + accum += v; + break; + } +} + +template +T stat_accum_finish(T accum, size_t n) { + switch (op) { + case StatOperation::MINIMUM: + return accum; + case StatOperation::MEAN: + return accum / n; + } +} + +template +void node_pair_to_node_helper( + OUT_T* output, + const IN_T* input, + const size_t n, + const size_t floats_per_pair, + const size_t node_index) { + + // for each float per pair + for (size_t float_index = 0; float_index < floats_per_pair; ++float_index, output += 2) { + // across all rows (axis 0) of column node_index, then across all columns (axis 1) of row node_index + IN_T accum = stat_init_accum(input[node_index * floats_per_pair + float_index]); + for (size_t row = 1; row < n; ++row) { + stat_accum(accum, input[(row * n + node_index) * floats_per_pair + float_index]); + } + output[0] = FeatureValues::convertToFeatureType(stat_accum_finish(accum, n)); + accum = stat_init_accum(input[node_index * n * floats_per_pair + float_index]); + for (size_t col = 1; col < n; ++col) { + stat_accum(accum, input[(node_index * n + col) * floats_per_pair + float_index]); + } + output[1] = FeatureValues::convertToFeatureType(stat_accum_finish(accum, n)); + } +} + +template +void node_pair_to_node_helper_stdev( + OUT_T* output, + const IN_T* input, + const size_t n, + const size_t floats_per_pair, + const size_t node_index) { + + // for each float per pair + for (size_t float_index = 0; float_index < floats_per_pair; ++float_index, output += 2) { + // across all rows (axis 0) of column node_index, then across all columns (axis 1) of row node_index + IN_T v = input[node_index * floats_per_pair + float_index]; + IN_T accum = v; + IN_T accum2 = v * v; + for (size_t row = 1; row < n; ++row) { + v = input[(row * n + node_index) * floats_per_pair + float_index]; + accum += v; + accum2 += v * v; + } + // NOTE: Using divisor n, the default in numpy.std, not n-1, the default elsewhere + accum /= n; + accum2 /= n; + output[0] = FeatureValues::convertToFeatureType(std::sqrt(accum2 - accum*accum)); + + v = input[node_index * n * floats_per_pair + float_index]; + accum = v; + accum2 = v * v; + for (size_t col = 1; col < n; ++col) { + v = input[(node_index * n + col) * floats_per_pair + float_index]; + accum += v; + accum2 += v * v; + } + // NOTE: Using divisor n, the default in numpy.std, not n-1, the default elsewhere + accum /= n; + accum2 /= n; + output[1] = FeatureValues::convertToFeatureType(std::sqrt(accum2 - accum*accum)); + } +} + +template +void node_pair_to_node( + std::unique_ptr& output_ptr, + size_t& floats_per_node, + const IN_T* input, + const size_t n, + const size_t floats_per_pair) { + + const size_t num_ops = 3; + floats_per_node = num_ops * 2 * floats_per_pair; + output_ptr.reset(new OUT_T[n * floats_per_node]); + OUT_T* output = output_ptr.get(); + for (size_t node_index = 0; node_index < n; ++node_index) { + // min, mean, stdev (using divisor N, the default in numpy.std, not N-1, the default elsewhere) + node_pair_to_node_helper(output, input, n, floats_per_pair, node_index); + output += 2 * floats_per_pair; + node_pair_to_node_helper(output, input, n, floats_per_pair, node_index); + output += 2 * floats_per_pair; + node_pair_to_node_helper_stdev(output, input, n, floats_per_pair, node_index); + output += 2 * floats_per_pair; + } +} + +template +void node_pair_to_edge( + std::unique_ptr& output_ptr, + size_t& floats_per_edge, + const IN_T* input, + const size_t n, + const size_t floats_per_pair, + const GraphData& graph) { + + // Edge order must be consistent with the edges in the graph, + // which is not necessarily lexicographic order. + const size_t num_half_edges = 2*graph.num_bonds; + floats_per_edge = floats_per_pair; + output_ptr.reset(new OUT_T[num_half_edges * floats_per_pair]); + OUT_T* output = output_ptr.get(); + for (size_t bond = 0; bond < graph.num_bonds; ++bond) { + const size_t atomi = graph.bonds[bond].beginAtomIdx; + const size_t atomj = graph.bonds[bond].endAtomIdx; + const IN_T* input_ij = input + ((atomi * n) + atomj) * floats_per_pair; + for (size_t float_index = 0; float_index < floats_per_pair; ++float_index, ++output) { + *output = FeatureValues::convertToFeatureType(input_ij[float_index]); + } + + const IN_T* input_ji = input + ((atomj * n) + atomi) * floats_per_pair; + for (size_t float_index = 0; float_index < floats_per_pair; ++float_index, ++output) { + *output = FeatureValues::convertToFeatureType(input_ji[float_index]); + } + } +} + +template +void create_positional_features( + const GraphData& graph, + const at::Tensor& positional_property_list, + c10::ScalarType dtype, + MaskNaNStyle mask_nan_style, + T mask_nan_value, + int64_t& num_nans, + int64_t& nan_tensor_index, + std::vector& tensors) { + + const size_t size = (positional_property_list.scalar_type() == c10::ScalarType::Long && positional_property_list.ndimension() == 1) ? positional_property_list.size(0) : 0; + const int64_t* const property_list = (size >= 3) ? positional_property_list.data_ptr() : nullptr; + + if (property_list == nullptr) { + return; + } + NeighbourData neighbours = construct_neighbours(graph); + + LaplacianData laplacian_data; + LaplacianData laplacian_data_comp; + size_t num_components = 0; // 0 indicates that the components haven't been computed yet + std::vector components; + std::vector laplacian_pseudoinverse; + std::vector matrix; + size_t i = 0; + while (size >= i + 3) { + int64_t property = property_list[i]; + int64_t current_size = property_list[i + 1]; + FeatureLevel feature_level = FeatureLevel(property_list[i + 2]); + i += 3; + if (i + current_size > size || i + current_size < i) { + break; + } + FeatureLevel base_level; + std::unique_ptr base_data; + int64_t base_dims[3] = { 1,1,1 }; + size_t base_dim_count; + if ((property == int64_t(PositionalFeature::LAPLACIAN_EIGENVEC) || property == int64_t(PositionalFeature::LAPLACIAN_EIGENVAL)) && current_size == 3) { + size_t num_pos = (property_list[i] >= 0) ? size_t(property_list[i]) : 0; + Normalization normalization = Normalization(property_list[i + 1]); + bool disconnected_comp = (property_list[i + 2] != 0); + i += 3; + + // The common case is that there's only 1 component, even if disconnected_comp is true, + // so find the number of components, first. + if (disconnected_comp && num_components == 0) { + num_components = find_components(graph.num_atoms, neighbours.neighbour_starts, neighbours.neighbours, components); + } + const bool multiple_components = disconnected_comp && (num_components > 1); + + LaplacianData& current_data = multiple_components ? laplacian_data_comp : laplacian_data; + if (current_data.eigenvalues.size() == 0 || current_data.normalization != normalization) { + compute_laplacian_eigendecomp( + graph.num_atoms, + neighbours.neighbour_starts, + neighbours.neighbours, + normalization, + current_data, + multiple_components ? num_components : 1, + &components); + } + + const bool isVec = (property == int64_t(PositionalFeature::LAPLACIAN_EIGENVEC)); + base_level = FeatureLevel::NODE; + base_dims[0] = graph.num_atoms; + base_dims[1] = num_pos; + base_dim_count = 2; + base_data.reset(new double[graph.num_atoms * num_pos]); + + // Ensure exactly the tensor dimensions of num_atoms x num_pos before changing the level. + if (isVec) { + double* data = base_data.get(); + for (size_t atom_index = 0; atom_index < graph.num_atoms; ++atom_index, data += num_pos) { + for (size_t i = 0; i < num_pos && i < graph.num_atoms; ++i) { + // Row eigenvectors to column eigenvectors + data[i] = current_data.vectors[atom_index + i * graph.num_atoms]; + // There's no plausible way the eigenvectors should end up with NaNs, + // so just assert in debug builds. + assert(std::isfinite(data[i])); + } + // NOTE: Do not treat extra values as NaN. The original code filled them with zeros. + for (size_t i = graph.num_atoms; i < num_pos; ++i) { + data[i] = 0; + } + } + } + else { + double* data = base_data.get(); + const bool is_multi_component = (current_data.eigenvalues.size() == size_t(graph.num_atoms)*graph.num_atoms); + assert(is_multi_component || (current_data.eigenvalues.size() == graph.num_atoms)); + size_t source_row_start = 0; + for (size_t atom_index = 0; atom_index < graph.num_atoms; ++atom_index, data += num_pos) { + for (size_t i = 0; i < num_pos && i < graph.num_atoms; ++i) { + // Duplicate the eigenvalue for each atom + data[i] = current_data.eigenvalues[source_row_start + i]; + // There's no plausible way the eigenvalues should end up with NaNs, + // so just assert in debug builds. + assert(std::isfinite(data[i])); + } + // NOTE: Do not treat extra values as NaN. The original code filled them with zeros. + for (size_t i = graph.num_atoms; i < num_pos; ++i) { + data[i] = 0; + } + if (is_multi_component) { + source_row_start += graph.num_atoms; + } + } + } + } + else if ((property == int64_t(PositionalFeature::RW_RETURN_PROBS) || property == int64_t(PositionalFeature::RW_TRANSITION_PROBS)) && current_size >= 1) { + int space_dim = property_list[i]; + ++i; + uint32_t num_powers = current_size - 1; + const uint64_t* powers = reinterpret_cast(property_list + i); + i += num_powers; + const bool isProbs = (property == int64_t(PositionalFeature::RW_RETURN_PROBS)); + RandomWalkDataOption option = isProbs ? RandomWalkDataOption::PROBABILITIES : RandomWalkDataOption::MATRIX; + + std::vector output; + compute_rwse(num_powers, powers, graph.num_atoms, neighbours.neighbour_starts, neighbours.neighbours, option, output, space_dim); + + base_level = isProbs ? FeatureLevel::NODE : FeatureLevel::NODEPAIR; + + base_dims[0] = graph.num_atoms; + base_dims[1] = isProbs ? num_powers : graph.num_atoms; + base_dims[2] = isProbs ? 1 : num_powers; + base_dim_count = isProbs ? 2 : 3; + base_data.reset(new double[output.size()]); + std::copy(output.begin(), output.end(), base_data.get()); + } + else if (property == int64_t(PositionalFeature::ELECTROSTATIC) && current_size == 0) { + const double* weights = nullptr; + compute_electrostatic_interactions(graph.num_atoms, neighbours.neighbour_starts, neighbours.neighbours, laplacian_data, laplacian_pseudoinverse, matrix, weights); + + base_level = FeatureLevel::NODEPAIR; + base_dims[0] = graph.num_atoms; + base_dims[1] = graph.num_atoms; + base_dim_count = 2; + assert(matrix.size() == graph.num_atoms * size_t(graph.num_atoms)); + base_data.reset(new double[matrix.size()]); + std::copy(matrix.begin(), matrix.end(), base_data.get()); + } + else if (property == int64_t(PositionalFeature::COMMUTE) && current_size == 0) { + const double* weights = nullptr; + compute_commute_distances(graph.num_atoms, neighbours.neighbour_starts, neighbours.neighbours, laplacian_data, laplacian_pseudoinverse, matrix, weights); + + base_level = FeatureLevel::NODEPAIR; + base_dims[0] = graph.num_atoms; + base_dims[1] = graph.num_atoms; + base_dim_count = 2; + assert(matrix.size() == graph.num_atoms * size_t(graph.num_atoms)); + base_data.reset(new double[matrix.size()]); + std::copy(matrix.begin(), matrix.end(), base_data.get()); + } + else if (property == int64_t(PositionalFeature::GRAPHORMER) && current_size == 0) { + std::vector> queue; + std::vector all_pairs_distances; + compute_graphormer_distances(graph.num_atoms, neighbours.neighbour_starts, neighbours.neighbours, queue, all_pairs_distances); + + base_level = FeatureLevel::NODEPAIR; + base_dims[0] = graph.num_atoms; + base_dims[1] = graph.num_atoms; + base_dim_count = 2; + assert(all_pairs_distances.size() == graph.num_atoms * size_t(graph.num_atoms)); + base_data.reset(new double[all_pairs_distances.size()]); + std::copy(all_pairs_distances.begin(), all_pairs_distances.end(), base_data.get()); + } + + if (base_data.get() == nullptr) { + continue; + } + + // Change the level and convert to the correct type if needed. + std::unique_ptr final_data; + int64_t final_dims[3]; + std::copy(base_dims, base_dims + 3, final_dims); + size_t final_num_dims = base_dim_count; + if (feature_level != base_level) { + if (base_level == FeatureLevel::NODE) { + if (feature_level == FeatureLevel::EDGE) { + size_t floats_per_half_edge; + node_to_edge(final_data, floats_per_half_edge, base_data.get(), base_dims[0], base_dims[1], graph); + final_dims[0] = 2 * graph.num_bonds; + final_dims[1] = floats_per_half_edge; + final_dims[2] = 1; + } + else if (feature_level == FeatureLevel::NODEPAIR) { + size_t floats_per_pair; + node_to_node_pair(final_data, floats_per_pair, base_data.get(), base_dims[0], base_dims[1]); + final_num_dims = 3; + final_dims[1] = base_dims[0]; + final_dims[2] = floats_per_pair; + } + else { + // Not implemented + } + } + else if (base_level == FeatureLevel::NODEPAIR) { + if (feature_level == FeatureLevel::NODE) { + size_t floats_per_node; + node_pair_to_node(final_data, floats_per_node, base_data.get(), base_dims[0], base_dims[2]); + final_num_dims = 2; + final_dims[1] = floats_per_node; + final_dims[2] = 1; + } + else if (feature_level == FeatureLevel::EDGE) { + size_t floats_per_edge; + node_pair_to_edge(final_data, floats_per_edge, base_data.get(), base_dims[0], base_dims[2], graph); + final_num_dims = 2; + final_dims[0] = 2 * graph.num_bonds; + final_dims[1] = floats_per_edge; + final_dims[2] = 1; + } + else { + // Not implemented + } + } + else { + // Not implemented + } + } + else if (dtype != c10::ScalarType::Double) { + // Just convert + const size_t total_num_floats = final_dims[0] * final_dims[1] * final_dims[2]; + final_data.reset(new T[total_num_floats]); + for (size_t i = 0; i < total_num_floats; ++i) { + final_data[i] = FeatureValues::convertToFeatureType(base_data[i]); + } + } + else { + // Perfect match out of the box + // This will only be hit if T is double, but it still needs to compile + // for other cases, which is why the reinterpret_cast is needed. + final_data.reset(reinterpret_cast(base_data.release())); + } + + if (final_data.get() == nullptr) { + continue; + } + + tensors.push_back(torch_tensor_from_array(std::move(final_data), final_dims, final_num_dims, dtype)); + } +} + +template +void create_all_features( + const GraphData& graph, + const at::Tensor& atom_property_list_onehot, + const at::Tensor& atom_property_list_float, + bool create_conformer_feature, + const at::Tensor& bond_property_list, + const at::Tensor& positional_property_list, + bool duplicate_edges, + bool add_self_loop, + bool already_has_Hs, + bool use_bonds_weights, + bool offset_carbon, + c10::ScalarType dtype, + MaskNaNStyle mask_nan_style, + T mask_nan_value, + int64_t& num_nans, + int64_t& nan_tensor_index, + const std::string& smiles_string, + std::vector& tensors) { + + if (mask_nan_style == MaskNaNStyle::NONE) { + // In some cases, the NONE and REPLACE styles can be combined. + mask_nan_value = FeatureValues::nan_value; + } + at::Tensor edge_weights_tensor = create_edge_weights( + graph, + duplicate_edges, + add_self_loop, + use_bonds_weights, + dtype); + tensors.push_back(std::move(edge_weights_tensor)); + at::Tensor atom_features_tensor = create_atom_features( + graph, + atom_property_list_onehot, + atom_property_list_float, + offset_carbon, + dtype, + mask_nan_style, + mask_nan_value, + num_nans); + tensors.push_back(std::move(atom_features_tensor)); + if (num_nans != 0) { + nan_tensor_index = tensors.size()-1; + } + at::Tensor bond_features_tensor = create_bond_features( + graph, + bond_property_list, + duplicate_edges, + add_self_loop, + dtype, + mask_nan_style, + mask_nan_value, + num_nans); + tensors.push_back(std::move(bond_features_tensor)); + if (nan_tensor_index < 0 && num_nans != 0) { + nan_tensor_index = tensors.size()-1; + } + if (create_conformer_feature) { + at::Tensor conformer_features_tensor = get_conformer_features( + *graph.mol, + already_has_Hs, + dtype, + mask_nan_style, + mask_nan_value, + num_nans, + smiles_string); + tensors.push_back(std::move(conformer_features_tensor)); + if (nan_tensor_index < 0 && num_nans != 0) { + nan_tensor_index = tensors.size(); + } + } + create_positional_features( + graph, + positional_property_list, + dtype, + mask_nan_style, + mask_nan_value, + num_nans, + nan_tensor_index, + tensors); +} + +std::tuple, int64_t, int64_t> featurize_smiles( + const std::string& smiles_string, + const at::Tensor& atom_property_list_onehot, + const at::Tensor& atom_property_list_float, + bool create_conformer_feature, + const at::Tensor& bond_property_list, + const at::Tensor& positional_property_list, + bool duplicate_edges, + bool add_self_loop, + bool explicit_H, + bool use_bonds_weights, + bool offset_carbon, + int dtype_int, + int mask_nan_style_int, + double mask_nan_value) { + + GraphData graph = read_graph(smiles_string, explicit_H); + + const size_t edge_coo_count = 2*graph.num_bonds + (add_self_loop ? graph.num_atoms : 0); + std::unique_ptr edge_index(new int64_t[2*edge_coo_count]); + for (size_t i = 0; i < graph.num_bonds; ++i) { + // PyG has all directed edge begin indices followed by all end indices. + edge_index[2*i] = graph.bonds[i].beginAtomIdx; + edge_index[2*i+1] = graph.bonds[i].endAtomIdx; + edge_index[2*i + edge_coo_count] = graph.bonds[i].endAtomIdx; + edge_index[2*i+1 + edge_coo_count] = graph.bonds[i].beginAtomIdx; + } + if (add_self_loop) { + for (size_t i = 0; i < graph.num_atoms; ++i) { + edge_index[2*graph.num_bonds + i] = i; + edge_index[2*graph.num_bonds + i + edge_coo_count] = i; + } + } + int64_t edge_coo_dims[2] = { int64_t(2), int64_t(edge_coo_count) }; + at::Tensor edge_coo_tensor = torch_tensor_from_array(std::move(edge_index), edge_coo_dims, 2, c10::ScalarType::Long); + + std::vector tensors; + tensors.push_back(std::move(edge_coo_tensor)); + c10::ScalarType dtype = c10::ScalarType(dtype_int); + MaskNaNStyle mask_nan_style = MaskNaNStyle(mask_nan_style_int); + int64_t num_nans = 0; + int64_t nan_tensor_index = -1; + if (dtype == c10::ScalarType::Half) { + create_all_features( + graph, + atom_property_list_onehot, + atom_property_list_float, + create_conformer_feature, + bond_property_list, + positional_property_list, + duplicate_edges, + add_self_loop, + explicit_H, + use_bonds_weights, + offset_carbon, + dtype, + mask_nan_style, + FeatureValues::convertToFeatureType(mask_nan_value), + num_nans, + nan_tensor_index, + smiles_string, + tensors); + } + else if (dtype == c10::ScalarType::Float) { + create_all_features( + graph, + atom_property_list_onehot, + atom_property_list_float, + create_conformer_feature, + bond_property_list, + positional_property_list, + duplicate_edges, + add_self_loop, + explicit_H, + use_bonds_weights, + offset_carbon, + dtype, + mask_nan_style, + FeatureValues::convertToFeatureType(mask_nan_value), + num_nans, + nan_tensor_index, + smiles_string, + tensors); + } + else if (dtype == c10::ScalarType::Double) { + create_all_features( + graph, + atom_property_list_onehot, + atom_property_list_float, + create_conformer_feature, + bond_property_list, + positional_property_list, + duplicate_edges, + add_self_loop, + explicit_H, + use_bonds_weights, + offset_carbon, + dtype, + mask_nan_style, + FeatureValues::convertToFeatureType(mask_nan_value), + num_nans, + nan_tensor_index, + smiles_string, + tensors); + } + + return std::make_tuple(tensors, num_nans, nan_tensor_index); +} diff --git a/graphium/graphium_cpp/features.h b/graphium/graphium_cpp/features.h new file mode 100644 index 000000000..4bbcde001 --- /dev/null +++ b/graphium/graphium_cpp/features.h @@ -0,0 +1,277 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include +#include + +// Torch tensor headers +#include +#include + +#include +#include + +// PyBind and Torch headers +#include +#include +#include + +enum class FeatureLevel { + NODE, + EDGE, + NODEPAIR, + GRAPH +}; + +enum class AtomFloatFeature { + ATOMIC_NUMBER, + MASS, + VALENCE, + IMPLICIT_VALENCE, + HYBRIDIZATION, + CHIRALITY, + AROMATIC, + IN_RING, + MIN_RING, + MAX_RING, + NUM_RING, + DEGREE, + RADICAL_ELECTRON, + FORMAL_CHARGE, + VDW_RADIUS, + COVALENT_RADIUS, + ELECTRONEGATIVITY, + IONIZATION, + MELTING_POINT, + METAL, + GROUP, + PERIOD, + SINGLE_BOND, + AROMATIC_BOND, + DOUBLE_BOND, + TRIPLE_BOND, + IS_CARBON, + UNKNOWN +}; + +enum class AtomOneHotFeature { + ATOMIC_NUM, + DEGREE, + VALENCE, + IMPLICIT_VALENCE, + HYBRIDIZATION, + CHIRALITY, + PHASE, + TYPE, + GROUP, + PERIOD, + UNKNOWN +}; + +enum class BondFeature { + TYPE_FLOAT, + TYPE_ONE_HOT, + IN_RING, + CONJUGATED, + STEREO_ONE_HOT, + CONFORMER_BOND_LENGTH, + ESTIMATED_BOND_LENGTH, + UNKNOWN +}; + +enum class PositionalFeature { + LAPLACIAN_EIGENVEC, + LAPLACIAN_EIGENVAL, + RW_RETURN_PROBS, + RW_TRANSITION_PROBS, + ELECTROSTATIC, + COMMUTE, + GRAPHORMER +}; + +enum class Normalization { + NONE, + SYMMETRIC, + INVERSE +}; + +enum class MaskNaNStyle { + NONE, + REPORT, + REPLACE +}; + +struct PositionalOptions { + PositionalFeature feature; + FeatureLevel level; + + std::vector rw_powers; + int rw_space_dim = 0; + + uint32_t laplacian_num_pos = 8; + Normalization laplacian_normalization = Normalization::NONE; + bool laplacian_disconnected_comp = true; +}; + +template +struct FeatureValues {}; + +template<> struct FeatureValues { + static constexpr int16_t zero = 0x0000; + static constexpr int16_t one = 0x3C00; + static constexpr int16_t nan_value = 0x7C01; + + template + static int16_t convertToFeatureType(T inputType) { + static_assert(std::is_floating_point_v); + return c10::detail::fp16_ieee_from_fp32_value(float(inputType)); + } + + static constexpr bool is_finite(int16_t v) { + // If the exponent bits are the maximum value, v is infinite or NaN + return (v & 0x7C00) != 0x7C00; + } + + using MathType = float; +}; +template<> struct FeatureValues { + static constexpr float zero = 0.0f; + static constexpr float one = 1.0f; + static constexpr float nan_value = std::numeric_limits::quiet_NaN(); + + template + static float convertToFeatureType(T inputType) { + static_assert(std::is_floating_point_v); + return float(inputType); + } + + static bool is_finite(float v) { + return std::isfinite(v); + } + + using MathType = float; +}; +template<> struct FeatureValues { + static constexpr double zero = 0.0; + static constexpr double one = 1.0; + static constexpr double nan_value = std::numeric_limits::quiet_NaN(); + + template + static double convertToFeatureType(T inputType) { + static_assert(std::is_floating_point_v); + return double(inputType); + } + + static constexpr bool is_finite(double v) { + return std::isfinite(v); + } + + using MathType = double; +}; + +template +constexpr int64_t mask_nans(T* data, size_t n, MaskNaNStyle style, T value) { + if (style == MaskNaNStyle::NONE) { + return 0; + } + if (style == MaskNaNStyle::REPLACE) { + for (size_t i = 0; i < n; ++i) { + if (!FeatureValues::is_finite(data[i])) { + data[i] = value; + } + } + return 0; + } + + assert(mask_nan_style == MaskNaNStyle::REPORT); + int64_t num_nans = 0; + for (size_t i = 0; i < n; ++i) { + num_nans += (!FeatureValues::is_finite(data[i])); + } + return num_nans; +} + + +// This is just a function to provide to torch, so that we don't have to copy +// the tensor data to put it in a torch tensor, and torch can delete the data +// when it's no longer needed. +template +void deleter(void* p) { + delete[](T*)p; +} + +template +at::Tensor torch_tensor_from_array(std::unique_ptr&& source, const int64_t* dims, size_t num_dims, c10::ScalarType type) { + return at::from_blob( + source.release(), + at::IntArrayRef(dims, num_dims), + deleter, c10::TensorOptions(type)); +} + +// Most of the data needed about an atom +struct CompactAtom { + uint8_t atomicNum; + uint8_t totalDegree; + int8_t formalCharge; + uint8_t chiralTag; + uint8_t totalNumHs; + uint8_t hybridization; + bool isAromatic; + float mass; +}; + +// Most of the data needed about a bond +struct CompactBond { + uint8_t bondType; + bool isConjugated; + bool isInRing; + uint8_t stereo; + uint32_t beginAtomIdx; + uint32_t endAtomIdx; +}; + +// Data representing a molecule before featurization +struct GraphData { + const size_t num_atoms; + std::unique_ptr atoms; + const size_t num_bonds; + std::unique_ptr bonds; + + std::unique_ptr mol; +}; + + +// These functions are in features.cpp, and declared here so that +// graphium_cpp.cpp can expose them to Python via pybind. +at::Tensor atom_float_feature_names_to_tensor(const std::vector& features); +at::Tensor atom_onehot_feature_names_to_tensor(const std::vector& features); +at::Tensor bond_feature_names_to_tensor(const std::vector& features); +std::pair,at::Tensor> positional_feature_options_to_tensor(const pybind11::dict& dict); +std::tuple, int64_t, int64_t> featurize_smiles( + const std::string& smiles_string, + const at::Tensor& atom_property_list_onehot, + const at::Tensor& atom_property_list_float, + bool create_conformer_feature, + const at::Tensor& bond_property_list, + const at::Tensor& positional_property_list, + bool duplicate_edges = true, + bool add_self_loop = false, + bool explicit_H = false, + bool use_bonds_weights = false, + bool offset_carbon = true, + int dtype_int = int(c10::ScalarType::Half), + int mask_nan_style_int = int(MaskNaNStyle::REPORT), + double mask_nan_value = 0.0); + + +// parse_mol is in graphium_cpp.cpp, but is declared in this header so +// that both labels.cpp and features.cpp can call it. +std::unique_ptr parse_mol( + const std::string& smiles_string, + bool explicit_H, + bool ordered = false); diff --git a/graphium/graphium_cpp/float_features.cpp b/graphium/graphium_cpp/float_features.cpp new file mode 100644 index 000000000..315a99f64 --- /dev/null +++ b/graphium/graphium_cpp/float_features.cpp @@ -0,0 +1,526 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "float_features.h" + +#include "features.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +static constexpr double qNaN = std::numeric_limits::quiet_NaN(); + +// This table is from the Electronegativity column of graphium/features/periodic_table.csv +const double electronegativityTable[] = { + 2.20, qNaN, 0.98, 1.57, 2.04, 2.55, 3.04, 3.44, 3.98, + qNaN, 0.93, 1.31, 1.61, 1.90, 2.19, 2.58, 3.16, qNaN, 0.82, + 1.00, 1.36, 1.54, 1.63, 1.66, 1.55, 1.83, 1.88, 1.91, 1.90, + 1.65, 1.81, 2.01, 2.18, 2.55, 2.96, qNaN, 0.82, 0.95, 1.22, + 1.33, 1.60, 2.16, 1.90, 2.20, 2.28, 2.20, 1.93, 1.69, 1.78, + 1.96, 2.05, 2.10, 2.66, qNaN, 0.79, 0.89, 1.10, 1.12, 1.13, + 1.14, 1.13, 1.17, 1.20, 1.20, 1.20, 1.22, 1.23, 1.24, 1.25, + 1.10, 1.27, 1.30, 1.50, 2.36, 1.90, 2.20, 2.20, 2.28, 2.54, + 2.00, 2.04, 2.33, 2.02, 2.00, 2.20, qNaN, 0.70, 0.90, 1.10, + 1.30, 1.50, 1.38, 1.36, 1.28, 1.30, 1.30, 1.30, 1.30, 1.30, + 1.30, 1.30, 1.30, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, + qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, +}; + +// This table is from the FirstIonization column of graphium/features/periodic_table.csv +const double firstIonizationTable[] = { + 13.5984, 24.5874, 5.3917, 9.3227, 8.2980, 11.2603, 14.5341, 13.6181, 17.4228, + 21.5645, 5.1391, 7.6462, 5.9858, 8.1517, 10.4867, 10.3600, 12.9676, 15.7596, 4.3407, + 6.1132, 6.5615, 6.8281, 6.7462, 6.7665, 7.4340, 7.9024, 7.8810, 7.6398, 7.7264, + 9.3942, 5.9993, 7.8994, 9.7886, 9.7524, 11.8138, 13.9996, 4.1771, 5.6949, 6.2173, + 6.6339, 6.7589, 7.0924, 7.2800, 7.3605, 7.4589, 8.3369, 7.5762, 8.9938, 5.7864, + 7.3439, 8.6084, 9.0096, 10.4513, 12.1298, 3.8939, 5.2117, 5.5769, 5.5387, 5.4730, + 5.5250, 5.5820, 5.6437, 5.6704, 6.1501, 5.8638, 5.9389, 6.0215, 6.1077, 6.1843, + 6.2542, 5.4259, 6.8251, 7.5496, 7.8640, 7.8335, 8.4382, 8.9670, 8.9587, 9.2255, + 10.4375, 6.1082, 7.4167, 7.2856, 8.4170, 9.3000, 10.7485, 4.0727, 5.2784, 5.1700, + 6.3067, 5.8900, 6.1941, 6.2657, 6.0262, 5.9738, 5.9915, 6.1979, 6.2817, 6.4200, + 6.5000, 6.5800, 6.6500, qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , + qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , +}; + +// This table is from the MeltingPoint column of graphium/features/periodic_table.csv +const double meltingPointTable[] = { + 14.175, qNaN , 453.85, 1560.15, 2573.15, 3948.15, 63.29, 50.50, 53.63, + 24.703, 371.15, 923.15, 933.40, 1683.15, 317.25, 388.51, 172.31, 83.96, 336.50, + 1112.15, 1812.15, 1933.15, 2175.15, 2130.15, 1519.15, 1808.15, 1768.15, 1726.15, 1357.75, + 692.88, 302.91, 1211.45, 1090.15, 494.15, 266.05, 115.93, 312.79, 1042.15, 1799.15, + 2125.15, 2741.15, 2890.15, 2473.15, 2523.15, 2239.15, 1825.15, 1234.15, 594.33, 429.91, + 505.21, 904.05, 722.80, 386.65, 161.45, 301.70, 1002.15, 1193.15, 1071.15, 1204.15, + 1289.15, 1204.15, 1345.15, 1095.15, 1585.15, 1630.15, 1680.15, 1743.15, 1795.15, 1818.15, + 1097.15, 1936.15, 2500.15, 3269.15, 3680.15, 3453.15, 3300.15, 2716.15, 2045.15, 1337.73, + 234.43, 577.15, 600.75, 544.67, 527.15, 575.15, 202.15, 300.15, 973.15, 1323.15, + 2028.15, 1873.15, 1405.15, 913.15, 913.15, 1267.15, 1340.15, 1259.15, 1925.15, 1133.15, + qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , + qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , +}; + +// This table is 2x the Metal column plus the Metalloid column of graphium/features/periodic_table.csv +const uint8_t metalTable[] = { + 0, 0, 2, 2, 1, 0, 0, 0, 0, + 0, 2, 2, 2, 1, 0, 0, 0, 0, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 1, 1, 0, 0, 0, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 1, 1, 0, 0, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 1, 0, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 0, 0, +}; + +template +void get_atom_float_feature(const GraphData& graph, T* data, AtomFloatFeature feature, size_t stride, bool offset_carbon) { + const uint32_t num_atoms = graph.num_atoms; + constexpr uint32_t carbon_atomic_num = 6; + using MT = typename FeatureValues::MathType; + switch (feature) { + case AtomFloatFeature::ATOMIC_NUMBER: { + const MT offset = offset_carbon ? carbon_atomic_num : 0; + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType((MT(graph.atoms[i].atomicNum) - offset) / MT(5)); + data += stride; + } + return; + } + case AtomFloatFeature::MASS: { + const RDKit::ROMol& mol = *graph.mol.get(); + constexpr MT carbon_mass = MT(12.011); + const MT offset = offset_carbon ? carbon_mass : 0; + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType((MT(mol.getAtomWithIdx(i)->getMass()) - offset) / MT(10)); + data += stride; + } + return; + } + case AtomFloatFeature::VALENCE: { + const RDKit::ROMol& mol = *graph.mol.get(); + const MT offset = offset_carbon ? 4 : 0; + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getTotalValence()) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::IMPLICIT_VALENCE: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getImplicitValence())); + data += stride; + } + return; + } + case AtomFloatFeature::HYBRIDIZATION: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getHybridization())); + data += stride; + } + return; + } + case AtomFloatFeature::CHIRALITY: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + const RDKit::Atom* atom = mol.getAtomWithIdx(i); + std::string prop; + bool has_prop = atom->getPropIfPresent(RDKit::common_properties::_CIPCode, prop); + *data = FeatureValues::convertToFeatureType(has_prop ? MT(prop.length() == 1 && prop[0] == 'R') : MT(2)); + data += stride; + } + return; + } + case AtomFloatFeature::AROMATIC: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getIsAromatic())); + data += stride; + } + return; + } + case AtomFloatFeature::IN_RING: { + const RDKit::ROMol& mol = *graph.mol.get(); + const RDKit::RingInfo* ring_info = mol.getRingInfo(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(ring_info->numAtomRings(i) != 0)); + data += stride; + } + return; + } + case AtomFloatFeature::MIN_RING: { + const RDKit::ROMol& mol = *graph.mol.get(); + const RDKit::RingInfo* ring_info = mol.getRingInfo(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(ring_info->minAtomRingSize(i))); + data += stride; + } + return; + } + case AtomFloatFeature::MAX_RING: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + data[i * stride] = FeatureValues::zero; + } + const RDKit::RingInfo* ring_info = mol.getRingInfo(); + const auto& rings = ring_info->atomRings(); + for (const auto& ring : rings) { + const T size = FeatureValues::convertToFeatureType(MT(ring.size())); + for (const auto atom_index : ring) { + if (size > data[atom_index * stride]) { + data[atom_index * stride] = size; + } + } + } + return; + } + case AtomFloatFeature::NUM_RING: { + const RDKit::ROMol& mol = *graph.mol.get(); + const RDKit::RingInfo* ring_info = mol.getRingInfo(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(ring_info->numAtomRings(i))); + data += stride; + } + return; + } + case AtomFloatFeature::DEGREE: { + const RDKit::ROMol& mol = *graph.mol.get(); + const MT offset = offset_carbon ? 2 : 0; + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getTotalDegree()) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::RADICAL_ELECTRON: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getNumRadicalElectrons())); + data += stride; + } + return; + } + case AtomFloatFeature::FORMAL_CHARGE: { + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(graph.atoms[i].formalCharge)); + data += stride; + } + return; + } + case AtomFloatFeature::VDW_RADIUS: { + const RDKit::PeriodicTable* table = RDKit::PeriodicTable::getTable(); + const MT offset = offset_carbon ? MT(table->getRvdw(carbon_atomic_num)) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(table->getRvdw(graph.atoms[i].atomicNum)) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::COVALENT_RADIUS: { + const RDKit::PeriodicTable* table = RDKit::PeriodicTable::getTable(); + const MT offset = offset_carbon ? MT(table->getRcovalent(carbon_atomic_num)) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(table->getRcovalent(graph.atoms[i].atomicNum)) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::ELECTRONEGATIVITY: { + const MT offset = offset_carbon ? MT(electronegativityTable[carbon_atomic_num-1]) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i, data += stride) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + if (atomic_num <= 0 || atomic_num > 118 || electronegativityTable[atomic_num - 1] == 0) { + *data = FeatureValues::nan_value; + continue; + } + *data = FeatureValues::convertToFeatureType(MT(electronegativityTable[atomic_num - 1]) - offset); + } + return; + } + case AtomFloatFeature::IONIZATION: { + const T offset = offset_carbon ? T(firstIonizationTable[carbon_atomic_num-1]) : T(0); + for (uint32_t i = 0; i < num_atoms; ++i, data += stride) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + if (atomic_num <= 0 || atomic_num > 118 || firstIonizationTable[atomic_num - 1] == 0) { + *data = FeatureValues::nan_value; + continue; + } + *data = FeatureValues::convertToFeatureType((MT(firstIonizationTable[atomic_num - 1]) - offset) / MT(5)); + } + return; + } + case AtomFloatFeature::MELTING_POINT: { + const MT offset = offset_carbon ? MT(meltingPointTable[carbon_atomic_num-1]) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i, data += stride) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + if (atomic_num <= 0 || atomic_num > 118 || meltingPointTable[atomic_num - 1] == 0) { + *data = FeatureValues::nan_value; + continue; + } + *data = FeatureValues::convertToFeatureType((MT(meltingPointTable[atomic_num - 1]) - offset) / MT(200)); + } + return; + } + case AtomFloatFeature::METAL: { + for (uint32_t i = 0; i < num_atoms; ++i) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + *data = (atomic_num <= 0 || atomic_num > 118) ? FeatureValues::nan_value : FeatureValues::convertToFeatureType(MT(metalTable[atomic_num - 1])); + data += stride; + } + return; + } + case AtomFloatFeature::GROUP: { + const MT offset = offset_carbon ? MT(atomicNumToGroupTable[carbon_atomic_num - 1]) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + *data = (atomic_num <= 0 || atomic_num > 118) ? FeatureValues::nan_value : FeatureValues::convertToFeatureType(MT(atomicNumToGroupTable[atomic_num - 1]) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::PERIOD: { + const MT offset = offset_carbon ? MT(atomicNumToPeriodTable[carbon_atomic_num - 1]) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + *data = (atomic_num <= 0 || atomic_num > 118) ? FeatureValues::nan_value : FeatureValues::convertToFeatureType(MT(atomicNumToPeriodTable[atomic_num - 1]) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::SINGLE_BOND: + case AtomFloatFeature::AROMATIC_BOND: + case AtomFloatFeature::DOUBLE_BOND: + case AtomFloatFeature::TRIPLE_BOND: + { + const RDKit::ROMol& mol = *graph.mol.get(); + const RDKit::Bond::BondType type = + (feature == AtomFloatFeature::SINGLE_BOND) ? RDKit::Bond::SINGLE : ( + (feature == AtomFloatFeature::AROMATIC_BOND) ? RDKit::Bond::AROMATIC : ( + (feature == AtomFloatFeature::DOUBLE_BOND) ? RDKit::Bond::DOUBLE : ( + RDKit::Bond::TRIPLE))); + for (uint32_t i = 0; i < num_atoms; ++i) { + auto [begin, end] = mol.getAtomBonds(mol.getAtomWithIdx(i)); + uint32_t count = 0; + for (; begin != end; ++begin) { + count += (mol[*begin]->getBondType() == type); + } + *data = FeatureValues::convertToFeatureType(MT(count)); + data += stride; + } + return; + } + case AtomFloatFeature::IS_CARBON: { + const MT offset = offset_carbon ? MT(1) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(graph.atoms[i].atomicNum == carbon_atomic_num) - offset); + data += stride; + } + return; + } + default: + break; + } + + // Missing implementation + assert(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::nan_value; + data += stride; + } +} + +template void get_atom_float_feature(const GraphData& graph, int16_t* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); +template void get_atom_float_feature(const GraphData& graph, float* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); +template void get_atom_float_feature(const GraphData& graph, double* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); + +// This table is from the SingleBondRadius column of graphium/features/periodic_table.csv +const double single_bond_lengths[] = { + 0.32, 0.46, 1.33, 1.02, 0.85, 0.75, 0.71, 0.63, 0.64, + 0.67, 1.55, 1.39, 1.26, 1.16, 1.11, 1.03, 0.99, 0.96, 1.96, + 1.71, 1.48, 1.36, 1.34, 1.22, 1.19, 1.16, 1.11, 1.10, 1.12, + 1.18, 1.24, 1.21, 1.21, 1.16, 1.14, 1.17, 2.10, 1.85, 1.63, + 1.54, 1.47, 1.38, 1.28, 1.25, 1.25, 1.20, 1.28, 1.36, 1.42, + 1.40, 1.40, 1.36, 1.33, 1.31, 2.32, 1.96, 1.80, 1.63, 1.76, + 1.74, 1.73, 1.72, 1.68, 1.69, 1.68, 1.67, 1.66, 1.65, 1.64, + 1.70, 1.62, 1.52, 1.46, 1.37, 1.31, 1.29, 1.22, 1.23, 1.24, + 1.33, 1.44, 1.44, 1.51, 1.45, 1.47, 1.42, 2.23, 2.01, 1.86, + 1.75, 1.69, 1.70, 1.71, 1.72, 1.66, 1.66, 1.68, 1.68, 1.65, + 1.67, 1.73, 1.76, 1.61, 1.57, 1.49, 1.43, 1.41, 1.34, 1.29, + 1.28, 1.21, 1.22, 1.36, 1.43, 1.62, 1.75, 1.65, 1.57, +}; +// This table is from the DoubleBondRadius column of graphium/features/periodic_table.csv +const double double_bond_lengths[] = { + qNaN, qNaN, 1.24, 0.90, 0.78, 0.67, 0.60, 0.57, 0.59, + 0.96, 1.60, 1.32, 1.13, 1.07, 1.02, 0.94, 0.95, 1.07, 1.93, + 1.47, 1.16, 1.17, 1.12, 1.11, 1.05, 1.09, 1.03, 1.01, 1.15, + 1.20, 1.17, 1.11, 1.14, 1.07, 1.09, 1.21, 2.02, 1.57, 1.30, + 1.27, 1.25, 1.21, 1.20, 1.14, 1.10, 1.17, 1.39, 1.44, 1.36, + 1.30, 1.33, 1.28, 1.29, 1.35, 2.09, 1.61, 1.39, 1.37, 1.38, + 1.37, 1.35, 1.34, 1.34, 1.35, 1.35, 1.33, 1.33, 1.33, 1.31, + 1.29, 1.31, 1.28, 1.26, 1.20, 1.19, 1.16, 1.15, 1.12, 1.21, + 1.42, 1.42, 1.35, 1.41, 1.35, 1.38, 1.45, 2.18, 1.73, 1.53, + 1.43, 1.38, 1.34, 1.36, 1.35, 1.35, 1.36, 1.39, 1.40, 1.40, + qNaN, 1.39, qNaN, 1.41, 1.40, 1.36, 1.28, 1.28, 1.25, 1.25, + 1.16, 1.16, 1.37, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, +}; +// This table is from the TripleBondRadius column of graphium/features/periodic_table.csv +const double triple_bond_lengths[] = { + qNaN, qNaN, qNaN, 0.85, 0.73, 0.60, 0.54, 0.53, 0.53, + qNaN, qNaN, 1.27, 1.11, 1.02, 0.94, 0.95, 0.93, 0.96, qNaN, + 1.33, 1.14, 1.08, 1.06, 1.03, 1.03, 1.02, 0.96, 1.01, 1.20, + qNaN, 1.21, 1.14, 1.06, 1.07, 1.10, 1.08, qNaN, 1.39, 1.24, + 1.21, 1.16, 1.13, 1.10, 1.03, 1.06, 1.12, 1.37, qNaN, 1.46, + 1.32, 1.27, 1.21, 1.25, 1.22, qNaN, 1.49, 1.39, 1.31, 1.28, + qNaN, qNaN, qNaN, qNaN, 1.32, qNaN, qNaN, qNaN, qNaN, qNaN, + qNaN, 1.31, 1.22, 1.19, 1.15, 1.10, 1.09, 1.07, 1.10, 1.23, + qNaN, 1.50, 1.37, 1.35, 1.29, 1.38, 1.33, qNaN, 1.59, 1.40, + 1.36, 1.29, 1.18, 1.16, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, + qNaN, qNaN, qNaN, qNaN, 1.31, 1.26, 1.21, 1.19, 1.18, 1.13, + 1.12, 1.18, 1.30, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, +}; + +template +void get_bond_float_feature(const GraphData& graph, T* data, BondFeature feature, size_t stride) { + const uint32_t num_bonds = graph.num_bonds; + switch (feature) { + case BondFeature::TYPE_FLOAT: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (size_t i = 0; i < num_bonds; ++i, data += stride) { + auto type = graph.bonds[i].bondType; + double value = 0; + switch (type) { + case RDKit::Bond::BondType::SINGLE: value = 1.0; break; + case RDKit::Bond::BondType::DOUBLE: value = 2.0; break; + case RDKit::Bond::BondType::TRIPLE: value = 3.0; break; + case RDKit::Bond::BondType::AROMATIC: value = 1.5; break; + default: value = mol.getBondWithIdx(i)->getBondTypeAsDouble(); + } + *data = FeatureValues::convertToFeatureType(value); + } + return; + } + case BondFeature::IN_RING: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (size_t i = 0; i < num_bonds; ++i, data += stride) { + bool is_in_ring = mol.getRingInfo()->numBondRings(i) != 0; + *data = is_in_ring ? FeatureValues::one : FeatureValues::zero; + } + return; + } + case BondFeature::CONJUGATED: { + for (size_t i = 0; i < num_bonds; ++i, data += stride) { + bool is_conjugated = graph.bonds[i].isConjugated; + *data = is_conjugated ? FeatureValues::one : FeatureValues::zero; + } + return; + } + case BondFeature::CONFORMER_BOND_LENGTH: { + RDKit::ROMol& mol = *graph.mol.get(); + if (mol.beginConformers() == mol.endConformers()) { + // Try to generate a conformer + RDKit::DGeomHelpers::EmbedParameters params; + params.enforceChirality = false; + params.ignoreSmoothingFailures = true; + params.useBasicKnowledge = true; + params.useExpTorsionAnglePrefs = true; + params.optimizerForceTol = 0.1; + int id = RDKit::DGeomHelpers::EmbedMolecule(mol, params); + if (id == -1) { + // Failed to generate a conformer + const uint32_t num_bonds = graph.num_bonds; + for (uint32_t i = 0; i < num_bonds; ++i, data += stride) { + *data = FeatureValues::nan_value; + } + return; + } + assert(mol.beginConformers() != mol.endConformers()); + } + const RDKit::Conformer& conformer = mol.getConformer(); + const auto& positions = conformer.getPositions(); + for (uint32_t i = 0; i < num_bonds; ++i, data += stride) { + const uint32_t begin_atom = graph.bonds[i].beginAtomIdx; + const uint32_t end_atom = graph.bonds[i].endAtomIdx; + const RDGeom::Point3D diff = (positions[end_atom] - positions[begin_atom]); + // Unfortunately, the length() function on Point3D is virtual, so compute it manually. + const double length = std::sqrt(diff.x * diff.x + diff.y * diff.y + diff.z * diff.z); + *data = FeatureValues::convertToFeatureType(length); + } + return; + } + case BondFeature::ESTIMATED_BOND_LENGTH: { + for (uint32_t i = 0; i < num_bonds; ++i, data += stride) { + const uint32_t begin_atom = graph.bonds[i].beginAtomIdx; + const uint32_t end_atom = graph.bonds[i].endAtomIdx; + const int atomic_num1 = graph.atoms[begin_atom].atomicNum; + const bool atom1_valid = (atomic_num1 >= 1 && atomic_num1 <= 118); + const int atomic_num2 = graph.atoms[end_atom].atomicNum; + const bool atom2_valid = (atomic_num2 >= 1 && atomic_num2 <= 118); + assert(atom1_valid && atom2_valid); + if (!atom1_valid || !atom2_valid) { + *data = FeatureValues::nan_value; + continue; + } + + const auto type = graph.bonds[i].bondType; + if (type == RDKit::Bond::BondType::SINGLE) { + // All atoms have a single bond length + *data = FeatureValues::convertToFeatureType( + single_bond_lengths[atomic_num1 - 1] + single_bond_lengths[atomic_num2 - 1]); + continue; + } + if (type == RDKit::Bond::BondType::DOUBLE) { + const double length1 = (double_bond_lengths[atomic_num1 - 1] >= 0) ? + double_bond_lengths[atomic_num1 - 1] : single_bond_lengths[atomic_num1 - 1]; + const double length2 = (double_bond_lengths[atomic_num2 - 1] >= 0) ? + double_bond_lengths[atomic_num2 - 1] : single_bond_lengths[atomic_num2 - 1]; + *data = FeatureValues::convertToFeatureType(length1 + length2); + continue; + } + if (type == RDKit::Bond::BondType::TRIPLE) { + const double length1 = (triple_bond_lengths[atomic_num1 - 1] >= 0) ? + triple_bond_lengths[atomic_num1 - 1] : single_bond_lengths[atomic_num1 - 1]; + const double length2 = (triple_bond_lengths[atomic_num2 - 1] >= 0) ? + triple_bond_lengths[atomic_num2 - 1] : single_bond_lengths[atomic_num2 - 1]; + *data = FeatureValues::convertToFeatureType(length1 + length2); + continue; + } + if (type != RDKit::Bond::BondType::AROMATIC) { + *data = FeatureValues::nan_value; + } + + // Aromatic case + double length1 = single_bond_lengths[atomic_num1 - 1]; + double length2 = single_bond_lengths[atomic_num2 - 1]; + if (double_bond_lengths[atomic_num1] >= 0) { + length1 = 0.5 * (length1 + double_bond_lengths[atomic_num1 - 1]); + } + if (double_bond_lengths[atomic_num2] >= 0) { + length2 = 0.5 * (length2 + double_bond_lengths[atomic_num2 - 1]); + } + *data = FeatureValues::convertToFeatureType(length1 + length2); + } + return; + } + default: + // Missing implementation + assert(0); + for (uint32_t i = 0; i < num_bonds; ++i, data += stride) { + *data = FeatureValues::nan_value; + } + return; + } +} + +template void get_bond_float_feature(const GraphData& graph, int16_t* data, BondFeature feature, size_t stride); +template void get_bond_float_feature(const GraphData& graph, float* data, BondFeature feature, size_t stride); +template void get_bond_float_feature(const GraphData& graph, double* data, BondFeature feature, size_t stride); diff --git a/graphium/graphium_cpp/float_features.h b/graphium/graphium_cpp/float_features.h new file mode 100644 index 000000000..b839416c3 --- /dev/null +++ b/graphium/graphium_cpp/float_features.h @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "features.h" + +#include + +#include + +template +void get_atom_float_feature(const GraphData& graph, T* data, AtomFloatFeature feature, size_t stride, bool offset_carbon = true); + +extern template void get_atom_float_feature(const GraphData& graph, int16_t* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); +extern template void get_atom_float_feature(const GraphData& graph, float* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); +extern template void get_atom_float_feature(const GraphData& graph, double* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); + +template +void get_bond_float_feature(const GraphData& graph, T* data, BondFeature feature, size_t stride); + +extern template void get_bond_float_feature(const GraphData& graph, int16_t* data, BondFeature feature, size_t stride); +extern template void get_bond_float_feature(const GraphData& graph, float* data, BondFeature feature, size_t stride); +extern template void get_bond_float_feature(const GraphData& graph, double* data, BondFeature feature, size_t stride); + +// This table is from the Group column of graphium/features/periodic_table.csv +constexpr uint8_t atomicNumToGroupTable[] = { + 1, 18, 1, 2, 13, 14, 15, 16, 17, + 18, 1, 2, 13, 14, 15, 16, 17, 18, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 1, 2, 3, 19, 19, + 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, + 19, 19, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 1, 2, 3, + 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, + 19, 19, 19, 19, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, +}; +constexpr size_t groupCount = 19; + +// This table is from the Period column of graphium/features/periodic_table.csv +constexpr uint8_t atomicNumToPeriodTable[] = { + 1, 1, 2, 2, 2, 2, 2, 2, 2, + 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7 +}; +constexpr size_t periodCount = 7; diff --git a/graphium/graphium_cpp/graphium_cpp.cpp b/graphium/graphium_cpp/graphium_cpp.cpp new file mode 100644 index 000000000..970c72d0a --- /dev/null +++ b/graphium/graphium_cpp/graphium_cpp.cpp @@ -0,0 +1,93 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "features.h" +#include "labels.h" + +// C++ standard library headers +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// RDKit headers +#include +#include +#include +#include +#include +#include + +// PyBind and Torch headers for use by library to be imported by Python +#include +#include +#include +#include + +// RDKit::SmilesToMol uses std::string, so until we replace it, lets use std::string here. +// ("const char*" could avoid an extra allocation, if we do eventually replace use of SmilesToMol.) +std::unique_ptr parse_mol( + const std::string& smiles_string, + bool explicit_H, + bool ordered) { + + // Parse SMILES string with default options + RDKit::SmilesParserParams params; + std::unique_ptr mol{ RDKit::SmilesToMol(smiles_string, params) }; + if (!mol) { + return mol; + } + + if (ordered) { + // Determine a canonical ordering of the atoms + const unsigned int num_atoms = mol->getNumAtoms(); + std::vector atom_order; + RDKit::Canon::rankMolAtoms(*mol, atom_order); + assert(atom_order.size() == num_atoms); + + // Invert the order + std::vector inverse_order(num_atoms); + for (unsigned int i = 0; i < num_atoms; ++i) { + inverse_order[atom_order[i]] = i; + } + + // Reorder the atoms to the canonical order + mol.reset(static_cast(RDKit::MolOps::renumberAtoms(*mol, inverse_order))); + } + if (explicit_H) { + RDKit::MolOps::addHs(*mol); + } + else { + // Default params for SmilesToMol already calls removeHs, + // and calling it again shouldn't have any net effect. + //RDKit::MolOps::removeHs(*mol); + } + return mol; +} + +// This is necessary to export Python functions in a Python module named graphium_cpp. +PYBIND11_MODULE(graphium_cpp, m) { + m.doc() = "graphium C++ plugin"; // Python module docstring + + // Functions in labels.cpp + m.def("load_num_cols_and_dtypes", &load_num_cols_and_dtypes, "Loads from a cache file, a list of integers representing the number of columns in each task, and a list of integers representing the torch ScalarType of the task's data."); + m.def("load_metadata_tensors", &load_metadata_tensors, "Loads from cache files for a specific stage, a torch tensor containing all SMILES strings contatenated, another with the offsets of all SMILES strings, two for the nubmer of nodes and edges in each molecule, and optionally another representing the offsets of molecules in files."); + m.def("load_stats", &load_stats, "Loads from a cache file of a specific task, the stats for each column, for use in denormalization."); + m.def("concatenate_strings", &concatenate_strings, "Accepts a Numpy array of strings or Python list of strings and returns a PyTorch tensor of all of the characters and another tensor containing indices into the other tensor indicating where each string begins."); + m.def("prepare_and_save_data", &prepare_and_save_data, "Accepts a dict mapping dataset (task) names to dicts with \"smiles\", \"labels\", and \"label_offsets\" data, and returns the data that would be returned by load_metadata_tensors, load_stats, and load_num_cols_and_dtypes."); + m.def("load_labels_from_index", &load_labels_from_index, "Loads label data from disk, for a specific stage and molecule."); + m.def("extract_string", &extract_string, "Extracts a single string from a Tensor of contatenated strings."); + + // Functions in features.cpp + m.def("atom_float_feature_names_to_tensor", &atom_float_feature_names_to_tensor, "Accepts feature names and returns a tensor representing them as integers"); + m.def("atom_onehot_feature_names_to_tensor", &atom_onehot_feature_names_to_tensor, "Accepts feature names and returns a tensor representing them as integers"); + m.def("bond_feature_names_to_tensor", &bond_feature_names_to_tensor, "Accepts feature names and returns a tensor representing them as integers"); + m.def("positional_feature_options_to_tensor", &positional_feature_options_to_tensor, "Accepts feature names, levels, and options, and returns a tensor representing them as integers"); + m.def("featurize_smiles", &featurize_smiles, "Accepts a SMILES string and returns tensors representing the features"); +} diff --git a/graphium/graphium_cpp/graphormer.cpp b/graphium/graphium_cpp/graphormer.cpp new file mode 100644 index 000000000..65badd030 --- /dev/null +++ b/graphium/graphium_cpp/graphormer.cpp @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "graphormer.h" + +#include +#include +#include +#include + +template +void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances) { + + // Compute all pairs shortest paths. + // Because this is a sparse graph treated as having unweighted edges, + // BFS on each node is faster than Dijkstra's or Floyd-Warshall's. + + if (queue.capacity() == 0) { + queue.reserve(n); + } + + all_pairs_distances.resize(size_t(n) * n); + std::fill(all_pairs_distances.begin(), all_pairs_distances.end(), T(-1)); + + for (uint32_t start_index = 0; start_index < n; ++start_index) { + queue.resize(0); + size_t queue_head = 0; + queue.push_back({ start_index,0 }); + T* const distances = all_pairs_distances.data() + start_index * n; + while (queue.size() != queue_head) { + auto [current_node, current_distance] = queue[queue_head]; + ++queue_head; + + if (distances[current_node] != T(-1)) { + continue; + } + + distances[current_node] = T(current_distance); + + ++current_distance; + + const uint32_t* neighbor_start = neighbors + neighbor_starts[current_node]; + const uint32_t* neighbor_end = neighbors + neighbor_starts[current_node+1]; + for (; neighbor_start != neighbor_end; ++neighbor_start) { + queue.push_back({ *neighbor_start,current_distance }); + } + } + } +} + +// Explicit instantiations for float and double +template +void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances); +template +void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances); diff --git a/graphium/graphium_cpp/graphormer.h b/graphium/graphium_cpp/graphormer.h new file mode 100644 index 000000000..4a82c67be --- /dev/null +++ b/graphium/graphium_cpp/graphormer.h @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +template +void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances); + +extern template +void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances); +extern template +void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances); diff --git a/graphium/graphium_cpp/labels.cpp b/graphium/graphium_cpp/labels.cpp new file mode 100644 index 000000000..731f8e25b --- /dev/null +++ b/graphium/graphium_cpp/labels.cpp @@ -0,0 +1,1753 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "labels.h" + +#include "features.h" + +// C++ standard library headers +#include +#include +#include + +// RDKit headers +#include +#include +#include +#include + +// Numpy array headers +#include +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include + +#ifdef _WIN32 +// Windows file handling wrappers +#define WIN32_LEAN_AND_MEAN +#include + +using FileType = HANDLE; +const auto INVALID_FILE = INVALID_HANDLE_VALUE; + +static FileType fopen_read_wrapper(const std::filesystem::path& file_path) { + return CreateFileW( + file_path.wstring().c_str(), + GENERIC_READ, + FILE_SHARE_READ, + nullptr, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + nullptr); +} + +static FileType fopen_write_wrapper(const std::filesystem::path& file_path) { + return CreateFileW( + file_path.wstring().c_str(), + GENERIC_WRITE, + 0, + nullptr, + CREATE_ALWAYS, + FILE_ATTRIBUTE_NORMAL, + nullptr); +} + +static size_t fread_wrapper(void* buffer, size_t bytes, FileType file) { + size_t total_bytes_read = 0; + while (bytes > 0) { + // NOTE: ReadFile should support reads up to (2^32 - 1) bytes, + // but might as well limit it to 1GB (2^30 bytes) at a time, + // just in case there are issues at or above 2GB. + const DWORD max_read_size = 1024 * 1024 * 1024; + const DWORD bytes_to_read = (bytes > max_read_size) ? max_read_size : (DWORD)bytes; + DWORD bytes_read; + BOOL success = ReadFile(file, buffer, bytes_to_read, &bytes_read, nullptr); + total_bytes_read += (success ? bytes_read : 0); + if (!success || bytes_read != bytes_to_read) { + return total_bytes_read; + } + bytes -= bytes_read; + } + return total_bytes_read; +} + +static size_t fwrite_wrapper(const void* buffer, size_t bytes, FileType file) { + size_t total_bytes_written = 0; + while (bytes > 0) { + // NOTE: ReadFile should support reads up to (2^32 - 1) bytes, + // but might as well limit it to 1GB (2^30 bytes) at a time, + // just in case there are issues at or above 2GB. + const DWORD max_write_size = 1024 * 1024 * 1024; + const DWORD bytes_to_write = (bytes > max_write_size) ? max_write_size : (DWORD)bytes; + DWORD bytes_written; + BOOL success = WriteFile(file, buffer, bytes_to_write, &bytes_written, nullptr); + total_bytes_written += (success ? bytes_written : 0); + if (!success || bytes_written != bytes_to_write) { + return total_bytes_written; + } + bytes -= bytes_written; + } + return total_bytes_written; +} + +static int fseek_wrapper(FileType file, int64_t file_pointer) { + LARGE_INTEGER file_pointer_union; + file_pointer_union.QuadPart = (LONGLONG)file_pointer; + BOOL success = SetFilePointerEx(file, file_pointer_union, nullptr, FILE_BEGIN); + return (success == 0); +} + +static void fclose_wrapper(FileType file) { + CloseHandle(file); +} + +#else +// Linux file handling wrappers +#include + +using FileType = FILE*; +const auto INVALID_FILE = (FILE*)nullptr; + +static FileType fopen_read_wrapper(const std::filesystem::path& file_path) { + return fopen(file_path.string().c_str(), "rb"); +} + +static FileType fopen_write_wrapper(const std::filesystem::path& file_path) { + return fopen(file_path.string().c_str(), "wb"); +} + +static size_t fread_wrapper(void* buffer, size_t bytes, FileType file) { + return fread(buffer, 1, bytes, file); +} + +static size_t fwrite_wrapper(const void* buffer, size_t bytes, FileType file) { + return fwrite(buffer, 1, bytes, file); +} + +static int fseek_wrapper(FileType file, int64_t file_pointer) { + // NOTE: If these files could ever be larger than 2GB each, fseek won't + // work on platforms where "long" is a 32-bit type (e.g. 32-bit Linux) + return fseek(file, (long)file_pointer, SEEK_SET); +} + +static void fclose_wrapper(FileType file) { + fclose(file); +} + +#endif // End of file handling wrappers + +struct InitNumpyArrayModule { + InitNumpyArrayModule() { + // This imports the numpy array module, and it must be + // called exactly once before numpy array functions are used. + if (_import_array() < 0) { + printf("ERROR: Failed to import numpy.core.multiarray from C++ in graphium_cpp module\n"); + } + } +}; +static void ensure_numpy_array_module_initialized() { + // Function scope static variables will be initialized upon the first call, + // and only once, in a threadsafe manner. + static InitNumpyArrayModule numpy_initializer; +} + +struct MolBriefData { + uint64_t unique_id[2]; + uint32_t num_nodes; + uint32_t num_edges; +}; + +static MolBriefData smiles_to_brief_data( + const std::string& smiles_string, + bool add_self_loop = false, + bool explicit_H = false) { + + // Don't add explicit_H here, in case it affects MolToInchiKey (though it really shouldn't) + std::unique_ptr mol{ parse_mol(smiles_string, false) }; + if (!mol) { + return MolBriefData{ {0,0}, 0, 0 }; + } + + const std::string inchiKeyString = MolToInchiKey(*mol, "/FixedH /SUU /RecMet /KET /15T"); + size_t n = inchiKeyString.size(); + // Format: AAAAAAAAAAAAAA-BBBBBBBBFV-P + // According to https://www.inchi-trust.org/technical-faq/ + assert(n == 27 && inchiKeyString[14] == '-' && inchiKeyString[25] == '-'); + // Convert from capital letter characters to 64-bit integers: + // 13 characters for first integer, 12 characters for 2nd integer. + // Neither should overflow a 64-bit unsigned integer. + uint64_t id0 = (n > 0) ? (inchiKeyString[0] - 'A') : 0; + for (size_t i = 1; i < 13 && i < n; ++i) { + id0 = 26*id0 + (inchiKeyString[i] - 'A'); + } + uint64_t id1 = (13 < n) ? (inchiKeyString[13] - 'A') : 0; + for (size_t i = 15; i < 25 && i < n; ++i) { + id1 = 26*id1 + (inchiKeyString[i] - 'A'); + } + if (26 < n) { + id1 = 26*id1 + (inchiKeyString[26] - 'A'); + } + + // Now handle explicit_H + if (explicit_H) { + RDKit::MolOps::addHs(*mol); + } + else { + // Default params for SmilesToMol already calls removeHs, + // and calling it again shouldn't have any net effect. + //RDKit::MolOps::removeHs(*mol); + } + + return MolBriefData{ + {id0, id1}, + mol->getNumAtoms(), + 2*mol->getNumBonds() + (add_self_loop ? mol->getNumAtoms() : 0) + }; +} + +enum class NormalizationMethod { + NONE, + NORMAL, + UNIT +}; +struct NormalizationOptions { + NormalizationMethod method = NormalizationMethod::NONE; + double min_clipping = -std::numeric_limits::infinity(); + double max_clipping = std::numeric_limits::infinity(); +}; + +constexpr size_t num_mols_per_file = 1024; + +static void get_mol_label_filename( + char filename[25], + uint64_t file_num) { + + size_t filename_index = 0; + while (file_num != 0) { + filename[filename_index] = '0' + (file_num % 10); + ++filename_index; + file_num /= 10; + } + while (filename_index < 7) { + filename[filename_index] = '0'; + ++filename_index; + } + std::reverse(filename, filename + filename_index); + filename[filename_index] = '.'; + filename[filename_index+1] = 't'; + filename[filename_index+2] = 'm'; + filename[filename_index+3] = 'p'; + filename[filename_index+4] = 0; +} + +struct Types { + size_t size; + int numpy_type; + c10::ScalarType torch_type; +}; +constexpr size_t num_supported_types = 3; +constexpr Types supported_types[num_supported_types] = { + {2, NPY_FLOAT16, c10::ScalarType::Half}, + {4, NPY_FLOAT32, c10::ScalarType::Float}, + {8, NPY_FLOAT64, c10::ScalarType::Double} +}; +static bool is_supported_numpy_type(int type) { + return (type == supported_types[0].numpy_type) || + (type == supported_types[1].numpy_type) || + (type == supported_types[2].numpy_type); +}; +static size_t numpy_type_index(int type) { + if (type == supported_types[0].numpy_type) { + return 0; + } + if (type == supported_types[1].numpy_type) { + return 1; + } + if (type == supported_types[2].numpy_type) { + return 2; + } + return num_supported_types; +}; +static size_t torch_type_index(c10::ScalarType type) { + if (type == supported_types[0].torch_type) { + return 0; + } + if (type == supported_types[1].torch_type) { + return 1; + } + if (type == supported_types[2].torch_type) { + return 2; + } + return num_supported_types; +}; + + +constexpr const char*const label_metadata_filename = "label_metadata.tmp"; +constexpr const char*const file_data_offsets_filename = "file_data_offsets.tmp"; +constexpr const char*const concat_smiles_filename = "concat_smiles.tmp"; +constexpr const char*const smiles_offsets_filename = "smiles_offsets.tmp"; +constexpr const char*const num_nodes_filename = "num_nodes.tmp"; +constexpr const char*const num_edges_filename = "num_edges.tmp"; + +static bool save_num_cols_and_dtypes( + const std::filesystem::path& common_path, + const std::vector& label_num_cols, + const std::vector& label_data_types) { + + const uint64_t num_labels = label_num_cols.size(); + if (num_labels != label_data_types.size()) { + return false; + } + std::filesystem::path file_path(common_path / label_metadata_filename); + FileType file = fopen_write_wrapper(file_path); + if (file == INVALID_FILE) { + return false; + } + size_t num_bytes_written = fwrite_wrapper(&num_labels, sizeof(num_labels), file); + num_bytes_written += fwrite_wrapper(label_num_cols.data(), sizeof(label_num_cols[0])*num_labels, file); + num_bytes_written += fwrite_wrapper(label_data_types.data(), sizeof(label_data_types[0])*num_labels, file); + fclose_wrapper(file); + if (num_bytes_written != sizeof(num_labels) + (sizeof(label_num_cols[0]) + sizeof(label_data_types[0]))*num_labels) { + return false; + } + return true; +} + +std::tuple< + std::vector, + std::vector +> load_num_cols_and_dtypes( + const std::string& processed_graph_data_path, + const std::string& data_hash) { + + std::vector label_num_cols; + std::vector label_data_types; + std::filesystem::path file_path( + std::filesystem::path(processed_graph_data_path) / data_hash / label_metadata_filename + ); + FileType file = fopen_read_wrapper(file_path); + if (file == INVALID_FILE) { + return std::make_tuple(std::move(label_num_cols), std::move(label_data_types)); + } + uint64_t num_labels = 0; + size_t num_bytes_read = fread_wrapper(&num_labels, sizeof(num_labels), file); + // Trying to allocate 2^60 would fail, unless it overflows and then crashes + if (num_bytes_read != sizeof(num_labels) || num_labels == 0 || num_labels >= (uint64_t(1) << (64-4))) { + fclose_wrapper(file); + return std::make_tuple(std::move(label_num_cols), std::move(label_data_types)); + } + label_num_cols.resize(num_labels, 0); + num_bytes_read = fread_wrapper(label_num_cols.data(), sizeof(label_num_cols[0])*num_labels, file); + if (num_bytes_read != sizeof(label_num_cols[0])*num_labels) { + fclose_wrapper(file); + label_num_cols.resize(0); + return std::make_tuple(std::move(label_num_cols), std::move(label_data_types)); + } + label_data_types.resize(num_labels, -1); + num_bytes_read = fread_wrapper(label_data_types.data(), sizeof(label_data_types[0])*num_labels, file); + fclose_wrapper(file); + if (num_bytes_read != sizeof(label_data_types[0])*num_labels) { + label_num_cols.resize(0); + label_data_types.resize(0); + } + return std::make_tuple(std::move(label_num_cols), std::move(label_data_types)); +} + +template +bool save_array_to_file( + const std::filesystem::path& directory, + const char*const filename, + const T* data, + const uint64_t n) { + + std::filesystem::path file_path(directory / filename); + FileType file = fopen_write_wrapper(file_path); + if (file == INVALID_FILE) { + return false; + } + size_t num_bytes_written = fwrite_wrapper(&n, sizeof(n), file); + num_bytes_written += fwrite_wrapper(data, sizeof(T)*n, file); + fclose_wrapper(file); + if (num_bytes_written != sizeof(n) + sizeof(T)*n) { + return false; + } + return true; +} + + +template +[[nodiscard]] uint64_t load_array_from_file( + const std::filesystem::path& directory, + const char*const filename, + std::unique_ptr& data) { + + data.reset(nullptr); + + std::filesystem::path file_path(directory / filename); + FileType file = fopen_read_wrapper(file_path); + if (file == INVALID_FILE) { + return 0; + } + uint64_t n; + size_t num_bytes_read = fread_wrapper(&n, sizeof(n), file); + // Trying to allocate 2^60 would fail, unless it overflows and then crashes + if (num_bytes_read != sizeof(n) || n == 0 || n >= (uint64_t(1) << (64-4))) { + fclose_wrapper(file); + return 0; + } + data.reset(new T[n]); + num_bytes_read = fread_wrapper(data.get(), sizeof(T)*n, file); + fclose_wrapper(file); + if (num_bytes_read != sizeof(T)*n) { + data.reset(nullptr); + return 0; + } + return n; +} + +std::vector load_metadata_tensors( + const std::string processed_graph_data_path, + const std::string stage, + const std::string data_hash) { + + std::filesystem::path base_path{processed_graph_data_path}; + std::filesystem::path directory = base_path / (stage + "_" + data_hash); + + std::unique_ptr concatenated_smiles; + uint64_t concatenated_smiles_size = + load_array_from_file(directory, concat_smiles_filename, concatenated_smiles); + + std::unique_ptr smiles_offsets; + uint64_t num_smiles_offsets = + load_array_from_file(directory, smiles_offsets_filename, smiles_offsets); + + std::unique_ptr num_nodes; + uint64_t num_num_nodes = + load_array_from_file(directory, num_nodes_filename, num_nodes); + + std::unique_ptr num_edges; + uint64_t num_num_edges = + load_array_from_file(directory, num_edges_filename, num_edges); + + std::unique_ptr mol_data_offsets; + uint64_t num_mol_data_offsets = + load_array_from_file(directory, file_data_offsets_filename, mol_data_offsets); + + if (num_num_nodes == 0 || num_num_edges != num_num_nodes || num_smiles_offsets != (num_num_nodes+1) || + concatenated_smiles_size == 0 || concatenated_smiles_size != uint64_t(smiles_offsets[num_num_edges]) || + (num_mol_data_offsets != num_num_nodes + (num_num_nodes + num_mols_per_file-1)/num_mols_per_file && num_mol_data_offsets != 0)) { + printf("ERROR: graphium_cpp.load_metadata_tensors failed to load valid metadata files\n"); + printf(" len(concat_smiles) is %zu\n", size_t(concatenated_smiles_size)); + printf(" len(smiles_offsets) is %zu\n", size_t(num_smiles_offsets)); + printf(" len(num_nodes) is %zu\n", size_t(num_num_nodes)); + printf(" len(num_edges) is %zu\n", size_t(num_num_edges)); + printf(" len(file_data_offsets) is %zu\n", size_t(num_mol_data_offsets)); + return std::vector(); + } + + // The above conditions should ensure that none of these arrays are empty, + // but assert in debug builds just in case. + assert(concatenated_smiles && smiles_offsets && num_nodes && num_edges); + + const int64_t concatenated_smiles_dims[1] = { int64_t(concatenated_smiles_size) }; + at::Tensor smiles_tensor = torch_tensor_from_array(std::move(concatenated_smiles), concatenated_smiles_dims, 1, c10::ScalarType::Char); + const int64_t smiles_offsets_dims[1] = { int64_t(num_num_nodes+1) }; + at::Tensor smiles_offsets_tensor = torch_tensor_from_array(std::move(smiles_offsets), smiles_offsets_dims, 1, c10::ScalarType::Long); + const int64_t num_nodes_dims[1] = { int64_t(num_num_nodes) }; + at::Tensor num_nodes_tensor = torch_tensor_from_array(std::move(num_nodes), num_nodes_dims, 1, c10::ScalarType::Int); + const int64_t num_edges_dims[1] = { int64_t(num_num_nodes) }; + at::Tensor num_edges_tensor = torch_tensor_from_array(std::move(num_edges), num_edges_dims, 1, c10::ScalarType::Int); + + std::vector stage_return_data; + stage_return_data.reserve((num_mol_data_offsets > 0) ? 5 : 4); + + stage_return_data.push_back(std::move(smiles_tensor)); + stage_return_data.push_back(std::move(smiles_offsets_tensor)); + stage_return_data.push_back(std::move(num_nodes_tensor)); + stage_return_data.push_back(std::move(num_edges_tensor)); + + if (num_mol_data_offsets > 0) { + const int64_t data_offsets_dims[1] = { int64_t(num_mol_data_offsets) }; + at::Tensor data_offsets_tensor = torch_tensor_from_array(std::move(mol_data_offsets), data_offsets_dims, 1, c10::ScalarType::Long); + + stage_return_data.push_back(std::move(data_offsets_tensor)); + } + + return stage_return_data; +} + +std::vector load_stats( + const std::string processed_graph_data_path, + const std::string data_hash, + const std::string task_name) { + + std::filesystem::path base_path{processed_graph_data_path}; + std::filesystem::path directory = base_path / data_hash; + const std::string filename(task_name + "_stats.tmp"); + + std::unique_ptr task_stats; + uint64_t num_stat_floats = + load_array_from_file(directory, filename.c_str(), task_stats); + + if (num_stat_floats == 0 || num_stat_floats % 4 != 0) { + return std::vector(); + } + + const uint64_t num_cols = num_stat_floats / 4; + std::vector return_stats(4); + for (size_t stat_index = 0; stat_index < 4; ++stat_index) { + std::unique_ptr single_stat(new double[num_cols]); + for (size_t i = 0; i < num_cols; ++i) { + single_stat[i] = task_stats[4*i + stat_index]; + } + const int64_t stat_dims[1] = { int64_t(num_cols) }; + at::Tensor stat_tensor = torch_tensor_from_array(std::move(single_stat), stat_dims, 1, c10::ScalarType::Double); + return_stats.push_back(std::move(stat_tensor)); + } + + return return_stats; +} + +std::pair concatenate_strings(pybind11::handle handle) { + using return_type = std::pair; + + ensure_numpy_array_module_initialized(); + + at::Tensor concatenated_strings; + at::Tensor offsets; + + PyObject* obj_ptr = handle.ptr(); + if (PyArray_Check(obj_ptr)) { + PyArrayObject* numpy_array = reinterpret_cast(obj_ptr); + int type_num = PyArray_TYPE(numpy_array); + int ndims = PyArray_NDIM(numpy_array); + if (type_num != NPY_OBJECT || ndims != 1) { + return return_type(std::move(concatenated_strings), std::move(offsets)); + } + intptr_t n = PyArray_DIM(numpy_array, 0); + if (n <= 0) { + return return_type(std::move(concatenated_strings), std::move(offsets)); + } + + size_t total_characters = 0; + for (intptr_t i = 0; i < n; ++i) { + pybind11::handle string_handle(*(PyObject**)PyArray_GETPTR1(numpy_array, i)); + if (!pybind11::isinstance(string_handle)) { + continue; + } + // TODO: Consider trying to avoid constructing std::string here + std::string string{pybind11::str{string_handle}}; + // +1 is for null terminator + total_characters += string.size() + 1; + } + std::unique_ptr concatenated_chars(new char[total_characters]); + std::unique_ptr offsets_buffer(new int64_t[n+1]); + int64_t offset = 0; + for (intptr_t i = 0; i < n; ++i) { + offsets_buffer[i] = offset; + pybind11::handle string_handle(*(PyObject**)PyArray_GETPTR1(numpy_array, i)); + if (!pybind11::isinstance(string_handle)) { + continue; + } + // TODO: Consider trying to avoid constructing std::string here + std::string string{pybind11::str{string_handle}}; + memcpy(concatenated_chars.get(), string.c_str(), string.size()); + offset += string.size(); + concatenated_chars[offset] = 0; + ++offset; + } + offsets_buffer[n] = offset; + + const int64_t concatenated_strings_dims[1] = { int64_t(total_characters) }; + concatenated_strings = torch_tensor_from_array(std::move(concatenated_chars), concatenated_strings_dims, 1, c10::ScalarType::Char); + const int64_t offsets_dims[1] = { int64_t(n+1) }; + offsets = torch_tensor_from_array(std::move(offsets_buffer), offsets_dims, 1, c10::ScalarType::Long); + } + if (pybind11::isinstance(handle)) { + pybind11::list list = handle.cast(); + size_t n = list.size(); + + size_t total_characters = 0; + for (size_t i = 0; i < n; ++i) { + pybind11::handle string_handle(list[i]); + if (!pybind11::isinstance(string_handle)) { + continue; + } + // TODO: Consider trying to avoid constructing std::string here + std::string string{pybind11::str{string_handle}}; + // +1 is for null terminator + total_characters += string.size() + 1; + } + std::unique_ptr concatenated_chars(new char[total_characters]); + std::unique_ptr offsets_buffer(new int64_t[n+1]); + int64_t offset = 0; + for (size_t i = 0; i < n; ++i) { + offsets_buffer[i] = offset; + pybind11::handle string_handle(list[i]); + if (!pybind11::isinstance(string_handle)) { + continue; + } + // TODO: Consider trying to avoid constructing std::string here + std::string string{pybind11::str{string_handle}}; + memcpy(concatenated_chars.get(), string.c_str(), string.size()); + offset += string.size(); + concatenated_chars[offset] = 0; + ++offset; + } + offsets_buffer[n] = offset; + + const int64_t concatenated_strings_dims[1] = { int64_t(total_characters) }; + concatenated_strings = torch_tensor_from_array(std::move(concatenated_chars), concatenated_strings_dims, 1, c10::ScalarType::Char); + const int64_t offsets_dims[1] = { int64_t(n+1) }; + offsets = torch_tensor_from_array(std::move(offsets_buffer), offsets_dims, 1, c10::ScalarType::Long); + } + return return_type(std::move(concatenated_strings), std::move(offsets)); +} + +constexpr size_t num_stages = 3; +// NOTE: Computing stats below depends on that "train" is stage 0. +const std::string stages[num_stages] = { + std::string("train"), + std::string("val"), + std::string("test") +}; + + +// Returns: +// stage -> [ +// unique mol smiles strings all concatenated, +// unique mol smiles string offsets (including one extra for the end), +// unique mol num_nodes, +// unique mol num_edges, +// mol_file_data_offsets +// ] +// task -> 4 stats tensors each +// task index -> label num columns +// task index -> label torch data type enum +std::tuple< + std::unordered_map>, + std::unordered_map>, + std::vector, + std::vector +> prepare_and_save_data( + const pybind11::list& task_names, + pybind11::dict& task_dataset_args, + const pybind11::dict& task_label_normalization, + const std::string processed_graph_data_path, + const std::string data_hash, + const pybind11::dict& task_train_indices, + const pybind11::dict& task_val_indices, + const pybind11::dict& task_test_indices, + bool add_self_loop, + bool explicit_H, + int max_threads) { + + ensure_numpy_array_module_initialized(); + + std::filesystem::path base_path{processed_graph_data_path}; + std::filesystem::create_directories(base_path); + std::filesystem::path common_path(base_path / data_hash); + std::filesystem::create_directories(common_path); + std::filesystem::path stage_paths[num_stages] = { + base_path / (stages[0] + "_" + data_hash), + base_path / (stages[1] + "_" + data_hash), + base_path / (stages[2] + "_" + data_hash) + }; + std::filesystem::create_directories(stage_paths[0]); + std::filesystem::create_directories(stage_paths[1]); + std::filesystem::create_directories(stage_paths[2]); + const pybind11::dict* stage_task_indices[num_stages] = { + &task_train_indices, + &task_val_indices, + &task_test_indices + }; + + const size_t num_tasks = task_names.size(); + std::vector return_label_num_cols(num_tasks, 0); + std::vector return_label_data_types(num_tasks, -1); + size_t total_num_cols = 0; + std::unique_ptr task_col_starts(new size_t[num_tasks+1]); + std::unique_ptr task_bytes_per_float(new size_t[num_tasks]); + std::unique_ptr task_normalization_options(new NormalizationOptions[num_tasks]); + std::unique_ptr smiles_numpy_arrays(new PyArrayObject*[num_tasks]); + std::unique_ptr labels_numpy_arrays(new PyArrayObject*[num_tasks]); + std::unique_ptr label_offsets_numpy_arrays(new PyArrayObject*[num_tasks]); + // Figure out the task bounds first, so that everything can be parallelized perfectly. + size_t task_index = 0; + for (const auto& task : task_names) { + const size_t current_task_index = task_index; + task_col_starts[current_task_index] = total_num_cols; + task_bytes_per_float[current_task_index] = 0; + smiles_numpy_arrays[current_task_index] = nullptr; + labels_numpy_arrays[current_task_index] = nullptr; + label_offsets_numpy_arrays[current_task_index] = nullptr; + ++task_index; + if (!pybind11::isinstance(task)) { + continue; + } + const std::string task_name{ pybind11::str(task) }; + pybind11::handle task_dataset_handle = pybind11::handle(PyDict_GetItemString(task_dataset_args.ptr(), task_name.c_str())); + if (!task_dataset_handle || !pybind11::isinstance(task_dataset_handle)) { + continue; + } + pybind11::dict dataset_dict = task_dataset_handle.cast(); + pybind11::handle smiles_handle = pybind11::handle(PyDict_GetItemString(dataset_dict.ptr(), "smiles")); + if (!smiles_handle) { + continue; + } + PyObject* smiles_obj_ptr = smiles_handle.ptr(); + if (!PyArray_Check(smiles_obj_ptr)) { + continue; + } + PyArrayObject* smiles_numpy_array = reinterpret_cast(smiles_obj_ptr); + int smiles_type_num = PyArray_TYPE(smiles_numpy_array); + int smiles_ndims = PyArray_NDIM(smiles_numpy_array); + if (smiles_type_num != NPY_OBJECT || smiles_ndims != 1) { + continue; + } + intptr_t num_smiles = PyArray_DIM(smiles_numpy_array, 0); + if (num_smiles <= 0) { + continue; + } + + // smiles array is okay + smiles_numpy_arrays[current_task_index] = smiles_numpy_array; + + // Check for labels. There might not be labels in inference case. + pybind11::handle labels_handle = pybind11::handle(PyDict_GetItemString(dataset_dict.ptr(), "labels")); + if (!labels_handle) { + continue; + } + pybind11::handle label_offsets_handle = pybind11::handle(PyDict_GetItemString(dataset_dict.ptr(), "label_offsets")); + PyObject* labels_obj_ptr = labels_handle.ptr(); + PyObject* label_offsets_obj_ptr = label_offsets_handle.ptr(); + const bool is_labels_numpy = PyArray_Check(labels_obj_ptr); + const bool is_labels_multi_row = label_offsets_obj_ptr && PyArray_Check(label_offsets_obj_ptr); + if (!is_labels_numpy) { + continue; + } + PyArrayObject* labels_numpy_array = reinterpret_cast(labels_obj_ptr); + PyArrayObject* label_offsets_numpy_array = is_labels_multi_row ? reinterpret_cast(label_offsets_obj_ptr) : nullptr; + int labels_type_num = PyArray_TYPE(labels_numpy_array); + int labels_ndims = PyArray_NDIM(labels_numpy_array); +#if GRAPHIUM_CPP_DEBUGGING + printf("\"%s\" labels numpy type %d, %d dims\n", task_name.c_str(), labels_type_num, labels_ndims); +#endif + if (!is_supported_numpy_type(labels_type_num) || labels_ndims != 2) { + continue; + } + if (is_labels_multi_row) { + int label_offsets_type_num = PyArray_TYPE(label_offsets_numpy_array); + int label_offsets_ndims = PyArray_NDIM(label_offsets_numpy_array); + // Only int64 is supported, for simplicity + if (label_offsets_type_num != NPY_INT64 || label_offsets_ndims != 1) { + continue; + } + } + intptr_t num_label_rows = PyArray_DIM(labels_numpy_array, 0); + intptr_t num_molecules = num_label_rows; + if (is_labels_multi_row) { + intptr_t num_offsets_rows = PyArray_DIM(label_offsets_numpy_array, 0); + if (num_offsets_rows == 0) { + continue; + } + // -1 is because last offset is the end offset + num_molecules = num_offsets_rows - 1; + + // Verify that the first offset is zero + if (*(const int64_t*)PyArray_GETPTR1(label_offsets_numpy_array, 0) != 0) { + continue; + } + // Verify that the last offset is the end offset + if (*(const int64_t*)PyArray_GETPTR1(label_offsets_numpy_array, num_molecules) != num_label_rows) { + continue; + } + } + intptr_t num_label_cols = PyArray_DIM(labels_numpy_array, 1); +#if GRAPHIUM_CPP_DEBUGGING + printf("\"%s\" labels[%zd][%zd] (%zd molecules)\n", task_name.c_str(), num_label_rows, num_label_cols, num_molecules); +#endif + if (num_smiles != num_molecules || num_label_cols <= 0) { + continue; + } + + const size_t supported_type_index = numpy_type_index(labels_type_num); + const size_t bytes_per_float = supported_types[supported_type_index].size; + labels_numpy_arrays[current_task_index] = labels_numpy_array; + label_offsets_numpy_arrays[current_task_index] = is_labels_multi_row ? label_offsets_numpy_array : nullptr; + return_label_num_cols[current_task_index] = num_label_cols; + return_label_data_types[current_task_index] = int(supported_types[supported_type_index].torch_type); + total_num_cols += size_t(num_label_cols); + task_bytes_per_float[current_task_index] = bytes_per_float; + + pybind11::handle task_normalization_handle = pybind11::handle(PyDict_GetItemString(task_label_normalization.ptr(), task_name.c_str())); + if (!task_normalization_handle || !pybind11::isinstance(task_normalization_handle)) { + continue; + } + pybind11::dict normalization_dict = task_normalization_handle.cast(); + pybind11::handle method_handle = pybind11::handle(PyDict_GetItemString(normalization_dict.ptr(), "method")); + pybind11::handle min_handle = pybind11::handle(PyDict_GetItemString(normalization_dict.ptr(), "min_clipping")); + pybind11::handle max_handle = pybind11::handle(PyDict_GetItemString(normalization_dict.ptr(), "max_clipping")); + if (method_handle && pybind11::isinstance(method_handle)) { + std::string method{pybind11::str(method_handle)}; + if (strcmp(method.c_str(), "normal") == 0) { + task_normalization_options[current_task_index].method = NormalizationMethod::NORMAL; + } + else if (strcmp(method.c_str(), "unit") == 0) { + task_normalization_options[current_task_index].method = NormalizationMethod::UNIT; + } + } + if (min_handle && pybind11::isinstance(min_handle)) { + task_normalization_options[current_task_index].min_clipping = double(int64_t(min_handle.cast())); + } + else if (min_handle && pybind11::isinstance(min_handle)) { + task_normalization_options[current_task_index].min_clipping = double(min_handle.cast()); + } + if (max_handle && pybind11::isinstance(max_handle)) { + task_normalization_options[current_task_index].max_clipping = double(int64_t(max_handle.cast())); + } + else if (max_handle && pybind11::isinstance(max_handle)) { + task_normalization_options[current_task_index].max_clipping = double(max_handle.cast()); + } + } + task_col_starts[num_tasks] = total_num_cols; + + if (total_num_cols > 0) { + save_num_cols_and_dtypes(common_path, return_label_num_cols, return_label_data_types); + } + + // Get the total number of molecules, by stage and task + size_t total_num_mols = 0; + for (size_t stage_index = 0; stage_index < num_stages; ++stage_index) { + const pybind11::dict& task_indices_dict = *stage_task_indices[stage_index]; + + for (size_t task_index = 0; task_index < num_tasks; ++task_index) { + pybind11::handle task = task_names[task_index]; + if (!smiles_numpy_arrays[task_index]) { + continue; + } + const std::string task_name{ pybind11::str(task) }; + pybind11::handle task_indices_handle = pybind11::handle(PyDict_GetItemString(task_indices_dict.ptr(), task_name.c_str())); + if (!task_indices_handle || !pybind11::isinstance(task_indices_handle)) { + printf("Error: Task %s indices list isn't valid.\n", task_name.c_str()); + continue; + } + const pybind11::list task_indices_list = task_indices_handle.cast(); + const size_t current_num_mols = task_indices_list.size(); + if (current_num_mols == 0) { + printf("Error: Task %s indices list is empty.\n", task_name.c_str()); + } + total_num_mols += current_num_mols; + } + } + + // Get the mol indices for all stages and tasks + std::vector task_mol_indices; + task_mol_indices.reserve(total_num_mols); + std::vector task_mol_start(num_stages*num_tasks + 1); + // Unfortunately, reading strings from a numpy array isn't threadsafe, + // so we have to do that single-threaded first, too. + std::vector smiles_strings; + smiles_strings.reserve(total_num_mols); + for (size_t stage_index = 0; stage_index < num_stages; ++stage_index) { + const pybind11::dict& task_indices_dict = *stage_task_indices[stage_index]; + + for (size_t task_index = 0; task_index < num_tasks; ++task_index) { + // Update task_mol_start here, in case any indices aren't integers + // or any SMILES strings aren't strings below. + task_mol_start[stage_index*num_tasks + task_index] = task_mol_indices.size(); + + pybind11::handle task = task_names[task_index]; + if (!smiles_numpy_arrays[task_index]) { + continue; + } + const std::string task_name{ pybind11::str(task) }; + pybind11::handle task_indices_handle = pybind11::handle(PyDict_GetItemString(task_indices_dict.ptr(), task_name.c_str())); + if (!task_indices_handle || !pybind11::isinstance(task_indices_handle)) { + continue; + } + + const pybind11::list task_indices_list = task_indices_handle.cast(); + const size_t current_num_mols = task_indices_list.size(); + + PyArrayObject*const smiles_numpy_array = smiles_numpy_arrays[task_index]; + const size_t smiles_array_size = PyArray_DIM(smiles_numpy_array, 0); + + for (size_t indices_index = 0; indices_index < current_num_mols; ++indices_index) { + const auto list_item = task_indices_list[indices_index]; + if (!pybind11::isinstance(list_item)) { + continue; + } + + size_t task_mol_index = size_t(list_item.cast()); + if (task_mol_index >= smiles_array_size) { + continue; + } + + pybind11::handle single_smiles_handle(*(PyObject**)PyArray_GETPTR1(smiles_numpy_array, task_mol_index)); + if (!pybind11::isinstance(single_smiles_handle)) { + continue; + } + + task_mol_indices.push_back(task_mol_index); + smiles_strings.push_back(std::string(pybind11::str(single_smiles_handle))); + } + + } + } + total_num_mols = task_mol_indices.size(); + task_mol_start[num_stages*num_tasks] = total_num_mols; + + struct MolKey { + uint64_t id0; + uint64_t id1; + uint32_t num_nodes; + uint32_t num_edges; + uint64_t task_index; + uint64_t task_mol_index; + uint64_t mol_index; + + bool operator<(const MolKey& other) const { + if (id0 != other.id0) { + return (id0 < other.id0); + } + if (id1 != other.id1) { + return (id1 < other.id1); + } + if (num_nodes != other.num_nodes) { + return (num_nodes < other.num_nodes); + } + if (num_edges != other.num_edges) { + return (num_edges < other.num_edges); + } + if (task_index != other.task_index) { + return (task_index < other.task_index); + } + return (task_mol_index < other.task_mol_index); + } + + // This is used for identifying keys of molecules with invalid SMILES strings. + // They show up as having no nodes, no edges, and ID 0. + bool isInvalid() const { + return id0 == 0 && id1 == 0 && num_nodes == 0 && num_edges == 0; + } + }; + + // Compute all InChI keys for all molecules, in parallel if applicable. + std::unique_ptr keys(new MolKey[total_num_mols]); + + // Determine the number of threads to use for computing MolKey values + const size_t num_mols_per_block = 512; + const size_t num_blocks = (total_num_mols + num_mols_per_block-1) / num_mols_per_block; + const size_t num_processors = std::thread::hardware_concurrency(); + size_t num_threads = (num_processors == 1 || num_blocks <= 4) ? 1 : std::min(num_processors, num_blocks/2); + // max_threads of -1 means n-1 threads, to avoid starving other processes + if (max_threads < 0) { + max_threads += num_processors; + // Don't hit zero or remain negative, because that would skip applying the limit + if (max_threads < 1) { + max_threads = 1; + } + } + // max_threads of 0 means to not limit the number of threads + if (max_threads > 0 && num_threads > size_t(max_threads)) { + num_threads = size_t(max_threads); + } + + auto&& get_single_mol_key = [&task_mol_start,add_self_loop,explicit_H,&task_mol_indices,&smiles_strings,num_tasks](size_t mol_index) -> MolKey { + // Find which task this mol is in. If there could be many tasks, + // this could be a binary search, but for small numbers of tasks, + // a linear search is fine. + size_t task_index = 0; + while (task_mol_start[task_index+1] <= mol_index) { + ++task_index; + } + const size_t task_mol_index = task_mol_indices[mol_index]; + + const std::string& smiles_str = smiles_strings[mol_index]; + MolBriefData mol_data = smiles_to_brief_data(smiles_str, add_self_loop, explicit_H); + + return MolKey{mol_data.unique_id[0], mol_data.unique_id[1], mol_data.num_nodes, mol_data.num_edges, task_index % num_tasks, task_mol_index, mol_index}; + }; + if (num_threads == 1) { + for (size_t mol_index = 0; mol_index < total_num_mols; ++mol_index) { + keys[mol_index] = get_single_mol_key(mol_index); + } + } + else { + std::atomic next_block_index(0); + auto&& thread_functor = [&keys,&next_block_index,num_blocks,num_mols_per_block,total_num_mols,&get_single_mol_key]() { + while (true) { + const size_t block_index = next_block_index.fetch_add(1); + if (block_index >= num_blocks) { + return; + } + const size_t begin_index = block_index * num_mols_per_block; + const size_t end_index = std::min((block_index+1) * num_mols_per_block, total_num_mols); + for (size_t mol_index = begin_index; mol_index < end_index; ++mol_index) { + keys[mol_index] = get_single_mol_key(mol_index); + } + } + }; + std::vector threads; + for (size_t thread_index = 0; thread_index < num_threads; ++thread_index) { + threads.push_back(std::thread(thread_functor)); + } + for (size_t thread_index = 0; thread_index < num_threads; ++thread_index) { + threads[thread_index].join(); + } + } + + // Compute stats on the train stage only (stage 0), like how the python code did it. + // Normalization will be applied to all stages later. + // TODO: Does it matter that stats calculations will include all copies of molecules + // that occur multiple times in the same dataset? + constexpr size_t stat_min_offset = 0; + constexpr size_t stat_max_offset = 1; + constexpr size_t stat_mean_offset = 2; + constexpr size_t stat_std_offset = 3; + constexpr size_t num_stats = 4; + size_t stats_floats = num_stats*total_num_cols; + std::unique_ptr all_task_stats((stats_floats > 0) ? new double[stats_floats] : nullptr); + std::unordered_map> all_stats_return_data; + + if (total_num_cols > 0) { + std::unique_ptr all_task_num_non_nan(new intptr_t[total_num_cols]); + for (size_t task_index = 0; task_index < num_tasks; ++task_index) { + const size_t task_num_mols = task_mol_start[task_index+1] - task_mol_start[task_index]; + const size_t task_first_col = task_col_starts[task_index]; + const size_t task_num_cols = task_col_starts[task_index+1] - task_first_col; + if (task_num_mols == 0 || task_num_cols == 0) { + continue; + } + // Initialize stats for accumulation + double*const task_stats = all_task_stats.get() + num_stats*task_first_col; + intptr_t*const task_num_non_nan = all_task_num_non_nan.get() + task_first_col; + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index) { + task_stats[num_stats*task_col_index + stat_min_offset] = std::numeric_limits::infinity(); + task_stats[num_stats*task_col_index + stat_max_offset] = -std::numeric_limits::infinity(); + task_stats[num_stats*task_col_index + stat_mean_offset] = 0.0; + task_stats[num_stats*task_col_index + stat_std_offset] = 0.0; + task_num_non_nan[task_col_index] = 0; + } + + const size_t bytes_per_float = task_bytes_per_float[task_index]; + + auto&& update_stats_single_row = [task_stats, task_num_non_nan](const char* col_data, const size_t task_num_cols, const size_t bytes_per_float, const intptr_t col_stride) { + double* stats = task_stats; + intptr_t* num_non_nan = task_num_non_nan; + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index, col_data += col_stride, stats += num_stats, ++num_non_nan) { + // TODO: Move the type check outside the loop if it's a bottleneck + double value; + if (bytes_per_float == sizeof(double)) { + value = *(const double*)(col_data); + } + else if (bytes_per_float == sizeof(float)) { + value = *(const float*)(col_data); + } + else { + assert(bytes_per_float == sizeof(uint16_t)); + value = c10::detail::fp16_ieee_to_fp32_value(*(const uint16_t*)(col_data)); + } + if (value != value) { + // NaN value, so skip it + continue; + } + stats[stat_min_offset] = std::min(stats[stat_min_offset], value); + stats[stat_max_offset] = std::max(stats[stat_max_offset], value); + stats[stat_mean_offset] += value; + // TODO: If summing the squares isn't accurate enough for computing the variance, + // consider other approaches. + stats[stat_std_offset] += value*value; + ++(*num_non_nan); + } + }; + + PyArrayObject*const labels_numpy_array = labels_numpy_arrays[task_index]; + if (labels_numpy_array != nullptr) { + const char* raw_data = (const char*)PyArray_DATA(labels_numpy_array); + const intptr_t* strides = PyArray_STRIDES(labels_numpy_array); + const intptr_t num_label_rows = PyArray_DIM(labels_numpy_array, 0); + PyArrayObject*const label_offsets_numpy_array = label_offsets_numpy_arrays[task_index]; + const char* offsets_raw_data = label_offsets_numpy_array ? (const char*)PyArray_DATA(label_offsets_numpy_array) : nullptr; + const intptr_t offsets_stride = label_offsets_numpy_array ? PyArray_STRIDES(label_offsets_numpy_array)[0] : 0; + // The -1 is because there's an extra entry at the end for the end offset. + const intptr_t num_mols = label_offsets_numpy_array ? PyArray_DIM(label_offsets_numpy_array, 0) - 1 : num_label_rows; + // The normalization is computed on the subsample being kept + for (size_t task_key_index = 0; task_key_index < task_num_mols; ++task_key_index) { + const size_t task_mol_index = keys[task_mol_start[task_index] + task_key_index].task_mol_index; + if (task_mol_index >= size_t(num_mols)) { + printf("Error: In task %zu, mol index %zu is past limit of %zu\n", size_t(task_index), task_mol_index, size_t(num_mols)); + continue; + } + if (offsets_raw_data == nullptr) { + const char* row_data = raw_data + strides[0]*task_mol_index; + update_stats_single_row(row_data, task_num_cols, bytes_per_float, strides[1]); + } + else { + size_t begin_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*task_mol_index); + size_t end_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*(task_mol_index+1)); + const char* row_data = raw_data + strides[0]*begin_offset; + for (size_t row = begin_offset; row < end_offset; ++row, row_data += strides[0]) { + update_stats_single_row(row_data, task_num_cols, bytes_per_float, strides[1]); + } + } + } + } + +#if GRAPHIUM_CPP_DEBUGGING + printf("Task %zu normalization method %zu\n", size_t(task_index), size_t(task_normalization_options[task_index].method)); + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index) { + printf("Task %zu col %zu, num non-nan = %zu, min = %e, max = %e\n", + size_t(task_index), task_col_index, + size_t(task_num_non_nan[task_col_index]), + task_stats[num_stats*task_col_index + stat_min_offset], + task_stats[num_stats*task_col_index + stat_max_offset]); + } +#endif + } + + for (size_t task_index = 0; task_index < num_tasks; ++task_index) { + const size_t task_first_col = task_col_starts[task_index]; + const size_t task_num_cols = task_col_starts[task_index+1] - task_first_col; + if (task_num_cols == 0) { + continue; + } + + // Finish accumulation + double*const task_stats = all_task_stats.get() + num_stats*task_first_col; + intptr_t*const task_num_non_nan = all_task_num_non_nan.get() + task_first_col; + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index) { + if (task_num_non_nan[task_col_index] == 0) { + task_stats[num_stats*task_col_index + stat_min_offset] = std::numeric_limits::quiet_NaN(); + task_stats[num_stats*task_col_index + stat_max_offset] = std::numeric_limits::quiet_NaN(); + task_stats[num_stats*task_col_index + stat_mean_offset] = std::numeric_limits::quiet_NaN(); + task_stats[num_stats*task_col_index + stat_std_offset] = std::numeric_limits::quiet_NaN(); + } + else { + if (task_normalization_options[task_index].min_clipping > task_stats[num_stats*task_col_index + stat_min_offset]) { + task_stats[num_stats*task_col_index + stat_min_offset] = task_normalization_options[task_index].min_clipping; + } + if (task_normalization_options[task_index].max_clipping < task_stats[num_stats*task_col_index + stat_max_offset]) { + task_stats[num_stats*task_col_index + stat_max_offset] = task_normalization_options[task_index].max_clipping; + } + const double n = double(task_num_non_nan[task_col_index]); + const double mean = task_stats[num_stats*task_col_index + stat_mean_offset] / n; + task_stats[num_stats*task_col_index + stat_mean_offset] = mean; + // sum((x[i] - m)^2)/(n-1) + // = sum(x[i]^2 -2mx[i] + m^2)/(n-1) + // = (sum(x[i]^2) - 2nm^2 + nm^2)/(n-1) + // = (sum(x[i]^2) - nm^2)/(n-1) + // except, for compatibility with numpy.nanstd, use n instead of n-1 + const double sum_sqaures = task_stats[num_stats*task_col_index + stat_std_offset]; + const double stdev = std::sqrt((sum_sqaures - n*mean*mean)/n); + task_stats[num_stats*task_col_index + stat_std_offset] = stdev; + } + } + + const std::string task_name{ pybind11::str(task_names[task_index]) }; +#if GRAPHIUM_CPP_DEBUGGING + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index) { + printf("%s %zu %lld %e %e %e %e\n", + task_name.c_str(), task_col_index, (long long)task_num_non_nan[task_col_index], + task_stats[num_stats*task_col_index + stat_min_offset], + task_stats[num_stats*task_col_index + stat_max_offset], + task_stats[num_stats*task_col_index + stat_mean_offset], + task_stats[num_stats*task_col_index + stat_std_offset]); + } +#endif + const std::string stats_filename = task_name + "_stats.tmp"; + save_array_to_file(common_path, stats_filename.c_str(), task_stats, num_stats*task_num_cols); + + // Make copies for returning in a format similar to the load_stats function. + std::vector task_stats_out; + for (size_t stat_index = 0; stat_index < num_stats; ++stat_index) { + const int64_t task_stats_dims[1] = { int64_t(task_num_cols) }; + std::unique_ptr task_stats_copy(new double[task_num_cols]); + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index) { + task_stats_copy[task_col_index] = task_stats[num_stats*task_col_index + stat_index]; + } + at::Tensor task_stats_tensor = torch_tensor_from_array(std::move(task_stats_copy), task_stats_dims, 1, c10::ScalarType::Double); + task_stats_out.push_back(std::move(task_stats_tensor)); + } + all_stats_return_data.insert(std::make_pair(std::move(task_name), std::move(task_stats_out))); + } + } + + // Sort train, val, and test separately, since they need to be stored separately. + // Don't sort until after accumulating stats, because the code above currently assumes that the tasks + // aren't interleaved. + std::sort(keys.get(), keys.get() + task_mol_start[num_tasks]); + std::sort(keys.get() + task_mol_start[num_tasks], keys.get() + task_mol_start[2*num_tasks]); + std::sort(keys.get() + task_mol_start[2*num_tasks], keys.get() + total_num_mols); + + std::unordered_map> per_stage_return_data; + + // Deal with non-label data first + for (size_t stage_index = 0; stage_index < num_stages; ++stage_index) { + size_t concatenated_smiles_size = 0; + uint64_t num_unique_mols = 0; + const size_t stage_begin_index = task_mol_start[stage_index*num_tasks]; + const size_t stage_end_index = task_mol_start[(stage_index+1)*num_tasks]; + for (size_t sorted_index = stage_begin_index; sorted_index < stage_end_index; ) { + if (keys[sorted_index].isInvalid()) { + ++sorted_index; + continue; + } + ++num_unique_mols; + + // Add the length of the smiles string to the total length, + // and include the terminating zero + const size_t smiles_length = smiles_strings[keys[sorted_index].mol_index].size(); + concatenated_smiles_size += (smiles_length+1); + + const uint64_t id0 = keys[sorted_index].id0; + const uint64_t id1 = keys[sorted_index].id1; + ++sorted_index; + while (sorted_index < stage_end_index && keys[sorted_index].id0 == id0 && keys[sorted_index].id1 == id1) { + ++sorted_index; + } + } + + std::unique_ptr concatenated_smiles(new char[concatenated_smiles_size]); + std::unique_ptr smiles_offsets(new int64_t[num_unique_mols+1]); + std::unique_ptr num_nodes(new int32_t[num_unique_mols]); + std::unique_ptr num_edges(new int32_t[num_unique_mols]); + size_t unique_index = 0; + int64_t smiles_offset = 0; + for (size_t sorted_index = stage_begin_index; sorted_index < stage_end_index; ) { + if (keys[sorted_index].isInvalid()) { + ++sorted_index; + continue; + } + smiles_offsets[unique_index] = smiles_offset; + + const uint64_t id0 = keys[sorted_index].id0; + const uint64_t id1 = keys[sorted_index].id1; + num_nodes[unique_index] = keys[sorted_index].num_nodes; + num_edges[unique_index] = keys[sorted_index].num_edges; + + // Copy the string + const std::string& smiles_string = smiles_strings[keys[sorted_index].mol_index]; + const size_t smiles_length = smiles_string.size(); + memcpy(concatenated_smiles.get() + smiles_offset, smiles_string.c_str(), smiles_length); + smiles_offset += smiles_length; + // Don't forget the terminating zero + concatenated_smiles[smiles_offset] = 0; + ++smiles_offset; + + ++unique_index; + ++sorted_index; + while (sorted_index < stage_end_index && keys[sorted_index].id0 == id0 && keys[sorted_index].id1 == id1) { + ++sorted_index; + } + } + smiles_offsets[unique_index] = smiles_offset; + + save_array_to_file(stage_paths[stage_index], concat_smiles_filename, concatenated_smiles.get(), concatenated_smiles_size); + save_array_to_file(stage_paths[stage_index], smiles_offsets_filename, smiles_offsets.get(), num_unique_mols+1); + save_array_to_file(stage_paths[stage_index], num_nodes_filename, num_nodes.get(), num_unique_mols); + save_array_to_file(stage_paths[stage_index], num_edges_filename, num_edges.get(), num_unique_mols); + + const int64_t concatenated_smiles_dims[1] = { int64_t(concatenated_smiles_size) }; + at::Tensor smiles_tensor = torch_tensor_from_array(std::move(concatenated_smiles), concatenated_smiles_dims, 1, c10::ScalarType::Char); + const int64_t smiles_offsets_dims[1] = { int64_t(num_unique_mols+1) }; + at::Tensor smiles_offsets_tensor = torch_tensor_from_array(std::move(smiles_offsets), smiles_offsets_dims, 1, c10::ScalarType::Long); + const int64_t num_nodes_dims[1] = { int64_t(num_unique_mols) }; + at::Tensor num_nodes_tensor = torch_tensor_from_array(std::move(num_nodes), num_nodes_dims, 1, c10::ScalarType::Int); + const int64_t num_edges_dims[1] = { int64_t(num_unique_mols) }; + at::Tensor num_edges_tensor = torch_tensor_from_array(std::move(num_edges), num_edges_dims, 1, c10::ScalarType::Int); + + std::vector stage_return_data; + // Reserve space for one extra, for the data offsets tensor later + stage_return_data.reserve((total_num_cols > 0) ? 5 : 4); + stage_return_data.push_back(std::move(smiles_tensor)); + stage_return_data.push_back(std::move(smiles_offsets_tensor)); + stage_return_data.push_back(std::move(num_nodes_tensor)); + stage_return_data.push_back(std::move(num_edges_tensor)); + per_stage_return_data.insert(std::make_pair(stages[stage_index], std::move(stage_return_data))); + } + + if (total_num_cols == 0) { + // No label data, so all done + return std::make_tuple( + std::move(per_stage_return_data), + std::move(all_stats_return_data), + std::move(return_label_num_cols), + std::move(return_label_data_types)); + } + + // mol_data_offsets will only need one entry for each unique molecule, + // plus one per file, but we can preallocate an upper bound. + std::vector mol_data_offsets; + size_t upper_bound_num_files = (task_mol_start[num_tasks] + num_mols_per_file-1) / num_mols_per_file; + mol_data_offsets.reserve(task_mol_start[num_tasks] + upper_bound_num_files); + + // temp_data is used for normalization + std::vector temp_data; + temp_data.reserve(total_num_cols*sizeof(double)); + + std::vector data; + data.reserve(num_mols_per_file*(total_num_cols*sizeof(double) + (1+2*num_tasks)*sizeof(uint64_t))); + + // Now, deal with label data + for (size_t stage_index = 0; stage_index < num_stages; ++stage_index) { + mol_data_offsets.resize(0); + assert(data.size() == 0); + uint64_t num_unique_mols = 0; + const size_t stage_begin_index = task_mol_start[stage_index*num_tasks]; + const size_t stage_end_index = task_mol_start[(stage_index+1)*num_tasks]; + for (size_t sorted_index = stage_begin_index; sorted_index < stage_end_index; ) { + if (keys[sorted_index].isInvalid()) { + ++sorted_index; + continue; + } + size_t data_offset = data.size(); + mol_data_offsets.push_back(data_offset); + + const size_t first_sorted_index = sorted_index; + const uint64_t id0 = keys[sorted_index].id0; + const uint64_t id1 = keys[sorted_index].id1; + + uint64_t prev_task_index = keys[sorted_index].task_index; + uint64_t mol_num_tasks = 1; + ++sorted_index; + while (sorted_index < stage_end_index && keys[sorted_index].id0 == id0 && keys[sorted_index].id1 == id1) { + // The same molecule can occur multiple times in a single dataset, + // but we only want to keep one copy for each task. + if (keys[sorted_index].task_index != prev_task_index) { + ++mol_num_tasks; + prev_task_index = keys[sorted_index].task_index; + } + ++sorted_index; + } + assert(mol_num_tasks <= num_tasks); + + // TODO: Double data capacity as needed if resizing is slow + assert(data.size() == data_offset); + data.resize(data_offset + sizeof(uint64_t)*(1+2*mol_num_tasks)); + + // Copy in the number of tasks for this molecule, followed by a list of the task indices and their end offsets. + memcpy(data.data() + data_offset, &mol_num_tasks, sizeof(uint64_t)); + data_offset += sizeof(uint64_t); + uint64_t task_offset = 0; + // Start with an invalid prev_task_index to pick up the first task + prev_task_index = uint64_t(int64_t(-1)); + for (size_t i = first_sorted_index; i < sorted_index; ++i) { + const uint64_t task_index = keys[i].task_index; + // The same molecule can occur multiple times in a single dataset, + // but we only want to keep one copy for each task. + if (task_index == prev_task_index) { + continue; + } + prev_task_index = task_index; + size_t num_cols = task_col_starts[task_index+1] - task_col_starts[task_index]; + PyArrayObject*const label_offsets_numpy_array = label_offsets_numpy_arrays[task_index]; + if (label_offsets_numpy_array != nullptr) { + const size_t task_mol_index = keys[i].task_mol_index; + const char* offsets_raw_data = (const char*)PyArray_DATA(label_offsets_numpy_array); + const intptr_t offsets_stride = PyArray_STRIDES(label_offsets_numpy_array)[0]; + const int64_t begin_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*task_mol_index); + const int64_t end_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*(task_mol_index+1)); + const size_t current_rows = size_t(end_offset - begin_offset); + num_cols *= current_rows; + } + task_offset += task_bytes_per_float[task_index]*num_cols; + memcpy(data.data() + data_offset, &task_index, sizeof(uint64_t)); + data_offset += sizeof(uint64_t); + memcpy(data.data() + data_offset, &task_offset, sizeof(uint64_t)); + data_offset += sizeof(uint64_t); + } + + // TODO: Double data capacity as needed if resizing is slow + assert(data.size() == data_offset); + data.resize(data_offset + task_offset); + + auto&& store_single_row = [&data_offset, &data, &temp_data]( + const char* col_data, + const size_t task_num_cols, + const intptr_t col_stride, + const size_t in_bytes_per_float, + const size_t out_bytes_per_float, + const NormalizationMethod normalization_method, + const double* task_stats) { + + if (size_t(col_stride) == in_bytes_per_float) { + memcpy(temp_data.data(), col_data, in_bytes_per_float*task_num_cols); + } + else { + for (size_t col = 0; col < task_num_cols; ++col) { + memcpy(temp_data.data() + col*in_bytes_per_float, col_data, in_bytes_per_float); + col_data += col_stride; + } + } + for (size_t col = 0; col < task_num_cols; ++col) { + double value; + if (in_bytes_per_float == sizeof(double)) { + value = ((const double*)(temp_data.data()))[col]; + } + else if (in_bytes_per_float == sizeof(float)) { + value = ((const float*)(temp_data.data()))[col]; + } + else { + assert(in_bytes_per_float == sizeof(uint16_t)); + value = c10::detail::fp16_ieee_to_fp32_value(((const uint16_t*)(temp_data.data()))[col]); + } + value = std::max(value, task_stats[stat_min_offset]); + value = std::min(value, task_stats[stat_max_offset]); + if (normalization_method == NormalizationMethod::NORMAL) { + if (task_stats[stat_std_offset] != 0) { + value = (value - task_stats[stat_mean_offset])/task_stats[stat_std_offset]; + } + else { + value = 0; + } + } + else if (normalization_method == NormalizationMethod::UNIT) { + // TODO: Cache 1/(max-min) or 0 to avoid check + if (task_stats[stat_max_offset] - task_stats[stat_min_offset] != 0) { + value = (value - task_stats[stat_min_offset])/(task_stats[stat_max_offset] - task_stats[stat_min_offset]); + } + else { + value = 0; + } + } + + // NOTE: The code below writes to temp_data, which is still being read from above, + // so this relies on that we're not writing to a larger data type than we're reading, + // else we'll overwrite data. + assert(out_bytes_per_float <= in_bytes_per_float); + if (out_bytes_per_float == sizeof(double)) { + ((double*)(temp_data.data()))[col] = value; + } + else if (out_bytes_per_float == sizeof(float)) { + ((float*)(temp_data.data()))[col] = float(value); + } + else { + assert(out_bytes_per_float == sizeof(uint16_t)); + ((uint16_t*)(temp_data.data()))[col] = c10::detail::fp16_ieee_from_fp32_value(value); + } + task_stats += num_stats; + } + + memcpy(data.data() + data_offset, temp_data.data(), out_bytes_per_float*task_num_cols); + data_offset += out_bytes_per_float*task_num_cols; + }; + + // Copy in the task data, with optional normalization + // Start with an invalid prev_task_index to pick up the first task + prev_task_index = uint64_t(int64_t(-1)); + for (size_t i = first_sorted_index; i < sorted_index; ++i) { + const uint64_t task_index = keys[i].task_index; + // The same molecule can occur multiple times in a single dataset, + // but we only want to keep one copy for each task. + if (task_index == prev_task_index) { + continue; + } + prev_task_index = task_index; + + const uint64_t task_mol_index = keys[i].task_mol_index; + + const size_t task_first_col = task_col_starts[task_index]; + const size_t task_num_cols = task_col_starts[task_index+1] - task_first_col; + const NormalizationOptions& normalization = task_normalization_options[task_index]; + const double* task_stats = all_task_stats.get() + num_stats*task_first_col; + + const size_t bytes_per_float = task_bytes_per_float[task_index]; + + PyArrayObject*const labels_numpy_array = labels_numpy_arrays[task_index]; + if (labels_numpy_array != nullptr) { + const char* raw_data = (const char*)PyArray_DATA(labels_numpy_array); + const intptr_t* strides = PyArray_STRIDES(labels_numpy_array); + PyArrayObject*const label_offsets_numpy_array = label_offsets_numpy_arrays[task_index]; + const char* offsets_raw_data = label_offsets_numpy_array ? (const char*)PyArray_DATA(label_offsets_numpy_array) : nullptr; + const intptr_t offsets_stride = label_offsets_numpy_array ? PyArray_STRIDES(label_offsets_numpy_array)[0] : 0; + if (offsets_raw_data == nullptr) { + const char* row_data = raw_data + strides[0]*task_mol_index; + store_single_row(row_data, task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, task_stats); + } + else { + size_t begin_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*task_mol_index); + size_t end_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*(task_mol_index+1)); + const char* row_data = raw_data + strides[0]*begin_offset; + for (size_t row = begin_offset; row < end_offset; ++row, row_data += strides[0]) { + store_single_row(row_data, task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, task_stats); + } + } + } + } + + ++num_unique_mols; + if (num_unique_mols % num_mols_per_file == 0 || sorted_index == stage_end_index) { + // Write out the data to a file + + // First, construct the filename + char filename[20+4+1]; + size_t file_num = ((num_unique_mols-1) / num_mols_per_file); + get_mol_label_filename(filename, file_num); + + std::filesystem::path file_path(stage_paths[stage_index] / filename); + FileType file = fopen_write_wrapper(file_path); + if (file == INVALID_FILE) { + return std::make_tuple( + std::move(per_stage_return_data), + std::move(all_stats_return_data), + std::move(return_label_num_cols), + std::move(return_label_data_types)); + } +#if GRAPHIUM_CPP_DEBUGGING + printf("Writing file %s\n", file_path.string().c_str()); +#endif + size_t num_bytes_written = fwrite_wrapper(data.data(), data_offset, file); + fclose_wrapper(file); + if (num_bytes_written != data_offset) { + return std::make_tuple( + std::move(per_stage_return_data), + std::move(all_stats_return_data), + std::move(return_label_num_cols), + std::move(return_label_data_types)); + } + data.resize(0); + + // One extra data offset to mark the end of each file. + // data_offset is automatically reset to 0 on the next iteration + // due to data.size() being 0 now. + mol_data_offsets.push_back(data_offset); + } + } + + // Write out the molecule data offsets to a separate file, + // so that only one file read is needed per molecule when data loading + // if the offsets are all loaded once and kept in memory. + // Note the one extra entry per file. +#if GRAPHIUM_CPP_DEBUGGING + printf("Stage %s has %zu unique mols from %zu original\n", stages[stage_index].c_str(), size_t(num_unique_mols), size_t(stage_end_index - stage_begin_index)); +#endif + assert(mol_data_offsets.size() == num_unique_mols + (num_unique_mols + num_mols_per_file-1)/num_mols_per_file); + std::filesystem::path file_path(stage_paths[stage_index] / "mol_offsets.tmp"); + FileType file = fopen_write_wrapper(file_path); + if (file == INVALID_FILE) { + return std::make_tuple( + std::move(per_stage_return_data), + std::move(all_stats_return_data), + std::move(return_label_num_cols), + std::move(return_label_data_types)); + } + size_t num_bytes_written = fwrite_wrapper(&num_unique_mols, sizeof(num_unique_mols), file); + if (num_bytes_written != sizeof(num_unique_mols)) { + fclose_wrapper(file); + return std::make_tuple( + std::move(per_stage_return_data), + std::move(all_stats_return_data), + std::move(return_label_num_cols), + std::move(return_label_data_types)); + } + size_t num_offsets = mol_data_offsets.size(); + size_t data_offsets_size = num_offsets*sizeof(mol_data_offsets[0]); + num_bytes_written = fwrite_wrapper(mol_data_offsets.data(), data_offsets_size, file); + fclose_wrapper(file); + if (num_bytes_written != data_offsets_size) { + return std::make_tuple( + std::move(per_stage_return_data), + std::move(all_stats_return_data), + std::move(return_label_num_cols), + std::move(return_label_data_types)); + } + + static_assert(sizeof(int64_t) == sizeof(mol_data_offsets[0])); + save_array_to_file(stage_paths[stage_index], file_data_offsets_filename, mol_data_offsets.data(), num_offsets); + std::unique_ptr temp_data_offsets(new int64_t[num_offsets]); + memcpy(temp_data_offsets.get(), mol_data_offsets.data(), data_offsets_size); + const int64_t data_offsets_dims[1] = { int64_t(num_offsets) }; + at::Tensor data_offsets_tensor = torch_tensor_from_array(std::move(temp_data_offsets), data_offsets_dims, 1, c10::ScalarType::Long); + + per_stage_return_data[stages[stage_index]].push_back(std::move(data_offsets_tensor)); + mol_data_offsets.resize(0); + } + + return std::make_tuple( + std::move(per_stage_return_data), + std::move(all_stats_return_data), + std::move(return_label_num_cols), + std::move(return_label_data_types)); +} + +void load_labels_from_index( + const std::string stage_directory, + const int64_t mol_index, + const at::Tensor& mol_file_data_offsets, + const pybind11::list& label_names, + const pybind11::list& label_num_cols, + const pybind11::list& label_data_types, + pybind11::dict& labels +) { + const std::filesystem::path stage_path{stage_directory}; + if (mol_index < 0) { + printf("Error: In load_labels_from_index, mol_index = %lld\n", (long long)mol_index); + return; + } + const uint64_t file_num = uint64_t(mol_index) / num_mols_per_file; + const size_t index_into_offsets = file_num*(num_mols_per_file+1) + (uint64_t(mol_index) % num_mols_per_file); + + const size_t num_data_offsets = (mol_file_data_offsets.scalar_type() == c10::ScalarType::Long && mol_file_data_offsets.ndimension() == 1) ? mol_file_data_offsets.size(0) : 0; + if (index_into_offsets+1 >= num_data_offsets) { + printf("Error: In load_labels_from_index, mol_index = %zu, index_into_offsets = %zu, num_data_offsets = %zu\n", + size_t(mol_index), size_t(index_into_offsets), size_t(num_data_offsets)); + return; + } + // NOTE: If TensorBase::data_ptr is ever removed, change it to TensorBase::const_data_ptr. + // Some torch version being used doesn't have const_data_ptr yet. + const int64_t* const data_offsets = mol_file_data_offsets.data_ptr(); + const int64_t file_begin_offset = data_offsets[index_into_offsets]; + const int64_t file_end_offset = data_offsets[index_into_offsets+1]; + if (file_end_offset < 0 || file_end_offset-file_begin_offset < 8) { + printf("Error: In load_labels_from_index, mol_index = %zu, file_begin_offset = %lld, file_end_offset = %lld\n", + size_t(mol_index), (long long)(index_into_offsets), (long long)(num_data_offsets)); + return; + } + const size_t file_read_size = size_t(file_end_offset - file_begin_offset); + + std::unique_ptr data(new char[file_read_size]); + + { + char filename[25]; + get_mol_label_filename(filename, file_num); + + const std::filesystem::path file_path{stage_path / filename}; + FileType file = fopen_read_wrapper(file_path); + if (file == INVALID_FILE) { + printf("Error: In load_labels_from_index, failed to open \"%s\" for molecule %zu\n", + file_path.string().c_str(), size_t(mol_index)); + return; + } + int seek_failed = fseek_wrapper(file, file_begin_offset); + if (seek_failed) { + printf("Error: In load_labels_from_index, failed to seek to offset %zu in \"%s\" for molecule %zu\n", + size_t(file_begin_offset), file_path.string().c_str(), size_t(mol_index)); + fclose_wrapper(file); + return; + } + size_t num_bytes_read = fread_wrapper(data.get(), file_read_size, file); + fclose_wrapper(file); + if (num_bytes_read != file_read_size) { + printf("Error: In load_labels_from_index, read only %zu/%zu bytes from \"%s\" for molecule %zu\n", + size_t(num_bytes_read), size_t(file_read_size), file_path.string().c_str(), size_t(mol_index)); + return; + } + } + + uint64_t mol_num_tasks = 0; + memcpy(&mol_num_tasks, data.get(), sizeof(uint64_t)); + size_t data_offset = sizeof(uint64_t); + if (mol_num_tasks == 0 || mol_num_tasks > label_names.size() || file_read_size < (1+2*mol_num_tasks)*sizeof(uint64_t)) { + printf("Error: In load_labels_from_index, mol_index = %zu, mol_num_tasks = %zu, file_read_size = %zu\n", + size_t(mol_index), size_t(mol_num_tasks), size_t(file_read_size)); + return; + } + const size_t base_offset = (1+2*mol_num_tasks)*sizeof(uint64_t); + const char* base_task_data = data.get() + base_offset; + uint64_t task_offset = 0; + for (size_t data_task_index = 0; data_task_index < mol_num_tasks; ++data_task_index) { + uint64_t task_index = 0; + memcpy(&task_index, data.get() + data_offset, sizeof(uint64_t)); + data_offset += sizeof(uint64_t); + if (task_index >= label_names.size() || task_index >= label_data_types.size() || task_index >= label_num_cols.size()) { + printf("Error: In load_labels_from_index, mol_index = %zu, task_index = %zu\n", + size_t(mol_index), size_t(task_index)); + return; + } + + uint64_t task_end_offset = 0; + memcpy(&task_end_offset, data.get() + data_offset, sizeof(uint64_t)); + data_offset += sizeof(uint64_t); + if (task_end_offset < task_offset || task_end_offset > file_read_size-base_offset) { + printf("Error: In load_labels_from_index, mol_index = %zu, task_offset = %zu, task_end_offset = %zu, file_read_size = %zu, base_offset = %zu\n", + size_t(mol_index), size_t(task_offset), size_t(task_end_offset), size_t(file_read_size), size_t(base_offset)); + return; + } + + const size_t task_num_bytes = task_end_offset - task_offset; + if (!pybind11::isinstance(label_data_types[task_index]) || + !pybind11::isinstance(label_num_cols[task_index])) { + printf("Error: In load_labels_from_index, mol_index = %zu, task_index = %zu, label_data_type = \"%s\", label_num_cols = \"%s\"\n", + size_t(mol_index), size_t(task_index), + std::string(pybind11::str(label_data_types[task_index])).c_str(), + std::string(pybind11::str(label_num_cols[task_index])).c_str()); + return; + } + const c10::ScalarType torch_type = c10::ScalarType(size_t(label_data_types[task_index].cast())); + const size_t num_cols = size_t(label_num_cols[task_index].cast()); + if (num_cols == 0) { + printf("Error: In load_labels_from_index, mol_index = %zu, task_index = %zu, label_data_type = %zu, label_num_cols = %zu\n", + size_t(mol_index), size_t(task_index), + size_t(torch_type), num_cols); + return; + } + const size_t supported_type_index = torch_type_index(torch_type); + if (supported_type_index >= num_supported_types) { + printf("Error: In load_labels_from_index, mol_index = %zu, task_index = %zu, label_data_type = %zu, label_num_cols = %zu\n", + size_t(mol_index), size_t(task_index), + size_t(torch_type), num_cols); + } + const size_t bytes_per_float = supported_types[supported_type_index].size; + const size_t num_floats = task_num_bytes / bytes_per_float; + const size_t num_rows = num_floats / num_cols; + + if (num_floats != num_rows*num_cols) { + printf("Error: In load_labels_from_index, mol_index = %zu, task data bytes = %zu (not a multiple of %zu*%zu)\n", + size_t(mol_index), size_t(task_num_bytes), bytes_per_float, num_cols); + return; + } + + const std::string label_name{pybind11::str(label_names[task_index])}; + const bool is_graph_level = (std::strncmp(label_name.c_str(), "graph", 5) == 0); + if (is_graph_level && num_rows != 1) { + printf("Error: In load_labels_from_index, mol_index = %zu, num_rows = %zu for task \"%s\"\n", + size_t(mol_index), num_rows, label_name.c_str()); + return; + } + size_t num_label_dims = is_graph_level ? 1 : 2; + const int64_t label_dims[2] = { (is_graph_level ? int64_t(num_floats) : int64_t(num_rows)), int64_t(num_cols) }; + at::Tensor label_tensor; + + if (bytes_per_float == 2) { + std::unique_ptr label_data(new uint16_t[num_floats]); + memcpy(label_data.get(), base_task_data + task_offset, task_num_bytes); + label_tensor = torch_tensor_from_array(std::move(label_data), label_dims, num_label_dims, torch_type); + } + else if (bytes_per_float == 4) { + std::unique_ptr label_data(new float[num_floats]); + memcpy(label_data.get(), base_task_data + task_offset, task_num_bytes); + label_tensor = torch_tensor_from_array(std::move(label_data), label_dims, num_label_dims, torch_type); + } + else if (bytes_per_float == 8) { + std::unique_ptr label_data(new double[num_floats]); + memcpy(label_data.get(), base_task_data + task_offset, task_num_bytes); + label_tensor = torch_tensor_from_array(std::move(label_data), label_dims, num_label_dims, torch_type); + } + + PyDict_SetItem(labels.ptr(), label_names[task_index].ptr(), THPVariable_Wrap(std::move(label_tensor))); + + task_offset = task_end_offset; + } +} + +std::string extract_string( + const at::Tensor& concat_strings, + const at::Tensor& string_offsets, + const int64_t index) { + + const size_t data_size = (concat_strings.scalar_type() == c10::ScalarType::Char && concat_strings.ndimension() == 1) ? concat_strings.size(0) : 0; + const size_t num_data_offsets = (string_offsets.scalar_type() == c10::ScalarType::Long && string_offsets.ndimension() == 1) ? string_offsets.size(0) : 0; + if (index < 0 || size_t(index) >= num_data_offsets) { + return std::string(); + } + const char* const data = reinterpret_cast(concat_strings.data_ptr()); + const int64_t* const data_offsets = string_offsets.data_ptr(); + int64_t offset = data_offsets[index]; + int64_t end_offset = data_offsets[index+1]; + int64_t size = (end_offset - offset) - 1; + if (offset < 0 || size < 0 || end_offset > int64_t(data_size)) { + return std::string(); + } + return std::string(data + offset, size_t(size)); +} diff --git a/graphium/graphium_cpp/labels.h b/graphium/graphium_cpp/labels.h new file mode 100644 index 000000000..30498750d --- /dev/null +++ b/graphium/graphium_cpp/labels.h @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +// Torch tensor headers +#include +#include +#include + +// PyBind and Torch headers +#include +#include +#include + +// The following functions are in labels.cpp, and declared here so that +// graphium_cpp.cpp can expose them to Python via pybind. +std::tuple< + std::vector, + std::vector +> load_num_cols_and_dtypes( + const std::string& processed_graph_data_path, + const std::string& data_hash); + +std::vector load_metadata_tensors( + const std::string processed_graph_data_path, + const std::string stage, + const std::string data_hash); + +std::vector load_stats( + const std::string processed_graph_data_path, + const std::string data_hash, + const std::string task_name); + +std::pair concatenate_strings(pybind11::handle handle); + +std::tuple< + std::unordered_map>, + std::unordered_map>, + std::vector, + std::vector +> prepare_and_save_data( + const pybind11::list& task_names, + pybind11::dict& task_dataset_args, + const pybind11::dict& task_label_normalization, + const std::string processed_graph_data_path, + const std::string data_hash, + const pybind11::dict& task_train_indices, + const pybind11::dict& task_val_indices, + const pybind11::dict& task_test_indices, + bool add_self_loop = false, + bool explicit_H = false, + int max_threads = 0); + +void load_labels_from_index( + const std::string stage_directory, + const int64_t mol_index, + const at::Tensor& mol_file_data_offsets, + const pybind11::list& label_names, + const pybind11::list& label_num_cols, + const pybind11::list& label_data_types, + pybind11::dict& labels); + +std::string extract_string( + const at::Tensor& concat_strings, + const at::Tensor& string_offsets, + const int64_t index); diff --git a/graphium/graphium_cpp/one_hot.cpp b/graphium/graphium_cpp/one_hot.cpp new file mode 100644 index 000000000..47485569e --- /dev/null +++ b/graphium/graphium_cpp/one_hot.cpp @@ -0,0 +1,361 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "one_hot.h" +#include "features.h" +#include "float_features.h" + +#include +#include + +#include +#include +#include +#include + +template +class OneHotLookup { + size_t indices[NUM_IN]; +public: + constexpr OneHotLookup(const size_t list[MAX_OUT]) : indices() { + std::fill(indices, indices + NUM_IN, MAX_OUT); + for (size_t i = 0; i < MAX_OUT; ++i) { + indices[list[i]] = i; + } + } + constexpr size_t operator[](size_t i) const { + return (i < NUM_IN) ? indices[i] : MAX_OUT; + } +}; + +// This list of elements matches ATOM_LIST in older file graphium/features/nmp.py +constexpr size_t atomicNumList[] = { + 6 -1, // C + 7 -1, // N + 8 -1, // O + 16-1,// S + 9 -1, // F + 14-1,// Si + 15-1,// P + 17-1,// Cl + 35-1,// Br + 12-1,// Mg + 11-1,// Na + 20-1,// Ca + 26-1,// Fe + 33-1,// As + 13-1,// Al + 53-1,// I + 5 -1,// B + 23-1,// V + 19-1,// K + 81-1,// Tl + 70-1,// Yb + 51-1,// Sb + 50-1,// Sn + 47-1,// Ag + 46-1,// Pd + 27-1,// Co + 34-1,// Se + 22-1,// Ti + 30-1,// Zn + 1 -1,// H + 3 -1,// Li + 32-1,// Ge + 29-1,// Cu + 79-1,// Au + 28-1,// Ni + 48-1,// Cd + 49-1,// In + 25-1,// Mn + 40-1,// Zr + 24-1,// Cr + 78-1,// Pt + 80-1,// Hg + 82-1,// Pb +}; +constexpr size_t atomicNumCount = std::extent::value; +constexpr OneHotLookup<118, atomicNumCount> atomicNumLookup(atomicNumList); + +constexpr size_t degreeCount = 5; +constexpr size_t valenceCount = 7; + +// Reverse alphabetical order, excluding "OTHER", +// matching HYBRIDIZATION_LIST in older file graphium/features/nmp.py +constexpr size_t hybridizationList[] = { + RDKit::Atom::HybridizationType::UNSPECIFIED, + RDKit::Atom::HybridizationType::SP3D2, + RDKit::Atom::HybridizationType::SP3D, + RDKit::Atom::HybridizationType::SP3, + RDKit::Atom::HybridizationType::SP2D, + RDKit::Atom::HybridizationType::SP2, + RDKit::Atom::HybridizationType::SP, + RDKit::Atom::HybridizationType::S, +}; +constexpr size_t hybridizationCount = std::extent::value; +constexpr OneHotLookup<8, hybridizationCount> hybridizationLookup(hybridizationList); + +static const std::string chiralityRString("R"); + +enum ElementPhase { + GAS, + ARTIFICIAL, + LIQ, + SOLID +}; +// This table is from the Phase column of graphium/features/periodic_table.csv +constexpr ElementPhase atomicNumToPhase[] = { + GAS, GAS, + SOLID, SOLID, SOLID, SOLID, GAS, GAS, GAS, GAS, + SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, GAS, GAS, + SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, LIQ, GAS, + SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, ARTIFICIAL, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, GAS, + SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, ARTIFICIAL, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, LIQ, SOLID, SOLID, SOLID, SOLID, SOLID, GAS, + SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, +}; +constexpr size_t phaseCount = 4; + +enum ElementType { + NOBLE_GAS, + ALKALI_METAL, + METAL, HALOGEN, + LANTHANIDE, + ALKALINE_EARTH_METAL, + TRANSITION_METAL, + ACTINIDE, + METALLOID, + NONE, + TRANSACTINIDE, + NONMETAL, + + NUM_ELEMENT_TYPES +}; +// This table is from the Type column of graphium/features/periodic_table.csv +constexpr ElementType atomicNumToType[] = { + NONMETAL, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, METALLOID, NONMETAL, NONMETAL, NONMETAL, HALOGEN, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, METAL, METALLOID, NONMETAL, NONMETAL, HALOGEN, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, METAL, METALLOID, METALLOID, NONMETAL, HALOGEN, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, METAL, METAL, METALLOID, METALLOID, HALOGEN, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, METAL, METAL, METAL, METALLOID, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, NONE, TRANSACTINIDE, NONE, TRANSACTINIDE, NONE, NOBLE_GAS +}; +constexpr size_t typeCount = ElementType::NUM_ELEMENT_TYPES; + +// This matches BOND_TYPES in older file graphium/features/nmp.py +constexpr size_t bondTypeList[] = { + RDKit::Bond::BondType::SINGLE, + RDKit::Bond::BondType::DOUBLE, + RDKit::Bond::BondType::TRIPLE, + RDKit::Bond::BondType::AROMATIC, +}; +constexpr size_t bondTypeCount = std::extent::value; +constexpr OneHotLookup<22, bondTypeCount> bondTypeLookup(bondTypeList); + +// This matches BOND_STEREO in older file graphium/features/nmp.py +constexpr size_t bondStereoList[] = { + RDKit::Bond::BondStereo::STEREONONE, + RDKit::Bond::BondStereo::STEREOANY, + RDKit::Bond::BondStereo::STEREOZ, + RDKit::Bond::BondStereo::STEREOE, + RDKit::Bond::BondStereo::STEREOCIS, + RDKit::Bond::BondStereo::STEREOTRANS, +}; +constexpr size_t bondStereoCount = std::extent::value; +constexpr OneHotLookup<6, bondStereoCount> bondStereoLookup(bondStereoList); + +size_t get_one_hot_atom_feature_size(AtomOneHotFeature feature) { + switch (feature) { + case AtomOneHotFeature::ATOMIC_NUM: return atomicNumCount + 1; + case AtomOneHotFeature::DEGREE: return degreeCount + 1; + case AtomOneHotFeature::VALENCE: return valenceCount + 1; + case AtomOneHotFeature::IMPLICIT_VALENCE: return valenceCount + 1; + case AtomOneHotFeature::HYBRIDIZATION: return hybridizationCount + 1; + // "R", anything else ("S" or no value), bool for if other property present + case AtomOneHotFeature::CHIRALITY: return 3; + case AtomOneHotFeature::PHASE: return phaseCount + 1; + case AtomOneHotFeature::TYPE: return typeCount + 1; + case AtomOneHotFeature::GROUP: return groupCount + 1; + case AtomOneHotFeature::PERIOD: return periodCount + 1; + default: + // Missing implementation + assert(0); + return 0; + } +} + +template +size_t get_one_hot_atom_feature(const GraphData& graph, T* data, AtomOneHotFeature feature, size_t stride) { + const size_t num_atoms = graph.num_atoms; + const RDKit::ROMol& mol = *graph.mol.get(); + const size_t feature_size = get_one_hot_atom_feature_size(feature); + const size_t total_feature_size = feature_size * num_atoms; + if (total_feature_size == 0) { + return feature_size; + } + { + T* current_data = data; + for (size_t i = 0; i < num_atoms; ++i) { + memset(current_data, 0, sizeof(data[0]) * feature_size); + current_data += stride; + } + } + switch (feature) { + case AtomOneHotFeature::ATOMIC_NUM: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + size_t atomicNum = graph.atoms[atomIndex].atomicNum; + data[atomicNumLookup[atomicNum-1]] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::DEGREE: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + auto degree = mol.getAtomWithIdx(atomIndex)->getDegree(); + size_t dataIndex = (degree < degreeCount) ? degree : degreeCount; + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::VALENCE: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + auto valence = mol.getAtomWithIdx(atomIndex)->getTotalValence(); + size_t dataIndex = (size_t(valence) < valenceCount) ? size_t(valence) : valenceCount; + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::IMPLICIT_VALENCE: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + auto valence = mol.getAtomWithIdx(atomIndex)->getImplicitValence(); + size_t dataIndex = (size_t(valence) < valenceCount) ? size_t(valence) : valenceCount; + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::HYBRIDIZATION: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + auto hybridization = mol.getAtomWithIdx(atomIndex)->getHybridization(); + data[hybridizationLookup[hybridization]] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::CHIRALITY: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + std::string chirality; + const RDKit::Atom* atom = mol.getAtomWithIdx(atomIndex); + bool isPresent = atom->getPropIfPresent(RDKit::common_properties::_CIPCode, chirality); + data[(isPresent && chirality == chiralityRString) ? 0 : 1] = FeatureValues::one; + if (atom->hasProp(RDKit::common_properties::_ChiralityPossible)) { + data[2] = FeatureValues::one; + } + } + return feature_size; + case AtomOneHotFeature::PHASE: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + size_t atomicNum = graph.atoms[atomIndex].atomicNum; + size_t dataIndex = phaseCount; + if (atomicNum - 1 < std::extent::value) { + ElementPhase phase = atomicNumToPhase[atomicNum - 1]; + // Group numbers are 1-based, but the array indices aren't. + dataIndex = phase - 1; + } + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::TYPE: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + size_t atomicNum = graph.atoms[atomIndex].atomicNum; + size_t dataIndex = typeCount; + if (atomicNum - 1 < std::extent::value) { + ElementType type = atomicNumToType[atomicNum - 1]; + // Group numbers are 1-based, but the array indices aren't. + dataIndex = type - 1; + } + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::GROUP: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + size_t atomicNum = graph.atoms[atomIndex].atomicNum; + size_t dataIndex = groupCount; + if (atomicNum - 1 < std::extent::value) { + uint8_t group = atomicNumToGroupTable[atomicNum - 1]; + // Group numbers are 1-based, but the array indices aren't. + dataIndex = group - 1; + } + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::PERIOD: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + size_t atomicNum = graph.atoms[atomIndex].atomicNum; + size_t dataIndex = periodCount; + if (atomicNum - 1 < std::extent::value) { + uint8_t period = atomicNumToPeriodTable[atomicNum - 1]; + // Period numbers are 1-based, but the array indices aren't. + dataIndex = period - 1; + } + data[dataIndex] = FeatureValues::one; + } + return feature_size; + default: + // Missing implementation + assert(0); + return feature_size; + } +} + +// Explicit instantiations, so that the function can be templated +// but still be used from other cpp files. +template size_t get_one_hot_atom_feature(const GraphData& graph, int16_t* data, AtomOneHotFeature feature, size_t stride); +template size_t get_one_hot_atom_feature(const GraphData& graph, float* data, AtomOneHotFeature feature, size_t stride); +template size_t get_one_hot_atom_feature(const GraphData& graph, double* data, AtomOneHotFeature feature, size_t stride); + + +size_t get_one_hot_bond_feature_size(BondFeature feature) { + switch (feature) { + case BondFeature::TYPE_ONE_HOT: return bondTypeCount + 1; + case BondFeature::STEREO_ONE_HOT: return bondStereoCount + 1; + default: + break; + } + // Missing implementation + assert(0); + return 0; +} + +template +size_t get_one_hot_bond_feature(const GraphData& graph, T* data, BondFeature feature, size_t stride) { + const size_t num_bonds = graph.num_bonds; + const size_t feature_size = get_one_hot_bond_feature_size(feature); + const size_t total_feature_size = feature_size * num_bonds; + if (total_feature_size == 0) { + return 0; + } + { + T* current_data = data; + for (size_t i = 0; i < num_bonds; ++i) { + memset(current_data, 0, sizeof(data[0]) * feature_size); + current_data += stride; + } + } + switch (feature) { + case BondFeature::TYPE_ONE_HOT: + for (size_t i = 0; i < num_bonds; ++i, data += stride) { + auto type = graph.bonds[i].bondType; + data[bondTypeLookup[type]] = FeatureValues::one; + } + return feature_size; + case BondFeature::STEREO_ONE_HOT: + for (size_t i = 0; i < num_bonds; ++i, data += stride) { + auto stereo = graph.bonds[i].stereo; + data[bondStereoLookup[stereo]] = FeatureValues::one; + } + return feature_size; + default: + // Missing implementation + assert(0); + return feature_size; + } +} + +// Explicit instantiations, so that the function can be templated +// but still be used from other cpp files. +template size_t get_one_hot_bond_feature(const GraphData& graph, int16_t* data, BondFeature feature, size_t stride); +template size_t get_one_hot_bond_feature(const GraphData& graph, float* data, BondFeature feature, size_t stride); +template size_t get_one_hot_bond_feature(const GraphData& graph, double* data, BondFeature feature, size_t stride); diff --git a/graphium/graphium_cpp/one_hot.h b/graphium/graphium_cpp/one_hot.h new file mode 100644 index 000000000..475b87a8e --- /dev/null +++ b/graphium/graphium_cpp/one_hot.h @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "features.h" + +#include + +#include + +size_t get_one_hot_atom_feature_size(AtomOneHotFeature feature); + +template +size_t get_one_hot_atom_feature(const GraphData& graph, T* data, AtomOneHotFeature feature, size_t stride); + +extern template size_t get_one_hot_atom_feature(const GraphData& graph, int16_t* data, AtomOneHotFeature feature, size_t stride); +extern template size_t get_one_hot_atom_feature(const GraphData& graph, float* data, AtomOneHotFeature feature, size_t stride); +extern template size_t get_one_hot_atom_feature(const GraphData& graph, double* data, AtomOneHotFeature feature, size_t stride); + +size_t get_one_hot_bond_feature_size(BondFeature feature); + +template +size_t get_one_hot_bond_feature(const GraphData& graph, T* data, BondFeature feature, size_t stride); + +extern template size_t get_one_hot_bond_feature(const GraphData& graph, int16_t* data, BondFeature feature, size_t stride); +extern template size_t get_one_hot_bond_feature(const GraphData& graph, float* data, BondFeature feature, size_t stride); +extern template size_t get_one_hot_bond_feature(const GraphData& graph, double* data, BondFeature feature, size_t stride); + diff --git a/graphium/graphium_cpp/random_walk.cpp b/graphium/graphium_cpp/random_walk.cpp new file mode 100644 index 000000000..e4dc3116b --- /dev/null +++ b/graphium/graphium_cpp/random_walk.cpp @@ -0,0 +1,141 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "random_walk.h" + +#include +#include +#include +#include +#include + +template +void multiply_dense_by_sparse(uint32_t n, T* out_matrix, const T* in_matrix, const uint32_t* neighbor_starts, const uint32_t* neighbors, const T* col_major_weights) { + for (uint32_t row = 0; row < n; ++row) { + T* out_row_start = out_matrix + row * n; + const T* in_row_start = in_matrix + row * n; + for (uint32_t col = 0; col < n; ++col) { + T sum = T(0); + // The adjacency is symmetric, so rows and cols are swappable there, + // but the weights might not be, so for fast access, we want column major weights. + const uint32_t* neighbors_start = neighbors + neighbor_starts[col]; + const uint32_t* neighbors_end = neighbors + neighbor_starts[col+1]; + const T* weights_start = col_major_weights + neighbor_starts[col]; + for (; neighbors_start != neighbors_end; ++neighbors_start, ++weights_start) { + sum += *weights_start * in_row_start[*neighbors_start]; + } + out_row_start[col] = sum; + } + } +} + +// The adjacency (neighbor_starts and neighbors) must be symmetric. +// powers must be in increasing sorted order. +template +void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim) { + + // Cast one n to size_t to avoid integer overflow if n >= 65536 + if (option == RandomWalkDataOption::PROBABILITIES) { + output.resize(num_powers * size_t(n)); + } + else { + output.resize(num_powers * size_t(n) * n); + } + + if (num_powers == 0) { + return; + } + if (n == 1) { + // Special case: All ones for single node, matching original code + for (uint32_t i = 0; i < output.size(); ++i) { + output[i] = T(1); + } + return; + } + + // Initialize this to represent column major D^-1 * adj + std::vector col_major_weights; + col_major_weights.resize(neighbor_starts[n]); + for (uint32_t col = 0, i = 0; col < n; ++col) { + const uint32_t* neighbor_start = neighbors + neighbor_starts[col]; + const uint32_t* neighbor_end = neighbors + neighbor_starts[col+1]; + for (; neighbor_start != neighbor_end; ++neighbor_start, ++i) { + const uint32_t neighbor = *neighbor_start; + uint32_t neighbor_degree = neighbor_starts[neighbor + 1] - neighbor_starts[neighbor]; + T degree_inv = (neighbor_degree == 0) ? T(0) : T(1) / T(neighbor_degree); + col_major_weights[i] = degree_inv; + } + } + + // Space for 2 matrices, to alternate between them + std::vector matrix; + matrix.resize(2 * size_t(n) * n, T(0)); + T* matrix0 = matrix.data(); + T* matrix1 = matrix.data() + size_t(n) * n; + uint64_t current_power = 0; + // Initialize current matrix to identity matrix + for (size_t i = 0, diag_index = 0; i < n; ++i, diag_index += (n+1)) { + matrix0[diag_index] = T(1); + } + + for (uint32_t power_index = 0; power_index < num_powers; ++power_index) { + const uint64_t target_power = powers[power_index]; + assert(target_power >= current_power); + while (target_power > current_power) { + std::swap(matrix0, matrix1); + multiply_dense_by_sparse(n, matrix0, matrix1, neighbor_starts, neighbors, col_major_weights.data()); + ++current_power; + } + + // Copy results to output + if (option == RandomWalkDataOption::PROBABILITIES) { + const T scale_factor = (space_dim == 0) ? T(1) : T(std::pow(T(target_power), T(0.5) * T(space_dim))); + // Just copy the diagonal values + for (size_t i = 0, diag_index = 0; i < n; ++i, diag_index += (n + 1)) { + output[i * num_powers + power_index] = scale_factor * matrix0[diag_index]; + } + } + else { + // Copy transition probabilities, making sure the dimensions are correct, because matrix0 isn't symmetric. + // Least significant dimension is num_powers + // Middle dimension is the columns across a single row of matrix0 + // Most significant dimension is the rows of the matrix0 + const size_t row_stride = num_powers * size_t(n); + for (size_t row = 0, i = 0; row < n; ++row) { + for (size_t col = 0; col < n; ++col, ++i) { + output[row * row_stride + col * num_powers + power_index] = matrix0[i]; + } + } + } + } +} + +// Explicit instantiations for float and double +template +void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim); +template +void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim); diff --git a/graphium/graphium_cpp/random_walk.h b/graphium/graphium_cpp/random_walk.h new file mode 100644 index 000000000..1617a7be2 --- /dev/null +++ b/graphium/graphium_cpp/random_walk.h @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +enum class RandomWalkDataOption { + PROBABILITIES, + MATRIX +}; + +template +void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim = 0); + +extern template +void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim); +extern template +void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim); diff --git a/graphium/graphium_cpp/setup.py b/graphium/graphium_cpp/setup.py new file mode 100755 index 000000000..c1fb1e3fb --- /dev/null +++ b/graphium/graphium_cpp/setup.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Setup script that builds graphium_cpp. +At time of writing, this has only been tested with GCC 10.5.0. +To build, git clone pybind11 into this directory, then run: +rm -r build/* +export PYTHONPATH=$PYTHONPATH:./pybind11 +pip install . +""" + +from distutils.core import setup +from pybind11.setup_helpers import Pybind11Extension, build_ext +import torch, rdkit, os +import numpy + +torch_dir = torch.__path__[0] +rdkit_lib_index = rdkit.__path__[0].split("/").index("lib") +rdkit_prefix = "/".join(rdkit.__path__[0].split("/")[:rdkit_lib_index]) + +ext_modules = [ + Pybind11Extension( + "graphium_cpp", + sources=[ + "graphium_cpp.cpp", + "features.cpp", + "labels.cpp", + "commute.cpp", + "electrostatic.cpp", + "float_features.cpp", + "graphormer.cpp", + "one_hot.cpp", + "random_walk.cpp", + "spectral.cpp", + ], + language="c++", + cxx_std=20, + include_dirs=[ + os.path.join(torch_dir, "include"), + os.path.join(torch_dir, "include/torch/csrc/api/include"), + os.path.join(rdkit_prefix, "include/rdkit"), + os.path.join(rdkit_prefix, "include/boost"), + numpy.get_include(), + ], + libraries=[ + "RDKitAlignment", + "RDKitDataStructs", + "RDKitDistGeometry", + "RDKitDistGeomHelpers", + "RDKitEigenSolvers", + "RDKitForceField", + "RDKitForceFieldHelpers", + "RDKitGenericGroups", + "RDKitGraphMol", + "RDKitInchi", + "RDKitRDInchiLib", + "RDKitRDBoost", + "RDKitRDGeneral", + "RDKitRDGeometryLib", + "RDKitRingDecomposerLib", + "RDKitSmilesParse", + "RDKitSubstructMatch", + "torch_cpu", + "torch_python", + ], + library_dirs=[os.path.join(rdkit_prefix, "lib"), os.path.join(torch_dir, "lib")], + extra_compile_args=[ + "-O3", + "-Wall", + "-Wmissing-field-initializers", + "-Wmaybe-uninitialized", + "-Wuninitialized", + ], + ) +] + +setup( + name="graphium_cpp", + version="0.1", + author="N. Dickson", + author_email="ndickson@nvidia.com", + license="Apache 2.0", + description="C++ extension for graphium", + ext_modules=ext_modules, + cmdclass={"build_ext": build_ext}, +) diff --git a/graphium/graphium_cpp/spectral.cpp b/graphium/graphium_cpp/spectral.cpp new file mode 100644 index 000000000..9ce9e6ed1 --- /dev/null +++ b/graphium/graphium_cpp/spectral.cpp @@ -0,0 +1,317 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "spectral.h" + +#include +#include +#include +#include + +#include "features.h" +#include + +size_t find_components( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + std::vector& components) { + + int32_t num_components = (n <= 1) ? 1 : 0; + std::vector queue; + if (n > 1) { + // First, find which nodes are in which component. + components.resize(n, -1); + queue.reserve(n); + for (uint32_t starti = 0; starti < n; ++starti) { + if (components[starti] >= 0) { + continue; + } + const int32_t component = num_components; + ++num_components; + queue.push_back(starti); + components[starti] = component; + while (queue.size() != 0) { + uint32_t current = queue[queue.size()-1]; + queue.resize(queue.size()-1); + const uint32_t* neighbor_begin = neighbors + row_starts[current]; + const uint32_t* neighbor_end = neighbors + row_starts[current+1]; + for ( ; neighbor_begin != neighbor_end; ++neighbor_begin) { + uint32_t neighbor = *neighbor_begin; + if (neighbor > starti && components[neighbor] < 0) { + components[neighbor] = component; + queue.push_back(neighbor); + } + } + } + } + } + return size_t(num_components); +} + +template +void compute_laplacian_eigendecomp_single(const uint32_t n, LaplacianData& data, Normalization normalization) { + T* matrix = data.matrix_temp.data(); + std::unique_ptr matrix_alloc(new T[n * n]); + std::copy(matrix, matrix + n * n, matrix_alloc.get()); + + int64_t dims[2] = { n, n }; + at::Tensor torch_matrix = torch_tensor_from_array(std::move(matrix_alloc), dims, 2, c10::ScalarType::Double); + + // Using linalg_eigh should ensure we get all real eigenvalues and eigenvectors. + // Arbitrarily choose lower-triangular portion (L) + auto tuple = at::linalg_eigh(torch_matrix, c10::string_view("L",1)); + at::Tensor eigenvalue_tensor = std::move(std::get<0>(tuple)); + at::Tensor eigenvector_tensor = std::move(std::get<1>(tuple)); + assert(eigenvalue_tensor.ndimension() == 1); + assert(eigenvector_tensor.ndimension() == 2); + assert(eigenvalue_tensor.size(0) == n); + assert(eigenvector_tensor.size(0) == n); + assert(eigenvector_tensor.size(1) == n); + + // Copy eigenvectors first, because normalization values are in eigenvalues_temp + data.vectors.clear(); + data.vectors.resize(size_t(n) * n, 0); + T* vectors = data.vectors.data(); + if (eigenvector_tensor.scalar_type() == c10::ScalarType::Double) { + const double* const eigenvector_data = eigenvector_tensor.data_ptr(); + for (size_t i = 0; i < size_t(n) * n; ++i) { + vectors[i] = T(eigenvector_data[i]); + } + + if (normalization == Normalization::INVERSE) { + // Convert symmetric case eigenvectors to asymmetric case eigenvectors + + // Scale each row by the factor in eigenvalues_temp + for (size_t row = 0, i = 0; row < n; ++row) { + const T factor = data.eigenvalues_temp[row]; + for (size_t col = 0; col < n; ++col, ++i) { + vectors[i] *= factor; + } + + // Clear to zero for the summing below + data.eigenvalues_temp[row] = 0; + } + + // Find each column length + for (size_t row = 0, i = 0; row < n; ++row) { + for (size_t col = 0; col < n; ++col, ++i) { + const T v = vectors[i]; + data.eigenvalues_temp[col] += v*v; + } + } + for (size_t col = 0; col < n; ++col) { + data.eigenvalues_temp[col] = T(1)/std::sqrt(data.eigenvalues_temp[col]); + } + + // Normalize each column + for (size_t row = 0, i = 0; row < n; ++row) { + for (size_t col = 0; col < n; ++col, ++i) { + vectors[i] *= data.eigenvalues_temp[col]; + } + } + } + } + else { + assert(0); + } + + // Copy eigenvalues + data.eigenvalues_temp.resize(n); + if (eigenvalue_tensor.scalar_type() == c10::ScalarType::Double) { + const double* const eigenvalue_data = eigenvalue_tensor.data_ptr(); + for (size_t i = 0; i < n; ++i) { + // No adjustment needed to eigenvalues between symmetric and asymmetric + data.eigenvalues_temp[i] = T(eigenvalue_data[i]); + } + } + else { + assert(0); + } + + // Find the sorted order of the eigenvalues + data.order_temp.resize(n); + std::iota(data.order_temp.begin(), data.order_temp.end(), 0); + std::stable_sort(data.order_temp.begin(), data.order_temp.end(), + [&data](uint32_t i, uint32_t j) -> bool { + return data.eigenvalues_temp[i] < data.eigenvalues_temp[j]; + } + ); + + // Copy the eigenvalues into the sorted order + data.eigenvalues.resize(n); + for (size_t i = 0; i < n; ++i) { + data.eigenvalues[i] = data.eigenvalues_temp[data.order_temp[i]]; + } + + // Copy the eigenvectors into the sorted order + std::swap(data.matrix_temp, data.vectors); + for (size_t row = 0, i = 0; row < n; ++row) { + const size_t source_row = data.order_temp[row]; + const size_t source_row_start = source_row * n; + for (size_t col = 0; col < n; ++col, ++i) { + data.vectors[i] = data.matrix_temp[source_row_start + col]; + } + } +} + +template +void compute_laplacian_eigendecomp( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + Normalization normalization, + LaplacianData& data, + size_t num_components, + const std::vector* components, + const T* weights) { + + // Compute the weight row sums, if applicable, for the diagonal of the laplacian + if (weights != nullptr) { + data.eigenvalues_temp.clear(); + data.eigenvalues_temp.resize(n, 0); + for (uint32_t i = 0; i < n; ++i) { + const T* weights_begin = weights + row_starts[i]; + const T* weights_end = weights + row_starts[i + 1]; + T sum = T(0); + for (; weights_begin != weights_end; ++weights_begin) { + sum += *weights_begin; + } + data.eigenvalues_temp[i] = sum; + } + } + data.normalization = normalization; + + // Prepare the laplacian matrix of the graph + data.matrix_temp.clear(); + data.matrix_temp.resize(size_t(n) * n, 0); + T* matrix = data.matrix_temp.data(); + if (normalization == Normalization::NONE) { + for (uint32_t i = 0, outi = 0; i < n; ++i, outi += n) { + const uint32_t* neighbor_begin = neighbors + row_starts[i]; + const uint32_t* neighbor_end = neighbors + row_starts[i + 1]; + if (weights == nullptr) { + const uint32_t degree = row_starts[i + 1] - row_starts[i]; + matrix[outi + i] = T(degree); + for (; neighbor_begin < neighbor_end; ++neighbor_begin) { + uint32_t neighbor = *neighbor_begin; + matrix[outi + neighbor] = T(-1); + } + } + else { + matrix[outi + i] = data.eigenvalues_temp[i]; + const T* weights_begin = weights + row_starts[i]; + for (; neighbor_begin < neighbor_end; ++neighbor_begin, ++weights_begin) { + uint32_t neighbor = *neighbor_begin; + matrix[outi + neighbor] = -(*weights_begin); + } + } + } + } + else { + // The diagonalization of the asymmetric normalization can be computed from the + // diagonalization of the symmetric normalization, which is faster, so always use symmetric. + + // Find the normalization factor for each node (row or col) + // These values in eigenvalues_temp are also used inside compute_laplacian_eigendecomp_single + for (uint32_t node = 0; node < n; ++node) { + const uint32_t row_degree = row_starts[node + 1] - row_starts[node]; + const T denominator = (weights == nullptr) ? T(row_degree) : data.eigenvalues_temp[node]; + data.eigenvalues_temp[node] = T(1) / std::sqrt(denominator); + } + + for (uint32_t i = 0, outi = 0; i < n; ++i, outi += n) { + const uint32_t* neighbor_begin = neighbors + row_starts[i]; + const uint32_t* neighbor_end = neighbors + row_starts[i + 1]; + if (neighbor_begin == neighbor_end) { + continue; + } + + // Diagonal is always exactly 1 when normalized (after skipping zero-degree nodes) + matrix[outi + i] = T(1); + + const T row_factor = data.eigenvalues_temp[i]; + for (; neighbor_begin < neighbor_end; ++neighbor_begin) { + uint32_t neighbor = *neighbor_begin; + const T col_factor = data.eigenvalues_temp[neighbor]; + matrix[outi + neighbor] = -row_factor * col_factor; + } + } + } + + if (num_components == 1) { + compute_laplacian_eigendecomp_single(n, data, normalization); + return; + } + + // There are multiple components. + // To match the original code, handle them separately and + // pack them into the output. + + // data.eigenvalues is length n for the single component case, + // but to be able to handle this, it needs to be larger, so go with n by n + data.eigenvalues.clear(); + data.eigenvalues.resize(size_t(n) * n, 0); + data.vectors.clear(); + data.vectors.resize(size_t(n) * n, 0); + + LaplacianData sub_data; + std::vector queue; + for (int32_t component = 0; component < num_components; ++component) { + // Reuse queue for the indices + queue.resize(0); + for (uint32_t i = 0; i < n; ++i) { + if ((*components)[i] == component) { + queue.push_back(i); + } + } + + // Extract the sub-matrix + const uint32_t sub_n = queue.size(); + sub_data.matrix_temp.resize(size_t(sub_n) * sub_n); + T* sub_matrix = sub_data.matrix_temp.data(); + for (uint32_t row_index = 0; row_index < sub_n; ++row_index) { + const uint32_t row = queue[row_index]; + const T*const source_row = matrix + row*size_t(n); + for (uint32_t col_index = 0; col_index < sub_n; ++col_index) { + const uint32_t col = queue[col_index]; + *sub_matrix = source_row[col]; + ++sub_matrix; + } + } + + // Find its eigenvalues and eigenvectors + compute_laplacian_eigendecomp_single(sub_n, sub_data, normalization); + + // Copy the eigenvalues to the output. The excess is already zeroed out. + // Unlike the eigenvectors, below, might as well switch to using columns + // for the eigenvalues, because the caller can handle this case more + // easily with the single component case this way. + for (uint32_t row_index = 0; row_index < sub_n; ++row_index) { + const uint32_t row = queue[row_index]; + T*const dest_row = data.eigenvalues.data() + row*size_t(n); + for (uint32_t col_index = 0; col_index < sub_n; ++col_index) { + // Destination data within the row is left justified, + // NOT distributed based on the component. + dest_row[col_index] = sub_data.eigenvalues[col_index]; + } + } + + // Copy the (row) eigenvectors to the output. The excess is already zeroed out. + // The caller changes them to column eigenvectors. + for (uint32_t row_index = 0; row_index < sub_n; ++row_index) { + // Destination data is top-aligned, NOT distributed + // based on the component. + T*const dest_row = data.vectors.data() + row_index*size_t(n); + const T*const source_row = sub_data.vectors.data() + row_index*size_t(sub_n); + for (uint32_t col_index = 0; col_index < sub_n; ++col_index) { + // Columns ARE distributed based on the component. + const uint32_t col = queue[col_index]; + dest_row[col] = source_row[col_index]; + } + } + } +} + +template void compute_laplacian_eigendecomp(const uint32_t n, const uint32_t* row_starts, const uint32_t* neighbors, Normalization normalization, LaplacianData& data, size_t num_components, const std::vector* components, const float* weights); +template void compute_laplacian_eigendecomp(const uint32_t n, const uint32_t* row_starts, const uint32_t* neighbors, Normalization normalization, LaplacianData& data, size_t num_components, const std::vector* components, const double* weights); diff --git a/graphium/graphium_cpp/spectral.h b/graphium/graphium_cpp/spectral.h new file mode 100644 index 000000000..3f3b6e41b --- /dev/null +++ b/graphium/graphium_cpp/spectral.h @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "features.h" + +#include +#include + +template +struct LaplacianData { + Normalization normalization; + + std::vector vectors; + std::vector eigenvalues; + + std::vector matrix_temp; + std::vector eigenvalues_temp; + std::vector order_temp; +}; + +size_t find_components( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + std::vector& components); + +// This outputs the eigenvalues in data.eigenvalues and the eigenvectors in data.vectors +template +void compute_laplacian_eigendecomp( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + Normalization normalization, + LaplacianData& data, + size_t num_components, + const std::vector* components, + const T* weights = nullptr); + +extern template void compute_laplacian_eigendecomp(const uint32_t n, const uint32_t* row_starts, const uint32_t* neighbors, Normalization normalization, LaplacianData& data, size_t num_components, const std::vector* components, const float* weights); +extern template void compute_laplacian_eigendecomp(const uint32_t n, const uint32_t* row_starts, const uint32_t* neighbors, Normalization normalization, LaplacianData& data, size_t num_components, const std::vector* components, const double* weights); diff --git a/graphium/utils/spaces.py b/graphium/utils/spaces.py index 44bf7e3df..63ccb3fc6 100644 --- a/graphium/utils/spaces.py +++ b/graphium/utils/spaces.py @@ -132,7 +132,6 @@ "GraphOGBDataModule": Datamodules.GraphOGBDataModule, "MultitaskFromSmilesDataModule": Datamodules.MultitaskFromSmilesDataModule, "ADMETBenchmarkDataModule": Datamodules.ADMETBenchmarkDataModule, - "FakeDataModule": Datamodules.FakeDataModule, } GRAPHIUM_PRETRAINED_MODELS_DICT = { diff --git a/profiling/configs_profiling.yaml b/profiling/configs_profiling.yaml index 0ff4f6c94..bde4bdb5f 100644 --- a/profiling/configs_profiling.yaml +++ b/profiling/configs_profiling.yaml @@ -11,8 +11,6 @@ datamodule: smiles_col: SMILES # Featurization - featurization_n_jobs: -1 - featurization_progress: True featurization: atom_property_list_onehot: [atomic-number, valence] atom_property_list_float: [mass, electronegativity] diff --git a/profiling/profile_mol_to_graph.py b/profiling/profile_mol_to_graph.py index 423f487cf..e8bf19315 100644 --- a/profiling/profile_mol_to_graph.py +++ b/profiling/profile_mol_to_graph.py @@ -16,7 +16,7 @@ import pickle from graphium.data.utils import load_micro_zinc -from graphium.features.featurizer import mol_to_pyggraph, mol_to_adj_and_features, mol_to_graph_dict +from graphium.features.featurizer import mol_to_pyggraph # Check out this profiling tool: https://kirillstrelkov.medium.com/python-profiling-with-vscode-3a17c0407833 @@ -67,10 +67,7 @@ def main(): graphs = [] for s in tqdm(smiles): - mol = dm.to_mol( - s - ) # Doesn't need `ordered=True` because this is just to test the speed of the featurizer - graphs.append(mol_to_graph_dict(mol, **featurizer)) + graphs.append(mol_to_pyggraph(s, **featurizer)) print(graphs[0]) diff --git a/tests/config_test_ipu_dataloader.yaml b/tests/config_test_ipu_dataloader.yaml index f0f55d197..3f63bfd3d 100644 --- a/tests/config_test_ipu_dataloader.yaml +++ b/tests/config_test_ipu_dataloader.yaml @@ -61,9 +61,6 @@ datamodule: weights_type: null # This may not always be provided task_level: graph # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 0 - featurization_progress: True featurization: atom_property_list_onehot: [atomic-number, valence] atom_property_list_float: [mass, electronegativity, in-ring] diff --git a/tests/config_test_ipu_dataloader_multitask.yaml b/tests/config_test_ipu_dataloader_multitask.yaml index 8b8fbf417..563222d8d 100644 --- a/tests/config_test_ipu_dataloader_multitask.yaml +++ b/tests/config_test_ipu_dataloader_multitask.yaml @@ -51,7 +51,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" qm9: @@ -95,10 +94,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" # processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -133,7 +128,6 @@ datamodule: num_workers: -1 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/tests/data/config_micro_ZINC.yaml b/tests/data/config_micro_ZINC.yaml index 88fc4a841..d2e94318f 100644 --- a/tests/data/config_micro_ZINC.yaml +++ b/tests/data/config_micro_ZINC.yaml @@ -11,8 +11,6 @@ datamodule: smiles_col: SMILES # Featurization - featurization_n_jobs: -1 - featurization_progress: True featurization: atom_property_list_onehot: [atomic-number, valence] atom_property_list_float: [mass, electronegativity, in-ring] diff --git a/tests/test_collate.py b/tests/test_collate.py index 3cb453b32..6524596d6 100644 --- a/tests/test_collate.py +++ b/tests/test_collate.py @@ -28,12 +28,12 @@ class test_Collate(ut.TestCase): def test_collate_labels(self): # Create fake labels - labels_size_dict = { - "graph_label1": [1], - "graph_label2": [3], - "node_label2": [5], - "edge_label3": [5, 2], - "node_label4": [5, 1], + labels_num_cols_dict = { + "graph_label1": 1, + "graph_label2": 3, + "node_label2": 1, + "edge_label3": 2, + "node_label4": 1, } labels_dtype_dict = { "graph_label1": torch.float32, @@ -57,9 +57,16 @@ def test_collate_labels(self): pyg_labels[key] = val + 17 * 2 fake_labels.append(pyg_labels) + num_nodes = [g.num_nodes for g in fake_labels] + num_edges = [g.num_edges for g in fake_labels] + # Collate labels and check for the right shapes and dtypes collated_labels = collate_labels( - deepcopy(fake_labels), deepcopy(labels_size_dict), deepcopy(labels_dtype_dict) + deepcopy(fake_labels), + deepcopy(labels_num_cols_dict), + deepcopy(labels_dtype_dict), + num_nodes, + num_edges, ) self.assertEqual(collated_labels["graph_label1"].shape, torch.Size([num_labels, 1])) # , 1 self.assertEqual(collated_labels["graph_label2"].shape, torch.Size([num_labels, 3])) # , 1 @@ -108,15 +115,19 @@ def test_collate_labels(self): label4_true[missing_labels["node_label4"]] = float("nan") # Collate labels and check for the right shapes - labels_size_dict = { - "graph_label1": [1], - "graph_label2": [3], - "node_label2": [5], - "edge_label3": [5, 2], - "node_label4": [5, 1], + labels_num_cols_dict = { + "graph_label1": 1, + "graph_label2": 3, + "node_label2": 1, + "edge_label3": 2, + "node_label4": 1, } collated_labels = collate_labels( - deepcopy(fake_labels), deepcopy(labels_size_dict), deepcopy(labels_dtype_dict) + deepcopy(fake_labels), + deepcopy(labels_num_cols_dict), + deepcopy(labels_dtype_dict), + num_nodes, + num_edges, ) self.assertEqual(collated_labels["graph_label1"].shape, torch.Size([num_labels, 1])) # , 1 self.assertEqual(collated_labels["graph_label2"].shape, torch.Size([num_labels, 3])) # , 1 @@ -138,9 +149,14 @@ def test_collate_labels(self): collated_labels["node_label4"].numpy(), label4_true.flatten(0, 1).numpy() ) # Now test the `graphium_collate_fn` function when only labels are given - fake_labels2 = [{"labels": this_label} for this_label in fake_labels] + fake_labels2 = [ + {"labels": this_label, "num_nodes": this_label.num_nodes, "num_edges": this_label.num_edges} + for this_label in fake_labels + ] collated_labels = graphium_collate_fn( - deepcopy(fake_labels2), labels_size_dict=labels_size_dict, labels_dtype_dict=labels_dtype_dict + deepcopy(fake_labels2), + labels_num_cols_dict=labels_num_cols_dict, + labels_dtype_dict=labels_dtype_dict, )["labels"] self.assertEqual(collated_labels["graph_label1"].shape, torch.Size([num_labels, 1])) self.assertEqual(collated_labels["graph_label2"].shape, torch.Size([num_labels, 3])) diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index a9ed6b045..a2d0c162c 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -22,6 +22,8 @@ from graphium.utils.fs import rm, exists, get_size from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule +import graphium_cpp + TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000" @@ -45,23 +47,22 @@ def test_ogb_datamodule(self): task_specific_args = {} task_specific_args["task_1"] = {"task_level": "graph", "dataset_name": dataset_name} dm_args = {} - dm_args["processed_graph_data_path"] = None dm_args["featurization"] = featurization_args dm_args["batch_size_training"] = 16 dm_args["batch_size_inference"] = 16 dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["featurization_n_jobs"] = 0 - dm_args["featurization_progress"] = True - dm_args["featurization_backend"] = "loky" - dm_args["featurization_batch_size"] = 50 - ds = GraphOGBDataModule(task_specific_args, **dm_args) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) - ds.prepare_data(save_smiles_and_ids=False) + ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) + + ds.prepare_data() # Check the keys in the dataset - ds.setup(save_smiles_and_ids=False) + ds.setup() assert set(ds.train_ds[0].keys()) == {"features", "labels"} # Delete the cache if already exist @@ -69,13 +70,13 @@ def test_ogb_datamodule(self): rm(TEMP_CACHE_DATA_PATH, recursive=True) # Reset the datamodule - ds = GraphOGBDataModule(task_specific_args, **dm_args) + ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) - ds.prepare_data(save_smiles_and_ids=True) + ds.prepare_data() # Check the keys in the dataset - ds.setup(save_smiles_and_ids=True) - assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"} + ds.setup() + assert set(ds.train_ds[0].keys()) == {"features", "labels"} # test module assert ds.num_edge_feats == 5 @@ -84,100 +85,7 @@ def test_ogb_datamodule(self): # test batch loader batch = next(iter(ds.train_dataloader())) - assert len(batch["smiles"]) == 16 assert len(batch["labels"]["graph_task_1"]) == 16 - assert len(batch["mol_ids"]) == 16 - - def test_none_filtering(self): - # Create the objects to filter - list_of_num = [ii for ii in range(100)] - list_of_str = [str(ii) for ii in list_of_num] - tuple_of_num = tuple(list_of_num) - array_of_num = np.asarray(list_of_num) - array_of_str = np.asarray(list_of_str) - tensor_of_num = torch.as_tensor(array_of_num) - arrays_of_num = np.stack([list_of_num, list_of_num, list_of_num], axis=1) - arrays_of_str = np.stack([list_of_str, list_of_str, list_of_str], axis=1) - tensors_of_num = torch.as_tensor(arrays_of_num) - dic = {"str": list_of_str, "num": list_of_num} - df = pd.DataFrame(dic) - df_shuffled = df.sample(frac=1) - series_num = df["num"] - series_num_shuffled = df_shuffled["num"] - - # Create different indexes to use for filtering - all_idx_none = [[3, 17, 88], [22, 33, 44, 55, 66, 77, 88], [], np.arange(len(list_of_num))] - - # Loop all the indexes and filter the objects. - for ii, idx_none in enumerate(all_idx_none): - msg = f"Failed for ii={ii}" - - # Create the true filtered sequences - filtered_num = [ii for ii in range(100) if ii not in idx_none] - filtered_str = [str(ii) for ii in filtered_num] - assert len(filtered_num) == len(list_of_num) - len(idx_none) - assert len(filtered_str) == len(list_of_str) - len(idx_none) - - # Filter the sequences from the Datamodule function - ( - list_of_num_2, - list_of_str_2, - tuple_of_num_2, - array_of_num_2, - array_of_str_2, - tensor_of_num_2, - df_2, - df_shuffled_2, - dic_2, - arrays_of_num_2, - arrays_of_str_2, - tensors_of_num_2, - series_num_2, - series_num_shuffled_2, - ) = graphium.data.MultitaskFromSmilesDataModule._filter_none_molecules( - idx_none, - list_of_num, - list_of_str, - tuple_of_num, - array_of_num, - array_of_str, - tensor_of_num, - df, - df_shuffled, - dic, - arrays_of_num, - arrays_of_str, - tensors_of_num, - series_num, - series_num_shuffled, - ) - - df_shuffled_2 = df_shuffled_2.sort_values(by="num", axis=0) - series_num_shuffled_2 = series_num_shuffled_2.sort_values(axis=0) - - # Assert the filtering is done correctly - self.assertListEqual(list_of_num_2, filtered_num, msg=msg) - self.assertListEqual(list_of_str_2, filtered_str, msg=msg) - self.assertListEqual(list(tuple_of_num_2), filtered_num, msg=msg) - self.assertListEqual(array_of_num_2.tolist(), filtered_num, msg=msg) - self.assertListEqual(array_of_str_2.tolist(), filtered_str, msg=msg) - self.assertListEqual(tensor_of_num_2.tolist(), filtered_num, msg=msg) - for jj in range(arrays_of_num.shape[1]): - self.assertListEqual(arrays_of_num_2[:, jj].tolist(), filtered_num, msg=msg) - self.assertListEqual(arrays_of_str_2[:, jj].tolist(), filtered_str, msg=msg) - self.assertListEqual(tensors_of_num_2[:, jj].tolist(), filtered_num, msg=msg) - self.assertListEqual(dic_2["num"], filtered_num, msg=msg) - self.assertListEqual(dic_2["str"], filtered_str, msg=msg) - self.assertListEqual(df_2["num"].tolist(), filtered_num, msg=msg) - self.assertListEqual(df_2["str"].tolist(), filtered_str, msg=msg) - self.assertListEqual(series_num_2.tolist(), filtered_num, msg=msg) - - # When the dataframe is shuffled, the lists are different because the filtering - # is done on the row indexes, not the dataframe indexes. - bool_to_check = (len(idx_none) == 0) or (len(idx_none) == len(df_shuffled)) - self.assertIs(df_shuffled_2["num"].tolist() == filtered_num, bool_to_check, msg=msg) - self.assertIs(df_shuffled_2["str"].tolist() == filtered_str, bool_to_check, msg=msg) - self.assertIs(series_num_shuffled_2.tolist() == filtered_num, bool_to_check, msg=msg) def test_caching(self): # other datasets are too large to be tested @@ -201,10 +109,6 @@ def test_caching(self): dm_args["batch_size_inference"] = 16 dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["featurization_n_jobs"] = 0 - dm_args["featurization_progress"] = True - dm_args["featurization_backend"] = "loky" - dm_args["featurization_batch_size"] = 50 # Delete the cache if already exist if exists(TEMP_CACHE_DATA_PATH): @@ -214,10 +118,10 @@ def test_caching(self): assert not exists(TEMP_CACHE_DATA_PATH) ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) # assert not ds.load_data_from_cache(verbose=False) - ds.prepare_data(save_smiles_and_ids=False) + ds.prepare_data() # Check the keys in the dataset - ds.setup(save_smiles_and_ids=False) + ds.setup() assert set(ds.train_ds[0].keys()) == {"features", "labels"} # ds_batch = next(iter(ds.train_dataloader())) @@ -227,23 +131,9 @@ def test_caching(self): # Test loading cached data assert exists(TEMP_CACHE_DATA_PATH) - cached_ds_from_ram = GraphOGBDataModule( - task_specific_args, - processed_graph_data_path=TEMP_CACHE_DATA_PATH, - dataloading_from="ram", - **dm_args, - ) - cached_ds_from_ram.prepare_data() - cached_ds_from_ram.setup() - cached_train_loader_from_ram = cached_ds_from_ram.get_dataloader( - cached_ds_from_ram.train_ds, shuffle=False, stage="train" - ) - batch_from_ram = next(iter(cached_train_loader_from_ram)) - cached_ds_from_disk = GraphOGBDataModule( task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, - dataloading_from="disk", **dm_args, ) cached_ds_from_disk.prepare_data() @@ -254,59 +144,31 @@ def test_caching(self): batch_from_disk = next(iter(cached_train_loader_from_disk)) # Features are the same - np.testing.assert_array_almost_equal( - batch["features"].edge_index, batch_from_ram["features"].edge_index - ) np.testing.assert_array_almost_equal( batch["features"].edge_index, batch_from_disk["features"].edge_index ) - assert batch["features"].num_nodes == batch_from_ram["features"].num_nodes assert batch["features"].num_nodes == batch_from_disk["features"].num_nodes - np.testing.assert_array_almost_equal( - batch["features"].edge_weight, batch_from_ram["features"].edge_weight - ) np.testing.assert_array_almost_equal( batch["features"].edge_weight, batch_from_disk["features"].edge_weight ) - np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_ram["features"].feat) np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_disk["features"].feat) - np.testing.assert_array_almost_equal( - batch["features"].edge_feat, batch_from_ram["features"].edge_feat - ) np.testing.assert_array_almost_equal( batch["features"].edge_feat, batch_from_disk["features"].edge_feat ) - np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_ram["features"].batch) np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_disk["features"].batch) - np.testing.assert_array_almost_equal(batch["features"].ptr, batch_from_ram["features"].ptr) np.testing.assert_array_almost_equal(batch["features"].ptr, batch_from_disk["features"].ptr) # Labels are the same - np.testing.assert_array_almost_equal( - batch["labels"].graph_task_1, batch_from_ram["labels"].graph_task_1 - ) np.testing.assert_array_almost_equal( batch["labels"].graph_task_1, batch_from_disk["labels"].graph_task_1 ) - np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_ram["labels"].x) - np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_disk["labels"].x) - - np.testing.assert_array_almost_equal(batch["labels"].edge_index, batch_from_ram["labels"].edge_index) - np.testing.assert_array_almost_equal(batch["labels"].edge_index, batch_from_disk["labels"].edge_index) - - np.testing.assert_array_almost_equal(batch["labels"].batch, batch_from_ram["labels"].batch) - np.testing.assert_array_almost_equal(batch["labels"].batch, batch_from_disk["labels"].batch) - - np.testing.assert_array_almost_equal(batch["labels"].ptr, batch_from_ram["labels"].ptr) - np.testing.assert_array_almost_equal(batch["labels"].ptr, batch_from_disk["labels"].ptr) - # Delete the cache if already exist if exists(TEMP_CACHE_DATA_PATH): rm(TEMP_CACHE_DATA_PATH, recursive=True) @@ -314,10 +176,10 @@ def test_caching(self): # Reset the datamodule ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) - ds.prepare_data(save_smiles_and_ids=True) + ds.prepare_data() - ds.setup(save_smiles_and_ids=True) - assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"} + ds.setup() + assert set(ds.train_ds[0].keys()) == {"features", "labels"} # test module assert ds.num_edge_feats == 5 @@ -326,9 +188,7 @@ def test_caching(self): # test batch loader batch = next(iter(ds.train_dataloader())) - assert len(batch["smiles"]) == 16 assert len(batch["labels"]["graph_task_1"]) == 16 - assert len(batch["mol_ids"]) == 16 # Delete the cache if already exist if exists(TEMP_CACHE_DATA_PATH): @@ -369,15 +229,18 @@ def test_datamodule_with_none_molecules(self): bad_smiles = (df["SMILES1"] == "XXX") & (df["SMILES2"] == "XXX") & (df["SMILES3"] == "XXX") num_bad_smiles = sum(bad_smiles) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + # Test the datamodule datamodule = MultitaskFromSmilesDataModule( task_specific_args=task_specific_args, + processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization_args=featurization_args, - featurization_n_jobs=0, - featurization_batch_size=1, ) datamodule.prepare_data() - datamodule.setup(save_smiles_and_ids=True) + datamodule.setup() # Check that the number of molecules is correct smiles = df["SMILES1"].tolist() + df["SMILES2"].tolist() + df["SMILES3"].tolist() @@ -400,33 +263,36 @@ def test_datamodule_with_none_molecules(self): df = df.set_index("idx_smiles") # Convert the smilies from the train_ds to a list, and check the content - train_smiles = [d["smiles"] for d in datamodule.train_ds] + train_smiles = [ + graphium_cpp.extract_string( + datamodule.train_ds.smiles_tensor, datamodule.train_ds.smiles_offsets_tensor, idx + ) + for idx in range(len(datamodule.train_ds)) + ] # Check that the set of smiles are the same - train_smiles_flat = list(set([item for sublist in train_smiles for item in sublist])) + train_smiles_flat = list(set(train_smiles)) train_smiles_flat.sort() index_smiles_filt = list(set([smiles for smiles in index_smiles if smiles != "XXX"])) index_smiles_filt.sort() self.assertListEqual(train_smiles_flat, index_smiles_filt) - # Check that the smiles are correct for each datapoint in the dataset + # Check that the smiles is correct for each datapoint in the dataset for smiles in train_smiles: - self.assertEqual(len(set(smiles)), 1) # Check that all smiles are the same - this_smiles = smiles[0] - true_smiles = df.loc[this_smiles][["SMILES1", "SMILES2", "SMILES3"]] - num_true_smiles = sum(true_smiles != "XXX") - self.assertEqual(len(smiles), num_true_smiles) # Check that the number of smiles is correct + assert isinstance(smiles, str) + true_smiles = df.loc[smiles][["SMILES1", "SMILES2", "SMILES3"]] self.assertEqual( - this_smiles, true_smiles[true_smiles != "XXX"].values[0] - ) # Check that the smiles are correct + smiles, true_smiles[true_smiles != "XXX"].values[0] + ) # Check that the smiles is correct # Convert the labels from the train_ds to a dataframe - train_labels = [{task: val[0] for task, val in d["labels"].items()} for d in datamodule.train_ds] + train_labels = [datamodule.train_ds[idx]["labels"] for idx in range(len(datamodule.train_ds))] + train_labels = [{k: v[0].item() for k, v in label} for label in train_labels] train_labels_df = pd.DataFrame(train_labels) train_labels_df = train_labels_df.rename( columns={"graph_task_1": "graph_SA", "graph_task_2": "graph_logp", "graph_task_3": "graph_score"} ) - train_labels_df["smiles"] = [s[0] for s in datamodule.train_ds.smiles] + train_labels_df["smiles"] = train_smiles train_labels_df = train_labels_df.set_index("smiles") train_labels_df = train_labels_df.sort_index() @@ -450,7 +316,11 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) ds.prepare_data() ds.setup() @@ -463,7 +333,11 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) ds.prepare_data() ds.setup() @@ -476,7 +350,11 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) ds.prepare_data() ds.setup() @@ -489,7 +367,11 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) ds.prepare_data() ds.setup() @@ -526,9 +408,13 @@ def test_splits_file(self): } } - ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) - ds.prepare_data(save_smiles_and_ids=True) - ds.setup(save_smiles_and_ids=True) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) + ds.prepare_data() + ds.setup() self.assertEqual(len(ds.train_ds), len(split_train)) self.assertEqual(len(ds.val_ds), len(split_val)) @@ -555,19 +441,30 @@ def test_splits_file(self): } } - ds2 = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) - ds2.prepare_data(save_smiles_and_ids=True) - ds2.setup(save_smiles_and_ids=True) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds2 = MultitaskFromSmilesDataModule( + task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH + ) + ds2.prepare_data() + ds2.setup() self.assertEqual(len(ds2.train_ds), len(split_train)) self.assertEqual(len(ds2.val_ds), len(split_val)) self.assertEqual(len(ds2.test_ds), len(split_test)) # Check that the splits are the same - self.assertEqual(len(ds.train_ds.smiles), len(split_train)) - np.testing.assert_array_equal(ds.train_ds.smiles, ds2.train_ds.smiles) - np.testing.assert_array_equal(ds.val_ds.smiles, ds2.val_ds.smiles) - np.testing.assert_array_equal(ds.test_ds.smiles, ds2.test_ds.smiles) + self.assertEqual(len(ds.train_ds.smiles_offsets_tensor), len(split_train) + 1) + np.testing.assert_array_equal(ds.train_ds.smiles_tensor, ds2.train_ds.smiles_tensor) + np.testing.assert_array_equal(ds.val_ds.smiles_tensor, ds2.val_ds.smiles_tensor) + np.testing.assert_array_equal(ds.test_ds.smiles_tensor, ds2.test_ds.smiles_tensor) + np.testing.assert_array_equal( + ds.train_ds.smiles_offsets_tensor, ds2.train_ds.smiles_offsets_tensor + ) + np.testing.assert_array_equal(ds.val_ds.smiles_offsets_tensor, ds2.val_ds.smiles_offsets_tensor) + np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor) if __name__ == "__main__": diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 0a377c35a..feb83a10a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -15,10 +15,95 @@ import unittest as ut from graphium.data import load_micro_zinc -from graphium.data.dataset import SingleTaskDataset, MultitaskDataset +from graphium.data.datamodule import MultitaskFromSmilesDataModule +from graphium.data.dataset import MultitaskDataset +from graphium.features import mol_to_pyggraph from graphium.data.smiles_transform import smiles_to_unique_mol_ids from graphium.data.utils import get_keys +import graphium_cpp + +import numpy as np +import os.path as osp + +TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000" + + +def dataframes_to_dataset(dataframes_dict, case_num): + task_names = [key for key in dataframes_dict.keys()] + + task_dataset_args = {} + task_train_indices = {} + task_val_indices = {} + task_test_indices = {} + for task in task_names: + ( + smiles, + labels, + label_offsets, + sample_idx, + extras, + ) = MultitaskFromSmilesDataModule._extract_smiles_labels( + df=dataframes_dict[task], + task_level="graph", + smiles_col="SMILES", + label_cols=task, + idx_col=None, + weights_col=None, + weights_type=None, + ) + num_molecules = len(smiles) + task_dataset_args[task] = { + "smiles": smiles, + "labels": labels, + "label_offsets": label_offsets, + "extras": extras, + } + + task_train_indices[task] = np.arange(num_molecules).tolist() + task_val_indices[task] = [] + task_test_indices[task] = [] + + fake_data_hash = "a1b2c3testdataset" + str(case_num) + + # The rest of the data preparation and caching is done in graphium_cpp.prepare_and_save_data + normalizations = {task: {} for task in task_names} # No normalization + stage_data, all_stats, label_num_cols, label_dtypes = graphium_cpp.prepare_and_save_data( + task_names, + task_dataset_args, + normalizations, + TEMP_CACHE_DATA_PATH, + fake_data_hash, + task_train_indices, + task_val_indices, + task_test_indices, + False, # add_self_loop + False, # explicit_H + 0, # preprocessing_n_jobs + ) + + stage_data = stage_data["train"] + + data_offsets = None + if MultitaskFromSmilesDataModule.data_offsets_tensor_index() < len(stage_data): + data_offsets = stage_data[MultitaskFromSmilesDataModule.data_offsets_tensor_index()] + + multitask_dataset = MultitaskDataset( + about="test_dataset case" + str(case_num), + data_path=osp.join(TEMP_CACHE_DATA_PATH, "train_" + fake_data_hash), + featurize_smiles=mol_to_pyggraph, + task_names=task_names, + label_num_cols=label_num_cols, + label_dtypes=label_dtypes, + mol_file_data_offsets=data_offsets, + concat_smiles_tensor=stage_data[MultitaskFromSmilesDataModule.concat_smiles_tensor_index()], + smiles_offsets_tensor=stage_data[MultitaskFromSmilesDataModule.smiles_offsets_tensor_index()], + num_nodes_tensor=stage_data[MultitaskFromSmilesDataModule.num_nodes_tensor_index()], + num_edges_tensor=stage_data[MultitaskFromSmilesDataModule.num_edges_tensor_index()], + ) + + return multitask_dataset + class test_Multitask_Dataset(ut.TestCase): # Then we can choose different rows and columns for the tests as we see fit. @@ -42,50 +127,44 @@ def test_multitask_dataset_case_1(self): df_micro_zinc_logp = df[["SMILES", "logp"]] df_micro_zinc_score = df[["SMILES", "score"]] - # We need to turn these dataframes into single-task datasets. + # We need to prepare the data for these dataframes. # We don't need to do featurization yet. - ds_micro_zinc_SA = SingleTaskDataset( - smiles=df_micro_zinc_SA.loc[:, "SMILES"].tolist(), labels=df_micro_zinc_SA.loc[:, "SA"].tolist() - ) - - ds_micro_zinc_logp = SingleTaskDataset( - smiles=df_micro_zinc_logp.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_logp.loc[:, "logp"].tolist(), - ) - ds_micro_zinc_score = SingleTaskDataset( - smiles=df_micro_zinc_score.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_score.loc[:, "score"].tolist(), - ) - - # Create the multitask dataset - datasets_dict = {"SA": ds_micro_zinc_SA, "logp": ds_micro_zinc_logp, "score": ds_micro_zinc_score} - multitask_microzinc = MultitaskDataset( - datasets_dict, save_smiles_and_ids=True - ) # Can optionally have features + dataframes = { + "SA": df_micro_zinc_SA, + "logp": df_micro_zinc_logp, + "score": df_micro_zinc_score, + } + multitask_dataset = dataframes_to_dataset(dataframes, 1) # Check: The number of unique molecules equals the number of datapoints in the multitask dataset. - self.assertEqual(num_unique_mols, multitask_microzinc.__len__()) + self.assertEqual(num_unique_mols, multitask_dataset.__len__()) # Check that for each task, you have the same label values as the initial DF. - for idx in range(multitask_microzinc.__len__()): + for idx in range(multitask_dataset.__len__()): smiles = df[["SMILES"]].iloc[idx].values[0] - # label = df[['SA']].iloc[idx] - label_SA = ds_micro_zinc_SA.labels[idx] - label_logp = ds_micro_zinc_logp.labels[idx] - label_score = ds_micro_zinc_score.labels[idx] - - # Search for the mol id in the multitask dataset - mol_ids = smiles_to_unique_mol_ids([smiles]) - mol_id = mol_ids[0] + + label_SA = df_micro_zinc_SA["SA"][idx] + label_logp = df_micro_zinc_logp["logp"][idx] + label_score = df_micro_zinc_score["score"][idx] + + # Search for the smiles string in the multitask dataset found_idx = -1 - for i, id in enumerate(multitask_microzinc.mol_ids): - if mol_id == id: + for i in range(multitask_dataset.__len__()): + if ( + graphium_cpp.extract_string( + multitask_dataset.smiles_tensor, multitask_dataset.smiles_offsets_tensor, i + ) + == smiles + ): found_idx = i + break + + item = multitask_dataset[found_idx]["labels"] # Compare labels - self.assertEqual(label_SA, multitask_microzinc.labels[found_idx]["SA"]) - self.assertEqual(label_logp, multitask_microzinc.labels[found_idx]["logp"]) - self.assertEqual(label_score, multitask_microzinc.labels[found_idx]["score"]) + self.assertEqual(label_SA, item["SA"]) + self.assertEqual(label_logp, item["logp"]) + self.assertEqual(label_score, item["score"]) def test_multitask_dataset_case_2(self): """Case: Different tasks, but with no intersection in the smiles (each task has a unique set of smiles) @@ -100,36 +179,18 @@ def test_multitask_dataset_case_2(self): df_rows_score = df.iloc[400:750] # 350 data points total_data_points = 750 - # Here we split the data according to the task we care about. - df_micro_zinc_SA = df_rows_SA[["SMILES", "SA"]] - df_micro_zinc_logp = df_rows_logp[["SMILES", "logp"]] - df_micro_zinc_score = df_rows_score[["SMILES", "score"]] - - # We need to turn these dataframes into single-task datasets. - # We don't need to do featurization yet. - ds_micro_zinc_SA = SingleTaskDataset( - smiles=df_micro_zinc_SA.loc[:, "SMILES"].tolist(), labels=df_micro_zinc_SA.loc[:, "SA"].tolist() - ) - ds_micro_zinc_logp = SingleTaskDataset( - smiles=df_micro_zinc_logp.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_logp.loc[:, "logp"].tolist(), - ) - ds_micro_zinc_score = SingleTaskDataset( - smiles=df_micro_zinc_score.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_score.loc[:, "score"].tolist(), - ) - - # Create the multitask dataset - datasets_dict = {"SA": ds_micro_zinc_SA, "logp": ds_micro_zinc_logp, "score": ds_micro_zinc_score} - multitask_microzinc = MultitaskDataset( - datasets_dict, save_smiles_and_ids=True - ) # Can optionally have features + dataframes = { + "SA": df_rows_SA, + "logp": df_rows_logp, + "score": df_rows_score, + } + multitask_microzinc = dataframes_to_dataset(dataframes, 2) # The total dataset has as many molecules as there are smiles in all tasks put together self.assertEqual(total_data_points, multitask_microzinc.__len__()) # For each task, only the smiles related to that task have values, and the value is what's expected from the initial DF. - for idx in range(len(ds_micro_zinc_SA)): + for idx in range(len(multitask_microzinc)): smiles = df[["SMILES"]].iloc[idx].values[0] task = "task" @@ -141,28 +202,33 @@ def test_multitask_dataset_case_2(self): task = "score" # Labels of that molecule - label_SA = df[["SA"]].iloc[idx].values[0] - label_logp = df[["logp"]].iloc[idx].values[0] - label_score = df[["score"]].iloc[idx].values[0] + label_df = df[[task]].iloc[idx].values[0] - # Search for that molecule in the multitask dataset - mol_ids = smiles_to_unique_mol_ids([smiles]) - mol_id = mol_ids[0] + # Search for the smiles string in the multitask dataset found_idx = -1 - for i, id in enumerate(multitask_microzinc.mol_ids): - if mol_id == id: + for i in range(multitask_microzinc.__len__()): + if ( + graphium_cpp.extract_string( + multitask_microzinc.smiles_tensor, multitask_microzinc.smiles_offsets_tensor, i + ) + == smiles + ): found_idx = i - multitask_microzinc_labels = get_keys(multitask_microzinc.labels[found_idx]) + break + + item = multitask_microzinc[found_idx]["labels"] + multitask_microzinc_labels = item.keys() + + assert task in multitask_microzinc_labels + self.assertEqual(label_df, item[task]) + if task == "SA": - self.assertEqual(label_SA, multitask_microzinc.labels[found_idx]["SA"]) self.assertFalse("score" in multitask_microzinc_labels) self.assertFalse("logp" in multitask_microzinc_labels) elif task == "logp": - self.assertEqual(label_logp, multitask_microzinc.labels[found_idx]["logp"]) self.assertFalse("score" in multitask_microzinc_labels) self.assertFalse("SA" in multitask_microzinc_labels) elif task == "score": - self.assertEqual(label_score, multitask_microzinc.labels[found_idx]["score"]) self.assertFalse("SA" in multitask_microzinc_labels) self.assertFalse("logp" in multitask_microzinc_labels) @@ -180,30 +246,12 @@ def test_multitask_dataset_case_3(self): df_rows_score = df.iloc[3:5] total_data_points = 5 - # Here we split the data according to the task we care about. - df_micro_zinc_SA = df_rows_SA[["SMILES", "SA"]] - df_micro_zinc_logp = df_rows_logp[["SMILES", "logp"]] - df_micro_zinc_score = df_rows_score[["SMILES", "score"]] - - # We need to turn these dataframes into single-task datasets. - # We don't need to do featurization yet. - ds_micro_zinc_SA = SingleTaskDataset( - smiles=df_micro_zinc_SA.loc[:, "SMILES"].tolist(), labels=df_micro_zinc_SA.loc[:, "SA"].tolist() - ) - ds_micro_zinc_logp = SingleTaskDataset( - smiles=df_micro_zinc_logp.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_logp.loc[:, "logp"].tolist(), - ) - ds_micro_zinc_score = SingleTaskDataset( - smiles=df_micro_zinc_score.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_score.loc[:, "score"].tolist(), - ) - - # Create the multitask dataset - datasets_dict = {"SA": ds_micro_zinc_SA, "logp": ds_micro_zinc_logp, "score": ds_micro_zinc_score} - multitask_microzinc = MultitaskDataset( - datasets_dict, save_smiles_and_ids=True - ) # Can optionally have features + dataframes = { + "SA": df_rows_SA, + "logp": df_rows_logp, + "score": df_rows_score, + } + multitask_microzinc = dataframes_to_dataset(dataframes, 3) # The multitask dataset has as many molecules as there are unique smiles across the single task datasets. self.assertEqual(total_data_points, multitask_microzinc.__len__()) diff --git a/tests/test_featurizer.py b/tests/test_featurizer.py index e8f666365..3336feae3 100644 --- a/tests/test_featurizer.py +++ b/tests/test_featurizer.py @@ -22,13 +22,9 @@ from rdkit import Chem import datamol as dm -from graphium.features.featurizer import ( - get_mol_atomic_features_onehot, - get_mol_atomic_features_float, - get_mol_edge_features, - mol_to_adj_and_features, - mol_to_pyggraph, -) +from graphium.features.featurizer import mol_to_pyggraph + +import graphium_cpp class test_featurizer(ut.TestCase): @@ -99,155 +95,120 @@ class test_featurizer(ut.TestCase): def test_get_mol_atomic_features_onehot(self): props = deepcopy(self.atomic_onehot_props) - bad_props = ["bob"] + # bad_props = ["bob"] all_smiles = self.smiles + self.smiles_noble - for s in all_smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) + for mol in all_smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {mol}" + + rdmol = dm.to_mol(mol) for ii in range(len(props)): this_props = props[:ii] err_msg2 = err_msg + f"\n\t\tprops: {this_props}" - prop_dict = get_mol_atomic_features_onehot(mol, property_list=this_props) - self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) - for key, val in prop_dict.items(): - err_msg3 = err_msg2 + f"\n\t\tkey: {key}" - self.assertEqual(val.shape[0], mol.GetNumAtoms(), msg=err_msg3) - self.assertGreater(val.shape[1], 1, msg=err_msg3) - self.assertTrue(np.all((val == 0) | (val == 1)), msg=err_msg3) + this_props_encoded = graphium_cpp.atom_onehot_feature_names_to_tensor(this_props) + features = mol_to_pyggraph(mol, atom_property_list_onehot=this_props_encoded, mask_nan=None) + val = features["feat"] + self.assertEqual(val.size(0), rdmol.GetNumAtoms(), msg=err_msg2) + self.assertGreaterEqual(val.size(1), 2 * len(this_props), msg=err_msg2) + self.assertTrue(((val == 0) | (val == 1)).numpy().all(), msg=err_msg2) - with self.assertRaises(ValueError, msg=err_msg): - get_mol_atomic_features_onehot(mol, property_list=bad_props) + # with self.assertRaises(ValueError, msg=err_msg): + # get_mol_atomic_features_onehot(mol, property_list=bad_props) def test_get_mol_atomic_features_float(self): props = deepcopy(self.atomic_float_props) - bad_props = ["bob"] + # bad_props = ["bob"] all_smiles = self.smiles + self.smiles_noble - for s in all_smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) + for mol in all_smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {mol}" + rdmol = dm.to_mol(mol) for ii in range(len(props)): this_props = props[:ii] err_msg2 = err_msg + f"\n\t\tprops: {this_props}" - prop_dict = get_mol_atomic_features_float(mol, property_list=this_props, mask_nan=None) - self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) - for key, val in prop_dict.items(): - err_msg3 = err_msg2 + f"\n\t\tkey: {key}" - self.assertListEqual(list(val.shape), [mol.GetNumAtoms()], msg=err_msg3) + this_props_encoded = graphium_cpp.atom_float_feature_names_to_tensor(this_props) + features = mol_to_pyggraph(mol, atom_property_list_float=this_props_encoded, mask_nan=None) + val = features["feat"] + self.assertEqual(val.size(0), rdmol.GetNumAtoms(), msg=err_msg2) + self.assertEqual(val.size(1), len(this_props), msg=err_msg2) - with self.assertRaises(ValueError, msg=err_msg): - get_mol_atomic_features_float(mol, property_list=bad_props) + # with self.assertRaises(ValueError, msg=err_msg): + # get_mol_atomic_features_float(mol, property_list=bad_props) def test_get_mol_atomic_features_float_nan_mask(self): - for s in self.smiles_noble: - mol = dm.to_mol(s) - + props_encoded = graphium_cpp.atom_float_feature_names_to_tensor(self.atomic_float_props) + for mol in self.smiles_noble: # Nothing happens when `mask_nan = None`, nans are still in the property array - prop_dict = get_mol_atomic_features_float( - mol, property_list=self.atomic_float_props, mask_nan=None + features = mol_to_pyggraph( + mol, atom_property_list_float=props_encoded, mask_nan=None, on_error="raise" ) - prop_array = np.concatenate(list(prop_dict.values()), axis=0) + prop_array = features["feat"] nans = np.isnan(prop_array) # Capture a raised error when `mask_nan = "raise"` with self.assertRaises(ValueError): - prop_dict = get_mol_atomic_features_float( - mol, property_list=self.atomic_float_props, mask_nan="raise" + features = mol_to_pyggraph( + mol, atom_property_list_float=props_encoded, mask_nan="raise", on_error="raise" ) + print(f"Failed to raise error for nans on {mol}") # Not sure how to Capture a logged warning when `mask_nan = "warn"` # Here, I'm testing a behaviour similar to `mask_nan = None` - prop_dict = get_mol_atomic_features_float( - mol, property_list=self.atomic_float_props, mask_nan="warn" + features = mol_to_pyggraph( + mol, atom_property_list_float=props_encoded, mask_nan="warn", on_error="raise" ) - prop_array = np.concatenate(list(prop_dict.values()), axis=0) - self.assertEqual(len(self.atomic_float_props), len(prop_dict)) - self.assertTrue(any(np.isnan(prop_array))) + prop_array = features["feat"] + self.assertEqual(len(self.atomic_float_props), prop_array.size(1)) + self.assertTrue(np.isnan(prop_array.numpy()).any()) # NaNs are replaced by `42` when `mask_nan=42` - prop_dict = get_mol_atomic_features_float(mol, property_list=self.atomic_float_props, mask_nan=42) - prop_array = np.concatenate(list(prop_dict.values()), axis=0) - self.assertEqual(len(self.atomic_float_props), len(prop_dict)) - self.assertFalse(any(np.isnan(prop_array))) - self.assertTrue(all(prop_array[nans] == 42)) + features = mol_to_pyggraph( + mol, atom_property_list_float=props_encoded, mask_nan=42, on_error="raise" + ) + prop_array = features["feat"] + self.assertEqual(len(self.atomic_float_props), prop_array.size(1)) + self.assertFalse(np.isnan(prop_array.numpy()).any()) + self.assertTrue((prop_array[nans] == 42).all()) def test_get_mol_edge_features(self): props = deepcopy(self.edge_props) - bad_props = ["bob"] + # bad_props = ["bob"] all_smiles = self.smiles + self.smiles_noble - for s in all_smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) + for mol in all_smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {mol}" + rdmol = dm.to_mol(mol) for ii in range(len(props)): this_props = props[: ii + 1] err_msg2 = err_msg + f"\n\t\tprops: {this_props}" - prop_dict = get_mol_edge_features(mol, property_list=this_props) - self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) - for key, val in prop_dict.items(): - err_msg3 = err_msg2 + f"\n\t\tkey: {key}" - self.assertEqual(val.shape[0], mol.GetNumBonds(), msg=err_msg3) - - if mol.GetNumBonds() > 0: - with self.assertRaises(ValueError, msg=err_msg): - get_mol_edge_features(mol, property_list=bad_props) - - def test_mol_to_adj_and_features(self): - np.random.seed(42) + this_props_encoded = graphium_cpp.bond_feature_names_to_tensor(this_props) + features = mol_to_pyggraph(mol, edge_property_list=this_props_encoded, mask_nan=None) + val = features["edge_feat"] + self.assertEqual(val.shape[0], 2 * rdmol.GetNumBonds(), msg=err_msg2) + if rdmol.GetNumBonds() > 0: + self.assertGreaterEqual(val.shape[1], len(this_props), msg=err_msg2) - for s in self.smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) - mol_Hs = Chem.AddHs(mol) # type: ignore - mol_No_Hs = Chem.RemoveHs(mol) # type: ignore - - for explicit_H in [True, False]: - this_mol = mol_Hs if explicit_H else mol_No_Hs - for ii in np.arange(0, 5, 0.2): - num_props = int(round(ii)) - err_msg2 = err_msg + f"\n\t\texplicit_H: {explicit_H}\n\t\tii: {ii}" - - adj, ndata, edata, _, _ = mol_to_adj_and_features( - mol=mol, - atom_property_list_onehot=np.random.choice( - self.atomic_onehot_props, size=num_props, replace=False - ), - atom_property_list_float=np.random.choice( - self.atomic_float_props, size=num_props, replace=False - ), - edge_property_list=np.random.choice(self.edge_props, size=num_props, replace=False), - add_self_loop=False, - explicit_H=explicit_H, - use_bonds_weights=False, - ) - - self.assertEqual(adj.shape[0], this_mol.GetNumAtoms(), msg=err_msg2) - if num_props > 0: - self.assertEqual(ndata.shape[0], this_mol.GetNumAtoms(), msg=err_msg2) - if this_mol.GetNumBonds() > 0: - self.assertEqual(edata.shape[0], this_mol.GetNumBonds(), msg=err_msg2) - self.assertGreaterEqual(edata.shape[1], num_props, msg=err_msg2) - self.assertGreaterEqual(ndata.shape[1], num_props, msg=err_msg2) + # if mol.GetNumBonds() > 0: + # with self.assertRaises(ValueError, msg=err_msg): + # get_mol_edge_features(mol, property_list=bad_props) def test_mol_to_pyggraph(self): np.random.seed(42) + single_atom_prop_encoded = graphium_cpp.atom_float_feature_names_to_tensor(["atomic-number"]) + single_bond_prop_encoded = graphium_cpp.bond_feature_names_to_tensor(["bond-type-float"]) - for s in self.smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) - mol_Hs = Chem.AddHs(mol) # type: ignore - mol_No_Hs = Chem.RemoveHs(mol) # type: ignore + for mol in self.smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {mol}" + rdmol = dm.to_mol(mol) graph = mol_to_pyggraph( mol=mol, - atom_property_list_onehot=[], - atom_property_list_float=["atomic-number"], - edge_property_list=["bond-type-float"], + atom_property_list_float=single_atom_prop_encoded, + edge_property_list=single_bond_prop_encoded, add_self_loop=False, explicit_H=False, use_bonds_weights=False, @@ -255,29 +216,32 @@ def test_mol_to_pyggraph(self): ) # Check the number of nodes and edges - self.assertListEqual(list(graph["feat"].shape), [mol.GetNumAtoms(), 1], msg=err_msg) - self.assertListEqual(list(graph["edge_feat"].shape), [2 * mol.GetNumBonds(), 1], msg=err_msg) + self.assertListEqual(list(graph["feat"].shape), [rdmol.GetNumAtoms(), 1], msg=err_msg) + self.assertListEqual(list(graph["edge_feat"].shape), [2 * rdmol.GetNumBonds(), 1], msg=err_msg) # Check the node features feat = graph["feat"].to_dense().numpy() * 5 + 6 # Undo the scaling - atom_nums = np.asarray([atom.GetAtomicNum() for atom in mol.GetAtoms()]) + atom_nums = np.asarray([atom.GetAtomicNum() for atom in rdmol.GetAtoms()]) np.testing.assert_array_almost_equal(feat[:, 0], atom_nums, decimal=5, err_msg=err_msg) # Check the edge features edge_feat = graph["edge_feat"].to_dense().numpy() - bond_types = np.asarray([bond.GetBondTypeAsDouble() for bond in mol.GetBonds()]).repeat(2) + bond_types = np.asarray([bond.GetBondTypeAsDouble() for bond in rdmol.GetBonds()]).repeat(2) np.testing.assert_array_almost_equal(edge_feat[:, 0], bond_types, decimal=5, err_msg=err_msg) # Check the edge indices - if mol.GetNumBonds() > 0: + if rdmol.GetNumBonds() > 0: edge_index = graph["edge_index"].to_dense().numpy() true_edge_index = [] - for bond in mol.GetBonds(): + for bond in rdmol.GetBonds(): true_edge_index.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) true_edge_index.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) true_edge_index = np.asarray(true_edge_index).T np.testing.assert_array_equal(edge_index, true_edge_index, err_msg=err_msg) + mol_Hs = Chem.AddHs(rdmol) # type: ignore + mol_No_Hs = Chem.RemoveHs(rdmol) # type: ignore + # Loop over many possible combinations of properties for explicit_H in [True, False]: this_mol = mol_Hs if explicit_H else mol_No_Hs @@ -287,13 +251,15 @@ def test_mol_to_pyggraph(self): graph = mol_to_pyggraph( mol=mol, - atom_property_list_onehot=np.random.choice( - self.atomic_onehot_props, size=num_props, replace=False + atom_property_list_onehot=graphium_cpp.atom_onehot_feature_names_to_tensor( + np.random.choice(self.atomic_onehot_props, size=num_props, replace=False) + ), + atom_property_list_float=graphium_cpp.atom_float_feature_names_to_tensor( + np.random.choice(self.atomic_float_props, size=num_props, replace=False) ), - atom_property_list_float=np.random.choice( - self.atomic_float_props, size=num_props, replace=False + edge_property_list=graphium_cpp.bond_feature_names_to_tensor( + np.random.choice(self.edge_props, size=num_props, replace=False) ), - edge_property_list=np.random.choice(self.edge_props, size=num_props, replace=False), add_self_loop=False, explicit_H=explicit_H, use_bonds_weights=False, diff --git a/tests/test_multitask_datamodule.py b/tests/test_multitask_datamodule.py index 81b5188df..81a51459f 100644 --- a/tests/test_multitask_datamodule.py +++ b/tests/test_multitask_datamodule.py @@ -22,6 +22,8 @@ import numpy as np import graphium +TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000" + class test_Multitask_DataModule(ut.TestCase): def setUp(self): @@ -109,12 +111,9 @@ def test_multitask_fromsmiles_dm( # Task-independent arguments dm_args["featurization"] = featurization_args - dm_args["featurization_n_jobs"] = 16 - dm_args["featurization_progress"] = True - dm_args["featurization_backend"] = "loky" dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["processed_graph_data_path"] = None + dm_args["processed_graph_data_path"] = TEMP_CACHE_DATA_PATH dm_args["batch_size_training"] = 16 dm_args["batch_size_inference"] = 16 @@ -175,6 +174,8 @@ def test_multitask_fromsmiles_from_config(self): dm_args["task_specific_args"]["logp"]["df_path"] = None dm_args["task_specific_args"]["score"]["df_path"] = None + dm_args["processed_graph_data_path"] = TEMP_CACHE_DATA_PATH + dm = graphium.data.MultitaskFromSmilesDataModule(**dm_args) # assert dm.num_node_feats == 50 @@ -205,6 +206,7 @@ def test_multitask_fromsmiles_from_config_csv(self): config = graphium.load_config(name="zinc_default_multitask_pyg") dm_args = OmegaConf.to_container(config.datamodule.args, resolve=True) + dm_args["processed_graph_data_path"] = TEMP_CACHE_DATA_PATH dm = graphium.data.MultitaskFromSmilesDataModule(**dm_args) dm.prepare_data() @@ -232,6 +234,7 @@ def test_multitask_fromsmiles_from_config_parquet(self): config = graphium.load_config(name="fake_multilevel_multitask_pyg") dm_args = OmegaConf.to_container(config.datamodule.args, resolve=True) + dm_args["processed_graph_data_path"] = TEMP_CACHE_DATA_PATH dm = graphium.data.MultitaskFromSmilesDataModule(**dm_args) dm.prepare_data() @@ -260,6 +263,7 @@ def test_multitask_with_missing_fromsmiles_from_config_parquet(self): config = graphium.load_config(name="fake_and_missing_multilevel_multitask_pyg") dm_args = OmegaConf.to_container(config.datamodule.args, resolve=True) + dm_args["processed_graph_data_path"] = TEMP_CACHE_DATA_PATH dm = graphium.data.MultitaskFromSmilesDataModule(**dm_args) dm.prepare_data() @@ -288,23 +292,25 @@ def test_extract_graph_level_singletask(self): df = pd.read_parquet(f"tests/converted_fake_multilevel_data.parquet") num_graphs = len(df) label_cols = ["graph_label"] - output = graphium.data.datamodule.extract_labels(df, "graph", label_cols) + output, output_offsets = graphium.data.datamodule.extract_labels(df, "graph", label_cols) assert isinstance(output, np.ndarray) assert len(output.shape) == 2 assert output.shape[0] == num_graphs assert output.shape[1] == 1 + assert output_offsets is None def test_extract_graph_level_multitask(self): df = pd.read_parquet(f"tests/converted_fake_multilevel_data.parquet") num_graphs = len(df) label_cols = ["graph_label", "graph_label"] - output = graphium.data.datamodule.extract_labels(df, "graph", label_cols) + output, output_offsets = graphium.data.datamodule.extract_labels(df, "graph", label_cols) assert isinstance(output, np.ndarray) assert len(output.shape) == 2 assert output.shape[0] == num_graphs assert output.shape[1] == len(label_cols) + assert output_offsets is None def test_extract_graph_level_multitask_missing_cols(self): df = pd.read_parquet(f"tests/converted_fake_multilevel_data.parquet") @@ -316,7 +322,7 @@ def test_extract_graph_level_multitask_missing_cols(self): for missing_col in label_cols[:replace]: df[missing_col].iloc[drop_index] = None - output = graphium.data.datamodule.extract_labels(df, "graph", label_cols) + output, output_offsets = graphium.data.datamodule.extract_labels(df, "graph", label_cols) assert isinstance(output, np.ndarray) assert len(output.shape) == 2 @@ -325,17 +331,24 @@ def test_extract_graph_level_multitask_missing_cols(self): def test_non_graph_level_extract_labels(self): df = pd.read_parquet(f"tests/converted_fake_multilevel_data.parquet") + num_graphs = len(df) for level in ["node", "edge", "nodepair"]: label_cols = [f"{level}_label_{suffix}" for suffix in ["list", "np"]] - output = graphium.data.datamodule.extract_labels(df, level, label_cols) + output, output_offsets = graphium.data.datamodule.extract_labels(df, level, label_cols) - assert isinstance(output, list) - assert len(output[0].shape) == 2 - assert output[0].shape[1] == len(label_cols) + assert isinstance(output, np.ndarray) + assert len(output.shape) == 2 + assert output.shape[1] == len(label_cols) + assert output_offsets is not None + assert isinstance(output_offsets, np.ndarray) + assert len(output_offsets.shape) == 1 + assert output_offsets.shape[0] == (num_graphs + 1) + assert output.shape[0] == output_offsets[-1] def test_non_graph_level_extract_labels_missing_cols(self): df = pd.read_parquet(f"tests/converted_fake_multilevel_data.parquet") + num_graphs = len(df) for level in ["node", "edge", "nodepair"]: label_cols = [f"{level}_label_{suffix}" for suffix in ["list", "np"]] @@ -344,16 +357,28 @@ def test_non_graph_level_extract_labels_missing_cols(self): for missing_col in label_cols[:replace]: df.loc[drop_index, missing_col] = None - output = graphium.data.datamodule.extract_labels(df, level, label_cols) + output, output_offsets = graphium.data.datamodule.extract_labels(df, level, label_cols) - for idx in drop_index: - assert len(output[idx].shape) == 2 - assert output[idx].shape[1] == len(label_cols) + assert isinstance(output, np.ndarray) + assert len(output.shape) == 2 + assert output.shape[1] == len(label_cols) + assert output_offsets is not None + assert isinstance(output_offsets, np.ndarray) + assert len(output_offsets.shape) == 1 + assert output_offsets.shape[0] == (num_graphs + 1) + assert output.shape[0] == output_offsets[-1] - # Check that number of labels is adjusted correctly - if replace == 1: - non_missing_col = label_cols[1] - assert output[idx].shape[0] == len(df[non_missing_col][idx]) + for idx in drop_index: + begin_idx = output_offsets[idx] + end_idx = output_offsets[idx + 1] + values = output[begin_idx:end_idx] + assert len(values.shape) == 2 + assert values.shape[1] == len(label_cols) + + # All removed entries must be nan + assert np.all(np.isnan(values[:, :replace])) + # All kept entries should be non-nan in this case + assert not np.any(np.isnan(values[:, replace:])) def test_tdc_admet_benchmark_data_module(self): """ diff --git a/tests/test_pe_nodepair.py b/tests/test_pe_nodepair.py index f90ce728b..b849b28b3 100644 --- a/tests/test_pe_nodepair.py +++ b/tests/test_pe_nodepair.py @@ -1,88 +1,113 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ """ -Unit tests for the positional encodings in graphium/features/* +Unit tests for the positional encodings in graphium/graphium_cpp/*.cpp """ import numpy as np -import networkx as nx +import torch import unittest as ut -from graphium.features.electrostatic import compute_electrostatic_interactions -from graphium.features.commute import compute_commute_distances -from graphium.features.graphormer import compute_graphormer_distances +import graphium +import graphium_cpp class test_positional_encodings(ut.TestCase): # Test graphs - adj_dict = {} + smiles_dict = {} + shape_dict = {} max_dict = {} # 6-ring - adj = np.asarray( - [ - [0, 1, 0, 0, 0, 1], - [1, 0, 1, 0, 0, 0], - [0, 1, 0, 1, 0, 0], - [0, 0, 1, 0, 1, 0], - [0, 0, 0, 1, 0, 1], - [1, 0, 0, 0, 1, 0], - ] - ) - adj_dict["6-ring"] = adj + smiles = "C1CCCCC1" + smiles_dict["6-ring"] = smiles + shape_dict["6-ring"] = [6, 6] max_dict["6-ring"] = 3 # 5-path - G = nx.path_graph(5) - adj = nx.to_numpy_array(G) - adj_dict["5-path"] = adj + smiles = "CCCCC" + smiles_dict["5-path"] = smiles + shape_dict["5-path"] = [5, 5] max_dict["5-path"] = 4 # 4-clique - adj = 1 - np.eye(4) - adj_dict["4-clique"] = adj + smiles = "C12C3C1C23" + smiles_dict["4-clique"] = smiles + shape_dict["4-clique"] = [4, 4] max_dict["4-clique"] = 1 # 4-barbell - H = nx.barbell_graph(4, 0) - adj = nx.to_numpy_array(H) - adj_dict["4-barbell"] = adj + smiles = "C12C3C1C23C12C3C1C23" + smiles_dict["4-barbell"] = smiles + shape_dict["4-barbell"] = [8, 8] max_dict["4-barbell"] = 3 + features = { + "electrostatic": {"pos_level": "nodepair", "pos_type": "electrostatic", "normalization": "none"}, + "graphormer": {"pos_level": "nodepair", "pos_type": "graphormer", "normalization": "none"}, + "commute": {"pos_level": "nodepair", "pos_type": "commute", "normalization": "none"}, + } + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor(features) + + def get_tensors(self, smiles): + tensors, _, _ = graphium_cpp.featurize_smiles( + smiles, + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_onehot + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_float + False, # has_conformer + torch.tensor(data=[], dtype=torch.int64), # edge_property_list + self.pos_encoding_tensor, + True, # duplicate_edges + False, # add_self_loop + False, # explicit_H=False + False, # use_bonds_weights + True, # offset_carbon + 7, # torch float64 + 0, # mask_nan_style_int + 0, # mask_nan_value + ) + return tensors + def test_dimensions(self): - for _, adj in self.adj_dict.items(): - pe, _, _ = compute_electrostatic_interactions(adj, cache={}) - self.assertEqual(pe.shape, adj.shape) + for key, smiles in self.smiles_dict.items(): + tensors = self.get_tensors(smiles) + + pe = tensors[4] # electrostatic + self.assertEqual(list(pe.shape), self.shape_dict[key]) - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) - self.assertEqual(pe.shape, adj.shape) + pe = tensors[5] # graphormer + self.assertEqual(list(pe.shape), self.shape_dict[key]) - pe, _, _ = compute_commute_distances(adj, adj.shape[0], cache={}) - self.assertEqual(pe.shape, adj.shape) + pe = tensors[6] # commute + self.assertEqual(list(pe.shape), self.shape_dict[key]) def test_symmetry(self): - for _, adj in self.adj_dict.items(): - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) + for _, smiles in self.smiles_dict.items(): + tensors = self.get_tensors(smiles) + + pe = tensors[5] # graphormer np.testing.assert_array_almost_equal(pe, pe.T) - pe, _, _ = compute_commute_distances(adj, adj.shape[0], cache={}) + pe = tensors[6] # commute np.testing.assert_array_almost_equal(pe, pe.T) def test_max_dist(self): - for key, adj in self.adj_dict.items(): - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) + for key, smiles in self.smiles_dict.items(): + tensors = self.get_tensors(smiles) + + pe = tensors[5] # graphormer np.testing.assert_array_almost_equal(pe.max(), self.max_dict[key]) diff --git a/tests/test_pe_rw.py b/tests/test_pe_rw.py index 938df28da..aebd6a577 100644 --- a/tests/test_pe_rw.py +++ b/tests/test_pe_rw.py @@ -1,53 +1,86 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ """ -Unit tests for the positional encodings in graphium/features/* +Unit tests for the positional encodings in graphium/features/random_walk.cpp """ import numpy as np -import networkx as nx +import torch import unittest as ut -from graphium.features.rw import compute_rwse +import graphium +import graphium_cpp class test_pe_spectral(ut.TestCase): - def test_caching_and_outputs(self): + def test_outputs(self): # 4-barbell - G = nx.barbell_graph(4, 0) - adj = nx.to_numpy_array(G) - num_nodes = adj.shape[0] - cache = {} + smiles = "C12C3C1C23C12C3C1C23" + num_nodes = 8 ksteps1 = [4, 6] ksteps2 = [2] ksteps3 = [6, 7] - pe1, _, cache = compute_rwse( - adj.astype(np.float32), ksteps1, num_nodes, cache, pos_type="rw_transition_probs" + # The feature names only depend on pos_type and pos_level, so the two + # rw_return_probs features can't have the same pos_level. + features = { + "rw_transition_probs": { + "pos_level": "nodepair", + "pos_type": "rw_transition_probs", + "normalization": "none", + "ksteps": ksteps1, + }, + "rw_return_probs_0": { + "pos_level": "node", + "pos_type": "rw_return_probs", + "normalization": "none", + "ksteps": ksteps2, + }, + "rw_return_probs_1": { + "pos_level": "nodepair", + "pos_type": "rw_return_probs", + "normalization": "none", + "ksteps": ksteps3, + }, + } + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor( + features ) - pe2, _, cache = compute_rwse( - adj.astype(np.float32), ksteps2, num_nodes, cache, pos_type="rw_return_probs" + tensors, _, _ = graphium_cpp.featurize_smiles( + smiles, + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_onehot + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_float + False, # has_conformer + torch.tensor(data=[], dtype=torch.int64), # edge_property_list + pos_encoding_tensor, + True, # duplicate_edges + False, # add_self_loop + False, # explicit_H=False + False, # use_bonds_weights + True, # offset_carbon + 7, # torch float64 + 0, # mask_nan_style_int + 0, # mask_nan_value ) - pe3, _, cache = compute_rwse( - adj.astype(np.float32), ksteps3, num_nodes, cache, pos_type="rw_return_probs" - ) + pe1 = tensors[4] + pe2 = tensors[5] + pe3 = tensors[6] - self.assertTrue(all([k in cache["ksteps"] for k in ksteps1 + ksteps2 + ksteps3])) self.assertTrue(pe1.shape, np.zeros((num_nodes, num_nodes, len(ksteps1)))) self.assertTrue(pe2.shape, np.zeros((num_nodes, len(ksteps2)))) self.assertTrue(pe3.shape, np.zeros((num_nodes, len(ksteps3)))) diff --git a/tests/test_pe_spectral.py b/tests/test_pe_spectral.py index 400eb9630..5c66e6f8b 100644 --- a/tests/test_pe_spectral.py +++ b/tests/test_pe_spectral.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -17,40 +17,75 @@ """ import numpy as np -import networkx as nx +import torch import unittest as ut -from graphium.features.spectral import compute_laplacian_pe +import graphium +import graphium_cpp + + +def get_pe_tensors(smiles, pos_encoding_tensor): + tensors, _, _ = graphium_cpp.featurize_smiles( + smiles, + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_onehot + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_float + False, # has_conformer + torch.tensor(data=[], dtype=torch.int64), # edge_property_list + pos_encoding_tensor, + True, # duplicate_edges + False, # add_self_loop + False, # explicit_H=False + False, # use_bonds_weights + True, # offset_carbon + 7, # torch float64 + 0, # mask_nan_style_int + 0, # mask_nan_value + ) + return tensors class test_pe_spectral(ut.TestCase): - # 2 disconnected 3 cliques - adj1 = np.zeros((6, 6)) - adj_3clq = 1 - np.eye(3) - adj1[:3, :3] = adj_3clq - adj1[3:, 3:] = adj_3clq + def test_for_connected_vs_disconnected_graph(self): + # 2 disconnected 3 cliques + smiles1 = "C1CC1.C1CC1" - # 3-clique - adj2 = 1 - np.eye(6) + # 6-clique (have to use S instead of C, because RDKit doesn't accept a carbon having 6 explicit bonds) + smiles2 = "S1234S567S189S251S368S4791" - def test_for_connected_vs_disconnected_graph(self): + num_atoms = 6 num_pos = 3 - # test if pe works identically on connected vs disconnected graphs - eigvals_pe1, _, _, cache = compute_laplacian_pe(self.adj1, num_pos, cache={}) - eigvals_pe1 = np.real(eigvals_pe1).astype(np.float32) - _, eigvecs_pe1, _, _ = compute_laplacian_pe(self.adj1, num_pos, cache=cache) + features = { + "laplacian_eigval": { + "pos_level": "node", + "pos_type": "laplacian_eigval", + "normalization": "none", + "num_pos": num_pos, + "disconnected_comp": True, + }, + "laplacian_eigvec": { + "pos_level": "node", + "pos_type": "laplacian_eigvec", + "normalization": "none", + "num_pos": num_pos, + "disconnected_comp": True, + }, + } + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor( + features + ) - # We expect to cache 4 objects in when running the functon for the first time - self.assertEqual(len(cache.keys()), 4) - - eigvals_pe2, _, _, _ = compute_laplacian_pe(self.adj2, num_pos, cache={}) - eigvals_pe2 = np.real(eigvals_pe2).astype(np.float32) - _, eigvecs_pe2, _, _ = compute_laplacian_pe(self.adj2, num_pos, cache={}) + # test if pe works identically on connected vs disconnected graphs + tensors1 = get_pe_tensors(smiles1, pos_encoding_tensor) + eigvals_pe1 = tensors1[4] + eigvecs_pe1 = tensors1[5] + tensors2 = get_pe_tensors(smiles2, pos_encoding_tensor) + eigvals_pe2 = tensors2[4] + eigvecs_pe2 = tensors2[5] np.testing.assert_array_almost_equal(2 * eigvals_pe1, eigvals_pe2) - self.assertListEqual(list(eigvals_pe2.shape), [self.adj2.shape[0], num_pos]) - self.assertListEqual(list(eigvecs_pe2.shape), [self.adj2.shape[0], num_pos]) + self.assertListEqual(list(eigvals_pe2.shape), [num_atoms, num_pos]) + self.assertListEqual(list(eigvecs_pe2.shape), [num_atoms, num_pos]) if __name__ == "__main__": diff --git a/tests/test_pos_transfer_funcs.py b/tests/test_pos_transfer_funcs.py index 5062cbe46..188c6b0e3 100644 --- a/tests/test_pos_transfer_funcs.py +++ b/tests/test_pos_transfer_funcs.py @@ -1,51 +1,166 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ - """ Unit tests for the positional encodings in graphium/features/* """ import numpy as np -import networkx as nx +import torch import unittest as ut +import math + +import graphium +import graphium_cpp + -from graphium.features.spectral import compute_laplacian_pe -from graphium.features.transfer_pos_level import ( - node_to_edge, - node_to_nodepair, - edge_to_nodepair, - nodepair_to_node, - nodepair_to_edge, - graph_to_node, -) +def get_tensors(smiles, pos_encoding_tensor): + tensors, _, _ = graphium_cpp.featurize_smiles( + smiles, + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_onehot + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_float + False, # has_conformer + torch.tensor(data=[], dtype=torch.int64), # edge_property_list + pos_encoding_tensor, + True, # duplicate_edges + False, # add_self_loop + False, # explicit_H=False + False, # use_bonds_weights + True, # offset_carbon + 7, # torch float64 + 0, # mask_nan_style_int + 0, # mask_nan_value + ) + return tensors class test_pos_transfer_funcs(ut.TestCase): - # 4-barbell - G = nx.barbell_graph(4, 0) - adj = nx.to_numpy_array(G) - num_nodes, num_feat = 8, 5 - node_pe = np.random.rand(num_nodes, num_feat) - - def test_different_pathways_from_node_to_edge(self): - edge_pe1, _ = node_to_edge(self.node_pe, self.adj, {}) - nodepair_pe1 = node_to_nodepair(self.node_pe, self.num_nodes) - edge_pe2, _ = nodepair_to_edge(nodepair_pe1, self.adj, {}) - nodepair_pe2, _ = edge_to_nodepair(edge_pe1, self.adj, self.num_nodes, {}) - edge_pe3, _ = nodepair_to_edge(nodepair_pe2, self.adj, {}) - np.testing.assert_array_almost_equal(edge_pe1, edge_pe2) - np.testing.assert_array_almost_equal(edge_pe1, edge_pe3) + + def test_different_transfers(self): + smiles = "CCCC" + + ksteps = [2, 4] + features = { + "a": { + "pos_level": "node", + "pos_type": "rw_return_probs", + "normalization": "none", + "ksteps": ksteps, + }, + "b": { + "pos_level": "edge", + "pos_type": "rw_return_probs", + "normalization": "none", + "ksteps": ksteps, + }, + "c": { + "pos_level": "nodepair", + "pos_type": "rw_return_probs", + "normalization": "none", + "ksteps": ksteps, + }, + "e": {"pos_level": "node", "pos_type": "graphormer", "normalization": "none"}, + "f": {"pos_level": "edge", "pos_type": "graphormer", "normalization": "none"}, + "d": {"pos_level": "nodepair", "pos_type": "graphormer", "normalization": "none"}, + } + + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor( + features + ) + + tensors = get_tensors(smiles, pos_encoding_tensor) + node_probs = tensors[4] + edge_probs = tensors[5] + nodepair_probs = tensors[6] + node_dists = tensors[7] + edge_dists = tensors[8] + nodepair_dists = tensors[9] + + print(f"node_probs =\n{node_probs}\n") + print(f"edge_probs =\n{edge_probs}\n") + print(f"nodepair_probs =\n{nodepair_probs}\n") + print(f"node_dists =\n{node_dists}\n") + print(f"edge_dists =\n{edge_dists}\n") + print(f"nodepair_dists =\n{nodepair_dists}\n") + + expected_node_probs = [ + [0.5, 0.375], + [0.75, 0.6875], + [0.75, 0.6875], + [0.5, 0.375], + ] + # sum for each node value and absolute difference for each node value, for each half-edge + expected_edge_probs = [ + [1.25, 1.0625, 0.25, 0.3125], + [1.25, 1.0625, 0.25, 0.3125], + [1.5, 1.375, 0.0, 0.0], + [1.5, 1.375, 0.0, 0.0], + [1.25, 1.0625, 0.25, 0.3125], + [1.25, 1.0625, 0.25, 0.3125], + ] + # sum for each node value and absolute difference for each node value, for each node pair + expected_nodepair_probs = [ + [ + [1.0000, 0.7500, 0.0000, 0.0000], + [1.2500, 1.0625, 0.2500, 0.3125], + [1.2500, 1.0625, 0.2500, 0.3125], + [1.0000, 0.7500, 0.0000, 0.0000], + ], + [ + [1.2500, 1.0625, 0.2500, 0.3125], + [1.5000, 1.3750, 0.0000, 0.0000], + [1.5000, 1.3750, 0.0000, 0.0000], + [1.2500, 1.0625, 0.2500, 0.3125], + ], + [ + [1.2500, 1.0625, 0.2500, 0.3125], + [1.5000, 1.3750, 0.0000, 0.0000], + [1.5000, 1.3750, 0.0000, 0.0000], + [1.2500, 1.0625, 0.2500, 0.3125], + ], + [ + [1.0000, 0.7500, 0.0000, 0.0000], + [1.2500, 1.0625, 0.2500, 0.3125], + [1.2500, 1.0625, 0.2500, 0.3125], + [1.0000, 0.7500, 0.0000, 0.0000], + ], + ] + self.assertEqual(node_probs.tolist(), expected_node_probs) + self.assertEqual(edge_probs.tolist(), expected_edge_probs) + self.assertEqual(nodepair_probs.tolist(), expected_nodepair_probs) + + expected_nodepair_dists = [ + [0.0, 1.0, 2.0, 3.0], + [1.0, 0.0, 1.0, 2.0], + [2.0, 1.0, 0.0, 1.0], + [3.0, 2.0, 1.0, 0.0], + ] + # Select half-edge node pairs + expected_edge_dists = [[1.0], [1.0], [1.0], [1.0], [1.0], [1.0]] + # Minimum of column, minimum of row, mean of column, mean of row, + # stdev of column, stdev of row, for each node + # stdev here uses n for normalization instead of n-1 + stdev_a = math.sqrt((1.5 * 1.5 + 0.5 * 0.5 + 0.5 * 0.5 + 1.5 * 1.5) / 4) + stdev_b = math.sqrt((1.0 * 1.0 + 1.0 * 1.0) / 4) + expected_node_dists = [ + [0.0, 0.0, 1.5, 1.5, stdev_a, stdev_a], + [0.0, 0.0, 1.0, 1.0, stdev_b, stdev_b], + [0.0, 0.0, 1.0, 1.0, stdev_b, stdev_b], + [0.0, 0.0, 1.5, 1.5, stdev_a, stdev_a], + ] + np.testing.assert_array_almost_equal(node_dists.tolist(), expected_node_dists) + self.assertEqual(edge_dists.tolist(), expected_edge_dists) + self.assertEqual(nodepair_dists.tolist(), expected_nodepair_dists) if __name__ == "__main__": diff --git a/tests/test_positional_encoders.py b/tests/test_positional_encoders.py index 166929ba2..66148487f 100644 --- a/tests/test_positional_encoders.py +++ b/tests/test_positional_encoders.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -18,19 +18,40 @@ import numpy as np import unittest as ut -from copy import deepcopy from rdkit import Chem import datamol as dm import torch -from scipy.sparse import coo_matrix +from torch_geometric.data import Data + +import graphium +import graphium_cpp -from graphium.features.featurizer import GraphDict -from graphium.features.positional_encoding import graph_positional_encoder from graphium.nn.encoders import laplace_pos_encoder, mlp_encoder, signnet_pos_encoder + # TODO: Test the MLP_encoder and signnet_pos_encoder +def get_pe_tensors(smiles, pos_encoding_tensor): + tensors, _, _ = graphium_cpp.featurize_smiles( + smiles, + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_onehot + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_float + False, # has_conformer + torch.tensor(data=[], dtype=torch.int64), # edge_property_list + pos_encoding_tensor, + True, # duplicate_edges + False, # add_self_loop + False, # explicit_H=False + False, # use_bonds_weights + True, # offset_carbon + 7, # torch float64 + 0, # mask_nan_style_int + 0, # mask_nan_value + ) + return tensors + + class test_positional_encoder(ut.TestCase): smiles = [ "C", @@ -44,22 +65,34 @@ class test_positional_encoder(ut.TestCase): adjs = [Chem.rdmolops.GetAdjacencyMatrix(mol) for mol in mols] def test_laplacian_eigvec_eigval(self): - for ii, adj in enumerate(deepcopy(self.adjs)): + for ii, mol in enumerate(self.smiles): + adj = self.adjs[ii] for num_pos in [1, 2, 4]: # Can't test too much eigs because of multiplicities for disconnected_comp in [True, False]: err_msg = f"adj_id={ii}, num_pos={num_pos}, disconnected_comp={disconnected_comp}" - # returns a dictionary of computed pe - pos_kwargs = { - "pos_type": "laplacian_eigvec", - "num_pos": num_pos, - "disconnected_comp": disconnected_comp, - "pos_level": "node", + features = { + "laplacian_eigval": { + "pos_type": "laplacian_eigval", + "num_pos": num_pos, + "disconnected_comp": disconnected_comp, + "pos_level": "node", + }, + "laplacian_eigvec": { + "pos_type": "laplacian_eigvec", + "num_pos": num_pos, + "disconnected_comp": disconnected_comp, + "pos_level": "node", + }, } - num_nodes = adj.shape[0] - eigvecs, cache = graph_positional_encoder(adj, num_nodes, pos_kwargs=pos_kwargs) - pos_kwargs["pos_type"] = "laplacian_eigval" - eigvals, cache = graph_positional_encoder(adj, num_nodes, pos_kwargs=pos_kwargs) + ( + pos_encoding_names, + pos_encoding_tensor, + ) = graphium_cpp.positional_feature_options_to_tensor(features) + + tensors = get_pe_tensors(mol, pos_encoding_tensor) + eigvals = tensors[4] + eigvecs = tensors[5] self.assertEqual(list(eigvecs.shape), [adj.shape[0], num_pos], msg=err_msg) self.assertEqual(list(eigvals.shape), [adj.shape[0], num_pos], msg=err_msg) @@ -74,7 +107,10 @@ def test_laplacian_eigvec_eigval(self): true_num_pos = min(num_pos, len(true_eigvals)) true_eigvals, true_eigvecs = true_eigvals[:true_num_pos], true_eigvecs[:, :true_num_pos] - if not ("." in self.smiles[ii]): + if not ("." in mol): + print( + f"About to test eigvecs for smiles {mol}, num_pos {num_pos}, disconnected_comp {disconnected_comp}" + ) np.testing.assert_array_almost_equal( np.abs(true_eigvecs), np.abs(eigvecs[:, :true_num_pos]), @@ -88,13 +124,22 @@ def test_laplacian_eigvec_eigval(self): # didn't actually check the exact computation result because the code was adapted def test_rwse(self): - for ii, adj in enumerate(deepcopy(self.adjs)): + for ii, mol in enumerate(self.smiles): + adj = self.adjs[ii] for ksteps in [1, 2, 4]: err_msg = f"adj_id={ii}, ksteps={ksteps}" num_nodes = adj.shape[0] pos_kwargs = {"pos_type": "rw_return_probs", "ksteps": ksteps, "pos_level": "node"} - rwse_embed, cache = graph_positional_encoder(adj, num_nodes, pos_kwargs=pos_kwargs) + features = { + "rw_return_probs": pos_kwargs, + } + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor( + features + ) + tensors = get_pe_tensors(mol, pos_encoding_tensor) + rwse_embed = tensors[4] + self.assertEqual(list(rwse_embed.shape), [num_nodes, ksteps], msg=err_msg) # TODO: work in progress @@ -105,23 +150,32 @@ def test_rwse(self): """ def test_laplacian_eigvec_with_encoder(self): - for ii, adj in enumerate(deepcopy(self.adjs)): + for ii, mol in enumerate(self.smiles): for num_pos in [2, 4, 8]: # Can't test too much eigs because of multiplicities for disconnected_comp in [True, False]: for model_type in ["Transformer", "DeepSet", "MLP"]: err_msg = f"adj_id={ii}, num_pos={num_pos}, disconnected_comp={disconnected_comp}" - # returns a dictionary of computed pe - pos_kwargs = { - "pos_type": "laplacian_eigvec", - "num_pos": num_pos, - "disconnected_comp": disconnected_comp, - "pos_level": "node", + features = { + "laplacian_eigval": { + "pos_type": "laplacian_eigval", + "num_pos": num_pos, + "disconnected_comp": disconnected_comp, + "pos_level": "node", + }, + "laplacian_eigvec": { + "pos_type": "laplacian_eigvec", + "num_pos": num_pos, + "disconnected_comp": disconnected_comp, + "pos_level": "node", + }, } - num_nodes = adj.shape[0] - eigvecs, cache = graph_positional_encoder(adj, num_nodes, pos_kwargs=pos_kwargs) - pos_kwargs["pos_type"] = "laplacian_eigval" - eigvals, cache = graph_positional_encoder(adj, num_nodes, pos_kwargs=pos_kwargs) + ( + pos_encoding_names, + pos_encoding_tensor, + ) = graphium_cpp.positional_feature_options_to_tensor(features) + + tensors = get_pe_tensors(mol, pos_encoding_tensor) input_keys = ["laplacian_eigvec", "laplacian_eigval"] in_dim = num_pos @@ -129,16 +183,17 @@ def test_laplacian_eigvec_with_encoder(self): out_dim = 64 num_layers = 1 - eigvecs = torch.from_numpy(eigvecs) - eigvals = torch.from_numpy(eigvals) - - g = GraphDict( - { - "adj": coo_matrix(adj), - "data": {"laplacian_eigval": eigvals, "laplacian_eigvec": eigvecs}, - } + num_nodes = tensors[2].size(0) + data_dict = { + # "feat": tensors[2], + # "edge_feat": tensors[3], + "laplacian_eigval": tensors[4].float(), + "laplacian_eigvec": tensors[5].float(), + } + # Create the PyG graph object `Data` + data = Data( + edge_index=tensors[0], edge_weight=tensors[1], num_nodes=num_nodes, **data_dict ) - batch = g.make_pyg_graph() encoder = laplace_pos_encoder.LapPENodeEncoder( input_keys=input_keys, @@ -153,7 +208,7 @@ def test_laplacian_eigvec_with_encoder(self): first_normalization=None, ) - hidden_embed = encoder(batch, key_prefix=None) + hidden_embed = encoder(data, key_prefix=None) assert "node" in hidden_embed.keys() self.assertEqual(list(hidden_embed["node"].shape), [num_nodes, out_dim], msg=err_msg) diff --git a/tests/test_positional_encodings.py b/tests/test_positional_encodings.py deleted file mode 100644 index 89bf355a4..000000000 --- a/tests/test_positional_encodings.py +++ /dev/null @@ -1,92 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -""" -Unit tests for the positional encodings in graphium/features/* -""" - -import numpy as np -import networkx as nx -import unittest as ut - -# from graphium.features.spectral import compute_laplacian_positional_eigvecs # TODO: add tests -# from graphium.features.rw import compute_rwse # TODO: add tests -from graphium.features.electrostatic import compute_electrostatic_interactions -from graphium.features.commute import compute_commute_distances -from graphium.features.graphormer import compute_graphormer_distances - - -class test_positional_encodings(ut.TestCase): - # Test graphs - adj_dict = {} - max_dict = {} - - # 6-ring - adj = np.asarray( - [ - [0, 1, 0, 0, 0, 1], - [1, 0, 1, 0, 0, 0], - [0, 1, 0, 1, 0, 0], - [0, 0, 1, 0, 1, 0], - [0, 0, 0, 1, 0, 1], - [1, 0, 0, 0, 1, 0], - ] - ) - adj_dict["6-ring"] = adj - max_dict["6-ring"] = 3 - - # 5-path - G = nx.path_graph(5) - adj = nx.to_numpy_array(G) - adj_dict["5-path"] = adj - max_dict["5-path"] = 4 - - # 4-clique - adj = 1 - np.eye(4) - adj_dict["4-clique"] = adj - max_dict["4-clique"] = 1 - - # 4-barbell - H = nx.barbell_graph(4, 0) - adj = nx.to_numpy_array(H) - adj_dict["4-barbell"] = adj - max_dict["4-barbell"] = 3 - - def test_dimensions(self): - for _, adj in self.adj_dict.items(): - pe, _, _ = compute_electrostatic_interactions(adj, cache={}) - self.assertEqual(pe.shape, adj.shape) - - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) - self.assertEqual(pe.shape, adj.shape) - - pe, _, _ = compute_commute_distances(adj, adj.shape[0], cache={}) - self.assertEqual(pe.shape, adj.shape) - - def test_symmetry(self): - for _, adj in self.adj_dict.items(): - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) - np.testing.assert_array_almost_equal(pe, pe.T) - - pe, _, _ = compute_commute_distances(adj, adj.shape[0], cache={}) - np.testing.assert_array_almost_equal(pe, pe.T) - - def test_max_dist(self): - for key, adj in self.adj_dict.items(): - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) - np.testing.assert_array_almost_equal(pe.max(), self.max_dict[key]) - - -if __name__ == "__main__": - ut.main() diff --git a/tests/test_training.py b/tests/test_training.py index aeec93689..b737cc478 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -48,7 +48,7 @@ def setup_class(cls): print("Data has been successfully downloaded.") - def call_cli_with_overrides(self, acc_type: str, acc_prec: str, load_type: str) -> None: + def call_cli_with_overrides(self, acc_type: str, acc_prec: str) -> None: overrides = [ f"accelerator={acc_type}", "tasks=toymix", @@ -75,7 +75,6 @@ def call_cli_with_overrides(self, acc_type: str, acc_prec: str, load_type: str) "+datamodule.args.task_specific_args.zinc.sample_size=1000", "trainer.trainer.check_val_every_n_epoch=1", f"trainer.trainer.precision={acc_prec}", - f"datamodule.args.dataloading_from={load_type}", ] if acc_type == "ipu": overrides.append("accelerator.ipu_config=['useIpuModel(True)']") @@ -92,14 +91,12 @@ def call_cli_with_overrides(self, acc_type: str, acc_prec: str, load_type: str) # Restore the original sys.argv sys.argv = original_argv - @pytest.mark.parametrize("load_type", ["RAM", "disk"]) - def test_cpu_cli_training(self, load_type): - self.call_cli_with_overrides("cpu", "32", load_type) + def test_cpu_cli_training(self): + self.call_cli_with_overrides("cpu", "32") @pytest.mark.ipu @pytest.mark.skip - @pytest.mark.parametrize("load_type", ["RAM", "disk"]) - def test_ipu_cli_training(self, load_type): + def test_ipu_cli_training(self): with patch("poptorch.ipuHardwareIsAvailable", return_value=True): with patch("lightning_graphcore.accelerator._IPU_AVAILABLE", new=True): import poptorch @@ -108,4 +105,4 @@ def test_ipu_cli_training(self, load_type): from lightning_graphcore.accelerator import _IPU_AVAILABLE assert _IPU_AVAILABLE is True - self.call_cli_with_overrides("ipu", "16-true", load_type) + self.call_cli_with_overrides("ipu", "16-true")