Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move quantization after LoRA surgery #828

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions finetuning/livecell_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,19 @@ def finetune_livecell(args):
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True}

# NOTE: memory req. for all vit_b models (compared on A100 80GB)
# NOTE 1: memory req. for all vit_b models (compared on A100 80GB).
# NOTE 2: all lora mentions are with rank 16.
# vit_b
# freeze_encoder: ~ 33.89 GB
# QLoRA: ~48.54 GB
# LoRA: ~48.62 GB
# FFT: ~49.56 GB
# QLoRA: ~35.34 GB
# LoRA: ~48.92 GB
# FFT: ~49.84 GB

# vit_h
# freeze_encoder: ~36.05 GB
# QLoRA: ~ 65.68 GB
# LoRA: ~ 67.14 GB
# FFT: ~72.34 GB
# freeze_encoder: ~36.33 GB
# QLoRA: ~ 36.41 GB
# LoRA: ~ 67.52 GB
# FFT: ~72.79 GB
anwai98 marked this conversation as resolved.
Show resolved Hide resolved

# Run training.
sam_training.train_sam(
Expand Down
40 changes: 16 additions & 24 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,26 @@ class LoRASurgery(nn.Module):
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing LoRA.
"""
def __init__(self, rank: int, block: nn.Module):
def __init__(self, rank: int, block: nn.Module, quantize: bool = False):
super().__init__()
self.qkv_proj = block.attn.qkv
self.dim = self.qkv_proj.in_features
self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance.
self.alpha = 1 # NOTE: From our experiments, 'alpha' as 1 gives the best performance.
self.rank = rank

self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
# Whether to quantize the linear layers to 4 bit precision.
# NOTE: This is currently supported for CUDA-supported devices only.
if quantize:
if not _have_bnb:
raise ModuleNotFoundError("Please install 'bitsandbytes'.")
linear_layer_class = bnb.nn.Linear4bit
else:
linear_layer_class = nn.Linear

self.w_a_linear_q = linear_layer_class(self.dim, self.rank, bias=False)
self.w_b_linear_q = linear_layer_class(self.rank, self.dim, bias=False)
self.w_a_linear_v = linear_layer_class(self.dim, self.rank, bias=False)
self.w_b_linear_v = linear_layer_class(self.rank, self.dim, bias=False)

self.reset_parameters()

Expand Down Expand Up @@ -301,7 +310,7 @@ class PEFT_Sam(nn.Module):
rank: The rank for low-rank adaptation.
peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
attention_layers_to_update: Which specific layers we apply PEFT methods to.
quantize: Whether to quantize the model for lower precision training.
module_kwargs: Additional arguments supported by the peft modules.
"""

def __init__(
Expand All @@ -310,7 +319,6 @@ def __init__(
rank: int,
peft_module: nn.Module = LoRASurgery,
attention_layers_to_update: Union[List[int]] = None,
quantize: bool = False,
**module_kwargs
):
super().__init__()
Expand All @@ -328,22 +336,6 @@ def __init__(
self.peft_module = peft_module
self.peft_blocks = []

# Whether to quantize the linear layers to 4 bit precision.
# NOTE: This is currently supported for CUDA-supported devices only.
if quantize:
if not _have_bnb:
raise ModuleNotFoundError("Please install 'bitsandbytes'.")

for name, module in model.image_encoder.named_modules():
if isinstance(module, torch.nn.Linear):
*parent_path, layer_name = name.split(".")
parent_module = model.image_encoder

for sub_module in parent_path:
parent_module = getattr(parent_module, sub_module)

setattr(parent_module, layer_name, bnb.nn.Linear4bit(module.in_features, module.out_features))

# Let's freeze all the pretrained image encoder layers first
for param in model.image_encoder.parameters():
param.requires_grad = False
Expand Down
37 changes: 37 additions & 0 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,46 @@ def get_sam_model(
if abbreviated_model_type == "vit_t":
raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")

_quantize = peft_kwargs.pop("quantize", None)

sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam

if _quantize:
import bitsandbytes as bnb
for name, module in sam.image_encoder.named_modules():
if isinstance(module, torch.nn.Linear):
*parent_path, layer_name = name.split(".")
parent_module = sam.image_encoder

for sub_module in parent_path:
parent_module = getattr(parent_module, sub_module)

# Extract weight and bias from the state_dict
weight_data = model_state.pop(f"image_encoder.{'.'.join(parent_path)}.{layer_name}.weight")
bias_data = model_state.pop(f"image_encoder.{'.'.join(parent_path)}.{layer_name}.bias", None)

layer_state_dict = {
k.split(f"image_encoder.{'.'.join(parent_path)}.")[1]: v
for k, v in model_state.items() if k.startswith(f"image_encoder.{'.'.join(parent_path)}.{layer_name}")
}

# Recreate the Linear4bit layer and load weights
linear_q4bit = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
bias=True,
)

# Assign the quantized weights to the new layer
linear_q4bit.weight = bnb.nn.Params4bit.from_prequantized(quantized_stats=layer_state_dict, data=weight_data)
if bias_data is not None:
linear_q4bit.bias = torch.nn.Parameter(bias_data)

# Replace the original linear layer with the quantized one
setattr(parent_module, layer_name, linear_q4bit)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# In case the model checkpoints have some issues when it is initialized with different parameters than default.

if flexible_load_checkpoint:
sam = _handle_checkpoint_loading(sam, model_state)
else:
Expand Down
Loading