Skip to content

Commit

Permalink
adds firescars and multicrop v2 versions
Browse files Browse the repository at this point in the history
Signed-off-by: Pedro Henrique Conrado <[email protected]>
  • Loading branch information
PedroConrado committed Dec 18, 2024
1 parent 97a2f61 commit 3e5034c
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 99 deletions.
64 changes: 45 additions & 19 deletions terratorch/datamodules/fire_scars.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,42 @@
from terratorch.datamodules.utils import wrap_in_compose_is_list
from terratorch.datasets import FireScarsHLS, FireScarsNonGeo, FireScarsSegmentationMask

MEANS = {
"BLUE": 0.033349706741586264,
"GREEN": 0.05701185520536176,
"RED": 0.05889748132001316,
"NIR_NARROW": 0.2323245113436119,
"SWIR_1": 0.1972854853760658,
"SWIR_2": 0.11944914225186566,
MEANS_PER_VERSION = {
'1': {
"BLUE": 0.0535,
"GREEN": 0.0788,
"RED": 0.0963,
"NIR_NARROW": 0.2119,
"SWIR_1": 0.2360,
"SWIR_2": 0.1731,
},
'2': {
"BLUE": 0.0535,
"GREEN": 0.0788,
"RED": 0.0963,
"NIR_NARROW": 0.2119,
"SWIR_1": 0.2360,
"SWIR_2": 0.1731,
}
}

STDS = {
"BLUE": 0.02269135568823774,
"GREEN": 0.026807560223070237,
"RED": 0.04004109844362779,
"NIR_NARROW": 0.07791732423672691,
"SWIR_1": 0.08708738838140137,
"SWIR_2": 0.07241979477437814,
STDS_PER_VERSION = {
'1': {
"BLUE": 0.0308,
"GREEN": 0.0378,
"RED": 0.0550,
"NIR_NARROW": 0.0707,
"SWIR_1": 0.0919,
"SWIR_2": 0.0841,
},
'2': {
"BLUE": 0.0308,
"GREEN": 0.0378,
"RED": 0.0550,
"NIR_NARROW": 0.0707,
"SWIR_1": 0.0919,
"SWIR_2": 0.0841,
}
}


Expand All @@ -40,6 +60,7 @@ class FireScarsNonGeoDataModule(NonGeoDataModule):
def __init__(
self,
data_root: str,
version: str = '2',
batch_size: int = 4,
num_workers: int = 0,
bands: Sequence[str] = FireScarsNonGeo.all_band_names,
Expand All @@ -54,14 +75,16 @@ def __init__(
) -> None:
super().__init__(FireScarsNonGeo, batch_size, num_workers, **kwargs)
self.data_root = data_root

means = [MEANS[b] for b in bands]
stds = [STDS[b] for b in bands]
means = MEANS_PER_VERSION[version]
stds = STDS_PER_VERSION[version]
self.means = [means[b] for b in bands]
self.stds = [stds[b] for b in bands]
self.version = version
self.bands = bands
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"])
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"])
self.drop_last = drop_last
self.no_data_replace = no_data_replace
self.no_label_replace = no_label_replace
Expand All @@ -71,6 +94,7 @@ def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
split="train",
version=self.version,
data_root=self.data_root,
transform=self.train_transform,
bands=self.bands,
Expand All @@ -81,6 +105,7 @@ def setup(self, stage: str) -> None:
if stage in ["fit", "validate"]:
self.val_dataset = self.dataset_class(
split="val",
version=self.version,
data_root=self.data_root,
transform=self.val_transform,
bands=self.bands,
Expand All @@ -90,7 +115,8 @@ def setup(self, stage: str) -> None:
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="val",
split="test",
version=self.version,
data_root=self.data_root,
transform=self.test_transform,
bands=self.bands,
Expand Down
62 changes: 44 additions & 18 deletions terratorch/datamodules/multi_temporal_crop_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,42 @@
from terratorch.datamodules.utils import wrap_in_compose_is_list
from terratorch.datasets import MultiTemporalCropClassification

MEANS = {
"BLUE": 494.905781,
"GREEN": 815.239594,
"RED": 924.335066,
"NIR_NARROW": 2968.881459,
"SWIR_1": 2634.621962,
"SWIR_2": 1739.579917,
MEANS_PER_VERSION = {
'1': {
"BLUE": 830.5397,
"GREEN": 2427.1667,
"RED": 760.6795,
"NIR_NARROW": 2575.2020,
"SWIR_1": 649.9128,
"SWIR_2": 2344.4357,
},
'2': {
"BLUE": 829.5907,
"GREEN": 2437.3473,
"RED": 748.6308,
"NIR_NARROW": 2568.9369,
"SWIR_1": 638.9926,
"SWIR_2": 2336.4087,
}
}

STDS = {
"BLUE": 284.925432,
"GREEN": 357.84876,
"RED": 575.566823,
"NIR_NARROW": 896.601013,
"SWIR_1": 951.900334,
"SWIR_2": 921.407808,
STDS_PER_VERSION = {
'1': {
"BLUE": 447.9155,
"GREEN": 910.8289,
"RED": 490.9398,
"NIR_NARROW": 1142.5207,
"SWIR_1": 430.9440,
"SWIR_2": 1094.0881,
},
'2': {
"BLUE": 447.1192,
"GREEN": 913.5633,
"RED": 480.5570,
"NIR_NARROW": 1140.6160,
"SWIR_1": 418.6212,
"SWIR_2": 1091.6073,
}
}


Expand All @@ -35,6 +55,7 @@ class MultiTemporalCropClassificationDataModule(NonGeoDataModule):
def __init__(
self,
data_root: str,
version: str = '2',
batch_size: int = 4,
num_workers: int = 0,
bands: Sequence[str] = MultiTemporalCropClassification.all_band_names,
Expand All @@ -51,9 +72,11 @@ def __init__(
) -> None:
super().__init__(MultiTemporalCropClassification, batch_size, num_workers, **kwargs)
self.data_root = data_root

self.means = [MEANS[b] for b in bands]
self.stds = [STDS[b] for b in bands]
means = MEANS_PER_VERSION[version]
stds = STDS_PER_VERSION[version]
self.means = [means[b] for b in bands]
self.stds = [stds[b] for b in bands]
self.version = version
self.bands = bands
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
Expand All @@ -70,6 +93,7 @@ def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
split="train",
version=self.version,
data_root=self.data_root,
transform=self.train_transform,
bands=self.bands,
Expand All @@ -82,6 +106,7 @@ def setup(self, stage: str) -> None:
if stage in ["fit", "validate"]:
self.val_dataset = self.dataset_class(
split="val",
version=self.version,
data_root=self.data_root,
transform=self.val_transform,
bands=self.bands,
Expand All @@ -93,7 +118,8 @@ def setup(self, stage: str) -> None:
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="val",
split="test",
version=self.version,
data_root=self.data_root,
transform=self.test_transform,
bands=self.bands,
Expand Down
53 changes: 41 additions & 12 deletions terratorch/datasets/fire_scars.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import glob
import os
import re
from collections.abc import Sequence
from pathlib import Path
from typing import Any
from typing import Any, Sequence

import albumentations as A
import matplotlib as mpl
Expand All @@ -20,11 +19,22 @@
from torchgeo.datasets import NonGeoDataset, RasterDataset
from xarray import DataArray

from terratorch.datasets.utils import clip_image_percentile, default_transform, validate_bands
from terratorch.datasets.utils import clip_image_percentile, default_transform, filter_valid_files, validate_bands


class FireScarsNonGeo(NonGeoDataset):
"""NonGeo dataset implementation for fire scars."""
"""NonGeo dataset implementation for fire scars.
If using the version 2 dataset, we use the version 2 train/val/test splits from the dataset.
If using the version 1 dataset, we use the version 1 train/val splits from the dataset.
"""
versions = ('1', '2')

splits_per_version = {
'1': {"train": "train", "val": "val", "test": "val"},
'2': {"train": "train", "val": "val", "test": "test"},
}

all_band_names = (
"BLUE",
"GREEN",
Expand All @@ -39,11 +49,11 @@ class FireScarsNonGeo(NonGeoDataset):
BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

num_classes = 2
splits = {"train": "training", "val": "validation"} # Only train and val splits available

def __init__(
self,
data_root: str,
version: str = '2',
split: str = "train",
bands: Sequence[str] = BAND_SETS["all"],
transform: A.Compose | None = None,
Expand All @@ -66,20 +76,39 @@ def __init__(
use_metadata (bool): whether to return metadata info (time and location).
"""
super().__init__()
if split not in self.splits:
msg = f"Incorrect split '{split}', please choose one of {self.splits}."
if version not in self.versions:
msg = f"Incorrect version '{version}', please choose one of {self.versions}."
raise ValueError(msg)
splits = self.splits_per_version[version]
if split not in splits:
msg = f"Incorrect split '{split}', please choose one of {list(splits.keys())}."
raise ValueError(msg)
split_name = self.splits[split]
self.split = split
self.split = splits[split]

validate_bands(bands, self.all_band_names)
self.bands = bands
self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
self.data_root = Path(data_root)

input_dir = self.data_root / split_name
self.image_files = sorted(glob.glob(os.path.join(input_dir, "*_merged.tif")))
self.segmentation_mask_files = sorted(glob.glob(os.path.join(input_dir, "*.mask.tif")))
self.image_files = sorted(glob.glob(os.path.join(self.data_root, "*_merged.tif")))
self.segmentation_mask_files = sorted(glob.glob(os.path.join(self.data_root, "*.mask.tif")))

split_file = self.data_root / f"{self.split}_v{version}_data.txt"
with open(split_file) as f:
split = f.readlines()
valid_files = {rf"{substring.strip()}" for substring in split}
self.image_files = filter_valid_files(
self.image_files,
valid_files=valid_files,
ignore_extensions=True,
allow_substring=True,
)
self.segmentation_mask_files = filter_valid_files(
self.segmentation_mask_files,
valid_files=valid_files,
ignore_extensions=True,
allow_substring=True,
)

self.use_metadata = use_metadata
self.no_data_replace = no_data_replace
Expand Down
Loading

0 comments on commit 3e5034c

Please sign in to comment.