Skip to content

Commit

Permalink
Fix guides (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps authored Feb 18, 2025
1 parent b210f07 commit 78dcdef
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 45 deletions.
52 changes: 16 additions & 36 deletions guides/careamist_api/configuration/build_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
# %%
# --8<-- [start:pydantic]
from careamics.config import ( # (1)!
Configuration,
DataConfig,
N2VAlgorithm,
N2VConfiguration,
N2VDataConfig,
TrainingConfig,
configuration_factory,
)
from careamics.config.architectures import UNetModel
from careamics.config.callback_model import CheckpointModel, EarlyStoppingModel
from careamics.config.support import (
SupportedData,
SupportedLogger,
)
from careamics.config.transformations import N2VManipulateModel, XYFlipModel
from careamics.config.transformations import XYFlipModel

experiment_name = "N2V_example"

Expand All @@ -32,58 +31,46 @@
# (6)!
)

# then the N2VDataConfig
data_model = N2VDataConfig( # (7)!
# then the data configuration for N2V
data_model = DataConfig( # (7)!
data_type=SupportedData.ARRAY.value,
patch_size=(256, 256),
batch_size=8,
axes="YX",
transforms=[XYFlipModel(flip_y=False), N2VManipulateModel()], # (8)! # (9)!
dataloader_params={ # (10)!
transforms=[XYFlipModel(flip_y=False)], # (8)!
dataloader_params={ # (9)!
"num_workers": 4,
"shuffle": True,
},
)

# then the TrainingConfig
earlystopping = EarlyStoppingModel(
# (11)!
# (10)!
)

checkpoints = CheckpointModel(every_n_epochs=10) # (12)!
checkpoints = CheckpointModel(every_n_epochs=10) # (11)!

training_model = TrainingConfig(
num_epochs=30,
logger=SupportedLogger.WANDB.value,
early_stopping_callback=earlystopping,
checkpoint_callback=checkpoints,
# (13)!
# (12)!
)

# finally, build the Configuration
config = N2VConfiguration( # (14)!
config = Configuration( # (13)!
experiment_name=experiment_name,
algorithm_config=algorithm_model,
data_config=data_model,
training_config=training_model,
)

# alternatively, use the factory method
config2 = configuration_factory( # (15)!
{
"experiment_name": experiment_name,
"algorithm_config": algorithm_model,
"data_config": data_model,
"training_config": training_model,
}
)
# --8<-- [end:pydantic]

if config != config2:
raise ValueError("Configurations are not equal (Pydantic).")

# %%
# --8<-- [start:as_dict]
from careamics.config import N2VConfiguration, configuration_factory
from careamics.config import Configuration

config_dict = {
"experiment_name": "N2V_example",
Expand All @@ -95,6 +82,7 @@
"num_channels_init": 64,
"depth": 3,
},
# (3)!
},
"data_config": {
"data_type": "array",
Expand All @@ -106,9 +94,6 @@
"name": "XYFlip",
"flip_y": False,
},
{
"name": "N2VManipulate",
},
],
"dataloader_params": {
"num_workers": 4,
Expand All @@ -117,22 +102,17 @@
"training_config": {
"num_epochs": 30,
"logger": "wandb",
"early_stopping_callback": {}, # (3)!
"early_stopping_callback": {}, # (4)!
"checkpoint_callback": {
"every_n_epochs": 10,
},
},
}

# instantiate specific configuration
config_as_dict = N2VConfiguration(**config_dict) # (4)!
config_as_dict = Configuration(**config_dict) # (5)!

# alternatively, use the factory method
config_as_dict2 = configuration_factory(config_dict) # (5)!
# --8<-- [end:as_dict]

if config_as_dict != config_as_dict2:
raise ValueError("Configurations are not equal (Dict).")

if config != config_as_dict:
raise ValueError("Configurations are not equal (Pydantic vs Dict).")
17 changes: 11 additions & 6 deletions guides/careamist_api/configuration/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ algorithm_config:
lr_scheduler:
name: ReduceLROnPlateau
parameters: {}
n2v_config:
name: N2VManipulate
roi_size: 11
masked_pixel_percentage: 0.2
remove_center: true
strategy: uniform
struct_mask_axis: none
struct_mask_span: 5
data_config:
data_type: tiff
axes: ZYX
Expand All @@ -35,12 +43,9 @@ data_config:
p: 0.5
- name: XYRandomRotate90
p: 0.5
- name: N2VManipulate
roi_size: 11
masked_pixel_percentage: 0.2
strategy: uniform
struct_mask_axis: none
struct_mask_span: 5
train_dataloader_params:
shuffle: true
val_dataloader_params: {}
training_config:
num_epochs: 20
precision: '32'
Expand Down
15 changes: 12 additions & 3 deletions guides/careamist_api/configuration/convenience_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,12 @@
patch_size=[16, 64, 64],
batch_size=8,
num_epochs=20,
dataloader_params={
train_dataloader_params={
"num_workers": 4, # (1)!
},
val_dataloader_params={
"num_workers": 2, # (2)!
},
)
# --8<-- [end:n2v_dataloader_kwargs]
# N2N with dataloader parameters
Expand All @@ -373,9 +376,12 @@
patch_size=[16, 64, 64],
batch_size=8,
num_epochs=20,
dataloader_params={
train_dataloader_params={
"num_workers": 4, # (1)!
},
val_dataloader_params={
"num_workers": 2, # (2)!
},
)
# --8<-- [end:care_dataloader_kwargs]
# N2N with dataloader parameters
Expand All @@ -387,8 +393,11 @@
patch_size=[16, 64, 64],
batch_size=8,
num_epochs=20,
dataloader_params={
train_dataloader_params={
"num_workers": 4, # (1)!
},
val_dataloader_params={
"num_workers": 2, # (2)!
},
)
# --8<-- [end:n2n_dataloader_kwargs]

0 comments on commit 78dcdef

Please sign in to comment.