diff --git a/payu/experiment.py b/payu/experiment.py index a2bc89ef..826d2009 100644 --- a/payu/experiment.py +++ b/payu/experiment.py @@ -23,6 +23,7 @@ # Extensions import yaml +from packaging import version # Local from payu import envmod @@ -387,7 +388,31 @@ def build_model(self): for model in self.models: model.build_model() + def check_payu_version(self): + """Check current payu version is greater than minimum required + payu version, if configured""" + # TODO: Move this function to a setup file if setup is moved to + # a separate file? + if "payu_minimum_version" not in self.config: + # Skip version check + return + + minimum_version = str(self.config['payu_minimum_version']) + + # Get the current version of the package + current_version = payu.__version__ + + # Compare versions + if version.parse(current_version) < version.parse(minimum_version): + raise RuntimeError( + f"Payu version {current_version} does not meet the configured " + f"minimum version. A version >= {minimum_version} is " + "required to run this configuration." + ) + def setup(self, force_archive=False): + # Check version + self.check_payu_version() # Confirm that no output path already exists if os.path.exists(self.output_path): diff --git a/pyproject.toml b/pyproject.toml index e1c7582c..ded28992 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,8 @@ dependencies = [ "tenacity >=8.0.0", "cftime", "GitPython >=3.1.40", - "ruamel.yaml >=0.18.5" + "ruamel.yaml >=0.18.5", + "packaging" ] [project.optional-dependencies] diff --git a/test/test_setup.py b/test/test_setup.py index 5f7a8b8c..afb98fdf 100644 --- a/test/test_setup.py +++ b/test/test_setup.py @@ -4,6 +4,7 @@ import pdb import pytest import shutil +from unittest.mock import patch import yaml import payu @@ -136,3 +137,68 @@ def test_setup(): for i in range(1, 4): assert((workdir/'input_00{i}.bin'.format(i=i)).stat().st_size == 1000**2 + i) + + +@pytest.mark.parametrize( + "current_version, min_version", + [ + ("2.0.0", "1.0.0"), + ("v0.11.2", "v0.11.1"), + ("1.0.0", "1.0.0"), + ("1.0.0+4.gabc1234", "1.0.0"), + ("1.0.0+0.gxyz987.dirty", "1.0.0"), + ("1.1.5", 1.1) + ] +) +def test_check_payu_version_pass(current_version, min_version): + # Mock the payu version + with patch('payu.__version__', current_version): + # Avoid running Experiment init method + with patch.object(payu.experiment.Experiment, '__init__', + lambda x: None): + expt = payu.experiment.Experiment() + + # Mock config.yaml + expt.config = { + "payu_minimum_version": min_version + } + expt.check_payu_version() + + +@pytest.mark.parametrize( + "current_version, min_version", + [ + ("1.0.0", "2.0.0"), + ("v0.11", "v0.11.1"), + ("1.0.0+4.gabc1234", "1.0.1"), + ("1.0.0+0.gxyz987.dirty", "v1.2"), + ] +) +def test_check_payu_version_fail(current_version, min_version): + with patch('payu.__version__', current_version): + with patch.object(payu.experiment.Experiment, '__init__', + lambda x: None): + expt = payu.experiment.Experiment() + + expt.config = { + "payu_minimum_version": min_version + } + + with pytest.raises(RuntimeError): + expt.check_payu_version() + + +@pytest.mark.parametrize( + "current_version", ["1.0.0", "1.0.0+4.gabc1234"] +) +def test_check_payu_version_pass_with_no_minimum_version(current_version): + with patch('payu.__version__', current_version): + with patch.object(payu.experiment.Experiment, '__init__', + lambda x: None): + expt = payu.experiment.Experiment() + + # Leave version out of config.yaml + expt.config = {} + + # Check runs without an error + expt.check_payu_version()