From f0fdb0a3137e65ae88b68e6be53d1acc784c16a5 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Mon, 6 Jan 2025 16:00:09 +0100 Subject: [PATCH] Add QLoRA (#820) Add QLoRA implementation --- finetuning/livecell_finetuning.py | 17 ++++++++++++-- micro_sam/models/peft_sam.py | 39 +++++++++++++++++++++++++++---- micro_sam/training/sam_trainer.py | 8 +++---- micro_sam/training/training.py | 21 +++++++---------- 4 files changed, 62 insertions(+), 23 deletions(-) diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index f32986f57..96d7143ca 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -56,6 +56,19 @@ def finetune_livecell(args): train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True} + # NOTE: memory req. for all vit_b models (compared on A100 80GB) + # vit_b + # freeze_encoder: ~ 33.89 GB + # QLoRA: ~48.54 GB + # LoRA: ~48.62 GB + # FFT: ~49.56 GB + + # vit_h + # freeze_encoder: ~36.05 GB + # QLoRA: ~ 65.68 GB + # LoRA: ~ 67.14 GB + # FFT: ~72.34 GB + # Run training. sam_training.train_sam( name=checkpoint_name, @@ -72,7 +85,7 @@ def finetune_livecell(args): save_root=args.save_root, scheduler_kwargs=scheduler_kwargs, save_every_kth_epoch=args.save_every_kth_epoch, - peft_kwargs={"rank": args.lora_rank} if args.lora_rank is not None else None, + peft_kwargs={"rank": args.lora_rank, "quantize": True} if args.lora_rank is not None else None, ) if args.export_path is not None: @@ -87,7 +100,7 @@ def finetune_livecell(args): def main(): parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LIVECell dataset.") parser.add_argument( - "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/", + "--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/data/livecell/", help="The filepath to the LIVECell data. If the data does not exist yet it will be downloaded." ) parser.add_argument( diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 06ea7e4d5..d72295c61 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -6,6 +6,12 @@ from segment_anything.modeling import Sam +try: + import bitsandbytes as bnb + _have_bnb = True +except ImportError: + _have_bnb = False + class LoRASurgery(nn.Module): """Operates on the attention layers for performing low-rank adaptation. @@ -22,7 +28,7 @@ class LoRASurgery(nn.Module): Args: rank: The rank of the decomposition matrices for updating weights in each attention layer. - block: The chosen attention blocks for implementing lora. + block: The chosen attention blocks for implementing LoRA. """ def __init__(self, rank: int, block: nn.Module): super().__init__() @@ -50,8 +56,14 @@ def forward(self, x): qkv = self.qkv_proj(x) # B, N, N, 3 * org_C new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) - qkv[:, :, :, :self.dim] += new_q - qkv[:, :, :, -self.dim:] += new_v + qkv = torch.cat( + [ + qkv[:, :, :, :self.dim] + new_q, # replacing new q values + qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical + qkv[:, :, :, -self.dim:] + new_v # replacing new v values + ], dim=-1 + ) + return qkv @@ -289,6 +301,7 @@ class PEFT_Sam(nn.Module): rank: The rank for low-rank adaptation. peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. attention_layers_to_update: Which specific layers we apply PEFT methods to. + quantize: Whether to quantize the model for lower precision training. """ def __init__( @@ -297,12 +310,12 @@ def __init__( rank: int, peft_module: nn.Module = LoRASurgery, attention_layers_to_update: Union[List[int]] = None, + quantize: bool = False, **module_kwargs ): super().__init__() assert rank > 0 - assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( "Invalid PEFT module" ) @@ -315,7 +328,23 @@ def __init__( self.peft_module = peft_module self.peft_blocks = [] - # let's freeze all the pretrained image encoder layers first + # Whether to quantize the linear layers to 4 bit precision. + # NOTE: This is currently supported for CUDA-supported devices only. + if quantize: + if not _have_bnb: + raise ModuleNotFoundError("Please install 'bitsandbytes'.") + + for name, module in model.image_encoder.named_modules(): + if isinstance(module, torch.nn.Linear): + *parent_path, layer_name = name.split(".") + parent_module = model.image_encoder + + for sub_module in parent_path: + parent_module = getattr(parent_module, sub_module) + + setattr(parent_module, layer_name, bnb.nn.Linear4bit(module.in_features, module.out_features)) + + # Let's freeze all the pretrained image encoder layers first for param in model.image_encoder.parameters(): param.requires_grad = False diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 020413e29..453df4c3b 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -218,10 +218,10 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim with torch.no_grad(): net_mean_model_iou = torch.mean(batched_iou_predictions) - loss += net_loss - mask_loss += net_mask_loss - iou_regression_loss += net_iou_regression_loss - mean_model_iou += net_mean_model_iou + loss = loss + net_loss + mask_loss = mask_loss + net_mask_loss + iou_regression_loss = iou_regression_loss + net_iou_regression_loss + mean_model_iou = mean_model_iou + net_mean_model_iou if i < (num_subiter - 1): # We need not update the prompts for the last iteration. # Determine the next prompts based on current predictions. diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 79485b8d0..c576e9dfa 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -193,22 +193,18 @@ def train_sam( """Run training for a SAM model. Args: - name: The name of the model to be trained. - The checkpoint and logs wil have this name. + name: The name of the model to be trained. The checkpoint and logs will have this name. model_type: The type of the SAM model. train_loader: The dataloader for training. val_loader: The dataloader for validation. n_epochs: The number of epochs to train for. - early_stopping: Enable early stopping after this number of epochs - without improvement. + early_stopping: Enable early stopping after this number of epochs without improvement. n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled. checkpoint_path: Path to checkpoint for initializing the SAM model. - with_segmentation_decoder: Whether to train additional UNETR decoder - for automatic instance segmentation. - freeze: Specify parts of the model that should be frozen, namely: - image_encoder, prompt_encoder and mask_decoder + with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. + freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated. device: The device to use for training. lr: The learning rate. @@ -226,10 +222,10 @@ def train_sam( optimizer_class: The optimizer class. By default, torch.optim.AdamW is used. peft_kwargs: Keyword arguments for the PEFT wrapper class. + ignore_warnings: Whether to ignore raised warnings. verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders. model_kwargs: Additional keyword arguments for the `util.get_sam_model`. - ignore_warnings: Whether to ignore raised warnings. """ with _filter_warnings(ignore_warnings): @@ -269,10 +265,11 @@ def train_sam( if not param_name.startswith("encoder"): joint_model_params.append(params) - optimizer = optimizer_class(joint_model_params, lr=lr) - + model_params = joint_model_params else: - optimizer = optimizer_class(model.parameters(), lr=lr) + model_params = model.parameters() + + optimizer = optimizer_class(model_params, lr=lr) if scheduler_kwargs is None: scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}