diff --git a/README.md b/README.md index 97c921405..c409ad09c 100644 --- a/README.md +++ b/README.md @@ -35,8 +35,6 @@ Visit https://graphium-docs.datamol.io/. ## Installation for developers -### For CPU and GPU developers - Use [`mamba`](https://github.com/mamba-org/mamba), a faster and better alternative to `conda`. If you are using a GPU, we recommend enforcing the CUDA version that you need with `CONDA_OVERRIDE_CUDA=XX.X`. @@ -53,18 +51,6 @@ mamba activate graphium pip install --no-deps -e . ``` -### For IPU developers -```bash -# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu` -./install_ipu.sh .graphium_ipu -``` - -The above step needs to be done once. After that, enable the SDK and the environment as follows: - -```bash -source enable_ipu.sh .graphium_ipu -``` - ## Training a model To learn how to train a model, we invite you to look at the documentation, or the jupyter notebooks available [here](https://github.com/datamol-io/graphium/tree/master/docs/tutorials/model_training). @@ -148,6 +134,39 @@ graphium-train --config-path [PATH] --config-name [CONFIG] ``` Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium. +### Finetuning + +After pretraining a model and saving a model checkpoint, the model can be finetuned to a new task + +```bash +graphium-train +finetuning [example-custom OR example-tdc] finetuning.pretrained_model=[model_identifier] +``` + +The `[model_identifier]` serves to identify the pretrained model among those maintained in the `GRAPHIUM_PRETRAINED_MODELS_DICT` in `graphium/utils/spaces.py`, where the `[model_identifier]` maps to the location of the checkpoint of the pretrained model. + +We have provided two example yaml configs under `expts/hydra-configs/finetuning` for finetuning on a custom dataset (`example-custom.yaml`) or for a task from the TDC benchmark collection (`example-tdc.yaml`). + +When using `example-custom.yaml`, to finetune on a custom dataset, we nee to provide the location of the data (`constants.data_path=[path_to_data]`) and the type of task (`constants.task_type=[cls OR reg]`). + +When using `example-tdc.yaml`, to finetune on a TDC task, we only need to provide the task name (`constants.task=[task_name]`) and the task type is inferred automatically. + +Custom datasets to finetune from consist of two files `raw.csv` and `split.csv`. The `raw.csv` contains two columns, namely `smiles` with the smiles strings, and `target` with the corresponding targets. In `split.csv`, three columns `train`, `val`, `test` contain the indices of the rows in `raw.csv`. Examples can be found under `expts/data/finetuning_example-reg` (regression) and `expts/data/finetuning_example-cls` (binary classification). + +### Fingerprinting + +Alternatively, we can also obtain molecular embeddings (fingerprints) from a pretrained model: +```bash +graphium fps create [example-custom OR example-tdc] pretrained.model=[model_identifier] pretrained.layers=[layer_identifiers] +``` + +We have provided two example yaml configs under `expts/hydra-configs/fingerprinting` for extracting fingerprints for a custom dataset (`example-custom.yaml`) or for a dataset from the TDC benchmark collection (`expample-tdc.yaml`). + +After specifiying the `[model_identifier]`, we need to provide a list of layers from that model where we want to read out embeddings via `[layer_identifiers]` (which requires knowledge of the architecture of the pretrained model). + +When using `example-custom.yaml`, the location of the smiles to be embedded needs to be passed via `datamodule.df_path=[path_to_data]`. The data can be passed as a csv/parquet file with a column `smiles`, similar to `expts/data/finetuning_example-reg/raw.csv`. + +When extracting fingerprints for a TDC task using `expample-tdc.yaml`, we need to specify `datamodule.benchmark` and `datamodule.task` instead of `datamodule.df_path`. + ## License Under the Apache-2.0 license. See [LICENSE](LICENSE). diff --git a/docs/api/graphium.finetuning.md b/docs/api/graphium.finetuning.md index 7e2b7f444..fb8c5e418 100644 --- a/docs/api/graphium.finetuning.md +++ b/docs/api/graphium.finetuning.md @@ -10,4 +10,4 @@ Module for finetuning models and doing linear probing (fingerprinting). ::: graphium.finetuning.finetuning_architecture.FinetuningHead -::: graphium.finetuning.fingerprinting.Fingerprinter +::: graphium.fingerprinting.fingerprinter.Fingerprinter diff --git a/docs/cli/graphium-train.md b/docs/cli/graphium-train.md index 0b421be67..b51f7e50a 100644 --- a/docs/cli/graphium-train.md +++ b/docs/cli/graphium-train.md @@ -24,7 +24,7 @@ graphium-train architecture=toymix tasks=toymix training=toymix model=gcn accele automatically selects the correct configs to run the experiment on GPU. Finally, you can also run a fine-tuning loop: ```bash -graphium-train +finetuning=admet +graphium-train +finetuning=example-tdc ``` To use a config file you built from scratch you can run diff --git a/docs/cli/graphium.md b/docs/cli/graphium.md index d90aa8aad..b2d816fad 100644 --- a/docs/cli/graphium.md +++ b/docs/cli/graphium.md @@ -103,7 +103,7 @@ $ graphium finetune [OPTIONS] COMMAND [ARGS]... **Commands**: -* `admet`: Utility CLI to easily fine-tune a model on... +* `tdc`: Utility CLI to easily fine-tune a model on... * `fingerprint`: Endpoint for getting fingerprints from a... ### `graphium finetune admet` @@ -135,7 +135,7 @@ Endpoint for getting fingerprints from a pretrained model. The pretrained model should be a `.ckpt` path or pre-specified, named model within Graphium. The fingerprint layer specification should be of the format `module:layer`. If specified as a list, the fingerprints from all the specified layers will be concatenated. -See the docs of the `graphium.finetuning.fingerprinting.Fingerprinter` class for more info. +See the docs of the `graphium.fingerprinting.fingerprinter.Fingerprinter` class for more info. **Usage**: diff --git a/expts/data/finetuning_example-cls/raw.csv b/expts/data/finetuning_example-cls/raw.csv new file mode 100644 index 000000000..e5df1c58c --- /dev/null +++ b/expts/data/finetuning_example-cls/raw.csv @@ -0,0 +1,201 @@ +,Drug_ID,smiles,target +0,644675,CC(=O)N(c1ccc2oc(=O)sc2c1)S(=O)(=O)c1cccs1,0 +1,644890,COc1ccccc1C(c1nnnn1C(C)(C)C)N1CCN(Cc2ccncc2)CC1,1 +2,645164,CCC(c1nnnn1CC1CCCO1)N(CCN1CCOCC1)Cc1cc2cc(C)ccc2[nH]c1=O,0 +3,6602688,Br.N=c1n(CCN2CCOCC2)c2ccccc2n1CC(=O)c1ccc(Cl)c(Cl)c1,1 +4,645448,CCC(C)(C)NC(=O)c1ccc2c(c1)N(CC(=O)OC)C(=O)C(C)(C)O2,0 +5,645569,CCc1cc2c(=O)[nH]cnc2s1,0 +6,645818,COc1cccc2c(=O)c(C(=O)NCc3cccs3)c[nH]c12,1 +7,645911,CCc1nnc(SCc2ccc(OC(C)C)cc2)n1N,0 +8,645965,Cc1ccc2c(c1)nnn2C1CCN(CC(=O)N2c3ccccc3CC2C)CC1,1 +9,646164,CCOC(=O)CSC1=C(C#N)C(C)C2=C(CCCC2=O)N1,0 +10,646293,Cc1ccc2cc(C)c3nnc(SCC(=O)NCc4ccco4)n3c2c1,0 +11,646353,CCOC(=O)N1CCN(S(=O)(=O)Cc2ccccc2)CC1,0 +12,646472,CCOC(=O)c1cc2sc(C)cc2n1CC(=O)N1CCN(C(=O)c2ccco2)CC1,0 +13,646515,CCOC(=O)C(NC(C)=O)C(OC(C)=O)c1cccc(N(CCO)CCO)c1,0 +14,646597,CCC(=O)Nc1cc2c(cc1C(=O)c1ccccc1)OCCO2,0 +15,6602690,CN(C)CCNC(=O)C(C(=O)c1ccc(F)cc1)n1ccccc1=O.Cl,0 +16,646768,Cc1cc(C)nc(N2CCC(C(=O)NCCc3ccc(F)cc3)CC2)n1,0 +17,646780,COc1cc(-c2nnc(-c3ccc(N4CCOCC4)cc3)o2)cc(OC)c1OC,0 +18,646897,CN(C)C=C1C(=O)N(C2CCCCC2)C(=O)N(C2CCCCC2)C1=O,0 +19,646955,COC(=O)CN(c1ccccn1)S(=O)(=O)c1ccccc1,0 +20,6398903,Cc1ccc(C)c(/C(O)=C2/C(=O)C(=O)N(CCN3CCOCC3)C2c2ccco2)c1,0 +21,647114,Cn1c(SC2=CS(=O)(=O)c3ccccc32)nc2ccccc21,0 +22,647205,CC(C(=O)NC1CCCC1)N(C(=O)c1snc(C(N)=O)c1N)c1ccc2c(c1)OCCO2,1 +23,647430,Cc1cc(C)n(-c2nc3c(c(=O)[nH]c(=O)n3C)n2C(C)C)n1,0 +24,647727,O=C(Nc1ccc(S(=O)(=O)N2CCCC2)cc1)c1ccc(CN2CCOCC2)cc1,1 +25,6602522,Cl.O=C(CN1CCN(c2ncccn2)CC1)NCCC1=CCCCC1,1 +26,647937,CC(=O)NC1(c2cccc(F)c2)CCN(CC(=O)NC2CCCCC2)CC1,0 +27,647996,Cc1cccc(Nc2nc3c(c(=O)n(C)c(=O)n3C)n2CC(O)CO)c1,0 +28,648175,CCn1c(Cc2ccccc2)nnc1SCC(=O)NC(C)(C)C,0 +29,648282,COc1cc2cc(CN(CCCO)S(=O)(=O)c3ccccc3Cl)c(=O)[nH]c2cc1OC,0 +30,648407,CCc1nnc(NC(=O)CSc2nnc(COc3ccccc3)n2Cc2ccccc2)s1,0 +31,648481,CSc1nc2nc3c(c(=O)n2[nH]1)CN(Cc1ccccc1)CC3,0 +32,648708,Cc1nn(C)c(C)c1CNC(=O)c1cnn2c1NC(c1ccccc1)CC2C(F)(F)F,0 +33,648836,Cc1ccc(C(c2nnnn2CC2CCCO2)N2CCN(C(=O)c3ccco3)CC2)cc1,0 +34,648878,CCN(CC)c1ccc2c(Cl)c(Br)c(=O)oc2c1,0 +35,648947,COc1ccc(OC)c(NC(=O)C(CC(=O)O)NCc2ccco2)c1,0 +36,649015,O=C1CC(c2cccs2)c2cc3c(cc2N1)OCO3,1 +37,649453,CCn1c(COc2ccccc2)nnc1SCC(=O)Nc1cc(OC)ccc1OC,0 +38,649754,Nc1ncc(-c2ccccc2)n1CC1CCCO1,1 +39,649786,CCCCn1c(SCC(=O)N2CCCC2)nc2c1c(=O)n(C)c(=O)n2C,0 +40,649878,O=C(O)CCn1nnc(-c2cccs2)n1,0 +41,650002,Cn1c(=O)[nH]c(=O)c2c1nc(NCCCO)n2CCCc1ccccc1,0 +42,650100,COc1ccccc1OCCn1cc(C(=O)c2ccco2)c2ccccc21,0 +43,650250,Cc1ccc(-c2csc(N3CCC(NS(=O)(=O)c4ccc5c(c4)OCCO5)CC3)n2)cc1,0 +44,650341,COc1ccc(S(=O)(=O)N2CCC(NC(=O)Nc3ccc(C)cc3)CC2)cc1,0 +45,650486,Cn1c(CNC(=O)Nc2ccccc2)nnc1SCc1ccccc1,0 +46,650558,O=C(CSc1nnc(CNc2ccccc2)o1)N1CCCc2ccccc21,0 +47,650691,CCCCNC(=O)NS(=O)(=O)c1ccc(C(=O)OC(C)C)o1,0 +48,6602999,CCOC(=O)C1Cc2c([nH]c3ccccc23)CN1.Cl,1 +49,650985,COc1ccc(CCNc2c(C)c(C)nc3ncnn23)cc1,1 +50,5768893,COCCN1C(=O)C(=O)/C(=C(/O)c2cccc(OC)c2)C1c1ccco1,0 +51,651076,CCC(=O)Nc1cccc(NC(=O)CSc2nnnn2Cc2ccccc2)c1,0 +52,651205,CCc1c(C)nc2c(C#N)c(C)[nH]n2c1=O,0 +53,651338,CCC(=O)N(Cc1ccco1)c1nc(-c2ccccc2)cs1,0 +54,651587,O=C(OCCCN1CCCCC1)c1ccc(O)cc1,1 +55,651589,Cc1cc(NC(=O)CCC(=O)N2CCC3(CC2)OCCO3)no1,0 +56,651769,O=C(CNC(=O)c1ccco1)OCc1c(F)cccc1Cl,0 +57,652002,O=C(CSc1nnc(-c2cccnc2)o1)Oc1ccccc1,0 +58,652521,CN(C)S(=O)(=O)c1ccc(C(=O)Nc2ccc(CN3CCCC3)cc2)cc1,1 +59,652549,CC(C)C(C(=O)NCC1CCCO1)N(Cc1ccco1)C(=O)CNS(=O)(=O)c1ccc(F)cc1,0 +60,652700,COCCNC(=O)COC(=O)c1nsc(Cl)c1Cl,0 +61,652799,CCOc1ccc(-c2nnn(CC(=O)Nc3cc(OC)ccc3OC)n2)cc1OCC,0 +62,6603138,Cl.O=C(CN1CCN(C2CCCCC2)CC1)NCCC1=CCCCC1,1 +63,653279,O=C(O)C1C2C=CC3(CN(Cc4ccccn4)C(=O)C13)O2,0 +64,653412,Cn1c(SCC(=O)NCc2ccc3c(c2)OCO3)nnc1-c1cc2ccccc2cc1O,1 +65,6603457,CCC(c1nnnn1CCOC)N1CCN(C(=O)c2ccco2)CC1.Cl,0 +66,653646,O=S(=O)(c1ccccc1)N1CCN(c2cc(-c3ccccc3)nc3ncnn23)CC1,0 +67,653695,O=C(CSc1n[nH]c(-c2ccccc2O)n1)N1CCCc2ccccc21,1 +68,653778,COc1ccccc1-n1c(SCC(=O)N2CCCC2)nc2cccnc21,0 +69,653799,CCc1nnc2sc(-c3ccc(NC(=O)c4ccco4)cc3)nn12,0 +70,653914,COc1ccc(-c2nnn(CC(=O)N(CC(=O)NCCC(C)C)Cc3cccs3)n2)cc1OC,1 +71,654078,CCn1c(SCc2ccc(C#N)cc2)nnc1-c1ccc(S(=O)(=O)N2CCCCC2)cc1,0 +72,654182,CCOC(=O)Cc1cc(=O)n2[nH]c(C)c(-c3ccccc3)c2n1,0 +73,6398932,CCOc1cccc(/C(O)=C2/C(=O)C(=O)N(Cc3ccco3)C2c2ccncc2)c1,0 +74,654363,O=C(NC1CCCCC1)C(c1cccs1)N(Cc1cccs1)C(=O)c1ccco1,1 +75,654435,O=C(CSc1nc2ccccc2o1)Nc1nc2ccccc2s1,0 +76,5373216,COc1n[nH]c2nncnc12,0 +77,654546,CCc1ccc(N2CC(C(=O)NC3=NCCS3)CC2=O)cc1,0 +78,654623,COc1ccc(CNC(=O)CN(CC2CCCO2)C(=O)CNS(=O)(=O)c2ccccc2)cc1,0 +79,654635,COc1cc(C2C(C(=O)c3ccc(C)o3)=C(O)C(=O)N2CCc2ccccc2)ccc1O,0 +80,654761,Cc1ccc(C)n1C(Cc1ccccc1)C(=O)O,0 +81,655183,COc1ccccc1CN(Cc1cc2cc(C)cc(C)c2[nH]c1=O)Cc1nnnn1CC1CCCO1,0 +82,655265,COc1ccc(OCc2nnc(SCC(=O)O)n2N)cc1,0 +83,5768421,COCCCN1C(=O)C(=O)/C(=C(/O)c2ccc(OC(C)C)c(C)c2)C1c1ccncc1,0 +84,655401,CCN(C1CCCCC1)S(=O)(=O)c1ccc(S(=O)(=O)NCc2ccncc2)cc1,1 +85,655439,CCOC(=O)C1=C(C)NC(=O)NC1c1ccoc1,0 +86,655857,c1ccc(Cn2nnc3c(N4CCc5ccccc5C4)ncnc32)cc1,0 +87,655866,COc1ccc(C(=O)NC2CC3CCCC(C2)N3CC(C)C)cc1OC,1 +88,655948,Cc1cc(N2CCN(c3nc4ccccc4s3)CC2)n2ncnc2n1,0 +89,656017,CC(C)C(=O)Nc1cc2c(cc1C(=O)c1ccccc1)OCCO2,0 +90,656027,COc1ccc(CCN2C(=O)C(O)=C(C(=O)c3ccco3)C2c2cccs2)cc1OC,0 +91,656095,COc1ccc(C2C(C(=O)N3CCOCC3)=C(C)NC3=C2C(=O)CC(C)(C)C3)c(OC)c1,0 +92,656157,O=c1oc(-c2ccco2)nc2c1cnn2-c1ccccc1,0 +93,656183,CC(C)OC(=O)NCCOC(=O)Nc1cccc(Cl)c1,0 +94,656257,Cc1ccccc1OCC1Cn2c(nc3c2c(=O)[nH]c(=O)n3C)O1,0 +95,656272,Cc1ccc(-c2[nH]n3c(=O)c4c(nc3c2C)CCCC4)cc1,1 +96,656290,Cc1cccc(NC(=O)Cn2c(=O)oc3ccccc32)c1,1 +97,6603060,I.OCCNC1=NCCN1,0 +98,6449251,COc1cc2c(cc1OC)/C(=C/C(=O)N1CCOCC1)NC(C)(C)C2.Cl,1 +99,135449532,Cc1cc(=O)[nH]c(-n2nc(C)cc2C)n1,0 +100,5940036,CCOC(=O)C1=C(N)n2c(s/c(=C\c3ccco3)c2=O)=C(C(=O)OCC)C1c1ccco1,0 +101,208296,O=c1nc(-c2ccccc2)cn[nH]1,0 +102,658411,COC(=O)c1ccc(Oc2cc(C)nc(-n3nc(C)cc3C)n2)cc1,0 +103,658723,COC(=O)c1[nH]c2ccc(Br)cc2c1NC(=O)CN1CCN(C2CCCCC2)CC1,1 +104,658813,CCc1c(C)c(C#N)c2nc3ccccc3n2c1Nc1c(C)n(C)n(-c2ccccc2)c1=O,0 +105,135415833,O=c1c(Cc2ccccc2)c(O)nc2n1CCS2,0 +106,658879,CCC(NC(=O)Nc1cc(OC)c(OC)c(OC)c1)(C(F)(F)F)C(F)(F)F,0 +107,659040,O=c1[nH]c(=S)[nH]nc1Cc1ccccc1,0 +108,135435901,COc1ccc(C2CC(=O)C(C=NCCN3CCOCC3)=C(O)C2)cc1,0 +109,16411130,Cc1nc(/N=C(\N)Nc2ccccc2)nc2ccccc12,0 +110,659321,CCOc1ccc(CSC(CC(=O)O)C(=O)O)cc1,0 +111,2838016,CC1=CC=CN2CC(O)CN=C12.Cl,0 +112,6603569,Cc1ccc2c(c1)[C@@H]1CN(C)CC[C@@H]1N2S(=O)(=O)c1ccc(F)cc1.Cl,1 +113,1922089,Cc1cc(C)c(-n2c(O)c(C=NCCN3CCOCC3)c(=O)[nH]c2=O)c(C)c1,0 +114,659756,O=C1CCCN1CC(CN1CCOCC1)Sc1nnnn1-c1ccccc1,0 +115,660120,Cc1cc(-c2cc(-c3ccc(Cl)cc3)nc(N)c2C#N)co1,0 +116,660285,COc1ccc(S(=O)(=O)N2CCC(N3CCCCC3)CC2)cc1,1 +117,660304,O=C(COc1ccccc1)Nc1ccc(-c2nnc(-c3ccco3)o2)cc1,0 +118,5389248,COc1ccc(C2/C(=C(/O)c3ccc(Cl)cc3)C(=O)C(=O)N2CCN2CCOCC2)c(OC)c1,0 +119,5389254,Cc1oc2cc(O)ccc2c(=O)c1-c1cnn(-c2ccccc2)c1,0 +120,660546,Cc1cccc(OCC(=O)N2CCC(N3CCCCCC3)CC2)c1,1 +121,135420605,CCCCc1c(O)nc(SCCN(C)C)n(-c2ccccc2)c1=O,0 +122,660831,COC1(OC)N=C(NC(=O)Nc2ccccc2)C2(C#N)C(c3ccccc3)C12C#N,0 +123,660995,Cc1ccc(S(=O)(=O)NCCSc2nnnn2C)cc1,0 +124,661065,CNc1oc(-c2cccs2)nc1C#N,0 +125,661098,CC(=O)Nc1ccc(Nc2ncnc3c2cnn3-c2ccccc2)cc1,0 +126,661170,O=C(Cc1ccc(Cl)cc1)Nc1cccc(-c2nnc(-c3ccco3)o2)c1,0 +127,661178,CCCOc1ccc(CSC(CC(=O)O)C(=O)O)cc1,0 +128,661187,c1ccc(-n2ncc3c(NCCCN4CCOCC4)ncnc32)cc1,0 +129,661203,CC1CCc2cccc3c2N1c1cc(C#N)c(C#N)cc1O3,0 +130,661217,N#Cc1nc(-c2cccs2)oc1NCc1ccccc1,1 +131,661296,CCc1c(C(=O)O)[nH]c2ccc(Br)cc12,0 +132,661300,S=c1nc(-c2ccccc2)[nH]n1-c1ccccc1,0 +133,661349,CC(=O)c1c(C(C)=O)c(C)n(NC(=O)c2ccncc2)c1C,0 +134,661355,Cn1c(=O)n(CCC(=O)O)c2ccccc21,0 +135,661406,CC1(C)CC(=O)C(CCCN2C(=O)c3ccccc3C2=O)C(=O)C1,0 +136,6603015,CN(C)CC(O)COc1cccc(OCC(O)CN(C)C)c1.Cl,0 +137,6603014,CN(C)CC(O)COc1ccc(C(C)(C)c2ccc(OCC(O)CN(C)C)cc2)cc1.Cl,0 +138,661455,Oc1ccccc1CNn1cnnc1,0 +139,661513,CCn1c(SCC(=O)Nc2ccccc2C(=O)OC)nnc1-c1ccc(N)cc1,0 +140,661518,CCc1cc2c(=O)c(-c3nc4ccccc4[nH]3)coc2cc1OC(=O)N1CCOCC1,0 +141,661528,Cc1ccc(C2Nc3ccccc3C(=O)N2Cc2ccco2)cc1,0 +142,661552,Cc1cn2c(-c3ccncc3)nnc2s1,0 +143,661761,CCOC(=O)c1[nH]c2cc3c(cc2c1NC(=O)CN1CCc2ccccc2C1)OCO3,1 +144,5389368,COc1ccc(/C(O)=C2\C(=O)C(=O)N(c3cc(C)on3)C2c2ccc(OC)c(OC)c2)cc1OC,0 +145,5389389,COc1cccc(/C(O)=C2/C(=O)C(=O)N(c3cc(C)on3)C2c2cccs2)c1,0 +146,5389423,CCN(CC)CCN1C(=O)C(=O)/C(=C(/O)c2cccc(OC)c2)C1c1ccccn1,0 +147,661999,COc1cccc(C2C(C(=O)c3cc4ccccc4o3)=C(O)C(=O)N2c2cc(C)on2)c1OC,0 +148,662011,CCN1CCCC1Cn1cnc2c([nH]c3ccc(C)cc32)c1=O,1 +149,6881185,COCCCNC(=O)c1c(N)n(/N=C/c2ccccn2)c2nc3ccccc3nc12,0 +150,6881246,CCO/C(C)=N/n1c2nc3ccccc3nc2c2c(=O)n(CC(C)C)c(C)nc21,0 +151,662144,CCOC(=O)c1c(C)n(C)c2ccc(OC)c(NC(=O)CN3CCN(Cc4ccccc4)CC3)c12,1 +152,662340,CCOC(=O)c1cc2c(=O)n3cccc(C)c3nc2n(CCCOC)c1=NC(=O)c1cccnc1,0 +153,5389504,CCOc1ccc(/C(O)=C2/C(=O)C(=O)N(CCOCCO)C2c2cccnc2)cc1,0 +154,662407,NC(=O)C1CCN(C(=O)CN2C(=O)c3ccccc3S2(=O)=O)CC1,0 +155,662515,CCOC(=O)c1[nH]c2cc(OC)c(OC)cc2c1NC(=O)c1nonc1C,0 +156,6603499,COc1ccc(C(=O)OC(C)CN2CCN(C)CC2)cc1OC.Cl,1 +157,9614332,C[n+]1cccc(CNC(=O)/C=N/O)c1.[I-],0 +158,662647,COc1ccccc1-c1nnc2n1N=C(c1ccc(O)c(O)c1)CS2,0 +159,662710,COc1cccc(-c2nnc3sc(-c4ccc(C)cc4)nn23)c1,0 +160,662745,CCOCCCn1cnc2c([nH]c3cc(OC)ccc32)c1=O,0 +161,200556,Cl.OC1(c2ccc(F)cc2)CCNC1,0 +162,662794,CC(C)(C)OC(=O)NCCc1nnc(SCC(=O)Nc2cccc(Cl)c2)o1,0 +163,662799,O=C1C2ON(c3ccccc3)C(c3ccncc3)C2C(=O)N1c1ccccc1,0 +164,662838,CCCn1nc(NC(=O)CC(C)C)c2cc3ccccc3nc21,0 +165,662878,CCOc1ccccc1NC(=O)CSc1nnc(-c2cnccn2)n1C,1 +166,662996,CCOc1ccc(-c2nnc3n2N=C(c2ccc(OC)cc2)CS3)cc1,0 +167,663008,CC(C)COP(=O)(c1ccc(N(C)C)cc1)C(O)c1ccccc1F,0 +168,6603621,Cl.NCCCCCc1nnc(SCc2ccccc2Cl)o1,1 +169,663121,Nc1c(S(=O)(=O)c2ccccc2)c2nc3ccccc3nc2n1Cc1ccco1,1 +170,9615342,Cn1c[n+](C)cc1/C=N/O.[I-],0 +171,663125,Oc1ccc(-c2[nH]ncc2-c2ccc(Cl)cc2)c(O)c1,1 +172,663143,CCOc1ccc(C2=Nn3c(nnc3-c3ccccc3OC)SC2)cc1,0 +173,663146,COc1ccc2c(c1)[nH]c1c(N3CCN(Cc4ccc5c(c4)OCO5)CC3)ncnc12,1 +174,663168,COc1cccc(-c2nnc3n2N=C(C(C)(C)C)CS3)c1,0 +175,663337,CCOC(=O)c1[nH]c2cc(OC)c(OC)cc2c1NC(=O)c1ccc2c(c1)OCO2,0 +176,663340,COc1ccc(CCn2c(=N)c(C(=O)NCc3ccco3)cc3c(=O)n4ccccc4nc32)cc1,1 +177,663539,Cc1nc(SCC(=O)Nc2ccc3c(c2)OCCO3)c2oc3ccccc3c2n1,0 +178,663581,COC(=O)[C@@H](NC(=O)Nc1ccc(C(C)=O)cc1)C(C)C,0 +179,5389740,CCCN1C(=O)C2(/C(=C(\O)c3ccc4c(c3)OCCO4)C(=O)C(=O)N2CCCOC)c2ccccc21,0 +180,663736,COC(=O)[C@H](Cc1ccccc1)NC(=O)N1CCN(Cc2ccccc2)CC1,1 +181,663792,CCOCCCn1c(=N)c(C(=O)NCc2ccc3c(c2)OCO3)cc2c(=O)n3cccc(C)c3nc21,1 +182,54676164,CCOC(=O)C1=C(O)C(=O)N(c2ccc(S(N)(=O)=O)cc2)C1c1ccc(OC)cc1,0 +183,664033,CC1(C)CCCN(C(=O)c2coc(=O)c(Br)c2)C1,0 +184,664154,COC(=O)c1[nH]c2cc(C)ccc2c1NC(=O)CN1CCCc2ccccc21,1 +185,5389802,Cc1nc2ccccn2c1/C(O)=C1\C(=O)C(=O)N(CCCn2ccnc2)C1c1ccncc1,0 +186,664250,CC(C)(C)OC(=O)N1CCCC1C(=O)NCCc1ccccc1,0 +187,135513628,CCOC(=O)/C(C(N)=NCCCO)=C(\O)OCC,0 +188,664461,O=C1COc2ccc(OCc3ccc(F)cc3)cc21,0 +189,6603365,Cl.c1ccc2c(c1)oc1c(NCCCn3ccnc3)ncnc12,1 +190,5389869,COc1ccc(-c2c(C)oc3c(CN4CCN(CCO)CC4)c(O)ccc3c2=O)cc1OC,0 +191,5389875,CC1Cc2cc(/C(O)=C3/C(=O)C(=O)N(CCN4CCOCC4)C3c3cccc(Cl)c3)ccc2O1,0 +192,5389878,COc1ccc(C2/C(=C(/O)c3ccc4c(c3)CC(C)O4)C(=O)C(=O)N2CCN2CCOCC2)cc1,0 +193,664733,Cc1nc2c3cnn(-c4ccc(C)c(C)c4)c3ncn2n1,0 +194,664737,CCn1c(=N)c(S(=O)(=O)c2ccc(F)cc2)cc2c(=O)n3ccccc3nc21,1 +195,664759,O=C(CN1CCN(Cc2ccccc2)CC1)c1ccc(Br)cc1,1 +196,5389891,C/C(Cl)=C\Cn1c(N2CCCC2)nc2c1c(=O)[nH]c(=O)n2C,0 +197,664983,CCc1ccc(N2CC(C)Cn3c2nc2c3c(=O)n(CCN3CCOCC3)c(=O)n2C)cc1,0 +198,6603255,Br.CC1(C(=O)CSc2nc3ccccc3s2)CCC(=O)O1,0 +199,665081,Cn1c(-c2ccc(CN3CCCCC3)o2)nc2ccccc21,0 diff --git a/expts/data/finetuning_example-cls/split.csv b/expts/data/finetuning_example-cls/split.csv new file mode 100644 index 000000000..2b47272b1 --- /dev/null +++ b/expts/data/finetuning_example-cls/split.csv @@ -0,0 +1,121 @@ +,train,val,test +0,0,120,160 +1,1,121,161 +2,2,122,162 +3,3,123,163 +4,4,124,164 +5,5,125,165 +6,6,126,166 +7,7,127,167 +8,8,128,168 +9,9,129,169 +10,10,130,170 +11,11,131,171 +12,12,132,172 +13,13,133,173 +14,14,134,174 +15,15,135,175 +16,16,136,176 +17,17,137,177 +18,18,138,178 +19,19,139,179 +20,20,140,180 +21,21,141,181 +22,22,142,182 +23,23,143,183 +24,24,144,184 +25,25,145,185 +26,26,146,186 +27,27,147,187 +28,28,148,188 +29,29,149,189 +30,30,150,190 +31,31,151,191 +32,32,152,192 +33,33,153,193 +34,34,154,194 +35,35,155,195 +36,36,156,196 +37,37,157,197 +38,38,158,198 +39,39,159,199 +40,40,, +41,41,, +42,42,, +43,43,, +44,44,, +45,45,, +46,46,, +47,47,, +48,48,, +49,49,, +50,50,, +51,51,, +52,52,, +53,53,, +54,54,, +55,55,, +56,56,, +57,57,, +58,58,, +59,59,, +60,60,, +61,61,, +62,62,, +63,63,, +64,64,, +65,65,, +66,66,, +67,67,, +68,68,, +69,69,, +70,70,, +71,71,, +72,72,, +73,73,, +74,74,, +75,75,, +76,76,, +77,77,, +78,78,, +79,79,, +80,80,, +81,81,, +82,82,, +83,83,, +84,84,, +85,85,, +86,86,, +87,87,, +88,88,, +89,89,, +90,90,, +91,91,, +92,92,, +93,93,, +94,94,, +95,95,, +96,96,, +97,97,, +98,98,, +99,99,, +100,100,, +101,101,, +102,102,, +103,103,, +104,104,, +105,105,, +106,106,, +107,107,, +108,108,, +109,109,, +110,110,, +111,111,, +112,112,, +113,113,, +114,114,, +115,115,, +116,116,, +117,117,, +118,118,, +119,119,, diff --git a/expts/data/finetuning_example-reg/raw.csv b/expts/data/finetuning_example-reg/raw.csv new file mode 100644 index 000000000..de8e39137 --- /dev/null +++ b/expts/data/finetuning_example-reg/raw.csv @@ -0,0 +1,161 @@ +,smiles,target +0,CCCS(=O)(=O)Nc1ccc(F)c(C(=O)c2c[nH]c3ncc(-c4ccc(Cl)cc4)cc23)c1F,-1.59345982 +1,CCCc1cc(N2CCc3c(nc(C4CC4)n3C)C2)n2ncnc2n1,1.677187053 +2,COc1ccccc1C(=O)Nc1ccc2cc[nH]c2c1,0.505149978 +3,Cn1cc(Nc2nc(N)nc(-c3cccc(-n4ccc5cc(C6CC6)cc(F)c5c4=O)c3CO)n2)cn1,0.850401148 +4,C=CC(=O)N(C)CCOc1c(N)ncnc1-c1cc(F)cc(NC(=O)c2ccc(C3CC3)cc2F)c1C,0.838502776 +5,C=CC(=O)N1CCC(CNc2ncnc(N)c2-c2ccc(Oc3ccccc3)cc2)CC1,0.752995432 +6,C=CC(=O)N1CCC[C@@H](n2nc(-c3ccc(Oc4ccccc4)cc3)c3c(N)ncnc32)C1,0.62603 +7,C=CC(=O)N1CCC[C@H](n2c(=O)n(-c3ccc(Oc4ccccc4)cc3)c3c(N)ncnc32)C1,0.983851719 +8,Cc1cccc(/C=N/Nc2cc(N3CCOCC3)nc(OCCc3ccccn3)n2)c1,-0.548213564 +9,C=CC(=O)Nc1cccc(Nc2nc(Nc3ccc(Oc4ccnc(C(=O)NC)c4)cc3)ncc2F)c1,-0.403402904 +10,C=CC(=O)Nc1cccc(Oc2nc(Nc3ccc(N4CCN(C)CC4)c(F)c3)nc3[nH]ccc23)c1,0.039414119 +11,C=CC(=O)Nc1cccc(Oc2nc(Nc3ccc(N4CCN(C)CC4)cc3)nc3ccoc23)c1,0.146128036 +12,C=CC(=O)Nc1cccc(Oc2nc(Nc3ccc(N4CCN(C)CC4)cc3)nc3ccsc23)c1,-0.049635146 +13,CNC(=O)C1(Cc2ccc(-c3cccnc3)cc2)CCN(Cc2cccc(F)c2)C1,0.829753919 +14,COc1nn(C)cc1C(=O)Nc1cccc(-c2cnc3n2CCC3)c1,0.866877814 +15,CNC(=O)C1(Cc2ccc(-c3ccncc3)cc2)CCN(Cc2cccc(F)c2)C1,1.10503305 +16,COc1nn(C)cc1C(=O)Nc1cccc2cnccc12,1.504729052 +17,O=C(Nc1ccc(CN2CCCCC2)cc1)c1ccnc2[nH]cnc12,1.767526899 +18,CCN(C/C=C\c1ccc(C2CCCCC2)c(Cl)c1)C1CCCCC1,-0.853871964 +19,CC#CC(=O)N1CC[C@@H](n2c(=O)n(-c3ccc(Oc4ccccc4)cc3)c3c(N)ncnc32)C1,0.935859798 +20,CC#CC(=O)N[C@H]1CCCN(c2c(F)cc(C(N)=O)c3[nH]c(C)c(C)c23)C1,0.163856803 +21,O=C(Nc1ccc2ccccc2n1)c1ccc(N2CCOC2=O)cc1,0.449015316 +22,O=C(Nc1cccc(-c2nncn2C2CC2)c1)c1cc(-n2cnc(C3CC3)c2)ccn1,0.262925469 +23,O=C(Nc1cccc(N2CCNC2=O)c1)C1NCC12CCCC2,1.718202645 +24,Fc1ccc(-c2ccc3c(c2)[nH]c2ccncc23)cn1,0.653405491 +25,CC(=O)N1CCN(c2c(Cl)cccc2NC(=O)COc2ccccc2Cl)CC1,0.256958153 +26,CC(=O)N1CCN(c2nc(C(F)(F)F)nc3sc(C)c(C)c23)CC1,0.13481437 +27,CNc1cc(Nc2cccn(-c3ccccn3)c2=O)nn2c(C(=O)N[C@@H]3C[C@@H]3F)cnc12,0.991226076 +28,CCN1CCN(S(=O)(=O)Cc2ccc(Cl)c(Cl)c2)CC1,1.249124949 +29,CNc1nc(C)cc(C(=O)Nc2ccc3[nH]ncc3c2)n1,1.493625323 +30,Fc1ccccc1-c1c[nH]nc1C1CCCN1Cc1ccc2ncccc2c1,0.545801757 +31,O=C(Nc1cnccc1-c1ccc(Cl)cc1)c1ccnc(NC(=O)C2CC2)c1,0.369957607 +32,O=C(Nc1nc2cccc(-c3ccc(CN4CCS(=O)(=O)CC4)cc3)n2n1)C1CC1,1.681349797 +33,C[C@@H]1CCN(C(=O)CC#N)C[C@@H]1N(C)c1ncnc2[nH]ccc12,1.991815076 +34,C[C@@H]1c2nnn(-c3ncc(F)cn3)c2CCN1C(=O)c1cccc(C(F)(F)F)c1Cl,0.969089603 +35,CC(=O)Nc1ccc(C(=O)N2CCCCC2c2nc(N)ncc2-c2ccc(Cl)cc2)cc1,0.596926814 +36,CC(=O)Nc1ccc(C(=O)N2CCCCC2c2nc(N)ncc2-c2cccc(Cl)c2)cc1,0.520483533 +37,C[C@H]1CN(C2COC2)CCN1c1ccc(Nc2cc(-c3ccnc(N4CCn5c(cc6c5CC(C)(C)C6)C4=O)c3CO)cn(C)c2=O)nc1,1.209139536 +38,CC(=O)Nc1ccc(O)cc1,1.887859133 +39,Cc1cccnc1Nc1cccc(C2CCCN(CC(=O)Nc3nccs3)C2)n1,0.180699201 +40,COCC(=O)N1CCC(Cc2ccccc2-c2cccc(F)c2)(C(=O)NC(C)C)CC1,0.596597096 +41,Cc1[nH]nc2c1C1(CCCCC1)CC(=O)N2,1.323066376 +42,CC(=O)Nc1ncc(C(=O)O)s1,1.81935965 +43,N#Cc1cc(F)c(NS(=O)(=O)c2c[nH]c3cc(Cl)ccc23)cc1F,-0.026872146 +44,Cc1c(C(=O)NCCCN(C)C)sc2ncnc(Nc3ccc(F)cc3OC(C)C)c12,0.741821047 +45,Cc1c(Cl)ccc2cc3n(c12)[C@@H](C)CNC3=O,0.328379603 +46,Cc1cn2nc(-c3cc(=O)n4cc(N5CCNC6(CC6)C5)ccc4n3)cc(C)c2n1,0.985246791 +47,CCOc1cc(CC(=O)N[C@@H](CC(C)C)c2ccccc2N2CCCCC2)ccc1C(=O)O,0.106870544 +48,COCCNC(=O)c1ccnc(C2CCNCC2)c1,2.0 +49,Cc1c[nH]c(=O)n1-c1ccc(C(=O)Nc2ccc3ccccc3n2)cc1,0.639386869 +50,Cc1cnc(C(=O)NCCc2ccc(S(=O)(=O)NC(=O)NC3CCCCC3)cc2)cn1,0.01745073 +51,Cc1c[nH]c2nccc(Oc3c(F)cc(Nc4cc(Cl)nc(N)n4)cc3F)c12,-1.22184875 +52,CC(C)(C)C(=O)N1CCC(Cc2ccc(-c3cccs3)cc2)(C(=O)N2CCCC2)CC1,-0.1837587 +53,O=C(c1ccc(Oc2ccccc2)cc1Cl)c1c[nH]c2ncnc(N[C@@H]3CC[C@@H](CO)OC3)c12,-0.467245621 +54,CC(C)(C)c1ccc(-c2nc3n(c(=O)c2C#N)CCS3)cc1,0.731266349 +55,CC(C)(C)c1ccc(C(O)CN2CCC(O)(c3ccc4c(c3)OCO4)CC2)cc1,0.779018972 +56,O=C1CCCC[C@@H]2[C@H](C[C@@H](Cc3ccccc3F)N2C(=O)c2cccc3ncccc23)N1,1.277998644 +57,CCc1c[nH]c2ncnc(N3CCC(CN4CCN(C)CC4)CC3)c12,1.728012707 +58,CC(C)(Oc1ccc(-c2cnc(N)c(-c3ccc(Cl)cc3)c2)cc1)C(=O)O,-1.158015195 +59,NC(=O)c1cnc(N2CCc3[nH]nc(C(F)(F)F)c3C2)c(Cl)c1,0.986009932 +60,NC1CC(NC(=O)c2ccc(-c3cn[nH]c3)cn2)C12CCC2,1.709702344 +61,NC1CCC(C(=O)N2CCC(c3c[nH]c4ncccc34)CC2)C1,1.917395215 +62,NC1CCCC(C(=O)N2CCC(c3c[nH]c4ncccc34)CC2)C1,1.752816431 +63,NC1CCCC(C(=O)Nc2ccc3[nH]ncc3c2)C1,2.0 +64,NC1CCCC(C(=O)Nc2cccc(N3CCNC3=O)c2)C1,1.763113391 +65,NC1CCCC1C(=O)N1CCC(c2c[nH]c3ncccc23)CC1,1.71701274 +66,O=C1NCCN(C(=O)c2ccc3nccn3c2)C1c1ccccc1C(F)(F)F,1.851001366 +67,O=C1NCCN(C(=O)c2ccncc2)C1c1ccccc1Cl,1.876996793 +68,NCC1CCCC1NC(=O)c1cc(N2CCNC2=O)ccc1F,1.87495702 +69,O=C1NCCSc2c1sc1ccc(O)cc21,0.891593204 +70,NCCN1CCN(C/C=C/C(=O)N2CCC[C@@H](n3nc(-c4ccc(Oc5ccccc5)cc4)c4c(N)ncnc43)C2)CC1,0.699837726 +71,CCc1nc(C)cn2nc(-c3cc(=O)n4cc(C5CCN(C)CC5)cc(C)c4n3)cc12,0.868232868 +72,O=S(=O)(c1cccc2cnccc12)N1CCCNCC1,1.892077899 +73,Nc1c(F)ccc2cnc(-n3ccc4ccncc43)cc12,0.713910354 +74,CCc1nc2c(C)cc(N3CCN(CC(=O)N4CC(O)C4)CC3)cn2c1N(C)c1nc(-c2ccc(F)cc2)c(C#N)s1,-0.204728421 +75,Cc1nc(Nc2nccs2)cc(C2CN(c3ncccn3)C2)n1,0.444669231 +76,CC(C)NC(=O)COc1cccc(-c2nc(Nc3ccc4[nH]ncc4c3)c3ccccc3n2)c1,-1.384078213 +77,Cc1cc(F)ccc1C1C(=O)NCCN1C(=O)c1ccc2nccn2c1,1.78096503 +78,Cc1cc(N2CCCC2c2cc(CCC(=O)O)cc(C)n2)ncn1,2.0 +79,CC(C)[C@H](CO)Nc1nc(Nc2cc(N)cc(Cl)c2)c2ncn(C(C)C)c2n1,-0.061980903 +80,CCn1c(-c2nonc2N)nc2cnc(Oc3cccc(NC(=O)c4ccc(OCCN5CCOCC5)cc4)c3)cc21,-0.449771647 +81,Nc1ncc(-c2cccc(C(F)(F)F)c2)c(C2CCCCN2C(=O)c2ccccc2)n1,-0.04431225 +82,CCn1c(=O)oc2cc(NC(=O)c3ccc(C(C)(C)C)cc3)ccc21,0.21005085 +83,Nc1ncc(C(=O)NC2CN(C(=O)C3CC3)C2)c2ccc(-c3cccc(F)c3)nc12,0.623352682 +84,CCn1c(CO)nn(-c2cc(O[C@@H](C)C(F)(F)F)c(C(=O)Nc3c(F)cccc3Cl)cc2F)c1=O,0.975753389 +85,Nc1ncnc2c1c(-c1ccc(Oc3ccccc3)cc1)nn2[C@@H]1CCCNC1,0.956648579 +86,Nc1ncnc2c1c(-c1cnc3[nH]ccc3c1)nn2C1CCCC1,0.862608364 +87,CCn1cc(C(=O)O)c(=O)c2c(N)c(F)c(NC3CCCCC3)cc21,0.427486109 +88,Cc1nccc(-c2cn(Cc3ccccc3)c3cnccc23)n1,1.122707254 +89,Cc1cc(Nc2cnccn2)cc(C2CCCN(CC(=O)N3CCCC3)C2)n1,1.642246824 +90,Cc1ncsc1C(=O)N1CCCCC1c1nc(N)ncc1-c1cccc(C(F)(F)F)c1,0.615318657 +91,CN(C(=O)c1cc(N2CCNC2=O)ccc1F)C1CCNC1,1.885728632 +92,OC[C@H](Nc1cncc(-c2ccc3[nH]ncc3n2)c1)c1ccccc1,1.017826038 +93,CN(C)C(=O)C1(Cc2ccccc2-c2cccc(F)c2)CCN(C(=O)C2CC=CCC2)CC1,-0.104025268 +94,CN(C)C(=O)C1(Cc2ccccc2-c2ccccc2)CCN(C(=O)C2CCCO2)CC1,0.902546779 +95,CN(C)C(=O)C1(Cc2ccccc2-c2ccccc2)CCN(C(=O)c2cnn(C)c2)CC1,1.041353202 +96,O=C(CCNC(=O)c1ccc(OC(F)(F)F)cc1)N[C@@H]1CCCc2ccccc21,0.526080692 +97,CC(O)(C#Cc1ccc2c(c1)N(c1nc(N)ncc1Cl)CC2)c1nccs1,-0.122628654 +98,COc1ccc(CCCN2CCN(c3cnn(C)c3)C(=O)C2)cc1F,1.343408594 +99,CC/C(=C(\c1ccccc1)c1ccc(OCCN(C)C)cc1)c1ccccc1,-0.580044252 +100,CC1(C)CC(Oc2ccc(-c3ccc(-c4cn[nH]c4)cc3O)nn2)CC(C)(C)N1,1.199947058 +101,Cc1nnc(-c2ccc(N3CCC(Oc4cc(F)ccc4Cl)CC3)nn2)o1,0.161068385 +102,Cc1nnc(-c2ccc(N3CCC(Oc4ccccc4C(F)(F)F)CC3)nn2)s1,-0.906578315 +103,c1ccc(-c2ccc(CN3CCCCCCC3)cc2)cc1,-0.614393726 +104,COc1ccc(CNC(=O)c2sc3nc(C)cc(C)c3c2N)cc1,-0.356547324 +105,Cc1nnc(CN(C)CC(C)Oc2ccc(Cl)c(Cl)c2)n1C,0.926290987 +106,CC1(C)Cc2cc(NC(=O)c3cnn4cccnc34)c(OCC3CC3)nc2O1,-0.199970641 +107,O=C(CN1CCCC(c2cccc(Cc3cccc(F)c3)n2)C1)N1CCCC1,1.062130535 +108,CC1(CNC(=O)c2cncc(C3CCNCC3)n2)CCCO1,1.988550039 +109,Cc1noc(C(C)C)c1C(=O)N1CC(C)OC(c2ccccc2)C1,1.180469962 +110,CC1CN(C(=O)c2cccnc2N2CCOCC2)CC(c2ccccc2)O1,1.573185017 +111,CCC(=O)N1CCN(c2ccc(Cl)cc2NC(=O)COc2ccccc2)CC1,0.250420002 +112,COc1ccc(S(=O)(=O)N2CCC(N3CCC(C)CC3)CC2)cc1,1.729799023 +113,COc1ccc([C@H]2CN(C(C)=O)[C@@H]3CCCN(Cc4cccc(F)c4)[C@H]23)cc1,0.595385981 +114,Cc1oc2ccccc2c1CNc1nnc(-c2ccncc2)o1,0.717254313 +115,c1sc(NCC2CCCO2)nc1C12CC3CC(CC(C3)C1)C2,-1.096910013 +116,CCC1=C(C)CN(C(=O)NCCc2ccc(S(=O)(=O)NC(=O)N[C@H]3CC[C@H](C)CC3)cc2)C1=O,-1.180456064 +117,CN(Cc1ccccc1)C1(C(=O)N2CCNC(=O)CC2)Cc2ccccc2C1,0.904931827 +118,CCCCNC(=O)NS(=O)(=O)c1ccc(C)cc1,0.439332694 +119,O=C(NC1CCNCC1)c1ccc2[nH]ncc2c1,1.842696589 +120,Cn1c(C2CC2)nc2c1CCN(c1ncnc3ccsc13)C2,1.237065953 +121,Cc1ccc(OCC2(O)CCN(CC3(O)CCN(c4ccccc4C)CC3)CC2)cc1,-0.040481623 +122,CCCNC(=O)NS(=O)(=O)c1ccc(Cl)cc1,0.969835093 +123,Cc1ccc(Oc2ccc(Cl)cc2NC(=O)CN(C)CC(=O)N(C)C)cc1,-0.053056729 +124,COc1ccccc1-c1cc(NC(=O)c2cccc(N3CCNC3=O)c2)[nH]n1,-0.397940009 +125,CN1C(N)=N[C@](C)(c2cc(NC(=O)c3ccc(F)cn3)ccc2F)CS1(=O)=O,1.560468571 +126,C=CC(=O)N1C[C@H](Nc2ncnc3[nH]ccc23)CC[C@@H]1C,1.828227965 +127,CC#CC(=O)N1CC[C@@H](n2cc(-c3ccc(Oc4c(F)cccc4F)cc3)c3c(N)n[nH]c(=O)c32)C1,-0.073143291 +128,CC(=O)NCCNc1cc(Cl)nn2c(-c3cccc(S(=O)(=O)N(C)C)c3)c(C)nc12,0.681693392 +129,CC(C)(O)CCn1cc2cc(NC(=O)c3cccc(C(F)(F)F)n3)c(C(C)(C)O)cc2n1,1.015611205 +130,CC1CC(N)CCN1C(=O)c1cc(N2CCNC2=O)c(F)cc1F,1.833504094 +131,CCOc1cc2nn(CCC(C)(C)O)cc2cc1NC(=O)c1cccc(C(F)F)n1,0.099680641 +132,CCS(=O)(=O)N1CC(CC#N)(n2cc(-c3ncnc4[nH]ccc34)cn2)C1,1.378597739 +133,CN(C(=O)c1cc(N2CCNC2=O)ccc1F)C1CCC(N)CC1,1.828717885 +134,CN1CCN(S(=O)(=O)c2ccc(-c3cnc(N)c(C(=O)Nc4cccnc4)n3)cc2)CC1,1.078638038 +135,CNC(=O)C1(Cc2ccc(-c3cccnc3)cc2)CCN(Cc2ccc(F)cc2)C1,1.021602716 +136,CNC(=O)C1(Cc2ccc(-c3cccnc3)cc2)CCN(Cc2ccccc2Cl)C1,0.387567779 +137,CNC(=O)c1cccc(NC(=O)N2CCC(Oc3ccccc3Cl)CC2)c1,1.068760828 +138,COCCc1noc(CN2CC(c3ccccc3)(c3ccccc3)CCC2=O)n1,1.167051359 +139,COc1cc2c(cc1OC)CC(=O)N(CCCN(C)C[C@H]1Cc3cc(OC)c(OC)cc31)CC2,1.730984039 +140,COc1cc2ncnc(Nc3cccc(O)c3)c2cc1OC,0.994317153 +141,COc1ccc(Cl)cc1C(=O)NCCc1ccc(S(=O)(=O)NC(=O)NC2CCCCC2)cc1,-0.91721463 +142,COc1ccc(Nc2c(C#N)cnc3cc(OC)c(OC)cc23)cc1Cl,-0.033858267 +143,COc1ccc(OCC(=O)Nc2cccc(Cl)c2N2CCN(C(C)=O)CC2)cc1,0.700790221 +144,COc1nc2sc(C(=O)NC3CC3)c(N)c2c(C)c1Cl,0.245512668 +145,Cc1cc(C)c2c(N)c(C(=O)NCc3ccc(Cl)cc3)sc2n1,-0.370590401 +146,Cc1cc(N2CCCC2)nc(C2CCCN(C)C2)n1,2.0 +147,Cc1ccc(C(=O)N2CCC(Cc3ccccc3-c3ccccc3)(C(=O)N(C)C)CC2)s1,-0.293282218 +148,Cc1ncc(CN2CCC(Nc3ccc4nnnn4n3)C(C)C2)s1,1.652285029 +149,Cc1ncsc1C(=O)N1CCCCC1c1nc(N(C)C)ncc1-c1cccc(Cl)c1,-0.083019953 +150,Cn1cc(-c2cn3nccc3c(-c3cnn([C@]4(CC#N)C[C@@H](C#N)C4)c3)n2)cn1,1.559583476 +151,N#CC[C@H](C1CCCC1)n1cc(-c2ncnc3[nH]ccc23)cn1,0.630986911 +152,NC(=O)C1(Cc2ccc(-c3ccncc3)cc2)CCN(C(=O)Cc2cccc(F)c2)CC1,1.006508828 +153,O=C(NC12CCCC1NCC2)c1ccc(-c2cn[nH]c2)cc1,2.0 +154,O=C(NCCc1ccccc1)c1ccc(NC(=O)N2CCCCc3ccccc32)cc1,0.075911761 +155,O=C1CN(c2ccc(Nc3nccc(C(F)(F)F)n3)cn2)CCN1,1.543819805 +156,OCC1CCCCN1Cc1ccc(-c2ccccc2)cc1,1.368007805 +157,OCC1CCCCN1Cc1ccc(Cl)c(Cl)c1,1.390069186 +158,[2H]C([2H])([2H])NC(=O)c1nnc(NC(=O)C2CC2)cc1Nc1cccc(-c2ncn(C)n2)c1OC,1.186566481 +159,c1ccc(Oc2cccc(CN(CCN3CCOCC3)Cc3cccnc3)c2)cc1,-0.222573178 diff --git a/expts/data/finetuning_example-reg/split.csv b/expts/data/finetuning_example-reg/split.csv new file mode 100644 index 000000000..e88c74b5e --- /dev/null +++ b/expts/data/finetuning_example-reg/split.csv @@ -0,0 +1,110 @@ +,train,val,test +0,0,60.0,126.0 +1,1,5.0,127.0 +2,2,35.0,128.0 +3,3,23.0,129.0 +4,4,15.0,130.0 +5,6,68.0,131.0 +6,7,12.0,132.0 +7,8,45.0,133.0 +8,9,119.0,134.0 +9,10,113.0,135.0 +10,11,41.0,136.0 +11,13,88.0,137.0 +12,16,30.0,138.0 +13,17,74.0,139.0 +14,18,54.0,140.0 +15,19,73.0,141.0 +16,20,14.0,142.0 +17,21,,143.0 +18,22,,144.0 +19,24,,145.0 +20,25,,146.0 +21,26,,147.0 +22,27,,148.0 +23,28,,149.0 +24,29,,150.0 +25,31,,151.0 +26,32,,152.0 +27,33,,153.0 +28,34,,154.0 +29,36,,155.0 +30,37,,156.0 +31,38,,157.0 +32,39,,158.0 +33,40,,159.0 +34,42,, +35,43,, +36,44,, +37,46,, +38,47,, +39,48,, +40,49,, +41,50,, +42,51,, +43,52,, +44,53,, +45,55,, +46,56,, +47,57,, +48,58,, +49,59,, +50,61,, +51,62,, +52,63,, +53,64,, +54,65,, +55,66,, +56,67,, +57,69,, +58,70,, +59,71,, +60,72,, +61,75,, +62,76,, +63,77,, +64,78,, +65,79,, +66,80,, +67,81,, +68,82,, +69,83,, +70,84,, +71,85,, +72,86,, +73,87,, +74,89,, +75,90,, +76,91,, +77,92,, +78,93,, +79,94,, +80,95,, +81,96,, +82,97,, +83,98,, +84,99,, +85,100,, +86,101,, +87,102,, +88,103,, +89,104,, +90,105,, +91,106,, +92,107,, +93,108,, +94,109,, +95,110,, +96,111,, +97,112,, +98,114,, +99,115,, +100,116,, +101,117,, +102,118,, +103,120,, +104,121,, +105,122,, +106,123,, +107,124,, +108,125,, diff --git a/expts/hydra-configs/finetuning/admet.yaml b/expts/hydra-configs/finetuning/admet.yaml deleted file mode 100644 index 7360707df..000000000 --- a/expts/hydra-configs/finetuning/admet.yaml +++ /dev/null @@ -1,91 +0,0 @@ -# @package _global_ - -# == Fine-tuning configs in Graphium == -# -# A fine-tuning config is a appendum to a (pre-)training config. -# Since many things (e.g. the architecture), will stay constant between (pre-)training and fine-tuning, -# this config should be as minimal as possible to avoid unnecessary duplication. It only specifies -# what to override with regards to the config used for (pre-)training. -# -# Given the following training command: -# >>> graphium-train --cfg /path/to/train.yaml -# -# Fine-tuning now is as easy as: -# >>> graphium-train --cfg /path/to/train.yaml +finetune=admet -# -# NOTE: This config can be used for each of the benchmarks in the TDC ADMET benchmark suite. -# The only thing that needs to be changed is the `constants.task` key. - - -## == Overrides == - -defaults: - # This file contains all metrics and loss function info for all ADMET tasks. - # This config is filtered at runtime based on the `constants.task` key. - - override /tasks/loss_metrics_datamodule: admet - -constants: - - # For now, we assume a model is always fine-tuned on a single task at a time. - # You can override this value with any of the benchmark names in the TDC benchmark suite. - # See also https://tdcommons.ai/benchmark/admet_group/overview/ - task: lipophilicity_astrazeneca - - name: finetuning_${constants.task}_gcn - wandb: - name: ${constants.name} - project: ${constants.task} - entity: multitask-gnn - save_dir: logs/${constants.task} - seed: 42 - max_epochs: 100 - data_dir: expts/data/admet/${constants.task} - raise_train_error: true - -predictor: - optim_kwargs: - lr: 4.e-5 - -# == Fine-tuning config == - -finetuning: - - # For now, we assume a model is always fine-tuned on a single task at a time. - # You can override this value with any of the benchmark names in the TDC benchmark suite. - # See also https://tdcommons.ai/benchmark/admet_group/overview/ - task: ${constants.task} - level: graph - - # Pretrained model - pretrained_model: dummy-pretrained-model - finetuning_module: task_heads # gnn - sub_module_from_pretrained: zinc # optional - new_sub_module: ${constants.task} # optional - - # keep_modules_after_finetuning_module: # optional - # graph_output_nn/graph: {} - # task_heads/zinc: - # new_sub_module: lipophilicity_astrazeneca - # out_dim: 1 - - - # Changes to finetuning_module - drop_depth: 1 - new_out_dim: 8 - added_depth: 2 - - # Training - unfreeze_pretrained_depth: 0 - epoch_unfreeze_all: none - - # Optional finetuning head appended to model after finetuning_module - finetuning_head: - task: ${constants.task} - previous_module: task_heads - incoming_level: graph - model_type: mlp - in_dim: 8 - out_dim: 1 - hidden_dims: 8 - depth: 2 - last_layer_is_readout: true diff --git a/expts/hydra-configs/finetuning/admet_baseline.yaml b/expts/hydra-configs/finetuning/admet_baseline.yaml deleted file mode 100644 index 5b6aca0ce..000000000 --- a/expts/hydra-configs/finetuning/admet_baseline.yaml +++ /dev/null @@ -1,70 +0,0 @@ -# @package _global_ - -defaults: - - override /tasks/loss_metrics_datamodule: admet - -constants: - task: tbd - name: finetune_${constants.task} - wandb: - name: ${constants.name} - project: finetuning - entity: recursion - seed: 42 - max_epochs: 100 - data_dir: ../data/graphium/admet/${constants.task} - datacache_path: ../datacache/admet/${constants.task} - raise_train_error: true - metric: ${get_metric_name:${constants.task}} - -datamodule: - args: - batch_size_training: 32 - persistent_workers: false - num_workers: 4 - -trainer: - model_checkpoint: - # save_top_k: 1 - # monitor: graph_${constants.task}/${constants.metric}/val - # mode: ${get_metric_mode:${constants.task}} - # save_last: true - # filename: best - dirpath: model_checkpoints/finetuning/${constants.task}/${now:%Y-%m-%d_%H-%M-%S.%f}/ - every_n_epochs: 200 - trainer: - precision: 32 - check_val_every_n_epoch: 1 - # early_stopping: - # monitor: graph_${constants.task}/${constants.metric}/val - # mode: ${get_metric_mode:${constants.task}} - # min_delta: 0.001 - # patience: 10 - accumulate_grad_batches: none - # test_from_checkpoint: best.ckpt - # test_from_checkpoint: ${trainer.model_checkpoint.dirpath}/best.ckpt - -predictor: - optim_kwargs: - lr: 0.000005 - - -# == Fine-tuning config == - -finetuning: - task: ${constants.task} - level: graph - pretrained_model: tbd - finetuning_module: graph_output_nn - sub_module_from_pretrained: graph - new_sub_module: graph - - keep_modules_after_finetuning_module: # optional - task_heads-pcqm4m_g25: - new_sub_module: ${constants.task} - hidden_dims: 256 - depth: 2 - last_activation: ${get_last_activation:${constants.task}} - out_dim: 1 - - epoch_unfreeze_all: tbd \ No newline at end of file diff --git a/expts/hydra-configs/finetuning/example-custom.yaml b/expts/hydra-configs/finetuning/example-custom.yaml new file mode 100644 index 000000000..4b9e20197 --- /dev/null +++ b/expts/hydra-configs/finetuning/example-custom.yaml @@ -0,0 +1,87 @@ +# @package _global_ + +defaults: + - override /tasks/loss_metrics_datamodule: finetune + +constants: + benchmark: custom + task: finetuning_example-cls # finetuning_example-cls OR finetuning_example-reg + task_type: cls # cls OR reg + data_path: expts/data + # wandb: + # name: finetune_${constants.task} + # project: tbd + # entity: tbd + # tags: + # - finetuning + # - ${constants.task} + # - ${finetuning.pretrained_model} + seed: 42 + max_epochs: 20 + raise_train_error: true + model_dropout: 0. + +datamodule: + args: + batch_size_training: 256 + batch_size_inference: 256 + dataloading_from: ram + persistent_workers: true + num_workers: 8 + + task_specific_args: + finetune: + df: null + df_path: ${constants.data_path}/${constants.task}/raw.csv + splits_path: ${constants.data_path}/${constants.task}/split.csv + smiles_col: smiles + label_cols: target + task_level: graph + epoch_sampling_fraction: 1.0 + +trainer: + model_checkpoint: + save_top_k: 0 + dirpath: none + every_n_epochs: 200 + save_last: false + trainer: + precision: 32 + check_val_every_n_epoch: 1 + accumulate_grad_batches: 1 + +predictor: + optim_kwargs: + lr: 0.00001 + torch_scheduler_kwargs: + module_type: WarmUpLinearLR + max_num_epochs: ${constants.max_epochs} + warmup_epochs: 3 + verbose: False + + + +# == Fine-tuning config == + +finetuning: + task: finetune + level: graph + pretrained_model: dummy-pretrained-model + finetuning_module: graph_output_nn + sub_module_from_pretrained: graph + new_sub_module: graph + drop_depth: 1 + added_depth: 1 + new_out_dim: 256 + + keep_modules_after_finetuning_module: + task_heads-zinc: + new_sub_module: finetune + hidden_dims: ${finetuning.new_out_dim} + depth: 1 + dropout: 0. + last_activation: none + out_dim: 1 + + epoch_unfreeze_all: 0 + always_freeze_modules: [] \ No newline at end of file diff --git a/expts/hydra-configs/finetuning/example-tdc.yaml b/expts/hydra-configs/finetuning/example-tdc.yaml new file mode 100644 index 000000000..d5bfaf98f --- /dev/null +++ b/expts/hydra-configs/finetuning/example-tdc.yaml @@ -0,0 +1,76 @@ +# @package _global_ + +defaults: + - override /tasks/loss_metrics_datamodule: tdc +constants: + task: bbb_martins + # wandb: + # name: finetune_${constants.task} + # project: tbd + # entity: tbd + # tags: + # - finetuning + # - ${constants.task} + # - ${finetuning.pretrained_model} + seed: 42 + max_epochs: 20 + raise_train_error: true + metric: ${get_metric_name:${constants.task}} + model_dropout: 0. + +datamodule: + args: + batch_size_training: 256 + batch_size_inference: 256 + dataloading_from: ram + persistent_workers: true + num_workers: 2 + split_type: default + tdc_train_val_seed: 1 + +trainer: + model_checkpoint: + save_top_k: 0 + dirpath: none + every_n_epochs: 200 + save_last: false + trainer: + precision: 32 + check_val_every_n_epoch: 1 + accumulate_grad_batches: 1 + +predictor: + optim_kwargs: + lr: 0.00001 + torch_scheduler_kwargs: + module_type: WarmUpLinearLR + max_num_epochs: ${constants.max_epochs} + warmup_epochs: 3 + verbose: False + + + +# == Fine-tuning config == + +finetuning: + task: ${constants.task} + level: graph + pretrained_model: dummy-pretrained-model + finetuning_module: graph_output_nn + sub_module_from_pretrained: graph + new_sub_module: graph + drop_depth: 1 + added_depth: 1 + new_out_dim: 256 + + keep_modules_after_finetuning_module: + task_heads-zinc: + new_sub_module: ${constants.task} + hidden_dims: ${finetuning.new_out_dim} + depth: 1 + dropout: 0. + last_activation: none + out_dim: 1 + + epoch_unfreeze_all: 0 + always_freeze_modules: [] \ No newline at end of file diff --git a/expts/hydra-configs/fingerprinting/example-custom.yaml b/expts/hydra-configs/fingerprinting/example-custom.yaml new file mode 100644 index 000000000..d6f26a3c5 --- /dev/null +++ b/expts/hydra-configs/fingerprinting/example-custom.yaml @@ -0,0 +1,16 @@ +pretrained: + model: dummy-pretrained-model + layers: + - graph_output_nn-graph:0 + - task_heads-zinc:0 + +datamodule: + df_path: ./expts/data/finetuning_example-reg/raw.csv + benchmark: null + task: null + split_val: 0.0 + split_test: 1.0 + device: cpu # cpu or cuda + num_workers: 0 + fps_cache_dir: ./expts/data/fps/finetuning_example-reg + mol_cache_dir: ${datamodule.fps_cache_dir} diff --git a/expts/hydra-configs/fingerprinting/example-tdc.yaml b/expts/hydra-configs/fingerprinting/example-tdc.yaml new file mode 100644 index 000000000..1f603e9c6 --- /dev/null +++ b/expts/hydra-configs/fingerprinting/example-tdc.yaml @@ -0,0 +1,15 @@ +pretrained: + model: dummy-pretrained-model + layers: + - graph_output_nn-graph:0 + - task_heads-zinc:0 + +datamodule: + df_path: null + benchmark: tdc + task: herg + data_seed: 1 + device: cpu # cpu or cuda + num_workers: 0 + fps_cache_dir: ./expts/data/fps/${datamodule.benchmark}/${datamodule.task} + mol_cache_dir: ${datamodule.fps_cache_dir} \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/finetune.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/finetune.yaml new file mode 100644 index 000000000..0f93b26de --- /dev/null +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/finetune.yaml @@ -0,0 +1,77 @@ +# @package _global_ + +#Task-specific +predictor: + metrics_on_progress_bar: + reg: ["mae"] + cls: ["auroc"] + loss_fun: + reg: mae + cls: bce_logits + random_seed: ${constants.seed} + optim_kwargs: + lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs + torch_scheduler_kwargs: + module_type: WarmUpLinearLR + max_num_epochs: &max_epochs 10 + warmup_epochs: 10 + verbose: False + target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label + multitask_handling: flatten # flatten, mean-per-label + +# Task-specific +metrics: + reg: + - name: mae + metric: mae + target_nan_mask: null + multitask_handling: flatten + threshold_kwargs: null + - name: spearman + metric: spearmanr + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: pearson + metric: pearsonr + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: r2_score + metric: r2_score + target_nan_mask: null + multitask_handling: mean-per-label + threshold_kwargs: null + cls: + - name: auroc + metric: auroc + task: binary + multitask_handling: mean-per-label + threshold_kwargs: null + - name: auprc + metric: averageprecision + task: binary + multitask_handling: mean-per-label + threshold_kwargs: null + - name: accuracy + metric: accuracy + multitask_handling: mean-per-label + target_to_int: True + average: micro + threshold_kwargs: &threshold_05 + operator: greater + threshold: 0.5 + th_on_preds: True + th_on_target: True + +datamodule: + args: + task_specific_args: + finetune: + df: null + df_path: expts/data/finetuning_example-reg/raw.csv + smiles_col: smiles + label_cols: target + task_level: graph + splits_path: expts/data/finetuning_example-reg/split.csv + epoch_sampling_fraction: 1.0 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/tdc.yaml similarity index 89% rename from expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml rename to expts/hydra-configs/tasks/loss_metrics_datamodule/tdc.yaml index 89176f2b6..d4a6e296d 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/tdc.yaml @@ -26,29 +26,29 @@ predictor: herg: ["auroc"] ames: ["auroc"] dili: ["auroc"] - ld50_zhu: ["auroc"] + ld50_zhu: ["mae"] loss_fun: caco2_wang: mae - hia_hou: bce - pgp_broccatelli: bce - bioavailability_ma: bce + hia_hou: bce_logits + pgp_broccatelli: bce_logits + bioavailability_ma: bce_logits lipophilicity_astrazeneca: mae solubility_aqsoldb: mae - bbb_martins: bce + bbb_martins: bce_logits ppbr_az: mae vdss_lombardo: mae - cyp2d6_veith: bce - cyp3a4_veith: bce - cyp2c9_veith: bce - cyp2d6_substrate_carbonmangels: bce - cyp3a4_substrate_carbonmangels: bce - cyp2c9_substrate_carbonmangels: bce + cyp2d6_veith: bce_logits + cyp3a4_veith: bce_logits + cyp2c9_veith: bce_logits + cyp2d6_substrate_carbonmangels: bce_logits + cyp3a4_substrate_carbonmangels: bce_logits + cyp2c9_substrate_carbonmangels: bce_logits half_life_obach: mae clearance_microsome_az: mae clearance_hepatocyte_az: mae - herg: bce - ames: bce - dili: bce + herg: bce_logits + ames: bce_logits + dili: bce_logits ld50_zhu: mae random_seed: ${constants.seed} optim_kwargs: @@ -134,7 +134,7 @@ metrics: ld50_zhu: *regression_metrics datamodule: - module_type: "ADMETBenchmarkDataModule" + module_type: "TDCBenchmarkDataModule" args: # TDC specific tdc_benchmark_names: null diff --git a/expts/hydra-configs/tasks/task_heads/admet.yaml b/expts/hydra-configs/tasks/task_heads/tdc.yaml similarity index 100% rename from expts/hydra-configs/tasks/task_heads/admet.yaml rename to expts/hydra-configs/tasks/task_heads/tdc.yaml diff --git a/expts/hydra-configs/tasks/admet.yaml b/expts/hydra-configs/tasks/tdc.yaml similarity index 77% rename from expts/hydra-configs/tasks/admet.yaml rename to expts/hydra-configs/tasks/tdc.yaml index 30dec61e0..f7fef1b57 100644 --- a/expts/hydra-configs/tasks/admet.yaml +++ b/expts/hydra-configs/tasks/tdc.yaml @@ -3,5 +3,5 @@ # want to override both. defaults: - - task_heads: admet - - loss_metrics_datamodule: admet \ No newline at end of file + - task_heads: tdc + - loss_metrics_datamodule: tdc \ No newline at end of file diff --git a/graphium/cli/__init__.py b/graphium/cli/__init__.py index e190d9ac4..1e60140fb 100644 --- a/graphium/cli/__init__.py +++ b/graphium/cli/__init__.py @@ -1,4 +1,5 @@ from .data import data_app from .parameters import param_app from .finetune_utils import finetune_app +from .fingerprint import fp_app from .main import app diff --git a/graphium/cli/finetune_utils.py b/graphium/cli/finetune_utils.py index 5566e7961..9500a8371 100644 --- a/graphium/cli/finetune_utils.py +++ b/graphium/cli/finetune_utils.py @@ -3,7 +3,6 @@ import fsspec import numpy as np import torch -import tqdm import typer import yaml from datamol.utils import fs @@ -13,7 +12,7 @@ from omegaconf import OmegaConf from graphium.config._loader import load_accelerator, load_datamodule -from graphium.finetuning.fingerprinting import Fingerprinter +from graphium.fingerprinting.fingerprinter import Fingerprinter from graphium.utils import fs from graphium.trainer.predictor import PredictorModule @@ -24,7 +23,7 @@ app.add_typer(finetune_app, name="finetune") -@finetune_app.command(name="admet") +@finetune_app.command(name="tdc") def benchmark_tdc_admet_cli( overrides: List[str], name: Optional[List[str]] = None, @@ -52,7 +51,7 @@ def benchmark_tdc_admet_cli( # Use the Compose API to construct the config for n in name: - overrides += ["+finetuning=admet", f"constants.task={n}"] + overrides += ["+finetuning=tdc", f"constants.task={n}"] with initialize(version_base=None, config_path="../../expts/hydra-configs"): cfg = compose( @@ -138,14 +137,14 @@ def get_fingerprints_from_model( def get_tdc_task_specific(task: str, output: Literal["name", "mode", "last_activation"]): if output == "last_activation": - config_arch_path = "expts/hydra-configs/tasks/task_heads/admet.yaml" + config_arch_path = "expts/hydra-configs/tasks/task_heads/tdc.yaml" with open(config_arch_path, "r") as yaml_file: config_tdc_arch = yaml.load(yaml_file, Loader=yaml.FullLoader) return config_tdc_arch["architecture"]["task_heads"][task]["last_activation"] else: - config_metrics_path = "expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml" + config_metrics_path = "expts/hydra-configs/tasks/loss_metrics_datamodule/tdc.yaml" with open(config_metrics_path, "r") as yaml_file: config_tdc_task_metric = yaml.load(yaml_file, Loader=yaml.FullLoader) diff --git a/graphium/cli/fingerprint.py b/graphium/cli/fingerprint.py new file mode 100644 index 000000000..0b0319fe9 --- /dev/null +++ b/graphium/cli/fingerprint.py @@ -0,0 +1,57 @@ +from typing import Any, List, Dict + +from loguru import logger + +from omegaconf import OmegaConf + +import wandb + +from graphium.fingerprinting.data import FingerprintDatamodule + +import typer +from hydra import initialize, compose + +from graphium.cli.main import app + +fp_app = typer.Typer(help="Automated fingerprinting from pretrained models.") +app.add_typer(fp_app, name="fps") + +@fp_app.command(name="create", help="Create fingerprints for pretrained model.") +def smiles_to_fps(cfg_name: str, overrides: List[str]) -> Dict[str, Any]: + with initialize(version_base=None, config_path="../../expts/hydra-configs/fingerprinting"): + cfg = compose( + config_name=cfg_name, + overrides=overrides, + ) + cfg = OmegaConf.to_container(cfg, resolve=True) + + if "wandb" in cfg.keys(): + wandb_cfg = cfg.get("wandb") + wandb.init(**wandb_cfg) + + pretrained_models = cfg.get("pretrained") + + # Allow alternative definition of `pretrained_models` with the single model specifier and desired layers + if "layers" in pretrained_models.keys(): + assert "model" in pretrained_models.keys(), "this workflow allows easier definition of fingerprinting sweeps" + model, layers = pretrained_models.pop("model"), pretrained_models.pop("layers") + pretrained_models[model] = layers + + data_kwargs = cfg.get("datamodule") + + datamodule = FingerprintDatamodule( + pretrained_models=pretrained_models, + **data_kwargs, + ) + + datamodule.prepare_data() + + logger.info(f"Fingerprints saved in {datamodule.fps_cache_dir}/fps.pt.") + try: + wandb.run.finish() + except: + pass + + +if __name__ == "__main__": + smiles_to_fps(cfg_name="example-tdc", overrides=[]) \ No newline at end of file diff --git a/graphium/cli/fingerprints.py b/graphium/cli/fingerprints.py deleted file mode 100644 index 62b078eb9..000000000 --- a/graphium/cli/fingerprints.py +++ /dev/null @@ -1,6 +0,0 @@ -from .main import app - - -@app.command(name="fp") -def get_fingerprints_from_model(): - ... diff --git a/graphium/cli/train_finetune_test.py b/graphium/cli/train_finetune_test.py index 4cea1ee8c..a83a520b0 100644 --- a/graphium/cli/train_finetune_test.py +++ b/graphium/cli/train_finetune_test.py @@ -20,6 +20,7 @@ from graphium.config._loader import ( load_accelerator, load_architecture, + load_mup, load_datamodule, load_metrics, load_predictor, @@ -152,6 +153,8 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: predictor = PredictorModule.load_pretrained_model( name_or_path=get_checkpoint_path(cfg), device=accelerator_type ) + mup_base_path = cfg["architecture"].pop("mup_base_path", None) + predictor = load_mup(mup_base_path, predictor) else: ## Architecture diff --git a/graphium/config/dummy_finetuning_from_gnn.yaml b/graphium/config/dummy_finetuning_from_gnn.yaml index 75848c40f..becdada2c 100644 --- a/graphium/config/dummy_finetuning_from_gnn.yaml +++ b/graphium/config/dummy_finetuning_from_gnn.yaml @@ -56,6 +56,7 @@ finetuning: constants: seed: 42 max_epochs: 5 + model_dropout: 0. accelerator: float32_matmul_precision: medium @@ -120,7 +121,7 @@ datamodule: ### FROM FINETUNING ### - module_type: "ADMETBenchmarkDataModule" + module_type: "TDCBenchmarkDataModule" args: processed_graph_data_path: datacache/processed_graph_data/dummy_finetuning_from_gnn # TDC specific diff --git a/graphium/config/dummy_finetuning_from_task_head.yaml b/graphium/config/dummy_finetuning_from_task_head.yaml index 373bc6e7e..77ee6852b 100644 --- a/graphium/config/dummy_finetuning_from_task_head.yaml +++ b/graphium/config/dummy_finetuning_from_task_head.yaml @@ -62,6 +62,7 @@ finetuning: constants: seed: 42 max_epochs: 5 + model_dropout: 0. accelerator: float32_matmul_precision: medium @@ -126,7 +127,7 @@ datamodule: ### FROM FINETUNING ### - module_type: "ADMETBenchmarkDataModule" + module_type: "TDCBenchmarkDataModule" args: processed_graph_data_path: datacache/processed_graph_data/dummy_finetuning_task_head # TDC specific diff --git a/graphium/data/__init__.py b/graphium/data/__init__.py index b18cda421..d6becddf3 100644 --- a/graphium/data/__init__.py +++ b/graphium/data/__init__.py @@ -5,6 +5,6 @@ from .datamodule import GraphOGBDataModule from .datamodule import MultitaskFromSmilesDataModule -from .datamodule import ADMETBenchmarkDataModule +from .datamodule import TDCBenchmarkDataModule from .dataset import MultitaskDataset diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 52781d845..6c1a0ce37 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -671,6 +671,7 @@ def __init__( idx_col: Optional[str] = None, mol_ids_col: Optional[str] = None, sample_size: Union[int, float, Type[None]] = None, + split_type: str = "random", split_val: float = 0.2, split_test: float = 0.2, seed: int = None, @@ -713,6 +714,7 @@ def __init__( self.idx_col = idx_col self.mol_ids_col = mol_ids_col self.sample_size = sample_size + self.split_type = split_type self.split_val = split_val self.split_test = split_test self.seed = seed @@ -1129,6 +1131,7 @@ def prepare_data(self): train_indices, val_indices, test_indices = self._get_split_indices( num_molecules, split_val=self.task_dataset_processing_params[task].split_val, + split_type=self.task_dataset_processing_params[task].split_type, split_test=self.task_dataset_processing_params[task].split_test, sample_idx=sample_idx, split_seed=self.task_dataset_processing_params[task].seed, @@ -1613,6 +1616,7 @@ def _get_split_indices( dataset_size: int, split_val: float, split_test: float, + split_type: str = "random", sample_idx: Optional[Iterable[int]] = None, split_seed: int = None, splits_path: Union[str, os.PathLike, Dict[str, Iterable[int]]] = None, @@ -1634,31 +1638,7 @@ def _get_split_indices( if sample_idx is None: sample_idx = np.arange(dataset_size) - if splits_path is None: - # Random splitting - if split_test + split_val > 0: - train_indices, val_test_indices = train_test_split( - sample_idx, - test_size=split_val + split_test, - random_state=split_seed, - ) - sub_split_test = split_test / (split_test + split_val) - else: - train_indices = sample_idx - val_test_indices = np.array([]) - sub_split_test = 0 - - if split_test > 0: - val_indices, test_indices = train_test_split( - val_test_indices, - test_size=sub_split_test, - random_state=split_seed, - ) - else: - val_indices = val_test_indices - test_indices = np.array([]) - - else: + if splits_path is not None: train, val, test = split_names if isinstance(splits_path, (Dict, pd.DataFrame)): # Split from a dataframe @@ -1685,6 +1665,70 @@ def _get_split_indices( test_indices = np.asarray(splits[test]).astype("int") test_indices = test_indices[~np.isnan(test_indices)].tolist() + elif split_type == "scaffold" and split_test != 1.: + # Scaffold splitting + try: + import splito + except ImportError as error: + raise RuntimeError( + f"To do the splitting, `splito` needs to be installed. " + f"Please install it with `pip install splito`" + ) from error + + # Split data into scaffolds + splitter = splito.ScaffoldSplit( + smiles=self.smiles, + test_size=split_test, + random_state=split_seed, + ) + train_val_indices, test_indices = next(splitter.split(X=self.smiles)) + train_val_smiles = [self.smiles[i] for i in train_val_indices] + + sub_split_val = split_val / (1 - split_test) + + splitter = splito.ScaffoldSplit( + smiles=train_val_smiles, + test_size=sub_split_val, + random_state=split_seed, + ) + train_indices, val_indices = next(splitter.split(X=train_val_smiles)) + + else: + if split_type != "random": + logger.warning(f"Unkown split {split_type}. Defaulting to `random`.") + + # Random splitting + if split_test + split_val > 0: + if split_test == 1.: + train_indices = np.array([]) + val_test_indices = sample_idx + sub_split_test = 1. + else: + train_indices, val_test_indices = train_test_split( + sample_idx, + test_size=split_val + split_test, + random_state=split_seed, + ) + sub_split_test = split_test / (split_test + split_val) + else: + train_indices = sample_idx + val_test_indices = np.array([]) + sub_split_test = 0 + + if split_test > 0: + if split_test == 1.: + val_indices = np.array([]) + test_indices = val_test_indices + else: + val_indices, test_indices = train_test_split( + val_test_indices, + test_size=sub_split_test, + random_state=split_seed, + ) + else: + val_indices = val_test_indices + test_indices = np.array([]) + # Filter train, val and test indices _, train_idx, _ = np.intersect1d(sample_idx, train_indices, return_indices=True) train_indices = train_idx.tolist() @@ -2022,7 +2066,7 @@ def _get_ogb_metadata(self): return ogb_metadata -class ADMETBenchmarkDataModule(MultitaskFromSmilesDataModule): +class TDCBenchmarkDataModule(MultitaskFromSmilesDataModule): """ Wrapper to use the ADMET benchmark group from the TDC (Therapeutics Data Commons). diff --git a/graphium/finetuning/finetuning.py b/graphium/finetuning/finetuning.py index 59902a921..4100218c5 100644 --- a/graphium/finetuning/finetuning.py +++ b/graphium/finetuning/finetuning.py @@ -17,10 +17,10 @@ from collections import OrderedDict import torch.nn as nn -import lightning.pytorch as pl +import pytorch_lightning as pl from torch.optim.optimizer import Optimizer -from lightning.pytorch.callbacks import BaseFinetuning +from pytorch_lightning.callbacks import BaseFinetuning class GraphFinetuning(BaseFinetuning): @@ -29,7 +29,8 @@ def __init__( finetuning_module: str, added_depth: int = 0, unfreeze_pretrained_depth: Optional[int] = None, - epoch_unfreeze_all: int = 0, + epoch_unfreeze_all: Optional[int] = 0, + always_freeze_modules: Optional[Union[List, str]] = None, train_bn: bool = False, ): """ @@ -41,6 +42,7 @@ def __init__( added_depth: Number of layers of finetuning module that have been modified rel. to pretrained model unfreeze_pretrained_depth: Number of additional layers to unfreeze before layers modified rel. to pretrained model epoch_unfreeze_all: Epoch to unfreeze entire model + always_freeze_modules: Module that always stay frozen while finetuning train_bn: Boolean value indicating if batchnorm layers stay in training mode """ @@ -51,6 +53,11 @@ def __init__( if unfreeze_pretrained_depth is not None: self.training_depth += unfreeze_pretrained_depth self.epoch_unfreeze_all = epoch_unfreeze_all + self.always_freeze_modules = always_freeze_modules + if self.always_freeze_modules == 'none': + self.always_freeze_modules = None + if isinstance(self.always_freeze_modules, str): + self.always_freeze_modules = [self.always_freeze_modules] self.train_bn = train_bn def freeze_before_training(self, pl_module: pl.LightningModule): @@ -105,3 +112,7 @@ def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer """ if epoch == self.epoch_unfreeze_all: self.unfreeze_and_add_param_group(modules=pl_module, optimizer=optimizer, train_bn=self.train_bn) + + if self.always_freeze_modules is not None: + for module_name in self.always_freeze_modules: + self.freeze_module(pl_module, module_name, pl_module.model.pretrained_model.net._module_map) \ No newline at end of file diff --git a/graphium/finetuning/finetuning_architecture.py b/graphium/finetuning/finetuning_architecture.py index 4b0de1607..e0b3fbe92 100644 --- a/graphium/finetuning/finetuning_architecture.py +++ b/graphium/finetuning/finetuning_architecture.py @@ -21,6 +21,7 @@ from graphium.nn.utils import MupMixin from graphium.trainer.predictor import PredictorModule +from graphium.utils.spaces import FINETUNING_HEADS_DICT class FullGraphFinetuningNetwork(nn.Module, MupMixin): @@ -308,8 +309,6 @@ def __init__(self, finetuning_head_kwargs: Dict[str, Any]): """ - from graphium.utils.spaces import FINETUNING_HEADS_DICT # Avoiding circular imports with `spaces.py` - super().__init__() self.task = finetuning_head_kwargs.pop("task", None) self.previous_module = finetuning_head_kwargs.pop("previous_module", "task_heads") @@ -346,4 +345,4 @@ def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = """ # For the post-nn network, all the dimension are divided - return self.net.make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim) + return self.net.make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim) \ No newline at end of file diff --git a/graphium/finetuning/utils.py b/graphium/finetuning/utils.py index 7b9f7df74..2ef440bc8 100644 --- a/graphium/finetuning/utils.py +++ b/graphium/finetuning/utils.py @@ -22,16 +22,34 @@ import graphium +def filter_cfg_for_custom_task(config: Dict[str, Any], task: str, task_type: str): + """ + Filter a base config for the task type (regression vs. classification) + """ + + cfg = deepcopy(config) + + # Filter the relevant config sections + if "predictor" in cfg and "metrics_on_progress_bar" in cfg["predictor"]: + cfg["predictor"]["metrics_on_progress_bar"] = {task: cfg["predictor"]["metrics_on_progress_bar"][task_type]} + if "predictor" in cfg and "loss_fun" in cfg["predictor"]: + cfg["predictor"]["loss_fun"] = {task: cfg["predictor"]["loss_fun"][task_type]} + if "metrics" in cfg: + cfg["metrics"] = {task: cfg["metrics"][task_type]} + + return cfg + + def filter_cfg_based_on_admet_benchmark_name(config: Dict[str, Any], names: Union[List[str], str]): """ Filter a base config for the full TDC ADMET benchmarking group to only have settings related to a subset of the endpoints """ - if config["datamodule"]["module_type"] != "ADMETBenchmarkDataModule": + if config["datamodule"]["module_type"] != "TDCBenchmarkDataModule": # NOTE (cwognum): For now, this implies we only support the ADMET benchmark from TDC. # It is easy to extend this in the future to support more datasets. - raise ValueError("You can only use this method for the `ADMETBenchmarkDataModule`") + raise ValueError("You can only use this method for the `TDCBenchmarkDataModule`") if isinstance(names, str): names = [names] @@ -61,17 +79,45 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): """ Function combining information from configuration and pretrained model for finetuning. """ - task = cfg["finetuning"]["task"] - # Filter the config based on the task name - # NOTE (cwognum): This prevents the need for having many different files for each of the tasks - # with lots and lots of config repetition. - cfg = filter_cfg_based_on_admet_benchmark_name(cfg, task) + benchmark = cfg["constants"].pop("benchmark", None) + task_type = cfg["constants"].pop("task_type", None) + task = cfg["finetuning"].get("task", "task") + + if benchmark == "custom" and task_type is not None: + cfg = filter_cfg_for_custom_task(cfg, task, task_type) + else: + # Filter the config based on the task name + # NOTE (cwognum): This prevents the need for having many different files for each of the tasks + # with lots and lots of config repetition. + cfg = filter_cfg_based_on_admet_benchmark_name(cfg, task) + cfg_finetune = cfg["finetuning"] # Load pretrained model pretrained_model = cfg_finetune["pretrained_model"] - pretrained_predictor = PredictorModule.load_pretrained_model(pretrained_model, device="cpu") + if isinstance(pretrained_model, dict): + mode = pretrained_model.get('mode') + size = pretrained_model.get('size') + model = pretrained_model.get('model') + pretraining_seed = pretrained_model.get('pretraining_seed') + if mode == 'width': + size = size[:4] + elif mode == 'depth': + size = size[:2] + elif mode == 'molecule': + size = size[:3] + elif mode == 'label': + size = size[:3] + elif mode == 'ablation': + size = f"_{size}" + pretrained_model_name = f"{mode}{size}_{model}_s{pretraining_seed}" + + else: + pretrained_model_name = pretrained_model + + cfg_finetune["pretrained_model"] = pretrained_model_name + pretrained_predictor = PredictorModule.load_pretrained_model(pretrained_model_name, device="cpu") # Inherit shared configuration from pretrained # Architecture @@ -123,6 +169,10 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): # Update config new_module_kwargs.update(upd_kwargs) + depth, hidden_dims = new_module_kwargs['depth'], new_module_kwargs['hidden_dims'] + if isinstance(hidden_dims, list): + if len(hidden_dims) != depth: + new_module_kwargs['hidden_dims'] = (depth - 1) * [hidden_dims[0]] if sub_module_from_pretrained is None: cfg_arch[finetuning_module] = new_module_kwargs @@ -151,6 +201,13 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): # Change architecture to FullGraphFinetuningNetwork cfg_arch["model_type"] = "FullGraphFinetuningNetwork" + def _change_dropout(c): + if isinstance(c, dict): + return {k: (_change_dropout(v) if k != 'dropout' else cfg['constants']['model_dropout']) for k, v in c.items()} + return c + + cfg_arch = _change_dropout(cfg_arch) + cfg["architecture"] = cfg_arch pretrained_overwriting_kwargs = deepcopy(cfg["finetuning"]) @@ -160,6 +217,7 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): "finetuning_head", "unfreeze_pretrained_depth", "epoch_unfreeze_all", + "always_freeze_modules", ] for key in drop_keys: @@ -213,4 +271,4 @@ def update_cfg_arch_for_module( ) cfg_arch[module_name].update({new_sub_module: cfg_arch_from_pretrained[module_name][sub_module]}) - return cfg_arch + return cfg_arch \ No newline at end of file diff --git a/graphium/fingerprinting/__init__.py b/graphium/fingerprinting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphium/fingerprinting/data.py b/graphium/fingerprinting/data.py new file mode 100644 index 000000000..76856a6a9 --- /dev/null +++ b/graphium/fingerprinting/data.py @@ -0,0 +1,252 @@ +from typing import Any, List, Dict, Literal, Union + +import os + +import torch + +import pandas as pd +import numpy as np + +from pytorch_lightning import LightningDataModule + +from torch.utils.data import Dataset, DataLoader + +from graphium.data.datamodule import BaseDataModule, MultitaskFromSmilesDataModule, TDCBenchmarkDataModule, DatasetProcessingParams +from graphium.trainer.predictor import PredictorModule +from graphium.fingerprinting.fingerprinter import Fingerprinter + + +class FingerprintDataset(Dataset): + """ + Dataset class for fingerprints useful for probing experiments. + + Parameters: + labels: Labels for the dataset. + fingerprints: Dictionary of fingerprints, where keys specify model and layer of extraction. + smiles: List of SMILES strings. + """ + def __init__( + self, + labels: torch.Tensor, + fingerprints: Dict[str, torch.Tensor], + smiles: List[str] = None, + ): + self.labels = labels + self.fingerprints = fingerprints + self.smiles = smiles + + def __len__(self): + return len(self.labels) + + def __getitem__(self, index): + fp_list = [] + for val in self.fingerprints.values(): + fp_list.append(val[index]) + + if self.smiles is not None: + return fp_list, self.labels[index], self.smiles[index] + else: + return fp_list, self.labels[index] + + +class FingerprintDatamodule(LightningDataModule): + """ + DataModule class for extracting fingerprints from one (or multiple) pretrained model(s). + + Parameters: + pretrained_models: Dictionary of pretrained models (keys) and list of layers (values), repectively to use. + task: Task to extract fingerprints for. + benchmark: Benchmark to extract fingerprints for. + df_path: Path to the DataFrame containing the SMILES strings. + batch_size: Batch size for fingerprint extraction (i.e., the forward passes of the pretrained models). + split_type: Type of split to use for the dataset. + splits_path: Path to the splits file. + split_val: Fraction of validation data. + split_test: Fraction of test data. + data_seed: Seed for data splitting. + num_workers: Number of workers for data loading. + device: Device to use for fingerprint extraction. + mol_cache_dir: Directory to cache the molecules in. + fps_cache_dir: Directory to cache the fingerprints in + + """ + def __init__( + self, + pretrained_models: Dict[str, List[str]], + task: str = "herg", + benchmark: Literal["tdc", None] = "tdc", + df_path: str = None, + batch_size: int = 64, + split_type: Literal["random", "scaffold"] = "random", + splits_path: str = None, + split_val: float = 0.1, + split_test: float = 0.1, + data_seed: int = 42, + num_workers: int = 0, + device: str = "cpu", + mol_cache_dir: str = "./expts/data/cache", + fps_cache_dir: str = "./expts/data/cache", + ): + super().__init__() + + assert benchmark is not None or df_path is not None, "Either benchmark or df_path must be provided" + + self.pretrained_models = pretrained_models + self.task = task + self.benchmark = benchmark + self.df_path = df_path + self.batch_size = batch_size + self.split_type = split_type + self.splits_path = splits_path + self.split_val = split_val + self.split_test = split_test + self.data_seed = data_seed + self.num_workers = num_workers + self.device = device + self.mol_cache_dir = mol_cache_dir + self.fps_cache_dir = fps_cache_dir + if benchmark is not None: + # Check if benchmark naming is already implied in config + if f"{benchmark}/{task}" not in mol_cache_dir: + self.mol_cache_dir = f"{mol_cache_dir}/{benchmark}/{task}" + if f"{benchmark}/{task}" not in fps_cache_dir: + self.fps_cache_dir = f"{fps_cache_dir}/{benchmark}/{task}" + + self.train_dataset = None + self.valid_dataset = None + self.test_dataset = None + + self.splits = [] + + def prepare_data(self) -> None: + if self.fps_cache_dir is not None and os.path.exists(f"{self.fps_cache_dir}/fps.pt"): + self.smiles, self.labels, self.fps_dict = torch.load(f"{self.fps_cache_dir}/fps.pt").values() + self.splits = list(self.smiles.keys()) + + else: + # Check which splits are needed + self.splits = [] + add_all = self.benchmark is not None or self.splits_path is not None + if add_all or self.split_val + self.split_test < 1: + self.splits.append("train") + if add_all or self.split_val > 0: + self.splits.append("valid") + if add_all or self.split_test > 0: + self.splits.append("test") + + self.data = { + "smiles": {split: [] for split in self.splits}, + "labels": {split: [] for split in self.splits}, + "fps": {split: {} for split in self.splits}, + } + + for model, layers in self.pretrained_models.items(): + predictor = PredictorModule.load_pretrained_model(model, device=self.device) + predictor.featurization.pop("max_num_atoms", None) + + # Featurization + if self.benchmark is None: + assert self.df_path is not None, "df_path must be provided if not using an integrated benchmark" + + # Add a dummy task column (filled with NaN values) in case no such column is provided + base_datamodule = BaseDataModule() + smiles_df = base_datamodule._read_table(self.df_path) + task_cols = [col for col in smiles_df if col.startswith("task_")] + if len(task_cols) == 0: + df_path, file_type = ".".join(self.df_path.split(".")[:-1]), self.df_path.split(".")[-1] + + smiles_df["task_dummy"] = np.nan + + if file_type == "parquet": + smiles_df.to_parquet(f"{df_path}_with_dummy_task_col.{file_type}", index=False) + else: + smiles_df.to_csv(f"{df_path}_with_dummy_task_col.{file_type}", index=False) + + self.df_path = f"{df_path}_with_dummy_task_col.{file_type}" + + task_specific_args = { + "fingerprinting": DatasetProcessingParams( + df_path=self.df_path, + smiles_col="smiles", + label_cols="task_*", + task_level="graph", + splits_path=self.splits_path, + split_type=self.split_type, + split_val=self.split_val, + split_test=self.split_test, + seed=self.data_seed, + ) + } + label_key = "graph_fingerprinting" + + datamodule = MultitaskFromSmilesDataModule( + task_specific_args=task_specific_args, + batch_size_inference=128, + featurization=predictor.featurization, + featurization_n_jobs=0, + processed_graph_data_path=f"{self.mol_cache_dir}/mols/", + ) + + elif self.benchmark == "tdc": + datamodule = TDCBenchmarkDataModule( + tdc_benchmark_names=[self.task], + tdc_train_val_seed=self.data_seed, + batch_size_inference=128, + featurization=predictor.featurization, + featurization_n_jobs=self.num_workers, + processed_graph_data_path=f"{self.mol_cache_dir}/mols/", + ) + label_key = f"graph_{self.task}" + + else: + raise ValueError(f"Invalid benchmark: {self.benchmark}") + + datamodule.prepare_data() + datamodule.setup() + + loader_dict = {} + if "train" in self.splits: + datamodule.train_ds.return_smiles = True + loader_dict["train"] = datamodule.get_dataloader(datamodule.train_ds, shuffle=False, stage="predict") + if "valid" in self.splits: + datamodule.val_ds.return_smiles = True + loader_dict["valid"] = datamodule.get_dataloader(datamodule.val_ds, shuffle=False, stage="predict") + if "test" in self.splits: + datamodule.test_ds.return_smiles = True + loader_dict["test"] = datamodule.get_dataloader(datamodule.test_ds, shuffle=False, stage="predict") + + for split, loader in loader_dict.items(): + if len(self.data["smiles"][split]) == 0: + for batch in loader: + self.data["smiles"][split] += [item for item in batch["smiles"]] + self.data["labels"][split] += batch["labels"][label_key] + + with Fingerprinter(predictor, layers, out_type="torch") as fp: + fps = fp.get_fingerprints_for_dataset(loader, store_dict=True) + for fp_name, fp in fps.items(): + self.data["fps"][split][f"{model}/{fp_name}"] = fp + + os.makedirs(self.fps_cache_dir, exist_ok=True) + torch.save(self.data, f"{self.fps_cache_dir}/fps.pt") + + def setup(self, stage: str) -> None: + # Creating datasets + if stage == "fit": + self.train_dataset = FingerprintDataset(self.labels["train"], self.fps_dict["train"]) + self.valid_dataset = FingerprintDataset(self.labels["valid"], self.fps_dict["valid"]) + else: + self.test_dataset = FingerprintDataset(self.labels["test"], self.fps_dict["test"]) + + def get_fp_dims(self): + fp_dict = next(iter(self.fps_dict.values())) + + return [fp.size(1) for fp in fp_dict.values()] + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True) + + def val_dataloader(self): + return DataLoader(self.valid_dataset, batch_size=len(self.valid_dataset), shuffle=False) + + def test_dataloader(self): + return DataLoader(self.test_dataset, batch_size=len(self.test_dataset), shuffle=False) \ No newline at end of file diff --git a/graphium/finetuning/fingerprinting.py b/graphium/fingerprinting/fingerprinter.py similarity index 79% rename from graphium/finetuning/fingerprinting.py rename to graphium/fingerprinting/fingerprinter.py index 8bfdb5d94..f29023145 100644 --- a/graphium/finetuning/fingerprinting.py +++ b/graphium/fingerprinting/fingerprinter.py @@ -138,8 +138,9 @@ def setup(self): self.network._enable_readout_cache(list(self._spec.keys())) return self - def get_fingerprints_for_batch(self, batch): + def get_fingerprints_for_batch(self, batch, store_dict: bool=False): """Get the fingerprints for a single batch""" + self.network.eval() if not self.network._cache_readouts: raise RuntimeError( @@ -152,18 +153,31 @@ def get_fingerprints_for_batch(self, batch): with torch.inference_mode(): if self.predictor is not None: batch["features"] = self.predictor._convert_features_dtype(batch["features"]) + device = next(iter(self.network.parameters())).device + for key, val in batch["features"].items(): + if isinstance(val, torch.Tensor): + batch["features"][key] = val.to(device) self.network(batch["features"]) - readout_list = [] - for module_name, layers in self._spec.items(): - readout_list.extend( - [self.network._module_map[module_name]._readout_cache[layer].cpu() for layer in layers] - ) - - feats = torch.cat(readout_list, dim=-1) - return self._convert_output_type(feats) + if store_dict: + readout_dict = {} + for module_name, layers in self._spec.items(): + for layer in layers: + readout_dict[f"{module_name}:{layer}"] = self._convert_output_type(self.network._module_map[module_name]._readout_cache[layer].cpu()) - def get_fingerprints_for_dataset(self, dataloader): + return readout_dict + + else: + readout_list = [] + for module_name, layers in self._spec.items(): + readout_list.extend( + [self.network._module_map[module_name]._readout_cache[layer].cpu() for layer in layers] + ) + + feats = torch.cat(readout_list, dim=-1) + return self._convert_output_type(feats) + + def get_fingerprints_for_dataset(self, dataloader, store_dict: bool=False): """Return the fingerprints for an entire dataset""" original_out_type = self._out_type @@ -171,13 +185,29 @@ def get_fingerprints_for_dataset(self, dataloader): fps = [] for batch in tqdm.tqdm(dataloader, desc="Fingerprinting batches"): - feats = self.get_fingerprints_for_batch(batch) + feats = self.get_fingerprints_for_batch(batch, store_dict=store_dict) fps.append(feats) - fps = torch.cat(fps, dim=0) - self._out_type = original_out_type - return self._convert_output_type(fps) + + if store_dict: + fps_dict = fps[0] + for key, value in fps_dict.items(): + fps_dict[key] = [value] + for item in fps[1:]: + for key, value in item.items(): + fps_dict[key].extend([value]) + + self._out_type = original_out_type + for key, value in fps_dict.items(): + fps_dict[key] = self._convert_output_type(torch.cat(value, dim=0)) + + return fps_dict + + else: + fps = torch.cat(fps, dim=0) + + return self._convert_output_type(fps) def teardown(self): """Restore the network to its original state""" @@ -202,4 +232,4 @@ def _convert_output_type(self, feats: torch.Tensor): """Small utility function to convert output types""" if self._out_type == "numpy": feats = feats.numpy() - return feats + return feats \ No newline at end of file diff --git a/graphium/utils/spaces.py b/graphium/utils/spaces.py index c7b6a7ac9..6753f2ed4 100644 --- a/graphium/utils/spaces.py +++ b/graphium/utils/spaces.py @@ -131,7 +131,7 @@ DATAMODULE_DICT = { "GraphOGBDataModule": Datamodules.GraphOGBDataModule, "MultitaskFromSmilesDataModule": Datamodules.MultitaskFromSmilesDataModule, - "ADMETBenchmarkDataModule": Datamodules.ADMETBenchmarkDataModule, + "TDCBenchmarkDataModule": Datamodules.TDCBenchmarkDataModule, } GRAPHIUM_PRETRAINED_MODELS_DICT = { diff --git a/notebooks/finetuning-on-tdc-admet-benchmark.ipynb b/notebooks/finetuning-on-tdc-admet-benchmark.ipynb index 43eb47081..69ec83bd5 100644 --- a/notebooks/finetuning-on-tdc-admet-benchmark.ipynb +++ b/notebooks/finetuning-on-tdc-admet-benchmark.ipynb @@ -47,7 +47,7 @@ "\n", "[TDC](https://tdcommons.ai/) hosts a variety of ML-ready datasets and benchmarks for ML for drug discovery. The [TDC ADMET benchmarking group](https://tdcommons.ai/benchmark/admet_group/overview/) is a popular collection of benchmarks for evaluating new _foundation models_ (see e.g. [MolE](https://arxiv.org/abs/2211.02657)) due to the variety and relevance of the included tasks.\n", "\n", - "The ADMET benchmarking group is integrated in `graphium` through the `ADMETBenchmarkDataModule` data-module. This notebook shows how to easily fine-tune and test a model using that data-module. \n", + "The ADMET benchmarking group is integrated in `graphium` through the `TDCBenchmarkDataModule` data-module. This notebook shows how to easily fine-tune and test a model using that data-module. \n", "\n", "
\n", " NOTE: This notebook is still work in progress. While the fine-tuning logic is unfinished, the notebook does demo how one could use the data-module to easily loop over each of the datasets in the benchmarking group and get the prescribed train-test split. Once the fine-tuning logic is finalized, we will finish this notebook and officially provide it as a tutorial within Graphium. \n", @@ -59,7 +59,20 @@ "execution_count": 3, "id": "4d5af838", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: '../expts/configs/config_tdc_admet_demo.yaml'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# First, let's read the yaml configuration file\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m../expts/configs/config_tdc_admet_demo.yaml\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m file:\n\u001b[1;32m 3\u001b[0m config \u001b[38;5;241m=\u001b[39m yaml\u001b[38;5;241m.\u001b[39mload(file, Loader\u001b[38;5;241m=\u001b[39myaml\u001b[38;5;241m.\u001b[39mFullLoader)\n", + "File \u001b[0;32m~/miniconda3/envs/graphium3/lib/python3.12/site-packages/IPython/core/interactiveshell.py:324\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 319\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 322\u001b[0m )\n\u001b[0;32m--> 324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../expts/configs/config_tdc_admet_demo.yaml'" + ] + } + ], "source": [ "# First, let's read the yaml configuration file\n", "with open(\"../expts/configs/config_tdc_admet_demo.yaml\", \"r\") as file:\n", @@ -125,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "9538abfb", "metadata": {}, "outputs": [], @@ -173,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "1ee4586e", "metadata": {}, "outputs": [], @@ -184,8 +197,8 @@ " have settings related to a subset of the endpoints\n", " \"\"\"\n", " \n", - " if config[\"datamodule\"][\"module_type\"] != \"ADMETBenchmarkDataModule\":\n", - " raise ValueError(\"You can only use this method for the `ADMETBenchmarkDataModule`\")\n", + " if config[\"datamodule\"][\"module_type\"] != \"TDCBenchmarkDataModule\":\n", + " raise ValueError(\"You can only use this method for the `TDCBenchmarkDataModule`\")\n", " \n", " if isinstance(names, str):\n", " names = [names]\n", @@ -896,7 +909,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tests/test_multitask_datamodule.py b/tests/test_multitask_datamodule.py index 81a51459f..664334561 100644 --- a/tests/test_multitask_datamodule.py +++ b/tests/test_multitask_datamodule.py @@ -394,7 +394,7 @@ def test_tdc_admet_benchmark_data_module(self): raise # Make sure we can initialize the module and run the main endpoints - data_module = graphium.data.ADMETBenchmarkDataModule() + data_module = graphium.data.TDCBenchmarkDataModule() data_module.prepare_data() data_module.setup()