diff --git a/examples/scripts/create_images.py b/examples/scripts/create_images.py new file mode 100644 index 00000000..034e85f9 --- /dev/null +++ b/examples/scripts/create_images.py @@ -0,0 +1,52 @@ +from PIL import Image +import os +import random +import numpy as np +import tifffile as tiff +from argparse import ArgumentParser +from osgeo import gdal +from osgeo import osr + +parser = ArgumentParser() +parser.add_argument("--input_file") +parser.add_argument("--output_dir") +parser.add_argument("--n_copies", type=int, default=2) + +args = parser.parse_args() +input_file = args.input_file +output_dir = args.output_dir +n_copies = args.n_copies + +pad_limit = 4 + +# config +GDAL_DATA_TYPE = gdal.GDT_Int32 +GEOTIFF_DRIVER_NAME = r'GTiff' +NO_DATA = 15 +SPATIAL_REFERENCE_SYSTEM_WKID = 4326 + +for c in range(n_copies): + + pad = 3#random.randint(1, pad_limit) + filename = os.path.split(input_file)[-1] + output_file = os.path.join(output_dir, filename.replace(".tif", f"_{c}.tif")) + print(pad) + imarray = tiff.imread(input_file) + im_shape = imarray.shape + im_shape_ext = tuple([i+2*pad for i in list(im_shape[:-1])]) + (im_shape[-1],) + #print(im_shape_ext) + output = np.zeros(im_shape_ext) + #print(output.shape) + output[pad:-pad, pad:-pad, :] = imarray + #print(output.shape) + #tiff.imwrite(output_file, output) + + # create driver + driver = gdal.GetDriverByName(GEOTIFF_DRIVER_NAME) + + output_raster = driver.Create(output_file, + output.shape[1], + output.shape[0], + output.shape[-1], + eType = GDAL_DATA_TYPE) + diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 136c6513..7ebfd022 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -10,6 +10,7 @@ from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights from terratorch.datasets.utils import generate_bands_intervals from terratorch.models.backbones.prithvi_mae import PrithviViT, PrithviMAE +from terratorch.models.utils import pad_images logger = logging.getLogger(__name__) @@ -153,20 +154,6 @@ def checkpoint_filter_fn_mae( return state_dict - -def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor: - p = patch_size - # h, w = imgs.shape[3], imgs.shape[4] - t, h, w = imgs.shape[-3:] - h_pad, w_pad = (p - h % p) % p, (p - w % p) % p # Ensure padding is within bounds - if h_pad > 0 or w_pad > 0: - imgs = torch.stack([ - nn.functional.pad(img, (0, w_pad, 0, h_pad), mode=padding) - for img in imgs # Apply per image to avoid NotImplementedError from torch.nn.functional.pad - ]) - return imgs - - def _create_prithvi( variant: str, pretrained: bool = False, # noqa: FBT001, FBT002 diff --git a/terratorch/models/backbones/select_patch_embed_weights.py b/terratorch/models/backbones/select_patch_embed_weights.py index 91eea253..a24328ed 100644 --- a/terratorch/models/backbones/select_patch_embed_weights.py +++ b/terratorch/models/backbones/select_patch_embed_weights.py @@ -1,6 +1,5 @@ # Copyright contributors to the Terratorch project - import logging import warnings @@ -13,7 +12,8 @@ def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoint_patch_embed: torch.Tensor) -> bool: # check all dimensions are the same except for channel dimension if len(model_patch_embed.shape) != len(checkpoint_patch_embed.shape): - return False + return False + model_shape = [model_patch_embed.shape[i] for i in range(len(model_patch_embed.shape)) if i != 1] checkpoint_shape = [checkpoint_patch_embed.shape[i] for i in range(len(checkpoint_patch_embed.shape)) if i != 1] return model_shape == checkpoint_shape @@ -82,5 +82,5 @@ def select_patch_embed_weights( ) state_dict[patch_embed_proj_weight_key] = temp_weight - + return state_dict diff --git a/terratorch/models/clay_model_factory.py b/terratorch/models/clay_model_factory.py index 82d1f183..391b93f6 100644 --- a/terratorch/models/clay_model_factory.py +++ b/terratorch/models/clay_model_factory.py @@ -1,6 +1,7 @@ import importlib import sys from collections.abc import Callable +import logging import timm import torch @@ -122,6 +123,26 @@ def build_model( backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_") + # Getting some necessary parameters + # Patch size + if "patch_size" in backbone_kwargs: + patch_size = backbone_kwargs["patch_size"] + else: + # If the configs for the model are right and images have the proper + # sizes, it can still work, but there is no way to fix possible + # errors during execution if information about patch size is not + # explicitly provided. + patch_size = None + + if "img_size" in backbone_kwargs: + img_size = backbone_kwargs["img_size"] + else: + # If the configs for the model are right and images have the proper + # sizes, it can still work, but there is no way to fix possible + # errors during execution if information about img_size is not + # provided in order to perform cropping when necessary. + img_size = None + # Trying to find the model on HuggingFace. try: backbone: nn.Module = timm.create_model( @@ -157,7 +178,7 @@ def build_model( head_kwargs["num_classes"] = num_classes if aux_decoders is None: return _build_appropriate_model( - task, backbone, decoder, head_kwargs, prepare_features_for_image_model, rescale=rescale + task, backbone, decoder, head_kwargs, prepare_features_for_image_model, patch_size=patch_size, img_size=img_size, rescale=rescale ) to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = [] @@ -186,6 +207,8 @@ def build_model( decoder, head_kwargs, prepare_features_for_image_model, + patch_size=patch_size, + img_size=img_size, rescale=rescale, auxiliary_heads=to_be_aux_decoders, ) @@ -197,6 +220,8 @@ def _build_appropriate_model( decoder: nn.Module, head_kwargs: dict, prepare_features_for_image_model: Callable, + patch_size:int=None, + img_size:int=None, rescale: bool = True, # noqa: FBT001, FBT002 auxiliary_heads: dict | None = None, ): @@ -206,6 +231,8 @@ def _build_appropriate_model( backbone, decoder, head_kwargs, + patch_size=patch_size, + img_size=img_size, rescale=rescale, auxiliary_heads=auxiliary_heads, ) @@ -215,6 +242,8 @@ def _build_appropriate_model( backbone, decoder, head_kwargs, + patch_size=patch_size, + img_size=img_size, auxiliary_heads=auxiliary_heads, ) diff --git a/terratorch/models/encoder_decoder_factory.py b/terratorch/models/encoder_decoder_factory.py index 04727265..2bad57e4 100644 --- a/terratorch/models/encoder_decoder_factory.py +++ b/terratorch/models/encoder_decoder_factory.py @@ -2,7 +2,7 @@ import warnings - +import logging from torch import nn from terratorch.models.model import ( @@ -65,6 +65,8 @@ def _check_all_args_used(kwargs): msg = f"arguments {kwargs} were passed but not used." raise ValueError(msg) +def _get_argument_from_instance(model, name): + return getattr(model._timm_module.patch_embed, name)[-1] @MODEL_FACTORY_REGISTRY.register class EncoderDecoderFactory(ModelFactory): @@ -128,6 +130,26 @@ def build_model( backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_") backbone = _get_backbone(backbone, **backbone_kwargs) + # Getting some necessary parameters + # Patch size + if "patch_size" in backbone_kwargs: + patch_size = backbone_kwargs["patch_size"] + else: + # If the configs for the model are right and images have the proper + # sizes, it can still work, but there is no way to fix possible + # errors during execution if information about patch size is not + # explicitly provided. + patch_size = None + + if "img_size" in backbone_kwargs: + img_size = backbone_kwargs["img_size"] + else: + # If the configs for the model are right and images have the proper + # sizes, it can still work, but there is no way to fix possible + # errors during execution if information about img_size is not + # provided in order to perform cropping when necessary. + img_size = None + if peft_config is not None: if not backbone_kwargs.get("pretrained", False): msg = ( @@ -166,6 +188,8 @@ def build_model( backbone, decoder, head_kwargs, + patch_size=patch_size, + img_size=img_size, necks=neck_list, decoder_includes_head=decoder_includes_head, rescale=rescale, @@ -191,6 +215,8 @@ def build_model( backbone, decoder, head_kwargs, + patch_size=patch_size, + img_size=img_size, necks=neck_list, decoder_includes_head=decoder_includes_head, rescale=rescale, @@ -203,6 +229,8 @@ def _build_appropriate_model( backbone: nn.Module, decoder: nn.Module, head_kwargs: dict, + patch_size: int, + img_size:int, decoder_includes_head: bool = False, necks: list[Neck] | None = None, rescale: bool = True, # noqa: FBT001, FBT002 @@ -218,6 +246,8 @@ def _build_appropriate_model( backbone, decoder, head_kwargs, + patch_size=patch_size, + img_size=img_size, decoder_includes_head=decoder_includes_head, neck=neck_module, rescale=rescale, @@ -229,6 +259,8 @@ def _build_appropriate_model( backbone, decoder, head_kwargs, + patch_size=patch_size, + img_size=img_size, decoder_includes_head=decoder_includes_head, neck=neck_module, auxiliary_heads=auxiliary_heads, diff --git a/terratorch/models/pixel_wise_model.py b/terratorch/models/pixel_wise_model.py index 6b9145c8..bc01f173 100644 --- a/terratorch/models/pixel_wise_model.py +++ b/terratorch/models/pixel_wise_model.py @@ -1,13 +1,14 @@ # Copyright contributors to the Terratorch project - +import logging import torch import torch.nn.functional as F # noqa: N812 +import torchvision.transforms as transforms from segmentation_models_pytorch.base import SegmentationModel from torch import nn from terratorch.models.heads import RegressionHead, SegmentationHead from terratorch.models.model import AuxiliaryHeadWithDecoderWithoutInstantiatedHead, Model, ModelOutput - +from terratorch.models.utils import pad_images def freeze_module(module: nn.Module): for param in module.parameters(): @@ -26,6 +27,8 @@ def __init__( encoder: nn.Module, decoder: nn.Module, head_kwargs: dict, + patch_size: int = None, + img_size:tuple = None, decoder_includes_head: bool = False, auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None, neck: nn.Module | None = None, @@ -69,6 +72,8 @@ def __init__( self.neck = neck self.rescale = rescale + self.patch_size = patch_size + self.img_size = (img_size, img_size) def freeze_encoder(self): freeze_module(self.encoder) @@ -77,9 +82,31 @@ def freeze_decoder(self): freeze_module(self.decoder) freeze_module(self.head) - # TODO: do this properly - def check_input_shape(self, x: torch.Tensor) -> bool: # noqa: ARG002 - return True + def check_input_shape(self, x: torch.Tensor) -> torch.Tensor: + + if self.patch_size: + x_shape = x.shape[2:] + if all([i//self.patch_size==0 for i in x_shape]): + return x + else: + x = pad_images(x, self.patch_size, "constant") + + return x + else: + # If patch size is not provided, the user should guarantee the + # dataset is properly configured to work with the model being used. + return x + + def _crop_image_when_necessary(self, x:torch.Tensor, size:tuple) -> torch.Tensor: + + if all(self.img_size): + + x_cropped = transforms.CenterCrop(self.img_size)(x) + return x_cropped + else: + logging.getLogger("terratorch").info("Cropping could be necessary to adjust images, so define `img_size` in your config file \ + if you get a shape mismatch.") + return x @staticmethod def _check_for_single_channel_and_squeeze(x): @@ -89,7 +116,7 @@ def _check_for_single_channel_and_squeeze(x): def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: """Sequentially pass `x` through model`s encoder, decoder and heads""" - self.check_input_shape(x) + if isinstance(x, torch.Tensor): input_size = x.shape[-2:] elif hasattr(kwargs, 'image_size'): @@ -99,6 +126,9 @@ def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: input_size = list(x.values())[0].shape[-2:] else: ValueError('Could not infer input shape.') + + # TODO make this verification optional to avoid unnecessary repetition + x = self.check_input_shape(x) features = self.encoder(x, **kwargs) ## only for backwards compatibility with pre-neck times. @@ -114,6 +144,7 @@ def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: if self.rescale and mask.shape[-2:] != input_size: mask = F.interpolate(mask, size=input_size, mode="bilinear") mask = self._check_for_single_channel_and_squeeze(mask) + aux_outputs = {} for name, decoder in self.aux_heads.items(): aux_output = decoder([f.clone() for f in features]) @@ -121,6 +152,8 @@ def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: aux_output = F.interpolate(aux_output, size=input_size, mode="bilinear") aux_output = self._check_for_single_channel_and_squeeze(aux_output) aux_outputs[name] = aux_output + + mask = self._crop_image_when_necessary(mask, input_size) return ModelOutput(output=mask, auxiliary_heads=aux_outputs) def _get_head(self, task: str, input_embed_dim: int, head_kwargs): diff --git a/terratorch/models/scalar_output_model.py b/terratorch/models/scalar_output_model.py index 76cf653e..d6bf6188 100644 --- a/terratorch/models/scalar_output_model.py +++ b/terratorch/models/scalar_output_model.py @@ -3,9 +3,10 @@ import torch from segmentation_models_pytorch.base import SegmentationModel from torch import nn - +import torchvision.transforms as transforms from terratorch.models.heads import ClassificationHead from terratorch.models.model import AuxiliaryHeadWithDecoderWithoutInstantiatedHead, Model, ModelOutput +from terratorch.models.utils import pad_images import pdb def freeze_module(module: nn.Module): @@ -25,6 +26,8 @@ def __init__( encoder: nn.Module, decoder: nn.Module, head_kwargs: dict, + patch_size:int = None, + img_size:tuple = None, decoder_includes_head: bool = False, auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None, neck: nn.Module | None = None, @@ -63,6 +66,8 @@ def __init__( self.aux_heads = nn.ModuleDict(aux_heads) self.neck = neck + self.patch_size = patch_size + self.img_size = (img_size, img_size) def freeze_encoder(self): freeze_module(self.encoder) @@ -71,15 +76,40 @@ def freeze_decoder(self): freeze_module(self.decoder) freeze_module(self.head) - # TODO: do this properly - def check_input_shape(self, x: torch.Tensor) -> bool: # noqa: ARG002 - return True + def check_input_shape(self, x: torch.Tensor) -> torch.Tensor: # noqa: ARG002 + + if self.patch_size: + x_shape = x.shape[2:] + if all([i//self.patch_size==0 for i in x_shape]): + return x + else: + x = pad_images(x, self.patch_size, "constant") + return x + else: + # If patch size is not provided, the user should guarantee the + # dataset is properly configured to work with the model being used. + return x + + def _crop_image_when_necessary(self, x:torch.Tensor, size:tuple) -> torch.Tensor: + + if self.img_size: + + return transforms.CenterCrop(self.img_size)(x) + else: + raise NameError("Cropping is necessary to adjust images, so define `img_size` in your config file.") + logging.getLogger("terratorch").info("Cropping could be necessary to adjust images, so define `img_size` in your config file \ + if you get a shape mismatch.") def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: """Sequentially pass `x` through model`s encoder, decoder and heads""" - self.check_input_shape(x) + + x = self.check_input_shape(x) features = self.encoder(x, **kwargs) + # Collecting information about the size of the input tensor in order to + # use it to possibly crop the image when necessary. + input_size = x.shape[-2:] + ## only for backwards compatibility with pre-neck times. if self.neck: prepare = self.neck @@ -91,10 +121,12 @@ def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: decoder_output = self.decoder([f.clone() for f in features]) mask = self.head(decoder_output) + aux_outputs = {} for name, decoder in self.aux_heads.items(): aux_output = decoder([f.clone() for f in features]) aux_outputs[name] = aux_output + return ModelOutput(output=mask, auxiliary_heads=aux_outputs) def _get_head(self, task: str, input_embed_dim: int, head_kwargs: dict): diff --git a/terratorch/models/utils.py b/terratorch/models/utils.py index cf0e3537..5704eb69 100644 --- a/terratorch/models/utils.py +++ b/terratorch/models/utils.py @@ -1,3 +1,6 @@ +from torch import nn, Tensor +import torch + class DecoderNotFoundError(Exception): pass @@ -11,3 +14,16 @@ def extract_prefix_keys(d: dict, prefix: str) -> dict: remaining_dict[k] = v return extracted_dict, remaining_dict + +def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor: + p = patch_size + # h, w = imgs.shape[3], imgs.shape[4] + t, h, w = imgs.shape[-3:] + h_pad, w_pad = (p - h % p) % p, (p - w % p) % p # Ensure padding is within bounds + if h_pad > 0 or w_pad > 0: + imgs = torch.stack([ + nn.functional.pad(img, (0, w_pad, 0, h_pad), mode=padding) + for img in imgs # Apply per image to avoid NotImplementedError from torch.nn.functional.pad + ]) + return imgs + diff --git a/tests/resources/configs/manufactured-finetune_prithvi_eo_v2_300.yaml b/tests/resources/configs/manufactured-finetune_prithvi_eo_v2_300.yaml index 3e44a1c5..a60517da 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_eo_v2_300.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_eo_v2_300.yaml @@ -20,7 +20,7 @@ trainer: init_args: monitor: val/loss patience: 100 - max_epochs: 2 + max_epochs: 1 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true @@ -100,6 +100,7 @@ model: # backbone_pretrained_cfg_overlay: # file: tests/prithvi_vit_300.pt backbone_drop_path_rate: 0.3 + backbone_patch_size: 16 # backbone_window_size: 8 decoder_channels: 64 num_frames: 1 diff --git a/tests/resources/configs/manufactured-finetune_prithvi_eo_v2_600.yaml b/tests/resources/configs/manufactured-finetune_prithvi_eo_v2_600.yaml index 292d229c..f6652815 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_eo_v2_600.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_eo_v2_600.yaml @@ -20,7 +20,7 @@ trainer: init_args: monitor: val/loss patience: 100 - max_epochs: 2 + max_epochs: 1 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true diff --git a/tests/resources/configs/manufactured-finetune_prithvi_pixelwise_nondivisible.yaml b/tests/resources/configs/manufactured-finetune_prithvi_pixelwise_nondivisible.yaml new file mode 100644 index 00000000..7fc0b834 --- /dev/null +++ b/tests/resources/configs/manufactured-finetune_prithvi_pixelwise_nondivisible.yaml @@ -0,0 +1,151 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: cpu + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 1 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + #- class_path: albumentations.HorizontalFlip + # init_args: + # p: 0.5 + #- class_path: albumentations.Rotate + # init_args: + # limit: 30 + # border_mode: 0 # cv2.BORDER_CONSTANT + # value: 0 + # # mask_value: 1 + # p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - 0 + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + - 1 + - 2 + - 3 + - 4 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/resources/inputs_extra + train_label_data_root: tests/resources/inputs_extra + val_data_root: tests/resources/inputs_extra + val_label_data_root: tests/resources/inputs_extra + test_data_root: tests/resources/inputs_extra + test_label_data_root: tests/resources/inputs_extra + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: false + backbone: prithvi_eo_v2_300 + # backbone_pretrained_cfg_overlay: + # file: tests/prithvi_vit_300.pt + backbone_drop_path_rate: 0.3 + backbone_img_size: 224 + # backbone_window_size: 8 + decoder_channels: 64 + num_frames: 1 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/resources/configs/manufactured-finetune_prithvi_pixelwise_pad.yaml b/tests/resources/configs/manufactured-finetune_prithvi_pixelwise_pad.yaml new file mode 100644 index 00000000..7e8ef8b7 --- /dev/null +++ b/tests/resources/configs/manufactured-finetune_prithvi_pixelwise_pad.yaml @@ -0,0 +1,151 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: cpu + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 1 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + #- class_path: albumentations.HorizontalFlip + # init_args: + # p: 0.5 + #- class_path: albumentations.Rotate + # init_args: + # limit: 30 + # border_mode: 0 # cv2.BORDER_CONSTANT + # value: 0 + # # mask_value: 1 + # p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - 0 + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + - 1 + - 2 + - 3 + - 4 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/resources/inputs + train_label_data_root: tests/resources/inputs + val_data_root: tests/resources/inputs + val_label_data_root: tests/resources/inputs + test_data_root: tests/resources/inputs + test_label_data_root: tests/resources/inputs + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: false + backbone: prithvi_eo_v2_300 + # backbone_pretrained_cfg_overlay: + # file: tests/prithvi_vit_300.pt + backbone_drop_path_rate: 0.3 + # backbone_window_size: 8 + backbone_patch_size: 13 + decoder_channels: 64 + num_frames: 1 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/resources/configs/manufactured-finetune_prithvi_swin_B.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_B.yaml index 065caa02..cea8a0ea 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_swin_B.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_swin_B.yaml @@ -20,7 +20,7 @@ trainer: init_args: monitor: val/loss patience: 100 - max_epochs: 3 + max_epochs: 1 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true diff --git a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_band_interval.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_band_interval.yaml index a9d4145e..9f5fc50c 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_band_interval.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_band_interval.yaml @@ -20,7 +20,7 @@ trainer: init_args: monitor: val/loss patience: 100 - max_epochs: 2 + max_epochs: 1 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true diff --git a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml index 9005547b..95907310 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml @@ -20,7 +20,7 @@ trainer: init_args: monitor: val/loss patience: 100 - max_epochs: 2 + max_epochs: 1 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true diff --git a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_string.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_string.yaml index 73813b6d..746175a2 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_string.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_string.yaml @@ -20,7 +20,7 @@ trainer: init_args: monitor: val/loss patience: 100 - max_epochs: 2 + max_epochs: 1 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true diff --git a/tests/resources/configs/manufactured-finetune_prithvi_swin_L.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_L.yaml index 16729210..453a8b4a 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_swin_L.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_swin_L.yaml @@ -20,7 +20,7 @@ trainer: init_args: monitor: val/loss patience: 100 - max_epochs: 2 + max_epochs: 1 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true @@ -31,16 +31,16 @@ data: batch_size: 2 num_workers: 4 train_transform: - #- class_path: albumentations.HorizontalFlip - # init_args: - # p: 0.5 - #- class_path: albumentations.Rotate - # init_args: - # limit: 30 - # border_mode: 0 # cv2.BORDER_CONSTANT - # value: 0 - # # mask_value: 1 - # p: 0.5 + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 - class_path: ToTensorV2 dataset_bands: - 0 diff --git a/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml b/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml index bb652415..3a696132 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml @@ -20,7 +20,7 @@ trainer: init_args: monitor: val/loss patience: 100 - max_epochs: 2 + max_epochs: 1 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true @@ -97,7 +97,7 @@ model: decoder: UperNetDecoder pretrained: false backbone: prithvi_vit_100 - #backbone_pretrained_cfg_overlay: + #backbone_pretrained_cfg_overlay: #file: tests/all_ecos_random/version_0/checkpoints/epoch=0_state_dict.ckpt #tests/prithvi_vit_100.pt backbone_drop_path_rate: 0.3 num_frames: 1 diff --git a/tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml b/tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml index 3e44a1c5..37294615 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml @@ -20,7 +20,7 @@ trainer: init_args: monitor: val/loss patience: 100 - max_epochs: 2 + max_epochs: 1 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true diff --git a/tests/resources/inputs_extra/regression_test_input_0.tif b/tests/resources/inputs_extra/regression_test_input_0.tif new file mode 100644 index 00000000..aceab23e Binary files /dev/null and b/tests/resources/inputs_extra/regression_test_input_0.tif differ diff --git a/tests/resources/inputs_extra/regression_test_input_1.tif b/tests/resources/inputs_extra/regression_test_input_1.tif new file mode 100644 index 00000000..aceab23e Binary files /dev/null and b/tests/resources/inputs_extra/regression_test_input_1.tif differ diff --git a/tests/resources/inputs_extra/regression_test_input_2.tif b/tests/resources/inputs_extra/regression_test_input_2.tif new file mode 100644 index 00000000..aceab23e Binary files /dev/null and b/tests/resources/inputs_extra/regression_test_input_2.tif differ diff --git a/tests/resources/inputs_extra/regression_test_input_3.tif b/tests/resources/inputs_extra/regression_test_input_3.tif new file mode 100644 index 00000000..aceab23e Binary files /dev/null and b/tests/resources/inputs_extra/regression_test_input_3.tif differ diff --git a/tests/resources/inputs_extra/regression_test_label_0.tif b/tests/resources/inputs_extra/regression_test_label_0.tif new file mode 100644 index 00000000..521be614 Binary files /dev/null and b/tests/resources/inputs_extra/regression_test_label_0.tif differ diff --git a/tests/resources/inputs_extra/regression_test_label_1.tif b/tests/resources/inputs_extra/regression_test_label_1.tif new file mode 100644 index 00000000..521be614 Binary files /dev/null and b/tests/resources/inputs_extra/regression_test_label_1.tif differ diff --git a/tests/resources/inputs_extra/regression_test_label_2.tif b/tests/resources/inputs_extra/regression_test_label_2.tif new file mode 100644 index 00000000..521be614 Binary files /dev/null and b/tests/resources/inputs_extra/regression_test_label_2.tif differ diff --git a/tests/resources/inputs_extra/regression_test_label_3.tif b/tests/resources/inputs_extra/regression_test_label_3.tif new file mode 100644 index 00000000..521be614 Binary files /dev/null and b/tests/resources/inputs_extra/regression_test_label_3.tif differ diff --git a/tests/test_finetune.py b/tests/test_finetune.py index 9c06e8da..badccfb1 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -46,6 +46,22 @@ def test_finetune_bands_str(model_name, case): gc.collect() +@pytest.mark.parametrize("model_name", ["prithvi_eo_v2_300"]) +@pytest.mark.parametrize("case", ["fit", "test", "validate"]) +def test_finetune_pad(case): + command_list = [case, "-c", f"tests/resources/configs/manufactured-finetune_prithvi_pixelwise_pad.yaml"] + _ = build_lightning_cli(command_list) + + gc.collect() + +@pytest.mark.parametrize("model_name", ["prithvi_eo_v2_300"]) +@pytest.mark.parametrize("case", ["fit", "test", "validate"]) +def test_finetune_pad(case): + command_list = [case, "-c", f"tests/resources/configs/manufactured-finetune_prithvi_pixelwise_nondivisible.yaml"] + _ = build_lightning_cli(command_list) + + gc.collect() + @pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) def test_finetune_metrics_from_file(model_name):