Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature pretraining #5

Merged
merged 9 commits into from
Nov 23, 2023
6 changes: 4 additions & 2 deletions chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from lightning.pytorch.cli import LightningCLI
from chebai.trainer.InnerCVTrainer import InnerCVTrainer


class ChebaiCLI(LightningCLI):

def __init__(self, *args, **kwargs):
super().__init__(trainer_class = InnerCVTrainer, *args, **kwargs)
super().__init__(trainer_class=InnerCVTrainer, *args, **kwargs)

def add_arguments_to_parser(self, parser):
for kind in ("train", "val", "test"):
Expand All @@ -15,7 +16,7 @@ def add_arguments_to_parser(self, parser):
"model.init_args.out_dim",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels",
)
#parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why
# parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why

@staticmethod
def subcommands() -> Dict[str, Set[str]]:
Expand All @@ -28,5 +29,6 @@ def subcommands() -> Dict[str, Set[str]]:
"cv_fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
}


def cli():
r = ChebaiCLI(save_config_callback=None, parser_kwargs={"parser_mode": "omegaconf"})
2 changes: 2 additions & 0 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from chebai.preprocessing.reader import MASK_TOKEN_INDEX, CLS_TOKEN
from chebai.preprocessing.datasets.chebi import extract_class_hierarchy
from chebai.loss.pretraining import ElectraPreLoss # noqa
import torch
import csv

Expand Down Expand Up @@ -404,3 +405,4 @@ def __call__(self, target, input):
memberships, target.unsqueeze(-1).expand(-1, -1, 20)
)
return loss

4 changes: 2 additions & 2 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def full_identifier(self):

@property
def processed_dir(self):
return os.path.join("data", self._name, f'chebi_v{self.chebi_version}', "processed", *self.identifier)
return os.path.join("data", self._name, "processed", *self.identifier)

@property
def raw_dir(self):
return os.path.join("data", self._name, f'chebi_v{self.chebi_version}', "raw")
return os.path.join("data", self._name, "raw")

@property
def _name(self):
Expand Down
139 changes: 99 additions & 40 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import pickle
import random

from sklearn.model_selection import train_test_split
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

import fastobo
import networkx as nx
import pandas as pd
Expand Down Expand Up @@ -117,17 +118,19 @@ def __init__(self, chebi_version_train: int = None, **kwargs):
def select_classes(self, g, split_name, *args, **kwargs):
raise NotImplementedError

def save(self, g, split, split_name: str):
def graph_to_raw_dataset(self, g, split_name=None):
"""Preparation step before creating splits, uses graph created by extract_class_hierarchy()
split_name is only relevant, if a separate train_version is set"""
smiles = nx.get_node_attributes(g, "smiles")
names = nx.get_node_attributes(g, "name")

print("build labels")
print(f"Process {split_name}")
print(f"Process graph")

molecules, smiles_list = zip(
*(
(n, smiles)
for n, smiles in ((n, smiles.get(n)) for n in split)
for n, smiles in ((n, smiles.get(n)) for n in smiles.keys())
if smiles
)
)
Expand All @@ -142,6 +145,10 @@ def save(self, g, split, split_name: str):
data = pd.DataFrame(data)
data = data[~data["SMILES"].isnull()]
data = data[data.iloc[:, 3:].any(axis=1)]
return data

def save(self, data: pd.DataFrame, split_name: str):

pickle.dump(data, open(os.path.join(self.raw_dir, split_name), "wb"))

@staticmethod
Expand Down Expand Up @@ -192,44 +199,90 @@ def setup_processed(self):
self._setup_pruned_test_set()
self.reader.save_token_cache()

def get_splits(self, g):
fixed_nodes = list(g.nodes)
def get_splits(self, df: pd.DataFrame):
print("Split dataset")
random.shuffle(fixed_nodes)

train_split, test_split = train_test_split(
fixed_nodes, train_size=self.train_split, shuffle=True
)
df_list = df.values.tolist()
df_list = [row[3:] for row in df_list]

msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0)

train_split = []
test_split = []
for (train_split, test_split) in msss.split(
df_list, df_list,
):
train_split = train_split
test_split = test_split
break
df_train = df.iloc[train_split]
df_test = df.iloc[test_split]
if self.use_inner_cross_validation:
return train_split, test_split
return df_train, df_test

df_test_list = df_test.values.tolist()
df_test_list = [row[3:] for row in df_test_list]
validation_split = []
test_split = []
msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0)
for (test_split, validation_split) in msss.split(
df_test_list, df_test_list
):
test_split = test_split
validation_split = validation_split
break

test_split, validation_split = train_test_split(
test_split, train_size=self.train_split, shuffle=True
)
return train_split, test_split, validation_split
df_validation = df_test.iloc[validation_split]
df_test = df_test.iloc[test_split]
return df_train, df_test, df_validation

def get_splits_given_test(self, g, test_split):
def get_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFrame):
""" Use test set from another chebi version the model does not train on, avoid overlap"""
fixed_nodes = list(g.nodes)
print(f"Split dataset for chebi_v{self.chebi_version_train}")
for node in test_split:
if node in fixed_nodes:
fixed_nodes.remove(node)
random.shuffle(fixed_nodes)
df_trainval = df
test_smiles = test_df['SMILES'].tolist()
mask = []
for row in df_trainval:
if row['SMILES'] in test_smiles:
mask.append(False)
else:
mask.append(True)
df_trainval = df_trainval[mask]

if self.use_inner_cross_validation:
return fixed_nodes
return df_trainval

# assume that size of validation split should relate to train split as in get_splits()
validation_split, train_split = train_test_split(
fixed_nodes, train_size=(1 - self.train_split) ** 2, shuffle=True
)
return train_split, validation_split
msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=self.train_split ** 2, random_state=0)

df_trainval_list = df_trainval.tolist()
df_trainval_list = [row[3:] for row in df_trainval_list]
train_split = []
validation_split = []
for (train_split, validation_split) in msss.split(
df_trainval_list, df_trainval_list
):
train_split = train_split
validation_split = validation_split

df_validation = df_trainval.iloc[validation_split]
df_train = df_trainval.iloc[train_split]
return df_train, df_validation

@property
def processed_dir(self):
return os.path.join("data", self._name, f'chebi_v{self.chebi_version}', "processed", *self.identifier)

@property
def raw_dir(self):
return os.path.join("data", self._name, f'chebi_v{self.chebi_version}', "raw")

@property
def processed_file_names_dict(self) -> dict:
train_v_str = f'_v{self.chebi_version_train}' if self.chebi_version_train else ''
res = {'test': f"test{train_v_str}.pt"}
if self.use_inner_cross_validation:
res['train_val'] = f'trainval{train_v_str}.pt' # for cv, split train/val on runtime
res['train_val'] = f'trainval{train_v_str}.pt' # for cv, split train/val on runtime
else:
res['train'] = f"train{train_v_str}.pt"
res['validation'] = f"validation{train_v_str}.pt"
Expand All @@ -238,10 +291,10 @@ def processed_file_names_dict(self) -> dict:
@property
def raw_file_names_dict(self) -> dict:
train_v_str = f'_v{self.chebi_version_train}' if self.chebi_version_train else ''
res = {'test': f"test.pkl"} # no extra raw test version for chebi_version_train - use default test set and only
# adapt processed file
res = {'test': f"test.pkl"} # no extra raw test version for chebi_version_train - use default test set and only
# adapt processed file
if self.use_inner_cross_validation:
res['train_val'] = f'trainval{train_v_str}.pkl' # for cv, split train/val on runtime
res['train_val'] = f'trainval{train_v_str}.pkl' # for cv, split train/val on runtime
else:
res['train'] = f"train{train_v_str}.pkl"
res['validation'] = f"validation{train_v_str}.pkl"
Expand Down Expand Up @@ -274,12 +327,13 @@ def prepare_data(self, *args, **kwargs):
open(chebi_path, "wb").write(r.content)
g = extract_class_hierarchy(chebi_path)
splits = {}
full_data = self.graph_to_raw_dataset(g)
if self.use_inner_cross_validation:
splits['train_val'], splits['test'] = self.get_splits(g)
splits['train_val'], splits['test'] = self.get_splits(full_data)
else:
splits['train'], splits['test'], splits['validation'] = self.get_splits(g)
splits['train'], splits['test'], splits['validation'] = self.get_splits(full_data)
for label, split in splits.items():
self.save(g, split, self.raw_file_names_dict[label])
self.save(split, self.raw_file_names_dict[label])
else:
# missing test set -> create
if not os.path.isfile(os.path.join(self.raw_dir, self.raw_file_names_dict['test'])):
Expand All @@ -290,8 +344,9 @@ def prepare_data(self, *args, **kwargs):
r = requests.get(url, allow_redirects=True)
open(chebi_path, "wb").write(r.content)
g = extract_class_hierarchy(chebi_path)
_, test_split, _ = self.get_splits(g)
self.save(g, test_split, self.raw_file_names_dict['test'])
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['test'])
_, test_split, _ = self.get_splits(df)
self.save(df, self.raw_file_names_dict['test'])
else:
# load test_split from file
with open(os.path.join(self.raw_dir, self.raw_file_names_dict['test']), "rb") as input_file:
Expand All @@ -304,12 +359,14 @@ def prepare_data(self, *args, **kwargs):
open(chebi_path, "wb").write(r.content)
g = extract_class_hierarchy(chebi_path)
if self.use_inner_cross_validation:
train_val_data = self.get_splits_given_test(g, test_split)
self.save(g, train_val_data, self.raw_file_names_dict['train_val'])
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train_val'])
train_val_df = self.get_splits_given_test(df, test_split)
self.save(train_val_df, self.raw_file_names_dict['train_val'])
else:
train_split, val_split = self.get_splits_given_test(g, test_split)
self.save(g, train_split, self.raw_file_names_dict['train'])
self.save(g, val_split, self.raw_file_names_dict['validation'])
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train'])
train_split, val_split = self.get_splits_given_test(df, test_split)
self.save(train_split, self.raw_file_names_dict['train'])
self.save(val_split, self.raw_file_names_dict['validation'])


class JCIExtendedBase(_ChEBIDataExtractor):
Expand Down Expand Up @@ -363,6 +420,7 @@ def select_classes(self, g, split_name, *args, **kwargs):
fout.writelines(str(node) + "\n" for node in nodes)
return nodes


class ChEBIOverXDeepSMILES(ChEBIOverX):
READER = dr.DeepChemDataReader

Expand All @@ -380,6 +438,7 @@ class ChEBIOver50(ChEBIOverX):
def label_number(self):
return 1332


class ChEBIOver100DeepSMILES(ChEBIOverXDeepSMILES, ChEBIOver100):
pass

Expand Down
16 changes: 10 additions & 6 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class PubChem(XYBaseDataModule):

def __init__(self, *args, k=100000, **kwargs):
self._k = k
self.pubchem_url = "https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/Monthly/2023-11-01/Extras/CID-SMILES.gz"

super(PubChem, self).__init__(*args, **kwargs)

@property
Expand Down Expand Up @@ -60,18 +62,17 @@ def _load_dict(input_file_path):
yield dict(features=smiles, labels=None, ident=ident)

def download(self):
url = f"https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/Monthly/2021-10-01/Extras/CID-SMILES.gz"
if self._k == PubChem.FULL:
if not os.path.isfile(os.path.join(self.raw_dir, "smiles.txt")):
print("Download from", url)
r = requests.get(url, allow_redirects=True)
print("Download from", self.pubchem_url)
r = requests.get(self.pubchem_url, allow_redirects=True)
with tempfile.NamedTemporaryFile() as tf:
tf.write(r.content)
print("Unpacking...")
tf.seek(0)
with gzip.open(tf, "rb") as f_in:
with open(
os.path.join(self.raw_dir, "smiles.txt"), "wb"
os.path.join(self.raw_dir, "smiles.txt"), "wb"
) as f_out:
shutil.copyfileobj(f_in, f_out)
else:
Expand Down Expand Up @@ -114,8 +115,8 @@ def processed_file_names(self):
def prepare_data(self, *args, **kwargs):
print("Check for raw data in", self.raw_dir)
if any(
not os.path.isfile(os.path.join(self.raw_dir, f))
for f in self.raw_file_names
not os.path.isfile(os.path.join(self.raw_dir, f))
for f in self.raw_file_names
):
print("Downloading data. This may take some time...")
self.download()
Expand Down Expand Up @@ -269,3 +270,6 @@ class PubToxAndChebi100(PubToxAndChebiX):

class PubToxAndChebi50(PubToxAndChebiX):
CHEBI_X = ChEBIOver50

class PubChemDeepSMILES(PubChem):
READER = dr.DeepChemDataReader
12 changes: 4 additions & 8 deletions chebai/trainer/InnerCVTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning_utilities.core.rank_zero import WarningCache

from sklearn import model_selection
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from lightning.pytorch.callbacks.model_checkpoint import _is_dir, rank_zero_warn

from chebai.preprocessing.datasets.base import XYBaseDataModule
Expand All @@ -28,17 +28,17 @@ def __init__(self, *args, **kwargs):
# instantiation custom logger connector
self._logger_connector.on_trainer_init(self.logger, 1)


def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwargs):
if n_splits < 2:
self.fit(datamodule=datamodule, *args, **kwargs)
else:
datamodule.prepare_data()
datamodule.setup()

kfold = model_selection.KFold(n_splits=n_splits)
kfold = MultilabelStratifiedKFold(n_splits=n_splits)

for fold, (train_ids, val_ids) in enumerate(kfold.split(datamodule.train_val_data)):
for fold, (train_ids, val_ids) in enumerate(
kfold.split(datamodule.train_val_data, [data['labels'] for data in datamodule.train_val_data])):
train_dataloader = datamodule.train_dataloader(ids=train_ids)
val_dataloader = datamodule.val_dataloader(ids=val_ids)
init_kwargs = self.init_kwargs
Expand Down Expand Up @@ -78,7 +78,6 @@ class ModelCheckpointCVSupport(ModelCheckpoint):
def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
"""Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir"""
if self.dirpath is not None:
print(f'Eliminating existing dirpath {self.dirpath} at ModelCheckpoint setup')
self.dirpath = None
dirpath = self.__resolve_ckpt_dir(trainer)
dirpath = trainer.strategy.broadcast(dirpath)
Expand Down Expand Up @@ -126,6 +125,3 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:

print(f'Now using checkpoint path {ckpt_path}')
return ckpt_path



1 change: 1 addition & 0 deletions configs/data/chebi100_deepSMILES.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class_path: chebai.preprocessing.datasets.chebi.ChEBIOver100DeepSMILES
1 change: 1 addition & 0 deletions configs/data/pubchem_deepSMILES.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class_path: chebai.preprocessing.datasets.pubchem.PubChemDeepSMILES
1 change: 1 addition & 0 deletions configs/loss/electra_pre_loss.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class_path: chebai.loss.pretraining.ElectraPreLoss
Loading
Loading