Skip to content

Commit

Permalink
Update configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jan 16, 2025
1 parent cbf971f commit 3957a1b
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 194 deletions.
70 changes: 68 additions & 2 deletions guides/careamist_api/careamist_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# %%
# --8<-- [start:careamist_api]
# --8<-- [start:quick_start_n2v]
import numpy as np
from careamics import CAREamist
from careamics.config import create_n2v_configuration
Expand All @@ -25,4 +25,70 @@
# once trained, predict
pred_data = np.random.randint(0, 255, (128, 128)).astype(np.float32)
predction = careamist.predict(source=pred_data)
# --8<-- [end:careamist_api]
# --8<-- [end:quick_start_n2v]

# --8<-- [start:quick_start_care]
import numpy as np
from careamics import CAREamist
from careamics.config import create_care_configuration

# create a configuration
config = create_care_configuration(
experiment_name="care_2D",
data_type="array",
axes="SYX",
patch_size=[64, 64],
batch_size=1,
num_epochs=1, # (1)!
)

# instantiate a careamist
careamist = CAREamist(config)

# train the model
train_data = np.random.randint(0, 255, (5, 256, 256)).astype(np.float32) # (2)!
train_target = np.random.randint(0, 255, (5, 256, 256)).astype(np.float32)
val_data = np.random.randint(0, 255, (2, 256, 256)).astype(np.float32)
val_target = np.random.randint(0, 255, (2, 256, 256)).astype(np.float32)
careamist.train(
train_source=train_data,
train_target=train_target,
val_source=val_data,
val_target=val_target,
)

# once trained, predict
pred_data = np.random.randint(0, 255, (128, 128)).astype(np.float32)
predction = careamist.predict(source=pred_data, axes="YX")
# --8<-- [end:quick_start_care]

# --8<-- [start:quick_start_n2n]
import numpy as np
from careamics import CAREamist
from careamics.config import create_n2n_configuration

# create a configuration
config = create_n2n_configuration(
experiment_name="n2n_2D",
data_type="array",
axes="YX",
patch_size=[64, 64],
batch_size=1,
num_epochs=1, # (1)!
)

# instantiate a careamist
careamist = CAREamist(config)

# train the model
train_data = np.random.randint(0, 255, (256, 256)).astype(np.float32) # (2)!
train_target = np.random.randint(0, 255, (256, 256)).astype(np.float32)
careamist.train(
train_source=train_data,
train_target=train_target,
)

# once trained, predict
pred_data = np.random.randint(0, 255, (128, 128)).astype(np.float32)
predction = careamist.predict(source=pred_data)
# --8<-- [end:quick_start_care]
32 changes: 0 additions & 32 deletions guides/careamist_api/configuration/advanced_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,3 @@
num_epochs=20,
)
# --8<-- [end:data]

# %%
# --8<-- [start:model]
from careamics.config import FCNAlgorithmConfig, register_model
from torch import nn, ones


@register_model(name="linear_model") # (1)!
class LinearModel(nn.Module):
def __init__(self, in_features, out_features, *args, **kwargs):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(ones(in_features, out_features))
self.bias = nn.Parameter(ones(out_features))

def forward(self, input):
return (input @ self.weight) + self.bias


config = FCNAlgorithmConfig(
algorithm_type="fcn",
algorithm="custom", # (2)!
loss="mse",
model={
"architecture": "custom", # (3)!
"name": "linear_model", # (4)!
"in_features": 10,
"out_features": 5,
},
)
# --8<-- [end:model]
161 changes: 103 additions & 58 deletions guides/careamist_api/configuration/build_configuration.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,138 @@
#!/usr/bin/env python
# %%
# --8<-- [start:as_dict]
from careamics import Configuration

config_as_dict = {
"experiment_name": "my_experiment", # (1)!
"algorithm_config": { # (2)!
"algorithm_type": "fcn",
"algorithm": "n2v",
"loss": "n2v",
"model": { # (3)!
"architecture": "UNet",
},
},
"data_config": { # (4)!
"data_type": "array",
"patch_size": [128, 128],
"axes": "YX",
},
"training_config": {
"num_epochs": 1,
},
}
config = Configuration(**config_as_dict) # (5)!
# --8<-- [end:as_dict]

# %%
# --8<-- [start:pydantic]
from careamics import Configuration
from careamics.config import ( # (1)!
DataConfig,
FCNAlgorithmConfig,
N2VAlgorithm,
N2VConfiguration,
N2VDataConfig,
TrainingConfig,
configuration_factory,
)
from careamics.config.architectures import UNetModel
from careamics.config.callback_model import CheckpointModel, EarlyStoppingModel
from careamics.config.support import (
SupportedAlgorithm,
SupportedArchitecture,
SupportedData,
SupportedLogger,
SupportedLoss,
SupportedTransform,
)
from careamics.config.transformations import N2VManipulateModel

experiment_name = "Pydantic N2V2 example"

# build AlgorithmConfig for the fully convolutional network
algorithm_model = FCNAlgorithmConfig( # (2)!
algorithm_type="fcn",
algorithm=SupportedAlgorithm.N2V.value, # (3)!
loss=SupportedLoss.N2V.value,
model=UNetModel( # (4)!
architecture=SupportedArchitecture.UNET.value,
in_channels=1,
num_classes=1,
),
from careamics.config.transformations import N2VManipulateModel, XYFlipModel

experiment_name = "N2V_example"

# build the model and algorithm configurations
model = UNetModel(
architecture="UNet", # (2)!
num_channels_init=64, # (3)!
depth=3,
# (4)!
)

algorithm_model = N2VAlgorithm( # (5)!
model=model,
# (6)!
)

# then the DataConfig
data_model = DataConfig(
# then the N2VDataConfig
data_model = N2VDataConfig( # (7)!
data_type=SupportedData.ARRAY.value,
patch_size=(256, 256),
batch_size=8,
axes="YX",
transforms=[
{ # (5)!
"name": SupportedTransform.XY_FLIP.value,
},
N2VManipulateModel( # (6)!
masked_pixel_percentage=0.15,
),
],
dataloader_params={ # (7)!
transforms=[XYFlipModel(flip_y=False), N2VManipulateModel()], # (8)! # (9)!
dataloader_params={ # (10)!
"num_workers": 4,
},
)

# then the TrainingConfig
earlystopping = EarlyStoppingModel(
# (11)!
)

checkpoints = CheckpointModel(every_n_epochs=10) # (12)!

training_model = TrainingConfig(
num_epochs=30,
logger=SupportedLogger.WANDB.value,
early_stopping_callback=earlystopping,
checkpoint_callback=checkpoints,
# (13)!
)

# finally, build the Configuration
config = Configuration( # (8)!
config = N2VConfiguration( # (14)!
experiment_name=experiment_name,
algorithm_config=algorithm_model,
data_config=data_model,
training_config=training_model,
)

# alternatively, use the factory method
config2 = configuration_factory( # (15)!
{
"experiment_name": experiment_name,
"algorithm_config": algorithm_model,
"data_config": data_model,
"training_config": training_model,
}
)
# --8<-- [end:pydantic]

if config != config2:
raise ValueError("Configurations are not equal (Pydantic).")

# %%
# --8<-- [start:as_dict]
from careamics.config import N2VConfiguration, configuration_factory

config_dict = {
"experiment_name": "N2V_example",
"algorithm_config": {
"algorithm": "n2v", # (1)!
"loss": "n2v",
"model": {
"architecture": "UNet", # (2)!
"num_channels_init": 64,
"depth": 3,
},
},
"data_config": {
"data_type": "array",
"patch_size": [256, 256],
"batch_size": 8,
"axes": "YX",
"transforms": [
{
"name": "XYFlip",
"flip_y": False,
},
{
"name": "N2VManipulate",
},
],
"dataloader_params": {
"num_workers": 4,
},
},
"training_config": {
"num_epochs": 30,
"logger": "wandb",
"early_stopping_callback": {}, # (3)!
"checkpoint_callback": {
"every_n_epochs": 10,
},
},
}

# instantiate specific configuration
config_as_dict = N2VConfiguration(**config_dict) # (4)!

# alternatively, use the factory method
config_as_dict2 = configuration_factory(config_dict) # (5)!
# --8<-- [end:as_dict]

if config_as_dict != config_as_dict2:
raise ValueError("Configurations are not equal (Dict).")

if config != config_as_dict:
raise ValueError("Configurations are not equal (Pydantic vs Dict).")
33 changes: 33 additions & 0 deletions guides/careamist_api/configuration/configuration_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python
# %%
from careamics.config import create_n2v_configuration

# %%
# create_n2v_configuration(
# experiment_name="N2V_example}",
# data_type="array",
# axes="YX",
# patch_size=[256, 256, 512],
# batch_size=8,
# num_epochs=30,
# )

# %%
create_n2v_configuration(
experiment_name="N2V_example",
data_type="arrray",
axes="YX",
patch_size=[256, 256],
batch_size=8,
num_epochs=30,
)

# %%
create_n2v_configuration(
experiment_name="N2V_example",
data_type="array",
axes="YX",
patch_size=[256, 256, 512],
batch_size=8,
num_epochs=30,
)
Loading

0 comments on commit 3957a1b

Please sign in to comment.