diff --git a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/chip_classification/spacenet_rio.py b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/chip_classification/spacenet_rio.py index 79115558d..eef87f08b 100644 --- a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/chip_classification/spacenet_rio.py +++ b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/chip_classification/spacenet_rio.py @@ -1,16 +1,16 @@ -# flake8: noqa - import os from os.path import join -import albumentations as A - -from rastervision.core.rv_pipeline import * -from rastervision.core.backend import * -from rastervision.core.data import * -from rastervision.core.analyzer import * -from rastervision.pytorch_backend import * -from rastervision.pytorch_learner import * +from rastervision.core.rv_pipeline import ChipClassificationConfig +from rastervision.core.data import ( + ChipClassificationLabelSourceConfig, ClassConfig, + ClassInferenceTransformerConfig, DatasetConfig, GeoJSONVectorSourceConfig, + RasterioSourceConfig, SceneConfig) +from rastervision.pytorch_backend import PyTorchChipClassificationConfig +from rastervision.pytorch_learner import ( + Backbone, ClassificationGeoDataConfig, ClassificationImageDataConfig, + ClassificationModelConfig, ExternalModuleConfig, GeoDataWindowConfig, + GeoDataWindowMethod, SolverConfig) from rastervision.pytorch_backend.examples.utils import (get_scene_info, save_image_crop) diff --git a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/object_detection/cowc_potsdam.py b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/object_detection/cowc_potsdam.py index f993bb37e..69e443848 100644 --- a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/object_detection/cowc_potsdam.py +++ b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/object_detection/cowc_potsdam.py @@ -1,14 +1,19 @@ -# flake8: noqa - import os from os.path import join -from rastervision.core.rv_pipeline import * -from rastervision.core.backend import * -from rastervision.core.data import * -from rastervision.core.analyzer import * -from rastervision.pytorch_backend import * -from rastervision.pytorch_learner import * +from rastervision.core.rv_pipeline import (ObjectDetectionConfig, + ObjectDetectionChipOptions, + ObjectDetectionPredictOptions) +from rastervision.core.data import ( + ClassConfig, ClassInferenceTransformerConfig, DatasetConfig, + GeoJSONVectorSourceConfig, ObjectDetectionLabelSourceConfig, + RasterioSourceConfig, SceneConfig) +from rastervision.pytorch_backend import PyTorchObjectDetectionConfig +from rastervision.pytorch_learner import ( + Backbone, ExternalModuleConfig, GeoDataWindowMethod, + ObjectDetectionGeoDataConfig, ObjectDetectionGeoDataWindowConfig, + ObjectDetectionImageDataConfig, ObjectDetectionModelConfig, PlotOptions, + SolverConfig) from rastervision.pytorch_backend.examples.utils import save_image_crop TRAIN_IDS = [ diff --git a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/object_detection/xview.py b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/object_detection/xview.py index 2f2b7d8d6..7e4c935f0 100644 --- a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/object_detection/xview.py +++ b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/object_detection/xview.py @@ -1,14 +1,18 @@ -# flake8: noqa - import os from os.path import join -from rastervision.core.rv_pipeline import * -from rastervision.core.backend import * -from rastervision.core.data import * -from rastervision.core.analyzer import * -from rastervision.pytorch_backend import * -from rastervision.pytorch_learner import * +from rastervision.core.rv_pipeline import (ObjectDetectionConfig, + ObjectDetectionChipOptions, + ObjectDetectionPredictOptions) +from rastervision.core.data import ( + ClassConfig, ClassInferenceTransformerConfig, DatasetConfig, + GeoJSONVectorSourceConfig, ObjectDetectionLabelSourceConfig, + RasterioSourceConfig, SceneConfig) +from rastervision.pytorch_backend import PyTorchObjectDetectionConfig +from rastervision.pytorch_learner import ( + Backbone, GeoDataWindowMethod, ObjectDetectionGeoDataConfig, + ObjectDetectionGeoDataWindowConfig, ObjectDetectionImageDataConfig, + ObjectDetectionModelConfig, SolverConfig) from rastervision.pytorch_backend.examples.utils import (get_scene_info, save_image_crop) diff --git a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/isprs_potsdam.py b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/isprs_potsdam.py index f98074b84..cce5ac8a5 100644 --- a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/isprs_potsdam.py +++ b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/isprs_potsdam.py @@ -1,16 +1,22 @@ -# flake8: noqa - -import os +from typing import Optional from os.path import join, basename -from rastervision.core.rv_pipeline import * -from rastervision.core.backend import * -from rastervision.core.data import * -from rastervision.core.analyzer import * -from rastervision.pytorch_backend import * -from rastervision.pytorch_learner import * -from rastervision.pytorch_backend.examples.utils import (get_scene_info, - save_image_crop) +import albumentations as A + +from rastervision.core.rv_pipeline import (SemanticSegmentationConfig, + SemanticSegmentationChipOptions, + SemanticSegmentationWindowMethod) +from rastervision.core.data import ( + ClassConfig, DatasetConfig, PolygonVectorOutputConfig, + RasterioSourceConfig, RGBClassTransformerConfig, SceneConfig, + SemanticSegmentationLabelSourceConfig, + SemanticSegmentationLabelStoreConfig) +from rastervision.pytorch_backend import PyTorchSemanticSegmentationConfig +from rastervision.pytorch_learner import ( + Backbone, ExternalModuleConfig, GeoDataWindowConfig, GeoDataWindowMethod, + PlotOptions, SolverConfig, SemanticSegmentationGeoDataConfig, + SemanticSegmentationImageDataConfig, SemanticSegmentationModelConfig) +from rastervision.pytorch_backend.examples.utils import save_image_crop from rastervision.pytorch_backend.examples.semantic_segmentation.utils import ( example_multiband_transform, example_rgb_transform, imagenet_stats, Unnormalize) @@ -32,22 +38,24 @@ def get_config(runner, raw_uri: str, - processed_uri: str, root_uri: str, + processed_uri: Optional[str] = None, multiband: bool = False, external_model: bool = True, augment: bool = False, nochip: bool = True, - test: bool = False): + num_epochs: int = 10, + batch_sz: int = 8, + test: bool = False) -> SemanticSegmentationConfig: """Generate the pipeline config for this task. This function will be called by RV, with arguments from the command line, when this example is run. Args: runner (Runner): Runner for the pipeline. Will be provided by RV. raw_uri (str): Directory where the raw data resides - processed_uri (str): Directory for storing processed data. - E.g. crops for testing. root_uri (str): Directory where all the output will be written. + processed_uri (str): Directory for storing processed data. + E.g. crops for testing. Defaults to None. multiband (bool, optional): If True, all 4 channels (R, G, B, & IR) available in the raster source will be used. If False, only IR, R, G (in that order) will be used. Defaults to False. @@ -61,6 +69,8 @@ def get_config(runner, training instead of from pre-generated chips. The analyze and chip commands should not be run, if this is set to True. Defaults to True. + num_epochs (int): Number of epochs to train for. + batch_sz (int): Batch size. test (bool, optional): If True, does the following simplifications: (1) Uses only the first 2 scenes (2) Uses only a 600x600 crop of the scenes @@ -203,7 +213,7 @@ def make_scene(id) -> SceneConfig: num_classes = len(class_config) model = SemanticSegmentationModelConfig( external_def=ExternalModuleConfig( - github_repo='AdeelH/pytorch-fpn:0.2', + github_repo='AdeelH/pytorch-fpn:0.3', name='fpn', entrypoint='make_fpn_resnet', entrypoint_kwargs={ @@ -217,11 +227,15 @@ def make_scene(id) -> SceneConfig: else: model = SemanticSegmentationModelConfig(backbone=Backbone.resnet50) + num_epochs = 2 if test else int(num_epochs) + batch_sz = 2 if test else int(batch_sz) + solver = SolverConfig( + lr=1e-4, num_epochs=num_epochs, batch_sz=batch_sz, one_cycle=True) + backend = PyTorchSemanticSegmentationConfig( data=data, model=model, - solver=SolverConfig( - lr=1e-4, num_epochs=10, batch_sz=8, one_cycle=True), + solver=solver, log_tensorboard=True, run_tensorboard=False, ) diff --git a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/spacenet_vegas.py b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/spacenet_vegas.py index 4f7ea4848..c02a054fa 100644 --- a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/spacenet_vegas.py +++ b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/spacenet_vegas.py @@ -1,17 +1,24 @@ -# flake8: noqa - from typing import Optional import re import random import os from abc import abstractmethod -from rastervision.pipeline.file_system import list_paths -from rastervision.core.rv_pipeline import * -from rastervision.core.backend import * -from rastervision.core.data import * -from rastervision.pytorch_backend import * -from rastervision.pytorch_learner import * +from rastervision.pipeline.file_system.utils import list_paths +from rastervision.core.rv_pipeline import (SemanticSegmentationConfig, + SemanticSegmentationChipOptions, + SemanticSegmentationWindowMethod) +from rastervision.core.data import ( + BufferTransformerConfig, ClassConfig, ClassInferenceTransformerConfig, + DatasetConfig, GeoJSONVectorSourceConfig, PolygonVectorOutputConfig, + RasterioSourceConfig, RasterizedSourceConfig, RasterizerConfig, + SceneConfig, SemanticSegmentationLabelSourceConfig, + SemanticSegmentationLabelStoreConfig, StatsTransformerConfig) +from rastervision.pytorch_backend import PyTorchSemanticSegmentationConfig +from rastervision.pytorch_learner import ( + Backbone, GeoDataWindowConfig, GeoDataWindowMethod, SolverConfig, + SemanticSegmentationGeoDataConfig, SemanticSegmentationImageDataConfig, + SemanticSegmentationModelConfig) BUILDINGS = 'buildings' ROADS = 'roads' @@ -28,21 +35,22 @@ def create(raw_uri, target): elif target.lower() == ROADS: return VegasRoads(raw_uri) else: - raise ValueError('{} is not a valid target.'.format(target)) + raise ValueError(f'{target} is not a valid target.') def get_raster_source_uri(self, id): + filename = f'{self.raster_fn_prefix}{id}.tif' return os.path.join(self.raw_uri, self.base_dir, self.raster_dir, - '{}{}.tif'.format(self.raster_fn_prefix, id)) + filename) def get_geojson_uri(self, id): + filename = f'{self.label_fn_prefix}{id}.geojson' return os.path.join(self.raw_uri, self.base_dir, self.label_dir, - '{}{}.geojson'.format(self.label_fn_prefix, id)) + filename) def get_scene_ids(self): label_dir = os.path.join(self.raw_uri, self.base_dir, self.label_dir) label_paths = list_paths(label_dir, ext='.geojson') - label_re = re.compile(r'.*{}(\d+)\.geojson'.format( - self.label_fn_prefix)) + label_re = re.compile(rf'.*{self.label_fn_prefix}(\d+)\.geojson') scene_ids = [ label_re.match(label_path).group(1) for label_path in label_paths ] diff --git a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/utils.py b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/utils.py index 53f16da77..5c92814b7 100644 --- a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/utils.py +++ b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/utils.py @@ -3,6 +3,7 @@ import csv from io import StringIO +from rastervision.pipeline.file_system.utils import file_exists from rastervision.core.data import (RasterioSource, GeoJSONVectorSource, ClassInferenceTransformer) from rastervision.core.data.utils import geoms_to_geojson, crop_geotiff @@ -48,6 +49,8 @@ def save_image_crop( ValueError if cannot find a crop satisfying min_features constraint. """ print(f'Saving test crop to {image_crop_uri}...') + if file_exists(image_crop_uri): + print(f'Already exists. Skipping.') old_environ = os.environ.copy() try: request_payer = S3FileSystem.get_request_payer()