Skip to content

Commit

Permalink
Minor changes and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
WenkelF committed Nov 5, 2024
1 parent 697e3d1 commit d9aa407
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 62 deletions.
37 changes: 18 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ Visit https://graphium-docs.datamol.io/.

## Installation for developers

### For CPU and GPU developers

Use [`mamba`](https://github.com/mamba-org/mamba), a faster and better alternative to `conda`.

If you are using a GPU, we recommend enforcing the CUDA version that you need with `CONDA_OVERRIDE_CUDA=XX.X`.
Expand All @@ -53,18 +51,6 @@ mamba activate graphium
pip install --no-deps -e .
```

### For IPU developers
```bash
# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu`
./install_ipu.sh .graphium_ipu
```

The above step needs to be done once. After that, enable the SDK and the environment as follows:

```bash
source enable_ipu.sh .graphium_ipu
```

## Training a model

To learn how to train a model, we invite you to look at the documentation, or the jupyter notebooks available [here](https://github.com/datamol-io/graphium/tree/master/docs/tutorials/model_training).
Expand Down Expand Up @@ -150,23 +136,36 @@ Thanks to the modular nature of `hydra` you can reuse many of our config setting

### Finetuning

After pretraining a model and saving a model checkpoint, the model can be finetuned to a new task:
After pretraining a model and saving a model checkpoint, the model can be finetuned to a new task

```bash
graphium-train +finetuning custom finetuning.pretrained_model=[model_identifier] constants.data_path=[path_to_data] constants.task=[name_of_task] constants.task_type=[cls OR reg]
graphium-train +finetuning [example-custom OR example-tdc] finetuning.pretrained_model=[model_identifier]
```

The `[model_identifier]` serves to identify the pretrained model among those maintained in the `GRAPHIUM_PRETRAINED_MODELS_DICT` in `graphium/utils/spaces.py`, where the `[model_identifier]` maps to the location of the checkpoint of the pretrained model.

The custom dataset to finetune from consists of two files `raw.csv` and `split.csv` that are provided in `[path_to_data]/[name_of_task]`. The `raw.csv` contains two columns, namely `smiles` with the smiles strings, and `target` with the corresponding targets. In `split.csv`, three columns `train`, `val`, `test` contain the indices of the rows in `raw.csv`. Examples can be found under `expts/data/finetuning_example-reg` (regression) and `expts/data/finetuning_example-cls` (binary classification).
We have provided two example yaml configs under `expts/hydra-configs/finetuning` for finetuning on a custom dataset (`example-custom.yaml`) or for a task from the TDC benchmark collection (`example-tdc.yaml`).

When using `example-custom.yaml`, to finetune on a custom dataset, we nee to provide the location of the data (`constants.data_path=[path_to_data]`) and the type of task (`constants.task_type=[cls OR reg]`).

When using `example-tdc.yaml`, to finetune on a TDC task, we only need to provide the task name (`constants.task=[task_name]`) and the task type is inferred automatically.

Custom datasets to finetune from consist of two files `raw.csv` and `split.csv`. The `raw.csv` contains two columns, namely `smiles` with the smiles strings, and `target` with the corresponding targets. In `split.csv`, three columns `train`, `val`, `test` contain the indices of the rows in `raw.csv`. Examples can be found under `expts/data/finetuning_example-reg` (regression) and `expts/data/finetuning_example-cls` (binary classification).

### Fingerprinting

Alternatively, we can also obtain molecular embeddings (fingerprints) from a pretrained model:
```bash
graphium fps create custom pretrained.model=[model_identifier] pretrained.layers=[layer_identifiers] datamodelu.df_path=[path_to_data]
graphium fps create [example-custom OR example-tdc] pretrained.model=[model_identifier] pretrained.layers=[layer_identifiers]
```

After specifiying the `[model_identifier]`, we need to provide a list of layers from that model where we want to read out embeddings via `[layer_identifiers]`. An example can be found in `expts/hydra-configs/fingerprinting/custom.yaml`. In addition, the location of the smiles to be embedded needs to be passed as `[path_to_data]`. The data can be passed as a csv file with a column `smiles`, similar to `expts/data/finetuning_example-reg/raw.csv`.
We have provided two example yaml configs under `expts/hydra-configs/fingerprinting` for extracting fingerprints for a custom dataset (`example-custom.yaml`) or for a dataset from the TDC benchmark collection (`expample-tdc.yaml`).

After specifiying the `[model_identifier]`, we need to provide a list of layers from that model where we want to read out embeddings via `[layer_identifiers]` (which requires knowledge of the architecture of the pretrained model).

When using `example-custom.yaml`, the location of the smiles to be embedded needs to be passed via `datamodule.df_path=[path_to_data]`. The data can be passed as a csv/parquet file with a column `smiles`, similar to `expts/data/finetuning_example-reg/raw.csv`.

When extracting fingerprints for a TDC task using `expample-tdc.yaml`, we need to specify `datamodule.benchmark` and `datamodule.task` instead of `datamodule.df_path`.

## License

Expand Down
17 changes: 9 additions & 8 deletions graphium/finetuning/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
added_depth: int = 0,
unfreeze_pretrained_depth: Optional[int] = None,
epoch_unfreeze_all: Optional[int] = 0,
freeze_always: Optional[Union[List, str]] = None,
always_freeze_modules: Optional[Union[List, str]] = None,
train_bn: bool = False,
):
"""
Expand All @@ -42,6 +42,7 @@ def __init__(
added_depth: Number of layers of finetuning module that have been modified rel. to pretrained model
unfreeze_pretrained_depth: Number of additional layers to unfreeze before layers modified rel. to pretrained model
epoch_unfreeze_all: Epoch to unfreeze entire model
always_freeze_modules: Module that always stay frozen while finetuning
train_bn: Boolean value indicating if batchnorm layers stay in training mode
"""
Expand All @@ -52,11 +53,11 @@ def __init__(
if unfreeze_pretrained_depth is not None:
self.training_depth += unfreeze_pretrained_depth
self.epoch_unfreeze_all = epoch_unfreeze_all
self.freeze_always = freeze_always
if self.freeze_always == 'none':
self.freeze_always = None
if isinstance(self.freeze_always, str):
self.freeze_always = [self.freeze_always]
self.always_freeze_modules = always_freeze_modules
if self.always_freeze_modules == 'none':
self.always_freeze_modules = None
if isinstance(self.always_freeze_modules, str):
self.always_freeze_modules = [self.always_freeze_modules]
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule):
Expand Down Expand Up @@ -112,6 +113,6 @@ def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer
if epoch == self.epoch_unfreeze_all:
self.unfreeze_and_add_param_group(modules=pl_module, optimizer=optimizer, train_bn=self.train_bn)

if self.freeze_always is not None:
for module_name in self.freeze_always:
if self.always_freeze_modules is not None:
for module_name in self.always_freeze_modules:
self.freeze_module(pl_module, module_name, pl_module.model.pretrained_model.net._module_map)
106 changes: 71 additions & 35 deletions graphium/fingerprinting/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,73 @@

from torch.utils.data import Dataset, DataLoader

from graphium.data.datamodule import MultitaskFromSmilesDataModule, ADMETBenchmarkDataModule
from graphium.data.datamodule import BaseDataModule, MultitaskFromSmilesDataModule, TDCBenchmarkDataModule, DatasetProcessingParams
from graphium.trainer.predictor import PredictorModule
from graphium.fingerprinting.fingerprinter import Fingerprinter


class FingerprintDataset(Dataset):
"""
Dataset class for fingerprints useful for probing experiments.
Parameters:
labels: Labels for the dataset.
fingerprints: Dictionary of fingerprints, where keys specify model and layer of extraction.
smiles: List of SMILES strings.
"""
def __init__(
self,
smiles: List[str],
labels: torch.Tensor,
fingerprints: Dict[str, torch.Tensor],
smiles: List[str] = None,
):
self.smiles = smiles
self.labels = labels
self.fingerprints = fingerprints
self.smiles = smiles

def __len__(self):
return len(self.smiles)
return len(self.labels)

def __getitem__(self, index):
fp_list = []
for val in self.fingerprints.values():
fp_list.append(val[index])
return fp_list, self.labels[index]

if self.smiles is not None:
return fp_list, self.labels[index], self.smiles[index]
else:
return fp_list, self.labels[index]


class FingerprintDatamodule(LightningDataModule):
"""
DataModule class for extracting fingerprints from one (or multiple) pretrained model(s).
Parameters:
pretrained_models: Dictionary of pretrained models (keys) and list of layers (values), repectively to use.
task: Task to extract fingerprints for.
benchmark: Benchmark to extract fingerprints for.
df_path: Path to the DataFrame containing the SMILES strings.
batch_size: Batch size for fingerprint extraction (i.e., the forward passes of the pretrained models).
split_type: Type of split to use for the dataset.
splits_path: Path to the splits file.
split_val: Fraction of validation data.
split_test: Fraction of test data.
data_seed: Seed for data splitting.
num_workers: Number of workers for data loading.
device: Device to use for fingerprint extraction.
mol_cache_dir: Directory to cache the molecules in.
fps_cache_dir: Directory to cache the fingerprints in
"""
def __init__(
self,
pretrained_models: Dict[str, List[str]],
task: str = "herg",
benchmark: Literal["tdc", None] = "tdc",
df_path: str = None,
batch_size: int = 64,
split_type: str = "random",
split_type: Literal["random", "scaffold"] = "random",
splits_path: str = None,
split_val: float = 0.1,
split_test: float = 0.1,
Expand Down Expand Up @@ -103,12 +135,9 @@ def prepare_data(self) -> None:
self.splits.append("test")

self.data = {
split: {
"smiles": [],
"labels": [],
"fps": {},
}
for split in self.splits
"smiles": {split: [] for split in self.splits},
"labels": {split: [] for split in self.splits},
"fps": {split: {} for split in self.splits},
}

for model, layers in self.pretrained_models.items():
Expand All @@ -120,26 +149,33 @@ def prepare_data(self) -> None:
assert self.df_path is not None, "df_path must be provided if not using an integrated benchmark"

# Add a dummy task column (filled with NaN values) in case no such column is provided
smiles_df = pd.read_csv(self.df_path)
base_datamodule = BaseDataModule()
smiles_df = base_datamodule._read_table(self.df_path)
task_cols = [col for col in smiles_df if col.startswith("task_")]
if len(task_cols) == 0:
df_path = ".".join(self.df_path.split(".")[:-1])
df_path, file_type = ".".join(self.df_path.split(".")[:-1]), self.df_path.split(".")[-1]

smiles_df["task_dummy"] = np.nan
smiles_df.to_csv(f"{df_path}_with_dummy_task_col.csv", index=False)
self.df_path = f"{df_path}_with_dummy_task_col.csv"

if file_type == "parquet":
smiles_df.to_parquet(f"{df_path}_with_dummy_task_col.{file_type}", index=False)
else:
smiles_df.to_csv(f"{df_path}_with_dummy_task_col.{file_type}", index=False)

self.df_path = f"{df_path}_with_dummy_task_col.{file_type}"

task_specific_args = {
"fingerprinting": {
"df_path": self.df_path,
"smiles_col": "smiles",
"label_cols": "task_*",
"task_level": "graph",
"splits_path": self.splits_path,
"split_type": self.split_type,
"split_val": self.split_val,
"split_test": self.split_test,
"seed": self.data_seed,
}
"fingerprinting": DatasetProcessingParams(
df_path=self.df_path,
smiles_col="smiles",
label_cols="task_*",
task_level="graph",
splits_path=self.splits_path,
split_type=self.split_type,
split_val=self.split_val,
split_test=self.split_test,
seed=self.data_seed,
)
}
label_key = "graph_fingerprinting"

Expand All @@ -152,7 +188,7 @@ def prepare_data(self) -> None:
)

elif self.benchmark == "tdc":
datamodule = ADMETBenchmarkDataModule(
datamodule = TDCBenchmarkDataModule(
tdc_benchmark_names=[self.task],
tdc_train_val_seed=self.data_seed,
batch_size_inference=128,
Expand Down Expand Up @@ -180,26 +216,26 @@ def prepare_data(self) -> None:
loader_dict["test"] = datamodule.get_dataloader(datamodule.test_ds, shuffle=False, stage="predict")

for split, loader in loader_dict.items():
if len(self.data[split]["smiles"]) == 0:
if len(self.data["smiles"][split]) == 0:
for batch in loader:
self.data[split]["smiles"] += [item for item in batch["smiles"]]
self.data[split]["labels"] += batch["labels"][label_key]
self.data["smiles"][split] += [item for item in batch["smiles"]]
self.data["labels"][split] += batch["labels"][label_key]

with Fingerprinter(predictor, layers, out_type="torch") as fp:
fps = fp.get_fingerprints_for_dataset(loader, store_dict=True)
for fp_name, fp in fps.items():
self.data[split]["fps"][f"{model}/{fp_name}"] = fp
self.data["fps"][split][f"{model}/{fp_name}"] = fp

os.makedirs(self.fps_cache_dir, exist_ok=True)
torch.save(self.data, f"{self.fps_cache_dir}/fps.pt")

def setup(self, stage: str) -> None:
# Creating datasets
if stage == "fit":
self.train_dataset = FingerprintDataset(self.smiles["train"], self.labels["train"], self.fps_dict["train"])
self.valid_dataset = FingerprintDataset(self.smiles["valid"], self.labels["valid"], self.fps_dict["valid"])
self.train_dataset = FingerprintDataset(self.labels["train"], self.fps_dict["train"])
self.valid_dataset = FingerprintDataset(self.labels["valid"], self.fps_dict["valid"])
else:
self.test_dataset = FingerprintDataset(self.smiles["test"], self.labels["test"], self.fps_dict["test"])
self.test_dataset = FingerprintDataset(self.labels["test"], self.fps_dict["test"])

def get_fp_dims(self):
fp_dict = next(iter(self.fps_dict.values()))
Expand Down

0 comments on commit d9aa407

Please sign in to comment.