diff --git a/terratorch/tasks/base_task.py b/terratorch/tasks/base_task.py index 762f666f..e59aaf39 100644 --- a/terratorch/tasks/base_task.py +++ b/terratorch/tasks/base_task.py @@ -1,20 +1,24 @@ +import logging +import lightning from lightning.pytorch.callbacks import Callback from torchgeo.trainers import BaseTask +from terratorch.models.model import Model +from terratorch.tasks.optimizer_factory import optimizer_factory BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 +logger = logging.getLogger("terratorch") -class TerraTorchTask(BaseTask): +class TerraTorchTask(BaseTask): """ - Parent used to share common methods among all the - tasks implemented in terratorch + Parent used to share common methods among all the + tasks implemented in terratorch """ - def __init__(self, task:str=None): - - self.task = task + def __init__(self, task: str | None = None): + self.task = task super().__init__() @@ -61,13 +65,14 @@ def configure_optimizers( def on_train_epoch_end(self) -> None: self.log_dict(self.train_metrics.compute(), sync_dist=True) self.train_metrics.reset() - return super().on_train_epoch_end() def on_validation_epoch_end(self) -> None: self.log_dict(self.val_metrics.compute(), sync_dist=True) self.val_metrics.reset() - return super().on_validation_epoch_end() + def on_test_epoch_end(self) -> None: + self.log_dict(self.test_metrics.compute(), sync_dist=True) + self.test_metrics.reset() def _do_plot_samples(self, batch_index): if not self.plot_on_val: # dont plot if self.plot_on_val is 0 @@ -81,5 +86,3 @@ def _do_plot_samples(self, batch_index): and hasattr(self.logger, "experiment") and (hasattr(self.logger.experiment, "add_figure") or hasattr(self.logger.experiment, "log_figure")) ) - - diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 3b1fb04f..48e80221 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -1,5 +1,4 @@ - -from typing import Any +from typing import Any from functools import partial import os import logging @@ -13,7 +12,7 @@ from torchmetrics import ClasswiseWrapper, MetricCollection from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassJaccardIndex -from terratorch.models.model import AuxiliaryHead, Model, ModelOutput +from terratorch.models.model import AuxiliaryHead, ModelOutput from terratorch.registry import MODEL_FACTORY_REGISTRY from terratorch.tasks.loss_handler import LossHandler from terratorch.tasks.optimizer_factory import optimizer_factory @@ -22,7 +21,8 @@ BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 -logger = logging.getLogger('terratorch') +logger = logging.getLogger("terratorch") + def to_segmentation_prediction(y: ModelOutput) -> Tensor: y_hat = y.output @@ -133,7 +133,6 @@ def __init__( self.monitor = f"{self.val_metrics.prefix}loss" self.plot_on_val = int(plot_on_val) - def configure_losses(self) -> None: """Initialize the loss criterion. @@ -262,8 +261,12 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None y_hat_hard = to_segmentation_prediction(model_output) self.test_metrics[dataloader_idx].update(y_hat_hard, y) - def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def on_test_epoch_end(self) -> None: + for metrics in self.test_metrics: + self.log_dict(metrics.compute(), sync_dist=True) + metrics.reset() + def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the validation loss and additional metrics. Args: batch: The output of your DataLoader.