diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..87650d1f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered (optional)** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context (optional)** +Add any other context or screenshots about the feature request here. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/issue-report.md b/.github/ISSUE_TEMPLATE/issue-report.md new file mode 100644 index 00000000..20ec527f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/issue-report.md @@ -0,0 +1,34 @@ +--- +name: Issue report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +**Describe the issue** +A clear and concise description of what the issue is. + +**To Reproduce (optional, but appreciated)** +Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error + +**Screenshots or log output (optional)** +If applicable, add screenshots or log output to help explain your problem. +
Log Output +
+Paste the log output here.
+
+
+ +**Expected behavior (optional)** +A clear and concise description of what you expected to happen. + +**Deployment information (optional)** +Describe what you've deployed and how: + - TerraTorch version: [e.g. 1.5.3] + - Installation source: [e.g. git, pip] diff --git a/.github/ISSUE_TEMPLATE/report-a-vulnerability.md b/.github/ISSUE_TEMPLATE/report-a-vulnerability.md new file mode 100644 index 00000000..408ab606 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/report-a-vulnerability.md @@ -0,0 +1,17 @@ +--- +name: " Report a vulnerability " +about: " Report a vulnerability " +title: '' +labels: '' +assignees: '' + +--- + +### Summary +_Short summary of the problem. Make the impact and severity as clear as possible. For example: An unsafe deserialization vulnerability allows any unauthenticated user to execute arbitrary code on the server._ + +### Details +_Give all details on the vulnerability. Pointing to the incriminated source code is very helpful for the maintainer._ + +### Impact +_What kind of vulnerability is it? Who is impacted?_ \ No newline at end of file diff --git a/.github/workflows/crosshair.yaml b/.github/workflows/crosshair.yaml new file mode 100644 index 00000000..b6b670e8 --- /dev/null +++ b/.github/workflows/crosshair.yaml @@ -0,0 +1,33 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Dynamic Code Analysis + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v3 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install crosshair-tool + pip install -r requirements/required.txt + - name: Test with pytest + run: | + crosshair watch ./terratorch diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 00000000..ba63bdd4 --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,28 @@ +name: Static Code Analysis (Pylint) + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10","3.11"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') + pylint $(git ls-files '*.ipynb') diff --git a/.gitignore b/.gitignore index b93078b8..b07ad6b9 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ dist/* **/*.pt *.ipynb_checkpoints **/*pth +.venv/* +venv/* \ No newline at end of file diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 00000000..508e0046 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,5 @@ +@CarlosGomes98 +@Joao-L-S-Almeida +@PedroConrado +@biancazadrozny +@romeokienzler \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..742740b5 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,70 @@ +## TerraTorch Community Code of Conduct 1.0 + + +### Our Pledge + +As contributors and maintainers of this project, and in the interest of fostering +an open and welcoming community, we pledge to respect all people who contribute +through reporting issues, posting feature requests, updating documentation, +submitting pull requests or patches, and other activities. + +We are committed to making participation in this project a harassment-free experience for +everyone, regardless of level of experience, gender, gender identity and expression, +sexual orientation, disability, personal appearance, body size, race, ethnicity, age, +religion, or nationality. + +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, +inclusive, and healthy community. + +### Our Standards + +Examples of behavior that contributes to a positive environment for our community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior +and will take appropriate and fair corrective action in response to any behavior that they deem +inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject comments, commits, +code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, +and will communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces including Matrix, issue trackers, wikis, +blogs, Twitter, and any other communication channels used by our community, and also applies when +an individual is officially representing the community in public spaces. Examples of representing +our community include using an official e-mail address, posting via an official social media account, +or acting as an appointed representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community +leaders responsible for enforcement via Matrix channel to the CLAIMED Project Management Committee at +#claimed-pmc:matrix.org. All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the reporter of any incident. + + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), +version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..047225fa --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,32 @@ + + +# Contributing + +Welcome to TerraTorch! If you are interested in contributing to the [TerraTorch code repo](README.md) +then checkout the [Contribution Process](contribution_process.md) and +the [Code of Conduct](CODE_OF_CONDUCT.md)). + +The [erraTorch code repo]([https://github.com/IBM/terratorch/blob/main/README.md)) contains information on how the community +is organized and other information that is pertinent to contributing. + +### Getting Started + +It's encouraged that you look under the [Issues]([https://github.com/IBM/terratorch/issues)) tab for contribution opportunites. + +Please also take the opportunity to join our [weekly community call](https://teams.microsoft.com/l/meetup-join/19%3ameeting_MWJhMThhMTMtMjc3MS00YjAyLWI3NTMtYTI0NDQ3NWY3ZGU2%40thread.v2/0?context=%7b%22Tid%22%3a%22fcf67057-50c9-4ad4-98f3-ffca64add9e9%22%2c%22Oid%22%3a%227f7ab87a-680c-4c93-acc5-fbd7ec80823a%22%7d) every Thursday 7 AM PST, 10 AM EST, **2 PM UTC**, 4 PM CET, 7:30 PM IST, 11 PM JST, 12 AM AEST diff --git a/VULNERABILITIES.md b/VULNERABILITIES.md new file mode 100644 index 00000000..737a02be --- /dev/null +++ b/VULNERABILITIES.md @@ -0,0 +1,4 @@ +# VULNERABILITIES reporting process + +Vulnerabilities can be reported on the GitHub Issue Tracker. +We are working on a more generic channel and update this information accordingly diff --git a/contribution_process.md b/contribution_process.md new file mode 100644 index 00000000..26db0271 --- /dev/null +++ b/contribution_process.md @@ -0,0 +1,9 @@ +# Contribution Process + +If you want to contribute to this project, there are many valuable ways in doing so + +1. Join the weekly community calls as indicated in [CONTRIBUTING.md](CONTRIBUTING.md) +1. Use / test TerraTorch and create an [Issue](https://github.com/IBM/terratorch/issues) if something is not working properly or if you have an idea for a feature request. +1. Pick an [Issue](https://github.com/IBM/terratorch/issues) and start contributing + +Contributions are welcome as pull requests on a [fork](https://github.com/IBM/terratorch/fork) of this project. Ideally, pull requests are backed by an [Issue](https://github.com/IBM/terratorch/issues). You can also tag the [code owners](https://github.com/IBM/terratorch/blob/main/CODEOWNERS) in the issue before you start, so we can talk about the details (in case you can't join one of the community calls). \ No newline at end of file diff --git a/docs/models.md b/docs/models.md index 33b27008..1f5627a6 100644 --- a/docs/models.md +++ b/docs/models.md @@ -61,6 +61,7 @@ By passing a list of bands being used to the constructor, we automatically filte ## Model Factory ### :::terratorch.models.PrithviModelFactory +### :::terratorch.models.SMPModelFactory # Adding new model types Adding new model types is as simple as creating a new factory that produces models. See for instance the example below for a potential `SMPModelFactory` diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index e0d523e1..171710f6 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -123,7 +123,7 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # output_dir = self.output_dir if not os.path.exists(output_dir): - os.mkdir(output_dir) + os.makedirs(output_dir, exist_ok=True) for pred_batch, filename_batch in predictions: for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): @@ -467,4 +467,4 @@ def inference(self, file_path: Path) -> torch.Tensor: prediction, file_name = self.inference_on_dir( tmpdir, ) - return prediction.squeeze(0) + return prediction.squeeze(0) \ No newline at end of file diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index f4947334..96624d1b 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -20,7 +20,7 @@ from torch import Tensor from torchgeo.datasets import NonGeoDataset -from terratorch.datasets.utils import HLSBands, filter_valid_files, to_tensor +from terratorch.datasets.utils import HLSBands, default_transform, filter_valid_files class GenericPixelWiseDataset(NonGeoDataset, ABC): @@ -136,7 +136,7 @@ def __init__( self.filter_indices = None # If no transform is given, apply only to transform to torch tensor - self.transform = transform if transform else lambda **batch: to_tensor(batch) + self.transform = transform if transform else default_transform # self.transform = transform if transform else ToTensorV2() def __len__(self) -> int: @@ -186,10 +186,6 @@ def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | bands.extend(expanded_element) else: bands.append(element) - # check the expansion didnt result in duplicate elements - if len(set(bands)) != len(bands): - msg = "Duplicate indices detected. Indices must be unique." - raise Exception(msg) return bands diff --git a/terratorch/datasets/generic_scalar_label_dataset.py b/terratorch/datasets/generic_scalar_label_dataset.py index bd82e3b0..f3255fc2 100644 --- a/terratorch/datasets/generic_scalar_label_dataset.py +++ b/terratorch/datasets/generic_scalar_label_dataset.py @@ -26,7 +26,7 @@ from torchgeo.datasets.utils import rasterio_loader from torchvision.datasets import ImageFolder -from terratorch.datasets.utils import HLSBands, filter_valid_files, to_tensor +from terratorch.datasets.utils import HLSBands, default_transform, filter_valid_files class GenericScalarLabelDataset(NonGeoDataset, ImageFolder, ABC): @@ -128,7 +128,7 @@ def is_valid_file(x): else: self.filter_indices = None # If no transform is given, apply only to transform to torch tensor - self.transforms = transform if transform else lambda **batch: to_tensor(batch) + self.transforms = transform if transform else default_transform # self.transform = transform if transform else ToTensorV2() def __len__(self) -> int: diff --git a/terratorch/datasets/utils.py b/terratorch/datasets/utils.py index 0dee447e..0d4065a8 100644 --- a/terratorch/datasets/utils.py +++ b/terratorch/datasets/utils.py @@ -34,6 +34,9 @@ def try_convert_to_hls_bands_enum(cls, x: Any): except ValueError: return x +def default_transform(**batch): + return to_tensor(batch) + def filter_valid_files( files, valid_files: Iterator[str] | None = None, ignore_extensions: bool = False, allow_substring: bool = True diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index a9448762..03691307 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -103,7 +103,10 @@ def checkpoint_filter_wrapper_fn(state_dict, model): kwargs = {k: v for k, v in kwargs.items() if k != "out_indices"} model.feature_info = FeatureInfo(model.feature_info, out_indices) model.encode_decode_forward = model.forward - model.forward = model.forward_features + def forward_filter_indices(*args, **kwargs): + features = model.forward_features(*args, **kwargs) + return [features[i] for i in out_indices] + model.forward = forward_filter_indices model.model_bands = model_bands model.pretrained_bands = pretrained_bands diff --git a/terratorch/models/prithvi_model_factory.py b/terratorch/models/prithvi_model_factory.py index 5347bca5..73ce819b 100644 --- a/terratorch/models/prithvi_model_factory.py +++ b/terratorch/models/prithvi_model_factory.py @@ -2,6 +2,7 @@ from collections.abc import Callable +import segmentation_models_pytorch as smp import timm import torch from torch import nn @@ -17,6 +18,7 @@ ) from terratorch.models.pixel_wise_model import PixelWiseModel from terratorch.models.scalar_output_model import ScalarOutputModel +from terratorch.models.smp_model_factory import make_smp_encoder, register_custom_encoder PIXEL_WISE_TASKS = ["segmentation", "regression"] SCALAR_TASKS = ["classification"] @@ -26,6 +28,7 @@ class DecoderNotFoundError(Exception): pass + @register_factory class PrithviModelFactory(ModelFactory): def build_model( @@ -34,7 +37,8 @@ def build_model( backbone: str | nn.Module, decoder: str | nn.Module, bands: list[HLSBands | int], - in_channels: int | None = None, # this should be removed, can be derived from bands. But it is a breaking change + in_channels: int + | None = None, # this should be removed, can be derived from bands. But it is a breaking change num_classes: int | None = None, pretrained: bool = True, # noqa: FBT001, FBT002 num_frames: int = 1, @@ -76,11 +80,6 @@ def build_model( Returns: nn.Module: Full model with encoder, decoder and head. """ - if not torch.cuda.is_available(): - self.CPU_ONLY = True - else: - self.CPU_ONLY = False - bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands] if in_channels is None: in_channels = len(bands) @@ -96,6 +95,10 @@ def build_model( raise NotImplementedError(msg) backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_") + # These params are used in case we need a SMP decoder + # but should not be used for timm encoder + output_stride = backbone_kwargs.pop("output_stride", None) + out_channels = backbone_kwargs.pop("out_channels", None) backbone: nn.Module = timm.create_model( backbone, @@ -106,14 +109,24 @@ def build_model( features_only=True, **backbone_kwargs, ) - # allow decoder to be a module passed directly - decoder_cls = _get_decoder(decoder) decoder_kwargs, kwargs = _extract_prefix_keys(kwargs, "decoder_") - # TODO: remove this - decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs) - # decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs) + if decoder.startswith("smp_"): + decoder: nn.Module = _get_smp_decoder( + decoder, + backbone_kwargs, + decoder_kwargs, + out_channels, + in_channels, + num_classes, + output_stride, + ) + else: + # allow decoder to be a module passed directly + decoder_cls = _get_decoder(decoder) + decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs) + # decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs) head_kwargs, kwargs = _extract_prefix_keys(kwargs, "head_") if num_classes: @@ -148,6 +161,46 @@ def build_model( ) +class SMPDecoderForPrithviWrapper(nn.Module): + """ + A wrapper for SMP decoders designed to handle single or multiple embeddings with specified indices. + + Attributes: + decoder (nn.Module): The SMP decoder module being wrapped. + channels (int): The number of output channels of the decoder. + in_index (Union[int, List[int]]): Index or indices of the embeddings to pass to the decoder. + + Methods: + forward(x: List[torch.Tensor]) -> torch.Tensor: + Forward pass for embeddings with specified indices. + """ + + def __init__(self, decoder, num_channels, in_index=-1) -> None: + """ + Args: + decoder (nn.Module): The SMP decoder module to be wrapped. + num_channels (int): The number of output channels of the decoder. + in_index (Union[int, List[int]], optional): Index or indices of the input embeddings to pass to the decoder. + Defaults to -1. + """ + super().__init__() + self.decoder = decoder + self.channels = num_channels + self.in_index = in_index + + @property + def output_embed_dim(self): + return self.channels + + def forward(self, x): + if isinstance(self.in_index, int): + selected_inputs = [x[self.in_index]] + else: + selected_inputs = [x[i] for i in self.in_index] + + return self.decoder(*selected_inputs) + + def _build_appropriate_model( task: str, backbone: nn.Module, @@ -178,6 +231,82 @@ def _build_appropriate_model( ) +def _get_smp_decoder( + decoder: str, + backbone_kwargs: dict, + decoder_kwargs: dict, + out_channels: list[int] | int, + in_channels: int, + num_classes: int, + output_stride: int, +): + """ + Creates and configures a decoder from the Segmentation Models Pytorch (SMP) library. + + This function constructs a decoder module based on the specified parameters and wraps it in a + custom wrapper that allows handling single or multiple embeddings. It also ensures that the + appropriate encoder parameters are passed and registered correctly. + + Args: + decoder (str): The name of the SMP decoder to use. + backbone_kwargs (dict): Dictionary of parameters for configuring the backbone. + decoder_kwargs (dict): Dictionary of parameters specific to the decoder. + out_channels (Union[list[int], int]): The number of output channels for each layer of the decoder. + in_channels (int): The number of input channels. + num_classes (int): The number of output classes for the model. + output_stride (int): The output stride of the decoder. + + Returns: + SMPDecoderForPrithviWrapper: A wrapped decoder module configured based on the provided parameters. + + Raises: + ValueError: If the specified decoder is not supported by SMP. + """ + decoder = decoder.removeprefix("smp_") + decoder_module = getattr(smp, decoder, None) + if decoder_module is None: + msg = f"Decoder {decoder} is not supported in SMP." + raise ValueError(msg) + + # Little hack to make SMP model accept our encoder. + # passes a dummy encoder to be changed later. + # this is needed to pass encoder params. + aux_kwargs, decoder_kwargs = _extract_prefix_keys(decoder_kwargs, "aux_") + smp_kwargs, decoder_kwargs = _extract_prefix_keys(decoder_kwargs, "smp_") + backbone_kwargs["out_channels"] = out_channels + backbone_kwargs["output_stride"] = output_stride + aux_kwargs = None if aux_kwargs == {} else aux_kwargs + + dummy_encoder = make_smp_encoder() + + register_custom_encoder(dummy_encoder, backbone_kwargs, None) + + dummy_encoder = dummy_encoder( + depth=smp_kwargs["encoder_depth"], + output_stride=backbone_kwargs["output_stride"], + out_channels=backbone_kwargs["out_channels"], + ) + + model_args = { + "encoder_name": "SMPEncoderWrapperWithPFFIM", + "encoder_weights": None, + "in_channels": in_channels, + "classes": num_classes, + **smp_kwargs, + } + + # Creates model with dummy encoder and decoder. + model = decoder_module(**model_args, aux_params=aux_kwargs) + + smp_decoder = SMPDecoderForPrithviWrapper( + decoder=model.decoder, + num_channels=out_channels[-1], + in_index=decoder_kwargs["in_index"], + ) + + return smp_decoder + + def _get_decoder(decoder: str | nn.Module) -> nn.Module: if isinstance(decoder, nn.Module): return decoder @@ -197,7 +326,7 @@ def _extract_prefix_keys(d: dict, prefix: str) -> dict: remaining_dict = {} for k, v in d.items(): if k.startswith(prefix): - extracted_dict[k.split(prefix)[1]] = v + extracted_dict[k[len(prefix) :]] = v else: remaining_dict[k] = v diff --git a/terratorch/models/smp_model_factory.py b/terratorch/models/smp_model_factory.py index ee58c0d1..46d8637e 100644 --- a/terratorch/models/smp_model_factory.py +++ b/terratorch/models/smp_model_factory.py @@ -1,83 +1,268 @@ # Copyright contributors to the Terratorch project -""" -This is just an example of a possible structure to include SMP models -Right now it always returns a UNET, but could easily be extended to many of the models provided by SMP. -""" +import importlib +from collections.abc import Callable import segmentation_models_pytorch as smp +import torch +import torch.nn.functional as F # noqa: N812 +from segmentation_models_pytorch.encoders import encoders as smp_encoders from torch import nn +from terratorch.datasets import HLSBands from terratorch.models.model import Model, ModelFactory, ModelOutput, register_factory +class SMPModelWrapper(Model, nn.Module): + """ + Wrapper class for SMP models. + + This class provides additional functionalities on top of SMP models. + + Attributes: + rescale (bool): Whether to rescale the output to match the input dimensions. + smp_model (nn.Module): The base SMP model being wrapped. + final_act (nn.Module): The final activation function to be applied on the output. + squeeze_single_class (bool): Whether to squeeze the output if there is a single output class. + + Methods: + forward(x: torch.Tensor) -> ModelOutput: + Forward pass through the model, optionally rescaling the output. + freeze_encoder() -> None: + Freezes the parameters of the encoder part of the model. + freeze_decoder() -> None: + Freezes the parameters of the decoder part of the model. + """ + + def __init__(self, smp_model, rescale=True, relu=False, squeeze_single_class=False) -> None: # noqa: FBT002 + super().__init__() + """ + Args: + smp_model (nn.Module): The base SMP model to be wrapped. + rescale (bool, optional): Whether to rescale the output to match the input dimensions. Defaults to True. + relu (bool, optional): Whether to apply ReLU activation on the output. + If False, Identity activation is used. Defaults to False. + squeeze_single_class (bool, optional): Whether to squeeze the output if there is a single output class. + Defaults to False. + """ + self.rescale = rescale + self.smp_model = smp_model + self.final_act = nn.ReLU() if relu else nn.Identity() + self.squeeze_single_class = squeeze_single_class + + def forward(self, x): + input_size = x.shape[-2:] + smp_output = self.smp_model(x) + smp_output = self.final_act(smp_output) + + # TODO: support auxiliary head labels + if isinstance(smp_output, tuple): + smp_output, labels = smp_output + + if smp_output.shape[1] == 1 and self.squeeze_single_class: + smp_output = smp_output.squeeze(1) + + if self.rescale and smp_output.shape[-2:] != input_size: + smp_output = F.interpolate(smp_output, size=input_size, mode="bilinear") + return ModelOutput(smp_output) + + def freeze_encoder(self): + freeze_module(self.smp_model.encoder) + + def freeze_decoder(self): + freeze_module(self.smp_model.decoder) + + @register_factory class SMPModelFactory(ModelFactory): def build_model( self, task: str, backbone: str, - decoder: str, - in_channels: int, - pretrained: str | bool | None = True, + model: str, + bands: list[HLSBands | int], + in_channels: int | None = None, num_classes: int = 1, - regression_relu: bool = False, + pretrained: str | bool | None = True, # noqa: FBT002 + prepare_features_for_image_model: Callable | None = None, + regression_relu: bool = False, # noqa: FBT001, FBT002 **kwargs, ) -> Model: - """Factory to create model based on SMP. + """ + Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization. - Args: - task (str): Must be "segmentation". - backbone (str): Name of backbone. - decoder (str): Decoder architecture. Currently only supports "unet". - in_channels (int): Number of input channels. - pretrained(str | bool): Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True. - num_classes (int): Number of classes. - regression_relu (bool). Whether to apply a ReLU if task is regression. Defaults to False. + This factory handles the instantiation of segmentation and regression models using specified + encoders and decoders from the SMP library, along with custom modifications and extensions such + as auxiliary decoders or modified encoders. + + Attributes: + task (str): Specifies the task for which the model is being built. Supported tasks are + "segmentation". + backbone (str): Specifies the backbone model to be used. + decoder (str): Specifies the decoder to be used for constructing the + segmentation model. + bands (list[terratorch.datasets.HLSBands | int]): A list specifying the bands that the model + will operate on. These are expected to be from terratorch.datasets.HLSBands. + in_channels (int, optional): Specifies the number of input channels. Defaults to None. + num_classes (int, optional): The number of output classes for the model. + pretrained (bool | Path, optional): Indicates whether to load pretrained weights for the + backbone. Can also specify a path to weights. Defaults to True. + num_frames (int, optional): Specifies the number of timesteps the model should handle. Useful + for temporal models. + regression_relu (bool): Whether to apply ReLU activation in the case of regression tasks. + **kwargs: Additional arguments that might be passed to further customize the backbone, decoder, + or any auxiliary heads. These should be prefixed appropriately + + Raises: + ValueError: If the specified decoder is not supported by SMP. + Exception: If the specified task is not "segmentation" Returns: - Model: SMP model wrapped in SMPModelWrapper. + nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified + parameters and tasks. """ - if task not in ["segmentation", "regression"]: - msg = f"SMP models can only perform pixel wise tasks, but got task {task}" + if task != "segmentation": + msg = f"SMP models can only perform segmentation, but got task {task}" raise Exception(msg) - # backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_") + + bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands] + if in_channels is None: + in_channels = len(bands) + + # Gets decoder module. + model_module = getattr(smp, model, None) + if model_module is None: + msg = f"Decoder {model} is not supported in SMP." + raise ValueError(msg) + + backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_") # Encoder params should be prefixed backbone_ + smp_kwargs = _extract_prefix_keys(backbone_kwargs, "smp_") # Smp model params should be prefixed smp_ + aux_params = _extract_prefix_keys(backbone_kwargs, "aux_") # Auxiliary head params should be prefixed aux_ + aux_params = None if aux_params == {} else aux_params + if isinstance(pretrained, bool): if pretrained: pretrained = "imagenet" else: pretrained = None - if decoder == "unet": - model = smp.Unet( - encoder_name=backbone, encoder_weights=pretrained, in_channels=in_channels, classes=num_classes - ) + + # If encoder not currently supported by SMP (custom encoder). + if backbone not in smp_encoders: + # These params must be included in the config file with appropriate prefix. + required_params = { + "encoder_depth": smp_kwargs, + "out_channels": backbone_kwargs, + "output_stride": backbone_kwargs, + } + + for param, config_dict in required_params.items(): + if param not in config_dict: + msg = f"Config must include the '{param}' parameter" + raise ValueError(msg) + + # Using new encoder. + backbone_class = make_smp_encoder(backbone) + backbone_kwargs["prepare_features_for_image_model"] = prepare_features_for_image_model + # Registering custom encoder into SMP. + register_custom_encoder(backbone_class, backbone_kwargs, pretrained) + + model_args = { + "encoder_name": "SMPEncoderWrapperWithPFFIM", + "encoder_weights": pretrained, + "in_channels": in_channels, + "classes": num_classes, + **smp_kwargs, + } + # Using SMP encoder. else: - msg = "Only unet decoder implemented" - raise NotImplementedError(msg) + model_args = { + "encoder_name": backbone, + "encoder_weights": pretrained, + "in_channels": in_channels, + "classes": num_classes, + **smp_kwargs, + } + + model = model_module(**model_args, aux_params=aux_params) + return SMPModelWrapper( model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression" ) -class SMPModelWrapper(Model, nn.Module): - def __init__(self, smp_model, relu=False, squeeze_single_class=False) -> None: - super().__init__() - self.smp_model = smp_model - self.final_act = nn.ReLU() if relu else nn.Identity() - self.squeeze_single_class = squeeze_single_class +# Registers a custom encoder into SMP. +def register_custom_encoder(encoder, params, pretrained): + smp_encoders["SMPEncoderWrapperWithPFFIM"] = { + "encoder": encoder, + "params": params, + "pretrained_settings": pretrained, + } - def forward(self, *args, **kwargs): - smp_output = self.smp_model(*args, **kwargs) - smp_output = self.final_act(smp_output) - if smp_output.shape[1] == 1 and self.squeeze_single_class: - smp_output = smp_output.squeeze(1) - return ModelOutput(smp_output) - def freeze_encoder(self): - raise NotImplementedError() +def make_smp_encoder(encoder=None): + if isinstance(encoder, str): + base_class = _get_class_from_string(encoder) + else: + base_class = nn.Module - def freeze_decoder(self): - raise NotImplementedError() + # Wrapper needed to include SMP params and PFFIM + class SMPEncoderWrapperWithPFFIM(base_class): + def __init__( + self, + depth: int, + output_stride: int, + out_channels: list[int], + prepare_features_for_image_model: Callable | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self._depth = depth + self._output_stride = output_stride + self._out_channels = out_channels + self.model = None + + if prepare_features_for_image_model: + self.prepare_features_for_image_model = prepare_features_for_image_model + elif not hasattr(super(), "prepare_features_for_image_model"): + self.prepare_features_for_image_model = lambda x: x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.model: + features = self.model(x) + if hasattr(self.model, "prepare_features_for_image_model"): + return self.model.prepare_features_for_image_model(features) + + features = super().forward(x) + return self.prepare_features_for_image_model(features) + + @property + def out_channels(self): + if hasattr(super(), "out_channels"): + return super().out_channels() + + return self._out_channels + + @property + def output_stride(self): + if hasattr(super(), "output_stride"): + return super().output_stride() + + return min(self._output_stride, 2**self._depth) + + def set_in_channels(self, in_channels, pretrained): + if hasattr(super(), "set_in_channels"): + return super().set_in_channels(in_channels, pretrained) + else: + pass + + def make_dilated(self, output_stride): + if hasattr(super(), "make_dilated"): + return super().make_dilated(output_stride) + else: + pass + + return SMPEncoderWrapperWithPFFIM def _extract_prefix_keys(d: dict, prefix: str) -> dict: @@ -92,3 +277,28 @@ def _extract_prefix_keys(d: dict, prefix: str) -> dict: del d[k] return extracted_dict + + +def _get_class_from_string(class_path): + try: + module_path, name = class_path.rsplit(".", 1) + except ValueError as vr: + msg = "Path must contain a '.' separating module from the class name" + raise ValueError(msg) from vr + + try: + module = importlib.import_module(module_path) + except ImportError as ie: + msg = f"Could not import module '{module_path}'." + raise ImportError(msg) from ie + + try: + return getattr(module, name) + except AttributeError as ae: + msg = f"The class '{name}' was not found in the module '{module_path}'." + raise AttributeError(msg) from ae + + +def freeze_module(module: nn.Module): + for param in module.parameters(): + param.requires_grad_(False) diff --git a/tests/manufactured-finetune_prithvi_swin_B.yaml b/tests/manufactured-finetune_prithvi_swin_B.yaml index b7498d1b..03cd7ea7 100644 --- a/tests/manufactured-finetune_prithvi_swin_B.yaml +++ b/tests/manufactured-finetune_prithvi_swin_B.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 diff --git a/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml b/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml index 8697cd63..61685471 100644 --- a/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml +++ b/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 diff --git a/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml b/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml index 91a72a3c..edce7f91 100644 --- a/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml +++ b/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 diff --git a/tests/manufactured-finetune_prithvi_swin_B_string.yaml b/tests/manufactured-finetune_prithvi_swin_B_string.yaml index a7aa84c2..cac4d4f7 100644 --- a/tests/manufactured-finetune_prithvi_swin_B_string.yaml +++ b/tests/manufactured-finetune_prithvi_swin_B_string.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 diff --git a/tests/manufactured-finetune_prithvi_swin_L.yaml b/tests/manufactured-finetune_prithvi_swin_L.yaml index 8619ffbf..3908d55e 100644 --- a/tests/manufactured-finetune_prithvi_swin_L.yaml +++ b/tests/manufactured-finetune_prithvi_swin_L.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 diff --git a/tests/manufactured-finetune_prithvi_vit_100.yaml b/tests/manufactured-finetune_prithvi_vit_100.yaml index 8ee70a9c..7ebf0559 100644 --- a/tests/manufactured-finetune_prithvi_vit_100.yaml +++ b/tests/manufactured-finetune_prithvi_vit_100.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 @@ -111,7 +111,6 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_frames: 1 head_dropout: 0.5708022831486758 head_final_act: torch.nn.ReLU head_learned_upscale_layers: 2 diff --git a/tests/manufactured-finetune_prithvi_vit_300.yaml b/tests/manufactured-finetune_prithvi_vit_300.yaml index 1994f0d1..cac7291d 100644 --- a/tests/manufactured-finetune_prithvi_vit_300.yaml +++ b/tests/manufactured-finetune_prithvi_vit_300.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 @@ -111,7 +111,6 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_frames: 1 head_dropout: 0.5708022831486758 head_final_act: torch.nn.ReLU head_learned_upscale_layers: 2 diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 3d953a98..d9af91d5 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -1,11 +1,13 @@ # Copyright contributors to the Terratorch project +import importlib +import os + import pytest import timm import torch -import importlib + import terratorch # noqa: F401 -import os NUM_CHANNELS = 6 NUM_FRAMES = 3 @@ -52,6 +54,14 @@ def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal): backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES) backbone(input_224_multitemporal) -#def test_swin_models_accept_non_divisible_by_patch_size(input_386): -# backbone = timm.create_model("prithvi_swin_90_us", pretrained=False, num_frames=NUM_FRAMES) -# backbone(input_386) +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) +def test_out_indices(model_name, input_224): + out_indices = [2, 4, 8, 10] + backbone = timm.create_model(model_name, pretrained=False, features_only=True, out_indices=out_indices) + assert backbone.feature_info.out_indices == out_indices + + output = backbone(input_224) + full_output = backbone.forward_features(input_224) + + for filtered_index, full_index in enumerate(out_indices): + assert torch.allclose(full_output[full_index], output[filtered_index]) diff --git a/tests/test_finetune.py b/tests/test_finetune.py index c1e47b28..11840839 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -1,51 +1,40 @@ +import os +import shutil + import pytest import timm import torch -import importlib -import terratorch -import subprocess -import os from terratorch.cli_tools import build_lightning_cli -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"]) -def test_finetune_multiple_backbones(model_name): +@pytest.fixture(autouse=True) +def setup_and_cleanup(model_name): model_instance = timm.create_model(model_name) - pretrained_bands = [0, 1, 2, 3, 4, 5] - model_bands = [0, 1, 2, 3, 4, 5] state_dict = model_instance.state_dict() - torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) + torch.save(state_dict, os.path.join("tests", model_name + ".pt")) - # Running the terratorch CLI + yield # everything after this runs after each test + + os.remove(os.path.join("tests", model_name + ".pt")) + shutil.rmtree(os.path.join("tests", "all_ecos_random")) + +@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"]) +def test_finetune_multiple_backbones(model_name): command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"] _ = build_lightning_cli(command_list) + @pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) def test_finetune_bands_intervals(model_name): - - model_instance = timm.create_model(model_name) - - state_dict = model_instance.state_dict() - - torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) - - # Running the terratorch CLI command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_band_interval.yaml"] _ = build_lightning_cli(command_list) + @pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) def test_finetune_bands_str(model_name): - - model_instance = timm.create_model(model_name) - - state_dict = model_instance.state_dict() - - torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) - - # Running the terratorch CLI command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"] _ = build_lightning_cli(command_list) diff --git a/tests/test_smp_model_factory.py b/tests/test_smp_model_factory.py new file mode 100644 index 00000000..11b0a67b --- /dev/null +++ b/tests/test_smp_model_factory.py @@ -0,0 +1,80 @@ +# Copyright contributors to the Terratorch project + +import os + +import pytest +import torch + +from terratorch.models import SMPModelFactory +from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS + +# from terratorch.models.backbones.prithvi_vit import default_cfgs as vit_default_cfgs + +NUM_CHANNELS = 6 +NUM_CLASSES = 2 +EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1, NUM_CLASSES, 224, 224) +EXPECTED_REGRESSION_OUTPUT_SHAPE = (1, 224, 224) +EXPECTED_CLASSIFICATION_OUTPUT_SHAPE = (1, NUM_CLASSES) + + +@pytest.fixture(scope="session") +def model_factory() -> SMPModelFactory: + return SMPModelFactory() + + +@pytest.fixture(scope="session") +def model_input() -> torch.Tensor: + return torch.ones((1, NUM_CHANNELS, 224, 224)) + + +@pytest.mark.parametrize("backbone", ["timm-regnetx_002"]) +@pytest.mark.parametrize("model", ["Unet", "DeepLabV3"]) +def test_create_segmentation_model(backbone, model, model_factory: SMPModelFactory, model_input): + model = model_factory.build_model( + "segmentation", + backbone=backbone, + model=model, + in_channels=NUM_CHANNELS, + bands=PRETRAINED_BANDS, + pretrained=False, + num_classes=NUM_CLASSES, + ) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE + + +@pytest.mark.parametrize("backbone", ["timm-regnetx_002"]) +@pytest.mark.parametrize("model", ["Unet", "DeepLabV3"]) +def test_create_segmentation_model_no_in_channels(backbone, model, model_factory: SMPModelFactory, model_input): + model = model_factory.build_model( + "segmentation", + backbone=backbone, + model=model, + bands=PRETRAINED_BANDS, + pretrained=False, + num_classes=NUM_CLASSES, + ) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE + + +@pytest.mark.parametrize("backbone", ["timm-regnetx_002"]) +@pytest.mark.parametrize("model", ["Unet", "DeepLabV3"]) +def test_create_model_with_extra_bands(backbone, model, model_factory: SMPModelFactory): + model = model_factory.build_model( + "segmentation", + backbone=backbone, + model=model, + in_channels=NUM_CHANNELS + 1, + bands=[*PRETRAINED_BANDS, 7], # add an extra band + pretrained=False, + num_classes=NUM_CLASSES, + ) + model.eval() + model_input = torch.ones((1, NUM_CHANNELS + 1, 224, 224)) + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE