diff --git a/requirements.txt b/requirements.txt index 4ab931d..8169b9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ beaker-py>=1.32 fastapi>=0.115 +google-cloud-bigtable>=2.18 interrogate>=1.7 pydantic>=2.8 pytest>=8.2 diff --git a/rslp/satlas/__init__.py b/rslp/satlas/__init__.py index fc043dc..f043905 100644 --- a/rslp/satlas/__init__.py +++ b/rslp/satlas/__init__.py @@ -8,8 +8,9 @@ """ from .job_launcher_worker import launch_workers, write_jobs, write_jobs_for_year_months -from .postprocess import postprocess_points +from .postprocess import merge_points, smooth_points from .predict_pipeline import predict_multi, predict_pipeline +from .publish import publish_points workflows = { "predict": predict_pipeline, @@ -17,5 +18,7 @@ "write_jobs": write_jobs, "write_jobs_for_year_months": write_jobs_for_year_months, "launch_workers": launch_workers, - "postprocess_points": postprocess_points, + "merge_points": merge_points, + "smooth_points": smooth_points, + "publish_points": publish_points, } diff --git a/rslp/satlas/bkt.py b/rslp/satlas/bkt.py new file mode 100644 index 0000000..e278445 --- /dev/null +++ b/rslp/satlas/bkt.py @@ -0,0 +1,406 @@ +"""Manage bucket files on GCS. + +We bucket together small (10-200 KB) files at high zoom levels (e.g. zoom 13) into a +single file at a lower zoom level (e.g. zoom 9) to save on GCS insert fee. + +This is similar to https://github.com/mactrem/com-tiles. + +The .bkt is just a concatenation of the small files. + +We record the byte offsets in a Google Cloud Bigtable database. +""" + +import functools +import io +import multiprocessing.pool +import os +import struct +import time +from collections.abc import Generator +from typing import Any + +import google.cloud.bigtable.row +import google.cloud.bigtable.row_filters +import google.cloud.bigtable.table +import numpy.typing as npt +import skimage.io +from google.cloud import bigtable, storage +from rslearn.utils.mp import star_imap_unordered + +from rslp.log_utils import get_logger + +logger = get_logger(__name__) + + +class BktInserter: + """A helper class that inserts metadata about bkt files into the database. + + The BktInserter is a separate class from BktWriter so that it can be pickled to + support use with multiprocessing. + """ + + def __init__( + self, + indexes: list[tuple[int, int, int, int]], + bkt_fname: str, + bkt_zoom: int, + zoom: int, + ): + """Create a new BktInserter. + + Args: + indexes: the byte offsets of the files within the bkt. It is a list of + (col, row, offset, length) tuples. + bkt_fname: the filename where the bkt will be written. + bkt_zoom: the zoom level of the bkt. + zoom: the zoom level of the tiles within the bkt. + """ + self.indexes = indexes + self.bkt_fname = bkt_fname + self.bkt_zoom = bkt_zoom + self.zoom = zoom + + def run(self, bkt_files_table: google.cloud.bigtable.table.Table) -> None: + """Insert the metadata into BigTable. + + Args: + bkt_files_table: the BigTable object + """ + # Row key in the table is just the bkt fname. + # Value is [4 byte bkt_zoom][4 byte zoom][indexes]. + # [indexes] is list of indexes encoded as [4 byte col][4 byte row][4 byte offset][4 byte length]. + buf = io.BytesIO() + buf.write(struct.pack(">II", self.bkt_zoom, self.zoom)) + for col, row, offset, length in self.indexes: + buf.write(struct.pack(">IIII", col, row, offset, length)) + db_row = bkt_files_table.direct_row(self.bkt_fname) + db_row.set_cell(b"d", b"d", buf.getvalue()) + db_row.commit() + + +class BktWriter: + """Writer for bkt files.""" + + def __init__(self) -> None: + """Create a new BktWriter.""" + self.indexes: list[tuple[int, int, int, int]] = [] + self.buf = io.BytesIO() + self.offset = 0 + + def add(self, col: int, row: int, bytes: bytes) -> None: + """Add a file to the bkt. + + Args: + col: the tile column. + row: the tile row. + bytes: the data at this tile. + """ + offset = self.offset + length = len(bytes) + self.indexes.append((col, row, offset, length)) + self.buf.write(bytes) + self.offset += length + + def get_bytes(self) -> bytes: + """Returns the bytes of the whole bkt file.""" + return self.buf.getvalue() + + def get_inserter(self, bkt_fname: str, bkt_zoom: int, zoom: int) -> "BktInserter": + """Creates a BktInserter that manages inserting the byte offsets to BigTable. + + Args: + bkt_fname: the filename where the bkt will be written. + bkt_zoom: the zoom level of the bkt file. + zoom: the zoom of the tiles within the bkt file. + + Returns: + a corresponding BktInserter + """ + return BktInserter(self.indexes, bkt_fname, bkt_zoom, zoom) + + def insert( + self, + bkt_files_table: google.cloud.bigtable.table.Table, + bkt_fname: str, + bkt_zoom: int, + zoom: int, + ) -> None: + """Insert the byte offsets for this bkt to BigTable. + + Args: + bkt_files_table: the BigTable table object. + bkt_fname: the filename where the bkt will be written. + bkt_zoom: the zoom level of the bkt file. + zoom: the zoom of the tiles within the bkt file. + """ + self.get_inserter(bkt_fname, bkt_zoom, zoom).run(bkt_files_table) + + +@functools.cache +def get_bucket() -> storage.Bucket: + """Get the GCS bucket where bkt files should be stored.""" + storage_client = storage.Client(project=os.environ["BKT_PROJECT_ID"]) + bucket = storage_client.bucket(os.environ["BKT_BUCKET_NAME"]) + return bucket + + +def download_bkt( + bkt_fname: str, + idx_map: dict[tuple[int, int], tuple[int, int]], + wanted: list[tuple[int, int, Any]], + mode: str, +) -> list[tuple[Any, npt.NDArray | bytes]]: + """Download tiles in a bkt file. + + Args: + bkt_fname: the bkt filename in the bucket to download from. + idx_map: map from tile (col, row) to (offset, length). + wanted: list of tiles to download. It should be a list of (col, row, metadata) + where metadata can be arbitrary data used by the caller that will be + returned with the tile data (which will be emitted in arbitrary order). + Note that if a tile does not exist within the bkt, it will not be returned + at all. + mode: either "image" to decode image and return numpy array, or "raw" to return + the byte string directly. + + Returns: + a list of (metadata, contents) where contents is a numpy array if mode is + "image" or a byte string if mode is "raw". + """ + bucket = get_bucket() + output = [] + + # Helper to postprocess an output based on the specified return mode. + def add_output(metadata: Any, contents: npt.NDArray | bytes) -> None: + if mode == "image": + buf = io.BytesIO(contents) + image = skimage.io.imread(buf) + output.append((metadata, image)) + + elif mode == "raw": + output.append((metadata, contents)) + + else: + raise ValueError(f"invalid mode {mode}") + + wanted = [ + (col, row, metadata) for col, row, metadata in wanted if (col, row) in idx_map + ] + + if len(wanted) == 1: + col, row, metadata = wanted[0] + offset, length = idx_map[(col, row)] + blob = bucket.blob(bkt_fname) + contents = blob.download_as_bytes(start=offset, end=offset + length) + add_output(metadata, contents) + + elif len(wanted) > 1: + blob = bucket.blob(bkt_fname) + bkt_bytes = blob.download_as_bytes() + for col, row, metadata in wanted: + offset, length = idx_map[(col, row)] + contents = bkt_bytes[offset : offset + length] + add_output(metadata, contents) + + return output + + +# Parallel download from various bkt files. +# Jobs is a list of (bkt_fname, col, row, metadata). +# download_from_bkt is a generator that will yield (metadata, bytes) for each provided job. +def download_from_bkt( + bkt_files_table: google.cloud.bigtable.table.Table, + pool: multiprocessing.pool.Pool | None, + jobs: list[tuple[str, int, int, Any]], + mode: str = "raw", +) -> Generator[tuple[Any, npt.NDArray | bytes], None, None]: + """Download tile contents in parallel from several bkt files. + + Args: + bkt_files_table: the BigTable table containing byte offsets. + pool: the multiprocessing pool to use for parallelism, or None to read in + current process. + jobs: list of (bkt_fname, col, row, metadata) to work through. Jobs referencing + the same bkt_fname will be grouped together so we don't read the same bkt + file multiple times. + mode: the return mode (see download_bkt). + + Yields: + the (metadata, contents) tuples across all of the jobs. + """ + # Get indexes associated with each distinct bkt_fname. + by_bkt_fname: dict[str, list[tuple[int, int, Any]]] = {} + for bkt_fname, col, row, metadata in jobs: + if bkt_fname not in by_bkt_fname: + by_bkt_fname[bkt_fname] = [] + by_bkt_fname[bkt_fname].append((col, row, metadata)) + + bkt_jobs: list[dict[str, Any]] = [] + for bkt_fname, wanted in by_bkt_fname.items(): + # Use retry loop since we seem to get error reading from BigTable occasionally. + def bkt_retry_loop() -> google.cloud.bigtable.row.PartialRowData: + for _ in range(8): + try: + db_row = bkt_files_table.read_row( + bkt_fname, + filter_=google.cloud.bigtable.row_filters.CellsColumnLimitFilter( + 1 + ), + ) + return db_row + except Exception as e: + print( + f"got error reading bkt_files_table for {bkt_fname} (trying again): {e}" + ) + time.sleep(1) + raise Exception( + f"repeatedly failed to read bkt_files_table for {bkt_fname}" + ) + + db_row = bkt_retry_loop() + + # Ignore requested files that don't exist. + if not db_row: + continue + # Skip 8-byte header with bkt_zoom/zoom. + encoded_indexes = db_row.cell_value("d", b"d")[8:] + + indexes = {} + for i in range(0, len(encoded_indexes), 16): + col, row, offset, length = struct.unpack( + ">IIII", encoded_indexes[i : i + 16] + ) + indexes[(col, row)] = (offset, length) + bkt_jobs.append( + dict( + bkt_fname=bkt_fname, + idx_map=indexes, + wanted=wanted, + mode=mode, + ) + ) + + if pool is None: + for job in bkt_jobs: + for metadata, image in download_bkt(**job): + yield (metadata, image) + else: + outputs = star_imap_unordered(pool, download_bkt, bkt_jobs) + for output in outputs: + for metadata, image in output: + yield (metadata, image) + + +def upload_bkt(bkt_fname: str, contents: bytes) -> None: + """Upload a bkt file to GCS bucket. + + Args: + bkt_fname: the bkt filename within the bucket. + contents: the data to upload. + """ + bucket = get_bucket() + blob = bucket.blob(bkt_fname) + blob.upload_from_string(contents) + + +# Tuples is list of (bkt_writer, bkt_fname, bkt_zoom, zoom). +def upload_bkts( + bkt_files_table: google.cloud.bigtable.table.Table, + pool: multiprocessing.pool.Pool, + jobs: list[tuple[BktWriter, str, int, int]], +) -> None: + """Upload several bkt files to GCS in parallel. + + Args: + bkt_files_table: the BigTable table to store byte offsets. + pool: a multiprocessing pool for parallelism. + jobs: list of (bkt_writer, bkt_fname, bkt_zoom, zoom) tuples. bkt_writer is the + BktWriter where the bkt contents and metadata are stored. bkt_fname is the + path where the bkt should be written. bkt_zoom in the zoom level of the bkt + file. zoom is the zoom level of tiles within the bkt. + """ + # Upload. We upload first since reader will assume that anything existing in + # BigTable already exists on GCS. + upload_jobs: list[tuple[str, bytes]] = [] + for bkt_writer, bkt_fname, bkt_zoom, zoom in jobs: + upload_jobs.append((bkt_fname, bkt_writer.get_bytes())) + outputs = star_imap_unordered(pool, upload_bkt, upload_jobs) + for _ in outputs: + pass + # Now we insert the metadata. + for bkt_writer, bkt_fname, bkt_zoom, zoom in jobs: + bkt_writer.insert( + bkt_files_table=bkt_files_table, + bkt_fname=bkt_fname, + bkt_zoom=bkt_zoom, + zoom=zoom, + ) + + +def make_bkt(src_dir: str, dst_path: str) -> None: + """Make a bkt file from the specified local source directory. + + The source directory must contain files of the form zoom/col/row.ext (the extension + is ignored). + + A single bkt file is created, so the zoom level of the bkt is always 0. + + Args: + src_dir: the local directory to turn into a single bkt file. + dst_path: the bkt filename in the bkt GCS bucket to write to. It must have a + {zoom} placeholder where the zoom goes. + """ + bucket = get_bucket() + bigtable_client = bigtable.Client(project=os.environ["BKT_BIGTABLE_PROJECT_ID"]) + bigtable_instance = bigtable_client.instance(os.environ["BKT_BIGTABLE_INSTANCE_ID"]) + bkt_files_table = bigtable_instance.table("bkt_files") + + for zoom_str in os.listdir(src_dir): + zoom_dir = os.path.join(src_dir, zoom_str) + if not os.path.isdir(zoom_dir): + continue + zoom = int(zoom_str) + logger.debug( + "make_bkt(%s, %s): start collecting files at zoom level %d", + src_dir, + dst_path, + zoom, + ) + + # Read all files at this zoom level from local path into bkt (in memory). + bkt_writer = BktWriter() + num_files = 0 + for col_str in os.listdir(zoom_dir): + col_dir = os.path.join(zoom_dir, col_str) + col = int(col_str) + for fname in os.listdir(col_dir): + row = int(fname.split(".")[0]) + num_files += 1 + with open(os.path.join(col_dir, fname), "rb") as f: + contents = f.read() + bkt_writer.add(col, row, contents) + logger.debug( + "make_bkt(%s, %s): processed %d files at zoom %d", + src_dir, + dst_path, + num_files, + zoom, + ) + + # Now upload to GCS. + bkt_fname = dst_path.format(zoom=zoom) + logger.debug( + "make_bkt(%s, %s) uploading bkt for zoom level %d to %s", + src_dir, + dst_path, + zoom, + bkt_fname, + ) + blob = bucket.blob(bkt_fname) + blob.upload_from_string(bkt_writer.get_bytes()) + bkt_writer.insert( + bkt_files_table=bkt_files_table, + bkt_fname=bkt_fname, + bkt_zoom=0, + zoom=zoom, + ) diff --git a/rslp/satlas/postprocess.py b/rslp/satlas/postprocess.py index 62fd573..2d67368 100644 --- a/rslp/satlas/postprocess.py +++ b/rslp/satlas/postprocess.py @@ -27,6 +27,13 @@ # exact. NMS_DISTANCE_THRESHOLD = 100 / MAX_METERS_PER_DEGREE +APP_CATEGORY_MAPS = { + Application.MARINE_INFRA: { + "platform": "offshore_platform", + "turbine": "offshore_wind_turbine", + } +} + logger = get_logger(__name__) @@ -90,19 +97,16 @@ def apply_nms( return good_features -def postprocess_points( +def merge_points( application: Application, label: str, predict_path: str, merged_path: str, - smoothed_path: str, workers: int = 32, ) -> None: - """Post-process Satlas point outputs. + """Merge Satlas point outputs. - This merges the outputs across different prediction tasks for this timestamp and - spatial tile. Then it applies Viterbi smoothing that takes into account merged - outputs from previous time ranges, and uploads the results. + This merges the outputs across different prediction tasks for this timestamp. Args: application: the application. @@ -111,11 +115,8 @@ def postprocess_points( the different tasks have been written. merged_path: folder to write merged predictions. The filename will be YYYY-MM.geojson. - smoothed_path: folder to write smoothed predictions. The filename will be - YYYY-MM.geojson. workers: number of worker processes. """ - # Merge the predictions. predict_upath = UPath(predict_path) merged_features = [] merged_patches: dict[str, list[tuple[int, int]]] = {} @@ -124,6 +125,9 @@ def postprocess_points( p = multiprocessing.Pool(workers) outputs = p.imap_unordered(_get_fc, fnames) + # Get category remapping in case one is specified for this application. + category_map = APP_CATEGORY_MAPS.get(application, {}) + for cur_fc in tqdm.tqdm(outputs, total=len(fnames)): # The projection information may be missing if there are no valid patches. if "crs" not in cur_fc["properties"]: @@ -151,6 +155,10 @@ def postprocess_points( dst_geom = src_geom.to_projection(WGS84_PROJECTION) feat["geometry"]["coordinates"] = [dst_geom.shp.x, dst_geom.shp.y] + category = feat["properties"]["category"] + if category in category_map: + feat["properties"]["category"] = category_map[category] + merged_features.append(feat) # Merge the valid patches too, these indicate which portions of the world @@ -162,18 +170,13 @@ def postprocess_points( p.close() - nms_features = apply_nms(merged_features, distance_threshold=NMS_DISTANCE_THRESHOLD) - logger.info( - "NMS filtered from %d -> %d features", len(merged_features), len(nms_features) - ) - merged_upath = UPath(merged_path) merged_fname = merged_upath / f"{label}.geojson" with merged_fname.open("w") as f: json.dump( { "type": "FeatureCollection", - "features": nms_features, + "features": merged_features, "properties": { "valid_patches": merged_patches, }, @@ -181,6 +184,27 @@ def postprocess_points( f, ) + +def smooth_points( + application: Application, + label: str, + merged_path: str, + smoothed_path: str, +) -> None: + """Smooth the Satlas point outputs. + + It applies Viterbi smoothing that takes into account merged outputs from previous + time ranges, and uploads the results. + + Args: + application: the application. + label: YYYY-MM representation of the time range used for this prediction run. + merged_path: folder to write merged predictions. The filename will be + YYYY-MM.geojson. + smoothed_path: folder to write smoothed predictions. The filename will be + YYYY-MM.geojson. + """ + merged_upath = UPath(merged_path) # Download the merged prediction history (ending with the one we just wrote) and # run smoothing. smoothed_upath = UPath(smoothed_path) diff --git a/rslp/satlas/publish.py b/rslp/satlas/publish.py new file mode 100644 index 0000000..ace959f --- /dev/null +++ b/rslp/satlas/publish.py @@ -0,0 +1,234 @@ +"""Publish Satlas outputs.""" + +import json +import os +import shutil +import subprocess # nosec +import tempfile +import zipfile +from typing import Any + +import boto3 +import boto3.s3 +from upath import UPath + +from rslp.log_utils import get_logger +from rslp.satlas.bkt import make_bkt + +from .predict_pipeline import Application + +logger = get_logger(__name__) + +# Number of timesteps to re-publish. +# Smoothing for points changes all of the outputs but we only upload outputs for this +# many of the most recent timesteps. +NUM_RECOMPUTE = 6 + +# Name on Cloudflare R2 for each application. +APP_NAME_ON_R2 = { + Application.MARINE_INFRA: "marine", +} + +APP_TIPPECANOE_LAYERS = { + Application.MARINE_INFRA: "marine", +} + +SHP_EXTENSIONS = [ + ".shp", + ".dbf", + ".prj", + ".shx", +] + +BKT_TILE_PATH = "output_mosaic/" + + +def get_cloudflare_r2_bucket() -> Any: + """Returns the Cloudflare R2 bucket where outputs are published.""" + s3 = boto3.resource( + "s3", + endpoint_url=os.environ["SATLAS_R2_ENDPOINT"], + aws_access_key_id=os.environ["SATLAS_R2_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["SATLAS_R2_SECRET_ACCESS_KEY"], + ) + bucket = s3.Bucket(os.environ["SATLAS_R2_BUCKET_NAME"]) + return bucket + + +def make_shapefile_zip(fname: str) -> str: + """Create zip file of the shapefile and its supporting files. + + If filename is "x" (for x.shp and supporting files) then output is "x.shp.zip". + + Args: + fname: fname without .shp extension + + Returns: + the local filename of the resulting zip file. + """ + zip_fname = fname + ".shp.zip" + basename = os.path.basename(fname) + with zipfile.ZipFile(zip_fname, "w") as z: + for ext in SHP_EXTENSIONS: + z.write(fname + ext, arcname=basename + ext) + return zip_fname + + +def update_index(bucket: Any, prefix: str) -> None: + """Update index file on Cloudflare R2. + + The index file just has list of filenames, last modified time, and md5. + + There is one index for each application folder. + + Args: + bucket: the Cloudflare R2 bucket. + prefix: the folder's prefix in the bucket. + """ + index_lines = [] + for obj in bucket.objects.filter(Prefix=prefix): + if obj.key.endswith("/index.txt"): + continue + line = "{},{},{}".format( + obj.key, obj.last_modified, obj.e_tag.split("-")[0].replace('"', "") + ) + index_lines.append(line) + index_lines.append("") + index_data = "\n".join(index_lines) + bucket.put_object( + Body=index_data.encode(), + Key=prefix + "index.txt", + ) + + +def publish_points( + application: Application, + smoothed_path: str, + version: str, + workers: int = 32, +) -> None: + """Publish Satlas point outputs. + + The points are added to two locations: GeoJSONs are added to Cloudflare R2, while + tippecanoe is used to generate vector tiles that are uploaded to GCS for use by the + satlas.allen.ai website. + + Args: + application: the application. + smoothed_path: folder containing smoothed predictions (including + history.geojson file). + version: current model version for use to distinguish different outputs on GCS. + workers: number of worker processes. + """ + smoothed_upath = UPath(smoothed_path) + + # First upload files to R2. + bucket = get_cloudflare_r2_bucket() + with tempfile.TemporaryDirectory() as tmp_dir: + # Upload history. + logger.info("upload history") + local_hist_fname = os.path.join(tmp_dir, "history.geojson") + with (smoothed_upath / "history.geojson").open("rb") as src: + with open(local_hist_fname, "wb") as dst: + shutil.copyfileobj(src, dst) + app_name_on_r2 = APP_NAME_ON_R2[application] + bucket.upload_file(local_hist_fname, f"outputs/{app_name_on_r2}/marine.geojson") + + # Upload the latest outputs too. + available_fnames: list[UPath] = [] + for fname in smoothed_upath.iterdir(): + if fname.name == "history.geojson": + continue + available_fnames.append(fname) + available_fnames.sort(key=lambda fname: fname.name) + for fname in available_fnames[-NUM_RECOMPUTE:]: + logger.info("upload %s", str(fname)) + local_geojson_fname = os.path.join(tmp_dir, "data.geojson") + # local_shp_prefix = os.path.join(tmp_dir, "shp_data") + # local_kml_fname = os.path.join(tmp_dir, "data.kml") + + with fname.open("rb") as src: + with open(local_geojson_fname, "wb") as dst: + shutil.copyfileobj(src, dst) + + """ + subprocess.check_call([ + 'ogr2ogr', + '-F', 'ESRI Shapefile', + '-nlt', 'POINT', + local_shp_prefix + ".shp", + local_geojson_fname, + ]) + make_shapefile_zip(local_shp_prefix) + subprocess.check_call([ + 'ogr2ogr', + '-F', 'KML', + local_kml_fname, + local_geojson_fname, + ]) + """ + + fname_prefix = fname.name.split(".")[0] + + bucket.upload_file( + local_geojson_fname, + f"outputs/{app_name_on_r2}/{fname_prefix}.geojson", + ) + """ + bucket.upload_file( + local_shp_prefix + ".shp.zip", + f"outputs/{app_name_on_r2}/{fname_prefix}.shp.zip", + ) + bucket.upload_file( + local_kml_fname, + f"outputs/{app_name_on_r2}/{fname_prefix}.kml", + ) + """ + if fname == available_fnames[-1]: + bucket.upload_file( + local_geojson_fname, + f"outputs/{app_name_on_r2}/latest.geojson", + ) + """ + bucket.upload_file( + local_shp_prefix + ".shp.zip", + f"outputs/{app_name_on_r2}/latest.shp.zip", + ) + bucket.upload_file( + local_kml_fname, + f"outputs/{app_name_on_r2}/latest.kml", + ) + """ + + update_index(bucket, f"outputs/{app_name_on_r2}/") + + # Generate the tippecanoe tiles. + # We set tippecanoe layer via property of each feature. + with tempfile.TemporaryDirectory() as tmp_dir: + tippecanoe_layer = APP_TIPPECANOE_LAYERS[application] + with (smoothed_upath / "history.geojson").open("rb") as f: + fc = json.load(f) + for feat in fc["features"]: + feat["tippecanoe"] = {"layer": tippecanoe_layer} + local_geojson_fname = os.path.join(tmp_dir, "history.geojson") + with open(local_geojson_fname, "w") as f: + json.dump(fc, f) + + local_tile_dir = os.path.join(tmp_dir, "tiles") + logger.info("run tippecanoe on history in local tmp dir %s", local_tile_dir) + subprocess.check_call( + [ + "tippecanoe", + "-z13", + "-r1", + "--cluster-densest-as-needed", + "--no-tile-compression", + "-e", + local_tile_dir, + local_geojson_fname, + ] + ) # nosec + + tile_dst_path = f"{BKT_TILE_PATH}{version}/history/{{zoom}}/0/0.bkt" + logger.info("make bkt at %s", tile_dst_path) + make_bkt(src_dir=local_tile_dir, dst_path=tile_dst_path) diff --git a/rslp/satlas/scripts/smooth_point_labels_viterbi.go b/rslp/satlas/scripts/smooth_point_labels_viterbi.go index 0c4dbb1..94cdd97 100644 --- a/rslp/satlas/scripts/smooth_point_labels_viterbi.go +++ b/rslp/satlas/scripts/smooth_point_labels_viterbi.go @@ -16,6 +16,11 @@ import ( const FUTURE_LABEL = "2030-01" const TILE_SIZE = 2048 +// Don't consider groups with fewer than this many valid timesteps. +// Note that the point doesn't need to be detected in all the timesteps, this is just +// timesteps where we have image coverage. +const MIN_VALID_TIMESTEPS = 8 + type Tile struct { Projection string Column int @@ -23,6 +28,7 @@ type Tile struct { } type Point struct { + Type string `json:"type"` Geometry struct { Type string `json:"type"` Coordinates [2]float64 `json:"coordinates"` @@ -96,6 +102,7 @@ func main() { outFname := flag.String("out", "", "Output filename with LABEL placeholder like out/LABEL.geojson") histFname := flag.String("hist", "", "Merged history output filename") distanceThreshold := flag.Float64("max_dist", 200, "Matching distance threshold in meters") + nmsDistance := flag.Float64("nms_dist", 200.0/111111, "NMS distance in degrees") numThreads := flag.Int("threads", 32, "Number of threads") flag.Parse() @@ -149,20 +156,33 @@ func main() { for groupIdx, group := range groups { projection := *group[0].Properties.Projection center := group.Center() - indices := gridIndexes[projection].Search(common.Rectangle{ - Min: common.Point{float64(center[0]) - GridSize, float64(center[1]) - GridSize}, - Max: common.Point{float64(center[0]) + GridSize, float64(center[1]) + GridSize}, - }) + + // Lookup candidate new points that could match this group using the grid index. + var indices []int + if gridIndexes[projection] != nil { + indices = gridIndexes[projection].Search(common.Rectangle{ + Min: common.Point{float64(center[0]) - GridSize, float64(center[1]) - GridSize}, + Max: common.Point{float64(center[0]) + GridSize, float64(center[1]) + GridSize}, + }) + } + var closestIdx int = -1 var closestDistance float64 for _, idx := range indices { if matchedIndices[idx] { continue } - if *group[0].Properties.Category != *curPoints[idx].Properties.Category { - continue - } + // Double check distance threshold since the index may still return + // points that are slightly outside the threshold. + // We used to check category too, but now we use the category of the + // last prediction, and just apply a distance penalty for mismatched + // category, since we noticed that sometimes there are partially + // constructed wind turbines detected as platforms but then later + // detected as turbines once construction is done, and we don't want + // that to mess up the Viterbi smoothing. Put another way, marine + // infrastructure should show up in our map even if we're not exactly + // sure about the category. dx := center[0] - *curPoints[idx].Properties.Column dy := center[1] - *curPoints[idx].Properties.Row distance := math.Sqrt(float64(dx*dx + dy*dy)) @@ -170,6 +190,11 @@ func main() { if distance > *distanceThreshold/MetersPerPixel { continue } + + if *group[0].Properties.Category != *curPoints[idx].Properties.Category { + distance += *distanceThreshold / MetersPerPixel + } + if closestIdx == -1 || distance < closestDistance { closestIdx = idx closestDistance = distance @@ -205,6 +230,51 @@ func main() { } } + // Apply non-maximal suppression over groups. + // We prefer longer groups, or if they are the same length, the group with higher + // last score. + log.Println("applying non-maximal suppression") + nmsIndex := common.NewGridIndex(*nmsDistance * 5) + for groupIdx, group := range groups { + last := group[len(group)-1] + coordinates := last.Geometry.Coordinates + nmsIndex.Insert(groupIdx, common.Point{coordinates[0], coordinates[1]}.Rectangle()) + } + var newGroups []Group + for groupIdx, group := range groups { + last := group[len(group)-1] + coordinates := last.Geometry.Coordinates + results := nmsIndex.Search(common.Point{coordinates[0], coordinates[1]}.RectangleTol(*nmsDistance)) + needsRemoval := false + for _, otherIdx := range results { + if otherIdx == groupIdx { + continue + } + other := groups[otherIdx] + otherLast := other[len(other)-1] + otherCoordinates := otherLast.Geometry.Coordinates + dx := coordinates[0] - otherCoordinates[0] + dy := coordinates[1] - otherCoordinates[1] + distance := math.Sqrt(float64(dx*dx + dy*dy)) + if distance >= *nmsDistance { + continue + } + + // It is within distance threshold, so see if group is worse than other. + if len(group) < len(other) { + needsRemoval = true + } else if len(group) == len(other) && *last.Properties.Score < *otherLast.Properties.Score { + needsRemoval = true + } + } + + if !needsRemoval { + newGroups = append(newGroups, group) + } + } + log.Printf("NMS filtered from %d to %d groups", len(groups), len(newGroups)) + groups = newGroups + // Apply Viterbi algorithm in each group. initialProbs := []float64{0.5, 0.5} transitionProbs := [][]float64{ @@ -318,6 +388,10 @@ func main() { validLabelSet[label] = true } + if len(validLabelSet) < MIN_VALID_TIMESTEPS { + continue + } + // Now make history of observations for Viterbi algorithm. // We only include timesteps where the tile was valid. // We also create a map from observed timesteps to original timestep index. @@ -363,6 +437,7 @@ func main() { for _, rng := range curRngs { last := rng.Group[len(rng.Group)-1] feat := Point{} + feat.Type = "Feature" feat.Geometry = last.Geometry feat.Properties.Category = last.Properties.Category feat.Properties.Score = last.Properties.Score