Skip to content

Commit

Permalink
Minor comments
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 authored Jan 15, 2025
1 parent 9c5a8c6 commit 8757482
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,13 +376,14 @@ def get_sam_model(
raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")

sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
sam.to(device=device)
if "quantize" in peft_kwargs: # TODO: add doc
sam.to(device=device)

# In case the model checkpoints have some issues when it is initialized with different parameters than default.
if flexible_load_checkpoint:
sam = _handle_checkpoint_loading(sam, model_state)
sam = _handle_checkpoint_loading(sam, model_state, peft_kwargs)
else:
sam.load_state_dict(model_state, strict=False)
sam.load_state_dict(model_state)

sam.to(device=device)

Expand All @@ -404,10 +405,17 @@ def get_sam_model(
return predictor


def _handle_checkpoint_loading(sam, model_state):
def _handle_checkpoint_loading(sam, model_state, peft_kwargs):
# Whether to handle the mismatch issues in a bit more elegant way.
# eg. while training for multi-class semantic segmentation in the mask encoder,
# parameters are updated - leading to "size mismatch" errors
# TODO: add other docs

if peft_kwargs and "quantize" in peft_kwargs:
# TODO: add docs
# TODO: add checks and raise warning for users to know about what's happening
# and brief details about strict=False behaviour
sam.load_state_dict(reference_state, strict=False)

new_state_dict = {} # for loading matching parameters
mismatched_layers = [] # for tracking mismatching parameters
Expand Down

0 comments on commit 8757482

Please sign in to comment.