Skip to content

Commit

Permalink
Move quantization inside lora surgery to fix backprop issues
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jan 12, 2025
1 parent daaa316 commit de6f955
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 37 deletions.
17 changes: 9 additions & 8 deletions finetuning/livecell_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,19 @@ def finetune_livecell(args):
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True}

# NOTE: memory req. for all vit_b models (compared on A100 80GB)
# NOTE 1: memory req. for all vit_b models (compared on A100 80GB).
# NOTE 2: all lora mentions are with rank 16.
# vit_b
# freeze_encoder: ~ 33.89 GB
# QLoRA: ~48.54 GB
# LoRA: ~48.62 GB
# FFT: ~49.56 GB
# QLoRA: ~35.34 GB
# LoRA: ~48.92 GB
# FFT: ~49.84 GB

# vit_h
# freeze_encoder: ~36.05 GB
# QLoRA: ~ 65.68 GB
# LoRA: ~ 67.14 GB
# FFT: ~72.34 GB
# freeze_encoder: ~36.33 GB
# QLoRA: ~ 36.41 GB
# LoRA: ~ 67.52 GB
# FFT: ~72.79 GB

# Run training.
sam_training.train_sam(
Expand Down
45 changes: 16 additions & 29 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,26 @@ class LoRASurgery(nn.Module):
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing LoRA.
"""
def __init__(self, rank: int, block: nn.Module):
def __init__(self, rank: int, block: nn.Module, quantize: bool = False):
super().__init__()
self.qkv_proj = block.attn.qkv
self.dim = self.qkv_proj.in_features
self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance.
self.alpha = 1 # NOTE: From our experiments, 'alpha' as 1 gives the best performance.
self.rank = rank

self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
# Whether to quantize the linear layers to 4 bit precision.
# NOTE: This is currently supported for CUDA-supported devices only.
if quantize:
if not _have_bnb:
raise ModuleNotFoundError("Please install 'bitsandbytes'.")
linear_layer_class = bnb.nn.Linear4bit
else:
linear_layer_class = nn.Linear

self.w_a_linear_q = linear_layer_class(self.dim, self.rank, bias=False)
self.w_b_linear_q = linear_layer_class(self.rank, self.dim, bias=False)
self.w_a_linear_v = linear_layer_class(self.dim, self.rank, bias=False)
self.w_b_linear_v = linear_layer_class(self.rank, self.dim, bias=False)

self.reset_parameters()

Expand Down Expand Up @@ -301,7 +310,7 @@ class PEFT_Sam(nn.Module):
rank: The rank for low-rank adaptation.
peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
attention_layers_to_update: Which specific layers we apply PEFT methods to.
quantize: Whether to quantize the model for lower precision training.
module_kwargs: Additional arguments supported by the peft modules.
"""

def __init__(
Expand All @@ -310,7 +319,6 @@ def __init__(
rank: int,
peft_module: nn.Module = LoRASurgery,
attention_layers_to_update: Union[List[int]] = None,
quantize: bool = False,
**module_kwargs
):
super().__init__()
Expand Down Expand Up @@ -346,27 +354,6 @@ def __init__(
else:
self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))

# Whether to quantize the linear layers to 4 bit precision.
# NOTE: This is currently supported for CUDA-supported devices only.
if quantize:
if not _have_bnb:
raise ModuleNotFoundError("Please install 'bitsandbytes'.")

for name, module in model.image_encoder.named_modules():
if isinstance(module, torch.nn.Linear):
*parent_path, layer_name = name.split(".")

# We avoid quantizing the MLP layers and the qkv projection layer in the attention block.
if "mlp" in parent_path or "qkv_proj" == layer_name:
continue

parent_module = model.image_encoder

for sub_module in parent_path:
parent_module = getattr(parent_module, sub_module)

setattr(parent_module, layer_name, bnb.nn.Linear4bit(module.in_features, module.out_features))

self.peft_blocks = nn.ModuleList(self.peft_blocks)

self.sam = model
Expand Down

0 comments on commit de6f955

Please sign in to comment.