-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
extends smp_model_factory class #56
Conversation
Signed-off-by: Pedro Henrique Conrado <[email protected]>
de97585
to
fb68d27
Compare
Thanks for the contribution @PedroConrado. Could you provide some basic unit tests just to check if your modules are not crashing when using the current terratorch requirements ? Could you provide the versions you need for the dependencies you have used (as segmentation-models-pytorch ) ? |
|
||
import torch.nn.functional as F # noqa: N812 | ||
|
||
class PrithviModelWrapper(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This SMP factory makes sense to instantiate models that are fully from SMP.
For Prithvi models, we already have the PrithviModelFactory
, so I think it could be confusing to have another factory that creates Prithvi models from a user perspective.
Do you know if it is possible to isolate only the instantiation of decoders from SMP so that the PrithviModelFactory
can also use them?
If not possible, or if that would be too hacky of an implementation, and the whole model needs to actually be within a smp.SegmentationModel
, then I think it could be acceptable to have the SMP model also instantiate Prithvi encoders. However, I would try to use the existing mechanisms of instantiating through timm.create_model
rather than how it is done here.
What do you think? Especially regarding the isolation of SMP decoders only
def __init__(self, prithvi_class, **kwargs) -> None: | ||
super().__init__() | ||
|
||
self.config = kwargs | ||
self.prithvi_class = str(prithvi_class) | ||
if "MMSegSwinTransformer" in self.prithvi_class: | ||
self.model = prithvi_class(**kwargs) | ||
# Default swin preapre_features_for_image_model, can be changed later. | ||
def prepare_features_for_image_model(x): | ||
x = list(x) | ||
outs = [i for i in x if not isinstance(i, tuple)] | ||
return [ | ||
layer_output.reshape( | ||
-1, | ||
int(math.sqrt(layer_output.shape[1])), | ||
int(math.sqrt(layer_output.shape[1])), | ||
layer_output.shape[2], | ||
).permute(0,3,1,2).contiguous() | ||
for layer_output in outs | ||
] | ||
|
||
self.model.prepare_features_for_image_model = prepare_features_for_image_model | ||
elif "TemporalViTEncoder" in self.prithvi_class: | ||
self.model = prithvi_class(**kwargs, encoder_only=True) | ||
else: | ||
self.model = prithvi_class(**kwargs) | ||
|
||
def forward(self, x): | ||
return self.model.forward_features(x) | ||
|
||
def prepare_features_for_image_model(self, x): | ||
return self.model.prepare_features_for_image_model(x) | ||
|
||
def channels(self): | ||
return self.config["num_heads"]*[self.config["embed_dim"]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As before, I would try to rely on timm.create_model
for this so that in the future we can support any other model we add
smp_output = self.smp_model(x) | ||
smp_output = self.final_act(smp_output) | ||
|
||
#TODO: support auxiliary head labels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing TODO to complete?
# Can load model from local checkpoint. | ||
if "checkpoint_path" in backbone_kwargs: | ||
|
||
checkpoint_path = backbone_kwargs.pop("checkpoint_path") | ||
print(f"Trying to load from the path defined in the config file, {checkpoint_path}.") | ||
|
||
if self.CPU_ONLY: | ||
model_dict = torch.load(checkpoint_path, map_location="cpu") | ||
else: | ||
model_dict = torch.load(checkpoint_path) | ||
|
||
if backbone.startswith("prithvi"): # Using Prithvi encoder (ViT or Swin). | ||
backbone_class = self._make_smp_encoder(PrithviModelWrapper) | ||
if backbone.startswith("prithvi_swin"): | ||
backbone_kwargs['prithvi_class'] = MMSegSwinTransformer | ||
elif backbone.startswith("prithvi_vit"): | ||
backbone_kwargs['prithvi_class'] = TemporalViTEncoder | ||
else: | ||
msg = f"Prithvi Backbone not found." | ||
raise NotImplementedError(msg) | ||
# Using new encoder (not Prithvi or SMP). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using timm.create_model
should handle all these complexities. checkpoint_path
should also be something that we don't take, but that is passed to this constructor (See the info box here)
if aux_params: | ||
model = decoder_module(**model_args, aux_params=aux_params) | ||
else: | ||
msg = "Only unet decoder implemented" | ||
raise NotImplementedError(msg) | ||
model = decoder_module(**model_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this if necessary? Is the default value for aux_params
not None
anyway, so we can just pass it either way?
# Loads state dict from checkpoint. | ||
if model_dict: | ||
if hasattr(model, "prithvi_class") and "TemporalViTEncoder" in model.prithvi_class: | ||
model_dict = checkpoint_filter_fn(model_dict, model=model.encoder, pretrained_bands=bands, model_bands=bands) | ||
model.encoder.load_state_dict(model_dict, strict=False) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be handled by timm.create_model
and not called explicitly
if "prepare_features_for_image_model" in kwargs: | ||
path = kwargs['prepare_features_for_image_model'] | ||
self.prepare_features_for_image_model = _get_callable_from_path(path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this ever happen? I think at this stage you should expect prepare_features_for_image_model
to already be a callable - LightningCLI should allow for that
elif hasattr(super(), 'prepare_features_for_image_model'): | ||
self.prepare_features_for_image_model = super().prepare_features_for_image_model | ||
# No PFFIM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't inheritance take care of this? why is it necessary?
|
||
# Creates Wrapper for SMP Encoder with PFFIM. | ||
# Wrapper needed to include SMP params and PFFIM | ||
class SMPEncoderWrapperWithPFFIM(BaseClass, nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I see this class adapts prithvi-like encoders to play nicely with SMP. The idea in my other comment would be basically if it would be possible to instead make SMP-like decoders play nicely with PrithviModelFactory
import segmentation_models_pytorch as smp | ||
from segmentation_models_pytorch.encoders import encoders as ENCODERS | ||
import importlib | ||
import math | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from terratorch.datasets import HLSBands | ||
from terratorch.models.backbones.prithvi_vit import checkpoint_filter_fn | ||
from terratorch.models.model import Model, ModelFactory, ModelOutput, register_factory | ||
from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder | ||
from terratorch.models.backbones.swin_encoder_decoder import MMSegSwinTransformer | ||
|
||
import torch.nn.functional as F # noqa: N812 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure imports are sorted (ruff) should do it
self.model = prithvi_class(**kwargs, encoder_only=True) | ||
else: | ||
self.model = prithvi_class(**kwargs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty space
|
||
def prepare_features_for_image_model(self, x): | ||
return self.model.prepare_features_for_image_model(x) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty space
|
||
def channels(self): | ||
return self.config["num_heads"]*[self.config["embed_dim"]] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty space
relu=False, | ||
squeeze_single_class=False | ||
) -> None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty space
#TODO: support auxiliary head labels | ||
if isinstance(smp_output, tuple): | ||
smp_output, labels = smp_output | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty space
msg = f"SMP models can only perform pixel wise tasks, but got task {task}" | ||
raise Exception(msg) | ||
# backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty space
# Gets decoder module. | ||
decoder_module = getattr(smp, decoder, None) | ||
if decoder_module is None: | ||
raise ValueError(f"Decoder {decoder} is not supported in SMP.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expression in exception shouldnt use an f-string (ruff)
# Gets class either from string or from Module reference. | ||
def _make_smp_encoder(self, Encoder): | ||
if isinstance(Encoder, str): | ||
BaseClass = _get_class_from_string(Encoder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
variable should be lower case
pass | ||
|
||
def make_dilated(self, output_stride): | ||
if hasattr(super(), 'make_dilated'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double quotes preferred
""" | ||
if task not in ["segmentation", "regression"]: | ||
self.CPU_ONLY = not torch.cuda.is_available() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class shouldn't have the responsibility for doing this, lightning does it.
if self.CPU_ONLY: | ||
model_dict = torch.load(checkpoint_path, map_location="cpu") | ||
else: | ||
model_dict = torch.load(checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need this because timm handles model instantiation, for for future reference, you can just always use line 187, even if the model will be loaded onto GPU :)
…actory Signed-off-by: Pedro Henrique Conrado <[email protected]>
Signed-off-by: Pedro Henrique Conrado <[email protected]>
0da4809
to
93c523e
Compare
terratorch/models/__init__.py
Outdated
@@ -6,6 +6,8 @@ | |||
from terratorch.models.smp_model_factory import SMPModelFactory | |||
from terratorch.models.timm_model_factory import TimmModelFactory | |||
|
|||
from terratorch.models.smp_model_factory import get_smp_decoder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think we should expose this directly. Is there a good reason for importing this here?
terratorch/models/__init__.py
Outdated
@@ -15,4 +17,5 @@ | |||
"TimmModelFactory", | |||
"AuxiliaryHead", | |||
"AuxiliaryHeadWithDecoderWithoutInstantiatedHead", | |||
"get_smp_decoder", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
terratorch/models/__init__.py
Outdated
from terratorch.models.smp_model_factory import SMPModelFactory | ||
from terratorch.models.timm_model_factory import TimmModelFactory | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unsorted imports (ruff can sort imports)
smp_kwargs, kwargs = _extract_prefix_keys(kwargs, "smp_") | ||
aux_kwargs, kwargs = _extract_prefix_keys(kwargs, "aux_") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make use of the args passed through decoder_
instead of needing a new prefix for smp and aux?
def forward_multiple_embeds(self, *x): | ||
return self.decoder(*x) | ||
|
||
def forward_single_embed(self, x): | ||
return self.decoder(x[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of this, can we remain consistent with current Prithvi decoders and accept an in_index
argument? This should be a single int or a list of ints that are the indices of x
which we select and pass forward to the decoder
raise ValueError(msg) | ||
|
||
# Using new encoder. | ||
backbone_class = self._make_smp_encoder(backbone) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this working? _make_smp_encoder
seems to be a module level method rather than a method of this class
smp_decoder = SMPDecoderForPrithviWrapper(decoder=model.decoder, num_channels=out_channels[-1]) | ||
if "multiple_embed" in head_kwargs: | ||
smp_decoder.forward = smp_decoder.forward_multiple_embeds | ||
else: | ||
smp_decoder.forward = smp_decoder.forward_single_embed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see suggestion above to allow in_index
parameter, which would make this monkey patching not necessary
|
||
|
||
# Gets class either from string or from Module reference. | ||
def _make_smp_encoder(encoder=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be a method of the factory above?
self.smp_model = smp_model | ||
self.final_act = nn.ReLU() if relu else nn.Identity() | ||
self.squeeze_single_class = squeeze_single_class | ||
def get_smp_decoder( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function and the class SMPDecoderForPrithviWrapper
seem to only be used for creating decoders which match the format required by PrithviModelFactory
right? Then, would it make sense to instead move them to the PrithviModelFactory
class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is also used by _get_smp_encoder to create the dummy smp encoder.
Signed-off-by: Pedro Henrique Conrado <[email protected]>
Signed-off-by: Pedro Henrique Conrado <[email protected]>
Signed-off-by: Pedro Henrique Conrado <[email protected]>
b4e1a2a
to
97b4688
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
* extends smp_model_factory class Signed-off-by: Pedro Henrique Conrado <[email protected]> * extends smp model factory and adds functionalities in prithvi model factory Signed-off-by: Pedro Henrique Conrado <[email protected]> * extends SMPModelFactory Signed-off-by: Pedro Henrique Conrado <[email protected]> * Extends SMPModelFactory and smp decoder in PrithviModelFactoy Signed-off-by: Pedro Henrique Conrado <[email protected]> * adds SMPModelFactory tests and SMPModelFactory to model.md Signed-off-by: Pedro Henrique Conrado <[email protected]> * adds SMPModelFactory tests and SMPModelFactory to docs/model.md Signed-off-by: Pedro Henrique Conrado <[email protected]> --------- Signed-off-by: Pedro Henrique Conrado <[email protected]>
No description provided.