diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 00000000..96812e26 --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,24 @@ +# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file +# mostly from https://github.com/microsoft/torchgeo/blob/main/.github/dependabot.yml +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "daily" + groups: + # torchvision pins torch, must update in unison + torch: + patterns: + - "torch" + - "torchvision" + ignore: + # setuptools releases new versions almost daily + - dependency-name: "setuptools" + update-types: ["version-update:semver-patch"] + # segmentation-models-pytorch pins timm, must update in unison + - dependency-name: "timm" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 505db965..d18c6bef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,8 @@ dependencies = [ "h5py>=3.10.0", "geobench>=1.0.0", "mlflow>=2.12.1", - "lightning<=2.2.5" + # broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977 + "lightning>=2, <=2.2.5" ] [project.optional-dependencies] diff --git a/requirements/required.txt b/requirements/required.txt index 80114f18..06dd4ef4 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -10,3 +10,4 @@ lightly==1.4.25 h5py==3.10.0 geobench==1.0.0 mlflow==2.12.1 +lightning==2.2.5 \ No newline at end of file diff --git a/tests/test_finetune.py b/tests/test_finetune.py index 4812d5f3..bb48b94e 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -6,8 +6,7 @@ import subprocess import os -from terratorch.models.backbones.prithvi_vit import checkpoint_filter_fn as checkpoint_filter_fn_vit -from terratorch.models.backbones.prithvi_swin import checkpoint_filter_fn as checkpoint_filter_fn_swin +from terratorch.cli_tools import build_lightning_cli @pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"]) def test_finetune_multiple_backbones(model_name): @@ -21,11 +20,24 @@ def test_finetune_multiple_backbones(model_name): torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) # Running the terratorch CLI - command_str = f"terratorch fit -c tests/manufactured-finetune_{model_name}.yaml" + command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"] + _ = build_lightning_cli(command_list) - command_out = subprocess.run(command_str, shell=True) +""" +@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"]) +def test_finetune_multiple_backbones(model_name): - assert not command_out.returncode + model_instance = timm.create_model(model_name) + pretrained_bands = [0, 1, 2, 3, 4, 5] + model_bands = [0, 1, 2, 3, 4, 5] + state_dict = model_instance.state_dict() + + torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) + # Running the terratorch CLI + command_str = f"python terratorch/__main__.py fit -c tests/manufactured-finetune_{model_name}.yaml" + command_out = subprocess.run(command_str, shell=True) + assert not command_out.returncode + """