From ba21d7345ccba79c463b82fb234d1283f33a14ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 9 Dec 2024 09:14:55 -0300 Subject: [PATCH] Fixing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/tasks/base_task.py | 2 ++ terratorch/tasks/segmentation_tasks.py | 50 ++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/terratorch/tasks/base_task.py b/terratorch/tasks/base_task.py index 0437556f..762f666f 100644 --- a/terratorch/tasks/base_task.py +++ b/terratorch/tasks/base_task.py @@ -3,6 +3,8 @@ from torchgeo.trainers import BaseTask +BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 + class TerraTorchTask(BaseTask): """ diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index a0214415..3b1fb04f 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -262,6 +262,56 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None y_hat_hard = to_segmentation_prediction(model_output) self.test_metrics[dataloader_idx].update(y_hat_hard, y) + def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + + """Compute the validation loss and additional metrics. + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. + """ + x = batch["image"] + y = batch["mask"] + + other_keys = batch.keys() - {"image", "mask", "filename"} + rest = {k: batch[k] for k in other_keys} + model_output: ModelOutput = self(x, **rest) + + loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) + self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) + y_hat_hard = to_segmentation_prediction(model_output) + self.val_metrics.update(y_hat_hard, y) + + if self._do_plot_samples(batch_idx): + try: + datamodule = self.trainer.datamodule + batch["prediction"] = y_hat_hard + + if isinstance(batch["image"], dict): + if hasattr(datamodule, 'rgb_modality'): + # Generic multimodal dataset + batch["image"] = batch["image"][datamodule.rgb_modality] + else: + # Multimodal dataset. Assuming first item to be the modality to visualize. + batch["image"] = batch["image"][list(batch["image"].keys())[0]] + + for key in ["image", "mask", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.val_dataset.plot(sample) + if fig: + summary_writer = self.logger.experiment + if hasattr(summary_writer, "add_figure"): + summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step) + elif hasattr(summary_writer, "log_figure"): + summary_writer.log_figure( + self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png" + ) + except ValueError: + pass + finally: + plt.close() + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities.