Skip to content

Commit

Permalink
Merge pull request #309 from IBM/basetask
Browse files Browse the repository at this point in the history
Base task for terratorch
  • Loading branch information
romeokienzler authored Dec 9, 2024
2 parents 0b57dc9 + dd1e76f commit 3ad56c7
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 249 deletions.
3 changes: 2 additions & 1 deletion terratorch/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from terratorch.tasks.regression_tasks import PixelwiseRegressionTask
from terratorch.tasks.segmentation_tasks import SemanticSegmentationTask
from terratorch.tasks.multilabel_classification_tasks import MultiLabelClassificationTask
from terratorch.tasks.base_task import TerraTorchTask
try:
wxc_present = True
from terratorch.tasks.wxc_downscaling_task import WxCDownscalingTask
Expand All @@ -20,4 +21,4 @@
)

if wxc_present:
__all__.__add__(("WxCDownscalingTask", ))
__all__.__add__(("WxCDownscalingTask", ))
83 changes: 83 additions & 0 deletions terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@

from lightning.pytorch.callbacks import Callback
from torchgeo.trainers import BaseTask


class TerraTorchTask(BaseTask):

"""
Parent used to share common methods among all the
tasks implemented in terratorch
"""

def __init__(self, task:str=None):

self.task = task

super().__init__()

# overwrite early stopping
def configure_callbacks(self) -> list[Callback]:
return []

def configure_models(self) -> None:
if not hasattr(self, "model_factory"):
if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]:
logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.")
# Skipping model factory because custom model is provided
return

self.model: Model = self.model_factory.build_model(
self.task, aux_decoders=self.aux_heads, **self.hparams["model_args"]
)

if self.hparams["freeze_backbone"]:
if self.hparams.get("peft_config", None) is not None:
msg = "PEFT should be run with freeze_backbone = False"
raise ValueError(msg)
self.model.freeze_encoder()

if self.hparams["freeze_decoder"]:
self.model.freeze_decoder()

def configure_optimizers(
self,
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
optimizer = self.hparams["optimizer"]
if optimizer is None:
optimizer = "Adam"
return optimizer_factory(
optimizer,
self.hparams["lr"],
self.parameters(),
self.hparams["optimizer_hparams"],
self.hparams["scheduler"],
self.monitor,
self.hparams["scheduler_hparams"],
)

def on_train_epoch_end(self) -> None:
self.log_dict(self.train_metrics.compute(), sync_dist=True)
self.train_metrics.reset()
return super().on_train_epoch_end()

def on_validation_epoch_end(self) -> None:
self.log_dict(self.val_metrics.compute(), sync_dist=True)
self.val_metrics.reset()
return super().on_validation_epoch_end()


def _do_plot_samples(self, batch_index):
if not self.plot_on_val: # dont plot if self.plot_on_val is 0
return False

return (
batch_index < BATCH_IDX_FOR_VALIDATION_PLOTTING
and hasattr(self.trainer, "datamodule")
and self.logger
and not self.current_epoch % self.plot_on_val # will be True every self.plot_on_val epochs
and hasattr(self.logger, "experiment")
and (hasattr(self.logger.experiment, "add_figure") or hasattr(self.logger.experiment, "log_figure"))
)


59 changes: 3 additions & 56 deletions terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
from torch import Tensor, nn
from torchgeo.datasets.utils import unbind_samples
from torchgeo.trainers import BaseTask
from torchmetrics import ClasswiseWrapper, MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassFBetaScore, MulticlassJaccardIndex

from terratorch.models.model import AuxiliaryHead, Model, ModelOutput
from terratorch.registry.registry import MODEL_FACTORY_REGISTRY
from terratorch.tasks.loss_handler import LossHandler
from terratorch.tasks.optimizer_factory import optimizer_factory
from terratorch.tasks.base_task import TerraTorchTask

logger = logging.getLogger('terratorch')

Expand All @@ -23,7 +23,7 @@ def to_class_prediction(y: ModelOutput) -> Tensor:
return y_hat.argmax(dim=1)


class ClassificationTask(BaseTask):
class ClassificationTask(TerraTorchTask):
"""Classification Task that accepts models from a range of sources.
This class is analog in functionality to class:ClassificationTask defined by torchgeo.
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(
if model_factory and model is None:
self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

super().__init__()
super().__init__(task="classification")

if model:
# Custom model
Expand All @@ -120,44 +120,6 @@ def __init__(
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"

# overwrite early stopping
def configure_callbacks(self) -> list[Callback]:
return []

def configure_models(self) -> None:
if not hasattr(self, "model_factory"):
if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]:
logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.")
# Skipping model factory because custom model is provided
return

self.model: Model = self.model_factory.build_model(
"classification", aux_decoders=self.aux_heads, **self.hparams["model_args"]
)

if self.hparams["freeze_backbone"]:
if self.hparams.get("peft_config", None) is not None:
msg = "PEFT should be run with freeze_backbone = False"
raise ValueError(msg)
self.model.freeze_encoder()
if self.hparams["freeze_decoder"]:
self.model.freeze_decoder()

def configure_optimizers(
self,
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
optimizer = self.hparams["optimizer"]
if optimizer is None:
optimizer = "Adam"
return optimizer_factory(
optimizer,
self.hparams["lr"],
self.parameters(),
self.hparams["optimizer_hparams"],
self.hparams["scheduler"],
self.monitor,
self.hparams["scheduler_hparams"],
)

def configure_losses(self) -> None:
"""Initialize the loss criterion.
Expand Down Expand Up @@ -248,11 +210,6 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->

return loss["loss"]

def on_train_epoch_end(self) -> None:
self.log_dict(self.train_metrics.compute(), sync_dist=True)
self.train_metrics.reset()
return super().on_train_epoch_end()

def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Compute the validation loss and additional metrics.
Expand All @@ -271,11 +228,6 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
y_hat_hard = to_class_prediction(model_output)
self.val_metrics.update(y_hat_hard, y)

def on_validation_epoch_end(self) -> None:
self.log_dict(self.val_metrics.compute(), sync_dist=True)
self.val_metrics.reset()
return super().on_validation_epoch_end()

def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Compute the test loss and additional metrics.
Expand All @@ -294,11 +246,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
y_hat_hard = to_class_prediction(model_output)
self.test_metrics.update(y_hat_hard, y)

def on_test_epoch_end(self) -> None:
self.log_dict(self.test_metrics.compute(), sync_dist=True)
self.test_metrics.reset()
return super().on_test_epoch_end()

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
"""Compute the predicted class probabilities.
Expand Down
76 changes: 4 additions & 72 deletions terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from lightning.pytorch.callbacks import Callback
from torch import Tensor, nn
from torchgeo.datasets.utils import unbind_samples
from torchgeo.trainers import BaseTask
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchmetrics.metric import Metric
from torchmetrics.wrappers.abstract import WrapperMetric
Expand All @@ -21,6 +20,7 @@
from terratorch.tasks.loss_handler import LossHandler
from terratorch.tasks.optimizer_factory import optimizer_factory
from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference
from terratorch.tasks.base_task import TerraTorchTask

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10

Expand Down Expand Up @@ -118,7 +118,7 @@ def reset(self) -> None:
self.metric.reset()


class PixelwiseRegressionTask(BaseTask):
class PixelwiseRegressionTask(TerraTorchTask):
"""Pixelwise Regression Task that accepts models from a range of sources.
This class is analog in functionality to
Expand Down Expand Up @@ -199,8 +199,8 @@ def __init__(
if model_factory and model is None:
self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

super().__init__()
super().__init__(task="regression")

if model:
# Custom_model
self.model = model
Expand All @@ -211,46 +211,6 @@ def __init__(
self.monitor = f"{self.val_metrics.prefix}loss"
self.plot_on_val = int(plot_on_val)

# overwrite early stopping
def configure_callbacks(self) -> list[Callback]:
return []

def configure_models(self) -> None:
if not hasattr(self, "model_factory"):
if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]:
logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.")
# Skipping model factory because custom model is provided
return

self.model: Model = self.model_factory.build_model(
"regression", aux_decoders=self.aux_heads, **self.hparams["model_args"]
)

if self.hparams["freeze_backbone"]:
if self.hparams.get("peft_config", None) is not None:
msg = "PEFT should be run with freeze_backbone = False"
raise ValueError(msg)
self.model.freeze_encoder()

if self.hparams["freeze_decoder"]:
self.model.freeze_decoder()

def configure_optimizers(
self,
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
optimizer = self.hparams["optimizer"]
if optimizer is None:
optimizer = "Adam"
return optimizer_factory(
optimizer,
self.hparams["lr"],
self.parameters(),
self.hparams["optimizer_hparams"],
self.hparams["scheduler"],
self.monitor,
self.hparams["scheduler_hparams"],
)

def configure_losses(self) -> None:
"""Initialize the loss criterion.
Expand Down Expand Up @@ -316,24 +276,6 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->

return loss["loss"]

def on_train_epoch_end(self) -> None:
self.log_dict(self.train_metrics.compute(), sync_dist=True)
self.train_metrics.reset()
return super().on_train_epoch_end()

def _do_plot_samples(self, batch_index):
if not self.plot_on_val: # dont plot if self.plot_on_val is 0
return False

return (
batch_index < BATCH_IDX_FOR_VALIDATION_PLOTTING
and hasattr(self.trainer, "datamodule")
and self.logger
and not self.current_epoch % self.plot_on_val # will be True every self.plot_on_val epochs
and hasattr(self.logger, "experiment")
and (hasattr(self.logger.experiment, "add_figure") or hasattr(self.logger.experiment, "log_figure"))
)

def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Compute the validation loss and additional metrics.
Expand Down Expand Up @@ -376,11 +318,6 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
finally:
plt.close()

def on_validation_epoch_end(self) -> None:
self.log_dict(self.val_metrics.compute(), sync_dist=True)
self.val_metrics.reset()
return super().on_validation_epoch_end()

def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Compute the test loss and additional metrics.
Expand All @@ -399,11 +336,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
y_hat = model_output.output
self.test_metrics.update(y_hat, y)

def on_test_epoch_end(self) -> None:
self.log_dict(self.test_metrics.compute(), sync_dist=True)
self.test_metrics.reset()
return super().on_test_epoch_end()

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
"""Compute the predicted class probabilities.
Expand Down
Loading

0 comments on commit 3ad56c7

Please sign in to comment.