From 5dbdf29fd2f9595478e27cca91044de2bd2736b0 Mon Sep 17 00:00:00 2001 From: Carolin Date: Sat, 11 Jan 2025 21:47:53 +0100 Subject: [PATCH 1/3] update quantization --- micro_sam/models/peft_sam.py | 32 +++++++++++++++---------------- micro_sam/util.py | 37 ++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index d72295c6..1f70969b 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -328,22 +328,6 @@ def __init__( self.peft_module = peft_module self.peft_blocks = [] - # Whether to quantize the linear layers to 4 bit precision. - # NOTE: This is currently supported for CUDA-supported devices only. - if quantize: - if not _have_bnb: - raise ModuleNotFoundError("Please install 'bitsandbytes'.") - - for name, module in model.image_encoder.named_modules(): - if isinstance(module, torch.nn.Linear): - *parent_path, layer_name = name.split(".") - parent_module = model.image_encoder - - for sub_module in parent_path: - parent_module = getattr(parent_module, sub_module) - - setattr(parent_module, layer_name, bnb.nn.Linear4bit(module.in_features, module.out_features)) - # Let's freeze all the pretrained image encoder layers first for param in model.image_encoder.parameters(): param.requires_grad = False @@ -362,6 +346,22 @@ def __init__( else: self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) + # Whether to quantize the linear layers to 4 bit precision. + # NOTE: This is currently supported for CUDA-supported devices only. + if quantize: + if not _have_bnb: + raise ModuleNotFoundError("Please install 'bitsandbytes'.") + + for name, module in model.image_encoder.named_modules(): + if isinstance(module, torch.nn.Linear): + *parent_path, layer_name = name.split(".") + parent_module = model.image_encoder + + for sub_module in parent_path: + parent_module = getattr(parent_module, sub_module) + + setattr(parent_module, layer_name, bnb.nn.Linear4bit(module.in_features, module.out_features)) + self.peft_blocks = nn.ModuleList(self.peft_blocks) self.sam = model diff --git a/micro_sam/util.py b/micro_sam/util.py index ba12e550..9d84d55f 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -375,9 +375,46 @@ def get_sam_model( if abbreviated_model_type == "vit_t": raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") + _quantize = peft_kwargs.pop("quantize", None) + sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam + if _quantize: + import bitsandbytes as bnb + for name, module in sam.image_encoder.named_modules(): + if isinstance(module, torch.nn.Linear): + *parent_path, layer_name = name.split(".") + parent_module = sam.image_encoder + + for sub_module in parent_path: + parent_module = getattr(parent_module, sub_module) + + # Extract weight and bias from the state_dict + weight_data = model_state.pop(f"image_encoder.{'.'.join(parent_path)}.{layer_name}.weight") + bias_data = model_state.pop(f"image_encoder.{'.'.join(parent_path)}.{layer_name}.bias", None) + + layer_state_dict = { + k.split(f"image_encoder.{'.'.join(parent_path)}.")[1]: v + for k, v in model_state.items() if k.startswith(f"image_encoder.{'.'.join(parent_path)}.{layer_name}") + } + + # Recreate the Linear4bit layer and load weights + linear_q4bit = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + bias=True, + ) + + # Assign the quantized weights to the new layer + linear_q4bit.weight = bnb.nn.Params4bit.from_prequantized(quantized_stats=layer_state_dict, data=weight_data) + if bias_data is not None: + linear_q4bit.bias = torch.nn.Parameter(bias_data) + + # Replace the original linear layer with the quantized one + setattr(parent_module, layer_name, linear_q4bit) + # In case the model checkpoints have some issues when it is initialized with different parameters than default. + if flexible_load_checkpoint: sam = _handle_checkpoint_loading(sam, model_state) else: From 8ee6829e3d410f5a6d3e2dcf62f4a4a341ffe671 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sun, 12 Jan 2025 00:19:29 +0100 Subject: [PATCH 2/3] Minor corner fix to avoid quantization in other parts --- micro_sam/models/peft_sam.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 1f70969b..59485b49 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -355,6 +355,11 @@ def __init__( for name, module in model.image_encoder.named_modules(): if isinstance(module, torch.nn.Linear): *parent_path, layer_name = name.split(".") + + # We avoid quantizing the MLP layers and the qkv projection layer in the attention block. + if "mlp" in parent_path or "qkv_proj" == layer_name: + continue + parent_module = model.image_encoder for sub_module in parent_path: From de6f955dc359017fb6e5afdf20aaf72e75798f3b Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sun, 12 Jan 2025 01:20:24 +0100 Subject: [PATCH 3/3] Move quantization inside lora surgery to fix backprop issues --- finetuning/livecell_finetuning.py | 17 ++++++------ micro_sam/models/peft_sam.py | 45 +++++++++++-------------------- 2 files changed, 25 insertions(+), 37 deletions(-) diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index 96d7143c..2b4cfc16 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -56,18 +56,19 @@ def finetune_livecell(args): train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True} - # NOTE: memory req. for all vit_b models (compared on A100 80GB) + # NOTE 1: memory req. for all vit_b models (compared on A100 80GB). + # NOTE 2: all lora mentions are with rank 16. # vit_b # freeze_encoder: ~ 33.89 GB - # QLoRA: ~48.54 GB - # LoRA: ~48.62 GB - # FFT: ~49.56 GB + # QLoRA: ~35.34 GB + # LoRA: ~48.92 GB + # FFT: ~49.84 GB # vit_h - # freeze_encoder: ~36.05 GB - # QLoRA: ~ 65.68 GB - # LoRA: ~ 67.14 GB - # FFT: ~72.34 GB + # freeze_encoder: ~36.33 GB + # QLoRA: ~ 36.41 GB + # LoRA: ~ 67.52 GB + # FFT: ~72.79 GB # Run training. sam_training.train_sam( diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 59485b49..606f5beb 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -30,17 +30,26 @@ class LoRASurgery(nn.Module): rank: The rank of the decomposition matrices for updating weights in each attention layer. block: The chosen attention blocks for implementing LoRA. """ - def __init__(self, rank: int, block: nn.Module): + def __init__(self, rank: int, block: nn.Module, quantize: bool = False): super().__init__() self.qkv_proj = block.attn.qkv self.dim = self.qkv_proj.in_features - self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. + self.alpha = 1 # NOTE: From our experiments, 'alpha' as 1 gives the best performance. self.rank = rank - self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) - self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) - self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) - self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) + # Whether to quantize the linear layers to 4 bit precision. + # NOTE: This is currently supported for CUDA-supported devices only. + if quantize: + if not _have_bnb: + raise ModuleNotFoundError("Please install 'bitsandbytes'.") + linear_layer_class = bnb.nn.Linear4bit + else: + linear_layer_class = nn.Linear + + self.w_a_linear_q = linear_layer_class(self.dim, self.rank, bias=False) + self.w_b_linear_q = linear_layer_class(self.rank, self.dim, bias=False) + self.w_a_linear_v = linear_layer_class(self.dim, self.rank, bias=False) + self.w_b_linear_v = linear_layer_class(self.rank, self.dim, bias=False) self.reset_parameters() @@ -301,7 +310,7 @@ class PEFT_Sam(nn.Module): rank: The rank for low-rank adaptation. peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. attention_layers_to_update: Which specific layers we apply PEFT methods to. - quantize: Whether to quantize the model for lower precision training. + module_kwargs: Additional arguments supported by the peft modules. """ def __init__( @@ -310,7 +319,6 @@ def __init__( rank: int, peft_module: nn.Module = LoRASurgery, attention_layers_to_update: Union[List[int]] = None, - quantize: bool = False, **module_kwargs ): super().__init__() @@ -346,27 +354,6 @@ def __init__( else: self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) - # Whether to quantize the linear layers to 4 bit precision. - # NOTE: This is currently supported for CUDA-supported devices only. - if quantize: - if not _have_bnb: - raise ModuleNotFoundError("Please install 'bitsandbytes'.") - - for name, module in model.image_encoder.named_modules(): - if isinstance(module, torch.nn.Linear): - *parent_path, layer_name = name.split(".") - - # We avoid quantizing the MLP layers and the qkv projection layer in the attention block. - if "mlp" in parent_path or "qkv_proj" == layer_name: - continue - - parent_module = model.image_encoder - - for sub_module in parent_path: - parent_module = getattr(parent_module, sub_module) - - setattr(parent_module, layer_name, bnb.nn.Linear4bit(module.in_features, module.out_features)) - self.peft_blocks = nn.ModuleList(self.peft_blocks) self.sam = model