diff --git a/chebai/cli.py b/chebai/cli.py index 392d3225..7b4fb26f 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -3,10 +3,11 @@ from lightning.pytorch.cli import LightningCLI from chebai.trainer.InnerCVTrainer import InnerCVTrainer + class ChebaiCLI(LightningCLI): def __init__(self, *args, **kwargs): - super().__init__(trainer_class = InnerCVTrainer, *args, **kwargs) + super().__init__(trainer_class=InnerCVTrainer, *args, **kwargs) def add_arguments_to_parser(self, parser): for kind in ("train", "val", "test"): @@ -15,7 +16,7 @@ def add_arguments_to_parser(self, parser): "model.init_args.out_dim", f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels", ) - #parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why + # parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why @staticmethod def subcommands() -> Dict[str, Set[str]]: @@ -28,5 +29,6 @@ def subcommands() -> Dict[str, Set[str]]: "cv_fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, } + def cli(): r = ChebaiCLI(save_config_callback=None, parser_kwargs={"parser_mode": "omegaconf"}) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index cfd7d3d5..12a5ef60 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -24,6 +24,7 @@ ) from chebai.preprocessing.reader import MASK_TOKEN_INDEX, CLS_TOKEN from chebai.preprocessing.datasets.chebi import extract_class_hierarchy +from chebai.loss.pretraining import ElectraPreLoss # noqa import torch import csv @@ -404,3 +405,4 @@ def __call__(self, target, input): memberships, target.unsqueeze(-1).expand(-1, -1, 20) ) return loss + diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index e94495f0..2912881d 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -60,11 +60,11 @@ def full_identifier(self): @property def processed_dir(self): - return os.path.join("data", self._name, f'chebi_v{self.chebi_version}', "processed", *self.identifier) + return os.path.join("data", self._name, "processed", *self.identifier) @property def raw_dir(self): - return os.path.join("data", self._name, f'chebi_v{self.chebi_version}', "raw") + return os.path.join("data", self._name, "raw") @property def _name(self): diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 02c50d16..349a3824 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -15,7 +15,8 @@ import pickle import random -from sklearn.model_selection import train_test_split +from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit + import fastobo import networkx as nx import pandas as pd @@ -117,17 +118,19 @@ def __init__(self, chebi_version_train: int = None, **kwargs): def select_classes(self, g, split_name, *args, **kwargs): raise NotImplementedError - def save(self, g, split, split_name: str): + def graph_to_raw_dataset(self, g, split_name=None): + """Preparation step before creating splits, uses graph created by extract_class_hierarchy() + split_name is only relevant, if a separate train_version is set""" smiles = nx.get_node_attributes(g, "smiles") names = nx.get_node_attributes(g, "name") print("build labels") - print(f"Process {split_name}") + print(f"Process graph") molecules, smiles_list = zip( *( (n, smiles) - for n, smiles in ((n, smiles.get(n)) for n in split) + for n, smiles in ((n, smiles.get(n)) for n in smiles.keys()) if smiles ) ) @@ -142,6 +145,10 @@ def save(self, g, split, split_name: str): data = pd.DataFrame(data) data = data[~data["SMILES"].isnull()] data = data[data.iloc[:, 3:].any(axis=1)] + return data + + def save(self, data: pd.DataFrame, split_name: str): + pickle.dump(data, open(os.path.join(self.raw_dir, split_name), "wb")) @staticmethod @@ -192,44 +199,90 @@ def setup_processed(self): self._setup_pruned_test_set() self.reader.save_token_cache() - def get_splits(self, g): - fixed_nodes = list(g.nodes) + def get_splits(self, df: pd.DataFrame): print("Split dataset") - random.shuffle(fixed_nodes) - train_split, test_split = train_test_split( - fixed_nodes, train_size=self.train_split, shuffle=True - ) + df_list = df.values.tolist() + df_list = [row[3:] for row in df_list] + + msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0) + + train_split = [] + test_split = [] + for (train_split, test_split) in msss.split( + df_list, df_list, + ): + train_split = train_split + test_split = test_split + break + df_train = df.iloc[train_split] + df_test = df.iloc[test_split] if self.use_inner_cross_validation: - return train_split, test_split + return df_train, df_test + + df_test_list = df_test.values.tolist() + df_test_list = [row[3:] for row in df_test_list] + validation_split = [] + test_split = [] + msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0) + for (test_split, validation_split) in msss.split( + df_test_list, df_test_list + ): + test_split = test_split + validation_split = validation_split + break - test_split, validation_split = train_test_split( - test_split, train_size=self.train_split, shuffle=True - ) - return train_split, test_split, validation_split + df_validation = df_test.iloc[validation_split] + df_test = df_test.iloc[test_split] + return df_train, df_test, df_validation - def get_splits_given_test(self, g, test_split): + def get_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFrame): """ Use test set from another chebi version the model does not train on, avoid overlap""" - fixed_nodes = list(g.nodes) print(f"Split dataset for chebi_v{self.chebi_version_train}") - for node in test_split: - if node in fixed_nodes: - fixed_nodes.remove(node) - random.shuffle(fixed_nodes) + df_trainval = df + test_smiles = test_df['SMILES'].tolist() + mask = [] + for row in df_trainval: + if row['SMILES'] in test_smiles: + mask.append(False) + else: + mask.append(True) + df_trainval = df_trainval[mask] + if self.use_inner_cross_validation: - return fixed_nodes + return df_trainval + # assume that size of validation split should relate to train split as in get_splits() - validation_split, train_split = train_test_split( - fixed_nodes, train_size=(1 - self.train_split) ** 2, shuffle=True - ) - return train_split, validation_split + msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=self.train_split ** 2, random_state=0) + + df_trainval_list = df_trainval.tolist() + df_trainval_list = [row[3:] for row in df_trainval_list] + train_split = [] + validation_split = [] + for (train_split, validation_split) in msss.split( + df_trainval_list, df_trainval_list + ): + train_split = train_split + validation_split = validation_split + + df_validation = df_trainval.iloc[validation_split] + df_train = df_trainval.iloc[train_split] + return df_train, df_validation + + @property + def processed_dir(self): + return os.path.join("data", self._name, f'chebi_v{self.chebi_version}', "processed", *self.identifier) + + @property + def raw_dir(self): + return os.path.join("data", self._name, f'chebi_v{self.chebi_version}', "raw") @property def processed_file_names_dict(self) -> dict: train_v_str = f'_v{self.chebi_version_train}' if self.chebi_version_train else '' res = {'test': f"test{train_v_str}.pt"} if self.use_inner_cross_validation: - res['train_val'] = f'trainval{train_v_str}.pt' # for cv, split train/val on runtime + res['train_val'] = f'trainval{train_v_str}.pt' # for cv, split train/val on runtime else: res['train'] = f"train{train_v_str}.pt" res['validation'] = f"validation{train_v_str}.pt" @@ -238,10 +291,10 @@ def processed_file_names_dict(self) -> dict: @property def raw_file_names_dict(self) -> dict: train_v_str = f'_v{self.chebi_version_train}' if self.chebi_version_train else '' - res = {'test': f"test.pkl"} # no extra raw test version for chebi_version_train - use default test set and only - # adapt processed file + res = {'test': f"test.pkl"} # no extra raw test version for chebi_version_train - use default test set and only + # adapt processed file if self.use_inner_cross_validation: - res['train_val'] = f'trainval{train_v_str}.pkl' # for cv, split train/val on runtime + res['train_val'] = f'trainval{train_v_str}.pkl' # for cv, split train/val on runtime else: res['train'] = f"train{train_v_str}.pkl" res['validation'] = f"validation{train_v_str}.pkl" @@ -274,12 +327,13 @@ def prepare_data(self, *args, **kwargs): open(chebi_path, "wb").write(r.content) g = extract_class_hierarchy(chebi_path) splits = {} + full_data = self.graph_to_raw_dataset(g) if self.use_inner_cross_validation: - splits['train_val'], splits['test'] = self.get_splits(g) + splits['train_val'], splits['test'] = self.get_splits(full_data) else: - splits['train'], splits['test'], splits['validation'] = self.get_splits(g) + splits['train'], splits['test'], splits['validation'] = self.get_splits(full_data) for label, split in splits.items(): - self.save(g, split, self.raw_file_names_dict[label]) + self.save(split, self.raw_file_names_dict[label]) else: # missing test set -> create if not os.path.isfile(os.path.join(self.raw_dir, self.raw_file_names_dict['test'])): @@ -290,8 +344,9 @@ def prepare_data(self, *args, **kwargs): r = requests.get(url, allow_redirects=True) open(chebi_path, "wb").write(r.content) g = extract_class_hierarchy(chebi_path) - _, test_split, _ = self.get_splits(g) - self.save(g, test_split, self.raw_file_names_dict['test']) + df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['test']) + _, test_split, _ = self.get_splits(df) + self.save(df, self.raw_file_names_dict['test']) else: # load test_split from file with open(os.path.join(self.raw_dir, self.raw_file_names_dict['test']), "rb") as input_file: @@ -304,12 +359,14 @@ def prepare_data(self, *args, **kwargs): open(chebi_path, "wb").write(r.content) g = extract_class_hierarchy(chebi_path) if self.use_inner_cross_validation: - train_val_data = self.get_splits_given_test(g, test_split) - self.save(g, train_val_data, self.raw_file_names_dict['train_val']) + df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train_val']) + train_val_df = self.get_splits_given_test(df, test_split) + self.save(train_val_df, self.raw_file_names_dict['train_val']) else: - train_split, val_split = self.get_splits_given_test(g, test_split) - self.save(g, train_split, self.raw_file_names_dict['train']) - self.save(g, val_split, self.raw_file_names_dict['validation']) + df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train']) + train_split, val_split = self.get_splits_given_test(df, test_split) + self.save(train_split, self.raw_file_names_dict['train']) + self.save(val_split, self.raw_file_names_dict['validation']) class JCIExtendedBase(_ChEBIDataExtractor): @@ -363,6 +420,7 @@ def select_classes(self, g, split_name, *args, **kwargs): fout.writelines(str(node) + "\n" for node in nodes) return nodes + class ChEBIOverXDeepSMILES(ChEBIOverX): READER = dr.DeepChemDataReader @@ -380,6 +438,7 @@ class ChEBIOver50(ChEBIOverX): def label_number(self): return 1332 + class ChEBIOver100DeepSMILES(ChEBIOverXDeepSMILES, ChEBIOver100): pass diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index cb1839b4..b3bb0423 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -31,6 +31,8 @@ class PubChem(XYBaseDataModule): def __init__(self, *args, k=100000, **kwargs): self._k = k + self.pubchem_url = "https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/Monthly/2023-11-01/Extras/CID-SMILES.gz" + super(PubChem, self).__init__(*args, **kwargs) @property @@ -60,18 +62,17 @@ def _load_dict(input_file_path): yield dict(features=smiles, labels=None, ident=ident) def download(self): - url = f"https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/Monthly/2021-10-01/Extras/CID-SMILES.gz" if self._k == PubChem.FULL: if not os.path.isfile(os.path.join(self.raw_dir, "smiles.txt")): - print("Download from", url) - r = requests.get(url, allow_redirects=True) + print("Download from", self.pubchem_url) + r = requests.get(self.pubchem_url, allow_redirects=True) with tempfile.NamedTemporaryFile() as tf: tf.write(r.content) print("Unpacking...") tf.seek(0) with gzip.open(tf, "rb") as f_in: with open( - os.path.join(self.raw_dir, "smiles.txt"), "wb" + os.path.join(self.raw_dir, "smiles.txt"), "wb" ) as f_out: shutil.copyfileobj(f_in, f_out) else: @@ -114,8 +115,8 @@ def processed_file_names(self): def prepare_data(self, *args, **kwargs): print("Check for raw data in", self.raw_dir) if any( - not os.path.isfile(os.path.join(self.raw_dir, f)) - for f in self.raw_file_names + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names ): print("Downloading data. This may take some time...") self.download() @@ -269,3 +270,6 @@ class PubToxAndChebi100(PubToxAndChebiX): class PubToxAndChebi50(PubToxAndChebiX): CHEBI_X = ChEBIOver50 + +class PubChemDeepSMILES(PubChem): + READER = dr.DeepChemDataReader diff --git a/chebai/trainer/InnerCVTrainer.py b/chebai/trainer/InnerCVTrainer.py index ecbe136f..1150d639 100644 --- a/chebai/trainer/InnerCVTrainer.py +++ b/chebai/trainer/InnerCVTrainer.py @@ -11,7 +11,7 @@ from lightning.fabric.plugins.environments import SLURMEnvironment from lightning_utilities.core.rank_zero import WarningCache -from sklearn import model_selection +from iterstrat.ml_stratifiers import MultilabelStratifiedKFold from lightning.pytorch.callbacks.model_checkpoint import _is_dir, rank_zero_warn from chebai.preprocessing.datasets.base import XYBaseDataModule @@ -28,7 +28,6 @@ def __init__(self, *args, **kwargs): # instantiation custom logger connector self._logger_connector.on_trainer_init(self.logger, 1) - def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwargs): if n_splits < 2: self.fit(datamodule=datamodule, *args, **kwargs) @@ -36,9 +35,10 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar datamodule.prepare_data() datamodule.setup() - kfold = model_selection.KFold(n_splits=n_splits) + kfold = MultilabelStratifiedKFold(n_splits=n_splits) - for fold, (train_ids, val_ids) in enumerate(kfold.split(datamodule.train_val_data)): + for fold, (train_ids, val_ids) in enumerate( + kfold.split(datamodule.train_val_data, [data['labels'] for data in datamodule.train_val_data])): train_dataloader = datamodule.train_dataloader(ids=train_ids) val_dataloader = datamodule.val_dataloader(ids=val_ids) init_kwargs = self.init_kwargs @@ -78,7 +78,6 @@ class ModelCheckpointCVSupport(ModelCheckpoint): def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: """Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir""" if self.dirpath is not None: - print(f'Eliminating existing dirpath {self.dirpath} at ModelCheckpoint setup') self.dirpath = None dirpath = self.__resolve_ckpt_dir(trainer) dirpath = trainer.strategy.broadcast(dirpath) @@ -126,6 +125,3 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH: print(f'Now using checkpoint path {ckpt_path}') return ckpt_path - - - diff --git a/configs/data/chebi100_deepSMILES.yml b/configs/data/chebi100_deepSMILES.yml new file mode 100644 index 00000000..943f0e17 --- /dev/null +++ b/configs/data/chebi100_deepSMILES.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.chebi.ChEBIOver100DeepSMILES \ No newline at end of file diff --git a/configs/data/pubchem_deepSMILES.yml b/configs/data/pubchem_deepSMILES.yml new file mode 100644 index 00000000..42f21199 --- /dev/null +++ b/configs/data/pubchem_deepSMILES.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.pubchem.PubChemDeepSMILES diff --git a/configs/loss/electra_pre_loss.yml b/configs/loss/electra_pre_loss.yml new file mode 100644 index 00000000..06520b2f --- /dev/null +++ b/configs/loss/electra_pre_loss.yml @@ -0,0 +1 @@ +class_path: chebai.loss.pretraining.ElectraPreLoss diff --git a/configs/model/electra_pretraining.yml b/configs/model/electra_pretraining.yml new file mode 100644 index 00000000..7b78e48d --- /dev/null +++ b/configs/model/electra_pretraining.yml @@ -0,0 +1,18 @@ +class_path: chebai.models.ElectraPre +init_args: + out_dim: null + optimizer_kwargs: + lr: 1e-4 + config: + generator: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 + discriminator: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 \ No newline at end of file diff --git a/configs/training/default_callbacks.yml b/configs/training/default_callbacks.yml index 4f27c4d7..7e0eb5fd 100644 --- a/configs/training/default_callbacks.yml +++ b/configs/training/default_callbacks.yml @@ -1,6 +1,7 @@ - class_path: chebai.trainer.InnerCVTrainer.ModelCheckpoint init_args: monitor: val_micro-f1 + mode: 'max' filename: 'best_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}' every_n_epochs: 1 save_top_k: 5 diff --git a/configs/training/pretraining_callbacks.yml b/configs/training/pretraining_callbacks.yml new file mode 100644 index 00000000..d998302d --- /dev/null +++ b/configs/training/pretraining_callbacks.yml @@ -0,0 +1,12 @@ +- class_path: chebai.trainer.InnerCVTrainer.ModelCheckpointCVSupport + init_args: + monitor: val_loss + mode: 'min' + filename: 'best_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}' + every_n_epochs: 1 + save_top_k: 5 +- class_path: chebai.trainer.InnerCVTrainer.ModelCheckpointCVSupport + init_args: + filename: 'per_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}' + every_n_epochs: 5 + save_top_k: -1 \ No newline at end of file diff --git a/configs/training/pretraining_trainer.yml b/configs/training/pretraining_trainer.yml new file mode 100644 index 00000000..9dd05a3a --- /dev/null +++ b/configs/training/pretraining_trainer.yml @@ -0,0 +1,10 @@ +min_epochs: 100 +max_epochs: 100 + +default_root_dir: &default_root_dir logs +logger: + class_path: lightning.pytorch.loggers.CSVLogger + init_args: + save_dir: *default_root_dir + +callbacks: pretraining_callbacks.yml \ No newline at end of file