From 3de39929f82a00e086fe97458fcb07adcdb19d48 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 24 Nov 2023 17:01:32 +0100 Subject: [PATCH 01/17] change logger class --- configs/training/default_trainer.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/training/default_trainer.yml b/configs/training/default_trainer.yml index 44af7f64..68882295 100644 --- a/configs/training/default_trainer.yml +++ b/configs/training/default_trainer.yml @@ -2,7 +2,8 @@ min_epochs: 100 max_epochs: 100 default_root_dir: &default_root_dir logs logger: - class_path: lightning.pytorch.loggers.CSVLogger + class_path: lightning.pytorch.loggers.WandbLogger init_args: save_dir: *default_root_dir + project: 'chebai' callbacks: default_callbacks.yml \ No newline at end of file From b8949864566f1c1dbdbc02441985000d443f011b Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 24 Nov 2023 17:31:00 +0100 Subject: [PATCH 02/17] set entity --- configs/training/default_trainer.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/training/default_trainer.yml b/configs/training/default_trainer.yml index 68882295..4def2ed3 100644 --- a/configs/training/default_trainer.yml +++ b/configs/training/default_trainer.yml @@ -6,4 +6,5 @@ logger: init_args: save_dir: *default_root_dir project: 'chebai' + entity: 'chebai' callbacks: default_callbacks.yml \ No newline at end of file From c0668136ebf5f6381af805fd6f22fc8a1fb08289 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 29 Nov 2023 09:55:31 +0100 Subject: [PATCH 03/17] process results for classification model --- chebai/result/classification.py | 49 +++++++++++ chebai/result/pretraining.py | 12 +-- demo_process_results.ipynb | 146 +++++++++++++++++++++++++++++--- 3 files changed, 189 insertions(+), 18 deletions(-) create mode 100644 chebai/result/classification.py diff --git a/chebai/result/classification.py b/chebai/result/classification.py new file mode 100644 index 00000000..d0487aea --- /dev/null +++ b/chebai/result/classification.py @@ -0,0 +1,49 @@ +from torchmetrics.classification import MultilabelF1Score + +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +import os +import chebai.models.electra as electra +from chebai.loss.pretraining import ElectraPreLoss +import torch +import tqdm + + +def visualise_f1(logs_path): + df = pd.read_csv(os.path.join(logs_path, 'metrics.csv')) + df_loss = df.melt(id_vars='epoch', value_vars=['val_ep_macro-f1', 'val_micro-f1', 'train_micro-f1', + 'train_ep_macro-f1']) + lineplt = sns.lineplot(df_loss, x='epoch', y='value', hue='variable') + plt.savefig(os.path.join(logs_path, 'f1_plot.png')) + plt.show() + +# get predictions from model +def evaluate_model(logs_base_path, model_filename, data_module): + model = electra.Electra.load_from_checkpoint( + os.path.join(logs_base_path, 'best_epoch=85_val_loss=0.0147_val_micro-f1=0.90.ckpt', model_filename)) + assert isinstance(model, electra.Electra) + collate = data_module.reader.COLLATER() + test_file = 'test.pt' + data_path = os.path.join(data_module.processed_dir, test_file) + data_list = torch.load(data_path) + preds_list = [] + labels_list = [] + + for row in tqdm.tqdm(data_list): + processable_data = model._process_batch(collate([row]), 0) + model_output = model(processable_data, **processable_data['model_kwargs']) + preds, labels = model._get_prediction_and_labels(processable_data, processable_data["labels"], model_output) + preds_list.append(preds) + labels_list.append(labels) + + test_preds = torch.cat(preds_list) + test_labels = torch.cat(labels_list) + print(test_preds.shape) + print(test_labels.shape) + test_loss = ElectraPreLoss() + print(f'Loss on test set: {test_loss(test_preds, test_labels)}') + f1_macro = MultilabelF1Score(test_preds.shape[1], average='macro') + f1_micro = MultilabelF1Score(test_preds.shape[1], average='micro') + print(f'Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}') + print(f'Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}') diff --git a/chebai/result/pretraining.py b/chebai/result/pretraining.py index 5392e61e..e8203cec 100644 --- a/chebai/result/pretraining.py +++ b/chebai/result/pretraining.py @@ -1,5 +1,3 @@ -from torchmetrics.classification import MultilabelF1Score - from chebai.result.base import ResultProcessor import pandas as pd import seaborn as sns @@ -7,22 +5,20 @@ import os import chebai.models.electra as electra from chebai.loss.pretraining import ElectraPreLoss -from chebai.preprocessing.datasets.pubchem import PubChemDeepSMILES import torch import tqdm def visualise_loss(logs_path): - df = pd.read_csv(os.path.join('logs_server', 'pubchem_pretraining', 'version_6', 'metrics.csv')) + df = pd.read_csv(os.path.join(logs_path, 'metrics.csv')) df_loss = df.melt(id_vars='epoch', value_vars=['val_loss_epoch', 'train_loss_epoch']) lineplt = sns.lineplot(df_loss, x='epoch', y='value', hue='variable') - plt.savefig(os.path.join(logs_path, 'loss')) + plt.savefig(os.path.join(logs_path, 'f1_plot.png')) plt.show() # get predictions from model -def evaluate_model(logs_base_path, model_filename): - data_module = PubChemDeepSMILES(chebi_version=227) - model = electra.ElectraPre.load_from_checkpoint(os.path.join(logs_base_path, 'checkpoints', model_filename)) +def evaluate_model(logs_base_path, model_filename, data_module): + model = electra.ElectraPre.load_from_checkpoint(os.path.join(logs_base_path, 'best_epoch=85_val_loss=0.0147_val_micro-f1=0.90.ckpt', model_filename)) assert isinstance(model, electra.ElectraPre) collate = data_module.reader.COLLATER() test_file = 'test.pt' diff --git a/demo_process_results.ipynb b/demo_process_results.ipynb index 678c0dc6..8d1637a6 100644 --- a/demo_process_results.ipynb +++ b/demo_process_results.ipynb @@ -2,12 +2,12 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2023-11-24T09:39:01.737866400Z", - "start_time": "2023-11-24T09:38:54.246379800Z" + "end_time": "2023-11-29T08:17:25.832642900Z", + "start_time": "2023-11-29T08:17:25.816890700Z" } }, "outputs": [], @@ -26,6 +26,7 @@ "from torchmetrics.classification import MultilabelF1Score\n", "import numpy as np\n", "from chebai.result import pretraining as eval_pre\n", + "from chebai.result import classification\n", "from chebai.preprocessing.datasets.pubchem import PubChemDeepSMILES" ] }, @@ -86,7 +87,7 @@ "logs_path = os.path.join('logs_server', 'pubchem_pretraining', 'version_6')\n", "checkpoint_name = 'best_epoch=88_val_loss=0.7713_val_micro-f1=0.00.ckpt'\n", "eval_pre.visualise_loss(logs_path)\n", - "eval_pre.evaluate_model(logs_path, checkpoint_name)\n", + "eval_pre.evaluate_model(logs_path, checkpoint_name, PubChemDeepSMILES(chebi_version=227))\n", "#todo: run on server" ], "metadata": { @@ -97,6 +98,134 @@ } } }, + { + "cell_type": "code", + "execution_count": 19, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " val_micro-f1 epoch\n", + "3339 0.744479 0.0\n", + "6680 0.781085 1.0\n", + "10021 0.800341 2.0\n", + "13362 0.815271 3.0\n", + "16703 0.824292 4.0\n", + "20044 0.831470 5.0\n", + "23385 0.835933 6.0\n", + "26726 0.841620 7.0\n", + "30067 0.844065 8.0\n", + "33408 0.849987 9.0\n", + "36749 0.852032 10.0\n", + "40090 0.853078 11.0\n", + "43431 0.857751 12.0\n", + "46772 0.860396 13.0\n", + "50113 0.860696 14.0\n", + "53454 0.862808 15.0\n", + "56795 0.864307 16.0\n", + "60136 0.866988 17.0\n", + "63477 0.867551 18.0\n", + "66818 0.868336 19.0\n", + "70159 0.871684 20.0\n", + "73500 0.871839 21.0\n", + "76841 0.873981 22.0\n", + "80182 0.873941 23.0\n", + "83523 0.875990 24.0\n", + "86864 0.878374 25.0\n", + "90205 0.876602 26.0\n", + "93546 0.877337 27.0\n", + "96887 0.879710 28.0\n", + "100228 0.880704 29.0\n", + "103569 0.881911 30.0\n", + "106910 0.881255 31.0\n", + "110251 0.882252 32.0\n", + "113592 0.884888 33.0\n", + "116933 0.884645 34.0\n", + "120274 0.885552 35.0\n", + "123615 0.885834 36.0\n", + "126956 0.886494 37.0\n", + "130297 0.887407 38.0\n", + "133638 0.886909 39.0\n", + "136979 0.889023 40.0\n", + "140320 0.889971 41.0\n", + "143661 0.889554 42.0\n", + "147002 0.890000 43.0\n", + "150343 0.889108 44.0\n", + "153684 0.890725 45.0\n", + "157025 0.892501 46.0\n", + "160366 0.891453 47.0\n", + "163707 0.891795 48.0\n", + "167048 0.892051 49.0\n", + "170389 0.893556 50.0\n", + "173730 0.893885 51.0\n", + "177071 0.894064 52.0\n", + "180412 0.894290 53.0\n", + "183753 0.894865 54.0\n", + "187094 0.895441 55.0\n", + "190435 0.894336 56.0\n", + "193776 0.896212 57.0\n", + "197117 0.897314 58.0\n", + "200458 0.898514 59.0\n", + "203799 0.896542 60.0\n", + "207140 0.896189 61.0\n", + "210481 0.897382 62.0\n", + "213822 0.897519 63.0\n", + "217163 0.898333 64.0\n", + "220504 0.899299 65.0\n", + "223845 0.898777 66.0\n", + "227186 0.900293 67.0\n", + "230527 0.899204 68.0\n", + "233868 0.898939 69.0\n", + "237209 0.898965 70.0\n", + "240550 0.899227 71.0\n", + "243891 0.899370 72.0\n", + "247232 0.898636 73.0\n", + "250573 0.899938 74.0\n", + "253914 0.900668 75.0\n", + "257255 0.900551 76.0\n", + "260596 0.900878 77.0\n", + "263937 0.900096 78.0\n", + "267278 0.900799 79.0\n", + "270619 0.903599 80.0\n", + "273960 0.902017 81.0\n", + "277301 0.901503 82.0\n", + "280642 0.903193 83.0\n", + "283983 0.902805 84.0\n", + "287324 0.904095 85.0\n", + "290665 0.902085 86.0\n", + "294006 0.904408 87.0\n", + "297347 0.902922 88.0\n", + "300688 0.904071 89.0\n", + "304029 0.903767 90.0\n", + "307370 0.902834 91.0\n", + "310711 0.903752 92.0\n", + "314052 0.902673 93.0\n", + "317393 0.903318 94.0\n", + "320734 0.904315 95.0\n", + "324075 0.904089 96.0\n", + "327416 0.903087 97.0\n", + "330757 0.903740 98.0\n", + "334098 0.903542 99.0\n" + ] + } + ], + "source": [ + "logs_path = os.path.join('logs_server', 'chebi100_bce_unweighted_deepsmiles', 'version_12')\n", + "#classification.visualise_f1(logs_path)\n", + "df = pd.read_csv(os.path.join(logs_path, 'metrics.csv'))\n", + "df2 = df[~df['val_micro-f1'].isna()]\n", + "df2 = df2[['val_micro-f1', 'epoch']]\n", + "print(df2.to_string())" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-29T08:33:48.374202Z", + "start_time": "2023-11-29T08:33:48.261436600Z" + } + } + }, { "cell_type": "code", "execution_count": 8, @@ -574,18 +703,15 @@ "- model_v200, test set from Martin:\n", " - Macro-F1 on test set with 854 classes: 0.607030\n", " - Micro-F1 on test set with 854 classes: 0.903165\n", - "- model_v200, test set from me (?) (i dont know where it is from, but it was there and lead to bad results)\n", - "- Macro-F1 on test set with 854 classes: 0.003233\n", - "- Micro-F1 on test set with 854 classes: 0.155634\n", "- model_v200, test set from Martin (only using classes also present in chebi_v148):\n", " - Macro-F1 on test set with 701 classes: 0.623063\n", " - Micro-F1 on test set with 701 classes: 0.905059\n", "- model_v227, full test set\n", " - Macro-F1 on test set with 940 classes: 0.593245\n", " - Micro-F1 on test set with 940 classes: 0.909714\n", - " - evaluate_model(model_v227, data_module_v227)\n", - "Macro-F1 on test set with 854 classes: 0.649757\n", - "Micro-F1 on test set with 854 classes: 0.910626\n", + "- model_v227, test set with classes from v200\n", + " - Macro-F1 on test set with 854 classes: 0.649757\n", + " - Micro-F1 on test set with 854 classes: 0.910626\n", "- model_v148, test set from ?\n", " - Macro-F1 on test set: 0.510064\n", " - Micro-F1 on test set: 0.854736\n", From 27868eebec3d9f68a79b972f66ffc0388927712f Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 29 Nov 2023 10:28:41 +0100 Subject: [PATCH 04/17] adapt selfies reader --- chebai/preprocessing/reader.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 92e3935d..bdceb1d8 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -87,7 +87,8 @@ def name(cls): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) dirname = os.path.dirname(__file__) - with open(os.path.join(dirname, "bin", "tokens.txt"), "r") as pk: + self.tokens_path = os.path.join(dirname, "bin", self.name(), "tokens.txt") + with open(self.tokens_path, "r") as pk: self.cache = [x.strip() for x in pk] def _get_token_index(self, token): @@ -104,8 +105,8 @@ def _read_data(self, raw_data): def save_token_cache(self): """write contents of self.cache into tokens.txt""" dirname = os.path.dirname(__file__) - with open(os.path.join(dirname, "bin", "tokens.txt"), "w") as pk: - print(f'saving tokens to {os.path.join(dirname, "bin", "tokens.txt")}...') + with open(self.tokens_path, "w") as pk: + print(f'saving {len(self.cache)} tokens to {self.tokens_path}...') print(f'first 10 tokens: {self.cache[:10]}') pk.writelines([f'{c}\n' for c in self.cache]) @@ -159,30 +160,34 @@ def __init__(self, *args, data_path=None, max_len=1800, vsize=4000, **kwargs): data_path, max_len=max_len ) + + def _get_raw_data(self, row): return self.tokenizer(row["features"])["input_ids"] -class SelfiesReader(DataReader): +class SelfiesReader(ChemDataReader): COLLATER = RaggedCollater def __init__(self, *args, data_path=None, max_len=1800, vsize=4000, **kwargs): super().__init__(*args, **kwargs) - with open("chebai/preprocessing/bin/selfies.txt", "rt") as pk: - self.cache = [l.strip() for l in pk] + self.error_count = 0 @classmethod def name(cls): return "selfies" - def _get_raw_data(self, row): + def _read_data(self, raw_data): try: - splits = sf.split_selfies(sf.encoder(row["features"].strip(), strict=True)) - except Exception as e: - print(e) - return - else: - return [self.cache.index(x) + EMBEDDING_OFFSET for x in splits] + tokenized = sf.split_selfies(sf.encoder(raw_data, strict=True)) + tokenized = [self._get_token_index(v) for v in tokenized] + except ValueError as e: + print(f'could not process {raw_data}') + print(f'\t{e}') + self.error_count += 1 + print(f'\terror count: {self.error_count}') + tokenized = [] + return tokenized class OrdReader(DataReader): From ac4017b908c3f50cb66b8cb509bb4c62af878de8 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 29 Nov 2023 10:31:19 +0100 Subject: [PATCH 05/17] add dir creation --- chebai/preprocessing/reader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index bdceb1d8..d95dcab7 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -87,6 +87,7 @@ def name(cls): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) dirname = os.path.dirname(__file__) + os.makedirs(os.path.join(dirname, "bin", self.name()), exist_ok=True) self.tokens_path = os.path.join(dirname, "bin", self.name(), "tokens.txt") with open(self.tokens_path, "r") as pk: self.cache = [x.strip() for x in pk] From 17a8760bde06754145a2692f44f98ec95c66b039 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 29 Nov 2023 10:33:36 +0100 Subject: [PATCH 06/17] add file creation --- chebai/preprocessing/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index d95dcab7..48be7bea 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -89,7 +89,7 @@ def __init__(self, *args, **kwargs): dirname = os.path.dirname(__file__) os.makedirs(os.path.join(dirname, "bin", self.name()), exist_ok=True) self.tokens_path = os.path.join(dirname, "bin", self.name(), "tokens.txt") - with open(self.tokens_path, "r") as pk: + with open(self.tokens_path, "r+") as pk: self.cache = [x.strip() for x in pk] def _get_token_index(self, token): From 5f50d8ff437b6c7ae3ee4c84b6adeb92a0d8f832 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 29 Nov 2023 10:39:16 +0100 Subject: [PATCH 07/17] fix selfies reader exception handling --- chebai/preprocessing/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 48be7bea..4bb5887a 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -182,7 +182,7 @@ def _read_data(self, raw_data): try: tokenized = sf.split_selfies(sf.encoder(raw_data, strict=True)) tokenized = [self._get_token_index(v) for v in tokenized] - except ValueError as e: + except Exception as e: print(f'could not process {raw_data}') print(f'\t{e}') self.error_count += 1 From fe6c3f096376f5c718b270724550ad3711eff0c4 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 29 Nov 2023 11:10:06 +0100 Subject: [PATCH 08/17] fix selfies reader file creation --- chebai/preprocessing/reader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 4bb5887a..0db6d1a7 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -88,8 +88,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) dirname = os.path.dirname(__file__) os.makedirs(os.path.join(dirname, "bin", self.name()), exist_ok=True) + self.tokens_path = os.path.join(dirname, "bin", self.name(), "tokens.txt") - with open(self.tokens_path, "r+") as pk: + with open(self.tokens_path, "a+") as pk: self.cache = [x.strip() for x in pk] def _get_token_index(self, token): @@ -183,7 +184,7 @@ def _read_data(self, raw_data): tokenized = sf.split_selfies(sf.encoder(raw_data, strict=True)) tokenized = [self._get_token_index(v) for v in tokenized] except Exception as e: - print(f'could not process {raw_data}') + print(f'could not process {raw_data} (type: {type(raw_data)}') print(f'\t{e}') self.error_count += 1 print(f'\terror count: {self.error_count}') From 8bf268a8e8e8d62d6b48f0a6653224ce1e99d6c7 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 29 Nov 2023 11:15:14 +0100 Subject: [PATCH 09/17] fix selfies reader --- chebai/preprocessing/reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 0db6d1a7..99abdabb 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -181,10 +181,10 @@ def name(cls): def _read_data(self, raw_data): try: - tokenized = sf.split_selfies(sf.encoder(raw_data, strict=True)) + tokenized = sf.split_selfies(sf.encoder(raw_data.strip(), strict=True)) tokenized = [self._get_token_index(v) for v in tokenized] except Exception as e: - print(f'could not process {raw_data} (type: {type(raw_data)}') + print(f'could not process {raw_data}') print(f'\t{e}') self.error_count += 1 print(f'\terror count: {self.error_count}') From 59a57b2cfb8732b56b01559b45817546dd2e76a0 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 29 Nov 2023 11:44:32 +0100 Subject: [PATCH 10/17] fix electra model --- chebai/models/electra.py | 36 +++++++++++++++++----------------- chebai/preprocessing/reader.py | 2 -- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 47069c61..f61052c9 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -24,7 +24,7 @@ ) from chebai.preprocessing.reader import MASK_TOKEN_INDEX, CLS_TOKEN from chebai.preprocessing.datasets.chebi import extract_class_hierarchy -from chebai.loss.pretraining import ElectraPreLoss # noqa +from chebai.loss.pretraining import ElectraPreLoss # noqa import torch import csv @@ -53,13 +53,14 @@ def _process_labels_in_batch(self, batch): def forward(self, data, **kwargs): features = data["features"] + features = features.to(self.device).long() # this has been added for selfies, i neither know why it is needed now, nor why it wasnt needed before self.batch_size = batch_size = features.shape[0] max_seq_len = features.shape[1] mask = kwargs["mask"] with torch.no_grad(): dis_tar = ( - torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1) + torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1) ).int() disc_tar_one_hot = torch.eq( torch.arange(max_seq_len, device=self.device)[None, :], dis_tar[:, None] @@ -67,7 +68,7 @@ def forward(self, data, **kwargs): gen_tar = features[disc_tar_one_hot] gen_tar_one_hot = torch.eq( torch.arange(self.generator_config.vocab_size, device=self.device)[ - None, : + None, : ], gen_tar[:, None], ) @@ -100,7 +101,7 @@ def _get_prediction_and_labels(self, batch, labels, output): def filter_dict(d, filter_key): return { - str(k)[len(filter_key) :]: v + str(k)[len(filter_key):]: v for k, v in d.items() if str(k).startswith(filter_key) } @@ -121,10 +122,10 @@ def _process_batch(self, batch, batch_idx): batch_first=True, ) cls_tokens = ( - torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( - -1 - ) - * CLS_TOKEN + torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( + -1 + ) + * CLS_TOKEN ) return dict( features=torch.cat((cls_tokens, batch.x), dim=1), @@ -139,7 +140,7 @@ def as_pretrained(self): return self.electra.electra def __init__( - self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs + self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs ): # Remove this property in order to prevent it from being stored as a # hyper parameter @@ -257,10 +258,10 @@ def _process_batch(self, batch, batch_idx): batch_first=True, ) cls_tokens = ( - torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( - -1 - ) - * CLS_TOKEN + torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( + -1 + ) + * CLS_TOKEN ) return dict( features=torch.cat((cls_tokens, batch.x), dim=1), @@ -295,7 +296,7 @@ def __init__(self, cone_dimensions=20, **kwargs): model_dict = torch.load(fin, map_location=self.device) if model_prefix: state_dict = { - str(k)[len(model_prefix) :]: v + str(k)[len(model_prefix):]: v for k, v in model_dict["state_dict"].items() if str(k).startswith(model_prefix) } @@ -356,7 +357,7 @@ def forward(self, data, **kwargs): def softabs(x, eps=0.01): - return (x**2 + eps) ** 0.5 - eps**0.5 + return (x ** 2 + eps) ** 0.5 - eps ** 0.5 def anglify(x): @@ -383,8 +384,8 @@ def in_cone_parts(vectors, cone_axes, cone_arcs): dis = (torch.abs(turn(v, theta_L)) + torch.abs(turn(v, theta_R)) - cone_arc_ang)/(2*pi-cone_arc_ang) return dis """ - a = cone_axes - cone_arcs**2 - b = cone_axes + cone_arcs**2 + a = cone_axes - cone_arcs ** 2 + b = cone_axes + cone_arcs ** 2 bigger_than_a = torch.sigmoid(vectors - a) smaller_than_b = torch.sigmoid(b - vectors) return bigger_than_a * smaller_than_b @@ -410,4 +411,3 @@ def __call__(self, target, input): memberships, target.unsqueeze(-1).expand(-1, -1, 20) ) return loss - diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 99abdabb..1b144142 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -162,8 +162,6 @@ def __init__(self, *args, data_path=None, max_len=1800, vsize=4000, **kwargs): data_path, max_len=max_len ) - - def _get_raw_data(self, row): return self.tokenizer(row["features"])["input_ids"] From 0868b6778f42b074c588b70fdbf5db21c05e6a14 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 30 Nov 2023 13:21:58 +0100 Subject: [PATCH 11/17] fix handling of errors in reader --- chebai/preprocessing/reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 1b144142..a2b0c8cf 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -134,7 +134,7 @@ def _read_data(self, raw_data): print(f'\t{e}') self.error_count += 1 print(f'\terror count: {self.error_count}') - tokenized = [] + tokenized = None return tokenized @@ -186,7 +186,7 @@ def _read_data(self, raw_data): print(f'\t{e}') self.error_count += 1 print(f'\terror count: {self.error_count}') - tokenized = [] + tokenized = None return tokenized From 5a3dd3de2590148f56a65cb5ad2a807d8f246a06 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 30 Nov 2023 18:55:47 +0100 Subject: [PATCH 12/17] set constraints in selfies reader --- chebai/preprocessing/reader.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index a2b0c8cf..ff7c3d57 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -172,6 +172,7 @@ class SelfiesReader(ChemDataReader): def __init__(self, *args, data_path=None, max_len=1800, vsize=4000, **kwargs): super().__init__(*args, **kwargs) self.error_count = 0 + sf.set_semantic_constraints("hypervalent") @classmethod def name(cls): @@ -183,10 +184,12 @@ def _read_data(self, raw_data): tokenized = [self._get_token_index(v) for v in tokenized] except Exception as e: print(f'could not process {raw_data}') - print(f'\t{e}') + #print(f'\t{e}') self.error_count += 1 print(f'\terror count: {self.error_count}') tokenized = None + #if self.error_count > 20: + # raise Exception('Too many errors') return tokenized From b30c97ec458b430c3c8f9ddd4821a542a4b768f5 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 30 Nov 2023 19:42:06 +0100 Subject: [PATCH 13/17] fix filtering for none-value in setup_preprocessed() --- chebai/preprocessing/datasets/base.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 2912881d..04c1c913 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -26,7 +26,7 @@ def __init__( balance_after_filter: typing.Optional[float] = None, num_workers: int = 1, chebi_version: int = 200, - inner_k_folds: int = -1, # use inner cross-validation if > 1 + inner_k_folds: int = -1, # use inner cross-validation if > 1 **kwargs, ): super().__init__(**kwargs) @@ -46,7 +46,7 @@ def __init__( self.chebi_version = chebi_version assert(type(inner_k_folds) is int) self.inner_k_folds = inner_k_folds - self.use_inner_cross_validation = inner_k_folds > 1 # only use cv if there are at least 2 folds + self.use_inner_cross_validation = inner_k_folds > 1 # only use cv if there are at least 2 folds os.makedirs(self.raw_dir, exist_ok=True) os.makedirs(self.processed_dir, exist_ok=True) @@ -130,6 +130,9 @@ def _load_data_from_file(self, path): for d in tqdm.tqdm(self._load_dict(path), total=lines) if d["features"] is not None ] + # filter for missing features in resulting data + data = [val for val in data if val['features'] is not None] + return data def train_dataloader(self, *args, **kwargs) -> DataLoader: @@ -160,6 +163,11 @@ def setup(self, **kwargs): if self.use_inner_cross_validation: self.train_val_data = torch.load(os.path.join(self.processed_dir, self.processed_file_names_dict['train_val'])) + def teardown(self, stage: str) -> None: + # cant save hyperparams at setup because logger is not initialised yet + # not sure if this has an effect + self.save_hyperparameters() + def setup_processed(self): raise NotImplementedError From 9e0bbd58c80c898065ea438440cb9ceb7916e3a3 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 4 Dec 2023 09:28:14 +0100 Subject: [PATCH 14/17] add custom logger --- chebai/cli.py | 3 +- chebai/loggers/custom.py | 57 ++++++++++++++++++++++++++ chebai/trainer/InnerCVTrainer.py | 46 +++++++++------------ configs/training/default_callbacks.yml | 7 ++-- configs/training/default_trainer.yml | 3 +- 5 files changed, 85 insertions(+), 31 deletions(-) create mode 100644 chebai/loggers/custom.py diff --git a/chebai/cli.py b/chebai/cli.py index 0c091db9..5ca33f6c 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -32,4 +32,5 @@ def subcommands() -> Dict[str, Set[str]]: def cli(): - r = ChebaiCLI(save_config_callback=None, parser_kwargs={"parser_mode": "omegaconf"}) + r = ChebaiCLI(save_config_kwargs={"config_filename": "lightning_config.yaml"}, + parser_kwargs={"parser_mode": "omegaconf"}) diff --git a/chebai/loggers/custom.py b/chebai/loggers/custom.py new file mode 100644 index 00000000..e62129c7 --- /dev/null +++ b/chebai/loggers/custom.py @@ -0,0 +1,57 @@ +from datetime import datetime +from typing import Optional, Union, Literal + +import wandb +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.loggers import WandbLogger +import os + + +class CustomLogger(WandbLogger): + """Adds support for custom naming of runs and cross-validation""" + + def __init__(self, save_dir: _PATH, name: str = "logs", version: Optional[Union[int, str]] = None, prefix: str = "", + fold: Optional[int] = None, project: Optional[str] = None, entity: Optional[str] = None, + offline: bool = False, + log_model: Union[Literal["all"], bool] = False, **kwargs): + if version is None: + version = f'{datetime.now():%y%m%d-%H%M}' + self._version = version + self._name = name + self._fold = fold + super().__init__(name=self.name, save_dir=save_dir, version=None, prefix=prefix, + log_model=log_model, entity=entity, project=project, offline=offline, **kwargs) + + @property + def name(self) -> Optional[str]: + name = f'{self._name}_{self.version}' + if self._fold is not None: + name += f'_fold{self._fold}' + return name + + @property + def version(self) -> Optional[str]: + return self._version + + @property + def root_dir(self) -> Optional[str]: + return os.path.join(self.save_dir, self.name) + + @property + def log_dir(self) -> str: + version = self.version if isinstance(self.version, str) else f"version_{self.version}" + if self._fold is None: + return os.path.join(self.root_dir, version) + return os.path.join(self.root_dir, version, f'fold_{self._fold}') + + def set_fold(self, fold: int): + if fold != self._fold: + self._fold = fold + # start new experiment + wandb.finish() + self._experiment = None + _ = self.experiment + + @property + def fold(self): + return self._fold diff --git a/chebai/trainer/InnerCVTrainer.py b/chebai/trainer/InnerCVTrainer.py index 1150d639..d9c78e1e 100644 --- a/chebai/trainer/InnerCVTrainer.py +++ b/chebai/trainer/InnerCVTrainer.py @@ -6,7 +6,7 @@ from lightning.fabric.utilities.types import _PATH from lightning.pytorch.trainer.connectors.logger_connector import _LoggerConnector from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE -from lightning.pytorch.loggers import CSVLogger, Logger, TensorBoardLogger +from lightning.pytorch.loggers import CSVLogger, Logger, TensorBoardLogger, WandbLogger from lightning.pytorch.callbacks import ModelCheckpoint from lightning.fabric.plugins.environments import SLURMEnvironment from lightning_utilities.core.rank_zero import WarningCache @@ -14,6 +14,7 @@ from iterstrat.ml_stratifiers import MultilabelStratifiedKFold from lightning.pytorch.callbacks.model_checkpoint import _is_dir, rank_zero_warn +from chebai.loggers.custom import CustomLogger from chebai.preprocessing.datasets.base import XYBaseDataModule log = logging.getLogger(__name__) @@ -42,35 +43,28 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar train_dataloader = datamodule.train_dataloader(ids=train_ids) val_dataloader = datamodule.val_dataloader(ids=val_ids) init_kwargs = self.init_kwargs - new_logger = CSVLoggerCVSupport(save_dir=self.logger.save_dir, name=self.logger.name, - version=self.logger.version, fold=fold) - init_kwargs['logger'] = new_logger new_trainer = Trainer(*self.init_args, **init_kwargs) - print(f'Logging this fold at {new_trainer.logger.log_dir}') + logger = new_trainer.logger + if isinstance(logger, CustomLogger): + logger.set_fold(fold) + print(f'Logging this fold at {logger.experiment.dir}') + else: + rank_zero_warn(f"Using k-fold cross-validation without an adapted logger class") new_trainer.fit(train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, *args, **kwargs) - -# extend CSVLogger to include fold number in log path -class CSVLoggerCVSupport(CSVLogger): - - def __init__(self, save_dir: _PATH, name: str = "lightning_logs", version: Optional[Union[int, str]] = None, - prefix: str = "", flush_logs_every_n_steps: int = 100, fold: int = None): - super().__init__(save_dir, name, version, prefix, flush_logs_every_n_steps) - self.fold = fold - @property - def log_dir(self) -> str: - """The log directory for this run. + def log_dir(self) -> Optional[str]: + if len(self.loggers) > 0: + logger = self.loggers[0] + if isinstance(logger, WandbLogger): + dirpath = logger.experiment.dir + else: + dirpath = self.loggers[0].log_dir + else: + dirpath = self.default_root_dir - By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the - constructor's version parameter instead of ``None`` or an int. - Additionally: Save data for each fold separately - """ - # create a pseudo standard path - version = self.version if isinstance(self.version, str) else f"version_{self.version}" - if self.fold is None: - return os.path.join(self.root_dir, version) - return os.path.join(self.root_dir, version, f'fold_{self.fold}') + dirpath = self.strategy.broadcast(dirpath) + return dirpath class ModelCheckpointCVSupport(ModelCheckpoint): @@ -114,7 +108,7 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH: version = trainer.loggers[0].version version = version if isinstance(version, str) else f"version_{version}" cv_logger = trainer.loggers[0] - if isinstance(cv_logger, CSVLoggerCVSupport) and cv_logger.fold is not None: + if isinstance(cv_logger, CustomLogger) and cv_logger.fold is not None: # log_dir includes fold ckpt_path = os.path.join(cv_logger.log_dir, "checkpoints") else: diff --git a/configs/training/default_callbacks.yml b/configs/training/default_callbacks.yml index 23ceda2e..46188782 100644 --- a/configs/training/default_callbacks.yml +++ b/configs/training/default_callbacks.yml @@ -1,13 +1,14 @@ -- class_path: chebai.trainer.InnerCVTrainer.ModelCheckpoint +- class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: monitor: val_micro-f1 mode: 'max' filename: 'best_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}' every_n_epochs: 1 save_top_k: 5 -- class_path: chebai.trainer.InnerCVTrainer.ModelCheckpointCVSupport +- class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: filename: 'per_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}' every_n_epochs: 5 save_top_k: -1 -- class_path: chebai.callbacks.epoch_metrics.EpochLevelMacroF1 \ No newline at end of file +- class_path: chebai.callbacks.epoch_metrics.EpochLevelMacroF1 +#class_path: chebai.callbacks.save_config_callback.CustomSaveConfigCallback diff --git a/configs/training/default_trainer.yml b/configs/training/default_trainer.yml index 4def2ed3..dcbdb40d 100644 --- a/configs/training/default_trainer.yml +++ b/configs/training/default_trainer.yml @@ -2,9 +2,10 @@ min_epochs: 100 max_epochs: 100 default_root_dir: &default_root_dir logs logger: - class_path: lightning.pytorch.loggers.WandbLogger + class_path: chebai.loggers.custom.CustomLogger init_args: save_dir: *default_root_dir project: 'chebai' entity: 'chebai' + log_model: 'all' callbacks: default_callbacks.yml \ No newline at end of file From a21e3f898b93368b346f650329122f0d99218e42 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 5 Dec 2023 14:04:27 +0100 Subject: [PATCH 15/17] update checkpoints for wandb logging, refactor --- chebai/callbacks/model_checkpoint.py | 57 +++++++++++++++++++++++ chebai/trainer/InnerCVTrainer.py | 62 +------------------------- configs/training/default_callbacks.yml | 8 ++-- 3 files changed, 63 insertions(+), 64 deletions(-) create mode 100644 chebai/callbacks/model_checkpoint.py diff --git a/chebai/callbacks/model_checkpoint.py b/chebai/callbacks/model_checkpoint.py new file mode 100644 index 00000000..cf461384 --- /dev/null +++ b/chebai/callbacks/model_checkpoint.py @@ -0,0 +1,57 @@ +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from lightning.fabric.utilities.types import _PATH + + +class CustomModelCheckpoint(ModelCheckpoint): + """Checkpoint class that resolves checkpoint paths s.t. for the CustomLogger, checkpoints get saved to the + same directory as the other logs""" + + def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: + """Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir""" + if self.dirpath is not None: + self.dirpath = None + dirpath = self.__resolve_ckpt_dir(trainer) + dirpath = trainer.strategy.broadcast(dirpath) + self.dirpath = dirpath + if trainer.is_global_zero and stage == "fit": + self.__warn_if_dir_not_empty(self.dirpath) + + def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: + """Same as in parent class, duplicated because method in parent class is not accessible""" + if self.save_top_k != 0 and _is_dir(self._fs, dirpath, strict=True) and len(self._fs.ls(dirpath)) > 0: + rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") + + def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH: + """Determines model checkpoint save directory at runtime. Reference attributes from the trainer's logger to + determine where to save checkpoints. The path for saving weights is set in this priority: + + 1. The ``ModelCheckpoint``'s ``dirpath`` if passed in + 2. The ``Logger``'s ``log_dir`` if the trainer has loggers + 3. The ``Trainer``'s ``default_root_dir`` if the trainer has no loggers + + The path gets extended with subdirectory "checkpoints". + + """ + print(f'Resolving checkpoint dir (custom)') + if self.dirpath is not None: + # short circuit if dirpath was passed to ModelCheckpoint + return self.dirpath + if len(trainer.loggers) > 0: + if trainer.loggers[0].save_dir is not None: + save_dir = trainer.loggers[0].save_dir + else: + save_dir = trainer.default_root_dir + name = trainer.loggers[0].name + version = trainer.loggers[0].version + version = version if isinstance(version, str) else f"version_{version}" + logger = trainer.loggers[0] + if isinstance(logger, WandbLogger): + ckpt_path = os.path.join(logger.experiment.dir, "checkpoints") + else: + ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") + else: + # if no loggers, use default_root_dir + ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") + + print(f'Now using checkpoint path {ckpt_path}') + return ckpt_path diff --git a/chebai/trainer/InnerCVTrainer.py b/chebai/trainer/InnerCVTrainer.py index d9c78e1e..30f48e9c 100644 --- a/chebai/trainer/InnerCVTrainer.py +++ b/chebai/trainer/InnerCVTrainer.py @@ -1,15 +1,11 @@ import logging import os -from typing import Optional, Union, Iterable +from typing import Optional from lightning import Trainer, LightningModule from lightning.fabric.utilities.types import _PATH -from lightning.pytorch.trainer.connectors.logger_connector import _LoggerConnector -from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE -from lightning.pytorch.loggers import CSVLogger, Logger, TensorBoardLogger, WandbLogger +from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning_utilities.core.rank_zero import WarningCache from iterstrat.ml_stratifiers import MultilabelStratifiedKFold from lightning.pytorch.callbacks.model_checkpoint import _is_dir, rank_zero_warn @@ -65,57 +61,3 @@ def log_dir(self) -> Optional[str]: dirpath = self.strategy.broadcast(dirpath) return dirpath - - -class ModelCheckpointCVSupport(ModelCheckpoint): - - def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: - """Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir""" - if self.dirpath is not None: - self.dirpath = None - dirpath = self.__resolve_ckpt_dir(trainer) - dirpath = trainer.strategy.broadcast(dirpath) - self.dirpath = dirpath - if trainer.is_global_zero and stage == "fit": - self.__warn_if_dir_not_empty(self.dirpath) - - def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: - """Same as in parent class, duplicated because method in parent class is not accessible""" - if self.save_top_k != 0 and _is_dir(self._fs, dirpath, strict=True) and len(self._fs.ls(dirpath)) > 0: - rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") - - def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH: - """Determines model checkpoint save directory at runtime. Reference attributes from the trainer's logger to - determine where to save checkpoints. The path for saving weights is set in this priority: - - 1. The ``ModelCheckpoint``'s ``dirpath`` if passed in - 2. The ``Logger``'s ``log_dir`` if the trainer has loggers - 3. The ``Trainer``'s ``default_root_dir`` if the trainer has no loggers - - The path gets extended with subdirectory "checkpoints". - - """ - print(f'Resolving checkpoint dir (with cross-validation)') - if self.dirpath is not None: - # short circuit if dirpath was passed to ModelCheckpoint - return self.dirpath - if len(trainer.loggers) > 0: - if trainer.loggers[0].save_dir is not None: - save_dir = trainer.loggers[0].save_dir - else: - save_dir = trainer.default_root_dir - name = trainer.loggers[0].name - version = trainer.loggers[0].version - version = version if isinstance(version, str) else f"version_{version}" - cv_logger = trainer.loggers[0] - if isinstance(cv_logger, CustomLogger) and cv_logger.fold is not None: - # log_dir includes fold - ckpt_path = os.path.join(cv_logger.log_dir, "checkpoints") - else: - ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") - else: - # if no loggers, use default_root_dir - ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") - - print(f'Now using checkpoint path {ckpt_path}') - return ckpt_path diff --git a/configs/training/default_callbacks.yml b/configs/training/default_callbacks.yml index 46188782..3d84c6ff 100644 --- a/configs/training/default_callbacks.yml +++ b/configs/training/default_callbacks.yml @@ -1,13 +1,13 @@ -- class_path: lightning.pytorch.callbacks.ModelCheckpoint +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint init_args: monitor: val_micro-f1 mode: 'max' - filename: 'best_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}' + filename: 'best_{epoch:02d}_{val_loss:.4f}_{val_ep_macro-f1:.4f}_{val_micro-f1:.4f}' every_n_epochs: 1 save_top_k: 5 -- class_path: lightning.pytorch.callbacks.ModelCheckpoint +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint init_args: - filename: 'per_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}' + filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_ep_macro-f1:.4f}_{val_micro-f1:.4f}' every_n_epochs: 5 save_top_k: -1 - class_path: chebai.callbacks.epoch_metrics.EpochLevelMacroF1 From e0f72aea6a8235f299ee5bdfac6eb868fe2bdd8e Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 5 Dec 2023 17:16:36 +0100 Subject: [PATCH 16/17] fix data preparation --- .gitignore | 3 - chebai/callbacks/model_checkpoint.py | 6 +- chebai/preprocessing/datasets/chebi.py | 133 +++++++++---------------- chebai/trainer/InnerCVTrainer.py | 2 +- 4 files changed, 55 insertions(+), 89 deletions(-) diff --git a/.gitignore b/.gitignore index cd34d6b4..4c14111e 100644 --- a/.gitignore +++ b/.gitignore @@ -161,6 +161,3 @@ cython_debug/ #.idea/ configs/ -# the notebook I put in the wrong folder -chebai/preprocessing/datasets/demo_old_chebi.ipynb -demo_examine_pretraining_data.ipynb \ No newline at end of file diff --git a/chebai/callbacks/model_checkpoint.py b/chebai/callbacks/model_checkpoint.py index cf461384..b5740438 100644 --- a/chebai/callbacks/model_checkpoint.py +++ b/chebai/callbacks/model_checkpoint.py @@ -1,6 +1,10 @@ from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.fabric.utilities.types import _PATH - +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch import Trainer, LightningModule +import os +from lightning.fabric.utilities.cloud_io import _is_dir +from lightning.pytorch.utilities.rank_zero import rank_zero_info class CustomModelCheckpoint(ModelCheckpoint): """Checkpoint class that resolves checkpoint paths s.t. for the CustomLogger, checkpoints get saved to the diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 22d7623a..01510972 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -119,12 +119,11 @@ def select_classes(self, g, split_name, *args, **kwargs): raise NotImplementedError def graph_to_raw_dataset(self, g, split_name=None): - """Preparation step before creating splits, uses graph created by extract_class_hierarchy() + """Preparation step before creating splits, uses graph created by extract_class_hierarchy(), split_name is only relevant, if a separate train_version is set""" smiles = nx.get_node_attributes(g, "smiles") names = nx.get_node_attributes(g, "name") - print("build labels") print(f"Process graph") molecules, smiles_list = zip( @@ -199,68 +198,50 @@ def setup_processed(self): self._setup_pruned_test_set() self.reader.save_token_cache() - def get_splits(self, df: pd.DataFrame): - print("Split dataset") + def get_test_split(self, df: pd.DataFrame): + print("Split dataset into train (including val) / test") df_list = df.values.tolist() df_list = [row[3:] for row in df_list] - msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0) + test_size = 1 - self.train_split - (1 - self.train_split) ** 2 + msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=0) train_split = [] test_split = [] for (train_split, test_split) in msss.split( - df_list, df_list, + df_list, df_list, ): train_split = train_split test_split = test_split break df_train = df.iloc[train_split] df_test = df.iloc[test_split] - if self.use_inner_cross_validation: - return df_train, df_test + return df_train, df_test - df_test_list = df_test.values.tolist() - df_test_list = [row[3:] for row in df_test_list] - validation_split = [] - test_split = [] - msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0) - for (test_split, validation_split) in msss.split( - df_test_list, df_test_list - ): - test_split = test_split - validation_split = validation_split - break + def get_train_val_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFrame): + """ Use test set (e.g., loaded from another chebi version or generated in get_test_split), avoid overlap""" + print(f"Split dataset into train / val with given test set") - df_validation = df_test.iloc[validation_split] - df_test = df_test.iloc[test_split] - return df_train, df_test, df_validation - - def get_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFrame): - """ Use test set from another chebi version the model does not train on, avoid overlap""" - print(f"Split dataset for chebi_v{self.chebi_version_train}") df_trainval = df test_smiles = test_df['SMILES'].tolist() - mask = [] - for row in df_trainval: - if row['SMILES'] in test_smiles: - mask.append(False) - else: - mask.append(True) + mask = [smiles not in test_smiles for smiles in df_trainval['SMILES']] df_trainval = df_trainval[mask] + if self.use_inner_cross_validation: return df_trainval - # assume that size of validation split should relate to train split as in get_splits() - msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=self.train_split ** 2, random_state=0) + # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) + test_size = ((1 - self.train_split) ** 2) / self.train_split + msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=0) - df_trainval_list = df_trainval.tolist() + df_trainval_list = df_trainval.values.tolist() df_trainval_list = [row[3:] for row in df_trainval_list] train_split = [] validation_split = [] for (train_split, validation_split) in msss.split( - df_trainval_list, df_trainval_list + df_trainval_list, df_trainval_list ): train_split = train_split validation_split = validation_split @@ -309,6 +290,16 @@ def processed_file_names(self): def raw_file_names(self): return list(self.raw_file_names_dict.values()) + def _load_chebi(self, version: int): + chebi_name = f'chebi.obo' if version == self.chebi_version else f'chebi_v{version}.obo' + chebi_path = os.path.join(self.raw_dir, chebi_name) + if not os.path.isfile(chebi_path): + print(f"Load ChEBI ontology (v_{version})") + url = f"http://purl.obolibrary.org/obo/chebi/{version}/chebi.obo" + r = requests.get(url, allow_redirects=True) + open(chebi_path, "wb").write(r.content) + return chebi_path + def prepare_data(self, *args, **kwargs): print("Check for raw data in", self.raw_dir) if any( @@ -317,56 +308,30 @@ def prepare_data(self, *args, **kwargs): ): os.makedirs(self.raw_dir, exist_ok=True) print("Missing raw data. Go fetch...") - if self.chebi_version_train is None: - # load chebi_v{chebi_version}, create splits - chebi_path = os.path.join(self.raw_dir, f"chebi.obo") - if not os.path.isfile(chebi_path): - print("Load ChEBI ontology") - url = f"http://purl.obolibrary.org/obo/chebi/{self.chebi_version}/chebi.obo" - r = requests.get(url, allow_redirects=True) - open(chebi_path, "wb").write(r.content) + # missing test set -> create + if not os.path.isfile(os.path.join(self.raw_dir, self.raw_file_names_dict['test'])): + chebi_path = self._load_chebi(self.chebi_version) g = extract_class_hierarchy(chebi_path) - splits = {} - full_data = self.graph_to_raw_dataset(g) - if self.use_inner_cross_validation: - splits['train_val'], splits['test'] = self.get_splits(full_data) - else: - splits['train'], splits['test'], splits['validation'] = self.get_splits(full_data) - for label, split in splits.items(): - self.save(split, self.raw_file_names_dict[label]) + df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['test']) + _, test_df = self.get_test_split(df) + self.save(test_df, self.raw_file_names_dict['test']) + # load test_split from file else: - # missing test set -> create - if not os.path.isfile(os.path.join(self.raw_dir, self.raw_file_names_dict['test'])): - chebi_path = os.path.join(self.raw_dir, f"chebi.obo") - if not os.path.isfile(chebi_path): - print("Load ChEBI ontology") - url = f"http://purl.obolibrary.org/obo/chebi/{self.chebi_version}/chebi.obo" - r = requests.get(url, allow_redirects=True) - open(chebi_path, "wb").write(r.content) - g = extract_class_hierarchy(chebi_path) - df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['test']) - _, test_split, _ = self.get_splits(df) - self.save(df, self.raw_file_names_dict['test']) - else: - # load test_split from file - with open(os.path.join(self.raw_dir, self.raw_file_names_dict['test']), "rb") as input_file: - test_split = [row[0] for row in pickle.load(input_file).values] - chebi_path = os.path.join(self.raw_dir, f"chebi_v{self.chebi_version_train}.obo") - if not os.path.isfile(chebi_path): - print(f"Load ChEBI ontology (v_{self.chebi_version_train})") - url = f"http://purl.obolibrary.org/obo/chebi/{self.chebi_version_train}/chebi.obo" - r = requests.get(url, allow_redirects=True) - open(chebi_path, "wb").write(r.content) - g = extract_class_hierarchy(chebi_path) - if self.use_inner_cross_validation: - df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train_val']) - train_val_df = self.get_splits_given_test(df, test_split) - self.save(train_val_df, self.raw_file_names_dict['train_val']) - else: - df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train']) - train_split, val_split = self.get_splits_given_test(df, test_split) - self.save(train_split, self.raw_file_names_dict['train']) - self.save(val_split, self.raw_file_names_dict['validation']) + with open(os.path.join(self.raw_dir, self.raw_file_names_dict['test']), "rb") as input_file: + test_df = pickle.load(input_file) + # create train/val split based on test set + chebi_path = self._load_chebi( + self.chebi_version_train if self.chebi_version_train is not None else self.chebi_version) + g = extract_class_hierarchy(chebi_path) + if self.use_inner_cross_validation: + df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train_val']) + train_val_df = self.get_train_val_splits_given_test(df, test_df) + self.save(train_val_df, self.raw_file_names_dict['train_val']) + else: + df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train']) + train_split, val_split = self.get_train_val_splits_given_test(df, test_df) + self.save(train_split, self.raw_file_names_dict['train']) + self.save(val_split, self.raw_file_names_dict['validation']) class JCIExtendedBase(_ChEBIDataExtractor): diff --git a/chebai/trainer/InnerCVTrainer.py b/chebai/trainer/InnerCVTrainer.py index 30f48e9c..ad06fe97 100644 --- a/chebai/trainer/InnerCVTrainer.py +++ b/chebai/trainer/InnerCVTrainer.py @@ -39,7 +39,7 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar train_dataloader = datamodule.train_dataloader(ids=train_ids) val_dataloader = datamodule.val_dataloader(ids=val_ids) init_kwargs = self.init_kwargs - new_trainer = Trainer(*self.init_args, **init_kwargs) + new_trainer = InnerCVTrainer(*self.init_args, **init_kwargs) logger = new_trainer.logger if isinstance(logger, CustomLogger): logger.set_fold(fold) From 9c75553cbb84f1ca19c2aa980dbb562223008759 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 5 Dec 2023 17:26:09 +0100 Subject: [PATCH 17/17] fix wandb run naming with cv --- chebai/loggers/custom.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chebai/loggers/custom.py b/chebai/loggers/custom.py index e62129c7..e282b7a6 100644 --- a/chebai/loggers/custom.py +++ b/chebai/loggers/custom.py @@ -49,6 +49,8 @@ def set_fold(self, fold: int): self._fold = fold # start new experiment wandb.finish() + self._wandb_init['name'] = self.name + self._name = self.name self._experiment = None _ = self.experiment