Skip to content

Commit

Permalink
Merge pull request #319 from fmartiescofet/fix_base_task
Browse files Browse the repository at this point in the history
Fix base task `on_test_epoch_end`
  • Loading branch information
Joao-L-S-Almeida authored Dec 11, 2024
2 parents 16e5af9 + 3fb54f7 commit 4a0afc2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
23 changes: 13 additions & 10 deletions terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import logging

import lightning
from lightning.pytorch.callbacks import Callback
from torchgeo.trainers import BaseTask

from terratorch.models.model import Model
from terratorch.tasks.optimizer_factory import optimizer_factory

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10
logger = logging.getLogger("terratorch")

class TerraTorchTask(BaseTask):

class TerraTorchTask(BaseTask):
"""
Parent used to share common methods among all the
tasks implemented in terratorch
Parent used to share common methods among all the
tasks implemented in terratorch
"""

def __init__(self, task:str=None):

self.task = task
def __init__(self, task: str | None = None):
self.task = task

super().__init__()

Expand Down Expand Up @@ -61,13 +65,14 @@ def configure_optimizers(
def on_train_epoch_end(self) -> None:
self.log_dict(self.train_metrics.compute(), sync_dist=True)
self.train_metrics.reset()
return super().on_train_epoch_end()

def on_validation_epoch_end(self) -> None:
self.log_dict(self.val_metrics.compute(), sync_dist=True)
self.val_metrics.reset()
return super().on_validation_epoch_end()

def on_test_epoch_end(self) -> None:
self.log_dict(self.test_metrics.compute(), sync_dist=True)
self.test_metrics.reset()

def _do_plot_samples(self, batch_index):
if not self.plot_on_val: # dont plot if self.plot_on_val is 0
Expand All @@ -81,5 +86,3 @@ def _do_plot_samples(self, batch_index):
and hasattr(self.logger, "experiment")
and (hasattr(self.logger.experiment, "add_figure") or hasattr(self.logger.experiment, "log_figure"))
)


15 changes: 9 additions & 6 deletions terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

from typing import Any
from typing import Any
from functools import partial
import os
import logging
Expand All @@ -13,7 +12,7 @@
from torchmetrics import ClasswiseWrapper, MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassJaccardIndex

from terratorch.models.model import AuxiliaryHead, Model, ModelOutput
from terratorch.models.model import AuxiliaryHead, ModelOutput
from terratorch.registry import MODEL_FACTORY_REGISTRY
from terratorch.tasks.loss_handler import LossHandler
from terratorch.tasks.optimizer_factory import optimizer_factory
Expand All @@ -22,7 +21,8 @@

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10

logger = logging.getLogger('terratorch')
logger = logging.getLogger("terratorch")


def to_segmentation_prediction(y: ModelOutput) -> Tensor:
y_hat = y.output
Expand Down Expand Up @@ -133,7 +133,6 @@ def __init__(
self.monitor = f"{self.val_metrics.prefix}loss"
self.plot_on_val = int(plot_on_val)


def configure_losses(self) -> None:
"""Initialize the loss criterion.
Expand Down Expand Up @@ -262,8 +261,12 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
y_hat_hard = to_segmentation_prediction(model_output)
self.test_metrics[dataloader_idx].update(y_hat_hard, y)

def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
def on_test_epoch_end(self) -> None:
for metrics in self.test_metrics:
self.log_dict(metrics.compute(), sync_dist=True)
metrics.reset()

def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Compute the validation loss and additional metrics.
Args:
batch: The output of your DataLoader.
Expand Down

0 comments on commit 4a0afc2

Please sign in to comment.