Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding padding at the input when necessary #342

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
842941b
Adding padding at the input when necessary
Joao-L-S-Almeida Dec 23, 2024
6528861
patch_size as a explicit argument for PixelWiseModel
Joao-L-S-Almeida Jan 2, 2025
8376a5e
logging
Joao-L-S-Almeida Jan 2, 2025
8fa3bba
Cropping image
Joao-L-S-Almeida Jan 2, 2025
6fb8c95
cropping image for scaler model
Joao-L-S-Almeida Jan 2, 2025
5f37ba7
patch_size could be None
Joao-L-S-Almeida Jan 3, 2025
5cb27dc
Adapting the Clay factory to support patch_size and minor adjusts
Joao-L-S-Almeida Jan 3, 2025
ba43134
Trying to reduce the cost of these tests
Joao-L-S-Almeida Jan 3, 2025
9c26eab
pad_images must be in utils.py
Joao-L-S-Almeida Jan 3, 2025
c4fd736
Cropping images could be a necessary operation
Joao-L-S-Almeida Jan 6, 2025
b70a368
The cropping must be placed before the head in case of scalar models
Joao-L-S-Almeida Jan 6, 2025
ecca3aa
Creating extra images for tests
Joao-L-S-Almeida Jan 6, 2025
e09cd79
Minor changes
Joao-L-S-Almeida Jan 6, 2025
6fdf1b7
img_size also could be necessary
Joao-L-S-Almeida Jan 6, 2025
6178b47
conditional cropping
Joao-L-S-Almeida Jan 6, 2025
8cb6d26
config for testing nondivisible images
Joao-L-S-Almeida Jan 6, 2025
cb62e56
minor adjusts
Joao-L-S-Almeida Jan 6, 2025
a0cac1c
minor adjusts
Joao-L-S-Almeida Jan 6, 2025
ca881b4
Input files to be used for testing the padding for non-divisible images
Joao-L-S-Almeida Jan 6, 2025
62fa305
minor changes
Joao-L-S-Almeida Jan 6, 2025
1c409e8
more tests
Joao-L-S-Almeida Jan 6, 2025
0d79f8e
merging
Joao-L-S-Almeida Jan 6, 2025
42c3d98
merging
Joao-L-S-Almeida Jan 6, 2025
fd1599f
merging
Joao-L-S-Almeida Jan 6, 2025
41af8f7
merging
Joao-L-S-Almeida Jan 6, 2025
dde31bb
merging
Joao-L-S-Almeida Jan 6, 2025
e04e53e
argument not used
Joao-L-S-Almeida Jan 6, 2025
8cb2ff9
merging
Joao-L-S-Almeida Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions examples/scripts/create_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from PIL import Image
import os
import random
import numpy as np
import tifffile as tiff
from argparse import ArgumentParser
from osgeo import gdal
from osgeo import osr

parser = ArgumentParser()
parser.add_argument("--input_file")
parser.add_argument("--output_dir")
parser.add_argument("--n_copies", type=int, default=2)

args = parser.parse_args()
input_file = args.input_file
output_dir = args.output_dir
n_copies = args.n_copies

pad_limit = 4

# config
GDAL_DATA_TYPE = gdal.GDT_Int32
GEOTIFF_DRIVER_NAME = r'GTiff'
NO_DATA = 15
SPATIAL_REFERENCE_SYSTEM_WKID = 4326

for c in range(n_copies):

pad = 3#random.randint(1, pad_limit)
filename = os.path.split(input_file)[-1]
output_file = os.path.join(output_dir, filename.replace(".tif", f"_{c}.tif"))
print(pad)
imarray = tiff.imread(input_file)
im_shape = imarray.shape
im_shape_ext = tuple([i+2*pad for i in list(im_shape[:-1])]) + (im_shape[-1],)
#print(im_shape_ext)
output = np.zeros(im_shape_ext)
#print(output.shape)
output[pad:-pad, pad:-pad, :] = imarray
#print(output.shape)
#tiff.imwrite(output_file, output)

# create driver
driver = gdal.GetDriverByName(GEOTIFF_DRIVER_NAME)

output_raster = driver.Create(output_file,
output.shape[1],
output.shape[0],
output.shape[-1],
eType = GDAL_DATA_TYPE)

15 changes: 1 addition & 14 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights
from terratorch.datasets.utils import generate_bands_intervals
from terratorch.models.backbones.prithvi_mae import PrithviViT, PrithviMAE
from terratorch.models.utils import pad_images

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -153,20 +154,6 @@ def checkpoint_filter_fn_mae(

return state_dict


def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor:
p = patch_size
# h, w = imgs.shape[3], imgs.shape[4]
t, h, w = imgs.shape[-3:]
h_pad, w_pad = (p - h % p) % p, (p - w % p) % p # Ensure padding is within bounds
if h_pad > 0 or w_pad > 0:
imgs = torch.stack([
nn.functional.pad(img, (0, w_pad, 0, h_pad), mode=padding)
for img in imgs # Apply per image to avoid NotImplementedError from torch.nn.functional.pad
])
return imgs


def _create_prithvi(
variant: str,
pretrained: bool = False, # noqa: FBT001, FBT002
Expand Down
6 changes: 3 additions & 3 deletions terratorch/models/backbones/select_patch_embed_weights.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright contributors to the Terratorch project


import logging
import warnings

Expand All @@ -13,7 +12,8 @@
def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoint_patch_embed: torch.Tensor) -> bool:
# check all dimensions are the same except for channel dimension
if len(model_patch_embed.shape) != len(checkpoint_patch_embed.shape):
return False
return False

model_shape = [model_patch_embed.shape[i] for i in range(len(model_patch_embed.shape)) if i != 1]
checkpoint_shape = [checkpoint_patch_embed.shape[i] for i in range(len(checkpoint_patch_embed.shape)) if i != 1]
return model_shape == checkpoint_shape
Expand Down Expand Up @@ -82,5 +82,5 @@ def select_patch_embed_weights(
)

state_dict[patch_embed_proj_weight_key] = temp_weight

return state_dict
31 changes: 30 additions & 1 deletion terratorch/models/clay_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
import sys
from collections.abc import Callable
import logging

import timm
import torch
Expand Down Expand Up @@ -122,6 +123,26 @@ def build_model(

backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")

# Getting some necessary parameters
# Patch size
if "patch_size" in backbone_kwargs:
patch_size = backbone_kwargs["patch_size"]
else:
# If the configs for the model are right and images have the proper
# sizes, it can still work, but there is no way to fix possible
# errors during execution if information about patch size is not
# explicitly provided.
patch_size = None

if "img_size" in backbone_kwargs:
img_size = backbone_kwargs["img_size"]
else:
# If the configs for the model are right and images have the proper
# sizes, it can still work, but there is no way to fix possible
# errors during execution if information about img_size is not
# provided in order to perform cropping when necessary.
img_size = None

# Trying to find the model on HuggingFace.
try:
backbone: nn.Module = timm.create_model(
Expand Down Expand Up @@ -157,7 +178,7 @@ def build_model(
head_kwargs["num_classes"] = num_classes
if aux_decoders is None:
return _build_appropriate_model(
task, backbone, decoder, head_kwargs, prepare_features_for_image_model, rescale=rescale
task, backbone, decoder, head_kwargs, prepare_features_for_image_model, patch_size=patch_size, img_size=img_size, rescale=rescale
)

to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
Expand Down Expand Up @@ -186,6 +207,8 @@ def build_model(
decoder,
head_kwargs,
prepare_features_for_image_model,
patch_size=patch_size,
img_size=img_size,
rescale=rescale,
auxiliary_heads=to_be_aux_decoders,
)
Expand All @@ -197,6 +220,8 @@ def _build_appropriate_model(
decoder: nn.Module,
head_kwargs: dict,
prepare_features_for_image_model: Callable,
patch_size:int=None,
img_size:int=None,
rescale: bool = True, # noqa: FBT001, FBT002
auxiliary_heads: dict | None = None,
):
Expand All @@ -206,6 +231,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
img_size=img_size,
rescale=rescale,
auxiliary_heads=auxiliary_heads,
)
Expand All @@ -215,6 +242,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
img_size=img_size,
auxiliary_heads=auxiliary_heads,
)

Expand Down
34 changes: 33 additions & 1 deletion terratorch/models/encoder_decoder_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


import warnings

import logging
from torch import nn

from terratorch.models.model import (
Expand Down Expand Up @@ -65,6 +65,8 @@ def _check_all_args_used(kwargs):
msg = f"arguments {kwargs} were passed but not used."
raise ValueError(msg)

def _get_argument_from_instance(model, name):
return getattr(model._timm_module.patch_embed, name)[-1]

@MODEL_FACTORY_REGISTRY.register
class EncoderDecoderFactory(ModelFactory):
Expand Down Expand Up @@ -128,6 +130,26 @@ def build_model(
backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
backbone = _get_backbone(backbone, **backbone_kwargs)

# Getting some necessary parameters
# Patch size
if "patch_size" in backbone_kwargs:
patch_size = backbone_kwargs["patch_size"]
else:
# If the configs for the model are right and images have the proper
# sizes, it can still work, but there is no way to fix possible
# errors during execution if information about patch size is not
# explicitly provided.
patch_size = None

if "img_size" in backbone_kwargs:
img_size = backbone_kwargs["img_size"]
else:
# If the configs for the model are right and images have the proper
# sizes, it can still work, but there is no way to fix possible
# errors during execution if information about img_size is not
# provided in order to perform cropping when necessary.
img_size = None

if peft_config is not None:
if not backbone_kwargs.get("pretrained", False):
msg = (
Expand Down Expand Up @@ -166,6 +188,8 @@ def build_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
img_size=img_size,
necks=neck_list,
decoder_includes_head=decoder_includes_head,
rescale=rescale,
Expand All @@ -191,6 +215,8 @@ def build_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
img_size=img_size,
necks=neck_list,
decoder_includes_head=decoder_includes_head,
rescale=rescale,
Expand All @@ -203,6 +229,8 @@ def _build_appropriate_model(
backbone: nn.Module,
decoder: nn.Module,
head_kwargs: dict,
patch_size: int,
img_size:int,
decoder_includes_head: bool = False,
necks: list[Neck] | None = None,
rescale: bool = True, # noqa: FBT001, FBT002
Expand All @@ -218,6 +246,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
img_size=img_size,
decoder_includes_head=decoder_includes_head,
neck=neck_module,
rescale=rescale,
Expand All @@ -229,6 +259,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
img_size=img_size,
decoder_includes_head=decoder_includes_head,
neck=neck_module,
auxiliary_heads=auxiliary_heads,
Expand Down
45 changes: 39 additions & 6 deletions terratorch/models/pixel_wise_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright contributors to the Terratorch project

import logging
import torch
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from segmentation_models_pytorch.base import SegmentationModel
from torch import nn

from terratorch.models.heads import RegressionHead, SegmentationHead
from terratorch.models.model import AuxiliaryHeadWithDecoderWithoutInstantiatedHead, Model, ModelOutput

from terratorch.models.utils import pad_images

def freeze_module(module: nn.Module):
for param in module.parameters():
Expand All @@ -26,6 +27,8 @@ def __init__(
encoder: nn.Module,
decoder: nn.Module,
head_kwargs: dict,
patch_size: int = None,
img_size:tuple = None,
decoder_includes_head: bool = False,
auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
neck: nn.Module | None = None,
Expand Down Expand Up @@ -69,6 +72,8 @@ def __init__(

self.neck = neck
self.rescale = rescale
self.patch_size = patch_size
self.img_size = (img_size, img_size)

def freeze_encoder(self):
freeze_module(self.encoder)
Expand All @@ -77,9 +82,31 @@ def freeze_decoder(self):
freeze_module(self.decoder)
freeze_module(self.head)

# TODO: do this properly
def check_input_shape(self, x: torch.Tensor) -> bool: # noqa: ARG002
return True
def check_input_shape(self, x: torch.Tensor) -> torch.Tensor:

if self.patch_size:
x_shape = x.shape[2:]
if all([i//self.patch_size==0 for i in x_shape]):
return x
else:
x = pad_images(x, self.patch_size, "constant")

return x
else:
# If patch size is not provided, the user should guarantee the
# dataset is properly configured to work with the model being used.
return x

def _crop_image_when_necessary(self, x:torch.Tensor, size:tuple) -> torch.Tensor:

if all(self.img_size):

x_cropped = transforms.CenterCrop(self.img_size)(x)
return x_cropped
else:
logging.getLogger("terratorch").info("Cropping could be necessary to adjust images, so define `img_size` in your config file \
if you get a shape mismatch.")
return x

@staticmethod
def _check_for_single_channel_and_squeeze(x):
Expand All @@ -89,7 +116,7 @@ def _check_for_single_channel_and_squeeze(x):

def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput:
"""Sequentially pass `x` through model`s encoder, decoder and heads"""
self.check_input_shape(x)

if isinstance(x, torch.Tensor):
input_size = x.shape[-2:]
elif hasattr(kwargs, 'image_size'):
Expand All @@ -99,6 +126,9 @@ def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput:
input_size = list(x.values())[0].shape[-2:]
else:
ValueError('Could not infer input shape.')

# TODO make this verification optional to avoid unnecessary repetition
x = self.check_input_shape(x)
features = self.encoder(x, **kwargs)

## only for backwards compatibility with pre-neck times.
Expand All @@ -114,13 +144,16 @@ def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput:
if self.rescale and mask.shape[-2:] != input_size:
mask = F.interpolate(mask, size=input_size, mode="bilinear")
mask = self._check_for_single_channel_and_squeeze(mask)

aux_outputs = {}
for name, decoder in self.aux_heads.items():
aux_output = decoder([f.clone() for f in features])
if self.rescale and aux_output.shape[-2:] != input_size:
aux_output = F.interpolate(aux_output, size=input_size, mode="bilinear")
aux_output = self._check_for_single_channel_and_squeeze(aux_output)
aux_outputs[name] = aux_output

mask = self._crop_image_when_necessary(mask, input_size)
return ModelOutput(output=mask, auxiliary_heads=aux_outputs)
Joao-L-S-Almeida marked this conversation as resolved.
Show resolved Hide resolved

def _get_head(self, task: str, input_embed_dim: int, head_kwargs):
Expand Down
Loading
Loading