Skip to content

Commit

Permalink
ensure tests run on cpu, save models to tmp path
Browse files Browse the repository at this point in the history
Signed-off-by: Carlos Gomes <[email protected]>
  • Loading branch information
CarlosGomes98 committed Aug 2, 2024
1 parent f8260cf commit 86151ec
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion tests/manufactured-finetune_prithvi_swin_B.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down
2 changes: 1 addition & 1 deletion tests/manufactured-finetune_prithvi_swin_B_string.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down
2 changes: 1 addition & 1 deletion tests/manufactured-finetune_prithvi_swin_L.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down
3 changes: 1 addition & 2 deletions tests/manufactured-finetune_prithvi_vit_100.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down Expand Up @@ -111,7 +111,6 @@ model:
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
Expand Down
3 changes: 1 addition & 2 deletions tests/manufactured-finetune_prithvi_vit_300.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down Expand Up @@ -111,7 +111,6 @@ model:
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
Expand Down
26 changes: 14 additions & 12 deletions tests/test_finetune.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,64 @@
import importlib
import os
import subprocess

import pytest
import timm
import torch
import importlib
import terratorch
import subprocess
import os

import terratorch
from terratorch.cli_tools import build_lightning_cli


@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
def test_finetune_multiple_backbones(model_name):
def test_finetune_multiple_backbones(model_name, tmp_path):

model_instance = timm.create_model(model_name)
pretrained_bands = [0, 1, 2, 3, 4, 5]
model_bands = [0, 1, 2, 3, 4, 5]

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
torch.save(state_dict, os.path.join(tmp_path, model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"]
_ = build_lightning_cli(command_list)

@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_intervals(model_name):
def test_finetune_bands_intervals(model_name, tmp_path):

model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
torch.save(state_dict, os.path.join(tmp_path, model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_band_interval.yaml"]
_ = build_lightning_cli(command_list)

@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_str(model_name):
def test_finetune_bands_str(model_name, tmp_path):

model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
torch.save(state_dict, os.path.join(tmp_path, model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"]
_ = build_lightning_cli(command_list)

@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_str(model_name):
def test_finetune_bands_str(model_name, tmp_path):

model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
torch.save(state_dict, os.path.join(tmp_path, model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_metrics_from_file.yaml"]
Expand Down

0 comments on commit 86151ec

Please sign in to comment.