Skip to content

Commit

Permalink
refac: Algorithm specific configurations (#344)
Browse files Browse the repository at this point in the history
### Description

> [!IMPORTANT]  
> **TLDR**: Split Pydantic configurations to simplify validation.

Following #320, this PR
adds algorithm-specific main configurations, as well as algorithm and
data sub-configurations when necessary (e.g. N2V) in order to simplify
the code base. Additionally, it fix mypy handling of Pydantic models
using the Pydantic mypy plugin. Finally, the `CustomModel` capacity was
removed, fixing errors in the mypy report.

#### Pydantic mypy plugin

Natively, mypy does not have a good handle of Pydantic models. Adding
the [Pydantic mypy
plugin](https://docs.pydantic.dev/latest/integrations/mypy/#configuring-the-plugin)
yields new errors that were hidden from us (see [report
here](https://results.pre-commit.ci/run/github/607105530/1729457186.StZU4dBMTrq2a55jjzwmGw)).

#### Custom models

The custom models were added without having a real use-case, and the
examples were mostly very simple toy examples. It is likely that any
real use case will require more than plugging in a model in the
CAREamist API, and will probably start with using one's own Lightning
module in the Lightning API.

In addition, the presence of unknown Pydantic models creates issues with
mypy (see previous point), and adds complexity in the code for a feature
we are not sure will be used. This PR removes the `CustomModel` class
and all mechanism to register models.

#### Splitting configurations

Another pain point in the code base
(#320) is that having one
single `Configuration` for vastly different algorithms (and their
constraints) generates complex validation code: validation across
sub-configurations, many `if else` statements, and long files with tons
of `field_validators`.

An elegant solution is to define a general configuration and sub-class
it with algorithm-specific Pydantic models. Then, a "factory" model can
leverage Pydantic model selection to choose the correct
algorithm-specific configuration based on the parameters. Here is an
illustration:

```python
class Configuration(BaseModel):
    # some example parameters
    version: Literal["0.1.0"] = "0.1.0"
    algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm]
    data_config: Union[DataConfig, N2VDataConfig]

class N2VConfiguration(Configuration):
    algorithm_config: UNetBasedAlgorithm
    data_config: N2VDataConfig

class HDN(Configuration):
    algorithm_config: VAEBasedAlgorithm
    data_config: DataConfig
```

In this scenario, all the validation related to the `N2VManipulate` is
in `N2VDataConfig` and `N2VConfiguration` can freely assume that the
correct transforms are present. Similarly, `DataConfig` does not need to
care about `N2VManipulate`, and other subclasses can exclude it from the
allowed types in its transforms. In the same spirit, the algorithm
config is always `UNet` for N2V, and `VAE` for HDN. In this PR, we
followed that idea with algorithm specific configurations (both for the
main configuration and the algorithm sub configuration).

The main issue raised by this approach is that `Configuration` can no
longer be used to instantiate the correct configuration, but rather
`configuration_factory` should be used:

``` python
my_cfg = {
    ...
}

# old way
cfg = Configuration(**my_cfg)

# new approach
cfg = configuration_factory(my_cfg)
```

Here, the model choice (`N2VConfiguration`, `CAREConfiguration`, etc...)
is delegated to Pydantic via the `configuration_factory`.

Similar factories have been created for the algorithm and data
sub-configurations in order to allow correct instantiation in the
Lightning API.

### Changes Made

- **Added**: 
    - `N2VConfiguration`: main configuration for N2V
    - `N2NConfiguration`: idem for N2N
    - `CAREConfiguration`: idem for CARE
- `N2VDataConfig`: data subconfiguration for N2V (enforced
`N2VManipulate`)
    - `N2VAlgorithm`: algorithm subconfiguration for N2V
    - `N2NAlgorithm`: idem for N2N
    - `CAREAlgorithm`: idem for CARE
- `ConfigurationFactory`: a Pydantic model instantiating the correct
configuration
    - `AlgorithmFactory`: idem for algorithm subconfigurations
    - `DataFactory`: idem for data subconfigurations
- `configuration_io.py`: read/save methods for configurations,
refactored from `configuration_factory.py`
- **Modified**: 
    - All calls to `Configuration` and `Configuration` itself. 
    - Renamed UNet and VAE algorithm configurations.
- **Removed**: `references` module, all calls to `has_n2v_manipulate`,
anything to do with `CustomModel`.

> [!WARNING]  
> Testing could be improve further and there are probably some aspects
that I am missing. But I'd rather get reviews already!

### Related Issues

- Resolves #320

### Breaking changes

- Configuration, algorithm subconfig, and data subconfig should be
instantiated via the factories.
- Renaming impacts the algorithm subconfigurations (`UNetAlgorithm`,
`VAEAlgorithm`).
- `CustomModel` is now completely removed.
- Calls to `has_n2v_manipulate` no longer work, prefer `isinstance(cfg,
N2VConfiguration)`.

> [!WARNING]  
> Any current work on adding new algorithm and configuration will be
largely impacted by this PR. @CatEek
---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
jdeschamps authored Jan 14, 2025
1 parent ec1ff2a commit 2bdd021
Show file tree
Hide file tree
Showing 75 changed files with 2,551 additions and 2,655 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ repos:
args: ["--config-file", "mypy.ini"]
additional_dependencies:
- numpy<2.0.0
- pydantic
- types-PyYAML
- types-setuptools

Expand Down
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[mypy]
ignore_missing_imports = True
plugins = pydantic.mypy

[mypy-careamics.lvae_training.*]
follow_imports = skip
Expand Down
19 changes: 17 additions & 2 deletions src/careamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,22 @@
except PackageNotFoundError:
__version__ = "uninstalled"

__all__ = ["CAREamist", "Configuration", "load_configuration", "save_configuration"]
__all__ = [
"CAREamist",
"Configuration",
"algorithm_factory",
"configuration_factory",
"data_factory",
"load_configuration",
"save_configuration",
]

from .careamist import CAREamist
from .config import Configuration, load_configuration, save_configuration
from .config import (
Configuration,
algorithm_factory,
configuration_factory,
data_factory,
load_configuration,
save_configuration,
)
7 changes: 4 additions & 3 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger

from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
from careamics.config import Configuration, UNetBasedAlgorithm, load_configuration
from careamics.config.support import (
SupportedAlgorithm,
SupportedArchitecture,
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(
self.cfg = source

# instantiate model
if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
self.model = FCNModule(
algorithm_config=self.cfg.algorithm_config,
)
Expand All @@ -157,7 +157,8 @@ def __init__(
self.cfg = load_configuration(source)

# instantiate model
if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
# TODO call model factory here
if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
self.model = FCNModule(
algorithm_config=self.cfg.algorithm_config,
) # type: ignore
Expand Down
54 changes: 38 additions & 16 deletions src/careamics/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,63 @@
"""Configuration module."""
"""CAREamics Pydantic configuration models.
To maintain clarity at the module level, we follow the following naming conventions:
`*_model` is specific for sub-configurations (e.g. architecture, data, algorithm),
while `*_configuration` is reserved for the main configuration models, including the
`Configuration` base class and its algorithm-specific child classes.
"""

__all__ = [
"CAREAlgorithm",
"CAREConfiguration",
"CheckpointModel",
"Configuration",
"CustomModel",
"DataConfig",
"FCNAlgorithmConfig",
"GaussianMixtureNMConfig",
"GeneralDataConfig",
"InferenceConfig",
"LVAELossConfig",
"MultiChannelNMConfig",
"N2NAlgorithm",
"N2NConfiguration",
"N2VAlgorithm",
"N2VConfiguration",
"N2VDataConfig",
"TrainingConfig",
"VAEAlgorithmConfig",
"clear_custom_models",
"UNetBasedAlgorithm",
"VAEBasedAlgorithm",
"algorithm_factory",
"configuration_factory",
"create_care_configuration",
"create_n2n_configuration",
"create_n2v_configuration",
"data_factory",
"load_configuration",
"register_model",
"save_configuration",
]
from .architectures import CustomModel, clear_custom_models, register_model

from .algorithms import (
CAREAlgorithm,
N2NAlgorithm,
N2VAlgorithm,
UNetBasedAlgorithm,
VAEBasedAlgorithm,
)
from .callback_model import CheckpointModel
from .configuration_factory import (
from .care_configuration import CAREConfiguration
from .configuration import Configuration
from .configuration_factories import (
algorithm_factory,
configuration_factory,
create_care_configuration,
create_n2n_configuration,
create_n2v_configuration,
data_factory,
)
from .configuration_model import (
Configuration,
load_configuration,
save_configuration,
)
from .data_model import DataConfig
from .fcn_algorithm_model import FCNAlgorithmConfig
from .configuration_io import load_configuration, save_configuration
from .data import DataConfig, GeneralDataConfig, N2VDataConfig
from .inference_model import InferenceConfig
from .loss_model import LVAELossConfig
from .n2n_configuration import N2NConfiguration
from .n2v_configuration import N2VConfiguration
from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
from .training_model import TrainingConfig
from .vae_algorithm_model import VAEAlgorithmConfig
15 changes: 15 additions & 0 deletions src/careamics/config/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Algorithm configurations."""

__all__ = [
"CAREAlgorithm",
"N2NAlgorithm",
"N2VAlgorithm",
"UNetBasedAlgorithm",
"VAEBasedAlgorithm",
]

from .care_algorithm_model import CAREAlgorithm
from .n2n_algorithm_model import N2NAlgorithm
from .n2v_algorithm_model import N2VAlgorithm
from .unet_algorithm_model import UNetBasedAlgorithm
from .vae_algorithm_model import VAEBasedAlgorithm
50 changes: 50 additions & 0 deletions src/careamics/config/algorithms/care_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""CARE algorithm configuration."""

from typing import Literal

from pydantic import field_validator

from careamics.config.architectures import UNetModel

from .unet_algorithm_model import UNetBasedAlgorithm


class CAREAlgorithm(UNetBasedAlgorithm):
"""CARE algorithm configuration.
Attributes
----------
algorithm : "care"
CARE Algorithm name.
loss : {"mae", "mse"}
CARE-compatible loss function.
"""

algorithm: Literal["care"] = "care"
"""CARE Algorithm name."""

loss: Literal["mae", "mse"] = "mae"
"""CARE-compatible loss function."""

@classmethod
@field_validator("model")
def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
"""Validate that the model does not have the n2v2 attribute.
Parameters
----------
value : UNetModel
Model to validate.
Returns
-------
UNetModel
The validated model.
"""
if value.n2v2:
raise ValueError(
"The N2N algorithm does not support the `n2v2` attribute. "
"Set it to `False`."
)

return value
42 changes: 42 additions & 0 deletions src/careamics/config/algorithms/n2n_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""N2N Algorithm configuration."""

from typing import Literal

from pydantic import field_validator

from careamics.config.architectures import UNetModel

from .unet_algorithm_model import UNetBasedAlgorithm


class N2NAlgorithm(UNetBasedAlgorithm):
"""N2N Algorithm configuration."""

algorithm: Literal["n2n"] = "n2n"
"""N2N Algorithm name."""

loss: Literal["mae", "mse"] = "mae"
"""N2N-compatible loss function."""

@classmethod
@field_validator("model")
def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
"""Validate that the model does not have the n2v2 attribute.
Parameters
----------
value : UNetModel
Model to validate.
Returns
-------
UNetModel
The validated model.
"""
if value.n2v2:
raise ValueError(
"The N2N algorithm does not support the `n2v2` attribute. "
"Set it to `False`."
)

return value
35 changes: 35 additions & 0 deletions src/careamics/config/algorithms/n2v_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
""""N2V Algorithm configuration."""

from typing import Literal

from pydantic import model_validator
from typing_extensions import Self

from .unet_algorithm_model import UNetBasedAlgorithm


class N2VAlgorithm(UNetBasedAlgorithm):
"""N2V Algorithm configuration."""

algorithm: Literal["n2v"] = "n2v"
"""N2V Algorithm name."""

loss: Literal["n2v"] = "n2v"
"""N2V loss function."""

@model_validator(mode="after")
def algorithm_cross_validation(self: Self) -> Self:
"""Validate the algorithm model for N2V.
Returns
-------
Self
The validated model.
"""
if self.model.in_channels != self.model.num_classes:
raise ValueError(
"N2V requires the same number of input and output channels. Make "
"sure that `in_channels` and `num_classes` are the same."
)

return self
88 changes: 88 additions & 0 deletions src/careamics/config/algorithms/unet_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""UNet-based algorithm Pydantic model."""

from pprint import pformat
from typing import Literal

from pydantic import BaseModel, ConfigDict

from careamics.config.architectures import UNetModel
from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel


class UNetBasedAlgorithm(BaseModel):
"""General UNet-based algorithm configuration.
This Pydantic model validates the parameters governing the components of the
training algorithm: which algorithm, loss function, model architecture, optimizer,
and learning rate scheduler to use.
Currently, we only support N2V, CARE, and N2N algorithms. In order to train these
algorithms, use the corresponding configuration child classes (e.g.
`N2VAlgorithm`) to ensure coherent parameters (e.g. specific losses).
Attributes
----------
algorithm : {"n2v", "care", "n2n"}
Algorithm to use.
loss : {"n2v", "mae", "mse"}
Loss function to use.
model : UNetModel
Model architecture to use.
optimizer : OptimizerModel, optional
Optimizer to use.
lr_scheduler : LrSchedulerModel, optional
Learning rate scheduler to use.
Raises
------
ValueError
Algorithm parameter type validation errors.
ValueError
If the algorithm, loss and model are not compatible.
"""

# Pydantic class configuration
model_config = ConfigDict(
protected_namespaces=(), # allows to use model_* as a field name
validate_assignment=True,
extra="allow",
)

# Mandatory fields
algorithm: Literal["n2v", "care", "n2n"]
"""Algorithm name, as defined in SupportedAlgorithm."""

loss: Literal["n2v", "mae", "mse"]
"""Loss function to use, as defined in SupportedLoss."""

model: UNetModel
"""UNet model configuration."""

# Optional fields
optimizer: OptimizerModel = OptimizerModel()
"""Optimizer to use, defined in SupportedOptimizer."""

lr_scheduler: LrSchedulerModel = LrSchedulerModel()
"""Learning rate scheduler to use, defined in SupportedLrScheduler."""

def __str__(self) -> str:
"""Pretty string representing the configuration.
Returns
-------
str
Pretty string.
"""
return pformat(self.model_dump())

@classmethod
def get_compatible_algorithms(cls) -> list[str]:
"""Get the list of compatible algorithms.
Returns
-------
list of str
List of compatible algorithms.
"""
return ["n2v", "care", "n2n"]
Loading

0 comments on commit 2bdd021

Please sign in to comment.