Skip to content

Commit

Permalink
merge changes from #366
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jan 22, 2025
2 parents 70a81e3 + 8cc0e9d commit 6eb9acb
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 4 deletions.
31 changes: 29 additions & 2 deletions src/careamics/config/configuration_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

from typing import Annotated, Any, Literal, Optional, Union

from pydantic import Field, TypeAdapter
from pydantic import Discriminator, Field, Tag, TypeAdapter

from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
from careamics.config.architectures import UNetModel
from careamics.config.care_configuration import CAREConfiguration
from careamics.config.configuration import Configuration
from careamics.config.data import DataConfig
from careamics.config.n2n_configuration import N2NConfiguration
from careamics.config.n2v_configuration import N2VConfiguration
from careamics.config.support import (
SupportedAlgorithm,
SupportedArchitecture,
SupportedPixelManipulation,
SupportedTransform,
Expand All @@ -24,6 +26,24 @@
)


def _algorithm_config_discriminator(value: Union[dict, Configuration]) -> str:
"""Discriminate algorithm-specific configurations based on the algorithm.
Parameters
----------
value : Any
Value to discriminate.
Returns
-------
str
Discriminator value.
"""
if isinstance(value, dict):
return value["algorithm_config"]["algorithm"]
return value.algorithm_config.algorithm


def configuration_factory(
configuration: dict[str, Any]
) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]:
Expand All @@ -41,7 +61,14 @@ def configuration_factory(
Configuration for training CAREamics.
"""
adapter: TypeAdapter = TypeAdapter(
Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]
Annotated[
Union[
Annotated[N2VConfiguration, Tag(SupportedAlgorithm.N2V.value)],
Annotated[N2NConfiguration, Tag(SupportedAlgorithm.N2N.value)],
Annotated[CAREConfiguration, Tag(SupportedAlgorithm.CARE.value)],
],
Discriminator(_algorithm_config_discriminator),
]
)
return adapter.validate_python(configuration)

Expand Down
6 changes: 5 additions & 1 deletion src/careamics/config/support/supported_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@


class SupportedAlgorithm(str, BaseEnum):
"""Algorithms available in CAREamics."""
"""Algorithms available in CAREamics.
These definitions are the same as the keyword `name` of the algorithm
configurations.
"""

N2V = "n2v"
"""Noise2Void algorithm, a self-supervised approach based on blind denoising."""
Expand Down
24 changes: 23 additions & 1 deletion tests/config/test_configuration_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
create_n2v_configuration,
)
from careamics.config.configuration_factories import (
_algorithm_config_discriminator,
_create_supervised_config_dict,
_create_unet_configuration,
_list_spatial_augmentations,
configuration_factory,
)
from careamics.config.support import (
SupportedAlgorithm,
SupportedPixelManipulation,
SupportedStructAxis,
SupportedTransform,
Expand All @@ -30,13 +32,33 @@
)


def test_algorithm_discriminator_n2v(minimum_n2v_configuration):
"""Test that the N2V configuration is discriminated correctly."""
tag = _algorithm_config_discriminator(minimum_n2v_configuration)
assert tag == SupportedAlgorithm.N2V.value


@pytest.mark.parametrize(
"algorithm", [SupportedAlgorithm.N2N.value, SupportedAlgorithm.CARE.value]
)
def test_algorithm_discriminator_supervised(
minimum_supervised_configuration, algorithm
):
"""Test that the supervised configuration is discriminated correctly."""
minimum_supervised_configuration["algorithm_config"]["algorithm"] = algorithm
tag = _algorithm_config_discriminator(minimum_supervised_configuration)
assert tag == algorithm


def test_careamics_config_n2v(minimum_n2v_configuration):
"""Test that the N2V configuration is created correctly."""
configuration = configuration_factory(minimum_n2v_configuration)
assert isinstance(configuration, N2VConfiguration)


@pytest.mark.parametrize("algorithm", ["n2n", "care"])
@pytest.mark.parametrize(
"algorithm", [SupportedAlgorithm.N2N.value, SupportedAlgorithm.CARE.value]
)
def test_careamics_config_supervised(minimum_supervised_configuration, algorithm):
"""Test that the supervised configuration is created correctly."""
min_config = minimum_supervised_configuration
Expand Down

0 comments on commit 6eb9acb

Please sign in to comment.