-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refac: Algorithm specific configurations (#344)
### Description > [!IMPORTANT] > **TLDR**: Split Pydantic configurations to simplify validation. Following #320, this PR adds algorithm-specific main configurations, as well as algorithm and data sub-configurations when necessary (e.g. N2V) in order to simplify the code base. Additionally, it fix mypy handling of Pydantic models using the Pydantic mypy plugin. Finally, the `CustomModel` capacity was removed, fixing errors in the mypy report. #### Pydantic mypy plugin Natively, mypy does not have a good handle of Pydantic models. Adding the [Pydantic mypy plugin](https://docs.pydantic.dev/latest/integrations/mypy/#configuring-the-plugin) yields new errors that were hidden from us (see [report here](https://results.pre-commit.ci/run/github/607105530/1729457186.StZU4dBMTrq2a55jjzwmGw)). #### Custom models The custom models were added without having a real use-case, and the examples were mostly very simple toy examples. It is likely that any real use case will require more than plugging in a model in the CAREamist API, and will probably start with using one's own Lightning module in the Lightning API. In addition, the presence of unknown Pydantic models creates issues with mypy (see previous point), and adds complexity in the code for a feature we are not sure will be used. This PR removes the `CustomModel` class and all mechanism to register models. #### Splitting configurations Another pain point in the code base (#320) is that having one single `Configuration` for vastly different algorithms (and their constraints) generates complex validation code: validation across sub-configurations, many `if else` statements, and long files with tons of `field_validators`. An elegant solution is to define a general configuration and sub-class it with algorithm-specific Pydantic models. Then, a "factory" model can leverage Pydantic model selection to choose the correct algorithm-specific configuration based on the parameters. Here is an illustration: ```python class Configuration(BaseModel): # some example parameters version: Literal["0.1.0"] = "0.1.0" algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm] data_config: Union[DataConfig, N2VDataConfig] class N2VConfiguration(Configuration): algorithm_config: UNetBasedAlgorithm data_config: N2VDataConfig class HDN(Configuration): algorithm_config: VAEBasedAlgorithm data_config: DataConfig ``` In this scenario, all the validation related to the `N2VManipulate` is in `N2VDataConfig` and `N2VConfiguration` can freely assume that the correct transforms are present. Similarly, `DataConfig` does not need to care about `N2VManipulate`, and other subclasses can exclude it from the allowed types in its transforms. In the same spirit, the algorithm config is always `UNet` for N2V, and `VAE` for HDN. In this PR, we followed that idea with algorithm specific configurations (both for the main configuration and the algorithm sub configuration). The main issue raised by this approach is that `Configuration` can no longer be used to instantiate the correct configuration, but rather `configuration_factory` should be used: ``` python my_cfg = { ... } # old way cfg = Configuration(**my_cfg) # new approach cfg = configuration_factory(my_cfg) ``` Here, the model choice (`N2VConfiguration`, `CAREConfiguration`, etc...) is delegated to Pydantic via the `configuration_factory`. Similar factories have been created for the algorithm and data sub-configurations in order to allow correct instantiation in the Lightning API. ### Changes Made - **Added**: - `N2VConfiguration`: main configuration for N2V - `N2NConfiguration`: idem for N2N - `CAREConfiguration`: idem for CARE - `N2VDataConfig`: data subconfiguration for N2V (enforced `N2VManipulate`) - `N2VAlgorithm`: algorithm subconfiguration for N2V - `N2NAlgorithm`: idem for N2N - `CAREAlgorithm`: idem for CARE - `ConfigurationFactory`: a Pydantic model instantiating the correct configuration - `AlgorithmFactory`: idem for algorithm subconfigurations - `DataFactory`: idem for data subconfigurations - `configuration_io.py`: read/save methods for configurations, refactored from `configuration_factory.py` - **Modified**: - All calls to `Configuration` and `Configuration` itself. - Renamed UNet and VAE algorithm configurations. - **Removed**: `references` module, all calls to `has_n2v_manipulate`, anything to do with `CustomModel`. > [!WARNING] > Testing could be improve further and there are probably some aspects that I am missing. But I'd rather get reviews already! ### Related Issues - Resolves #320 ### Breaking changes - Configuration, algorithm subconfig, and data subconfig should be instantiated via the factories. - Renaming impacts the algorithm subconfigurations (`UNetAlgorithm`, `VAEAlgorithm`). - `CustomModel` is now completely removed. - Calls to `has_n2v_manipulate` no longer work, prefer `isinstance(cfg, N2VConfiguration)`. > [!WARNING] > Any current work on adding new algorithm and configuration will be largely impacted by this PR. @CatEek --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features)
- Loading branch information
1 parent
ec1ff2a
commit 2bdd021
Showing
75 changed files
with
2,551 additions
and
2,655 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,63 @@ | ||
"""Configuration module.""" | ||
"""CAREamics Pydantic configuration models. | ||
To maintain clarity at the module level, we follow the following naming conventions: | ||
`*_model` is specific for sub-configurations (e.g. architecture, data, algorithm), | ||
while `*_configuration` is reserved for the main configuration models, including the | ||
`Configuration` base class and its algorithm-specific child classes. | ||
""" | ||
|
||
__all__ = [ | ||
"CAREAlgorithm", | ||
"CAREConfiguration", | ||
"CheckpointModel", | ||
"Configuration", | ||
"CustomModel", | ||
"DataConfig", | ||
"FCNAlgorithmConfig", | ||
"GaussianMixtureNMConfig", | ||
"GeneralDataConfig", | ||
"InferenceConfig", | ||
"LVAELossConfig", | ||
"MultiChannelNMConfig", | ||
"N2NAlgorithm", | ||
"N2NConfiguration", | ||
"N2VAlgorithm", | ||
"N2VConfiguration", | ||
"N2VDataConfig", | ||
"TrainingConfig", | ||
"VAEAlgorithmConfig", | ||
"clear_custom_models", | ||
"UNetBasedAlgorithm", | ||
"VAEBasedAlgorithm", | ||
"algorithm_factory", | ||
"configuration_factory", | ||
"create_care_configuration", | ||
"create_n2n_configuration", | ||
"create_n2v_configuration", | ||
"data_factory", | ||
"load_configuration", | ||
"register_model", | ||
"save_configuration", | ||
] | ||
from .architectures import CustomModel, clear_custom_models, register_model | ||
|
||
from .algorithms import ( | ||
CAREAlgorithm, | ||
N2NAlgorithm, | ||
N2VAlgorithm, | ||
UNetBasedAlgorithm, | ||
VAEBasedAlgorithm, | ||
) | ||
from .callback_model import CheckpointModel | ||
from .configuration_factory import ( | ||
from .care_configuration import CAREConfiguration | ||
from .configuration import Configuration | ||
from .configuration_factories import ( | ||
algorithm_factory, | ||
configuration_factory, | ||
create_care_configuration, | ||
create_n2n_configuration, | ||
create_n2v_configuration, | ||
data_factory, | ||
) | ||
from .configuration_model import ( | ||
Configuration, | ||
load_configuration, | ||
save_configuration, | ||
) | ||
from .data_model import DataConfig | ||
from .fcn_algorithm_model import FCNAlgorithmConfig | ||
from .configuration_io import load_configuration, save_configuration | ||
from .data import DataConfig, GeneralDataConfig, N2VDataConfig | ||
from .inference_model import InferenceConfig | ||
from .loss_model import LVAELossConfig | ||
from .n2n_configuration import N2NConfiguration | ||
from .n2v_configuration import N2VConfiguration | ||
from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig | ||
from .training_model import TrainingConfig | ||
from .vae_algorithm_model import VAEAlgorithmConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
"""Algorithm configurations.""" | ||
|
||
__all__ = [ | ||
"CAREAlgorithm", | ||
"N2NAlgorithm", | ||
"N2VAlgorithm", | ||
"UNetBasedAlgorithm", | ||
"VAEBasedAlgorithm", | ||
] | ||
|
||
from .care_algorithm_model import CAREAlgorithm | ||
from .n2n_algorithm_model import N2NAlgorithm | ||
from .n2v_algorithm_model import N2VAlgorithm | ||
from .unet_algorithm_model import UNetBasedAlgorithm | ||
from .vae_algorithm_model import VAEBasedAlgorithm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
"""CARE algorithm configuration.""" | ||
|
||
from typing import Literal | ||
|
||
from pydantic import field_validator | ||
|
||
from careamics.config.architectures import UNetModel | ||
|
||
from .unet_algorithm_model import UNetBasedAlgorithm | ||
|
||
|
||
class CAREAlgorithm(UNetBasedAlgorithm): | ||
"""CARE algorithm configuration. | ||
Attributes | ||
---------- | ||
algorithm : "care" | ||
CARE Algorithm name. | ||
loss : {"mae", "mse"} | ||
CARE-compatible loss function. | ||
""" | ||
|
||
algorithm: Literal["care"] = "care" | ||
"""CARE Algorithm name.""" | ||
|
||
loss: Literal["mae", "mse"] = "mae" | ||
"""CARE-compatible loss function.""" | ||
|
||
@classmethod | ||
@field_validator("model") | ||
def model_without_n2v2(cls, value: UNetModel) -> UNetModel: | ||
"""Validate that the model does not have the n2v2 attribute. | ||
Parameters | ||
---------- | ||
value : UNetModel | ||
Model to validate. | ||
Returns | ||
------- | ||
UNetModel | ||
The validated model. | ||
""" | ||
if value.n2v2: | ||
raise ValueError( | ||
"The N2N algorithm does not support the `n2v2` attribute. " | ||
"Set it to `False`." | ||
) | ||
|
||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
"""N2N Algorithm configuration.""" | ||
|
||
from typing import Literal | ||
|
||
from pydantic import field_validator | ||
|
||
from careamics.config.architectures import UNetModel | ||
|
||
from .unet_algorithm_model import UNetBasedAlgorithm | ||
|
||
|
||
class N2NAlgorithm(UNetBasedAlgorithm): | ||
"""N2N Algorithm configuration.""" | ||
|
||
algorithm: Literal["n2n"] = "n2n" | ||
"""N2N Algorithm name.""" | ||
|
||
loss: Literal["mae", "mse"] = "mae" | ||
"""N2N-compatible loss function.""" | ||
|
||
@classmethod | ||
@field_validator("model") | ||
def model_without_n2v2(cls, value: UNetModel) -> UNetModel: | ||
"""Validate that the model does not have the n2v2 attribute. | ||
Parameters | ||
---------- | ||
value : UNetModel | ||
Model to validate. | ||
Returns | ||
------- | ||
UNetModel | ||
The validated model. | ||
""" | ||
if value.n2v2: | ||
raise ValueError( | ||
"The N2N algorithm does not support the `n2v2` attribute. " | ||
"Set it to `False`." | ||
) | ||
|
||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
""""N2V Algorithm configuration.""" | ||
|
||
from typing import Literal | ||
|
||
from pydantic import model_validator | ||
from typing_extensions import Self | ||
|
||
from .unet_algorithm_model import UNetBasedAlgorithm | ||
|
||
|
||
class N2VAlgorithm(UNetBasedAlgorithm): | ||
"""N2V Algorithm configuration.""" | ||
|
||
algorithm: Literal["n2v"] = "n2v" | ||
"""N2V Algorithm name.""" | ||
|
||
loss: Literal["n2v"] = "n2v" | ||
"""N2V loss function.""" | ||
|
||
@model_validator(mode="after") | ||
def algorithm_cross_validation(self: Self) -> Self: | ||
"""Validate the algorithm model for N2V. | ||
Returns | ||
------- | ||
Self | ||
The validated model. | ||
""" | ||
if self.model.in_channels != self.model.num_classes: | ||
raise ValueError( | ||
"N2V requires the same number of input and output channels. Make " | ||
"sure that `in_channels` and `num_classes` are the same." | ||
) | ||
|
||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""UNet-based algorithm Pydantic model.""" | ||
|
||
from pprint import pformat | ||
from typing import Literal | ||
|
||
from pydantic import BaseModel, ConfigDict | ||
|
||
from careamics.config.architectures import UNetModel | ||
from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel | ||
|
||
|
||
class UNetBasedAlgorithm(BaseModel): | ||
"""General UNet-based algorithm configuration. | ||
This Pydantic model validates the parameters governing the components of the | ||
training algorithm: which algorithm, loss function, model architecture, optimizer, | ||
and learning rate scheduler to use. | ||
Currently, we only support N2V, CARE, and N2N algorithms. In order to train these | ||
algorithms, use the corresponding configuration child classes (e.g. | ||
`N2VAlgorithm`) to ensure coherent parameters (e.g. specific losses). | ||
Attributes | ||
---------- | ||
algorithm : {"n2v", "care", "n2n"} | ||
Algorithm to use. | ||
loss : {"n2v", "mae", "mse"} | ||
Loss function to use. | ||
model : UNetModel | ||
Model architecture to use. | ||
optimizer : OptimizerModel, optional | ||
Optimizer to use. | ||
lr_scheduler : LrSchedulerModel, optional | ||
Learning rate scheduler to use. | ||
Raises | ||
------ | ||
ValueError | ||
Algorithm parameter type validation errors. | ||
ValueError | ||
If the algorithm, loss and model are not compatible. | ||
""" | ||
|
||
# Pydantic class configuration | ||
model_config = ConfigDict( | ||
protected_namespaces=(), # allows to use model_* as a field name | ||
validate_assignment=True, | ||
extra="allow", | ||
) | ||
|
||
# Mandatory fields | ||
algorithm: Literal["n2v", "care", "n2n"] | ||
"""Algorithm name, as defined in SupportedAlgorithm.""" | ||
|
||
loss: Literal["n2v", "mae", "mse"] | ||
"""Loss function to use, as defined in SupportedLoss.""" | ||
|
||
model: UNetModel | ||
"""UNet model configuration.""" | ||
|
||
# Optional fields | ||
optimizer: OptimizerModel = OptimizerModel() | ||
"""Optimizer to use, defined in SupportedOptimizer.""" | ||
|
||
lr_scheduler: LrSchedulerModel = LrSchedulerModel() | ||
"""Learning rate scheduler to use, defined in SupportedLrScheduler.""" | ||
|
||
def __str__(self) -> str: | ||
"""Pretty string representing the configuration. | ||
Returns | ||
------- | ||
str | ||
Pretty string. | ||
""" | ||
return pformat(self.model_dump()) | ||
|
||
@classmethod | ||
def get_compatible_algorithms(cls) -> list[str]: | ||
"""Get the list of compatible algorithms. | ||
Returns | ||
------- | ||
list of str | ||
List of compatible algorithms. | ||
""" | ||
return ["n2v", "care", "n2n"] |
Oops, something went wrong.