Skip to content

Commit

Permalink
Add QLoRA (#820)
Browse files Browse the repository at this point in the history
Add QLoRA implementation
  • Loading branch information
anwai98 authored Jan 6, 2025
1 parent 9073486 commit f0fdb0a
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 23 deletions.
17 changes: 15 additions & 2 deletions finetuning/livecell_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ def finetune_livecell(args):
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True}

# NOTE: memory req. for all vit_b models (compared on A100 80GB)
# vit_b
# freeze_encoder: ~ 33.89 GB
# QLoRA: ~48.54 GB
# LoRA: ~48.62 GB
# FFT: ~49.56 GB

# vit_h
# freeze_encoder: ~36.05 GB
# QLoRA: ~ 65.68 GB
# LoRA: ~ 67.14 GB
# FFT: ~72.34 GB

# Run training.
sam_training.train_sam(
name=checkpoint_name,
Expand All @@ -72,7 +85,7 @@ def finetune_livecell(args):
save_root=args.save_root,
scheduler_kwargs=scheduler_kwargs,
save_every_kth_epoch=args.save_every_kth_epoch,
peft_kwargs={"rank": args.lora_rank} if args.lora_rank is not None else None,
peft_kwargs={"rank": args.lora_rank, "quantize": True} if args.lora_rank is not None else None,
)

if args.export_path is not None:
Expand All @@ -87,7 +100,7 @@ def finetune_livecell(args):
def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LIVECell dataset.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/",
"--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/data/livecell/",
help="The filepath to the LIVECell data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
Expand Down
39 changes: 34 additions & 5 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

from segment_anything.modeling import Sam

try:
import bitsandbytes as bnb
_have_bnb = True
except ImportError:
_have_bnb = False


class LoRASurgery(nn.Module):
"""Operates on the attention layers for performing low-rank adaptation.
Expand All @@ -22,7 +28,7 @@ class LoRASurgery(nn.Module):
Args:
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing lora.
block: The chosen attention blocks for implementing LoRA.
"""
def __init__(self, rank: int, block: nn.Module):
super().__init__()
Expand Down Expand Up @@ -50,8 +56,14 @@ def forward(self, x):
qkv = self.qkv_proj(x) # B, N, N, 3 * org_C
new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x))
new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x))
qkv[:, :, :, :self.dim] += new_q
qkv[:, :, :, -self.dim:] += new_v
qkv = torch.cat(
[
qkv[:, :, :, :self.dim] + new_q, # replacing new q values
qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical
qkv[:, :, :, -self.dim:] + new_v # replacing new v values
], dim=-1
)

return qkv


Expand Down Expand Up @@ -289,6 +301,7 @@ class PEFT_Sam(nn.Module):
rank: The rank for low-rank adaptation.
peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
attention_layers_to_update: Which specific layers we apply PEFT methods to.
quantize: Whether to quantize the model for lower precision training.
"""

def __init__(
Expand All @@ -297,12 +310,12 @@ def __init__(
rank: int,
peft_module: nn.Module = LoRASurgery,
attention_layers_to_update: Union[List[int]] = None,
quantize: bool = False,
**module_kwargs
):
super().__init__()

assert rank > 0

assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
"Invalid PEFT module"
)
Expand All @@ -315,7 +328,23 @@ def __init__(
self.peft_module = peft_module
self.peft_blocks = []

# let's freeze all the pretrained image encoder layers first
# Whether to quantize the linear layers to 4 bit precision.
# NOTE: This is currently supported for CUDA-supported devices only.
if quantize:
if not _have_bnb:
raise ModuleNotFoundError("Please install 'bitsandbytes'.")

for name, module in model.image_encoder.named_modules():
if isinstance(module, torch.nn.Linear):
*parent_path, layer_name = name.split(".")
parent_module = model.image_encoder

for sub_module in parent_path:
parent_module = getattr(parent_module, sub_module)

setattr(parent_module, layer_name, bnb.nn.Linear4bit(module.in_features, module.out_features))

# Let's freeze all the pretrained image encoder layers first
for param in model.image_encoder.parameters():
param.requires_grad = False

Expand Down
8 changes: 4 additions & 4 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim
with torch.no_grad():
net_mean_model_iou = torch.mean(batched_iou_predictions)

loss += net_loss
mask_loss += net_mask_loss
iou_regression_loss += net_iou_regression_loss
mean_model_iou += net_mean_model_iou
loss = loss + net_loss
mask_loss = mask_loss + net_mask_loss
iou_regression_loss = iou_regression_loss + net_iou_regression_loss
mean_model_iou = mean_model_iou + net_mean_model_iou

if i < (num_subiter - 1): # We need not update the prompts for the last iteration.
# Determine the next prompts based on current predictions.
Expand Down
21 changes: 9 additions & 12 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,22 +193,18 @@ def train_sam(
"""Run training for a SAM model.
Args:
name: The name of the model to be trained.
The checkpoint and logs wil have this name.
name: The name of the model to be trained. The checkpoint and logs will have this name.
model_type: The type of the SAM model.
train_loader: The dataloader for training.
val_loader: The dataloader for validation.
n_epochs: The number of epochs to train for.
early_stopping: Enable early stopping after this number of epochs
without improvement.
early_stopping: Enable early stopping after this number of epochs without improvement.
n_objects_per_batch: The number of objects per batch used to compute
the loss for interative segmentation. If None all objects will be used,
if given objects will be randomly sub-sampled.
checkpoint_path: Path to checkpoint for initializing the SAM model.
with_segmentation_decoder: Whether to train additional UNETR decoder
for automatic instance segmentation.
freeze: Specify parts of the model that should be frozen, namely:
image_encoder, prompt_encoder and mask_decoder
with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
By default nothing is frozen and the full model is updated.
device: The device to use for training.
lr: The learning rate.
Expand All @@ -226,10 +222,10 @@ def train_sam(
optimizer_class: The optimizer class.
By default, torch.optim.AdamW is used.
peft_kwargs: Keyword arguments for the PEFT wrapper class.
ignore_warnings: Whether to ignore raised warnings.
verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
By default, 50 batches of labels are verified from the dataloaders.
model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
ignore_warnings: Whether to ignore raised warnings.
"""
with _filter_warnings(ignore_warnings):

Expand Down Expand Up @@ -269,10 +265,11 @@ def train_sam(
if not param_name.startswith("encoder"):
joint_model_params.append(params)

optimizer = optimizer_class(joint_model_params, lr=lr)

model_params = joint_model_params
else:
optimizer = optimizer_class(model.parameters(), lr=lr)
model_params = model.parameters()

optimizer = optimizer_class(model_params, lr=lr)

if scheduler_kwargs is None:
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
Expand Down

0 comments on commit f0fdb0a

Please sign in to comment.