diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cf62813..506ce74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,6 +37,7 @@ repos: additional_dependencies: - numpy - types-PyYAML + - careamics-stubs # check docstrings - repo: https://github.com/numpy/numpydoc diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..ff37064 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +plugins = pydantic.mypy diff --git a/src/careamics_napari/careamics_utils/callback.py b/src/careamics_napari/careamics_utils/callback.py index fbf49e4..d23247a 100644 --- a/src/careamics_napari/careamics_utils/callback.py +++ b/src/careamics_napari/careamics_utils/callback.py @@ -1,5 +1,4 @@ -"""PyTorch Lightning callback for updating a GUI with training and prediction progress. -""" +"""PyTorch Lightning callback used to update GUI with progress.""" from queue import Queue from typing import Any @@ -78,10 +77,12 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: PyTorch Lightning module. """ # compute the number of batches + len_dataloader = len(trainer.train_dataloader) # type: ignore + self.training_queue.put( TrainUpdate( TrainUpdateType.MAX_BATCH, - int(len(trainer.train_dataloader) / trainer.accumulate_grad_batches), + int(len_dataloader / trainer.accumulate_grad_batches), ) ) diff --git a/src/careamics_napari/careamics_utils/configuration.py b/src/careamics_napari/careamics_utils/configuration.py index 8aba78d..bac4736 100644 --- a/src/careamics_napari/careamics_utils/configuration.py +++ b/src/careamics_napari/careamics_utils/configuration.py @@ -1,5 +1,7 @@ """Utility to create CAREamics configurations from user-set settings.""" +from typing import Union + from careamics import Configuration from careamics.config import ( create_care_configuration, @@ -37,13 +39,13 @@ def create_configuration(signal: TrainingSignal) -> Configuration: experiment_name = signal.experiment_name if signal.is_3d: - patches: tuple[int, ...] = ( + patches: list[int] = [ signal.patch_size_xy, signal.patch_size_xy, signal.patch_size_z, - ) + ] else: - patches = (signal.patch_size_xy, signal.patch_size_xy) + patches = [signal.patch_size_xy, signal.patch_size_xy] # model params model_params = { @@ -52,7 +54,7 @@ def create_configuration(signal: TrainingSignal) -> Configuration: } # augmentations - augs = [] + augs: list[Union[XYFlipModel, XYRandomRotate90Model]] = [] if signal.x_flip or signal.y_flip: augs.append(XYFlipModel(flip_x=signal.x_flip, flip_y=signal.y_flip)) diff --git a/src/careamics_napari/careamics_utils/free_memory.py b/src/careamics_napari/careamics_utils/free_memory.py index de3464b..093c17c 100644 --- a/src/careamics_napari/careamics_utils/free_memory.py +++ b/src/careamics_napari/careamics_utils/free_memory.py @@ -14,7 +14,11 @@ def free_memory(careamist: CAREamist) -> None: careamist : CAREamist CAREamics instance. """ - if careamist is not None: + if ( + careamist is not None + and careamist.trainer is not None + and careamist.trainer.model is not None + ): careamist.trainer.model.cpu() del careamist.trainer.model del careamist.trainer diff --git a/src/careamics_napari/signals/training_signal.py b/src/careamics_napari/signals/training_signal.py index 6052acb..3423703 100644 --- a/src/careamics_napari/signals/training_signal.py +++ b/src/careamics_napari/signals/training_signal.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from careamics.utils import get_careamics_home from psygnal import evented @@ -35,6 +35,8 @@ class TrainingSignalGroup(SignalGroup): else: _has_napari = True +HOME = get_careamics_home() + # TODO make sure defaults are used @evented @@ -62,7 +64,7 @@ class TrainingSignal: """Whether the data is 3D.""" # parameters set by widgets for training - work_dir: Union[str, Path] = get_careamics_home() + work_dir: Path = HOME """Directory where the checkpoints and logs are saved.""" load_from_disk: bool = True diff --git a/src/careamics_napari/training_plugin.py b/src/careamics_napari/training_plugin.py index a819873..ba64ddc 100644 --- a/src/careamics_napari/training_plugin.py +++ b/src/careamics_napari/training_plugin.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: import napari + from careamics import CAREamist # at run time try: @@ -108,23 +109,23 @@ def __init__( """ super().__init__() self.viewer = napari_viewer - self.careamist = None + self.careamist: Optional[CAREamist] = None # create statuses, used to keep track of the threads statuses - self.train_status = TrainingStatus() - self.pred_status = PredictionStatus() - self.save_status = SavingStatus() + self.train_status = TrainingStatus() # type: ignore + self.pred_status = PredictionStatus() # type: ignore + self.save_status = SavingStatus() # type: ignore # create signals, used to hold the various parameters modified by the UI - self.train_config_signal = TrainingSignal() + self.train_config_signal = TrainingSignal() # type: ignore self.pred_config_signal = PredictionSignal() self.save_config_signal = SavingSignal() self.train_config_signal.events.is_3d.connect(self._set_pred_3d) # create queues, used to communicate between the threads and the UI - self._training_queue = Queue(10) - self._prediction_queue = Queue(10) + self._training_queue: Queue = Queue(10) + self._prediction_queue: Queue = Queue(10) # set workdir self.train_config_signal.work_dir = Path.cwd() @@ -248,7 +249,8 @@ def _training_state_changed(self, state: TrainingState) -> None: self.train_worker.start() elif state == TrainingState.STOPPED: - self.careamist.stop_training() + if self.careamist is not None: + self.careamist.stop_training() elif state == TrainingState.CRASHED or state == TrainingState.IDLE: del self.careamist @@ -302,12 +304,15 @@ def _update_from_training(self, update: TrainUpdate) -> None: Update. """ if update.type == TrainUpdateType.CAREAMIST: - self.careamist = update.value + if isinstance(update.value, CAREamist): + self.careamist = update.value elif update.type == TrainUpdateType.DEBUG: print(update.value) elif update.type == TrainUpdateType.EXCEPTION: self.train_status.state = TrainingState.CRASHED - raise update.value + + if isinstance(update.value, Exception): + raise update.value else: self.train_status.update(update) @@ -340,7 +345,8 @@ def _update_from_prediction(self, update: PredictionUpdate) -> None: if update.type == PredictionUpdateType.SAMPLE: # add image to napari # TODO keep scaling? - self.viewer.add_image(update.value, name="Prediction") + if self.viewer is not None: + self.viewer.add_image(update.value, name="Prediction") else: self.pred_status.update(update) diff --git a/src/careamics_napari/widgets/algorithm_choice.py b/src/careamics_napari/widgets/algorithm_choice.py index ad1097d..8d1a9a2 100644 --- a/src/careamics_napari/widgets/algorithm_choice.py +++ b/src/careamics_napari/widgets/algorithm_choice.py @@ -9,7 +9,13 @@ class AlgorithmSelectionWidget(QComboBox): - """Algorithm selection widget.""" + """Algorithm selection widget. + + Parameters + ---------- + training_signal : TrainingSignal or None, default=None + Training signal holding all parameters to be set by the user. + """ def __init__(self, training_signal: Optional[TrainingSignal] = None) -> None: """Initialize the widget. @@ -61,7 +67,7 @@ def algorithm_changed(self, index: int) -> None: from qtpy.QtWidgets import QApplication - myalgo = TrainingSignal() # typing: ignore + myalgo = TrainingSignal() # type: ignore @myalgo.events.name.connect def print_algorithm(name: str): diff --git a/src/careamics_napari/widgets/axes_widget.py b/src/careamics_napari/widgets/axes_widget.py index 3aaf320..ed11995 100644 --- a/src/careamics_napari/widgets/axes_widget.py +++ b/src/careamics_napari/widgets/axes_widget.py @@ -25,7 +25,17 @@ class Highlight(Enum): class LettersValidator(QtGui.QValidator): - """Custom validator.""" + """Custom validator. + + Parameters + ---------- + options : str + Allowed characters. + *args : Any + Variable length argument list. + **kwargs : Any + Arbitrary keyword arguments. + """ def __init__(self: Self, options: str, *args: Any, **kwargs: Any) -> None: """Initialize the validator. @@ -53,6 +63,11 @@ def validate( Input value. pos : int Position of the cursor. + + Returns + ------- + (QtGui.QValidator.State, str, int) + Validation state, value, and position. """ if len(value) > 0: if value[-1] in self._options: @@ -66,8 +81,19 @@ def validate( # TODO keep the validation? # TODO is train layer selected, then show the orange and red, otherwise ignore? class AxesWidget(QWidget): - """A widget allowing users to specify axes.""" - + """A widget allowing users to specify axes. + + Parameters + ---------- + n_axes : int, default=3 + Number of axes. + is_3D : bool, default=False + Whether the data is 3D. + training_signal : TrainingSignal or None, default=None + Signal holding all training parameters to be set by the user. + """ + + # TODO unused parameters def __init__( self, n_axes=3, is_3D=False, training_signal: Optional[TrainingSignal] = None ) -> None: @@ -79,7 +105,7 @@ def __init__( Number of axes. is_3D : bool, default=False Whether the data is 3D. - signal : TrainingSignal or None, default=None + training_signal : TrainingSignal or None, default=None Signal holding all training parameters to be set by the user. """ super().__init__() @@ -227,10 +253,11 @@ def set_text_field(self: Self, text: str) -> None: app = QApplication(sys.argv) # Signals - myalgo = TrainingSignal() + myalgo = TrainingSignal() # type: ignore @myalgo.events.use_channels.connect def print_axes(): + """Print axes.""" print(f"Use channels: {myalgo.use_channels}") # Instantiate widget diff --git a/src/careamics_napari/widgets/configuration_window.py b/src/careamics_napari/widgets/configuration_window.py index 05c5b85..2a8b7c4 100644 --- a/src/careamics_napari/widgets/configuration_window.py +++ b/src/careamics_napari/widgets/configuration_window.py @@ -55,7 +55,11 @@ def __init__( """ super().__init__(parent) - self.configuration_signal = training_signal + self.configuration_signal = ( + TrainingSignal() # type: ignore + if training_signal is None + else training_signal + ) self.setLayout(QVBoxLayout()) @@ -181,7 +185,7 @@ def __init__( self.model_depth = create_int_spinbox(2, 5, self.configuration_signal.depth, 1) self.model_depth.setToolTip("Depth of the U-Net model.") self.size_conv_filters = create_int_spinbox( - 8, 1024, self.configuration_signal.num_conv_filters, 8, 8 + 8, 1024, self.configuration_signal.num_conv_filters, 8 ) self.size_conv_filters.setToolTip( "Number of convolutional filters in the first layer." @@ -274,10 +278,10 @@ def _save(self: Self) -> None: app = QApplication(sys.argv) # Signals - myalgo = TrainingSignal(use_channels=False) + myalgo = TrainingSignal(use_channels=False) # type: ignore # Instantiate widget - widget = AdvancedConfigurationWindow(training_signal=myalgo) + widget = AdvancedConfigurationWindow(training_signal=myalgo) # type: ignore # Show the widget widget.show() diff --git a/src/careamics_napari/widgets/predict_data_widget.py b/src/careamics_napari/widgets/predict_data_widget.py index 4b7b71a..3bedaf8 100644 --- a/src/careamics_napari/widgets/predict_data_widget.py +++ b/src/careamics_napari/widgets/predict_data_widget.py @@ -28,7 +28,13 @@ class PredictDataWidget(QTabWidget): - """A widget offering to select a layer from napari or a path from disk.""" + """A widget offering to select a layer from napari or a path from disk. + + Parameters + ---------- + prediction_signal : PredConfigurationSignal, default=None + Signal to be updated with changed in widgets values. + """ def __init__( self: Self, @@ -38,11 +44,16 @@ def __init__( Parameters ---------- - signal : PredConfigurationSignal, default=None + prediction_signal : PredConfigurationSignal, default=None Signal to be updated with changed in widgets values. """ super().__init__() - self.config_signal = prediction_signal + + self.config_signal = ( + PredictionSignal() # type: ignore + if prediction_signal is None + else prediction_signal + ) # QTabs layer_tab = QWidget() @@ -138,7 +149,8 @@ def _update_pred_layer(self: Self, layer: Image) -> None: layer : Image The selected layer. """ - self.config_signal.layer_pred = layer + if self.config_signal.layer_pred is not None: + self.config_signal.layer_pred = layer def _update_pred_folder(self: Self, folder: str) -> None: """Update the path attribute of the signal. @@ -148,7 +160,8 @@ def _update_pred_folder(self: Self, folder: str) -> None: folder : str The selected folder. """ - self.config_signal.path_pred = folder + if self.config_signal.path_pred is not None: + self.config_signal.path_pred = folder if __name__ == "__main__": diff --git a/src/careamics_napari/widgets/prediction_widget.py b/src/careamics_napari/widgets/prediction_widget.py index 88ca8dc..17c0e5f 100644 --- a/src/careamics_napari/widgets/prediction_widget.py +++ b/src/careamics_napari/widgets/prediction_widget.py @@ -39,9 +39,9 @@ class PredictionWidget(QGroupBox): The training status signal. pred_status : PredictionStatus or None, default=None The prediction status signal. - train_config_signal : TrainingSignal or None, default=None + train_signal : TrainingSignal or None, default=None The training configuration signal. - pred_config_signal : PredictionSignal or None, default=None + pred_signal : PredictionSignal or None, default=None The prediction configuration signal. """ @@ -49,8 +49,8 @@ def __init__( self: Self, train_status: Optional[TrainingStatus] = None, pred_status: Optional[PredictionStatus] = None, - train_config_signal: Optional[TrainingSignal] = None, - pred_config_signal: Optional[PredictionSignal] = None, + train_signal: Optional[TrainingSignal] = None, + pred_signal: Optional[PredictionSignal] = None, ) -> None: """Initialize the widget. @@ -60,24 +60,30 @@ def __init__( The training status signal. pred_status : PredictionStatus or None, default=None The prediction status signal. - train_config_signal : TrainingSignal or None, default=None + train_signal : TrainingSignal or None, default=None The training configuration signal. - pred_config_signal : PredictionSignal or None, default=None + pred_signal : PredictionSignal or None, default=None The prediction configuration signal. """ super().__init__() - self.train_status = train_status - self.pred_status = pred_status - self.train_config_signal = train_config_signal - self.pred_config_signal = pred_config_signal + self.train_status = ( + TrainingStatus() if train_status is None else train_status # type: ignore + ) + self.pred_status = ( + PredictionStatus() if pred_status is None else pred_status # type: ignore + ) + self.train_signal = ( + TrainingSignal() if train_signal is None else train_signal # type: ignore + ) + self.pred_signal = PredictionSignal() if pred_signal is None else pred_signal self.setTitle("Prediction") self.setLayout(QVBoxLayout()) self.layout().setContentsMargins(20, 20, 20, 0) # data selection - predict_data_widget = PredictDataWidget(self.pred_config_signal) + predict_data_widget = PredictDataWidget(self.pred_signal) self.layout().addWidget(predict_data_widget) # checkbox @@ -89,13 +95,11 @@ def __init__( self.layout().addWidget(self.tiling_cbox) # tiling spinboxes - self.tile_size_xy = PowerOfTwoSpinBox( - 64, 1024, self.pred_config_signal.tile_size_xy - ) + self.tile_size_xy = PowerOfTwoSpinBox(64, 1024, self.pred_signal.tile_size_xy) self.tile_size_xy.setToolTip("Tile size in the xy dimension.") self.tile_size_xy.setEnabled(False) - self.tile_size_z = PowerOfTwoSpinBox(4, 32, self.pred_config_signal.tile_size_z) + self.tile_size_z = PowerOfTwoSpinBox(4, 32, self.pred_signal.tile_size_z) self.tile_size_z.setToolTip("Tile size in the z dimension.") self.tile_size_z.setEnabled(False) @@ -137,7 +141,7 @@ def __init__( self.tile_size_z.valueChanged.connect(self._set_z_tile_size) # listening to the signals - self.train_config_signal.events.is_3d.connect(self._set_3d) + self.train_signal.events.is_3d.connect(self._set_3d) self.train_status.events.state.connect(self._update_button_from_train) self.pred_status.events.state.connect(self._update_button_from_pred) @@ -152,8 +156,8 @@ def _set_xy_tile_size(self: Self, size: int) -> None: size : int The new tile size in the xy dimension. """ - if self.pred_config_signal is not None: - self.pred_config_signal.tile_size_xy = size + if self.pred_signal is not None: + self.pred_signal.tile_size_xy = size def _set_z_tile_size(self: Self, size: int) -> None: """Update the signal tile size in the z dimension. @@ -163,8 +167,8 @@ def _set_z_tile_size(self: Self, size: int) -> None: size : int The new tile size in the z dimension. """ - if self.pred_config_signal is not None: - self.pred_config_signal.tile_size_z = size + if self.pred_signal is not None: + self.pred_signal.tile_size_z = size def _set_3d(self: Self, state: bool) -> None: """Enable the z tile size spinbox if the data is 3D. @@ -174,7 +178,7 @@ def _set_3d(self: Self, state: bool) -> None: state : bool The new state of the 3D checkbox. """ - if self.pred_config_signal.tiled: + if self.pred_signal.tiled: self.tile_size_z.setEnabled(state) def _update_tiles(self: Self, state: bool) -> None: @@ -185,10 +189,10 @@ def _update_tiles(self: Self, state: bool) -> None: state : bool The new state of the tiling checkbox. """ - self.pred_config_signal.tiled = state + self.pred_signal.tiled = state self.tile_size_xy.setEnabled(state) - if self.train_config_signal.is_3d: + if self.train_signal.is_3d: self.tile_size_z.setEnabled(state) def _update_3d_tiles(self: Self, state: bool) -> None: @@ -199,7 +203,7 @@ def _update_3d_tiles(self: Self, state: bool) -> None: state : bool The new state of the 3D checkbox. """ - if self.pred_config_signal.tiled: + if self.pred_signal.tiled: self.tile_size_z.setEnabled(state) def _update_max_sample(self: Self, max_sample: int) -> None: @@ -270,12 +274,13 @@ def _update_button_from_pred(self: Self, state: PredictionState) -> None: app = QApplication(sys.argv) # create signal - train_signal = TrainingStatus() - pred_signal = PredictionStatus() - config_signal = PredictionSignal() + train_status = TrainingStatus() # type: ignore + pred_status = PredictionStatus() # type: ignore + pred_signal = PredictionSignal() # type: ignore + train_signal = TrainingSignal() # type: ignore # Instantiate widget - widget = PredictionWidget(train_signal, pred_signal, config_signal) + widget = PredictionWidget(train_status, pred_status, train_signal, pred_signal) # Show the widget widget.show() diff --git a/src/careamics_napari/widgets/saving_widget.py b/src/careamics_napari/widgets/saving_widget.py index 5edfb2e..9f5a5d3 100644 --- a/src/careamics_napari/widgets/saving_widget.py +++ b/src/careamics_napari/widgets/saving_widget.py @@ -25,7 +25,17 @@ class SavingWidget(QGroupBox): - """A widget allowing users to select a model type and a path.""" + """A widget allowing users to select a model type and a path. + + Parameters + ---------- + train_status : TrainingStatus or None, default=None + Signal containing training parameters. + save_status : SavingStatus or None, default=None + Signal containing saving parameters. + save_signal : SavingSignal or None, default=None + Signal to trigger saving. + """ def __init__( self: Self, @@ -106,7 +116,7 @@ def _update_training_state(self: Self, state: TrainingState) -> None: def _save_model(self: Self) -> None: """Prompt users with a path selection dialog and update the saving state.""" if self.save_status is not None: - if ( + if self.save_signal is not None and ( self.save_status.state == SavingState.IDLE or self.save_status.state == SavingState.DONE or self.save_status.state == SavingState.CRASHED diff --git a/src/careamics_napari/widgets/tbplot_widget.py b/src/careamics_napari/widgets/tbplot_widget.py index 018a21c..409db65 100644 --- a/src/careamics_napari/widgets/tbplot_widget.py +++ b/src/careamics_napari/widgets/tbplot_widget.py @@ -16,7 +16,21 @@ # TODO why is it a magicgui container and not just a widget? class TBPlotWidget(Container): - """A widget displaying losses and a button to open TensorBoard in the browser.""" + """A widget displaying losses and a button to open TensorBoard in the browser. + + Parameters + ---------- + min_width : int or None, default=None + Minimum width of the widget. + min_height : int or None, default=None + Minimum height of the widget. + max_width : int or None, default=None + Maximum width of the widget. + max_height : int or None, default=None + Maximum height of the widget. + train_signal : TrainingSignal or None, default=None + Signal containing training parameters. + """ # TODO what is this method used for? def __setitem__(self: Self, key: Any, value: Any) -> None: @@ -94,10 +108,10 @@ def __init__( self.native.layout().addWidget(button_widget) # set empty references - self.epochs = [] - self.train_loss = [] - self.val_loss = [] - self.url = None + self.epochs: list[int] = [] + self.train_loss: list[float] = [] + self.val_loss: list[float] = [] + self.url: Optional[str] = None self.tb = None def stop_tb(self: Self) -> None: @@ -111,18 +125,20 @@ def stop_tb(self: Self) -> None: def open_tb(self: Self) -> None: """Open TensorBoard in the browser.""" - if not self.tb: + if self.tb is not None and self.train_signal is not None: from tensorboard import program self.tb = program.TensorBoard() path = str(self.train_signal.work_dir / "logs" / "lightning_logs") - self.tb.configure(argv=[None, "--logdir", path]) - self.url = self.tb.launch() + self.tb.configure(argv=[None, "--logdir", path]) # type: ignore + self.url = self.tb.launch() # type: ignore - webbrowser.open(self.url) + if self.url is not None: + webbrowser.open(self.url) else: - webbrowser.open(self.url) + if self.url is not None: + webbrowser.open(self.url) def update_plot(self: Self, epoch: int, train_loss: float, val_loss: float) -> None: """Update the plot with new data. diff --git a/src/careamics_napari/widgets/training_configuration_widget.py b/src/careamics_napari/widgets/training_configuration_widget.py index 9df7a9d..17eb0eb 100644 --- a/src/careamics_napari/widgets/training_configuration_widget.py +++ b/src/careamics_napari/widgets/training_configuration_widget.py @@ -24,20 +24,26 @@ class ConfigurationWidget(QGroupBox): - """A widget allowing the creation of a CAREamics configuration.""" + """A widget allowing the creation of a CAREamics configuration. + + Parameters + ---------- + training_signal : TrainingSignal or None, default=None + Signal containing the training parameters. + """ def __init__(self: Self, training_signal: Optional[TrainingSignal] = None) -> None: """Initialize the widget. Parameters ---------- - signal : TrainingSignal or None, default=None + training_signal : TrainingSignal or None, default=None Signal containing the training parameters. """ super().__init__() self.configuration_signal = training_signal - self.config_window = None + self.config_window: Optional[AdvancedConfigurationWindow] = None self.setTitle("Training parameters") # self.setMinimumWidth(100) diff --git a/src/careamics_napari/workers/prediction_worker.py b/src/careamics_napari/workers/prediction_worker.py index ece5b06..be0d255 100644 --- a/src/careamics_napari/workers/prediction_worker.py +++ b/src/careamics_napari/workers/prediction_worker.py @@ -144,8 +144,8 @@ def _predict( # Predict with CAREamist try: - result = careamist.predict( - source=pred_data, + result = careamist.predict( # type: ignore + pred_data, data_type="tiff" if config_signal.load_from_disk else "array", tile_size=tile_size, tile_overlap=tile_overlap,