From 8cc0e9dd368891a594e92d5809e1b9aa25c8b8ee Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Wed, 22 Jan 2025 11:44:41 +0100 Subject: [PATCH] feat: Improve Pydantic configuration discrimination --- .../config/configuration_factories.py | 32 +++++++++++++++++-- .../config/support/supported_algorithms.py | 6 +++- tests/config/test_configuration_factories.py | 24 +++++++++++++- 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/src/careamics/config/configuration_factories.py b/src/careamics/config/configuration_factories.py index 899abcc5..45759e7b 100644 --- a/src/careamics/config/configuration_factories.py +++ b/src/careamics/config/configuration_factories.py @@ -1,8 +1,8 @@ """Convenience functions to create configurations for training and inference.""" -from typing import Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, Optional, Union -from pydantic import TypeAdapter +from pydantic import Discriminator, Tag, TypeAdapter from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm from careamics.config.architectures import UNetModel @@ -12,6 +12,7 @@ from careamics.config.n2n_configuration import N2NConfiguration from careamics.config.n2v_configuration import N2VConfiguration from careamics.config.support import ( + SupportedAlgorithm, SupportedArchitecture, SupportedPixelManipulation, SupportedTransform, @@ -26,6 +27,24 @@ ) +def _algorithm_config_discriminator(value: Union[dict, Configuration]) -> str: + """Discriminate algorithm-specific configurations based on the algorithm. + + Parameters + ---------- + value : Any + Value to discriminate. + + Returns + ------- + str + Discriminator value. + """ + if isinstance(value, dict): + return value["algorithm_config"]["algorithm"] + return value.algorithm_config.algorithm + + def configuration_factory( configuration: dict[str, Any] ) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]: @@ -43,7 +62,14 @@ def configuration_factory( Configuration for training CAREamics. """ adapter: TypeAdapter = TypeAdapter( - Union[N2VConfiguration, N2NConfiguration, CAREConfiguration] + Annotated[ + Union[ + Annotated[N2VConfiguration, Tag(SupportedAlgorithm.N2V.value)], + Annotated[N2NConfiguration, Tag(SupportedAlgorithm.N2N.value)], + Annotated[CAREConfiguration, Tag(SupportedAlgorithm.CARE.value)], + ], + Discriminator(_algorithm_config_discriminator), + ] ) return adapter.validate_python(configuration) diff --git a/src/careamics/config/support/supported_algorithms.py b/src/careamics/config/support/supported_algorithms.py index dc5752bd..15b30274 100644 --- a/src/careamics/config/support/supported_algorithms.py +++ b/src/careamics/config/support/supported_algorithms.py @@ -6,7 +6,11 @@ class SupportedAlgorithm(str, BaseEnum): - """Algorithms available in CAREamics.""" + """Algorithms available in CAREamics. + + These definitions are the same as the keyword `name` of the algorithm + configurations. + """ N2V = "n2v" """Noise2Void algorithm, a self-supervised approach based on blind denoising.""" diff --git a/tests/config/test_configuration_factories.py b/tests/config/test_configuration_factories.py index d47b1f38..3d7b637a 100644 --- a/tests/config/test_configuration_factories.py +++ b/tests/config/test_configuration_factories.py @@ -14,6 +14,7 @@ create_n2v_configuration, ) from careamics.config.configuration_factories import ( + _algorithm_config_discriminator, _create_configuration, _create_supervised_configuration, _create_unet_configuration, @@ -23,6 +24,7 @@ ) from careamics.config.data import N2VDataConfig from careamics.config.support import ( + SupportedAlgorithm, SupportedPixelManipulation, SupportedStructAxis, SupportedTransform, @@ -34,13 +36,33 @@ ) +def test_algorithm_discriminator_n2v(minimum_n2v_configuration): + """Test that the N2V configuration is discriminated correctly.""" + tag = _algorithm_config_discriminator(minimum_n2v_configuration) + assert tag == SupportedAlgorithm.N2V.value + + +@pytest.mark.parametrize( + "algorithm", [SupportedAlgorithm.N2N.value, SupportedAlgorithm.CARE.value] +) +def test_algorithm_discriminator_supervised( + minimum_supervised_configuration, algorithm +): + """Test that the supervised configuration is discriminated correctly.""" + minimum_supervised_configuration["algorithm_config"]["algorithm"] = algorithm + tag = _algorithm_config_discriminator(minimum_supervised_configuration) + assert tag == algorithm + + def test_careamics_config_n2v(minimum_n2v_configuration): """Test that the N2V configuration is created correctly.""" configuration = configuration_factory(minimum_n2v_configuration) assert isinstance(configuration, N2VConfiguration) -@pytest.mark.parametrize("algorithm", ["n2n", "care"]) +@pytest.mark.parametrize( + "algorithm", [SupportedAlgorithm.N2N.value, SupportedAlgorithm.CARE.value] +) def test_careamics_config_supervised(minimum_supervised_configuration, algorithm): """Test that the supervised configuration is created correctly.""" min_config = minimum_supervised_configuration