Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extends smp_model_factory class #56

Merged
merged 6 commits into from
Aug 2, 2024

Conversation

PedroConrado
Copy link
Collaborator

No description provided.

Signed-off-by: Pedro Henrique Conrado <[email protected]>
@PedroConrado PedroConrado force-pushed the extend/smp_model_factory branch from de97585 to fb68d27 Compare July 19, 2024 18:11
@Joao-L-S-Almeida Joao-L-S-Almeida self-requested a review July 19, 2024 18:18
@Joao-L-S-Almeida Joao-L-S-Almeida self-assigned this Jul 19, 2024
@Joao-L-S-Almeida
Copy link
Member

Thanks for the contribution @PedroConrado. Could you provide some basic unit tests just to check if your modules are not crashing when using the current terratorch requirements ? Could you provide the versions you need for the dependencies you have used (as segmentation-models-pytorch ) ?


import torch.nn.functional as F # noqa: N812

class PrithviModelWrapper(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This SMP factory makes sense to instantiate models that are fully from SMP.

For Prithvi models, we already have the PrithviModelFactory, so I think it could be confusing to have another factory that creates Prithvi models from a user perspective.

Do you know if it is possible to isolate only the instantiation of decoders from SMP so that the PrithviModelFactory can also use them?

If not possible, or if that would be too hacky of an implementation, and the whole model needs to actually be within a smp.SegmentationModel, then I think it could be acceptable to have the SMP model also instantiate Prithvi encoders. However, I would try to use the existing mechanisms of instantiating through timm.create_model rather than how it is done here.

What do you think? Especially regarding the isolation of SMP decoders only

Comment on lines 20 to 54
def __init__(self, prithvi_class, **kwargs) -> None:
super().__init__()

self.config = kwargs
self.prithvi_class = str(prithvi_class)
if "MMSegSwinTransformer" in self.prithvi_class:
self.model = prithvi_class(**kwargs)
# Default swin preapre_features_for_image_model, can be changed later.
def prepare_features_for_image_model(x):
x = list(x)
outs = [i for i in x if not isinstance(i, tuple)]
return [
layer_output.reshape(
-1,
int(math.sqrt(layer_output.shape[1])),
int(math.sqrt(layer_output.shape[1])),
layer_output.shape[2],
).permute(0,3,1,2).contiguous()
for layer_output in outs
]

self.model.prepare_features_for_image_model = prepare_features_for_image_model
elif "TemporalViTEncoder" in self.prithvi_class:
self.model = prithvi_class(**kwargs, encoder_only=True)
else:
self.model = prithvi_class(**kwargs)

def forward(self, x):
return self.model.forward_features(x)

def prepare_features_for_image_model(self, x):
return self.model.prepare_features_for_image_model(x)

def channels(self):
return self.config["num_heads"]*[self.config["embed_dim"]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As before, I would try to rely on timm.create_model for this so that in the future we can support any other model we add

smp_output = self.smp_model(x)
smp_output = self.final_act(smp_output)

#TODO: support auxiliary head labels
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing TODO to complete?

Comment on lines 180 to 200
# Can load model from local checkpoint.
if "checkpoint_path" in backbone_kwargs:

checkpoint_path = backbone_kwargs.pop("checkpoint_path")
print(f"Trying to load from the path defined in the config file, {checkpoint_path}.")

if self.CPU_ONLY:
model_dict = torch.load(checkpoint_path, map_location="cpu")
else:
model_dict = torch.load(checkpoint_path)

if backbone.startswith("prithvi"): # Using Prithvi encoder (ViT or Swin).
backbone_class = self._make_smp_encoder(PrithviModelWrapper)
if backbone.startswith("prithvi_swin"):
backbone_kwargs['prithvi_class'] = MMSegSwinTransformer
elif backbone.startswith("prithvi_vit"):
backbone_kwargs['prithvi_class'] = TemporalViTEncoder
else:
msg = f"Prithvi Backbone not found."
raise NotImplementedError(msg)
# Using new encoder (not Prithvi or SMP).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using timm.create_model should handle all these complexities. checkpoint_path should also be something that we don't take, but that is passed to this constructor (See the info box here)

Comment on lines 224 to 227
if aux_params:
model = decoder_module(**model_args, aux_params=aux_params)
else:
msg = "Only unet decoder implemented"
raise NotImplementedError(msg)
model = decoder_module(**model_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this if necessary? Is the default value for aux_params not None anyway, so we can just pass it either way?

Comment on lines 229 to 235
# Loads state dict from checkpoint.
if model_dict:
if hasattr(model, "prithvi_class") and "TemporalViTEncoder" in model.prithvi_class:
model_dict = checkpoint_filter_fn(model_dict, model=model.encoder, pretrained_bands=bands, model_bands=bands)
model.encoder.load_state_dict(model_dict, strict=False)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled by timm.create_model and not called explicitly

Comment on lines 275 to 277
if "prepare_features_for_image_model" in kwargs:
path = kwargs['prepare_features_for_image_model']
self.prepare_features_for_image_model = _get_callable_from_path(path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this ever happen? I think at this stage you should expect prepare_features_for_image_model to already be a callable - LightningCLI should allow for that

Comment on lines 279 to 281
elif hasattr(super(), 'prepare_features_for_image_model'):
self.prepare_features_for_image_model = super().prepare_features_for_image_model
# No PFFIM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't inheritance take care of this? why is it necessary?


# Creates Wrapper for SMP Encoder with PFFIM.
# Wrapper needed to include SMP params and PFFIM
class SMPEncoderWrapperWithPFFIM(BaseClass, nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I see this class adapts prithvi-like encoders to play nicely with SMP. The idea in my other comment would be basically if it would be possible to instead make SMP-like decoders play nicely with PrithviModelFactory

Comment on lines 3 to 17
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import encoders as ENCODERS
import importlib
import math

import torch
from torch import nn

from terratorch.datasets import HLSBands
from terratorch.models.backbones.prithvi_vit import checkpoint_filter_fn
from terratorch.models.model import Model, ModelFactory, ModelOutput, register_factory
from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder
from terratorch.models.backbones.swin_encoder_decoder import MMSegSwinTransformer

import torch.nn.functional as F # noqa: N812
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure imports are sorted (ruff) should do it

self.model = prithvi_class(**kwargs, encoder_only=True)
else:
self.model = prithvi_class(**kwargs)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty space


def prepare_features_for_image_model(self, x):
return self.model.prepare_features_for_image_model(x)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty space


def channels(self):
return self.config["num_heads"]*[self.config["embed_dim"]]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty space

relu=False,
squeeze_single_class=False
) -> None:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty space

#TODO: support auxiliary head labels
if isinstance(smp_output, tuple):
smp_output, labels = smp_output

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty space

msg = f"SMP models can only perform pixel wise tasks, but got task {task}"
raise Exception(msg)
# backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty space

# Gets decoder module.
decoder_module = getattr(smp, decoder, None)
if decoder_module is None:
raise ValueError(f"Decoder {decoder} is not supported in SMP.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expression in exception shouldnt use an f-string (ruff)

# Gets class either from string or from Module reference.
def _make_smp_encoder(self, Encoder):
if isinstance(Encoder, str):
BaseClass = _get_class_from_string(Encoder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variable should be lower case

pass

def make_dilated(self, output_stride):
if hasattr(super(), 'make_dilated'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double quotes preferred

"""
if task not in ["segmentation", "regression"]:
self.CPU_ONLY = not torch.cuda.is_available()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class shouldn't have the responsibility for doing this, lightning does it.

Comment on lines 186 to 189
if self.CPU_ONLY:
model_dict = torch.load(checkpoint_path, map_location="cpu")
else:
model_dict = torch.load(checkpoint_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need this because timm handles model instantiation, for for future reference, you can just always use line 187, even if the model will be loaded onto GPU :)

Signed-off-by: Pedro Henrique Conrado <[email protected]>
@PedroConrado PedroConrado force-pushed the extend/smp_model_factory branch from 0da4809 to 93c523e Compare July 28, 2024 17:18
@@ -6,6 +6,8 @@
from terratorch.models.smp_model_factory import SMPModelFactory
from terratorch.models.timm_model_factory import TimmModelFactory

from terratorch.models.smp_model_factory import get_smp_decoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we should expose this directly. Is there a good reason for importing this here?

@@ -15,4 +17,5 @@
"TimmModelFactory",
"AuxiliaryHead",
"AuxiliaryHeadWithDecoderWithoutInstantiatedHead",
"get_smp_decoder",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Comment on lines 6 to 8
from terratorch.models.smp_model_factory import SMPModelFactory
from terratorch.models.timm_model_factory import TimmModelFactory

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsorted imports (ruff can sort imports)

Comment on lines 104 to 105
smp_kwargs, kwargs = _extract_prefix_keys(kwargs, "smp_")
aux_kwargs, kwargs = _extract_prefix_keys(kwargs, "aux_")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make use of the args passed through decoder_ instead of needing a new prefix for smp and aux?

Comment on lines 45 to 49
def forward_multiple_embeds(self, *x):
return self.decoder(*x)

def forward_single_embed(self, x):
return self.decoder(x[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this, can we remain consistent with current Prithvi decoders and accept an in_index argument? This should be a single int or a list of ints that are the indices of x which we select and pass forward to the decoder

raise ValueError(msg)

# Using new encoder.
backbone_class = self._make_smp_encoder(backbone)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this working? _make_smp_encoder seems to be a module level method rather than a method of this class

Comment on lines 302 to 306
smp_decoder = SMPDecoderForPrithviWrapper(decoder=model.decoder, num_channels=out_channels[-1])
if "multiple_embed" in head_kwargs:
smp_decoder.forward = smp_decoder.forward_multiple_embeds
else:
smp_decoder.forward = smp_decoder.forward_single_embed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see suggestion above to allow in_index parameter, which would make this monkey patching not necessary



# Gets class either from string or from Module reference.
def _make_smp_encoder(encoder=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a method of the factory above?

self.smp_model = smp_model
self.final_act = nn.ReLU() if relu else nn.Identity()
self.squeeze_single_class = squeeze_single_class
def get_smp_decoder(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function and the class SMPDecoderForPrithviWrapper seem to only be used for creating decoders which match the format required by PrithviModelFactory right? Then, would it make sense to instead move them to the PrithviModelFactory class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is also used by _get_smp_encoder to create the dummy smp encoder.

@PedroConrado PedroConrado force-pushed the extend/smp_model_factory branch from b4e1a2a to 97b4688 Compare July 29, 2024 23:46
Copy link
Contributor

@CarlosGomes98 CarlosGomes98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@CarlosGomes98 CarlosGomes98 merged commit 6c483a8 into IBM:main Aug 2, 2024
1 of 3 checks passed
PedroConrado added a commit to PedroConrado/terratorch that referenced this pull request Aug 2, 2024
* extends smp_model_factory class

Signed-off-by: Pedro Henrique Conrado <[email protected]>

* extends smp model factory and adds functionalities in prithvi model factory

Signed-off-by: Pedro Henrique Conrado <[email protected]>

* extends SMPModelFactory

Signed-off-by: Pedro Henrique Conrado <[email protected]>

* Extends SMPModelFactory and smp decoder in PrithviModelFactoy

Signed-off-by: Pedro Henrique Conrado <[email protected]>

* adds SMPModelFactory tests and SMPModelFactory to model.md

Signed-off-by: Pedro Henrique Conrado <[email protected]>

* adds SMPModelFactory tests and SMPModelFactory to docs/model.md

Signed-off-by: Pedro Henrique Conrado <[email protected]>

---------

Signed-off-by: Pedro Henrique Conrado <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants