diff --git a/docs/source/user_guide/cli.rst b/docs/source/user_guide/cli.rst index ff5147a..6af713b 100644 --- a/docs/source/user_guide/cli.rst +++ b/docs/source/user_guide/cli.rst @@ -152,7 +152,7 @@ Or, with band names: --band_name red \ --band_name green -You can enable compression of the band data using the ``--compress`` flag. This uses gzip compression which can significantly reduce storage size: +You can enable compression of the band data using the ``--compress`` flag. This uses gzip compression which can significantly reduce storage size. By default, it uses compression level 6, which provides a good balance between compression ratio and performance. You can adjust this using the ``--compression-level`` parameter (values from 1 to 9, where 1 is fastest but least compressed, and 9 gives maximum compression): .. code-block:: bash @@ -161,7 +161,8 @@ You can enable compression of the band data using the ``--compress`` flag. This --project my-gcp-project \ --dataset my-bigquery-dataset \ --table my-bigquery-table \ - --compress + --compress \ + --compression-level 3 The same works for Snowflake: @@ -175,7 +176,8 @@ The same works for Snowflake: --account my-snowflake-account \ --username my-snowflake-user \ --password my-snowflake-password \ - --compress + --compress \ + --compression-level 3 .. seealso:: See the :ref:`cli_details` for a full list of options. diff --git a/docs/source/user_guide/use_with_python.rst b/docs/source/user_guide/use_with_python.rst index 1434764..1495cc0 100644 --- a/docs/source/user_guide/use_with_python.rst +++ b/docs/source/user_guide/use_with_python.rst @@ -91,7 +91,8 @@ To enable compression of the band data, which can significantly reduce storage s connector.upload_raster( file_path = 'path/to/raster.tif', fqn = 'database.schema.tablename', - compress = True # Enable gzip compression of band data + compress = True, # Enable gzip compression of band data + compression_level = 3 # Optional: Set compression level (1-9, default=6) ) The compression information will be stored in the metadata of the table, and the data will be automatically decompressed when reading it back. diff --git a/raster_loader/cli/bigquery.py b/raster_loader/cli/bigquery.py index 0fc7a09..7ab3398 100644 --- a/raster_loader/cli/bigquery.py +++ b/raster_loader/cli/bigquery.py @@ -95,6 +95,12 @@ def bigquery(args=None): required=False, is_flag=True, ) +@click.option( + "--compression-level", + help="Compression level (1-9, higher = better compression but slower)", + type=int, + default=6, +) @catch_exception() def upload( file_path, @@ -112,6 +118,7 @@ def upload( cleanup_on_failure=False, exact_stats=False, basic_stats=False, + compression_level=6, ): from raster_loader.io.common import ( get_number_of_blocks, @@ -186,6 +193,7 @@ def upload( exact_stats=exact_stats, basic_stats=basic_stats, compress=compress, + compression_level=compression_level, ) click.echo("Raster file uploaded to Google BigQuery") diff --git a/raster_loader/cli/snowflake.py b/raster_loader/cli/snowflake.py index 5819dd2..1ae1dd5 100644 --- a/raster_loader/cli/snowflake.py +++ b/raster_loader/cli/snowflake.py @@ -117,6 +117,12 @@ def snowflake(args=None): is_flag=True, default=False, ) +@click.option( + "--compression-level", + help="Compression level (1-9, higher = better compression but slower)", + type=int, + default=6, +) @catch_exception() def upload( account, @@ -141,6 +147,7 @@ def upload( cleanup_on_failure=False, exact_stats=False, basic_stats=False, + compression_level=6, ): from raster_loader.io.common import ( get_number_of_blocks, @@ -247,6 +254,7 @@ def upload( exact_stats=exact_stats, basic_stats=basic_stats, compress=compress, + compression_level=compression_level, ) click.echo("Raster file uploaded to Snowflake") diff --git a/raster_loader/io/bigquery.py b/raster_loader/io/bigquery.py index 7f87b68..e22951d 100644 --- a/raster_loader/io/bigquery.py +++ b/raster_loader/io/bigquery.py @@ -110,6 +110,7 @@ def upload_raster( exact_stats: bool = False, basic_stats: bool = False, compress: bool = False, + compression_level: int = 6, ): """Write a raster file to a BigQuery table.""" print("Loading raster file to BigQuery...") @@ -145,6 +146,7 @@ def upload_raster( self.band_rename_function, bands_info, compress=compress, + compression_level=compression_level, ) windows_records_gen = rasterio_windows_to_records( @@ -152,6 +154,7 @@ def upload_raster( self.band_rename_function, bands_info, compress=compress, + compression_level=compression_level, ) records_gen = chain(overviews_records_gen, windows_records_gen) diff --git a/raster_loader/io/common.py b/raster_loader/io/common.py index 85c3a66..814f5bd 100644 --- a/raster_loader/io/common.py +++ b/raster_loader/io/common.py @@ -101,15 +101,15 @@ def get_default_nodata_value(dtype: str) -> float: # TODO: Remove this once we drop support for Python < 3.11 if sys.version_info < (3, 11): - def compress_bytes(arr_bytes): - compressed = zlib.compress(arr_bytes, level=6) + def compress_bytes(arr_bytes, level=6): + compressed = zlib.compress(arr_bytes, level=level) # Add gzip header corresponding to wbits=31 return b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03" + compressed else: - def compress_bytes(arr_bytes): - return zlib.compress(arr_bytes, level=6, wbits=31) + def compress_bytes(arr_bytes, level=6): + return zlib.compress(arr_bytes, level=level, wbits=31) def array_to_record( @@ -122,6 +122,7 @@ def array_to_record( window: rasterio.windows.Window, no_data_value: float = None, compress: bool = False, + compression_level: int = 6, ) -> dict: row_off = window.row_off col_off = window.col_off @@ -144,7 +145,7 @@ def array_to_record( arr_bytes = np.ascontiguousarray(arr).tobytes() # Apply compression if requested - arr_bytes = compress_bytes(arr_bytes) if compress else arr_bytes + arr_bytes = compress_bytes(arr_bytes, compression_level) if compress else arr_bytes record = { band_rename_function("block"): block, @@ -724,6 +725,7 @@ def rasterio_overview_to_records( band_rename_function: Callable, bands_info: List[Tuple[int, str]], compress: bool = False, + compression_level: int = 6, ) -> Iterable: raster_info = rio_cogeo.cog_info(file_path).dict() with rasterio.open(file_path) as raster_dataset: @@ -831,6 +833,7 @@ def rasterio_overview_to_records( tile_window, no_data_value, compress=compress, + compression_level=compression_level, ) if newrecord: record.update(newrecord) @@ -844,6 +847,7 @@ def rasterio_windows_to_records( band_rename_function: Callable, bands_info: List[Tuple[int, str]], compress: bool = False, + compression_level: int = 6, ) -> Iterable: invalid_names = [ name for _, name in bands_info if name and name.lower() in ["block", "metadata"] @@ -891,6 +895,7 @@ def rasterio_windows_to_records( window, no_data_value, compress=compress, + compression_level=compression_level, ) # add the new columns generated by array_t diff --git a/raster_loader/io/snowflake.py b/raster_loader/io/snowflake.py index 49e1c02..696f399 100644 --- a/raster_loader/io/snowflake.py +++ b/raster_loader/io/snowflake.py @@ -207,7 +207,10 @@ def upload_raster( exact_stats: bool = False, basic_stats: bool = False, compress: bool = False, + compression_level: int = 6, ) -> bool: + """Write a raster file to a Snowflake table.""" + def band_rename_function(x): return x.upper() @@ -242,12 +245,14 @@ def band_rename_function(x): band_rename_function, bands_info, compress=compress, + compression_level=compression_level, ) windows_records_gen = rasterio_windows_to_records( file_path, band_rename_function, bands_info, compress=compress, + compression_level=compression_level, ) records_gen = chain(overviews_records_gen, windows_records_gen) diff --git a/raster_loader/tests/bigquery/test_io.py b/raster_loader/tests/bigquery/test_io.py index 3f5ae3e..b02b483 100644 --- a/raster_loader/tests/bigquery/test_io.py +++ b/raster_loader/tests/bigquery/test_io.py @@ -759,3 +759,62 @@ def test_rasterio_to_bigquery_with_compression(*args, **kwargs): compress=True, ) assert success + + +@patch( + "raster_loader.io.bigquery.BigQueryConnection.check_if_table_exists", + return_value=True, +) +@patch("raster_loader.io.bigquery.BigQueryConnection.delete_table", return_value=None) +@patch( + "raster_loader.io.bigquery.BigQueryConnection.check_if_table_is_empty", + return_value=False, +) +@patch("raster_loader.io.bigquery.ask_yes_no_question", return_value=True) +@patch("raster_loader.io.bigquery.BigQueryConnection.delete_table", return_value=None) +@patch("raster_loader.io.bigquery.BigQueryConnection.write_metadata", return_value=None) +@patch("raster_loader.io.bigquery.BigQueryConnection.update_labels", return_value=None) +@patch( + "raster_loader.io.bigquery.BigQueryConnection.get_metadata", + return_value={ + "bounds": [0, 0, 0, 0], + "block_resolution": 5, + "nodata": 0, + "block_width": 256, + "block_height": 256, + "compression": "gzip", + "bands": [ + { + "type": "uint8", + "name": "band_1", + "colorinterp": "red", + "stats": { + "min": 0.0, + "max": 255.0, + "mean": 28.66073989868164, + "stddev": 41.5693439511935, + "count": 100000, + "sum": 2866073.989868164, + "sum_squares": 1e15, + "approximated_stats": False, + "top_values": [1, 2, 3], + "version": "0.0.3", + }, + "nodata": "0", + "colortable": None, + } + ], + "num_blocks": 1, + "num_pixels": 1, + }, +) +def test_rasterio_to_bigquery_with_compression_level(*args, **kwargs): + table_name = "test_mosaic_compressed" + connector = mocks.MockBigQueryConnection() + success = connector.upload_raster( + os.path.join(fixtures_dir, "mosaic_cog.tif"), + f"{BQ_PROJECT_ID}.{BQ_DATASET_ID}.{table_name}", + compress=True, + compression_level=3, + ) + assert success diff --git a/raster_loader/tests/snowflake/test_io.py b/raster_loader/tests/snowflake/test_io.py index 9228a14..c6fe909 100644 --- a/raster_loader/tests/snowflake/test_io.py +++ b/raster_loader/tests/snowflake/test_io.py @@ -728,3 +728,64 @@ def test_rasterio_to_snowflake_with_compression(*args, **kwargs): compress=True, ) assert success + + +@patch( + "raster_loader.io.snowflake.SnowflakeConnection.check_if_table_exists", + return_value=True, +) +@patch("raster_loader.io.snowflake.SnowflakeConnection.delete_table", return_value=None) +@patch( + "raster_loader.io.snowflake.SnowflakeConnection.check_if_table_is_empty", + return_value=False, +) +@patch("raster_loader.io.snowflake.ask_yes_no_question", return_value=True) +@patch( + "raster_loader.io.snowflake.SnowflakeConnection.write_metadata", return_value=None +) +@patch( + "raster_loader.io.snowflake.SnowflakeConnection.get_metadata", + return_value={ + "bounds": [0, 0, 0, 0], + "block_resolution": 5, + "nodata": 0, + "block_width": 256, + "block_height": 256, + "compression": "gzip", + "bands": [ + { + "type": "uint8", + "name": "BAND_1", + "colorinterp": "red", + "stats": { + "min": 0.0, + "max": 255.0, + "mean": 28.66073989868164, + "stddev": 41.5693439511935, + "count": 100000, + "sum": 2866073.989868164, + "sum_squares": 1e15, + "approximated_stats": False, + "top_values": [1, 2, 3], + "version": "0.0.3", + }, + "nodata": "0", + "colorinterp": "red", + "colortable": None, + } + ], + "num_blocks": 1, + "num_pixels": 1, + }, +) +@patch("raster_loader.io.snowflake.write_pandas", return_value=[True]) +def test_rasterio_to_snowflake_with_compression_level(*args, **kwargs): + table_name = "test_mosaic_compressed".upper() + connector = mocks.MockSnowflakeConnection() + success = connector.upload_raster( + os.path.join(fixtures_dir, "mosaic_cog.tif"), + f"{SF_DATABASE}.{SF_SCHEMA}.{table_name}", + compress=True, + compression_level=3, + ) + assert success