diff --git a/terratorch/tasks/base_task.py b/terratorch/tasks/base_task.py index 6a79498e..0437556f 100644 --- a/terratorch/tasks/base_task.py +++ b/terratorch/tasks/base_task.py @@ -1,15 +1,25 @@ +from lightning.pytorch.callbacks import Callback from torchgeo.trainers import BaseTask class TerraTorchTask(BaseTask): + """ + Parent used to share common methods among all the + tasks implemented in terratorch + """ + def __init__(self, task:str=None): self.task = task super().__init__() + # overwrite early stopping + def configure_callbacks(self) -> list[Callback]: + return [] + def configure_models(self) -> None: if not hasattr(self, "model_factory"): if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]: diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index c7ab25bc..89974004 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -7,7 +7,6 @@ from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss from torch import Tensor, nn from torchgeo.datasets.utils import unbind_samples -from torchgeo.trainers import BaseTask from torchmetrics import ClasswiseWrapper, MetricCollection from torchmetrics.classification import MulticlassAccuracy, MulticlassFBetaScore, MulticlassJaccardIndex @@ -15,6 +14,7 @@ from terratorch.registry.registry import MODEL_FACTORY_REGISTRY from terratorch.tasks.loss_handler import LossHandler from terratorch.tasks.optimizer_factory import optimizer_factory +from terratorch.tasks.base_task import TerraTorchTask logger = logging.getLogger('terratorch') @@ -23,7 +23,7 @@ def to_class_prediction(y: ModelOutput) -> Tensor: return y_hat.argmax(dim=1) -class ClassificationTask(BaseTask): +class ClassificationTask(TerraTorchTask): """Classification Task that accepts models from a range of sources. This class is analog in functionality to class:ClassificationTask defined by torchgeo. @@ -109,7 +109,7 @@ def __init__( if model_factory and model is None: self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) - super().__init__() + super().__init__(task="classification") if model: # Custom model @@ -120,44 +120,6 @@ def __init__( self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" - # overwrite early stopping - def configure_callbacks(self) -> list[Callback]: - return [] - - def configure_models(self) -> None: - if not hasattr(self, "model_factory"): - if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]: - logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.") - # Skipping model factory because custom model is provided - return - - self.model: Model = self.model_factory.build_model( - "classification", aux_decoders=self.aux_heads, **self.hparams["model_args"] - ) - - if self.hparams["freeze_backbone"]: - if self.hparams.get("peft_config", None) is not None: - msg = "PEFT should be run with freeze_backbone = False" - raise ValueError(msg) - self.model.freeze_encoder() - if self.hparams["freeze_decoder"]: - self.model.freeze_decoder() - - def configure_optimizers( - self, - ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": - optimizer = self.hparams["optimizer"] - if optimizer is None: - optimizer = "Adam" - return optimizer_factory( - optimizer, - self.hparams["lr"], - self.parameters(), - self.hparams["optimizer_hparams"], - self.hparams["scheduler"], - self.monitor, - self.hparams["scheduler_hparams"], - ) def configure_losses(self) -> None: """Initialize the loss criterion. @@ -248,11 +210,6 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> return loss["loss"] - def on_train_epoch_end(self) -> None: - self.log_dict(self.train_metrics.compute(), sync_dist=True) - self.train_metrics.reset() - return super().on_train_epoch_end() - def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the validation loss and additional metrics. @@ -271,11 +228,6 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - y_hat_hard = to_class_prediction(model_output) self.val_metrics.update(y_hat_hard, y) - def on_validation_epoch_end(self) -> None: - self.log_dict(self.val_metrics.compute(), sync_dist=True) - self.val_metrics.reset() - return super().on_validation_epoch_end() - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. @@ -294,11 +246,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None y_hat_hard = to_class_prediction(model_output) self.test_metrics.update(y_hat_hard, y) - def on_test_epoch_end(self) -> None: - self.log_dict(self.test_metrics.compute(), sync_dist=True) - self.test_metrics.reset() - return super().on_test_epoch_end() - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index a7211a37..29bbc00f 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -11,7 +11,6 @@ from lightning.pytorch.callbacks import Callback from torch import Tensor, nn from torchgeo.datasets.utils import unbind_samples -from torchgeo.trainers import BaseTask from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection from torchmetrics.metric import Metric from torchmetrics.wrappers.abstract import WrapperMetric @@ -21,6 +20,7 @@ from terratorch.tasks.loss_handler import LossHandler from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference +from terratorch.tasks.base_task import TerraTorchTask BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 @@ -118,7 +118,7 @@ def reset(self) -> None: self.metric.reset() -class PixelwiseRegressionTask(BaseTask): +class PixelwiseRegressionTask(TerraTorchTask): """Pixelwise Regression Task that accepts models from a range of sources. This class is analog in functionality to @@ -199,8 +199,8 @@ def __init__( if model_factory and model is None: self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) - super().__init__() - + super().__init__(task="regression") + if model: # Custom_model self.model = model @@ -211,46 +211,6 @@ def __init__( self.monitor = f"{self.val_metrics.prefix}loss" self.plot_on_val = int(plot_on_val) - # overwrite early stopping - def configure_callbacks(self) -> list[Callback]: - return [] - - def configure_models(self) -> None: - if not hasattr(self, "model_factory"): - if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]: - logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.") - # Skipping model factory because custom model is provided - return - - self.model: Model = self.model_factory.build_model( - "regression", aux_decoders=self.aux_heads, **self.hparams["model_args"] - ) - - if self.hparams["freeze_backbone"]: - if self.hparams.get("peft_config", None) is not None: - msg = "PEFT should be run with freeze_backbone = False" - raise ValueError(msg) - self.model.freeze_encoder() - - if self.hparams["freeze_decoder"]: - self.model.freeze_decoder() - - def configure_optimizers( - self, - ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": - optimizer = self.hparams["optimizer"] - if optimizer is None: - optimizer = "Adam" - return optimizer_factory( - optimizer, - self.hparams["lr"], - self.parameters(), - self.hparams["optimizer_hparams"], - self.hparams["scheduler"], - self.monitor, - self.hparams["scheduler_hparams"], - ) - def configure_losses(self) -> None: """Initialize the loss criterion. @@ -316,24 +276,6 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> return loss["loss"] - def on_train_epoch_end(self) -> None: - self.log_dict(self.train_metrics.compute(), sync_dist=True) - self.train_metrics.reset() - return super().on_train_epoch_end() - - def _do_plot_samples(self, batch_index): - if not self.plot_on_val: # dont plot if self.plot_on_val is 0 - return False - - return ( - batch_index < BATCH_IDX_FOR_VALIDATION_PLOTTING - and hasattr(self.trainer, "datamodule") - and self.logger - and not self.current_epoch % self.plot_on_val # will be True every self.plot_on_val epochs - and hasattr(self.logger, "experiment") - and (hasattr(self.logger.experiment, "add_figure") or hasattr(self.logger.experiment, "log_figure")) - ) - def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the validation loss and additional metrics. @@ -376,11 +318,6 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - finally: plt.close() - def on_validation_epoch_end(self) -> None: - self.log_dict(self.val_metrics.compute(), sync_dist=True) - self.val_metrics.reset() - return super().on_validation_epoch_end() - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. @@ -399,11 +336,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None y_hat = model_output.output self.test_metrics.update(y_hat, y) - def on_test_epoch_end(self) -> None: - self.log_dict(self.test_metrics.compute(), sync_dist=True) - self.test_metrics.reset() - return super().on_test_epoch_end() - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 290bdccd..a0214415 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -10,7 +10,6 @@ from lightning.pytorch.callbacks import Callback from torch import Tensor, nn from torchgeo.datasets.utils import unbind_samples -from torchgeo.trainers import BaseTask from torchmetrics import ClasswiseWrapper, MetricCollection from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassJaccardIndex @@ -19,6 +18,7 @@ from terratorch.tasks.loss_handler import LossHandler from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference +from terratorch.tasks.base_task import TerraTorchTask BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 @@ -29,7 +29,7 @@ def to_segmentation_prediction(y: ModelOutput) -> Tensor: return y_hat.argmax(dim=1) -class SemanticSegmentationTask(BaseTask): +class SemanticSegmentationTask(TerraTorchTask): """Semantic Segmentation Task that accepts models from a range of sources. This class is analog in functionality to class:SemanticSegmentationTask defined by torchgeo. @@ -119,7 +119,7 @@ def __init__( if model_factory and model is None: self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) - super().__init__() + super().__init__(task="segmentation") if model is not None: # Custom model @@ -133,44 +133,6 @@ def __init__( self.monitor = f"{self.val_metrics.prefix}loss" self.plot_on_val = int(plot_on_val) - # overwrite early stopping - def configure_callbacks(self) -> list[Callback]: - return [] - - def configure_models(self) -> None: - if not hasattr(self, "model_factory"): - if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]: - logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.") - # Skipping model factory because custom model is provided - return - - self.model: Model = self.model_factory.build_model( - "segmentation", aux_decoders=self.aux_heads, **self.hparams["model_args"] - ) - - if self.hparams["freeze_backbone"]: - if self.hparams.get("peft_config", None) is not None: - msg = "PEFT should be run with freeze_backbone = False" - raise ValueError(msg) - self.model.freeze_encoder() - if self.hparams["freeze_decoder"]: - self.model.freeze_decoder() - - def configure_optimizers( - self, - ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": - optimizer = self.hparams["optimizer"] - if optimizer is None: - optimizer = "Adam" - return optimizer_factory( - optimizer, - self.hparams["lr"], - self.parameters(), - self.hparams["optimizer_hparams"], - self.hparams["scheduler"], - self.monitor, - self.hparams["scheduler_hparams"], - ) def configure_losses(self) -> None: """Initialize the loss criterion. @@ -275,79 +237,6 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> return loss["loss"] - def on_train_epoch_end(self) -> None: - self.log_dict(self.train_metrics.compute(), sync_dist=True) - self.train_metrics.reset() - return super().on_train_epoch_end() - - def _do_plot_samples(self, batch_index): - if not self.plot_on_val: # dont plot if self.plot_on_val is 0 - return False - - return ( - batch_index < BATCH_IDX_FOR_VALIDATION_PLOTTING - and hasattr(self.trainer, "datamodule") - and self.logger - and not self.current_epoch % self.plot_on_val # will be True every self.plot_on_val epochs - and hasattr(self.logger, "experiment") - and (hasattr(self.logger.experiment, "add_figure") or hasattr(self.logger.experiment, "log_figure")) - ) - - def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: - """Compute the validation loss and additional metrics. - - Args: - batch: The output of your DataLoader. - batch_idx: Integer displaying index of this batch. - dataloader_idx: Index of the current dataloader. - """ - x = batch["image"] - y = batch["mask"] - - other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k: batch[k] for k in other_keys} - model_output: ModelOutput = self(x, **rest) - - loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) - y_hat_hard = to_segmentation_prediction(model_output) - self.val_metrics.update(y_hat_hard, y) - - if self._do_plot_samples(batch_idx): - try: - datamodule = self.trainer.datamodule - batch["prediction"] = y_hat_hard - - if isinstance(batch["image"], dict): - if hasattr(datamodule, 'rgb_modality'): - # Generic multimodal dataset - batch["image"] = batch["image"][datamodule.rgb_modality] - else: - # Multimodal dataset. Assuming first item to be the modality to visualize. - batch["image"] = batch["image"][list(batch["image"].keys())[0]] - - for key in ["image", "mask", "prediction"]: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - fig = datamodule.val_dataset.plot(sample) - if fig: - summary_writer = self.logger.experiment - if hasattr(summary_writer, "add_figure"): - summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step) - elif hasattr(summary_writer, "log_figure"): - summary_writer.log_figure( - self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png" - ) - except ValueError: - pass - finally: - plt.close() - - def on_validation_epoch_end(self) -> None: - self.log_dict(self.val_metrics.compute(), sync_dist=True) - self.val_metrics.reset() - return super().on_validation_epoch_end() - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. @@ -373,12 +262,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None y_hat_hard = to_segmentation_prediction(model_output) self.test_metrics[dataloader_idx].update(y_hat_hard, y) - def on_test_epoch_end(self) -> None: - for metrics in self.test_metrics: - self.log_dict(metrics.compute(), sync_dist=True) - metrics.reset() - return super().on_test_epoch_end() - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities.