Skip to content

Commit

Permalink
Base task for terratorch
Browse files Browse the repository at this point in the history
Signed-off-by: Joao Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Dec 7, 2024
1 parent fc1bc25 commit 974f177
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 1 deletion.
3 changes: 2 additions & 1 deletion terratorch/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from terratorch.tasks.regression_tasks import PixelwiseRegressionTask
from terratorch.tasks.segmentation_tasks import SemanticSegmentationTask
from terratorch.tasks.multilabel_classification_tasks import MultiLabelClassificationTask
from terratorch.tasks.base_task import TerraTorchTask
try:
wxc_present = True
from terratorch.tasks.wxc_downscaling_task import WxCDownscalingTask
Expand All @@ -20,4 +21,4 @@
)

if wxc_present:
__all__.__add__(("WxCDownscalingTask", ))
__all__.__add__(("WxCDownscalingTask", ))
73 changes: 73 additions & 0 deletions terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

from torchgeo.trainers import BaseTask


class TerraTorchTask(BaseTask):

def __init__(self, task:str=None):

self.task = task

super().__init__()

def configure_models(self) -> None:
if not hasattr(self, "model_factory"):
if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]:
logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.")
# Skipping model factory because custom model is provided
return

self.model: Model = self.model_factory.build_model(
self.task, aux_decoders=self.aux_heads, **self.hparams["model_args"]
)

if self.hparams["freeze_backbone"]:
if self.hparams.get("peft_config", None) is not None:
msg = "PEFT should be run with freeze_backbone = False"
raise ValueError(msg)
self.model.freeze_encoder()

if self.hparams["freeze_decoder"]:
self.model.freeze_decoder()

def configure_optimizers(
self,
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
optimizer = self.hparams["optimizer"]
if optimizer is None:
optimizer = "Adam"
return optimizer_factory(
optimizer,
self.hparams["lr"],
self.parameters(),
self.hparams["optimizer_hparams"],
self.hparams["scheduler"],
self.monitor,
self.hparams["scheduler_hparams"],
)

def on_train_epoch_end(self) -> None:
self.log_dict(self.train_metrics.compute(), sync_dist=True)
self.train_metrics.reset()
return super().on_train_epoch_end()

def on_validation_epoch_end(self) -> None:
self.log_dict(self.val_metrics.compute(), sync_dist=True)
self.val_metrics.reset()
return super().on_validation_epoch_end()


def _do_plot_samples(self, batch_index):
if not self.plot_on_val: # dont plot if self.plot_on_val is 0
return False

return (
batch_index < BATCH_IDX_FOR_VALIDATION_PLOTTING
and hasattr(self.trainer, "datamodule")
and self.logger
and not self.current_epoch % self.plot_on_val # will be True every self.plot_on_val epochs
and hasattr(self.logger, "experiment")
and (hasattr(self.logger.experiment, "add_figure") or hasattr(self.logger.experiment, "log_figure"))
)


0 comments on commit 974f177

Please sign in to comment.