Skip to content

Commit

Permalink
Merge branch 'IBM:main' into reduce_lr
Browse files Browse the repository at this point in the history
  • Loading branch information
fmartiescofet authored Dec 18, 2024
2 parents 2fb9383 + 30dfdf1 commit 5858feb
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 6 deletions.
5 changes: 4 additions & 1 deletion contribution_process.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ If you want to contribute to this project, there are many valuable ways in doing
1. Use / test TerraTorch and create an [Issue](https://github.com/IBM/terratorch/issues) if something is not working properly or if you have an idea for a feature request.
1. Pick an [Issue](https://github.com/IBM/terratorch/issues) and start contributing

Contributions are welcome as pull requests on a [fork](https://github.com/IBM/terratorch/fork) of this project. Ideally, pull requests are backed by an [Issue](https://github.com/IBM/terratorch/issues). You can also tag the [code owners](https://github.com/IBM/terratorch/blob/main/CODEOWNERS) in the issue before you start, so we can talk about the details (in case you can't join one of the community calls).
Contributions are welcome as pull requests on a [fork](https://github.com/IBM/terratorch/fork) of this project. Ideally, pull requests are backed by an [Issue](https://github.com/IBM/terratorch/issues). You can also tag the [code owners](https://github.com/IBM/terratorch/blob/main/CODEOWNERS) in the issue before you start, so we can talk about the details (in case you can't join one of the community calls).

After or during implementation on your branch, please create a PR to main. During development, please mark this PR as DRAFT and prefix with '[WIP]'
If you want us to merge the PR, remove 'draft' and '[WIP]'. Before that, please make sure that all tests are passing. Unit tests are automatically run on GitHub on the branch as well. The TerraTorch committers will review your code and will run integrations tests on our GPU cluster before we merge to main.
150 changes: 150 additions & 0 deletions tests/resources/configs/manufactured-finetune_prithvi_eo_v2_300.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: tests/
name: all_ecos_random
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 2
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: tests/
data:
class_path: GenericNonGeoPixelwiseRegressionDataModule
init_args:
batch_size: 2
num_workers: 4
train_transform:
#- class_path: albumentations.HorizontalFlip
# init_args:
# p: 0.5
#- class_path: albumentations.Rotate
# init_args:
# limit: 30
# border_mode: 0 # cv2.BORDER_CONSTANT
# value: 0
# # mask_value: 1
# p: 0.5
- class_path: ToTensorV2
dataset_bands:
- 0
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- 1
- 2
- 3
- 4
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
train_data_root: tests/resources/inputs
train_label_data_root: tests/resources/inputs
val_data_root: tests/resources/inputs
val_label_data_root: tests/resources/inputs
test_data_root: tests/resources/inputs
test_label_data_root: tests/resources/inputs
img_grep: "regression*input*.tif"
label_grep: "regression*label*.tif"
means:
- 547.36707
- 898.5121
- 1020.9082
- 2665.5352
- 2340.584
- 1610.1407
stds:
- 411.4701
- 558.54065
- 815.94025
- 812.4403
- 1113.7145
- 1067.641
no_label_replace: -1
no_data_replace: 0

model:
class_path: terratorch.tasks.PixelwiseRegressionTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: false
backbone: prithvi_eo_v2_300
# backbone_pretrained_cfg_overlay:
# file: tests/prithvi_vit_300.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 64
num_frames: 1
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
loss: rmse
#aux_heads:
# - name: aux_head
# decoder: IdentityDecoder
# decoder_args:
# decoder_out_index: 2
# head_dropout: 0,5
# head_channel_list:
# - 64
# head_final_act: torch.nn.ReLU
#aux_loss:
# aux_head: 0.4
ignore_index: -1
freeze_backbone: true
freeze_decoder: false
model_factory: PrithviModelFactory

# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00013524680528283027
weight_decay: 0.047782217873995426
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

150 changes: 150 additions & 0 deletions tests/resources/configs/manufactured-finetune_prithvi_eo_v2_600.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: tests/
name: all_ecos_random
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 2
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: tests/
data:
class_path: GenericNonGeoPixelwiseRegressionDataModule
init_args:
batch_size: 2
num_workers: 4
train_transform:
#- class_path: albumentations.HorizontalFlip
# init_args:
# p: 0.5
#- class_path: albumentations.Rotate
# init_args:
# limit: 30
# border_mode: 0 # cv2.BORDER_CONSTANT
# value: 0
# # mask_value: 1
# p: 0.5
- class_path: ToTensorV2
dataset_bands:
- 0
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- 1
- 2
- 3
- 4
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
train_data_root: tests/resources/inputs
train_label_data_root: tests/resources/inputs
val_data_root: tests/resources/inputs
val_label_data_root: tests/resources/inputs
test_data_root: tests/resources/inputs
test_label_data_root: tests/resources/inputs
img_grep: "regression*input*.tif"
label_grep: "regression*label*.tif"
means:
- 547.36707
- 898.5121
- 1020.9082
- 2665.5352
- 2340.584
- 1610.1407
stds:
- 411.4701
- 558.54065
- 815.94025
- 812.4403
- 1113.7145
- 1067.641
no_label_replace: -1
no_data_replace: 0

model:
class_path: terratorch.tasks.PixelwiseRegressionTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: false
backbone: prithvi_eo_v2_600
# backbone_pretrained_cfg_overlay:
# file: tests/prithvi_vit_300.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 64
num_frames: 1
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
loss: rmse
#aux_heads:
# - name: aux_head
# decoder: IdentityDecoder
# decoder_args:
# decoder_out_index: 2
# head_dropout: 0,5
# head_channel_list:
# - 64
# head_final_act: torch.nn.ReLU
#aux_loss:
# aux_head: 0.4
ignore_index: -1
freeze_backbone: true
freeze_decoder: false
model_factory: PrithviModelFactory

# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00013524680528283027
weight_decay: 0.047782217873995426
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ model:
model_args:
decoder: UperNetDecoder
pretrained: false
backbone: prithvi_vit_300
backbone: prithvi_eo_v2_300
# backbone_pretrained_cfg_overlay:
# file: tests/prithvi_vit_300.pt
backbone_drop_path_rate: 0.3
Expand Down
9 changes: 6 additions & 3 deletions tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,23 @@ def input_386():
return torch.ones((1, NUM_CHANNELS, 386, 386))


@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False)
input_tensor = request.getfixturevalue(test_input)
backbone(input_tensor)
gc.collect()

@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm_features_only(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False, features_only=True)
input_tensor = request.getfixturevalue(test_input)
backbone(input_tensor)
gc.collect()
@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])

@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("prefix", ["", "timm_"])
def test_can_create_timm_backbones_from_registry(model_name, input_224, prefix):
backbone = BACKBONE_REGISTRY.build(prefix+model_name, pretrained=False)
Expand All @@ -62,12 +63,14 @@ def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal):
backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES)
backbone(input_224_multitemporal)
gc.collect()

@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
def test_vit_models_non_divisible_input(model_name, input_non_divisible):
#padding 'none','constant', 'reflect', 'replicate' or 'circular' default is 'none'
backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, padding='constant')
backbone(input_non_divisible)
gc.collect()

@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
@pytest.mark.parametrize("patch_size", [8, 16])
@pytest.mark.parametrize("patch_size_time", [1, 2, 4])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setup_and_cleanup(model_name):
if os.path.isdir(os.path.join("tests", "all_ecos_random")):
shutil.rmtree(os.path.join("tests", "all_ecos_random"))

@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_vit_100"])
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_eo_v2_300", "prithvi_eo_v2_600"])
@pytest.mark.parametrize("case", ["fit", "test", "validate"])
def test_finetune_multiple_backbones(model_name, case):
command_list = [case, "-c", f"tests/resources/configs/manufactured-finetune_{model_name}.yaml"]
Expand Down

0 comments on commit 5858feb

Please sign in to comment.