Skip to content

Commit

Permalink
chore: Fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Oct 20, 2024
1 parent 5cfc448 commit b9beb89
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 118 deletions.
3 changes: 1 addition & 2 deletions src/careamics_napari/signals/saving_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Union


class ExportType(Enum):
Expand Down Expand Up @@ -38,7 +37,7 @@ class SavingSignal:
elements.
"""

path_model: Union[str, Path] = ""
path_model: Path = Path(".")
"""Path in which to save the model."""

export_type: ExportType = ExportType.BMZ
Expand Down
127 changes: 80 additions & 47 deletions src/careamics_napari/workers/prediction_worker.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
"""A thread worker function running CAREamics prediction."""
from typing import Generator, Optional

from collections.abc import Generator
from queue import Queue
from threading import Thread

from napari.qt.threading import thread_worker
import napari.utils.notifications as ntf
from typing import Optional, Union

from careamics import CAREamist
from napari.qt.threading import thread_worker

from careamics_napari.signals import (
PredictionSignal,
PredictionState,
PredictionUpdate,
PredictionUpdateType,
PredictionState
)


# TODO register CAREamist to continue training and predict
# TODO register CAREamist to continue training and predict
# TODO how to load pre-trained?
# TODO pass careamist here if it already exists?
@thread_worker
Expand All @@ -25,22 +25,37 @@ def predict_worker(
config_signal: PredictionSignal,
update_queue: Queue,
) -> Generator[PredictionUpdate, None, None]:

"""Model prediction worker.
Parameters
----------
careamist : CAREamist
CAREamist instance.
config_signal : PredictionSignal
Prediction signal.
update_queue : Queue
Queue used to send updates to the UI.
Yields
------
Generator[PredictionUpdate, None, None]
Updates.
"""
# start training thread
training = Thread(
target=_predict,
target=_predict,
args=(
careamist,
config_signal,
config_signal,
update_queue,
)
),
)
training.start()

# look for updates
while True:
update: PredictionUpdate = update_queue.get(block=True)

yield update

if (
Expand All @@ -49,66 +64,84 @@ def predict_worker(
):
break


def _push_exception(queue: Queue, e: Exception) -> None:
"""Push an exception to the queue.
Parameters
----------
queue : Queue
Queue.
e : Exception
Exception.
"""
queue.put(PredictionUpdate(PredictionUpdateType.EXCEPTION, e))


def _predict(
careamist: CAREamist,
config_signal: PredictionSignal,
update_queue: Queue,
) -> None:

"""Run the prediction.
Parameters
----------
careamist : CAREamist
CAREamist instance.
config_signal : PredictionSignal
Prediction signal.
update_queue : Queue
Queue used to send updates to the UI.
"""
# Format data
if config_signal.load_from_disk:

if config_signal.path_pred == "":
_push_exception(
update_queue,
ValueError(
"Prediction data path is empty."
)
)
_push_exception(update_queue, ValueError("Prediction data path is empty."))
return

pred_data = config_signal.path_pred

else:
if config_signal.layer_pred is None:
_push_exception(
update_queue,
ValueError(
"Training data path is empty."
)
update_queue, ValueError("Prediction layer has not been selected.")
)

pred_data = config_signal.layer_pred.data
elif config_signal.layer_pred.data is None:
_push_exception(
update_queue,
ValueError(
f"Prediction layer {config_signal.layer_pred.name} is empty."
),
)
else:
pred_data = config_signal.layer_pred.data

# tiling
if config_signal.tiled:
if config_signal.is_3d:
tile_size = (
config_signal.tile_size_z,
config_signal.tile_size_xy,
config_signal.tile_size_xy
tile_size: Optional[Union[tuple[int, int, int], tuple[int, int]]] = (
config_signal.tile_size_z,
config_signal.tile_size_xy,
config_signal.tile_size_xy,
)
tile_overlap = (
config_signal.tile_overlap_z,
config_signal.tile_overlap_xy,
config_signal.tile_overlap_xy
tile_overlap: Optional[Union[tuple[int, int, int], tuple[int, int]]] = (
config_signal.tile_overlap_z,
config_signal.tile_overlap_xy,
config_signal.tile_overlap_xy,
)
else:
tile_size = (
config_signal.tile_size_xy,
config_signal.tile_size_xy
)
tile_size = (config_signal.tile_size_xy, config_signal.tile_size_xy)
tile_overlap = (
config_signal.tile_overlap_xy,
config_signal.tile_overlap_xy
config_signal.tile_overlap_xy,
config_signal.tile_overlap_xy,
)
else:
tile_size = None
tile_overlap = None

# Predict with CAREamist
try:
result = careamist.predict(
Expand All @@ -122,26 +155,26 @@ def _predict(

# # TODO can we use this to monkey patch the training process?
# import time
# update_queue.put(PredictionUpdate(PredictionUpdateType.MAX_SAMPLES, 1_000 // 10))
# update_queue.put(
# PredictionUpdate(PredictionUpdateType.MAX_SAMPLES, 1_000 // 10)
# )
# for i in range(1_000):

# # if stopper.stop:
# # update_queue.put(Update(UpdateType.STATE, TrainingState.STOPPED))
# # break

# if i % 10 == 0:
# update_queue.put(PredictionUpdate(PredictionUpdateType.SAMPLE_IDX, i // 10))
# update_queue.put(
# PredictionUpdate(PredictionUpdateType.SAMPLE_IDX, i // 10)
# )
# print(i)


# time.sleep(0.2)
# time.sleep(0.2)

except Exception as e:
update_queue.put(
PredictionUpdate(PredictionUpdateType.EXCEPTION, e)
)
update_queue.put(PredictionUpdate(PredictionUpdateType.EXCEPTION, e))
return


# signify end of prediction
update_queue.put(PredictionUpdate(PredictionUpdateType.STATE, PredictionState.DONE))
update_queue.put(PredictionUpdate(PredictionUpdateType.STATE, PredictionState.DONE))
40 changes: 28 additions & 12 deletions src/careamics_napari/workers/saving_worker.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
"""A thread worker function running CAREamics prediction."""
from typing import Generator, Optional
from queue import Queue
from threading import Thread

from napari.qt.threading import thread_worker
import napari.utils.notifications as ntf
from collections.abc import Generator

from careamics import CAREamist
from napari.qt.threading import thread_worker

from careamics_napari.signals import (
SavingState,
SavingStatus,
ExportType,
SavingSignal,
SavingState,
SavingUpdate,
SavingUpdateType,
ExportType,
TrainingSignal
TrainingSignal,
)


Expand All @@ -25,16 +21,36 @@ def save_worker(
training_signal: TrainingSignal,
config_signal: SavingSignal,
) -> Generator[SavingUpdate, None, None]:
"""Model saving worker.
Parameters
----------
careamist : CAREamist
CAREamist instance.
training_signal : TrainingSignal
Training signal.
config_signal : SavingSignal
Saving signal.
Yields
------
Generator[SavingUpdate, None, None]
Updates.
Raises
------
NotImplementedError
Export to BMZ not implemented yet.
"""
dims = "3D" if training_signal.is_3d else "2D"
name = f"{training_signal.algorithm}_{dims}_{training_signal.experiment_name}"

# save model
try:
if config_signal.export_type == ExportType.BMZ:

raise NotImplementedError("Export to BMZ not implemented yet.")
raise NotImplementedError("Export to BMZ not implemented yet (but soon).")

else:
name = name + ".ckpt"
# TODO: should we reexport the model every time?
Expand All @@ -45,4 +61,4 @@ def save_worker(
except Exception as e:
yield SavingUpdate(SavingUpdateType.EXCEPTION, e)

yield SavingUpdate(SavingUpdateType.STATE, SavingState.DONE)
yield SavingUpdate(SavingUpdateType.STATE, SavingState.DONE)
Loading

0 comments on commit b9beb89

Please sign in to comment.