diff --git a/main_pretrain.py b/main_pretrain.py index e83d5591e..ae67b1732 100644 --- a/main_pretrain.py +++ b/main_pretrain.py @@ -69,7 +69,7 @@ def main(cfg: DictConfig): assert cfg.method in METHODS, f"Choose from {METHODS.keys()}" if cfg.data.num_large_crops != 2: - assert cfg.method in ["wmse", "mae"] + assert cfg.method in ["wmse", "mae", "frossl"] model = METHODS[cfg.method](cfg) make_contiguous(model) diff --git a/scripts/pretrain/cifar/frossl.yaml b/scripts/pretrain/cifar/frossl.yaml new file mode 100644 index 000000000..e11bae3d0 --- /dev/null +++ b/scripts/pretrain/cifar/frossl.yaml @@ -0,0 +1,54 @@ +defaults: + - _self_ + - augmentations: symmetric.yaml + - wandb: private.yaml + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +# disable hydra outputs +hydra: + output_subdir: null + run: + dir: . + +name: "frossl-cifar10" # change here for cifar100 +method: "frossl" +backbone: + name: "resnet18" +method_kwargs: + proj_hidden_dim: 2048 + proj_output_dim: 1024 + invariance_weight: 1.4 + +data: + dataset: cifar10 # change here for cifar100 + train_path: "./datasets" + val_path: "./datasets" + format: "image_folder" + num_workers: 8 +optimizer: + name: "lars" + batch_size: 256 + lr: 0.3 + classifier_lr: 0.1 + weight_decay: 1e-4 + kwargs: + clip_lr: True + eta: 0.02 + exclude_bias_n_norm: True +scheduler: + name: "warmup_cosine" +checkpoint: + enabled: True + dir: "trained_models" + frequency: 1 +auto_resume: + enabled: True + +# overwrite PL stuff +max_epochs: 1000 +devices: [0] +sync_batchnorm: True +accelerator: "gpu" +strategy: "ddp" +precision: 16-mixed \ No newline at end of file diff --git a/scripts/pretrain/imagenet-100/frossl.yaml b/scripts/pretrain/imagenet-100/frossl.yaml new file mode 100644 index 000000000..bc0c3187b --- /dev/null +++ b/scripts/pretrain/imagenet-100/frossl.yaml @@ -0,0 +1,53 @@ +defaults: + - _self_ + - augmentations: symmetric.yaml + - wandb: private.yaml + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +# disable hydra outputs +hydra: + output_subdir: null + run: + dir: . + +name: "frossl-imagenet100" +method: "frossl" +backbone: + name: "resnet18" +method_kwargs: + proj_hidden_dim: 2048 + proj_output_dim: 1024 + invariance_weight: 2.0 +data: + dataset: imagenet100 + train_path: "./datasets/imagenet100/train" + val_path: "./datasets/imagenet100/val" + format: "dali" + num_workers: 16 +optimizer: + name: "lars" + batch_size: 256 + lr: 0.3 + classifier_lr: 0.1 + weight_decay: 1e-4 + kwargs: + clip_lr: True + eta: 0.02 + exclude_bias_n_norm: True +scheduler: + name: "warmup_cosine" +checkpoint: + enabled: True + dir: "trained_models" + frequency: 1 +auto_resume: + enabled: False + +# overwrite PL stuff +max_epochs: 400 +devices: [0, 1] +sync_batchnorm: True +accelerator: "gpu" +strategy: "ddp" +precision: 16-mixed \ No newline at end of file diff --git a/scripts/pretrain/imagenet/frossl.yaml b/scripts/pretrain/imagenet/frossl.yaml new file mode 100644 index 000000000..200b8a5e9 --- /dev/null +++ b/scripts/pretrain/imagenet/frossl.yaml @@ -0,0 +1,54 @@ +defaults: + - _self_ + - augmentations: vicreg.yaml + - wandb: private.yaml + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +# disable hydra outputs +hydra: + output_subdir: null + run: + dir: . + +name: "frossl-imagenet" +method: "frossl" +backbone: + name: "resnet18" +method_kwargs: + proj_hidden_dim: 2048 + proj_output_dim: 1024 + invariance_weight: 2.0 + +data: + dataset: imagenet + train_path: "./datasets/imagenet/train" + val_path: "./datasets/imagenet/val" + format: "dali" + num_workers: 8 +optimizer: + name: "lars" + batch_size: 256 + lr: 0.3 + classifier_lr: 0.1 + weight_decay: 1e-4 + kwargs: + clip_lr: True + eta: 0.02 + exclude_bias_n_norm: True +scheduler: + name: "warmup_cosine" +checkpoint: + enabled: True + dir: "trained_models" + frequency: 1 +auto_resume: + enabled: True + +# overwrite PL stuff +max_epochs: 100 +devices: [0, 1] +sync_batchnorm: True +accelerator: "gpu" +strategy: "ddp" +precision: 16-mixed \ No newline at end of file diff --git a/solo/losses/__init__.py b/solo/losses/__init__.py index 23a365d63..0b0c7275b 100644 --- a/solo/losses/__init__.py +++ b/solo/losses/__init__.py @@ -21,6 +21,7 @@ from solo.losses.byol import byol_loss_func from solo.losses.deepclusterv2 import deepclusterv2_loss_func from solo.losses.dino import DINOLoss +from solo.losses.frossl import frossl_loss_func from solo.losses.mae import mae_loss_func from solo.losses.mocov2plus import mocov2plus_loss_func from solo.losses.mocov3 import mocov3_loss_func @@ -38,6 +39,7 @@ "byol_loss_func", "deepclusterv2_loss_func", "DINOLoss", + "frossl_loss_func", "mae_loss_func", "mocov2plus_loss_func", "mocov3_loss_func", diff --git a/solo/losses/frossl.py b/solo/losses/frossl.py new file mode 100644 index 000000000..e37bb4d58 --- /dev/null +++ b/solo/losses/frossl.py @@ -0,0 +1,89 @@ +# Copyright 2024 solo-learn development team. + +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +# Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies +# or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from typing import Any, List, Sequence, Dict +import torch +import torch.distributed as dist +import torch.nn.functional as F + +def calculate_frobenius_regularization_term(z: torch.Tensor) -> torch.Tensor: + V, N, D = z.shape + + if N > D: + cov = torch.matmul(z.transpose(1, 2), z) # V x D x D + else: + cov = torch.matmul(z, z.transpose(1, 2)) # V x N x N + + # divide each view covariance by its trace + trace = torch.diagonal(cov, dim1=1, dim2=2) # V x D + trace = torch.sum(trace, dim=1) # V x 1 + cov = cov / trace.unsqueeze(-1).unsqueeze(-1) + + # REGULARIZATION TERM - sum the log-frobenius norm of each view covariance matrix + fro_norm_per_view = torch.linalg.norm(cov, dim=(1,2), ord='fro') # V x 1 + regularization_term = -torch.sum( 2*torch.log(fro_norm_per_view) ) # we bring frobenius square outside log + + return regularization_term + +def calculate_invariance_term(z: torch.Tensor) -> torch.Tensor: + V, N, D = z.shape + + # INVARIANCE - align each view to the average view + average_z = torch.mean(z, dim=0) # N x D, samples are averaged across views + average_z = average_z.repeat(V, 1, 1) # V x N x D + invariance_loss_term = F.mse_loss(z, average_z) + + return invariance_loss_term + +def frossl_loss_func( + z: torch.Tensor, invariance_weight=1, logger=None +) -> torch.Tensor: + """ + Implements FroSSL (https://arxiv.org/pdf/2310.02903) + Heavily adapted from https://github.com/OFSkean/FroSSL. The main difference is that this + implementation stacks the views and operates on all of them at once, rather than one at a time. + This saves ~2 seconds (about 5% improvement) per batch with N=2,D=1024 on a A5000 GPU. For a simpler, + ableit slower, implementation of loss that operates on one view at a time, please see + the original implementation. + + Args: + z (torch.Tensor): V x N x D Tensor containing projected features from the views. + Every N-th sample is a different view of the same image. + invariance_weight (float): weight for the invariance loss term. default is 1. + + Return: + torch.Tensor: FroSSL loss. + """ + V, N, D = z.shape + + z = F.normalize(z, dim=1) # V x N x D + + regularization_term = calculate_frobenius_regularization_term(z) + regularization_term = -1 * regularization_term # make sure its maximized + + invariance_tradeoff = V * D * invariance_weight + invariance_term = calculate_invariance_term(z) + invariance_term = invariance_tradeoff * invariance_term + + if logger is not None: + logger("frossl_regularization_loss", -regularization_term, sync_dist=True) + logger("frossl_invariance_loss", invariance_term, sync_dist=True) + + total_loss = regularization_term + invariance_term + return total_loss \ No newline at end of file diff --git a/solo/methods/__init__.py b/solo/methods/__init__.py index 720182625..5a3499883 100644 --- a/solo/methods/__init__.py +++ b/solo/methods/__init__.py @@ -22,6 +22,7 @@ from solo.methods.byol import BYOL from solo.methods.deepclusterv2 import DeepClusterV2 from solo.methods.dino import DINO +from solo.methods.frossl import FroSSL from solo.methods.linear import LinearModel from solo.methods.mae import MAE from solo.methods.mocov2plus import MoCoV2Plus @@ -49,6 +50,7 @@ "byol": BYOL, "deepclusterv2": DeepClusterV2, "dino": DINO, + "frossl": FroSSL, "mae": MAE, "mocov2plus": MoCoV2Plus, "mocov3": MoCoV3, diff --git a/solo/methods/frossl.py b/solo/methods/frossl.py new file mode 100644 index 000000000..be32ea56e --- /dev/null +++ b/solo/methods/frossl.py @@ -0,0 +1,141 @@ +# Copyright 2024 solo-learn development team. + +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +# Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies +# or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from typing import Any, List, Sequence, Dict + +import omegaconf +import torch +import torch.nn as nn +from solo.methods.base import BaseMethod +from solo.utils.misc import omegaconf_select, gather +from solo.losses.frossl import frossl_loss_func + +class FroSSL(BaseMethod): + def __init__(self, cfg: omegaconf.DictConfig): + """Implements FroSSL (https://arxiv.org/pdf/2310.02903) + Heavily adapted from https://github.com/OFSkean/FroSSL + + Extra cfg settings: + method_kwargs: + proj_hidden_dim (int): number of neurons of the hidden layers of the projector. + proj_output_dim (int): number of dimensions of projected features. + invariance_weight (float): weight of the invariance loss term. + """ + + super().__init__(cfg) + + self.invariance_weight: float = cfg.method_kwargs.invariance_weight + + proj_hidden_dim: int = cfg.method_kwargs.proj_hidden_dim + proj_output_dim: int = cfg.method_kwargs.proj_output_dim + + # projector + self.projector = nn.Sequential( + nn.Linear(self.features_dim, proj_hidden_dim), + nn.BatchNorm1d(proj_hidden_dim), + nn.ReLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.BatchNorm1d(proj_hidden_dim), + nn.ReLU(), + nn.Linear(proj_hidden_dim, proj_output_dim), + ) + + @staticmethod + def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig: + """Adds method specific default values/checks for config. + + Args: + cfg (omegaconf.DictConfig): DictConfig object. + + Returns: + omegaconf.DictConfig: same as the argument, used to avoid errors. + """ + + cfg = super(FroSSL, FroSSL).add_and_assert_specific_cfg(cfg) + + assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_hidden_dim") + assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_output_dim") + + cfg.method_kwargs.invariance_weight = omegaconf_select(cfg, "method_kwargs.invariance_weight", 1.0) + return cfg + + @property + def learnable_params(self) -> List[dict]: + """Adds projector parameters to parent's learnable parameters. + + Returns: + List[dict]: list of learnable parameters. + """ + + extra_learnable_params = [{"name": "projector", "params": self.projector.parameters()}] + return super().learnable_params + extra_learnable_params + + def forward(self, X): + """Performs the forward pass of the backbone and the projector. + + Args: + X (torch.Tensor): a batch of images in the tensor format. + + Returns: + Dict[str, Any]: a dict containing the outputs of the parent and the projected features. + """ + + out = super().forward(X) + z = self.projector(out["feats"]) + out.update({"z": z}) + return out + + def multicrop_forward(self, X: torch.tensor) -> Dict[str, Any]: + """Performs the forward pass for the multicrop views. + + Args: + X (torch.Tensor): batch of images in tensor format. + + Returns: + Dict[]: a dict containing the outputs of the parent + and the projected features. + """ + + out = super().multicrop_forward(X) + z = self.projector(out["feats"]) + out.update({"z": z}) + return out + + def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: + """Training step for FroSSL reusing BaseMethod training step. + + Args: + batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where + [X] is a list of size num_crops containing batches of images. + batch_idx (int): index of the batch. + + Returns: + torch.Tensor: total loss composed of FroSSL loss and classification loss. + """ + + out = super().training_step(batch, batch_idx) + class_loss = out["loss"] + + z = torch.stack(out["z"], dim=0) # V x N_per_gpu x D + z = gather(z, dim=1) # V x N_total x D + + frossl_loss = frossl_loss_func(z, invariance_weight=self.invariance_weight, logger=self.log) + self.log("train_frossl_loss", frossl_loss, on_epoch=True, sync_dist=True) + + return frossl_loss + class_loss \ No newline at end of file diff --git a/tests/losses/test_frossl.py b/tests/losses/test_frossl.py new file mode 100644 index 000000000..13110c30e --- /dev/null +++ b/tests/losses/test_frossl.py @@ -0,0 +1,68 @@ +# Copyright 2024 solo-learn development team. + +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +# Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies +# or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch +from solo.losses import frossl_loss_func + + +def test_frossl_loss_D_greaterthan_N(): + b, f = 32, 128 + z1 = torch.randn(b, f).requires_grad_() + z2 = torch.randn(b, f).requires_grad_() + + z = torch.stack([z1, z2], dim=0) + loss = frossl_loss_func(z, invariance_weight=1.4) + initial_loss = loss.item() + assert initial_loss != 0 + + for _ in range(20): + z = torch.stack([z1, z2], dim=0) + + loss = frossl_loss_func(z, invariance_weight=1.4) + loss.backward() + z1.data.add_(-0.5 * z1.grad) + z2.data.add_(-0.5 * z2.grad) + + z1.grad = z2.grad = None + + assert loss < initial_loss + + + +def test_frossl_loss_N_greaterthan_D(): + b, f = 128, 32 + z1 = torch.randn(b, f).requires_grad_() + z2 = torch.randn(b, f).requires_grad_() + + z = torch.stack([z1, z2], dim=0) + loss = frossl_loss_func(z, invariance_weight=1.4) + initial_loss = loss.item() + assert initial_loss != 0 + + for _ in range(20): + z = torch.stack([z1, z2], dim=0) + + loss = frossl_loss_func(z, invariance_weight=1.4) + loss.backward() + z1.data.add_(-0.5 * z1.grad) + z2.data.add_(-0.5 * z2.grad) + + z1.grad = z2.grad = None + + assert loss < initial_loss diff --git a/tests/methods/test_frossl.py b/tests/methods/test_frossl.py new file mode 100644 index 000000000..0b7687079 --- /dev/null +++ b/tests/methods/test_frossl.py @@ -0,0 +1,131 @@ +# Copyright 2024 solo-learn development team. + +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +# Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies +# or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch +from solo.methods import FroSSL + +from .utils import gen_base_cfg, gen_batch, gen_trainer, prepare_dummy_dataloaders + + +def test_frossl(): + method_kwargs = { + "proj_hidden_dim": 2048, + "proj_output_dim": 256, + "invariance_weight": 1.4, + } + + cfg = gen_base_cfg("frossl", batch_size=2, num_classes=100, momentum=True) + cfg.method_kwargs = method_kwargs + model = FroSSL(cfg) + + # test arguments + model.add_and_assert_specific_cfg(cfg) + + # test parameters + assert model.learnable_params is not None + + # test forward + batch, _ = gen_batch(cfg.optimizer.batch_size, cfg.data.num_classes, "imagenet100") + out = model(batch[1][0]) + assert ( + "logits" in out + and isinstance(out["logits"], torch.Tensor) + and out["logits"].size() == (cfg.optimizer.batch_size, cfg.data.num_classes) + ) + assert ( + "feats" in out + and isinstance(out["feats"], torch.Tensor) + and out["feats"].size() == (cfg.optimizer.batch_size, model.features_dim) + ) + assert ( + "z" in out + and isinstance(out["z"], torch.Tensor) + and out["z"].size() == (cfg.optimizer.batch_size, method_kwargs["proj_output_dim"]) + ) + print('here') + + multicrop_out = model.multicrop_forward(batch[1][0]) + assert ( + "feats" in multicrop_out + and isinstance(multicrop_out["feats"], torch.Tensor) + and multicrop_out["feats"].size() == (cfg.optimizer.batch_size, model.features_dim) + ) + assert ( + "z" in multicrop_out + and isinstance(multicrop_out["z"], torch.Tensor) + and multicrop_out["z"].size() + == (cfg.optimizer.batch_size, method_kwargs["proj_output_dim"]) + ) + + # imagenet + model = FroSSL(cfg) + + trainer = gen_trainer(cfg) + train_dl, val_dl = prepare_dummy_dataloaders( + "imagenet100", + num_large_crops=cfg.data.num_large_crops, + num_small_crops=0, + num_classes=cfg.data.num_classes, + batch_size=cfg.optimizer.batch_size, + ) + trainer.fit(model, train_dl, val_dl) + + # cifar + cfg.data.dataset = "cifar10" + cfg.data.num_classes = 10 + model = FroSSL(cfg) + + trainer = gen_trainer(cfg) + train_dl, val_dl = prepare_dummy_dataloaders( + "cifar10", + num_large_crops=cfg.data.num_large_crops, + num_small_crops=0, + num_classes=cfg.data.num_classes, + batch_size=cfg.optimizer.batch_size, + ) + trainer.fit(model, train_dl, val_dl) + + # multicrop + cfg.data.num_small_crops = 6 + model = FroSSL(cfg) + + trainer = gen_trainer(cfg) + train_dl, val_dl = prepare_dummy_dataloaders( + "imagenet100", + num_large_crops=cfg.data.num_large_crops, + num_small_crops=cfg.data.num_small_crops, + num_classes=cfg.data.num_classes, + batch_size=cfg.optimizer.batch_size, + ) + trainer.fit(model, train_dl, val_dl) + + # 8 large views + cfg.data.num_small_crops = 8 + cfg.data.num_small_crops = 0 + model = FroSSL(cfg) + + trainer = gen_trainer(cfg) + train_dl, val_dl = prepare_dummy_dataloaders( + "imagenet100", + num_large_crops=cfg.data.num_large_crops, + num_small_crops=cfg.data.num_small_crops, + num_classes=cfg.data.num_classes, + batch_size=cfg.optimizer.batch_size, + ) + trainer.fit(model, train_dl, val_dl) \ No newline at end of file