From 64ab95ac74195774d6c7e055d660330d2e843114 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 26 Jul 2024 14:26:00 -0300 Subject: [PATCH 1/7] Metrics as std and mean could be informed via file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../generic_pixel_wise_data_module.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index 8dba1c20..f02f7d5b 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -3,11 +3,11 @@ """ This module contains generic data modules for instantiation at runtime. """ - +import os from collections.abc import Callable, Iterable from pathlib import Path from typing import Any - +import numpy as np import albumentations as A import kornia.augmentation as K import torch @@ -23,6 +23,20 @@ def wrap_in_compose_is_list(transform_list): # set check shapes to false because of the multitemporal case return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list +def load_from_file_or_attribute(value:list[float] | str): + + if type(value) == list: + return value + elif type(str): # It can be the path for a file + if os.path.isfile(value): + try: + content = np.genfromtxt(value).tolist() + except: + raise Exception(f"File must be txt, but received {value}") + else: + raise Exception("It seems that {value} does not exist or is not a file.") + + return content # def collate_fn_list_dicts(batch): # metadata = [] @@ -79,8 +93,8 @@ def __init__( test_data_root: Path, img_grep: str, label_grep: str, - means: list[float], - stds: list[float], + means: list[float] | str, + stds: list[float] | str, num_classes: int, predict_data_root: Path | None = None, train_label_data_root: Path | None = None, @@ -198,6 +212,9 @@ def __init__( # K.Normalize(means, stds), # data_keys=["image"], # ) + means = load_from_file_or_attribute(means) + stds = load_from_file_or_attribute(stds) + self.aug = Normalize(means, stds) # self.aug = Normalize(means, stds) From d0c792eb4ff93a1000ad668bc8856a6298ab8542 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 26 Jul 2024 15:17:25 -0300 Subject: [PATCH 2/7] Reading metrics (std and mean) from file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../generic_pixel_wise_data_module.py | 7 ++++-- .../generic_scalar_label_data_module.py | 23 +++++++++++++++++-- tests/test_finetune.py | 13 +++++++++++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index f02f7d5b..a67666f2 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -334,8 +334,8 @@ def __init__( train_data_root: Path, val_data_root: Path, test_data_root: Path, - means: list[float], - stds: list[float], + means: list[float] | str, + stds: list[float] | str, predict_data_root: Path | None = None, img_grep: str | None = "*", label_grep: str | None = "*", @@ -447,6 +447,9 @@ def __init__( # K.Normalize(means, stds), # data_keys=["image"], # ) + means = load_from_file_or_attribute(means) + stds = load_from_file_or_attribute(stds) + self.aug = Normalize(means, stds) self.no_data_replace = no_data_replace self.no_label_replace = no_label_replace diff --git a/terratorch/datamodules/generic_scalar_label_data_module.py b/terratorch/datamodules/generic_scalar_label_data_module.py index 5cd7470a..d289afff 100644 --- a/terratorch/datamodules/generic_scalar_label_data_module.py +++ b/terratorch/datamodules/generic_scalar_label_data_module.py @@ -27,6 +27,21 @@ def wrap_in_compose_is_list(transform_list): # set check shapes to false because of the multitemporal case return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list +def load_from_file_or_attribute(value:list[float] | str): + + if type(value) == list: + return value + elif type(str): # It can be the path for a file + if os.path.isfile(value): + try: + content = np.genfromtxt(value).tolist() + except: + raise Exception(f"File must be txt, but received {value}") + else: + raise Exception("It seems that {value} does not exist or is not a file.") + + return content + class Normalize(Callable): def __init__(self, means, stds): @@ -68,8 +83,8 @@ def __init__( train_data_root: Path, val_data_root: Path, test_data_root: Path, - means: list[float], - stds: list[float], + means: list[float] | str, + stds: list[float] | str, num_classes: int, predict_data_root: Path | None = None, train_split: Path | None = None, @@ -166,6 +181,10 @@ def __init__( # K.Normalize(means, stds), # data_keys=["image"], # ) + + means = load_from_file_or_attribute(means) + stds = load_from_file_or_attribute(stds) + self.aug = Normalize(means, stds) # self.aug = Normalize(means, stds) diff --git a/tests/test_finetune.py b/tests/test_finetune.py index cd4f5356..8592e639 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -48,4 +48,17 @@ def test_finetune_bands_str(model_name): # Running the terratorch CLI command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"] _ = build_lightning_cli(command_list) + +@pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) +def test_finetune_bands_str(model_name): + + model_instance = timm.create_model(model_name) + + state_dict = model_instance.state_dict() + + torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) + + # Running the terratorch CLI + command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_metrics_from_file.yaml"] + _ = build_lightning_cli(command_list) From ad2e445632660eed3816003a5df23bbebfa6d701 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 26 Jul 2024 15:18:30 -0300 Subject: [PATCH 3/7] Auxiliary files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- ...tune_prithvi_swin_B_metrics_from_file.yaml | 124 ++++++++++++++++++ tests/means.txt | 7 + tests/stds.txt | 7 + 3 files changed, 138 insertions(+) create mode 100644 tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml create mode 100644 tests/means.txt create mode 100644 tests/stds.txt diff --git a/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml b/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml new file mode 100644 index 00000000..91a72a3c --- /dev/null +++ b/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml @@ -0,0 +1,124 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 5 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - [0, 11] + output_bands: + - [1, 3] + - [4, 6] + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/ + train_label_data_root: tests/ + val_data_root: tests/ + val_label_data_root: tests/ + test_data_root: tests/ + test_label_data_root: tests/ + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: tests/means.txt + stds: tests/stds.txt + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: true + backbone: prithvi_swin_B + backbone_pretrained_cfg_overlay: + file: tests/prithvi_swin_B.pt + backbone_drop_path_rate: 0.3 + # backbone_window_size: 8 + decoder_channels: 256 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/means.txt b/tests/means.txt new file mode 100644 index 00000000..56900a7f --- /dev/null +++ b/tests/means.txt @@ -0,0 +1,7 @@ +411.4701 +558.54065 +815.94025 +812.4403 +1113.7145 +1067.641 + diff --git a/tests/stds.txt b/tests/stds.txt new file mode 100644 index 00000000..ad602006 --- /dev/null +++ b/tests/stds.txt @@ -0,0 +1,7 @@ +547.36707 +898.5121 +1020.9082 +2665.5352 +2340.584 +1610.1407 + From 70f3707eb90a56230de35fa21a25b869aecd4643 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 29 Jul 2024 13:51:59 -0300 Subject: [PATCH 4/7] Minor issues solved MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datamodules/generic_scalar_label_data_module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/terratorch/datamodules/generic_scalar_label_data_module.py b/terratorch/datamodules/generic_scalar_label_data_module.py index d289afff..27ff5790 100644 --- a/terratorch/datamodules/generic_scalar_label_data_module.py +++ b/terratorch/datamodules/generic_scalar_label_data_module.py @@ -29,16 +29,16 @@ def wrap_in_compose_is_list(transform_list): def load_from_file_or_attribute(value:list[float] | str): - if type(value) == list: + if isinstance(value, list): return value - elif type(str): # It can be the path for a file + elif isinstance(value, str): # It can be the path for a file if os.path.isfile(value): try: content = np.genfromtxt(value).tolist() except: raise Exception(f"File must be txt, but received {value}") else: - raise Exception("It seems that {value} does not exist or is not a file.") + raise Exception(f"The input {value} does not exist or is not a file.") return content From 9f840ea4d44ef597d1a5002323dfbdde46f56bee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 29 Jul 2024 13:54:54 -0300 Subject: [PATCH 5/7] space removed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datamodules/generic_scalar_label_data_module.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/terratorch/datamodules/generic_scalar_label_data_module.py b/terratorch/datamodules/generic_scalar_label_data_module.py index 27ff5790..c32fbd7a 100644 --- a/terratorch/datamodules/generic_scalar_label_data_module.py +++ b/terratorch/datamodules/generic_scalar_label_data_module.py @@ -27,11 +27,12 @@ def wrap_in_compose_is_list(transform_list): # set check shapes to false because of the multitemporal case return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list -def load_from_file_or_attribute(value:list[float] | str): + +def load_from_file_or_attribute(value: list[float]|str): if isinstance(value, list): return value - elif isinstance(value, str): # It can be the path for a file + elif isinstance(value, str): # It can be the path for a file if os.path.isfile(value): try: content = np.genfromtxt(value).tolist() @@ -40,7 +41,7 @@ def load_from_file_or_attribute(value:list[float] | str): else: raise Exception(f"The input {value} does not exist or is not a file.") - return content + return content class Normalize(Callable): From 3aefe1ba7535711b7531979fc2f4877d23a1a6f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 30 Jul 2024 09:49:36 -0300 Subject: [PATCH 6/7] This function must be shared MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../generic_pixel_wise_data_module.py | 16 +--------------- terratorch/io/file.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index a67666f2..434f7488 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -17,26 +17,12 @@ from torchgeo.transforms import AugmentationSequential from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset, HLSBands - +from terratorch.io.file import load_from_file_or_attribute def wrap_in_compose_is_list(transform_list): # set check shapes to false because of the multitemporal case return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list -def load_from_file_or_attribute(value:list[float] | str): - - if type(value) == list: - return value - elif type(str): # It can be the path for a file - if os.path.isfile(value): - try: - content = np.genfromtxt(value).tolist() - except: - raise Exception(f"File must be txt, but received {value}") - else: - raise Exception("It seems that {value} does not exist or is not a file.") - - return content # def collate_fn_list_dicts(batch): # metadata = [] diff --git a/terratorch/io/file.py b/terratorch/io/file.py index bb8ef9fc..6cab0acd 100644 --- a/terratorch/io/file.py +++ b/terratorch/io/file.py @@ -1,6 +1,7 @@ import os import importlib from torch import nn +import numpy as np def open_generic_torch_model(model: type | str = None, model_kwargs: dict = None, @@ -51,3 +52,21 @@ def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = N ) return model + +def load_from_file_or_attribute(value: list[float]|str): + + if isinstance(value, list): + return value + elif isinstance(value, str): # It can be the path for a file + if os.path.isfile(value): + try: + print(value) + content = np.genfromtxt(value).tolist() + except: + raise Exception(f"File must be txt, but received {value}") + else: + raise Exception(f"The input {value} does not exist or is not a file.") + + return content + + From fae25318370378306202882360b02f7785791f62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 30 Jul 2024 11:16:24 -0300 Subject: [PATCH 7/7] This function should not be here MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../generic_scalar_label_data_module.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/terratorch/datamodules/generic_scalar_label_data_module.py b/terratorch/datamodules/generic_scalar_label_data_module.py index c32fbd7a..71f75533 100644 --- a/terratorch/datamodules/generic_scalar_label_data_module.py +++ b/terratorch/datamodules/generic_scalar_label_data_module.py @@ -22,28 +22,12 @@ HLSBands, ) +from terratorch.io.file import load_from_file_or_attribute def wrap_in_compose_is_list(transform_list): # set check shapes to false because of the multitemporal case return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list - -def load_from_file_or_attribute(value: list[float]|str): - - if isinstance(value, list): - return value - elif isinstance(value, str): # It can be the path for a file - if os.path.isfile(value): - try: - content = np.genfromtxt(value).tolist() - except: - raise Exception(f"File must be txt, but received {value}") - else: - raise Exception(f"The input {value} does not exist or is not a file.") - - return content - - class Normalize(Callable): def __init__(self, means, stds): super().__init__()