diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner_config.py index 4e96cfb2b..dfa06dc14 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner_config.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union from enum import Enum import logging @@ -15,6 +15,9 @@ ClassificationRandomWindowGeoDataset) from rastervision.pytorch_learner.utils import adjust_conv_channels +if TYPE_CHECKING: + from rastervision.core.data import SceneConfig + log = logging.getLogger(__name__) @@ -53,11 +56,13 @@ class ClassificationGeoDataConfig(ClassificationDataConfig, GeoDataConfig): See :mod:`rastervision.pytorch_learner.dataset.classification_dataset`. """ - def build_scenes(self, tmp_dir: str): - for s in self.scene_dataset.all_scenes: + def build_scenes(self, + scene_configs: Iterable['SceneConfig'], + tmp_dir: Optional[str] = None): + for s in scene_configs: if s.label_source is not None: s.label_source.lazy = True - return super().build_scenes(tmp_dir=tmp_dir) + return super().build_scenes(scene_configs, tmp_dir=tmp_dir) def scene_to_dataset(self, scene: Scene, diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py index 72e396b8a..2523ee0de 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py @@ -1,11 +1,11 @@ +from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List, + Literal, Optional, Sequence, Tuple, Union) from os.path import join, isdir from enum import Enum import random import uuid import logging -from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List, - Literal, Optional, Sequence, Tuple, Union) from pydantic import (PositiveFloat, PositiveInt as PosInt, constr, confloat, conint) from pydantic.utils import sequence_like @@ -29,6 +29,7 @@ torch_hub_load_local, torch_hub_load_github, torch_hub_load_uri) if TYPE_CHECKING: + from rastervision.core.data import SceneConfig from rastervision.pytorch_learner.learner import Learner log = logging.getLogger(__name__) @@ -694,9 +695,6 @@ def validate_plot_options(cls, values: dict) -> dict: plot_options.update(img_channels=img_channels) return values - def make_datasets(self) -> Tuple[Dataset, Dataset, Dataset]: - raise NotImplementedError() - def get_custom_albumentations_transforms(self) -> List[dict]: """Returns all custom transforms found in this config. @@ -775,12 +773,20 @@ def get_data_transforms(self) -> Tuple[A.BasicTransform, A.BasicTransform]: return base_transform, aug_transform def build(self, - tmp_dir: str, + tmp_dir: Optional[str] = None, overfit_mode: bool = False, test_mode: bool = False) -> Tuple[Dataset, Dataset, Dataset]: """Build and return train, val, and test datasets.""" raise NotImplementedError() + def build_dataset(self, + split: Literal['train', 'valid', 'test'], + tmp_dir: Optional[str] = None, + overfit_mode: bool = False, + test_mode: bool = False) -> Dataset: + """Build and return dataset for a single split.""" + raise NotImplementedError() + def random_subset_dataset(self, ds: Dataset, size: Optional[int] = None, @@ -861,14 +867,34 @@ def validate_group_uris(cls, values: dict) -> dict: 'len(group_train_sz_rel) != len(group_uris).') return values - def make_datasets(self, - train_dirs: Iterable[str], - val_dirs: Iterable[str], - test_dirs: Iterable[str], - train_tf: Optional[A.BasicTransform] = None, - val_tf: Optional[A.BasicTransform] = None, - test_tf: Optional[A.BasicTransform] = None - ) -> Tuple[Dataset, Dataset, Dataset]: + def _build_dataset( + self, + dirs: Iterable[str], + tf: Optional[A.BasicTransform] = None, + ) -> Tuple[Dataset, Dataset, Dataset]: + """Make datasets for a single split. + + Args: + dirs: Directories where the data is located. + tf: Transform for the dataset. Defaults to None. + + Returns: + PyTorch-compatiable dataset. + """ + per_dir_datasets = [self.dir_to_dataset(d, tf) for d in dirs] + if len(per_dir_datasets) == 0: + per_dir_datasets.append([]) + combined_dataset = ConcatDataset(per_dir_datasets) + return combined_dataset + + def _build_datasets(self, + train_dirs: Iterable[str], + val_dirs: Iterable[str], + test_dirs: Iterable[str], + train_tf: Optional[A.BasicTransform] = None, + val_tf: Optional[A.BasicTransform] = None, + test_tf: Optional[A.BasicTransform] = None + ) -> Tuple[Dataset, Dataset, Dataset]: """Make training, validation, and test datasets. Args: @@ -885,18 +911,9 @@ def make_datasets(self, Returns: PyTorch-compatiable training, validation, and test datasets. """ - train_ds_list = [self.dir_to_dataset(d, train_tf) for d in train_dirs] - val_ds_list = [self.dir_to_dataset(d, val_tf) for d in val_dirs] - test_ds_list = [self.dir_to_dataset(d, test_tf) for d in test_dirs] - - for ds_list in [train_ds_list, val_ds_list, test_ds_list]: - if len(ds_list) == 0: - ds_list.append([]) - - train_ds = ConcatDataset(train_ds_list) - val_ds = ConcatDataset(val_ds_list) - test_ds = ConcatDataset(test_ds_list) - + train_ds = self._build_dataset(train_dirs, train_tf) + val_ds = self._build_dataset(val_dirs, val_tf) + test_ds = self._build_dataset(test_dirs, test_tf) return train_ds, val_ds, test_ds def dir_to_dataset(self, data_dir: str, @@ -909,7 +926,7 @@ def build(self, test_mode: bool = False) -> Tuple[Dataset, Dataset, Dataset]: if self.group_uris is None: - return self.get_datasets_from_uri( + return self._get_datasets_from_uri( self.uri, tmp_dir=tmp_dir, overfit_mode=overfit_mode, @@ -919,7 +936,7 @@ def build(self, log.warning('Both DataConfig.uri and DataConfig.group_uris ' 'specified. Only DataConfig.group_uris will be used.') - train_ds, valid_ds, test_ds = self.get_datasets_from_group_uris( + train_ds, valid_ds, test_ds = self._get_datasets_from_group_uris( self.group_uris, tmp_dir=tmp_dir, overfit_mode=overfit_mode, @@ -931,7 +948,40 @@ def build(self, return train_ds, valid_ds, test_ds - def get_datasets_from_uri( + def build_dataset(self, + split: Literal['train', 'valid', 'test'], + tmp_dir: Optional[str] = None, + overfit_mode: bool = False, + test_mode: bool = False) -> Dataset: + + if self.group_uris is None: + ds = self._get_dataset_from_uri( + self.uri, + split=split, + tmp_dir=tmp_dir, + overfit_mode=overfit_mode, + test_mode=test_mode) + return ds + + if self.uri is not None: + log.warning('Both DataConfig.uri and DataConfig.group_uris ' + 'specified. Only DataConfig.group_uris will be used.') + + ds = self._get_dataset_from_group_uris( + self.group_uris, + split=split, + tmp_dir=tmp_dir, + overfit_mode=overfit_mode, + test_mode=test_mode) + + if split == 'train': + if self.train_sz is not None or self.train_sz_rel is not None: + ds = self.random_subset_dataset( + ds, size=self.train_sz, fraction=self.train_sz_rel) + + return ds + + def _get_datasets_from_uri( self, uri: Union[str, List[str]], tmp_dir: str, @@ -960,7 +1010,7 @@ def get_datasets_from_uri( train_tf = (aug_transform if not overfit_mode else base_transform) val_tf, test_tf = base_transform, base_transform - train_ds, val_ds, test_ds = self.make_datasets( + train_ds, val_ds, test_ds = self._build_datasets( train_dirs=train_dirs, val_dirs=val_dirs, test_dirs=test_dirs, @@ -969,7 +1019,36 @@ def get_datasets_from_uri( test_tf=test_tf) return train_ds, val_ds, test_ds - def get_datasets_from_group_uris( + def _get_dataset_from_uri(self, + uri: Union[str, List[str]], + split: Literal['train', 'valid', 'test'], + tmp_dir: str, + overfit_mode: bool = False, + test_mode: bool = False) -> Dataset: + """Get image dataset from a single zip file. + + Args: + uri (Union[str, List[str]]): Uri of a zip file containing the + images. + + Returns: + Training, validation, and test dataSets. + """ + data_dirs = self.get_data_dirs(uri, unzip_dir=tmp_dir) + + dirs = [join(d, split) for d in data_dirs if isdir(d)] + dirs = [d for d in dirs if isdir(d)] + + base_transform, aug_transform = self.get_data_transforms() + if split == 'train' and not overfit_mode: + tf = aug_transform + else: + tf = base_transform + + ds = self._build_dataset(dirs, tf) + return ds + + def _get_datasets_from_group_uris( self, uris: Union[str, List[str]], tmp_dir: str, @@ -989,7 +1068,7 @@ def get_datasets_from_group_uris( group_sizes = [group_sizes] * len(uris) for uri, size in zip(uris, group_sizes): - train_ds, valid_ds, test_ds = self.get_datasets_from_uri( + train_ds, valid_ds, test_ds = self._get_datasets_from_uri( uri, tmp_dir=tmp_dir, overfit_mode=overfit_mode, @@ -1010,6 +1089,43 @@ def get_datasets_from_group_uris( ConcatDataset(test_ds_lst)) return train_ds, valid_ds, test_ds + def _get_dataset_from_group_uris( + self, + split: Literal['train', 'valid', 'test'], + uris: Union[str, List[str]], + tmp_dir: str, + group_sz: Optional[int] = None, + group_sz_rel: Optional[float] = None, + overfit_mode: bool = False, + test_mode: bool = False, + ) -> Dataset: + + group_sizes = None + if group_sz is not None: + group_sizes = group_sz + elif group_sz_rel is not None: + group_sizes = group_sz_rel + if not sequence_like(group_sizes): + group_sizes = [group_sizes] * len(uris) + + per_uri_dataset = [] + for uri, size in zip(uris, group_sizes): + ds = self._get_dataset_from_uri( + uri, + split=split, + tmp_dir=tmp_dir, + overfit_mode=overfit_mode, + test_mode=test_mode) + if size is not None: + if isinstance(size, float): + ds = self.random_subset_dataset(ds, fraction=size) + else: + ds = self.random_subset_dataset(ds, size=size) + per_uri_dataset.append(ds) + + combined_dataset = ConcatDataset(per_uri_dataset) + return combined_dataset + def get_data_dirs(self, uri: Union[str, List[str]], unzip_dir: str) -> List[str]: """Extract data dirs from uri. @@ -1242,34 +1358,62 @@ def get_class_info_from_class_config_if_needed(cls, values: dict) -> dict: values['class_colors'] = class_config.colors return values - def build_scenes(self, tmp_dir: str - ) -> Tuple[List[Scene], List[Scene], List[Scene]]: + def build_scenes(self, + scene_configs: Iterable['SceneConfig'], + tmp_dir: Optional[str] = None) -> List[Scene]: """Build training, validation, and test scenes.""" + class_config = self.scene_dataset.class_config + scenes = [ + s.build(class_config, tmp_dir, use_transformers=True) + for s in scene_configs + ] + return scenes + + def _build_dataset(self, + split: Literal['train', 'valid', 'test'], + tf: Optional[A.BasicTransform] = None, + tmp_dir: Optional[str] = None, + **kwargs) -> Tuple[Dataset, Dataset, Dataset]: + """Make training, validation, and test datasets. + + Args: + split: Name of data split. One of: 'train', 'valid', 'test'. + tf: Transform for the + dataset. Defaults to None. + tmp_dir: Temporary directory to be used for building scenes. + **kwargs: Kwargs to pass to :meth:`.scene_to_dataset`. + + Returns: + Dataset: PyTorch-compatiable dataset. + """ if self.scene_dataset is None: raise ValueError('Cannot build scenes if scene_dataset is None.') - class_cfg = self.scene_dataset.class_config - train_scenes = [ - s.build(class_cfg, tmp_dir, use_transformers=True) - for s in self.scene_dataset.train_scenes - ] - val_scenes = [ - s.build(class_cfg, tmp_dir, use_transformers=True) - for s in self.scene_dataset.validation_scenes - ] - test_scenes = [ - s.build(class_cfg, tmp_dir, use_transformers=True) - for s in self.scene_dataset.test_scenes + if split == 'train': + scene_configs = self.scene_dataset.train_scenes + elif split == 'valid': + scene_configs = self.scene_dataset.validation_scenes + elif split == 'test': + scene_configs = self.scene_dataset.test_scenes + + scenes = self.build_scenes(scene_configs, tmp_dir) + per_scene_datasets = [ + self.scene_to_dataset(s, tf, **kwargs) for s in scenes ] - return train_scenes, val_scenes, test_scenes + if len(per_scene_datasets) == 0: + per_scene_datasets.append([]) - def make_datasets(self, - tmp_dir: str, - train_tf: Optional[A.BasicTransform] = None, - val_tf: Optional[A.BasicTransform] = None, - test_tf: Optional[A.BasicTransform] = None, - **kwargs) -> Tuple[Dataset, Dataset, Dataset]: + combined_dataset = ConcatDataset(per_scene_datasets) + + return combined_dataset + + def _build_datasets(self, + tmp_dir: Optional[str] = None, + train_tf: Optional[A.BasicTransform] = None, + val_tf: Optional[A.BasicTransform] = None, + test_tf: Optional[A.BasicTransform] = None, + **kwargs) -> Tuple[Dataset, Dataset, Dataset]: """Make training, validation, and test datasets. Args: @@ -1286,26 +1430,9 @@ def make_datasets(self, Tuple[Dataset, Dataset, Dataset]: PyTorch-compatiable training, validation, and test datasets. """ - train_scenes, val_scenes, test_scenes = self.build_scenes(tmp_dir) - - train_ds_list = [ - self.scene_to_dataset(s, train_tf, **kwargs) for s in train_scenes - ] - val_ds_list = [ - self.scene_to_dataset(s, val_tf, **kwargs) for s in val_scenes - ] - test_ds_list = [ - self.scene_to_dataset(s, test_tf, **kwargs) for s in test_scenes - ] - - for ds_list in [train_ds_list, val_ds_list, test_ds_list]: - if len(ds_list) == 0: - ds_list.append([]) - - train_ds = ConcatDataset(train_ds_list) - val_ds = ConcatDataset(val_ds_list) - test_ds = ConcatDataset(test_ds_list) - + train_ds = self._build_dataset('train', train_tf, tmp_dir, **kwargs) + val_ds = self._build_dataset('valid', train_tf, tmp_dir, **kwargs) + test_ds = self._build_dataset('test', train_tf, tmp_dir, **kwargs) return train_ds, val_ds, test_ds def scene_to_dataset(self, @@ -1316,15 +1443,36 @@ def scene_to_dataset(self, """ raise NotImplementedError() + def build_dataset(self, + split: Literal['train', 'valid', 'test'], + tmp_dir: Optional[str] = None, + overfit_mode: bool = False, + test_mode: bool = False) -> Dataset: + + base_transform, aug_transform = self.get_data_transforms() + if split == 'train' and not overfit_mode: + tf = aug_transform + else: + tf = base_transform + + ds = self._build_dataset(split, tf, tmp_dir) + + if split == 'train': + if self.train_sz is not None or self.train_sz_rel is not None: + ds = self.random_subset_dataset( + ds, size=self.train_sz, fraction=self.train_sz_rel) + + return ds + def build(self, - tmp_dir: str, + tmp_dir: Optional[str] = None, overfit_mode: bool = False, test_mode: bool = False) -> Tuple[Dataset, Dataset, Dataset]: base_transform, aug_transform = self.get_data_transforms() train_tf = (aug_transform if not overfit_mode else base_transform) val_tf, test_tf = base_transform, base_transform - train_ds, val_ds, test_ds = self.make_datasets( + train_ds, val_ds, test_ds = self._build_datasets( tmp_dir=tmp_dir, train_tf=train_tf, val_tf=val_tf, test_tf=test_tf) if self.train_sz is not None or self.train_sz_rel is not None: