Skip to content

Commit

Permalink
chore: Fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Oct 20, 2024
1 parent b9beb89 commit cd5946a
Show file tree
Hide file tree
Showing 16 changed files with 187 additions and 82 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ repos:
additional_dependencies:
- numpy
- types-PyYAML
- careamics-stubs

# check docstrings
- repo: https://github.com/numpy/numpydoc
Expand Down
2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
plugins = pydantic.mypy
7 changes: 4 additions & 3 deletions src/careamics_napari/careamics_utils/callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""PyTorch Lightning callback for updating a GUI with training and prediction progress.
"""
"""PyTorch Lightning callback used to update GUI with progress."""

from queue import Queue
from typing import Any
Expand Down Expand Up @@ -78,10 +77,12 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
PyTorch Lightning module.
"""
# compute the number of batches
len_dataloader = len(trainer.train_dataloader) # type: ignore

self.training_queue.put(
TrainUpdate(
TrainUpdateType.MAX_BATCH,
int(len(trainer.train_dataloader) / trainer.accumulate_grad_batches),
int(len_dataloader / trainer.accumulate_grad_batches),
)
)

Expand Down
10 changes: 6 additions & 4 deletions src/careamics_napari/careamics_utils/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utility to create CAREamics configurations from user-set settings."""

from typing import Union

from careamics import Configuration
from careamics.config import (
create_care_configuration,
Expand Down Expand Up @@ -37,13 +39,13 @@ def create_configuration(signal: TrainingSignal) -> Configuration:
experiment_name = signal.experiment_name

if signal.is_3d:
patches: tuple[int, ...] = (
patches: list[int] = [
signal.patch_size_xy,
signal.patch_size_xy,
signal.patch_size_z,
)
]
else:
patches = (signal.patch_size_xy, signal.patch_size_xy)
patches = [signal.patch_size_xy, signal.patch_size_xy]

# model params
model_params = {
Expand All @@ -52,7 +54,7 @@ def create_configuration(signal: TrainingSignal) -> Configuration:
}

# augmentations
augs = []
augs: list[Union[XYFlipModel, XYRandomRotate90Model]] = []
if signal.x_flip or signal.y_flip:
augs.append(XYFlipModel(flip_x=signal.x_flip, flip_y=signal.y_flip))

Expand Down
6 changes: 5 additions & 1 deletion src/careamics_napari/careamics_utils/free_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ def free_memory(careamist: CAREamist) -> None:
careamist : CAREamist
CAREamics instance.
"""
if careamist is not None:
if (
careamist is not None
and careamist.trainer is not None
and careamist.trainer.model is not None
):
careamist.trainer.model.cpu()
del careamist.trainer.model
del careamist.trainer
Expand Down
6 changes: 4 additions & 2 deletions src/careamics_napari/signals/training_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING

from careamics.utils import get_careamics_home
from psygnal import evented
Expand Down Expand Up @@ -35,6 +35,8 @@ class TrainingSignalGroup(SignalGroup):
else:
_has_napari = True

HOME = get_careamics_home()


# TODO make sure defaults are used
@evented
Expand Down Expand Up @@ -62,7 +64,7 @@ class TrainingSignal:
"""Whether the data is 3D."""

# parameters set by widgets for training
work_dir: Union[str, Path] = get_careamics_home()
work_dir: Path = HOME
"""Directory where the checkpoints and logs are saved."""

load_from_disk: bool = True
Expand Down
28 changes: 17 additions & 11 deletions src/careamics_napari/training_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

if TYPE_CHECKING:
import napari
from careamics import CAREamist

# at run time
try:
Expand Down Expand Up @@ -108,23 +109,23 @@ def __init__(
"""
super().__init__()
self.viewer = napari_viewer
self.careamist = None
self.careamist: Optional[CAREamist] = None

# create statuses, used to keep track of the threads statuses
self.train_status = TrainingStatus()
self.pred_status = PredictionStatus()
self.save_status = SavingStatus()
self.train_status = TrainingStatus() # type: ignore
self.pred_status = PredictionStatus() # type: ignore
self.save_status = SavingStatus() # type: ignore

# create signals, used to hold the various parameters modified by the UI
self.train_config_signal = TrainingSignal()
self.train_config_signal = TrainingSignal() # type: ignore
self.pred_config_signal = PredictionSignal()
self.save_config_signal = SavingSignal()

self.train_config_signal.events.is_3d.connect(self._set_pred_3d)

# create queues, used to communicate between the threads and the UI
self._training_queue = Queue(10)
self._prediction_queue = Queue(10)
self._training_queue: Queue = Queue(10)
self._prediction_queue: Queue = Queue(10)

# set workdir
self.train_config_signal.work_dir = Path.cwd()
Expand Down Expand Up @@ -248,7 +249,8 @@ def _training_state_changed(self, state: TrainingState) -> None:
self.train_worker.start()

elif state == TrainingState.STOPPED:
self.careamist.stop_training()
if self.careamist is not None:
self.careamist.stop_training()

elif state == TrainingState.CRASHED or state == TrainingState.IDLE:
del self.careamist
Expand Down Expand Up @@ -302,12 +304,15 @@ def _update_from_training(self, update: TrainUpdate) -> None:
Update.
"""
if update.type == TrainUpdateType.CAREAMIST:
self.careamist = update.value
if isinstance(update.value, CAREamist):
self.careamist = update.value
elif update.type == TrainUpdateType.DEBUG:
print(update.value)
elif update.type == TrainUpdateType.EXCEPTION:
self.train_status.state = TrainingState.CRASHED
raise update.value

if isinstance(update.value, Exception):
raise update.value
else:
self.train_status.update(update)

Expand Down Expand Up @@ -340,7 +345,8 @@ def _update_from_prediction(self, update: PredictionUpdate) -> None:
if update.type == PredictionUpdateType.SAMPLE:
# add image to napari
# TODO keep scaling?
self.viewer.add_image(update.value, name="Prediction")
if self.viewer is not None:
self.viewer.add_image(update.value, name="Prediction")
else:
self.pred_status.update(update)

Expand Down
10 changes: 8 additions & 2 deletions src/careamics_napari/widgets/algorithm_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@


class AlgorithmSelectionWidget(QComboBox):
"""Algorithm selection widget."""
"""Algorithm selection widget.
Parameters
----------
training_signal : TrainingSignal or None, default=None
Training signal holding all parameters to be set by the user.
"""

def __init__(self, training_signal: Optional[TrainingSignal] = None) -> None:
"""Initialize the widget.
Expand Down Expand Up @@ -61,7 +67,7 @@ def algorithm_changed(self, index: int) -> None:

from qtpy.QtWidgets import QApplication

myalgo = TrainingSignal() # typing: ignore
myalgo = TrainingSignal() # type: ignore

@myalgo.events.name.connect
def print_algorithm(name: str):
Expand Down
37 changes: 32 additions & 5 deletions src/careamics_napari/widgets/axes_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,17 @@ class Highlight(Enum):


class LettersValidator(QtGui.QValidator):
"""Custom validator."""
"""Custom validator.
Parameters
----------
options : str
Allowed characters.
*args : Any
Variable length argument list.
**kwargs : Any
Arbitrary keyword arguments.
"""

def __init__(self: Self, options: str, *args: Any, **kwargs: Any) -> None:
"""Initialize the validator.
Expand Down Expand Up @@ -53,6 +63,11 @@ def validate(
Input value.
pos : int
Position of the cursor.
Returns
-------
(QtGui.QValidator.State, str, int)
Validation state, value, and position.
"""
if len(value) > 0:
if value[-1] in self._options:
Expand All @@ -66,8 +81,19 @@ def validate(
# TODO keep the validation?
# TODO is train layer selected, then show the orange and red, otherwise ignore?
class AxesWidget(QWidget):
"""A widget allowing users to specify axes."""

"""A widget allowing users to specify axes.
Parameters
----------
n_axes : int, default=3
Number of axes.
is_3D : bool, default=False
Whether the data is 3D.
training_signal : TrainingSignal or None, default=None
Signal holding all training parameters to be set by the user.
"""

# TODO unused parameters
def __init__(
self, n_axes=3, is_3D=False, training_signal: Optional[TrainingSignal] = None
) -> None:
Expand All @@ -79,7 +105,7 @@ def __init__(
Number of axes.
is_3D : bool, default=False
Whether the data is 3D.
signal : TrainingSignal or None, default=None
training_signal : TrainingSignal or None, default=None
Signal holding all training parameters to be set by the user.
"""
super().__init__()
Expand Down Expand Up @@ -227,10 +253,11 @@ def set_text_field(self: Self, text: str) -> None:
app = QApplication(sys.argv)

# Signals
myalgo = TrainingSignal()
myalgo = TrainingSignal() # type: ignore

@myalgo.events.use_channels.connect
def print_axes():
"""Print axes."""
print(f"Use channels: {myalgo.use_channels}")

# Instantiate widget
Expand Down
12 changes: 8 additions & 4 deletions src/careamics_napari/widgets/configuration_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def __init__(
"""
super().__init__(parent)

self.configuration_signal = training_signal
self.configuration_signal = (
TrainingSignal() # type: ignore
if training_signal is None
else training_signal
)

self.setLayout(QVBoxLayout())

Expand Down Expand Up @@ -181,7 +185,7 @@ def __init__(
self.model_depth = create_int_spinbox(2, 5, self.configuration_signal.depth, 1)
self.model_depth.setToolTip("Depth of the U-Net model.")
self.size_conv_filters = create_int_spinbox(
8, 1024, self.configuration_signal.num_conv_filters, 8, 8
8, 1024, self.configuration_signal.num_conv_filters, 8
)
self.size_conv_filters.setToolTip(
"Number of convolutional filters in the first layer."
Expand Down Expand Up @@ -274,10 +278,10 @@ def _save(self: Self) -> None:
app = QApplication(sys.argv)

# Signals
myalgo = TrainingSignal(use_channels=False)
myalgo = TrainingSignal(use_channels=False) # type: ignore

# Instantiate widget
widget = AdvancedConfigurationWindow(training_signal=myalgo)
widget = AdvancedConfigurationWindow(training_signal=myalgo) # type: ignore

# Show the widget
widget.show()
Expand Down
23 changes: 18 additions & 5 deletions src/careamics_napari/widgets/predict_data_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@


class PredictDataWidget(QTabWidget):
"""A widget offering to select a layer from napari or a path from disk."""
"""A widget offering to select a layer from napari or a path from disk.
Parameters
----------
prediction_signal : PredConfigurationSignal, default=None
Signal to be updated with changed in widgets values.
"""

def __init__(
self: Self,
Expand All @@ -38,11 +44,16 @@ def __init__(
Parameters
----------
signal : PredConfigurationSignal, default=None
prediction_signal : PredConfigurationSignal, default=None
Signal to be updated with changed in widgets values.
"""
super().__init__()
self.config_signal = prediction_signal

self.config_signal = (
PredictionSignal() # type: ignore
if prediction_signal is None
else prediction_signal
)

# QTabs
layer_tab = QWidget()
Expand Down Expand Up @@ -138,7 +149,8 @@ def _update_pred_layer(self: Self, layer: Image) -> None:
layer : Image
The selected layer.
"""
self.config_signal.layer_pred = layer
if self.config_signal.layer_pred is not None:
self.config_signal.layer_pred = layer

def _update_pred_folder(self: Self, folder: str) -> None:
"""Update the path attribute of the signal.
Expand All @@ -148,7 +160,8 @@ def _update_pred_folder(self: Self, folder: str) -> None:
folder : str
The selected folder.
"""
self.config_signal.path_pred = folder
if self.config_signal.path_pred is not None:
self.config_signal.path_pred = folder


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit cd5946a

Please sign in to comment.