Skip to content

Commit

Permalink
Merge pull request #2060 from AdeelH/onnx-pred
Browse files Browse the repository at this point in the history
Simplify inference with ONNX models
  • Loading branch information
AdeelH authored Feb 14, 2024
2 parents e6510d7 + 822cdae commit d6d28f7
Show file tree
Hide file tree
Showing 12 changed files with 201 additions and 162 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING
from os.path import join
import uuid

Expand All @@ -7,12 +7,10 @@
from rastervision.pytorch_backend.pytorch_learner_backend import (
PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_backend.utils import chip_collate_fn_cc
from rastervision.pytorch_learner.dataset import (
ClassificationSlidingWindowGeoDataset)
from rastervision.core.data import ChipClassificationLabels
from rastervision.pytorch_learner.utils import predict_scene_cc

if TYPE_CHECKING:
import numpy as np
from rastervision.core.data import DatasetConfig, Scene
from rastervision.core.rv_pipeline import ChipOptions, PredictOptions
from rastervision.pytorch_learner import ClassificationGeoDataConfig
Expand Down Expand Up @@ -60,32 +58,9 @@ def chip_dataset(self,

def predict_scene(self, scene: 'Scene', predict_options: 'PredictOptions'
) -> 'ChipClassificationLabels':

if self.learner is None:
self.load_model()

chip_sz = predict_options.chip_sz
stride = predict_options.stride
batch_sz = predict_options.batch_sz

# Important to use self.learner.cfg.data instead of
# self.learner_cfg.data because of the updates
# Learner.from_model_bundle() makes to the custom transforms.
base_tf, _ = self.learner.cfg.data.get_data_transforms()
ds = ClassificationSlidingWindowGeoDataset(
scene, size=chip_sz, stride=stride, transform=base_tf)

predictions: Iterator['np.array'] = self.learner.predict_dataset(
ds,
raw_out=True,
numpy_out=True,
dataloader_kw=dict(batch_size=batch_sz),
progress_bar=True,
progress_bar_kw=dict(desc=f'Making predictions on {scene.id}'))

labels = ChipClassificationLabels.from_predictions(
ds.windows, predictions)

labels = predict_scene_cc(self.learner, scene, predict_options)
return labels

def _make_chip_data_config(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from typing import TYPE_CHECKING, Dict, Iterator
from typing import TYPE_CHECKING
from os.path import join, basename
import uuid

import numpy as np

from rastervision.pipeline.file_system import json_to_file
from rastervision.core.data_sample import DataSample
from rastervision.core.data.label import ObjectDetectionLabels
from rastervision.pytorch_backend.pytorch_learner_backend import (
PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_backend.utils import chip_collate_fn_od
from rastervision.pytorch_learner.dataset import (
ObjectDetectionSlidingWindowGeoDataset)
from rastervision.pytorch_learner.utils import predict_scene_od

if TYPE_CHECKING:
from rastervision.core.data import DatasetConfig, Scene
Expand Down Expand Up @@ -117,39 +114,9 @@ def chip_dataset(self,
def predict_scene(self, scene: 'Scene',
predict_options: 'ObjectDetectionPredictOptions'
) -> ObjectDetectionLabels:

chip_sz = predict_options.chip_sz
stride = predict_options.stride
batch_sz = predict_options.batch_sz

if self.learner is None:
self.load_model()

# Important to use self.learner.cfg.data instead of
# self.learner_cfg.data because of the updates
# Learner.from_model_bundle() makes to the custom transforms.
base_tf, _ = self.learner.cfg.data.get_data_transforms()
ds = ObjectDetectionSlidingWindowGeoDataset(
scene, size=chip_sz, stride=stride, transform=base_tf)

predictions: Iterator[Dict[str, 'np.ndarray']] = (
self.learner.predict_dataset(
ds,
raw_out=True,
numpy_out=True,
predict_kw=dict(out_shape=(chip_sz, chip_sz)),
dataloader_kw=dict(batch_size=batch_sz),
progress_bar=True,
progress_bar_kw=dict(desc=f'Making predictions on {scene.id}'))
)

labels = ObjectDetectionLabels.from_predictions(
ds.windows, predictions)
labels = ObjectDetectionLabels.prune_duplicates(
labels,
score_thresh=predict_options.score_thresh,
merge_thresh=predict_options.merge_thresh)

labels = predict_scene_od(self.learner, scene, predict_options)
return labels

def _make_chip_data_config(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING
from os.path import join
import uuid

Expand All @@ -10,12 +10,10 @@
from rastervision.pytorch_backend.pytorch_learner_backend import (
PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_backend.utils import chip_collate_fn_ss
from rastervision.pytorch_learner.dataset import (
SemanticSegmentationSlidingWindowGeoDataset)
from rastervision.pytorch_learner.utils import predict_scene_ss

if TYPE_CHECKING:
from rastervision.core.data import (DatasetConfig, Scene,
SemanticSegmentationLabelStore)
from rastervision.core.data import DatasetConfig, Scene
from rastervision.core.rv_pipeline import (
ChipOptions, SemanticSegmentationPredictOptions)
from rastervision.pytorch_learner import SemanticSegmentationGeoDataConfig
Expand Down Expand Up @@ -71,51 +69,9 @@ def chip_dataset(self,
def predict_scene(self, scene: 'Scene',
predict_options: 'SemanticSegmentationPredictOptions'
) -> 'SemanticSegmentationLabels':

if scene.label_store is None:
raise ValueError(
f'Scene.label_store is not set for scene {scene.id}')

if self.learner is None:
self.load_model()

chip_sz = predict_options.chip_sz
stride = predict_options.stride
crop_sz = predict_options.crop_sz
batch_sz = predict_options.batch_sz

label_store: 'SemanticSegmentationLabelStore' = scene.label_store
raw_out = label_store.smooth_output

# Important to use self.learner.cfg.data instead of
# self.learner_cfg.data because of the updates
# Learner.from_model_bundle() makes to the custom transforms.
base_tf, _ = self.learner.cfg.data.get_data_transforms()
pad_direction = 'end' if crop_sz is None else 'both'
ds = SemanticSegmentationSlidingWindowGeoDataset(
scene,
size=chip_sz,
stride=stride,
pad_direction=pad_direction,
transform=base_tf)

predictions: Iterator[np.ndarray] = self.learner.predict_dataset(
ds,
raw_out=raw_out,
numpy_out=True,
predict_kw=dict(out_shape=(chip_sz, chip_sz)),
dataloader_kw=dict(batch_size=batch_sz),
progress_bar=True,
progress_bar_kw=dict(desc=f'Making predictions on {scene.id}'))

labels = SemanticSegmentationLabels.from_predictions(
ds.windows,
predictions,
smooth=raw_out,
extent=scene.extent,
num_classes=len(label_store.class_config),
crop_sz=crop_sz)

labels = predict_scene_ss(self.learner, scene, predict_options)
return labels

def _make_chip_data_config(self, dataset: 'DatasetConfig',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def build_default_model(self, num_classes: int,
class ClassificationLearnerConfig(LearnerConfig):
"""Configure a :class:`.ClassificationLearner`."""

data: Union[ClassificationImageDataConfig, ClassificationGeoDataConfig]
model: Optional[ClassificationModelConfig]

def build(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,21 @@ def __init__(self,
training mode. Defaults to True.
"""
self.cfg = cfg

if model is None and cfg.model is None:
self.training = training
self._onnx_mode = (model_weights_path is not None
and model_weights_path.lower().endswith('.onnx'))
if self.onnx_mode and self.training:
raise ValueError('Training mode is not supported for ONNX models.')
if model is None and cfg.model is None and not self.onnx_mode:
raise ValueError(
'cfg.model can only be None if a custom model is specified.')
'cfg.model can only be None if a custom model is specified '
'or if model_weights_path is an .onnx file.')

if tmp_dir is None:
self._tmp_dir = get_tmp_dir()
tmp_dir = self._tmp_dir.name
self.tmp_dir = tmp_dir

self.training = training

self.train_ds = train_ds
self.valid_ds = valid_ds
self.test_ds = test_ds
Expand Down Expand Up @@ -198,38 +201,46 @@ def __init__(self,
# ---------------------------
# Set URIs
# ---------------------------
if output_dir is None and cfg.output_uri is None:
raise ValueError('output_dir or LearnerConfig.output_uri must '
'be specified.')
if output_dir is not None and cfg.output_uri is not None:
log.warning(
'Both output_dir and LearnerConfig.output_uri specified. '
'LearnerConfig.output_uri will be ignored.')
if output_dir is None:
assert cfg.output_uri is not None
self.output_dir = cfg.output_uri
self.model_bundle_uri = cfg.get_model_bundle_uri()
else:
self.output_dir = output_dir
self.model_bundle_uri = join(self.output_dir, 'model-bundle.zip')
if is_local(self.output_dir):
self.output_dir_local = self.output_dir
make_dir(self.output_dir_local)
else:
self.output_dir_local = get_local_path(self.output_dir, tmp_dir)
make_dir(self.output_dir_local, force_empty=True)
if self.training:
self.sync_from_cloud()
log.info(f'Local output dir: {self.output_dir_local}')
log.info(f'Remote output dir: {self.output_dir}')

self.modules_dir = join(self.output_dir, MODULES_DIRNAME)
self.checkpoints_dir_local = join(self.output_dir_local,
CHECKPOINTS_DIRNAME)
make_dir(self.checkpoints_dir_local)
self.output_dir = None
self.output_dir_local = None
self.model_bundle_uri = None
self.modules_dir = None
self.checkpoints_dir_local = None

if self.training:
if output_dir is None and cfg.output_uri is None:
raise ValueError('output_dir or LearnerConfig.output_uri must '
'be specified in training mode.')
if output_dir is not None and cfg.output_uri is not None:
log.warning(
'Both output_dir and LearnerConfig.output_uri specified. '
'LearnerConfig.output_uri will be ignored.')
if output_dir is None:
assert cfg.output_uri is not None
self.output_dir = cfg.output_uri
self.model_bundle_uri = cfg.get_model_bundle_uri()
else:
self.output_dir = output_dir
self.model_bundle_uri = join(self.output_dir,
'model-bundle.zip')
if is_local(self.output_dir):
self.output_dir_local = self.output_dir
make_dir(self.output_dir_local)
else:
self.output_dir_local = get_local_path(self.output_dir,
tmp_dir)
make_dir(self.output_dir_local, force_empty=True)
if self.training:
self.sync_from_cloud()
log.info(f'Local output dir: {self.output_dir_local}')
log.info(f'Remote output dir: {self.output_dir}')

self.modules_dir = join(self.output_dir, MODULES_DIRNAME)
self.checkpoints_dir_local = join(self.output_dir_local,
CHECKPOINTS_DIRNAME)
make_dir(self.checkpoints_dir_local)

# ---------------------------
self._onnx_mode = False
self.init_model_weights_path = model_weights_path
self.init_model_def_path = model_def_path
self.init_loss_def_path = loss_def_path
Expand Down Expand Up @@ -771,7 +782,7 @@ def predict_dataset(self,

dl_kw = dict(
collate_fn=self.get_collate_fn(),
batch_size=cfg.solver.batch_sz,
batch_size=cfg.solver.batch_sz if cfg.solver else 1,
num_workers=int(num_workers),
shuffle=False,
pin_memory=True)
Expand Down Expand Up @@ -1101,9 +1112,7 @@ def setup_model(self,
model_def_path (Optional[str], optional): Path to model definition.
Will be available when loading from a bundle. Defaults to None.
"""
self._onnx_mode = (model_weights_path is not None
and model_weights_path.lower().endswith('.onnx'))
if self._onnx_mode:
if self.onnx_mode:
self.model = self.load_onnx_model(model_weights_path)
return
if self.model is None:
Expand Down Expand Up @@ -1716,7 +1725,8 @@ def load_checkpoint(self):

def load_onnx_model(self, model_path: str) -> ONNXRuntimeAdapter:
log.info(f'Loading ONNX model from {model_path}')
onnx_model = ONNXRuntimeAdapter.from_file(model_path)
path = download_if_needed(model_path)
onnx_model = ONNXRuntimeAdapter.from_file(path)
return onnx_model

def log_data_stats(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1373,8 +1373,8 @@ def learner_config_upgrader(cfg_dict: dict, version: int) -> dict:
@register_config('learner', upgrader=learner_config_upgrader)
class LearnerConfig(Config):
"""Config for Learner."""
model: Optional[ModelConfig]
solver: SolverConfig
model: Optional[ModelConfig] = None
solver: Optional[SolverConfig] = None
data: DataConfig

eval_train: bool = Field(
Expand Down Expand Up @@ -1411,7 +1411,9 @@ def validate_run_tensorboard(cls, v: bool, values: dict) -> bool:

@root_validator(skip_on_failure=True)
def validate_class_loss_weights(cls, values: dict) -> dict:
solver: SolverConfig = values.get('solver')
solver: Optional[SolverConfig] = values.get('solver')
if solver is None:
return values
class_loss_weights = solver.class_loss_weights
if class_loss_weights is not None:
data: DataConfig = values.get('data')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def build_default_model(self, num_classes: int, in_channels: int,
class ObjectDetectionLearnerConfig(LearnerConfig):
"""Configure an :class:`.ObjectDetectionLearner`."""

data: Union[ObjectDetectionImageDataConfig, ObjectDetectionGeoDataConfig]
model: Optional[ObjectDetectionModelConfig]

def build(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ class RegressionLearnerConfig(LearnerConfig):
"""Configure a :class:`.RegressionLearner`."""

model: Optional[RegressionModelConfig]
data: Union[RegressionImageDataConfig, RegressionGeoDataConfig]

def build(self,
tmp_dir,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Union
from typing import Callable, Optional
from os.path import join
from enum import Enum
import logging
Expand Down Expand Up @@ -209,8 +209,6 @@ def build_default_model(self, num_classes: int,
class SemanticSegmentationLearnerConfig(LearnerConfig):
"""Configure a :class:`.SemanticSegmentationLearner`."""

data: Union[SemanticSegmentationImageDataConfig,
SemanticSegmentationGeoDataConfig]
model: Optional[SemanticSegmentationModelConfig]

def build(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from rastervision.pytorch_learner.utils.utils import *
from rastervision.pytorch_learner.utils.torch_hub import *
from rastervision.pytorch_learner.utils.distributed import *
from rastervision.pytorch_learner.utils.prediction import *

__all__ = [
SplitTensor.__name__,
Expand All @@ -24,4 +25,7 @@
torch_hub_load_local.__name__,
DDPContextManager.__name__,
'DDP_BACKEND',
predict_scene_cc.__name__,
predict_scene_od.__name__,
predict_scene_ss.__name__,
]
Loading

0 comments on commit d6d28f7

Please sign in to comment.