From 167c3d69359f9b3abb49a3c1c5aa6249f76c0992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roberto=20Antol=C3=ADn?= Date: Thu, 31 Oct 2024 13:53:21 +0100 Subject: [PATCH] Bug/sc 425841/internal team raster loader killed when using (#153) Co-authored-by: cayetanobv --- raster_loader/cli/bigquery.py | 16 + raster_loader/cli/snowflake.py | 16 + raster_loader/io/bigquery.py | 39 +- raster_loader/io/common.py | 615 +++++++++++++++++----- raster_loader/io/snowflake.py | 45 +- raster_loader/tests/bigquery/test_cli.py | 54 ++ raster_loader/tests/bigquery/test_io.py | 14 +- raster_loader/tests/snowflake/test_cli.py | 70 +++ raster_loader/tests/snowflake/test_io.py | 10 + raster_loader/utils.py | 18 + 10 files changed, 754 insertions(+), 143 deletions(-) diff --git a/raster_loader/cli/bigquery.py b/raster_loader/cli/bigquery.py index caf90c7..e1c1279 100644 --- a/raster_loader/cli/bigquery.py +++ b/raster_loader/cli/bigquery.py @@ -77,6 +77,18 @@ def bigquery(args=None): default=False, is_flag=True, ) +@click.option( + "--exact_stats", + help="Compute exact statistics for the raster bands.", + default=False, + is_flag=True, +) +@click.option( + "--all_stats", + help="Compute all statistics including quantiles and most frequent values.", + required=False, + is_flag=True, +) @catch_exception() def upload( file_path, @@ -91,6 +103,8 @@ def upload( overwrite=False, append=False, cleanup_on_failure=False, + exact_stats=False, + all_stats=False, ): from raster_loader.io.common import ( get_number_of_blocks, @@ -161,6 +175,8 @@ def upload( overwrite=overwrite, append=append, cleanup_on_failure=cleanup_on_failure, + exact_stats=exact_stats, + all_stats=all_stats, ) click.echo("Raster file uploaded to Google BigQuery") diff --git a/raster_loader/cli/snowflake.py b/raster_loader/cli/snowflake.py index 1574f37..e5894c5 100644 --- a/raster_loader/cli/snowflake.py +++ b/raster_loader/cli/snowflake.py @@ -86,6 +86,18 @@ def snowflake(args=None): default=False, is_flag=True, ) +@click.option( + "--exact_stats", + help="Compute exact statistics for the raster bands.", + default=False, + is_flag=True, +) +@click.option( + "--all_stats", + help="Compute all statistics including quantiles and most frequent values.", + required=False, + is_flag=True, +) @catch_exception() def upload( account, @@ -104,6 +116,8 @@ def upload( overwrite=False, append=False, cleanup_on_failure=False, + exact_stats=False, + all_stats=False, ): from raster_loader.io.common import ( get_number_of_blocks, @@ -185,6 +199,8 @@ def upload( overwrite=overwrite, append=append, cleanup_on_failure=cleanup_on_failure, + exact_stats=exact_stats, + all_stats=all_stats, ) click.echo("Raster file uploaded to Snowflake") diff --git a/raster_loader/io/bigquery.py b/raster_loader/io/bigquery.py index e995f51..17848a8 100644 --- a/raster_loader/io/bigquery.py +++ b/raster_loader/io/bigquery.py @@ -4,14 +4,17 @@ import rasterio import re +from itertools import chain from raster_loader import __version__ from raster_loader.errors import import_error_bigquery, IncompatibleRasterException from raster_loader.utils import ask_yes_no_question, batched from raster_loader.io.common import ( + check_metadata_is_compatible, + get_number_of_blocks, + get_number_of_overviews_blocks, rasterio_metadata, + rasterio_overview_to_records, rasterio_windows_to_records, - get_number_of_blocks, - check_metadata_is_compatible, update_metadata, ) @@ -104,6 +107,8 @@ def upload_raster( overwrite: bool = False, append: bool = False, cleanup_on_failure: bool = False, + exact_stats: bool = False, + all_stats: bool = False, ): """Write a raster file to a BigQuery table.""" print("Loading raster file to BigQuery...") @@ -126,21 +131,31 @@ def upload_raster( exit() metadata = rasterio_metadata( - file_path, bands_info, self.band_rename_function + file_path, bands_info, self.band_rename_function, exact_stats, all_stats ) - records_gen = rasterio_windows_to_records( + overviews_records_gen = rasterio_overview_to_records( + file_path, + self.band_rename_function, + bands_info + ) + + windows_records_gen = rasterio_windows_to_records( file_path, self.band_rename_function, bands_info, ) + records_gen = chain(overviews_records_gen, windows_records_gen) if append_records: old_metadata = self.get_metadata(fqn) check_metadata_is_compatible(metadata, old_metadata) update_metadata(metadata, old_metadata) - total_blocks = get_number_of_blocks(file_path) + number_of_blocks = get_number_of_blocks(file_path) + number_of_overview_tiles = get_number_of_overviews_blocks(file_path) + total_blocks = number_of_blocks + number_of_overview_tiles + if chunk_size is None: job = self.upload_records(records_gen, fqn) # raise error if job went wrong (blocking call) @@ -150,7 +165,10 @@ def upload_raster( jobs = [] errors = [] - print(f"Writing {total_blocks} blocks to BigQuery...") + print( + f"Writing {number_of_blocks} blocks and {number_of_overview_tiles} " + "overview tiles to BigQuery..." + ) with tqdm(total=total_blocks) as pbar: if total_blocks < chunk_size: chunk_size = total_blocks @@ -167,9 +185,11 @@ def done_callback(job): # job already removed because failed pass + processed_blocks = 0 for records in batched(records_gen, chunk_size): job = self.upload_records(records, fqn) job.num_records = len(records) + processed_blocks += len(records) job.add_done_callback(partial(lambda job: done_callback(job))) jobs.append(job) @@ -185,7 +205,10 @@ def done_callback(job): if len(errors): raise Exception(errors) - pbar.update(1) + empty_blocks = total_blocks - processed_blocks + pbar.update(empty_blocks) + + print("Number of empty blocks: ", empty_blocks) print("Writing metadata to BigQuery...") self.write_metadata(metadata, append_records, fqn) @@ -220,6 +243,8 @@ def done_callback(job): if delete: self.delete_table(fqn) + import traceback + print(traceback.print_exc()) raise IOError("Error uploading to BigQuery: {}".format(e)) print("Done.") diff --git a/raster_loader/io/common.py b/raster_loader/io/common.py index 6875061..142abf0 100644 --- a/raster_loader/io/common.py +++ b/raster_loader/io/common.py @@ -1,15 +1,12 @@ -import sys import math +import numpy as np import pyproj import shapely -import numpy as np +import sys from raster_loader._version import __version__ from collections import Counter -from typing import Iterable -from typing import Callable -from typing import List -from typing import Tuple +from typing import Dict, Callable, Iterable, List, Tuple, Union from affine import Affine from shapely import wkt # Can not use directly from shapely.wkt @@ -21,6 +18,7 @@ from raster_loader.errors import ( error_not_google_compatible, ) +from raster_loader.utils import warnings DEFAULT_COG_BLOCK_SIZE = 256 @@ -38,8 +36,15 @@ "float64": np.nan, } +DEFAULT_MAX_MOST_COMMON = 10 +DEFAULT_SAMPLING_MAX_ITERATIONS = 10 +DEFAULT_SAMPLING_MAX_SAMPLES = 1000 +DEFAULT_OVERVIEWS = range(3, 20) + should_swap = {"=": sys.byteorder != "little", "<": False, ">": True, "|": False} +Samples = Dict[int, List[Union[int, float]]] + def band_field_name(custom_name: str, band: int, band_rename_function: Callable) -> str: return band_rename_function(custom_name or "band_" + str(band)) @@ -99,12 +104,17 @@ def array_to_record( geotransform: Affine, resolution: int, window: rasterio.windows.Window, + no_data_value: float = None ) -> dict: row_off = window.row_off col_off = window.col_off width = window.width height = window.height + # Skip blocks without any data to relieve loading burden + if no_data_value is not None and np.all(arr == no_data_value): + return None + x, y = transformer.transform( *(geotransform * (col_off + width * 0.5, row_off + height * 0.5)) ) @@ -152,10 +162,21 @@ def get_resolution_and_block_sizes( return block_width, block_height, resolution +def get_color_table(raster_dataset: rasterio.io.DatasetReader, band: int): + try: + if raster_dataset.colorinterp[band - 1].name == "palette": + return raster_dataset.colormap(band) + return None + except ValueError: + return None + + def rasterio_metadata( file_path: str, bands_info: List[Tuple[int, str]], band_rename_function: Callable, + exact_stats: bool = False, + all_stats: bool = False, ): """Open a raster file with rasterio.""" raster_info = rio_cogeo.cog_info(file_path).dict() @@ -190,8 +211,40 @@ def rasterio_metadata( if metadata["nodata"] is not None and math.isnan(metadata["nodata"]): metadata["nodata"] = None bands_metadata = [] + + # We only need to sample if we are not computing exact stats + # and we need to sample from the raster dataset just once! + if not exact_stats: + samples = sample_not_masked_values( + raster_dataset, DEFAULT_SAMPLING_MAX_SAMPLES + ) + for band, band_name in bands_info: - band_colorinterp = raster_dataset.colorinterp[band - 1].name + + if exact_stats: + print("Computing exact stats...") + warnings.warn( + "Exact statistics can be quite resources demanding. " + "User is encourage to compute approximate statistics instead.", + UserWarning, + ) + stats = raster_band_stats(raster_dataset, band, all_stats) + else: + print("Computing approximate stats...") + stats = raster_band_approx_stats( + raster_dataset, samples, band, all_stats + ) + + try: + # There is [an issue](https://github.com/OSGeo/gdal/issues/1928) + # in gdal with the same error message that we see in this line: + # "Failed to compute statistics, no valid pixels found in sampling." + # + # It seems to be an error with cropped rasters. + band_colorinterp = raster_dataset.colorinterp[band - 1].name + except Exception: + band_colorinterp = None + if band_colorinterp == "alpha": band_nodata = "0" else: @@ -203,7 +256,8 @@ def rasterio_metadata( "type": raster_band_type(raster_dataset, band), "name": band_field_name(band_name, band, band_rename_function), "colorinterp": band_colorinterp, - "stats": raster_band_stats(raster_dataset, band), + "colortable": get_color_table(raster_dataset, band), + "stats": stats, "nodata": band_nodata, } bands_metadata.append(meta) @@ -214,7 +268,6 @@ def rasterio_metadata( bounds_coords = list(bounds_polygon.bounds) center_coords = list(*bounds_polygon.centroid.coords) center_coords.append(resolution) - pixel_resolution = int(resolution + math.log(block_width * block_height, 4)) if pixel_resolution > 26: raise ValueError( @@ -230,6 +283,7 @@ def rasterio_metadata( "stats": e["stats"], "colorinterp": e["colorinterp"], "nodata": e["nodata"], + "colortable": e["colortable"], } for e in bands_metadata ] @@ -317,141 +371,330 @@ def band_with_nodata_mask( return (raw_data, raw_data == nodata_value) -def raster_band_stats(raster_dataset: rasterio.io.DatasetReader, band: int) -> dict: +def quantile_ranges() -> List[List[float]]: + """Return a list of ranges to compute quantiles.""" + return [[j / i for j in range(1, i)] for i in DEFAULT_OVERVIEWS] + + +def sample_not_masked_values( + raster_dataset: rasterio.io.DatasetReader, n_samples: int +) -> Samples: + """Compute quantiles for a raster dataset band.""" + def not_enough_samples(): + return ( + len(not_masked_samples[1]) < n_samples + and iterations < DEFAULT_SAMPLING_MAX_ITERATIONS + ) + + west = raster_dataset.bounds.left + south = raster_dataset.bounds.bottom + east = raster_dataset.bounds.right + north = raster_dataset.bounds.top + + bands = range(1, raster_dataset.count + 1) + + not_masked_samples = {b: [] for b in bands} + + iterations = 0 + + print('Sampling raster...') + rng = np.random.default_rng() + while not_enough_samples(): + x = rng.uniform(west, east, n_samples) + y = rng.uniform(south, north, n_samples) + + try: + from rasterio.sample import sort_xy + except ImportError: + coords = zip(x, y) + else: + coords = sort_xy(zip(x, y)) + + samples = raster_dataset.sample(coords, indexes=bands, masked=True) + for sample in samples: + raster_is_masked = ( + sample.mask if isinstance(sample.mask, np.bool_) else sample.mask.any() + ) + if not raster_is_masked: + for band in bands: + not_masked_samples[band].append(sample[band - 1]) + + iterations += 1 + + if len(not_masked_samples[1]) < n_samples: + warnings.warn( + "The data is very sparse and there are not enough non-masked samples.\n" + f"Only {len(not_masked_samples[1])} samples were collected and " + "quantiles and most common values may be inaccurate.", + UserWarning, + ) + + for b in bands: + not_masked_samples[b] = not_masked_samples[b][:n_samples] + + return not_masked_samples + + +def most_common_approx(samples: List[Union[int, float]]) -> Dict[int, int]: + """Compute the most common values in a list of int samples.""" + counts = np.bincount(samples) + nth = min(DEFAULT_MAX_MOST_COMMON, len(counts)) + idx = np.argpartition(counts, -nth)[-nth:] + return dict([(int(i), int(counts[i])) for i in idx if counts[i] > 0]) + + +def compute_quantiles( + data: List[Union[int, float]], cast_function: Callable +) -> dict: + """Compute quantiles for a raster dataset band.""" + print("Computing quantiles...") + quantiles = [ + [cast_function(np.quantile(data, q, method="lower")) for q in r] + for r in quantile_ranges() + ] + return dict(zip(DEFAULT_OVERVIEWS, quantiles)) + + +def get_stats( + raster_dataset: rasterio.io.DatasetReader, band: int +) -> rasterio.Statistics: """Get statistics for a raster band.""" + try: + # stats method is supported since rasterio 1.4.0 and statistics + # method will be deprecated in future versions of rasterio + return raster_dataset.stats(indexes=[band], approx=True)[0] + except AttributeError: + return raster_dataset.statistics(band, approx=True) + + +def raster_band_approx_stats( + raster_dataset: rasterio.io.DatasetReader, + samples: Samples, + band: int, + all_stats: bool, +) -> dict: + """Get approximate statistics for a raster band.""" + + stats = get_stats(raster_dataset, band) + + samples_band = samples[band] + + count = len(samples_band) + _sum = 0 + sum_squares = 0 + if count > 0: + _sum = int(np.sum(samples_band)) + sum_squares = int(np.sum(np.array(samples_band) ** 2)) + + quantiles = None + most_common = None + if all_stats: + + quantiles = compute_quantiles(samples_band, int) + + most_common = dict() + if not band_is_float(raster_dataset, band): + most_common = most_common_approx(samples_band) + + return { + "min": stats.min, + "max": stats.max, + "mean": stats.mean, + "stddev": stats.std, + "sum": _sum, + "sum_squares": sum_squares, + "count": count, + "quantiles": quantiles, + "top_values": most_common, + "version": ".".join(__version__.split(".")[:3]), + "approximated_stats": True, + } + + +def is_masked_band(raster_dataset: rasterio.io.DatasetReader, band: int) -> bool: + """Check if a band is masked.""" alpha_band = get_alpha_band(raster_dataset) original_nodata_value = band_original_nodata_value(raster_dataset, band) - if band == alpha_band or ( + return band == alpha_band or ( alpha_band is None and original_nodata_value is None and not band_is_float(raster_dataset, band) - ): - masked = False + ) + + +def read_raster_band(raster_dataset: rasterio.io.DatasetReader, band: int) -> np.array: + band_is_masked = is_masked_band(raster_dataset, band) + if band_is_masked: unmasked_data = raster_dataset.read(band) - stats = np.ma.masked_array(data=unmasked_data, mask=False) + return np.ma.masked_array(data=unmasked_data, mask=False) + + alpha_band = get_alpha_band(raster_dataset) + if alpha_band: + # mask data with alpha band to exclude from stats + return read_masked(raster_dataset, band, alpha_band) + + # mask nodata values to exclude from stats + compound_bands = get_compound_bands(raster_dataset, band) + if len(compound_bands) > 1: + # if band is part of a RGB triplet, + # we need to use the three bands for masking + bands_and_masks = [ + band_with_nodata_mask(raster_dataset, b) for b in compound_bands + ] + mask = np.logical_and.reduce( + [band_and_mask[1] for band_and_mask in bands_and_masks] + ) + raw_data = bands_and_masks[compound_bands.index(band)][0] else: - masked = True - if alpha_band: - # mask data with alpha band to exclude from stats - stats = read_masked(raster_dataset, band, alpha_band) - else: - # mask nodata values to exclude from stats - compound_bands = get_compound_bands(raster_dataset, band) - if len(compound_bands) > 1: - # if band is part of a RGB triplet, - # we need to use the three bands for masking - bands_and_masks = [ - band_with_nodata_mask(raster_dataset, b) for b in compound_bands - ] - mask = np.logical_and.reduce( - [band_and_mask[1] for band_and_mask in bands_and_masks] - ) - raw_data = bands_and_masks[compound_bands.index(band)][0] - else: - (raw_data, mask) = band_with_nodata_mask(raster_dataset, band) - stats = np.ma.masked_array(data=raw_data, mask=mask) - qdata = stats.compressed() - ranges = [[j / i for j in range(1, i)] for i in range(3, 20)] - casting_function = int if np.issubdtype(stats.dtype, np.integer) else float - quantiles = [ - [casting_function(np.quantile(qdata, q, method="lower")) for q in r] - for r in ranges - ] - quantiles = dict(zip(range(3, 20), quantiles)) - most_common = Counter(qdata).most_common(100) - most_common.sort(key=lambda x: x[1], reverse=True) - most_common = dict([(casting_function(x[0]), x[1]) for x in most_common]) + (raw_data, mask) = band_with_nodata_mask(raster_dataset, band) + + return np.ma.masked_array(data=raw_data, mask=mask) + + +def raster_band_stats( + raster_dataset: rasterio.io.DatasetReader, band: int, all_stats: bool +) -> dict: + """Get statistics for a raster band.""" + + print('Computing stats for band {0}...'.format(band)) + + _stats = get_stats(raster_dataset, band) + _min = _stats.min + _max = _stats.max + _mean = _stats.mean + _std = _stats.std + + count = math.prod(_stats.shape) + if is_masked_band(raster_dataset, band): + count = np.count_nonzero(_stats.mask is False) + + _sum = _mean * count + sum_squares = count * _std ** 2 + _mean ** 2 + + quantiles = None + most_common = None + if all_stats: + raster_band = read_raster_band(raster_dataset=raster_dataset, band=band) + + print("Removing masked data...") + qdata = raster_band.compressed() + + casting_function = ( + int if np.issubdtype(raster_band.dtype, np.integer) else float + ) + + quantiles = compute_quantiles(qdata, casting_function) + + print("Computing most commons values...") + warnings.warn( + "Most common values are meant for categorical data. " + "Computing them for float bands can be meaningless." + ) + most_common = Counter(qdata).most_common(100) + most_common.sort(key=lambda x: x[1], reverse=True) + most_common = dict([(casting_function(x[0]), x[1]) for x in most_common]) + version = ".".join(__version__.split(".")[:3]) + return { - "min": float(stats.min()), - "max": float(stats.max()), - "mean": float(stats.mean()), - "stddev": float(stats.std()), - "sum": float(stats.sum()), - "sum_squares": float((stats**2).sum()), + "min": float(_min), + "max": float(_max), + "mean": float(_mean), + "stddev": float(_std), + "sum": _sum, + "sum_squares": sum_squares, + "count": count, "quantiles": quantiles, "top_values": most_common, "version": version, - "count": ( - np.count_nonzero(stats.mask is False) if masked else math.prod(stats.shape) - ), # noqa: E712 + "approximated_stats": False, } -def rasterio_windows_to_records( - file_path: str, - band_rename_function: Callable, - bands_info: List[Tuple[int, str]], -) -> Iterable: - invalid_names = [ - name for _, name in bands_info if name and name.lower() in ["block", "metadata"] - ] - if invalid_names: - raise ValueError(f"Invalid band names: {', '.join(invalid_names)}") +def get_number_of_overviews_blocks(file_path: str) -> int: - """Open a raster file with rio-cogeo.""" raster_info = rio_cogeo.cog_info(file_path).dict() - - """Check if raster is compatible.""" - if "GoogleMapsCompatible" != raster_info.get("Tags", {}).get( - "Tiling Scheme", {} - ).get("NAME"): - error_not_google_compatible() - - """Open a raster file with rasterio.""" with rasterio.open(file_path) as raster_dataset: + overview_factors = raster_dataset.overviews(1) block_width, block_height, resolution = get_resolution_and_block_sizes( raster_dataset, raster_info ) raster_crs = raster_dataset.crs.to_string() - raster_to_4326_transformer = pyproj.Transformer.from_crs( raster_crs, "EPSG:4326", always_xy=True ) - # raster_crs must be 3857 pixels_to_raster_transform = raster_dataset.transform - # Base raster - for _, window in raster_dataset.block_windows(): - record = {} - no_data_value = get_nodata_value(raster_dataset) - for band, band_name in bands_info: - tile_data = read_filled( - raster_dataset, band, no_data_value, window=window, boundless=True - ) - newrecord = array_to_record( - tile_data, - band_field_name(band_name, band, band_rename_function), - band_rename_function, - raster_to_4326_transformer, - pixels_to_raster_transform, - resolution, - window, + # results are crs 4326, so x = long, y = lat + min_base_tile_lng, min_base_tile_lat = raster_to_4326_transformer.transform( + *(pixels_to_raster_transform * (block_width * 0.5, block_height * 0.5)) + ) + max_base_tile_lng, max_base_tile_lat = raster_to_4326_transformer.transform( + *( + pixels_to_raster_transform + * ( + raster_dataset.width - block_width * 0.5, + raster_dataset.height - block_height * 0.5, ) + ) + ) - # add the new columns generated by array_t - # o_record - # but leaving unchanged the index e.g. the block column - record.update(newrecord) + # quadbin cell at base resolution + min_base_tile = quadbin.point_to_cell( + min_base_tile_lng, min_base_tile_lat, resolution + ) + min_base_x, min_base_y, _z = quadbin.cell_to_tile(min_base_tile) + + n_records = 0 + for overview_index in range(0, len(overview_factors)): + # quadbin cell at overview resolution (quadbin_tile -> quadbin_cell) + min_tile = quadbin.point_to_cell( + min_base_tile_lng, min_base_tile_lat, resolution - overview_index - 1 + ) + max_tile = quadbin.point_to_cell( + max_base_tile_lng, max_base_tile_lat, resolution - overview_index - 1 + ) + min_x, min_y, min_z = quadbin.cell_to_tile(min_tile) + max_x, max_y, _z = quadbin.cell_to_tile(max_tile) - yield record + n_records += (max_x - min_x + 1) * (max_y - min_y + 1) - # Overviews + return n_records - # Block size must be equal for all bands; - # We avoid looping here over bands because we need - # to loop internally to accumulate, for each block - # the data for all bands. - if not is_valid_block_shapes(raster_dataset.block_shapes): - raise ValueError("Invalid block shapes: must be equal for all bands") - (block_width, block_height) = raster_dataset.block_shapes[0] - overview_factors = raster_dataset.overviews(1) +def rasterio_overview_to_records( + # raster_dataset: rasterio.io.DatasetReader, + file_path: str, + band_rename_function: Callable, + bands_info: List[Tuple[int, str]] +) -> Iterable: + raster_info = rio_cogeo.cog_info(file_path).dict() + with rasterio.open(file_path) as raster_dataset: + block_width, block_height, resolution = get_resolution_and_block_sizes( + raster_dataset, raster_info + ) + raster_crs = raster_dataset.crs.to_string() - if not is_valid_overview_indexes(overview_factors): - raise ValueError( - "Invalid overview factors: must be consecutive powers of 2" - ) + raster_to_4326_transformer = pyproj.Transformer.from_crs( + raster_crs, "EPSG:4326", always_xy=True + ) + # raster_crs must be 3857 + pixels_to_raster_transform = raster_dataset.transform + is_valid_raster_dataset(raster_dataset) + + block_width, block_height, resolution = get_resolution_and_block_sizes( + raster_dataset, raster_info + ) + raster_crs = raster_dataset.crs.to_string() + + raster_to_4326_transformer = pyproj.Transformer.from_crs( + raster_crs, "EPSG:4326", always_xy=True + ) + # raster_crs must be 3857 + pixels_to_raster_transform = raster_dataset.transform + + overview_factors = raster_dataset.overviews(1) + (block_width, block_height) = raster_dataset.block_shapes[0] for overview_index in range(0, len(overview_factors)): # results are crs 4326, so x = long, y = lat @@ -482,6 +725,7 @@ def rasterio_windows_to_records( ) min_x, min_y, min_z = quadbin.cell_to_tile(min_tile) max_x, max_y, _z = quadbin.cell_to_tile(max_tile) + for tile_x in range(min_x, max_x + 1): for tile_y in range(min_y, max_y + 1): children = quadbin.cell_to_children( @@ -494,6 +738,7 @@ def rasterio_windows_to_records( min_child_x, max_child_x = min(child_xs), max(child_xs) min_child_y, max_child_y = min(child_ys), max(child_ys) factor = overview_factors[overview_index] + # tile_window for current overview tile_window = rasterio.windows.Window( col_off=block_width * (min_child_x - min_base_x), @@ -527,10 +772,74 @@ def rasterio_windows_to_records( pixels_to_raster_transform, resolution - overview_index - 1, tile_window, + no_data_value ) - record.update(newrecord) + if newrecord: + record.update(newrecord) + + if record: + yield record - yield record + +def rasterio_windows_to_records( + file_path: str, + band_rename_function: Callable, + bands_info: List[Tuple[int, str]], +) -> Iterable: + invalid_names = [ + name for _, name in bands_info if name and name.lower() in ["block", "metadata"] + ] + if invalid_names: + raise ValueError(f"Invalid band names: {', '.join(invalid_names)}") + + """Open a raster file with rio-cogeo.""" + raster_info = rio_cogeo.cog_info(file_path).dict() + + """Check if raster is compatible.""" + if "GoogleMapsCompatible" != raster_info.get("Tags", {}).get( + "Tiling Scheme", {} + ).get("NAME"): + error_not_google_compatible() + + """Open a raster file with rasterio.""" + with rasterio.open(file_path) as raster_dataset: + block_width, block_height, resolution = get_resolution_and_block_sizes( + raster_dataset, raster_info + ) + raster_crs = raster_dataset.crs.to_string() + + raster_to_4326_transformer = pyproj.Transformer.from_crs( + raster_crs, "EPSG:4326", always_xy=True + ) + # raster_crs must be 3857 + pixels_to_raster_transform = raster_dataset.transform + + # Base raster + for _, window in raster_dataset.block_windows(): + record = {} + no_data_value = get_nodata_value(raster_dataset) + for band, band_name in bands_info: + tile_data = read_filled( + raster_dataset, band, no_data_value, window=window, boundless=True + ) + newrecord = array_to_record( + tile_data, + band_field_name(band_name, band, band_rename_function), + band_rename_function, + raster_to_4326_transformer, + pixels_to_raster_transform, + resolution, + window, + no_data_value + ) + + # add the new columns generated by array_t + # o_record but leaving unchanged the index e.g. the block column + if newrecord: + record.update(newrecord) + + if record: + yield record def is_valid_overview_indexes(overview_factors) -> bool: @@ -541,6 +850,10 @@ def is_valid_overview_indexes(overview_factors) -> bool: def is_valid_block_shapes(block_shapes) -> bool: + # Block size must be equal for all bands; + # We avoid looping here over bands because we need + # to loop internally to accumulate, for each block + # the data for all bands. (block_width, block_height) = block_shapes[0] for block_shape_index in range(0, len(block_shapes)): (index_block_width, index_block_height) = block_shapes[block_shape_index] @@ -549,8 +862,24 @@ def is_valid_block_shapes(block_shapes) -> bool: return True +def is_valid_raster_dataset(raster_dataset: rasterio.io.DatasetReader) -> bool: + + if not is_valid_block_shapes(raster_dataset.block_shapes): + raise ValueError("Invalid block shapes: must be equal for all bands") + + if not is_valid_overview_indexes(raster_dataset.overviews(1)): + raise ValueError( + "Invalid overview factors: must be consecutive powers of 2" + ) + + return True + + def band_without_stats(band): - return {k: band[k] for k in set(list(band.keys())) - set(["stats"])} + return { + k: band[k] + for k in set(list(band.keys())) - set(["stats", "colorinterp", "colortable"]) + } def bands_without_stats(metadata): @@ -633,35 +962,67 @@ def update_metadata(metadata, old_metadata): metadata["height"] = (s - n) * metadata["block_height"] metadata["width"] = (e - w) * metadata["block_width"] - for old_band in old_metadata["bands"]: - new_band = next( - (band for band in metadata["bands"] if band["name"] == old_band["name"]), + for band in metadata["bands"]: + old_band = next( + ( + old_band + for old_band in old_metadata["bands"] + if old_band["name"] == band["name"] + ), None, ) - if new_band is None: + + if old_band is None: + # Extra precaution as this should never happen raise ValueError( "Cannot append records to a table with different bands" f"(band {old_band['name']} not found)." ) - new_stats = new_band["stats"] + + new_stats = band["stats"] old_stats = old_band["stats"] - sum = old_stats["sum"] + new_stats["sum"] + + _min = min(old_stats["min"], new_stats["min"]) + _max = max(old_stats["max"], new_stats["max"]) + _sum = old_stats["sum"] + new_stats["sum"] sum_squares = old_stats["sum_squares"] + new_stats["sum_squares"] count = old_stats["count"] + new_stats["count"] - mean = sum / count - new_band["stats"] = { - "min": min(old_stats["min"], new_stats["min"]), - "max": max(old_stats["max"], new_stats["max"]), - "sum": sum, + + if old_stats["count"] == 0 or new_stats["count"] == 0: + mean = (old_stats["mean"] + new_stats["mean"]) / 2 + stdev = math.sqrt(old_stats["stddev"] ** 2 + new_stats["stddev"] ** 2) + else: + mean = _sum / count + stdev = math.sqrt(sum_squares / count - mean * mean) + + approximated_stats = ( + old_stats["approximated_stats"] or new_stats["approximated_stats"] + ) + + try: + top_values = set(new_stats["top_values"] + old_stats["top_values"]) + top_values = list(top_values).sort(reverse=True)[:DEFAULT_MAX_MOST_COMMON] + except TypeError: + # If top values in any of either metadata is None, + # the above will raise a TypeError + top_values = None + + version = max(old_stats["version"], new_stats["version"]) + + band["stats"] = { + "min": _min, + "max": _max, + "sum": _sum, "sum_squares": sum_squares, "count": count, "mean": mean, - "stddev": math.sqrt(sum_squares / count - mean * mean), + "stddev": stdev, + 'quantiles': None, + 'top_values': top_values, + 'version': version, + 'approximated_stats': approximated_stats, } - if old_band not in metadata["bands"]: - metadata["bands"].append(old_band) - def get_number_of_blocks(file_path: str) -> int: """Get the number of blocks in a raster file.""" diff --git a/raster_loader/io/snowflake.py b/raster_loader/io/snowflake.py index fcd10be..af0caa5 100644 --- a/raster_loader/io/snowflake.py +++ b/raster_loader/io/snowflake.py @@ -2,6 +2,7 @@ import rasterio import pandas as pd +from itertools import chain from typing import Iterable, List, Tuple from raster_loader.errors import ( @@ -13,8 +14,10 @@ from raster_loader.io.common import ( rasterio_metadata, + rasterio_overview_to_records, rasterio_windows_to_records, get_number_of_blocks, + get_number_of_overviews_blocks, check_metadata_is_compatible, update_metadata, ) @@ -175,7 +178,12 @@ def upload_raster( overwrite: bool = False, append: bool = False, cleanup_on_failure: bool = False, + exact_stats: bool = False, + all_stats: bool = False, ) -> bool: + def band_rename_function(x): + return x.upper() + print("Loading raster file to Snowflake...") bands_info = bands_info or [(1, None)] @@ -198,15 +206,26 @@ def upload_raster( if not append_records: exit() - metadata = rasterio_metadata(file_path, bands_info, lambda x: x.upper()) + metadata = rasterio_metadata( + file_path, bands_info, band_rename_function, exact_stats, all_stats + ) - records_gen = rasterio_windows_to_records( + overviews_records_gen = rasterio_overview_to_records( + file_path, + band_rename_function, + bands_info, + ) + windows_records_gen = rasterio_windows_to_records( file_path, - lambda x: x.upper(), + band_rename_function, bands_info, ) - total_blocks = get_number_of_blocks(file_path) + records_gen = chain(overviews_records_gen, windows_records_gen) + + number_of_blocks = get_number_of_blocks(file_path) + number_of_overview_tiles = get_number_of_overviews_blocks(file_path) + total_blocks = number_of_blocks + number_of_overview_tiles if chunk_size is None: ret = self.upload_records(records_gen, fqn, overwrite) @@ -215,20 +234,32 @@ def upload_raster( else: from tqdm.auto import tqdm - print(f"Writing {total_blocks} blocks to Snowflake...") + processed_blocks = 0 + print( + f"Writing {number_of_blocks} blocks and {number_of_overview_tiles} " + "overview tiles to Snowflake..." + ) with tqdm(total=total_blocks) as pbar: if total_blocks < chunk_size: chunk_size = total_blocks isFirstBatch = True + for records in batched(records_gen, chunk_size): ret = self.upload_records( records, fqn, overwrite and isFirstBatch ) - pbar.update(chunk_size) + num_records = len(records) + processed_blocks += num_records + pbar.update(num_records) + if not ret: raise IOError("Error uploading to Snowflake.") isFirstBatch = False - pbar.update(1) + + empty_blocks = total_blocks - processed_blocks + pbar.update(empty_blocks) + + print("Number of empty blocks: ", empty_blocks) print("Writing metadata to Snowflake...") if append_records: diff --git a/raster_loader/tests/bigquery/test_cli.py b/raster_loader/tests/bigquery/test_cli.py index ed809ab..9ff2214 100644 --- a/raster_loader/tests/bigquery/test_cli.py +++ b/raster_loader/tests/bigquery/test_cli.py @@ -38,6 +38,60 @@ def test_bigquery_upload(*args, **kwargs): assert result.exit_code == 0 +@patch("raster_loader.cli.bigquery.BigQueryConnection.upload_raster", return_value=None) +@patch("raster_loader.cli.bigquery.BigQueryConnection.__init__", return_value=None) +def test_bigquery_upload_with_all_stats(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "bigquery", + "upload", + "--file_path", + f"{tiff}", + "--project", + "project", + "--dataset", + "dataset", + "--table", + "table", + "--chunk_size", + 1, + "--band", + 1, + "--all_stats", + ], + ) + assert result.exit_code == 0 + + +@patch("raster_loader.cli.bigquery.BigQueryConnection.upload_raster", return_value=None) +@patch("raster_loader.cli.bigquery.BigQueryConnection.__init__", return_value=None) +def test_bigquery_upload_with_exact_stats(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "bigquery", + "upload", + "--file_path", + f"{tiff}", + "--project", + "project", + "--dataset", + "dataset", + "--table", + "table", + "--chunk_size", + 1, + "--band", + 1, + "--exact_stats", + ], + ) + assert result.exit_code == 0 + + @patch("raster_loader.cli.bigquery.BigQueryConnection.upload_raster", return_value=None) @patch("raster_loader.cli.bigquery.BigQueryConnection.__init__", return_value=None) def test_bigquery_file_path_or_url_check(*args, **kwargs): diff --git a/raster_loader/tests/bigquery/test_io.py b/raster_loader/tests/bigquery/test_io.py index 58f0be3..88d01f3 100644 --- a/raster_loader/tests/bigquery/test_io.py +++ b/raster_loader/tests/bigquery/test_io.py @@ -477,8 +477,13 @@ def test_rasterio_to_table_overwrite(*args, **kwargs): "count": 100000, "sum": 2866073.989868164, "sum_squares": 1e15, + "approximated_stats": False, + "top_values": [1, 2, 3], + "version": "0.0.3", }, - "nodata": "255", + 'colorinterp': 'red', + 'nodata': '255', + 'colortable': None, } ], "num_blocks": 1, @@ -639,8 +644,13 @@ def test_rasterio_to_table_invalid_raster(*args, **kwargs): "count": 100000, "sum": 2866073.989868164, "sum_squares": 1e15, + "approximated_stats": False, + "top_values": [1, 2, 3], + "version": "0.0.3", }, - "nodata": "255", + 'colorinterp': 'red', + 'nodata': '255', + 'colortable': None, } ], "num_blocks": 1, diff --git a/raster_loader/tests/snowflake/test_cli.py b/raster_loader/tests/snowflake/test_cli.py index 70453fd..9c11771 100644 --- a/raster_loader/tests/snowflake/test_cli.py +++ b/raster_loader/tests/snowflake/test_cli.py @@ -46,6 +46,76 @@ def test_snowflake_upload(*args, **kwargs): assert result.exit_code == 0 +@patch( + "raster_loader.io.snowflake.SnowflakeConnection.upload_raster", return_value=None +) +@patch("raster_loader.io.snowflake.SnowflakeConnection.__init__", return_value=None) +def test_snowflake_upload_with_all_stats(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "snowflake", + "upload", + "--file_path", + f"{tiff}", + "--database", + "database", + "--schema", + "schema", + "--table", + "table", + "--account", + "account", + "--username", + "username", + "--password", + "password", + "--chunk_size", + 1, + "--band", + 1, + "--all_stats", + ], + ) + assert result.exit_code == 0 + + +@patch( + "raster_loader.io.snowflake.SnowflakeConnection.upload_raster", return_value=None +) +@patch("raster_loader.io.snowflake.SnowflakeConnection.__init__", return_value=None) +def test_snowflake_upload_with_exact_stats(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "snowflake", + "upload", + "--file_path", + f"{tiff}", + "--database", + "database", + "--schema", + "schema", + "--table", + "table", + "--account", + "account", + "--username", + "username", + "--password", + "password", + "--chunk_size", + 1, + "--band", + 1, + "--exact_stats", + ], + ) + assert result.exit_code == 0 + + @patch( "raster_loader.io.snowflake.SnowflakeConnection.upload_raster", return_value=None ) diff --git a/raster_loader/tests/snowflake/test_io.py b/raster_loader/tests/snowflake/test_io.py index 5e81245..3586b2f 100644 --- a/raster_loader/tests/snowflake/test_io.py +++ b/raster_loader/tests/snowflake/test_io.py @@ -455,8 +455,13 @@ def test_rasterio_to_table_overwrite(*args, **kwargs): "count": 100000, "sum": 2866073.989868164, "sum_squares": 1e15, + "approximated_stats": False, + "top_values": [1, 2, 3], + "version": "0.0.3", }, "nodata": "255", + "colorinterp": 'red', + "colortable": None, } ], "num_blocks": 1, @@ -617,8 +622,13 @@ def test_rasterio_to_table_invalid_raster(*args, **kwargs): "count": 100000, "sum": 2866073.989868164, "sum_squares": 1e15, + "approximated_stats": False, + "top_values": [1, 2, 3], + "version": "0.0.3", }, "nodata": "255", + "colorinterp": 'red', + "colortable": None, } ], "num_blocks": 1, diff --git a/raster_loader/utils.py b/raster_loader/utils.py index 2a60140..9ed9eb1 100644 --- a/raster_loader/utils.py +++ b/raster_loader/utils.py @@ -2,6 +2,7 @@ import os import re import uuid +import warnings def ask_yes_no_question(question: str) -> bool: @@ -34,3 +35,20 @@ def get_default_table_name(base_path: str, band): table = os.path.basename(base_path).split(".")[0] table = "_".join([table, "band", str(band), str(uuid.uuid4())]) return re.sub(r"[^a-zA-Z0-9_-]", "_", table) + + +# Modify the __init__ so that self.line = "" instead of None +def new_init( + self, message, category, filename, lineno, file=None, line=None, source=None +): + self.message = message + self.category = category + self.filename = filename + self.lineno = lineno + self.file = file + self.line = "" + self.source = source + self._category_name = category.__name__.upper() if category else None + + +warnings.WarningMessage.__init__ = new_init