Skip to content

Commit

Permalink
Merge pull request #33 from IBM/finetune_tests
Browse files Browse the repository at this point in the history
Using the Lightning CLI interface to run tests
  • Loading branch information
CarlosGomes98 authored Jul 4, 2024
2 parents 1ebeeb5 + 11974fe commit 308d540
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 6 deletions.
24 changes: 24 additions & 0 deletions .github/dependabot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
# mostly from https://github.com/microsoft/torchgeo/blob/main/.github/dependabot.yml
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
- package-ecosystem: "pip"
directory: "/"
schedule:
interval: "daily"
groups:
# torchvision pins torch, must update in unison
torch:
patterns:
- "torch"
- "torchvision"
ignore:
# setuptools releases new versions almost daily
- dependency-name: "setuptools"
update-types: ["version-update:semver-patch"]
# segmentation-models-pytorch pins timm, must update in unison
- dependency-name: "timm"
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ dependencies = [
"h5py>=3.10.0",
"geobench>=1.0.0",
"mlflow>=2.12.1",
"lightning<=2.2.5"
# broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977
"lightning>=2, <=2.2.5"
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ lightly==1.4.25
h5py==3.10.0
geobench==1.0.0
mlflow==2.12.1
lightning==2.2.5
22 changes: 17 additions & 5 deletions tests/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import subprocess
import os

from terratorch.models.backbones.prithvi_vit import checkpoint_filter_fn as checkpoint_filter_fn_vit
from terratorch.models.backbones.prithvi_swin import checkpoint_filter_fn as checkpoint_filter_fn_swin
from terratorch.cli_tools import build_lightning_cli

@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
def test_finetune_multiple_backbones(model_name):
Expand All @@ -21,11 +20,24 @@ def test_finetune_multiple_backbones(model_name):
torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))

# Running the terratorch CLI
command_str = f"terratorch fit -c tests/manufactured-finetune_{model_name}.yaml"
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"]
_ = build_lightning_cli(command_list)

command_out = subprocess.run(command_str, shell=True)
"""
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
def test_finetune_multiple_backbones(model_name):
assert not command_out.returncode
model_instance = timm.create_model(model_name)
pretrained_bands = [0, 1, 2, 3, 4, 5]
model_bands = [0, 1, 2, 3, 4, 5]
state_dict = model_instance.state_dict()
torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
# Running the terratorch CLI
command_str = f"python terratorch/__main__.py fit -c tests/manufactured-finetune_{model_name}.yaml"
command_out = subprocess.run(command_str, shell=True)
assert not command_out.returncode
"""

0 comments on commit 308d540

Please sign in to comment.