Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read std mean from file #61

Merged
merged 7 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions terratorch/datamodules/generic_pixel_wise_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
"""
This module contains generic data modules for instantiation at runtime.
"""

import os
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any

import numpy as np
import albumentations as A
import kornia.augmentation as K
import torch
Expand All @@ -23,6 +23,20 @@ def wrap_in_compose_is_list(transform_list):
# set check shapes to false because of the multitemporal case
return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list

def load_from_file_or_attribute(value:list[float] | str):
Joao-L-S-Almeida marked this conversation as resolved.
Show resolved Hide resolved

if type(value) == list:
return value
elif type(str): # It can be the path for a file
if os.path.isfile(value):
try:
content = np.genfromtxt(value).tolist()
except:
raise Exception(f"File must be txt, but received {value}")
else:
raise Exception("It seems that {value} does not exist or is not a file.")
Joao-L-S-Almeida marked this conversation as resolved.
Show resolved Hide resolved

return content

# def collate_fn_list_dicts(batch):
# metadata = []
Expand Down Expand Up @@ -79,8 +93,8 @@ def __init__(
test_data_root: Path,
img_grep: str,
label_grep: str,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
num_classes: int,
predict_data_root: Path | None = None,
train_label_data_root: Path | None = None,
Expand Down Expand Up @@ -198,6 +212,9 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )
means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)

# self.aug = Normalize(means, stds)
Expand Down Expand Up @@ -317,8 +334,8 @@ def __init__(
train_data_root: Path,
val_data_root: Path,
test_data_root: Path,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
predict_data_root: Path | None = None,
img_grep: str | None = "*",
label_grep: str | None = "*",
Expand Down Expand Up @@ -430,6 +447,9 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )
means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)
self.no_data_replace = no_data_replace
self.no_label_replace = no_label_replace
Expand Down
24 changes: 22 additions & 2 deletions terratorch/datamodules/generic_scalar_label_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ def wrap_in_compose_is_list(transform_list):
return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list


def load_from_file_or_attribute(value: list[float]|str):

if isinstance(value, list):
return value
elif isinstance(value, str): # It can be the path for a file
if os.path.isfile(value):
try:
content = np.genfromtxt(value).tolist()
except:
raise Exception(f"File must be txt, but received {value}")
else:
raise Exception(f"The input {value} does not exist or is not a file.")

return content


class Normalize(Callable):
def __init__(self, means, stds):
super().__init__()
Expand Down Expand Up @@ -68,8 +84,8 @@ def __init__(
train_data_root: Path,
val_data_root: Path,
test_data_root: Path,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
num_classes: int,
predict_data_root: Path | None = None,
train_split: Path | None = None,
Expand Down Expand Up @@ -166,6 +182,10 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )

means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)

# self.aug = Normalize(means, stds)
Expand Down
124 changes: 124 additions & 0 deletions tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: tests/
name: all_ecos_random
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 5
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: tests/
data:
class_path: GenericNonGeoPixelwiseRegressionDataModule
init_args:
batch_size: 2
num_workers: 4
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.Rotate
init_args:
limit: 30
border_mode: 0 # cv2.BORDER_CONSTANT
value: 0
# mask_value: 1
p: 0.5
- class_path: ToTensorV2
dataset_bands:
- [0, 11]
output_bands:
- [1, 3]
- [4, 6]
rgb_indices:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
img_grep: "regression*input*.tif"
label_grep: "regression*label*.tif"
means: tests/means.txt
stds: tests/stds.txt
no_label_replace: -1
no_data_replace: 0

model:
class_path: terratorch.tasks.PixelwiseRegressionTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: true
backbone: prithvi_swin_B
backbone_pretrained_cfg_overlay:
file: tests/prithvi_swin_B.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 256
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
loss: rmse
#aux_heads:
# - name: aux_head
# decoder: IdentityDecoder
# decoder_args:
# decoder_out_index: 2
# head_dropout: 0,5
# head_channel_list:
# - 64
# head_final_act: torch.nn.ReLU
#aux_loss:
# aux_head: 0.4
ignore_index: -1
freeze_backbone: true
freeze_decoder: false
model_factory: PrithviModelFactory

# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00013524680528283027
weight_decay: 0.047782217873995426
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

7 changes: 7 additions & 0 deletions tests/means.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
411.4701
558.54065
815.94025
812.4403
1113.7145
1067.641

7 changes: 7 additions & 0 deletions tests/stds.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
547.36707
898.5121
1020.9082
2665.5352
2340.584
1610.1407

13 changes: 13 additions & 0 deletions tests/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,17 @@ def test_finetune_bands_str(model_name):
# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"]
_ = build_lightning_cli(command_list)

@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_str(model_name):

model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_metrics_from_file.yaml"]
_ = build_lightning_cli(command_list)

Loading