diff --git a/gufe/protocols/protocol.py b/gufe/protocols/protocol.py index 4bf7f455..918e1ba7 100644 --- a/gufe/protocols/protocol.py +++ b/gufe/protocols/protocol.py @@ -9,7 +9,7 @@ from typing import Optional, Iterable, Any, Union from openff.units import Quantity -from ..settings import Settings +from ..settings import Settings, SettingsBaseModel from ..tokenization import GufeTokenizable, GufeKey from ..chemicalsystem import ChemicalSystem from ..mapping import ComponentMapping @@ -89,13 +89,14 @@ def __init__(self, settings: Settings): Parameters ---------- settings : Settings - The full settings for this ``Protocol`` instance. + The full settings for this ``Protocol`` instance. Must be passed an instance of Settings or a + subclass which is specialised for a particular Protocol """ - self._settings = settings + self._settings = settings.frozen_copy() @property def settings(self) -> Settings: - """The full settings for this ``Protocol`` instance.""" + """A read-only view of the settings for this ``Protocol`` instance.""" return self._settings @classmethod diff --git a/gufe/settings/models.py b/gufe/settings/models.py index fcc38356..1ba97110 100644 --- a/gufe/settings/models.py +++ b/gufe/settings/models.py @@ -16,6 +16,7 @@ Extra, Field, PositiveFloat, + PrivateAttr, validator, ) except ImportError: @@ -23,18 +24,77 @@ Extra, Field, PositiveFloat, + PrivateAttr, validator, ) class SettingsBaseModel(DefaultModel): """Settings and modifications we want for all settings classes.""" + _is_frozen: bool = PrivateAttr(default_factory=lambda: False) class Config: extra = Extra.forbid arbitrary_types_allowed = False smart_union = True + def frozen_copy(self): + """A copy of this Settings object which cannot be modified + + This is intended to be used by Protocols to make their stored Settings + read-only + """ + copied = self.copy(deep=True) + + def freeze_model(model): + submodels = ( + mod for field in model.__fields__ + if isinstance(mod := getattr(model, field), SettingsBaseModel) + ) + for mod in submodels: + freeze_model(mod) + + if not model._is_frozen: + model._is_frozen = True + + freeze_model(copied) + return copied + + def unfrozen_copy(self): + """A copy of this Settings object, which can be modified + + Settings objects become frozen when within a Protocol. If you *really* + need to reverse this, this method is how. + """ + copied = self.copy(deep=True) + + def unfreeze_model(model): + submodels = ( + mod for field in model.__fields__ + if isinstance(mod := getattr(model, field), SettingsBaseModel) + ) + for mod in submodels: + unfreeze_model(mod) + + model._is_frozen = False + + unfreeze_model(copied) + + return copied + + @property + def is_frozen(self): + """If this Settings object is frozen and cannot be modified""" + return self._is_frozen + + def __setattr__(self, name, value): + if name != "_is_frozen" and self._is_frozen: + raise AttributeError( + f"Cannot set '{name}': Settings are immutable once attached" + " to a Protocol and cannot be modified. Modify Settings " + "*before* creating the Protocol.") + return super().__setattr__(name, value) + class ThermoSettings(SettingsBaseModel): """Settings for thermodynamic parameters. diff --git a/gufe/tests/conftest.py b/gufe/tests/conftest.py index 034fac04..62487740 100644 --- a/gufe/tests/conftest.py +++ b/gufe/tests/conftest.py @@ -244,7 +244,7 @@ def absolute_transformation(solvated_ligand, solvated_complex): return gufe.Transformation( solvated_ligand, solvated_complex, - protocol=DummyProtocol(settings=None), + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), mapping=None, ) @@ -253,7 +253,7 @@ def absolute_transformation(solvated_ligand, solvated_complex): def complex_equilibrium(solvated_complex): return gufe.NonTransformation( solvated_complex, - protocol=DummyProtocol(settings=None) + protocol=DummyProtocol(settings=DummyProtocol.default_settings()) ) @@ -292,7 +292,7 @@ def benzene_variants_star_map( ] = gufe.Transformation( solvated_ligands["benzene"], solvated_ligands[ligand.name], - protocol=DummyProtocol(settings=None), + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), mapping=None, ) @@ -316,7 +316,7 @@ def benzene_variants_star_map( ] = gufe.Transformation( solvated_complexes["benzene"], solvated_complexes[ligand.name], - protocol=DummyProtocol(settings=None), + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), mapping=None, ) diff --git a/gufe/tests/test_alchemicalnetwork.py b/gufe/tests/test_alchemicalnetwork.py index 420c2243..3d49c51d 100644 --- a/gufe/tests/test_alchemicalnetwork.py +++ b/gufe/tests/test_alchemicalnetwork.py @@ -13,8 +13,8 @@ class TestAlchemicalNetwork(GufeTokenizableTestsMixin): cls = AlchemicalNetwork - key = "AlchemicalNetwork-7f0cf9403891eb7deaae860515ea1c63" - repr = "" + key = "AlchemicalNetwork-d1035e11493ca60ff7bac5171eddfee3" + repr = "" @pytest.fixture def instance(self, benzene_variants_star_map): diff --git a/gufe/tests/test_models.py b/gufe/tests/test_models.py index 1d10741b..4fd67bf0 100644 --- a/gufe/tests/test_models.py +++ b/gufe/tests/test_models.py @@ -9,7 +9,11 @@ from openff.units import unit import pytest -from gufe.settings.models import Settings, OpenMMSystemGeneratorFFSettings +from gufe.settings.models import ( + OpenMMSystemGeneratorFFSettings, + Settings, + ThermoSettings, +) def test_model_schema(): @@ -53,3 +57,57 @@ def test_invalid_constraint(value, good): else: with pytest.raises(ValueError): _ = OpenMMSystemGeneratorFFSettings(constraints=value) + + +class TestFreezing: + def test_default_not_frozen(self): + s = Settings.get_defaults() + # make a frozen copy to check this doesn't alter the original + s2 = s.frozen_copy() + + s.thermo_settings.temperature = 199 * unit.kelvin + assert s.thermo_settings.temperature == 199 * unit.kelvin + + def test_freezing(self): + s = Settings.get_defaults() + + s2 = s.frozen_copy() + + with pytest.raises(AttributeError, match="immutable"): + s2.thermo_settings.temperature = 199 * unit.kelvin + + def test_unfreezing(self): + s = Settings.get_defaults() + + s2 = s.frozen_copy() + + with pytest.raises(AttributeError, match="immutable"): + s2.thermo_settings.temperature = 199 * unit.kelvin + + assert s2.is_frozen + + s3 = s2.unfrozen_copy() + + s3.thermo_settings.temperature = 199 * unit.kelvin + assert s3.thermo_settings.temperature == 199 * unit.kelvin + + def test_frozen_equality(self): + # the frozen-ness of Settings doesn't alter its contents + # therefore a frozen/unfrozen Settings which are otherwise identical + # should be considered equal + s = Settings.get_defaults() + s2 = s.frozen_copy() + + assert s == s2 + + def test_set_subsection(self): + # check that attempting to set a subsection of settings still respects + # frozen state of parent object + s = Settings.get_defaults().frozen_copy() + + assert s.is_frozen + + ts = ThermoSettings(temperature=301 * unit.kelvin) + + with pytest.raises(AttributeError, match="immutable"): + s.thermo_settings = ts diff --git a/gufe/tests/test_protocol.py b/gufe/tests/test_protocol.py index 548bd9f8..44d70e48 100644 --- a/gufe/tests/test_protocol.py +++ b/gufe/tests/test_protocol.py @@ -306,7 +306,7 @@ def test_dag_execute_failure(self, protocol_dag_broken): assert len(succeeded_units) > 0 def test_dag_execute_failure_raise_error(self, solvated_ligand, vacuum_ligand, tmpdir): - protocol = BrokenProtocol(settings=None) + protocol = BrokenProtocol(settings=BrokenProtocol.default_settings()) dag = protocol.create( stateA=solvated_ligand, stateB=vacuum_ligand, name="a broken dummy run", mapping=None, @@ -507,7 +507,7 @@ def _defaults(cls): @classmethod def _default_settings(cls): - return {} + return settings.Settings.get_defaults() def _create( self, @@ -719,3 +719,23 @@ def test_execute_DAG_bad_nretries(solvated_ligand, vacuum_ligand, tmpdir): keep_scratch=True, raise_error=False, n_retries=-1) + + +def test_settings_readonly(): + # checks that settings aren't editable once inside a Protocol + p = DummyProtocol(DummyProtocol.default_settings()) + + before = p.settings.n_repeats + + with pytest.raises(AttributeError, match="immutable"): + p.settings.n_repeats = before + 1 + + assert p.settings.n_repeats == before + + # also check child settings + before = p.settings.thermo_settings.temperature + + with pytest.raises(AttributeError, match="immutable"): + p.settings.thermo_settings.temperature = 400.0 * unit.kelvin + + assert p.settings.thermo_settings.temperature == before diff --git a/gufe/tests/test_transformation.py b/gufe/tests/test_transformation.py index 7e98d094..e0b59bdf 100644 --- a/gufe/tests/test_transformation.py +++ b/gufe/tests/test_transformation.py @@ -99,10 +99,12 @@ def test_equality(self, absolute_transformation, solvated_ligand, solvated_compl ) assert absolute_transformation != opposite + s = DummyProtocol.default_settings() + s.n_repeats = 99 different_protocol_settings = Transformation( solvated_ligand, solvated_complex, - protocol=DummyProtocol(settings={"lol": True}), + protocol=DummyProtocol(settings=s), ) assert absolute_transformation != different_protocol_settings @@ -187,9 +189,10 @@ def test_protocol_extend(self, complex_equilibrium, tmpdir): assert len(protocolresult.data) == 2 def test_equality(self, complex_equilibrium, solvated_ligand, solvated_complex): - + s = DummyProtocol.default_settings() + s.n_repeats = 4031 different_protocol_settings = NonTransformation( - solvated_complex, protocol=DummyProtocol(settings={"lol": True}) + solvated_complex, protocol=DummyProtocol(settings=s) ) assert complex_equilibrium != different_protocol_settings