Skip to content

Commit

Permalink
Merge pull request #2033 from AdeelH/examples
Browse files Browse the repository at this point in the history
Misc. minor fixes and improvements to examples
  • Loading branch information
AdeelH authored Jan 23, 2024
2 parents dcb7c39 + 92acae3 commit 6a918ba
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 57 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# flake8: noqa

import os
from os.path import join

import albumentations as A

from rastervision.core.rv_pipeline import *
from rastervision.core.backend import *
from rastervision.core.data import *
from rastervision.core.analyzer import *
from rastervision.pytorch_backend import *
from rastervision.pytorch_learner import *
from rastervision.core.rv_pipeline import ChipClassificationConfig
from rastervision.core.data import (
ChipClassificationLabelSourceConfig, ClassConfig,
ClassInferenceTransformerConfig, DatasetConfig, GeoJSONVectorSourceConfig,
RasterioSourceConfig, SceneConfig)
from rastervision.pytorch_backend import PyTorchChipClassificationConfig
from rastervision.pytorch_learner import (
Backbone, ClassificationGeoDataConfig, ClassificationImageDataConfig,
ClassificationModelConfig, ExternalModuleConfig, GeoDataWindowConfig,
GeoDataWindowMethod, SolverConfig)
from rastervision.pytorch_backend.examples.utils import (get_scene_info,
save_image_crop)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# flake8: noqa

import os
from os.path import join

from rastervision.core.rv_pipeline import *
from rastervision.core.backend import *
from rastervision.core.data import *
from rastervision.core.analyzer import *
from rastervision.pytorch_backend import *
from rastervision.pytorch_learner import *
from rastervision.core.rv_pipeline import (ObjectDetectionConfig,
ObjectDetectionChipOptions,
ObjectDetectionPredictOptions)
from rastervision.core.data import (
ClassConfig, ClassInferenceTransformerConfig, DatasetConfig,
GeoJSONVectorSourceConfig, ObjectDetectionLabelSourceConfig,
RasterioSourceConfig, SceneConfig)
from rastervision.pytorch_backend import PyTorchObjectDetectionConfig
from rastervision.pytorch_learner import (
Backbone, ExternalModuleConfig, GeoDataWindowMethod,
ObjectDetectionGeoDataConfig, ObjectDetectionGeoDataWindowConfig,
ObjectDetectionImageDataConfig, ObjectDetectionModelConfig, PlotOptions,
SolverConfig)
from rastervision.pytorch_backend.examples.utils import save_image_crop

TRAIN_IDS = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# flake8: noqa

import os
from os.path import join

from rastervision.core.rv_pipeline import *
from rastervision.core.backend import *
from rastervision.core.data import *
from rastervision.core.analyzer import *
from rastervision.pytorch_backend import *
from rastervision.pytorch_learner import *
from rastervision.core.rv_pipeline import (ObjectDetectionConfig,
ObjectDetectionChipOptions,
ObjectDetectionPredictOptions)
from rastervision.core.data import (
ClassConfig, ClassInferenceTransformerConfig, DatasetConfig,
GeoJSONVectorSourceConfig, ObjectDetectionLabelSourceConfig,
RasterioSourceConfig, SceneConfig)
from rastervision.pytorch_backend import PyTorchObjectDetectionConfig
from rastervision.pytorch_learner import (
Backbone, GeoDataWindowMethod, ObjectDetectionGeoDataConfig,
ObjectDetectionGeoDataWindowConfig, ObjectDetectionImageDataConfig,
ObjectDetectionModelConfig, SolverConfig)
from rastervision.pytorch_backend.examples.utils import (get_scene_info,
save_image_crop)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
# flake8: noqa

import os
from typing import Optional
from os.path import join, basename

from rastervision.core.rv_pipeline import *
from rastervision.core.backend import *
from rastervision.core.data import *
from rastervision.core.analyzer import *
from rastervision.pytorch_backend import *
from rastervision.pytorch_learner import *
from rastervision.pytorch_backend.examples.utils import (get_scene_info,
save_image_crop)
import albumentations as A

from rastervision.core.rv_pipeline import (SemanticSegmentationConfig,
SemanticSegmentationChipOptions,
SemanticSegmentationWindowMethod)
from rastervision.core.data import (
ClassConfig, DatasetConfig, PolygonVectorOutputConfig,
RasterioSourceConfig, RGBClassTransformerConfig, SceneConfig,
SemanticSegmentationLabelSourceConfig,
SemanticSegmentationLabelStoreConfig)
from rastervision.pytorch_backend import PyTorchSemanticSegmentationConfig
from rastervision.pytorch_learner import (
Backbone, ExternalModuleConfig, GeoDataWindowConfig, GeoDataWindowMethod,
PlotOptions, SolverConfig, SemanticSegmentationGeoDataConfig,
SemanticSegmentationImageDataConfig, SemanticSegmentationModelConfig)
from rastervision.pytorch_backend.examples.utils import save_image_crop
from rastervision.pytorch_backend.examples.semantic_segmentation.utils import (
example_multiband_transform, example_rgb_transform, imagenet_stats,
Unnormalize)
Expand All @@ -32,22 +38,24 @@

def get_config(runner,
raw_uri: str,
processed_uri: str,
root_uri: str,
processed_uri: Optional[str] = None,
multiband: bool = False,
external_model: bool = True,
augment: bool = False,
nochip: bool = True,
test: bool = False):
num_epochs: int = 10,
batch_sz: int = 8,
test: bool = False) -> SemanticSegmentationConfig:
"""Generate the pipeline config for this task. This function will be called
by RV, with arguments from the command line, when this example is run.
Args:
runner (Runner): Runner for the pipeline. Will be provided by RV.
raw_uri (str): Directory where the raw data resides
processed_uri (str): Directory for storing processed data.
E.g. crops for testing.
root_uri (str): Directory where all the output will be written.
processed_uri (str): Directory for storing processed data.
E.g. crops for testing. Defaults to None.
multiband (bool, optional): If True, all 4 channels (R, G, B, & IR)
available in the raster source will be used. If False, only
IR, R, G (in that order) will be used. Defaults to False.
Expand All @@ -61,6 +69,8 @@ def get_config(runner,
training instead of from pre-generated chips. The analyze and chip
commands should not be run, if this is set to True. Defaults to
True.
num_epochs (int): Number of epochs to train for.
batch_sz (int): Batch size.
test (bool, optional): If True, does the following simplifications:
(1) Uses only the first 2 scenes
(2) Uses only a 600x600 crop of the scenes
Expand Down Expand Up @@ -203,7 +213,7 @@ def make_scene(id) -> SceneConfig:
num_classes = len(class_config)
model = SemanticSegmentationModelConfig(
external_def=ExternalModuleConfig(
github_repo='AdeelH/pytorch-fpn:0.2',
github_repo='AdeelH/pytorch-fpn:0.3',
name='fpn',
entrypoint='make_fpn_resnet',
entrypoint_kwargs={
Expand All @@ -217,11 +227,15 @@ def make_scene(id) -> SceneConfig:
else:
model = SemanticSegmentationModelConfig(backbone=Backbone.resnet50)

num_epochs = 2 if test else int(num_epochs)
batch_sz = 2 if test else int(batch_sz)
solver = SolverConfig(
lr=1e-4, num_epochs=num_epochs, batch_sz=batch_sz, one_cycle=True)

backend = PyTorchSemanticSegmentationConfig(
data=data,
model=model,
solver=SolverConfig(
lr=1e-4, num_epochs=10, batch_sz=8, one_cycle=True),
solver=solver,
log_tensorboard=True,
run_tensorboard=False,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# flake8: noqa

from typing import Optional
import re
import random
import os
from abc import abstractmethod

from rastervision.pipeline.file_system import list_paths
from rastervision.core.rv_pipeline import *
from rastervision.core.backend import *
from rastervision.core.data import *
from rastervision.pytorch_backend import *
from rastervision.pytorch_learner import *
from rastervision.pipeline.file_system.utils import list_paths
from rastervision.core.rv_pipeline import (SemanticSegmentationConfig,
SemanticSegmentationChipOptions,
SemanticSegmentationWindowMethod)
from rastervision.core.data import (
BufferTransformerConfig, ClassConfig, ClassInferenceTransformerConfig,
DatasetConfig, GeoJSONVectorSourceConfig, PolygonVectorOutputConfig,
RasterioSourceConfig, RasterizedSourceConfig, RasterizerConfig,
SceneConfig, SemanticSegmentationLabelSourceConfig,
SemanticSegmentationLabelStoreConfig, StatsTransformerConfig)
from rastervision.pytorch_backend import PyTorchSemanticSegmentationConfig
from rastervision.pytorch_learner import (
Backbone, GeoDataWindowConfig, GeoDataWindowMethod, SolverConfig,
SemanticSegmentationGeoDataConfig, SemanticSegmentationImageDataConfig,
SemanticSegmentationModelConfig)

BUILDINGS = 'buildings'
ROADS = 'roads'
Expand All @@ -28,21 +35,22 @@ def create(raw_uri, target):
elif target.lower() == ROADS:
return VegasRoads(raw_uri)
else:
raise ValueError('{} is not a valid target.'.format(target))
raise ValueError(f'{target} is not a valid target.')

def get_raster_source_uri(self, id):
filename = f'{self.raster_fn_prefix}{id}.tif'
return os.path.join(self.raw_uri, self.base_dir, self.raster_dir,
'{}{}.tif'.format(self.raster_fn_prefix, id))
filename)

def get_geojson_uri(self, id):
filename = f'{self.label_fn_prefix}{id}.geojson'
return os.path.join(self.raw_uri, self.base_dir, self.label_dir,
'{}{}.geojson'.format(self.label_fn_prefix, id))
filename)

def get_scene_ids(self):
label_dir = os.path.join(self.raw_uri, self.base_dir, self.label_dir)
label_paths = list_paths(label_dir, ext='.geojson')
label_re = re.compile(r'.*{}(\d+)\.geojson'.format(
self.label_fn_prefix))
label_re = re.compile(rf'.*{self.label_fn_prefix}(\d+)\.geojson')
scene_ids = [
label_re.match(label_path).group(1) for label_path in label_paths
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import csv
from io import StringIO

from rastervision.pipeline.file_system.utils import file_exists
from rastervision.core.data import (RasterioSource, GeoJSONVectorSource,
ClassInferenceTransformer)
from rastervision.core.data.utils import geoms_to_geojson, crop_geotiff
Expand Down Expand Up @@ -48,6 +49,8 @@ def save_image_crop(
ValueError if cannot find a crop satisfying min_features constraint.
"""
print(f'Saving test crop to {image_crop_uri}...')
if file_exists(image_crop_uri):
print(f'Already exists. Skipping.')
old_environ = os.environ.copy()
try:
request_payer = S3FileSystem.get_request_payer()
Expand Down

0 comments on commit 6a918ba

Please sign in to comment.