Skip to content

Commit

Permalink
handle quantization parameters in state dict
Browse files Browse the repository at this point in the history
  • Loading branch information
caroteu committed Jan 15, 2025
1 parent 8757482 commit f423834
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
13 changes: 11 additions & 2 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,11 +770,20 @@ def get_unetr(
resize_input=True,
)
if decoder_state is not None:
strict = True
unetr_state_dict = unetr.state_dict()
for k, v in unetr_state_dict.items():
qlora_keys = []
for k, _ in unetr_state_dict.items():
if not k.startswith("encoder"):
unetr_state_dict[k] = decoder_state[k]
unetr.load_state_dict(unetr_state_dict, strict=False)
# handle problematic QLoRA parameters
if k.find("weight") != -1 and not k.endswith("weight"):
qlora_keys.append(k)
if qlora_keys:
# make sure the keys of both dictionaries match before setting strict to False
assert unetr.state_dict().keys() == unetr_state_dict.keys(), "Unexpected or missing keys in state dict."
strict = False
unetr.load_state_dict(unetr_state_dict, strict=strict)

unetr.to(device)
return unetr
Expand Down
25 changes: 20 additions & 5 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,12 @@ def get_sam_model(
raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")

sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
if "quantize" in peft_kwargs: # TODO: add doc
if "quantize" in peft_kwargs and peft_kwargs["quantize"]:
# Quantization happens when loading the model to device. This needs to be done before loading the state to
# avoid a mismatch between the shape of the quantized and unquantized weights.
sam.to(device=device)
# Handle QLoRA weight parameters, that get recognized as unexpected keys in load_state_dict
flexible_load_checkpoint = True

# In case the model checkpoints have some issues when it is initialized with different parameters than default.
if flexible_load_checkpoint:
Expand Down Expand Up @@ -410,12 +414,23 @@ def _handle_checkpoint_loading(sam, model_state, peft_kwargs):
# eg. while training for multi-class semantic segmentation in the mask encoder,
# parameters are updated - leading to "size mismatch" errors
# TODO: add other docs
strict = True

if peft_kwargs and "quantize" in peft_kwargs:
# TODO: add docs
# TODO: add checks and raise warning for users to know about what's happening
# and brief details about strict=False behaviour
sam.load_state_dict(reference_state, strict=False)
# Handle QLoRA weight parameters, that get recognized as unexpected keys in
# load_state_dict although there is no mismatch in parameters, by setting strict to False.
problematic_keys = []
for k, v in model_state.items():
if k.find("weight") != -1 and not k.endswith("weight"):
problematic_keys.append(k)
if problematic_keys:
warnings.warn(f"Parameters with problematic behaviour: {problematic_keys}")
strict = False
# make sure there is no real mismatch
assert model_state.keys() == sam.state_dict().keys(), "Unexpected or missing keys in the model state."

sam.load_state_dict(model_state, strict=strict)
return sam

new_state_dict = {} # for loading matching parameters
mismatched_layers = [] # for tracking mismatching parameters
Expand Down

0 comments on commit f423834

Please sign in to comment.