Skip to content

Commit

Permalink
Merge pull request #311 from IBM/basetask
Browse files Browse the repository at this point in the history
Fixing
  • Loading branch information
Joao-L-S-Almeida authored Dec 9, 2024
2 parents 3ad56c7 + ba21d73 commit ea44763
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
2 changes: 2 additions & 0 deletions terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from torchgeo.trainers import BaseTask


BATCH_IDX_FOR_VALIDATION_PLOTTING = 10

class TerraTorchTask(BaseTask):

"""
Expand Down
50 changes: 50 additions & 0 deletions terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,56 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
y_hat_hard = to_segmentation_prediction(model_output)
self.test_metrics[dataloader_idx].update(y_hat_hard, y)

def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:

"""Compute the validation loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
x = batch["image"]
y = batch["mask"]

other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)

loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
y_hat_hard = to_segmentation_prediction(model_output)
self.val_metrics.update(y_hat_hard, y)

if self._do_plot_samples(batch_idx):
try:
datamodule = self.trainer.datamodule
batch["prediction"] = y_hat_hard

if isinstance(batch["image"], dict):
if hasattr(datamodule, 'rgb_modality'):
# Generic multimodal dataset
batch["image"] = batch["image"][datamodule.rgb_modality]
else:
# Multimodal dataset. Assuming first item to be the modality to visualize.
batch["image"] = batch["image"][list(batch["image"].keys())[0]]

for key in ["image", "mask", "prediction"]:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig = datamodule.val_dataset.plot(sample)
if fig:
summary_writer = self.logger.experiment
if hasattr(summary_writer, "add_figure"):
summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
elif hasattr(summary_writer, "log_figure"):
summary_writer.log_figure(
self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
)
except ValueError:
pass
finally:
plt.close()

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
"""Compute the predicted class probabilities.
Expand Down

0 comments on commit ea44763

Please sign in to comment.