diff --git a/README.md b/README.md index 61c8132c..68015f4d 100644 --- a/README.md +++ b/README.md @@ -26,15 +26,23 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co python -m chebai train --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --ckpt_path=[path-to-model-with-ontology-pretraining] ``` -## Features on branch `features-sfluegel` -### Cross-validation +## Predicting classes given SMILES strings + +``` +python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +``` +The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the +one row for each SMILES string and one column for each class. + + +## Cross-validation Use inner cross-validation by not splitting between test and validation sets at dataset creation, but using k-fold cross-validation at runtime. This creates k models with separate metrics and checkpoints. For training with `k`-fold cross-validation, use the `cv-fit` subcommand and the options ``` --data.init_args.inner_k_folds=k --n_splits=k ``` -### Chebi versions +## Chebi versions Change the chebi version used for all sets (default: 200): ``` --data.init_args.chebi_version=VERSION diff --git a/chebai/cli.py b/chebai/cli.py index 0c091db9..13659e40 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -28,6 +28,7 @@ def subcommands() -> Dict[str, Set[str]]: "test": {"model", "dataloaders", "datamodule"}, "predict": {"model", "dataloaders", "datamodule"}, "cv_fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, + "predict_from_file": {"model"} } diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 92e3935d..0f6cc9d7 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -84,10 +84,13 @@ class ChemDataReader(DataReader): def name(cls): return "smiles_token" - def __init__(self, *args, **kwargs): + def __init__(self, token_path = None, *args, **kwargs): super().__init__(*args, **kwargs) - dirname = os.path.dirname(__file__) - with open(os.path.join(dirname, "bin", "tokens.txt"), "r") as pk: + if token_path is None: + dirname = os.path.dirname(__file__) + token_path = os.path.join(dirname, "bin", "tokens.txt") + self.token_path = token_path + with open(self.token_path, "r") as pk: self.cache = [x.strip() for x in pk] def _get_token_index(self, token): @@ -104,8 +107,8 @@ def _read_data(self, raw_data): def save_token_cache(self): """write contents of self.cache into tokens.txt""" dirname = os.path.dirname(__file__) - with open(os.path.join(dirname, "bin", "tokens.txt"), "w") as pk: - print(f'saving tokens to {os.path.join(dirname, "bin", "tokens.txt")}...') + with open(self.token_path, "w") as pk: + print(f'saving tokens to {self.token_path}...') print(f'first 10 tokens: {self.cache[:10]}') pk.writelines([f'{c}\n' for c in self.cache]) diff --git a/chebai/trainer/InnerCVTrainer.py b/chebai/trainer/InnerCVTrainer.py index 1150d639..559e51df 100644 --- a/chebai/trainer/InnerCVTrainer.py +++ b/chebai/trainer/InnerCVTrainer.py @@ -1,20 +1,23 @@ import logging import os -from typing import Optional, Union, Iterable +from typing import Optional, Union +import pandas as pd from lightning import Trainer, LightningModule from lightning.fabric.utilities.types import _PATH -from lightning.pytorch.trainer.connectors.logger_connector import _LoggerConnector -from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE -from lightning.pytorch.loggers import CSVLogger, Logger, TensorBoardLogger from lightning.pytorch.callbacks import ModelCheckpoint from lightning.fabric.plugins.environments import SLURMEnvironment from lightning_utilities.core.rank_zero import WarningCache - +from lightning.pytorch.loggers import CSVLogger from iterstrat.ml_stratifiers import MultilabelStratifiedKFold from lightning.pytorch.callbacks.model_checkpoint import _is_dir, rank_zero_warn from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.collate import RaggedCollater +from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader +from torch.nn.utils.rnn import pad_sequence +import torch +import pandas as pd log = logging.getLogger(__name__) @@ -49,6 +52,32 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar print(f'Logging this fold at {new_trainer.logger.log_dir}') new_trainer.fit(train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, *args, **kwargs) + def predict_from_file(self, model: LightningModule, checkpoint_path: _PATH, input_path: _PATH, + save_to: _PATH='predictions.csv', classes_path: Optional[_PATH] = None): + loaded_model= model.__class__.load_from_checkpoint(checkpoint_path) + with open(input_path, 'r') as input: + smiles_strings = [inp.strip() for inp in input.readlines()] + predictions = self._predict_smiles(loaded_model, smiles_strings) + predictions_df = pd.DataFrame(predictions.detach().numpy()) + if classes_path is not None: + with open(classes_path, 'r') as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + predictions_df.index = smiles_strings + predictions_df.to_csv(save_to) + + + def _predict_smiles(self, model: LightningModule, smiles: list[str]): + reader = ChemDataReader() + parsed_smiles = [reader._read_data(s) for s in smiles] + x = pad_sequence([torch.tensor(a) for a in parsed_smiles], batch_first=True) + cls_tokens = (torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) * CLS_TOKEN) + features = torch.cat((cls_tokens, x), dim=1) + model_output = model({'features': features}) + preds = torch.sigmoid(model_output['logits']) + + print(preds.shape) + return preds + # extend CSVLogger to include fold number in log path class CSVLoggerCVSupport(CSVLogger):