Skip to content

Commit

Permalink
fix: add explicit aerial and sentinel band naming to toy dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathias Baumgartinger committed Dec 15, 2024
1 parent d6d62c6 commit bf0b7fc
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions torchgeo/datasets/flair2.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def _load_image(self, path: Path) -> Tensor:
"""
with rasterio.open(path) as f:
array: np.typing.NDArray[np.int_] = f.read()
tensor = torch.from_numpy(array).float()
tensor = torch.from_numpy(array).float() / 255

# Extract the bands of interest
tensor = tensor[[int(band[-2:]) - 1 for band in self.aerial_bands]]
Expand Down Expand Up @@ -731,22 +731,24 @@ def __init__(
self,
root: Path = 'data',
split: str = 'train',
bands: Sequence[str] = FLAIR2.aerial_all_bands,
aerial_bands: Sequence[str] = FLAIR2.aerial_all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
use_sentinel: bool = False,
sentinel_bands: Sequence[str] = FLAIR2.sentinel_all_bands
) -> None:
"""Initialize a new FLAIR2Toy dataset instance.
Args:
root: root directory where dataset can be found
split: which split to load, one of 'train' or 'test'
bands: which bands to load (B01, B02, B03, B04, B05)
aerial_bands: which bands to load (B01, B02, B03, B04, B05)
transforms: optional transforms to apply to sample
download: whether to download the dataset if it is not found
checksum: whether to verify the dataset using checksums
use_sentinel: whether to use sentinel data in the dataset # FIXME: sentinel does not work with dataloader due to varying dimensions
sentinel_bands: which bands to load from sentinel data (B01, B02, ..., B10)
Raises:
DatasetNotFoundError
Expand All @@ -761,7 +763,7 @@ def __init__(
)
print('-' * 80)
super().__init__(
root, split, bands, transforms, download, checksum, use_sentinel
root, split, aerial_bands, transforms, download, checksum, use_sentinel, sentinel_bands
)

def _verify(self) -> None:
Expand Down

0 comments on commit bf0b7fc

Please sign in to comment.