Skip to content

Commit

Permalink
Expose out_channels arguments for get_unetr function (#838)
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 authored Jan 19, 2025
1 parent 87e3d85 commit 84dc87d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def get_unetr(
image_encoder: torch.nn.Module,
decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
device: Optional[Union[str, torch.device]] = None,
out_channels: int = 3,
) -> torch.nn.Module:
"""Get UNETR model for automatic instance segmentation.
Expand All @@ -754,6 +755,7 @@ def get_unetr(
This is used as encoder by the UNETR too.
decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
device: The device.
out_channels: The number of output channels.
Returns:
The UNETR model.
Expand All @@ -763,7 +765,7 @@ def get_unetr(
unetr = UNETR(
backbone="sam",
encoder=image_encoder,
out_channels=3,
out_channels=out_channels,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False,
Expand Down

0 comments on commit 84dc87d

Please sign in to comment.