Skip to content

Commit

Permalink
Fix QLoRA weights and bias initialisation (#833)
Browse files Browse the repository at this point in the history
Fixes the following:
- Loading of pretrained weights to quantized layers.
- Converting the QLoRA-style finetuned model to LoRA-style model to run inference.

---------

Co-authored-by: Anwai Archit <[email protected]>
Co-authored-by: Anwai Archit <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2025
1 parent a47f2f2 commit 552dc55
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 5 deletions.
18 changes: 17 additions & 1 deletion micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,23 @@ def __init__(
for sub_module in parent_path:
parent_module = getattr(parent_module, sub_module)

setattr(parent_module, layer_name, bnb.nn.Linear4bit(module.in_features, module.out_features))
# Create the new Linear4bit layer
linear_q = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
bias=False if module.bias is None else True,
)
# Assign weights and bias to the new layer
new_weight = bnb.nn.Params4bit(
data=module.weight,
requires_grad=False,
)
linear_q.weight = new_weight
if module.bias is not None:
linear_q.bias = torch.nn.Parameter(module.bias)

# Replace the original linear layer with the quantized one
setattr(parent_module, layer_name, linear_q)

# Let's freeze all the pretrained image encoder layers first
for param in model.image_encoder.parameters():
Expand Down
57 changes: 53 additions & 4 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,17 +389,18 @@ def get_sam_model(
# Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
# Overwrites the SAM model by freezing the backbone and allow PEFT.
if peft_kwargs and isinstance(peft_kwargs, dict):
# NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
peft_kwargs.pop("quantize", None)

if abbreviated_model_type == "vit_t":
raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")

sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam

# In case the model checkpoints have some issues when it is initialized with different parameters than default.
if flexible_load_checkpoint:
sam = _handle_checkpoint_loading(sam, model_state)
else:
sam.load_state_dict(model_state)

sam.to(device=device)

predictor = SamPredictor(sam)
Expand Down Expand Up @@ -456,13 +457,13 @@ def _handle_checkpoint_loading(sam, model_state):
def export_custom_sam_model(
checkpoint_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike],
) -> None:
"""Export a finetuned segment anything model to the standard model format.
"""Export a finetuned Segment Anything Model to the standard model format.
The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
Args:
checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
save_path: Where to save the exported model.
"""
_, state = get_sam_model(
Expand All @@ -476,6 +477,54 @@ def export_custom_sam_model(
torch.save(model_state, save_path)


def export_custom_qlora_model(
checkpoint_path: Union[str, os.PathLike],
finetuned_path: Union[str, os.PathLike],
model_type: str,
save_path: Union[str, os.PathLike],
) -> None:
"""Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
Args:
checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
finetuned_path: The path to the new finetuned model, using QLoRA.
model_type: The Segment Anything Model type corresponding to the checkpoint.
save_path: Where to save the exported model.
"""
# Step 1: Get the base SAM model: used to start finetuning from.
_, sam = get_sam_model(
model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
)

# Step 2: Load the QLoRA-style finetuned model.
ft_state, ft_model_state = _load_checkpoint(finetuned_path)

# Step 3: Get LoRA weights from QLoRA and retain all original parameters from the base SAM model.
updated_model_state = {}

# - At first, we get all LoRA layers from the QLoRA-style finetuned model checkpoint.
for k, v in ft_model_state.items():
if k.find("w_b_linear") != -1 or k.find("w_a_linear") != -1:
updated_model_state[k] = v

# - Next, we get all the remaining parameters from the base SAM model.
for k, v in sam.state_dict().items():
if k.find("attn.qkv.") != -1:
k = k.replace("qkv", "qkv.qkv_proj")
updated_model_state[k] = v
else:

updated_model_state[k] = v

# - Finally, we replace the old model state with the new one (to retain other relevant stuff)
ft_state['model_state'] = updated_model_state

# Step 4: Store the new "state" to "save_path"
torch.save(ft_state, save_path)


def get_model_names() -> Iterable:
model_registry = models()
model_names = model_registry.registry.keys()
Expand Down

0 comments on commit 552dc55

Please sign in to comment.