Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GeoJSON output option for vessel detection pipelines #97

Merged
merged 15 commits into from
Feb 6, 2025
Merged
1 change: 1 addition & 0 deletions data/sentinel2_vessels/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ data:
load_all_patches: true
skip_targets: true
patch_size: 512
overlap_ratio: 0.1
trainer:
max_epochs: 500
callbacks:
Expand Down
1 change: 1 addition & 0 deletions rslp/landsat_vessels/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Landsat config
LANDSAT_LAYER_NAME = "landsat"
LANDSAT_RESOLUTION = 15
LANDSAT_SOURCE = "landsat"

# Data config
LOCAL_FILES_DATASET_CONFIG = "data/landsat_vessels/predict_dataset_config.json"
Expand Down
86 changes: 36 additions & 50 deletions rslp/landsat_vessels/predict_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rslearn.data_sources import Item, data_source_from_config
from rslearn.data_sources.aws_landsat import LandsatOliTirs
from rslearn.dataset import Dataset, Window, WindowLayerData
from rslearn.utils import Projection, STGeometry
from rslearn.utils import Projection
from rslearn.utils.get_utm_ups_crs import get_utm_ups_projection
from typing_extensions import TypedDict
from upath import UPath
Expand All @@ -29,6 +29,7 @@
LANDSAT_BANDS,
LANDSAT_LAYER_NAME,
LANDSAT_RESOLUTION,
LANDSAT_SOURCE,
LOCAL_FILES_DATASET_CONFIG,
)
from rslp.log_utils import get_logger
Expand All @@ -42,37 +43,11 @@
materialize_dataset,
run_model_predict,
)
from rslp.vessels import VesselDetection

logger = get_logger(__name__)


class VesselDetection:
"""A vessel detected in a Landsat scene."""

def __init__(
self,
col: int,
row: int,
projection: Projection,
score: float,
crop_window_dir: UPath | None = None,
) -> None:
"""Create a new VesselDetection.

Args:
col: the column in projection coordinates.
row: the row in projection coordinates.
projection: the projection used.
score: confidence score from object detector.
crop_window_dir: the path to the window used for classifying the crop.
"""
self.col = col
self.row = row
self.projection = projection
self.score = score
self.crop_window_dir = crop_window_dir


class FormattedPrediction(TypedDict):
"""Formatted prediction for a single vessel detection."""

Expand Down Expand Up @@ -151,14 +126,18 @@ def get_vessel_detections(
col = int(shp.centroid.x)
row = int(shp.centroid.y)
score = feature["properties"]["score"]
detections.append(
VesselDetection(
col=col,
row=row,
projection=projection,
score=score,
)

detection = VesselDetection(
source=LANDSAT_SOURCE,
col=col,
row=row,
projection=projection,
score=score,
)
if item:
detection.scene_id = item.name
detection.ts = item.geometry.time_range[0]
detections.append(detection)

return detections

Expand Down Expand Up @@ -192,7 +171,7 @@ def run_classifier(
for detection in detections:
favyen2 marked this conversation as resolved.
Show resolved Hide resolved
window_name = f"{detection.col}_{detection.row}"
window_path = ds_path / "windows" / group / window_name
detection.crop_window_dir = window_path
detection.metadata["crop_window_dir"] = window_path
bounds = [
detection.col - CLASSIFY_WINDOW_SIZE // 2,
detection.row - CLASSIFY_WINDOW_SIZE // 2,
Expand Down Expand Up @@ -275,6 +254,7 @@ def predict_pipeline(
json_path: str | None = None,
scratch_path: str | None = None,
crop_path: str | None = None,
geojson_path: str | None = None,
) -> list[FormattedPrediction]:
"""Run the Landsat vessel prediction pipeline.

Expand All @@ -292,6 +272,7 @@ def predict_pipeline(
json_path: path to write vessel detections as JSON file.
scratch_path: directory to use to store temporary dataset.
crop_path: path to write the vessel crop images.
geojson_path: path to write vessel detections as GeoJSON file.
"""
if scratch_path is None:
tmp_scratch_dir = tempfile.TemporaryDirectory()
Expand Down Expand Up @@ -430,31 +411,23 @@ def predict_pipeline(
crop_upath.mkdir(parents=True, exist_ok=True)

json_data = []
geojson_features = []
near_infra_filter = NearInfraFilter(infra_distance_threshold=INFRA_THRESHOLD_KM)
infra_detections = 0
for idx, detection in enumerate(detections):
# Get longitude/latitude.
src_geom = STGeometry(
detection.projection, shapely.Point(detection.col, detection.row), None
)
dst_geom = src_geom.to_projection(WGS84_PROJECTION)
lon = dst_geom.shp.x
lat = dst_geom.shp.y
# Apply near infra filter (True -> filter out, False -> keep)
if near_infra_filter.should_filter(lat, lon):
lon, lat = detection.get_lon_lat()
if near_infra_filter.should_filter(lon, lat):
infra_detections += 1
continue
# Load crops from the window directory.
images = {}
if detection.crop_window_dir is None:
crop_window_dir = detection.metadata["crop_window_dir"]
if crop_window_dir is None:
raise ValueError("Crop window directory is None")
for band in ["B2", "B3", "B4", "B8"]:
favyen2 marked this conversation as resolved.
Show resolved Hide resolved
image_fname = (
detection.crop_window_dir
/ "layers"
/ LANDSAT_LAYER_NAME
/ band
/ "geotiff.tif"
crop_window_dir / "layers" / LANDSAT_LAYER_NAME / band / "geotiff.tif"
favyen2 marked this conversation as resolved.
Show resolved Hide resolved
)
with image_fname.open("rb") as f:
with rasterio.open(f) as src:
Expand Down Expand Up @@ -499,10 +472,23 @@ def predict_pipeline(
b8_fname=str(b8_fname),
),
)
geojson_features.append(detection.to_feature())

if json_path:
json_upath = UPath(json_path)
with json_upath.open("w") as f:
json.dump(json_data, f)

if geojson_path:
geojson_upath = UPath(geojson_path)
with geojson_upath.open("w") as f:
json.dump(
{
"type": "FeatureCollection",
"properties": {},
"features": geojson_features,
},
f,
)

return json_data
Loading
Loading