Skip to content

Commit

Permalink
Update file format in-situ (#596)
Browse files Browse the repository at this point in the history
* changes to the RasterAsGridReader and ReTileOnRead processes

* removed some cruft

* switch runners

* fixed failing zarr tests

* disabled the test for netcdf with checkpointing disabled

* added function `RST_AsFormat()` to change raster format / driver in-situ

* fixed zarr test

* added python bindings and tests
  • Loading branch information
sllynn authored Nov 13, 2024
1 parent 27fd3ec commit a8ef2d3
Show file tree
Hide file tree
Showing 16 changed files with 432 additions and 5 deletions.
46 changes: 46 additions & 0 deletions python/mosaic/api/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#######################

__all__ = [
"rst_asformat",
"rst_avg",
"rst_bandmetadata",
"rst_boundingbox",
Expand All @@ -23,6 +24,7 @@
"rst_derivedband",
"rst_dtmfromgeoms",
"rst_filter",
"rst_format",
"rst_frombands",
"rst_fromcontent",
"rst_fromfile",
Expand Down Expand Up @@ -85,6 +87,30 @@
]


def rst_asformat(raster_tile: ColumnOrName, driver: ColumnOrName) -> Column:
"""
Translates the raster to the specified format.
Parameters
----------
raster_tile : Column (RasterTileType)
Mosaic raster tile struct column.
driver : Column (StringType)
The format driver to use.
Returns
-------
Column (RasterTileType)
The updated raster.
"""
return config.mosaic_context.invoke_function(
"rst_asformat",
pyspark_to_java_column(raster_tile),
pyspark_to_java_column(driver),
)


def rst_avg(raster_tile: ColumnOrName) -> Column:
"""
Returns an array containing mean value for each band.
Expand Down Expand Up @@ -356,6 +382,26 @@ def rst_filter(raster_tile: ColumnOrName, kernel_size: Any, operation: Any) -> C
)


def rst_format(raster_tile: ColumnOrName) -> Column:
"""
Returns the format of the raster.
Parameters
----------
raster_tile : Column (RasterTileType)
Mosaic raster tile struct column.
Returns
-------
Column (StringType)
The format of the raster (driver required for reading).
"""
return config.mosaic_context.invoke_function(
"rst_format", pyspark_to_java_column(raster_tile)
)


def rst_frombands(bands: ColumnOrName) -> Column:
"""
Stack an array of bands into a raster tile.
Expand Down
3 changes: 3 additions & 0 deletions python/test/test_raster_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def test_raster_scalar_functions(self):
.withColumn("rst_max", api.rst_max("tile"))
.withColumn("rst_median", api.rst_median("tile"))
.withColumn("rst_min", api.rst_min("tile"))
.withColumn("rst_setsrid", api.rst_setsrid("tile", lit(4326)))
.withColumn("rst_format", api.rst_format("rst_setsrid"))
.withColumn("rst_asformat", api.rst_asformat("rst_setsrid", lit("GRIB")))
.withColumn("rst_frombands", api.rst_frombands(array("tile", "tile")))
.withColumn("rst_georeference", api.rst_georeference("tile"))
.withColumn("rst_getnodata", api.rst_getnodata("tile"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.databricks.labs.mosaic.core.raster.operator.RasterTranslate

import com.databricks.labs.mosaic.core.raster.api.GDAL
import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL
import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALTranslate
import com.databricks.labs.mosaic.utils.PathUtils

object TranslateFormat {

/**
* Converts the data type of a raster's bands
*
* @param raster
* The raster to update.
* @param newFormat
* The new format of the raster.
* @return
* A MosaicRasterGDAL object.
*/
def update(
raster: MosaicRasterGDAL,
newFormat: String
): MosaicRasterGDAL = {

val outOptions = raster.getWriteOptions.copy(format = newFormat, extension = GDAL.getExtension(newFormat))
val resultFileName = PathUtils.createTmpFilePath(outOptions.extension)

val result = GDALTranslate.executeTranslate(
resultFileName,
raster,
command = s"gdal_translate",
outOptions
)

result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ object VectorClipper {
* The shapefile name.
*/
private def getShapefileName: String = {
val shapeFileName = PathUtils.createTmpFilePath(".shp")
val shapeFileName = PathUtils.createTmpFilePath("shp")
shapeFileName
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ case class MosaicRasterTile(

def getDriver: String = driver

def setDriver(value: String): MosaicRasterTile = {
new MosaicRasterTile(index, raster.copy(createInfo = raster.createInfo.updated("driver", value)))
}

def driver: String = raster.createInfo("driver")

def getRaster: MosaicRasterGDAL = raster
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,16 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead

val retiledDf = retileRaster(pathsDf, config)

val convertToFormat = if (config("convertToFormat").isEmpty) {
col("tile.metadata").getItem("driver") // which should be a noop
} else {
lit(config("convertToFormat"))
}

val rasterToGridCombiner = getRasterToGridFunc(config("combiner"))

val loadedDf = retiledDf
.withColumn("tile", rst_asformat(col("tile"), convertToFormat))
.withColumn(
"tile",
rst_tessellate(col("tile"), lit(resolution))
Expand Down Expand Up @@ -225,7 +232,8 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead
"retile" -> this.extraOptions.getOrElse("retile", "false"),
"tileSize" -> this.extraOptions.getOrElse("tileSize", "-1"),
"sizeInMB" -> this.extraOptions.getOrElse("sizeInMB", "-1"),
"kRingInterpolate" -> this.extraOptions.getOrElse("kRingInterpolate", "0")
"kRingInterpolate" -> this.extraOptions.getOrElse("kRingInterpolate", "0"),
"convertToFormat" -> this.extraOptions.getOrElse("convertToFormat", "")
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.raster.api.GDAL
import com.databricks.labs.mosaic.core.raster.operator.RasterTranslate.TranslateFormat
import com.databricks.labs.mosaic.core.types.RasterTileType
import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.raster.base.Raster1ArgExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.types.DataType
import org.apache.spark.unsafe.types.UTF8String

case class RST_AsFormat (
tileExpr: Expression,
newFormat: Expression,
expressionConfig: MosaicExpressionConfig
) extends Raster1ArgExpression[RST_AsFormat](
tileExpr,
newFormat,
returnsRaster = true,
expressionConfig
)
with NullIntolerant
with CodegenFallback {

override def dataType: DataType = {
GDAL.enable(expressionConfig)
RasterTileType(expressionConfig.getCellIdType, tileExpr, expressionConfig.isRasterUseCheckpoint)
}

/** Changes the data type of a band of the raster. */
override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = {

val newFormat = arg1.asInstanceOf[UTF8String].toString
if (tile.getRaster.driverShortName.getOrElse("") == newFormat) {
return tile
}
val result = TranslateFormat.update(tile.getRaster, newFormat)
tile.copy(raster = result).setDriver(newFormat)
}

}

/** Expression info required for the expression registration for spark SQL. */
object RST_AsFormat extends WithExpressionInfo {

override def name: String = "rst_asformat"

override def usage: String = "_FUNC_(expr1) - Returns a raster tile in a different underlying format"

override def example: String =
"""
| Examples:
| > SELECT _FUNC_(tile, 'GTiff')
| {index_id, updated_raster, parentPath, driver}
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[RST_AsFormat](2, expressionConfig)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String

case class RST_Format (
tileExpr: Expression,
expressionConfig: MosaicExpressionConfig
) extends RasterExpression[RST_Format](
tileExpr,
returnsRaster = false,
expressionConfig
)
with NullIntolerant
with CodegenFallback {

override def dataType: DataType = StringType

/** Returns the format of the raster. */
override def rasterTransform(tile: MosaicRasterTile): Any = {
UTF8String.fromString(tile.getDriver)
}

}

/** Expression info required for the expression registration for spark SQL. */
object RST_Format extends WithExpressionInfo {

override def name: String = "rst_format"

override def usage: String = "_FUNC_(expr1) - Returns the driver used to read the raster"

override def example: String =
"""
| Examples:
| > SELECT _FUNC_(tile)
| 'GTiff'
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[RST_Format](1, expressionConfig)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
)

/** RasterAPI dependent functions */
mosaicRegistry.registerExpression[RST_AsFormat](expressionConfig)
mosaicRegistry.registerExpression[RST_Avg](expressionConfig)
mosaicRegistry.registerExpression[RST_BandMetaData](expressionConfig)
mosaicRegistry.registerExpression[RST_BoundingBox](expressionConfig)
Expand All @@ -285,6 +286,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
mosaicRegistry.registerExpression[RST_DerivedBand](expressionConfig)
mosaicRegistry.registerExpression[RST_DTMFromGeoms](expressionConfig)
mosaicRegistry.registerExpression[RST_Filter](expressionConfig)
mosaicRegistry.registerExpression[RST_Format](expressionConfig)
mosaicRegistry.registerExpression[RST_GeoReference](expressionConfig)
mosaicRegistry.registerExpression[RST_GetNoData](expressionConfig)
mosaicRegistry.registerExpression[RST_GetSubdataset](expressionConfig)
Expand Down Expand Up @@ -695,6 +697,10 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
def st_within(geom1: Column, geom2: Column): Column = ColumnAdapter(ST_Within(geom1.expr, geom2.expr, expressionConfig))

/** RasterAPI dependent functions */
def rst_asformat(raster: Column, driver: Column): Column =
ColumnAdapter(RST_AsFormat(raster.expr, driver.expr, expressionConfig))
def rst_asformat(raster: Column, driver: String): Column =
ColumnAdapter(RST_AsFormat(raster.expr, lit(driver).expr, expressionConfig))
def rst_bandmetadata(raster: Column, band: Column): Column =
ColumnAdapter(RST_BandMetaData(raster.expr, band.expr, expressionConfig))
def rst_bandmetadata(raster: Column, band: Int): Column =
Expand All @@ -716,6 +722,8 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
ColumnAdapter(RST_Filter(raster.expr, kernelSize.expr, operation.expr, expressionConfig))
def rst_filter(raster: Column, kernelSize: Int, operation: String): Column =
ColumnAdapter(RST_Filter(raster.expr, lit(kernelSize).expr, lit(operation).expr, expressionConfig))
def rst_format(raster: Column): Column =
ColumnAdapter(RST_Format(raster.expr, expressionConfig))
def rst_georeference(raster: Column): Column = ColumnAdapter(RST_GeoReference(raster.expr, expressionConfig))
def rst_getnodata(raster: Column): Column = ColumnAdapter(RST_GetNoData(raster.expr, expressionConfig))
def rst_getsubdataset(raster: Column, subdatasetName: Column): Column =
Expand Down
10 changes: 8 additions & 2 deletions src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,14 @@ object PathUtils {
for (path <- toCopy) {
val destination = Paths.get(copyToPath, path.getFileName.toString)
// noinspection SimplifyBooleanMatch
if (Files.isDirectory(path)) FileUtils.copyDirectory(path.toFile, destination.toFile)
else FileUtils.copyFile(path.toFile, destination.toFile)
if (path != destination) {
if (Files.isDirectory(path)) {
FileUtils.copyDirectory(path.toFile, destination.toFile)
}
else {
FileUtils.copyFile(path.toFile, destination.toFile)
}
}
}
}

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest
import org.apache.spark.sql.test.SharedSparkSessionGDAL
import org.scalatest.Tag
import org.scalatest.matchers.must.Matchers.{be, noException}
import org.scalatest.matchers.should.Matchers.an
import org.scalatest.matchers.should.Matchers.{an, convertToAnyShouldWrapper}

import java.nio.file.{Files, Paths}

Expand Down Expand Up @@ -41,6 +41,36 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess

}

test("Read ECMWF netcdf with Raster As Grid Reader") {
assume(System.getProperty("os.name") == "Linux")
assume(checkpointingEnabled)
val mc = MosaicContext.build(H3IndexSystem, JTS)
mc.register(spark)


val netcdf = "/binary/netcdf-ECMWF/"
val filePath = this.getClass.getResource(netcdf).getPath

val result = MosaicContext.read
.format("raster_to_grid")
.option("sizeInMB", "16")
.option("convertToFormat", "GTiff")
.option("resolution", "0")
.option("readSubdataset", "true")
.option("subdatasetName", "t2m")
.option("retile", "true")
.option("tileSize", "600")
.option("combiner", "avg")
.load(filePath)
.select("measure")
.cache()

result.count shouldBe 1098

noException should be thrownBy result.take(1)

}

test("Read grib with Raster As Grid Reader", ExcludeLocalTag) {
assume(System.getProperty("os.name") == "Linux")
MosaicContext.build(H3IndexSystem, JTS)
Expand Down
Loading

0 comments on commit a8ef2d3

Please sign in to comment.