Skip to content

Commit

Permalink
Merge branch 'feature-wandb-integration' into features-sfluegel
Browse files Browse the repository at this point in the history
# Conflicts:
#	chebai/preprocessing/reader.py
#	chebai/trainer/InnerCVTrainer.py
  • Loading branch information
sfluegel committed Dec 5, 2023
2 parents 21937f6 + 9c75553 commit 9be20cb
Show file tree
Hide file tree
Showing 14 changed files with 435 additions and 221 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,3 @@ cython_debug/
#.idea/

configs/
# the notebook I put in the wrong folder
chebai/preprocessing/datasets/demo_old_chebi.ipynb
demo_examine_pretraining_data.ipynb
61 changes: 61 additions & 0 deletions chebai/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer, LightningModule
import os
from lightning.fabric.utilities.cloud_io import _is_dir
from lightning.pytorch.utilities.rank_zero import rank_zero_info

class CustomModelCheckpoint(ModelCheckpoint):
"""Checkpoint class that resolves checkpoint paths s.t. for the CustomLogger, checkpoints get saved to the
same directory as the other logs"""

def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
"""Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir"""
if self.dirpath is not None:
self.dirpath = None
dirpath = self.__resolve_ckpt_dir(trainer)
dirpath = trainer.strategy.broadcast(dirpath)
self.dirpath = dirpath
if trainer.is_global_zero and stage == "fit":
self.__warn_if_dir_not_empty(self.dirpath)

def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
"""Same as in parent class, duplicated because method in parent class is not accessible"""
if self.save_top_k != 0 and _is_dir(self._fs, dirpath, strict=True) and len(self._fs.ls(dirpath)) > 0:
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
"""Determines model checkpoint save directory at runtime. Reference attributes from the trainer's logger to
determine where to save checkpoints. The path for saving weights is set in this priority:
1. The ``ModelCheckpoint``'s ``dirpath`` if passed in
2. The ``Logger``'s ``log_dir`` if the trainer has loggers
3. The ``Trainer``'s ``default_root_dir`` if the trainer has no loggers
The path gets extended with subdirectory "checkpoints".
"""
print(f'Resolving checkpoint dir (custom)')
if self.dirpath is not None:
# short circuit if dirpath was passed to ModelCheckpoint
return self.dirpath
if len(trainer.loggers) > 0:
if trainer.loggers[0].save_dir is not None:
save_dir = trainer.loggers[0].save_dir
else:
save_dir = trainer.default_root_dir
name = trainer.loggers[0].name
version = trainer.loggers[0].version
version = version if isinstance(version, str) else f"version_{version}"
logger = trainer.loggers[0]
if isinstance(logger, WandbLogger):
ckpt_path = os.path.join(logger.experiment.dir, "checkpoints")
else:
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
else:
# if no loggers, use default_root_dir
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")

print(f'Now using checkpoint path {ckpt_path}')
return ckpt_path
3 changes: 2 additions & 1 deletion chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ def subcommands() -> Dict[str, Set[str]]:


def cli():
r = ChebaiCLI(save_config_callback=None, parser_kwargs={"parser_mode": "omegaconf"})
r = ChebaiCLI(save_config_kwargs={"config_filename": "lightning_config.yaml"},
parser_kwargs={"parser_mode": "omegaconf"})
59 changes: 59 additions & 0 deletions chebai/loggers/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from datetime import datetime
from typing import Optional, Union, Literal

import wandb
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.loggers import WandbLogger
import os


class CustomLogger(WandbLogger):
"""Adds support for custom naming of runs and cross-validation"""

def __init__(self, save_dir: _PATH, name: str = "logs", version: Optional[Union[int, str]] = None, prefix: str = "",
fold: Optional[int] = None, project: Optional[str] = None, entity: Optional[str] = None,
offline: bool = False,
log_model: Union[Literal["all"], bool] = False, **kwargs):
if version is None:
version = f'{datetime.now():%y%m%d-%H%M}'
self._version = version
self._name = name
self._fold = fold
super().__init__(name=self.name, save_dir=save_dir, version=None, prefix=prefix,
log_model=log_model, entity=entity, project=project, offline=offline, **kwargs)

@property
def name(self) -> Optional[str]:
name = f'{self._name}_{self.version}'
if self._fold is not None:
name += f'_fold{self._fold}'
return name

@property
def version(self) -> Optional[str]:
return self._version

@property
def root_dir(self) -> Optional[str]:
return os.path.join(self.save_dir, self.name)

@property
def log_dir(self) -> str:
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
if self._fold is None:
return os.path.join(self.root_dir, version)
return os.path.join(self.root_dir, version, f'fold_{self._fold}')

def set_fold(self, fold: int):
if fold != self._fold:
self._fold = fold
# start new experiment
wandb.finish()
self._wandb_init['name'] = self.name
self._name = self.name
self._experiment = None
_ = self.experiment

@property
def fold(self):
return self._fold
36 changes: 18 additions & 18 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from chebai.preprocessing.reader import MASK_TOKEN_INDEX, CLS_TOKEN
from chebai.preprocessing.datasets.chebi import extract_class_hierarchy
from chebai.loss.pretraining import ElectraPreLoss # noqa
from chebai.loss.pretraining import ElectraPreLoss # noqa
import torch
import csv

Expand Down Expand Up @@ -53,21 +53,22 @@ def _process_labels_in_batch(self, batch):

def forward(self, data, **kwargs):
features = data["features"]
features = features.to(self.device).long() # this has been added for selfies, i neither know why it is needed now, nor why it wasnt needed before
self.batch_size = batch_size = features.shape[0]
max_seq_len = features.shape[1]

mask = kwargs["mask"]
with torch.no_grad():
dis_tar = (
torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1)
torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1)
).int()
disc_tar_one_hot = torch.eq(
torch.arange(max_seq_len, device=self.device)[None, :], dis_tar[:, None]
)
gen_tar = features[disc_tar_one_hot]
gen_tar_one_hot = torch.eq(
torch.arange(self.generator_config.vocab_size, device=self.device)[
None, :
None, :
],
gen_tar[:, None],
)
Expand Down Expand Up @@ -100,7 +101,7 @@ def _get_prediction_and_labels(self, batch, labels, output):

def filter_dict(d, filter_key):
return {
str(k)[len(filter_key) :]: v
str(k)[len(filter_key):]: v
for k, v in d.items()
if str(k).startswith(filter_key)
}
Expand All @@ -121,10 +122,10 @@ def _process_batch(self, batch, batch_idx):
batch_first=True,
)
cls_tokens = (
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
)
return dict(
features=torch.cat((cls_tokens, batch.x), dim=1),
Expand All @@ -139,7 +140,7 @@ def as_pretrained(self):
return self.electra.electra

def __init__(
self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs
self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs
):
# Remove this property in order to prevent it from being stored as a
# hyper parameter
Expand Down Expand Up @@ -257,10 +258,10 @@ def _process_batch(self, batch, batch_idx):
batch_first=True,
)
cls_tokens = (
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
)
return dict(
features=torch.cat((cls_tokens, batch.x), dim=1),
Expand Down Expand Up @@ -295,7 +296,7 @@ def __init__(self, cone_dimensions=20, **kwargs):
model_dict = torch.load(fin, map_location=self.device)
if model_prefix:
state_dict = {
str(k)[len(model_prefix) :]: v
str(k)[len(model_prefix):]: v
for k, v in model_dict["state_dict"].items()
if str(k).startswith(model_prefix)
}
Expand Down Expand Up @@ -356,7 +357,7 @@ def forward(self, data, **kwargs):


def softabs(x, eps=0.01):
return (x**2 + eps) ** 0.5 - eps**0.5
return (x ** 2 + eps) ** 0.5 - eps ** 0.5


def anglify(x):
Expand All @@ -383,8 +384,8 @@ def in_cone_parts(vectors, cone_axes, cone_arcs):
dis = (torch.abs(turn(v, theta_L)) + torch.abs(turn(v, theta_R)) - cone_arc_ang)/(2*pi-cone_arc_ang)
return dis
"""
a = cone_axes - cone_arcs**2
b = cone_axes + cone_arcs**2
a = cone_axes - cone_arcs ** 2
b = cone_axes + cone_arcs ** 2
bigger_than_a = torch.sigmoid(vectors - a)
smaller_than_b = torch.sigmoid(b - vectors)
return bigger_than_a * smaller_than_b
Expand All @@ -410,4 +411,3 @@ def __call__(self, target, input):
memberships, target.unsqueeze(-1).expand(-1, -1, 20)
)
return loss

12 changes: 10 additions & 2 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
balance_after_filter: typing.Optional[float] = None,
num_workers: int = 1,
chebi_version: int = 200,
inner_k_folds: int = -1, # use inner cross-validation if > 1
inner_k_folds: int = -1, # use inner cross-validation if > 1
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -46,7 +46,7 @@ def __init__(
self.chebi_version = chebi_version
assert(type(inner_k_folds) is int)
self.inner_k_folds = inner_k_folds
self.use_inner_cross_validation = inner_k_folds > 1 # only use cv if there are at least 2 folds
self.use_inner_cross_validation = inner_k_folds > 1 # only use cv if there are at least 2 folds
os.makedirs(self.raw_dir, exist_ok=True)
os.makedirs(self.processed_dir, exist_ok=True)

Expand Down Expand Up @@ -130,6 +130,9 @@ def _load_data_from_file(self, path):
for d in tqdm.tqdm(self._load_dict(path), total=lines)
if d["features"] is not None
]
# filter for missing features in resulting data
data = [val for val in data if val['features'] is not None]

return data

def train_dataloader(self, *args, **kwargs) -> DataLoader:
Expand Down Expand Up @@ -160,6 +163,11 @@ def setup(self, **kwargs):
if self.use_inner_cross_validation:
self.train_val_data = torch.load(os.path.join(self.processed_dir, self.processed_file_names_dict['train_val']))

def teardown(self, stage: str) -> None:
# cant save hyperparams at setup because logger is not initialised yet
# not sure if this has an effect
self.save_hyperparameters()

def setup_processed(self):
raise NotImplementedError

Expand Down
Loading

0 comments on commit 9be20cb

Please sign in to comment.