Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fix raster to grid #512

Merged
merged 30 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,30 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) {
}
}

/**
* Counts the number of pixels in the band. The mask is used to determine
* if a pixel is valid. If pixel value is noData or mask value is 0.0, the
* pixel is not counted.
*
* @return
* Returns the band's pixel count.
*/
def pixelCount: Int = {
val line = Array.ofDim[Double](band.GetXSize())
val maskLine = Array.ofDim[Double](band.GetXSize())
var count = 0
for (y <- 0 until band.GetYSize()) {
band.ReadRaster(0, y, band.GetXSize(), 1, line)
val maskRead = band.GetMaskBand().ReadRaster(0, y, band.GetXSize(), 1, maskLine)
if (maskRead != gdalconstConstants.CE_None) {
count = count + line.count(_ != noDataValue)
} else {
count = count + line.zip(maskLine).count { case (pixel, mask) => pixel != noDataValue && mask != 0.0 }
}
}
count
}

/**
* @return
* Returns the band's mask flags.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose
import com.databricks.labs.mosaic.core.raster.io.{RasterCleaner, RasterReader, RasterWriter}
import com.databricks.labs.mosaic.core.raster.operator.clip.RasterClipByVector
import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum.POLYGON
import com.databricks.labs.mosaic.utils.PathUtils
import com.databricks.labs.mosaic.utils.{FileUtils, PathUtils}
import org.gdal.gdal.gdal.GDALInfo
import org.gdal.gdal.{Dataset, InfoOptions, gdal}
import org.gdal.gdalconst.gdalconstConstants._
Expand Down Expand Up @@ -405,7 +405,7 @@ case class MosaicRasterGDAL(
} else {
path
}
val byteArray = Files.readAllBytes(Paths.get(readPath))
val byteArray = FileUtils.readBytes(readPath)
if (dispose) RasterCleaner.dispose(this)
if (readPath != PathUtils.getCleanPath(parentPath)) {
Files.deleteIfExists(Paths.get(readPath))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.databricks.labs.mosaic.core.raster.operator.gdal

import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL
import org.gdal.gdal.{InfoOptions, gdal}

/** GDALBuildVRT is a wrapper for the GDAL BuildVRT command. */
object GDALInfo {

/**
* Executes the GDAL BuildVRT command. For flags check the way gdalinfo.py
* script is called, InfoOptions expects a collection of same flags.
*
* @param raster
* The raster to get info from.
* @param command
* The GDAL Info command.
* @return
* A result json string.
*/
def executeInfo(raster: MosaicRasterGDAL, command: String): String = {
require(command.startsWith("gdalinfo"), "Not a valid GDAL Info command.")

val infoOptionsVec = OperatorOptions.parseOptions(command)
val infoOptions = new InfoOptions(infoOptionsVec)
val gdalInfo = gdal.GDALInfo(raster.getRaster, infoOptions)

if (gdalInfo == null) {
throw new Exception(s"""
|GDAL Info failed.
|Command: $command
|Error: ${gdal.GetLastErrorMsg}
|""".stripMargin)
}

gdalInfo
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ object BalancedSubdivision {
*/
def getNumSplits(raster: MosaicRasterGDAL, destSize: Int): Int = {
val size = raster.getMemSize
val n = size.toDouble / (destSize * 1000 * 1000)
val nInt = Math.ceil(n).toInt
Math.pow(4, Math.ceil(Math.log(nInt) / Math.log(4))).toInt
var n = 1
while (true) {
n *= 4
if (size / n <= destSize * 1000 * 1000) return n
}
n
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.databricks.labs.mosaic.datasource.multiread

import com.databricks.labs.mosaic.MOSAIC_RASTER_READ_STRATEGY
import com.databricks.labs.mosaic.functions.MosaicContext
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
Expand All @@ -25,18 +26,36 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead
nPartitions
}

private def workerNCores = {
sparkSession.sparkContext.range(0, 1).map(_ => java.lang.Runtime.getRuntime.availableProcessors).collect.head
}

private def nWorkers = sparkSession.sparkContext.getExecutorMemoryStatus.size

override def load(path: String): DataFrame = load(Seq(path): _*)

override def load(paths: String*): DataFrame = {

val config = getConfig
val resolution = config("resolution").toInt
val nPartitions = getNPartitions(config)
val readStrategy = config("retile") match {
case "true" => "retile_on_read"
case _ => "in_memory"
}
val tileSize = config("sizeInMB").toInt

val nCores = nWorkers * workerNCores
val stageCoefficient = math.ceil(math.log(nCores) / math.log(4))

val firstStageSize = (tileSize * math.pow(4, stageCoefficient)).toInt

val pathsDf = sparkSession.read
.format("gdal")
.option("extensions", config("extensions"))
.option("raster_storage", "in-memory")
.option(MOSAIC_RASTER_READ_STRATEGY, readStrategy)
.option("vsizip", config("vsizip"))
.option("sizeInMB", firstStageSize)
.load(paths: _*)
.repartition(nPartitions)

Expand All @@ -46,7 +65,12 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead

val retiledDf = retileRaster(rasterDf, config)

val loadedDf = retiledDf
val loadedDf = rasterDf
.withColumn(
"tile",
rst_tessellate(col("tile"), lit(resolution))
)
.repartition(nPartitions)
.withColumn(
"grid_measures",
rasterToGridCombiner(col("tile"), lit(resolution))
Expand All @@ -58,6 +82,7 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead
.select(
posexplode(col("grid_measures")).as(Seq("band_id", "grid_measures"))
)
.repartition(nPartitions)
.select(
col("band_id"),
explode(col("grid_measures")).alias("grid_measures")
Expand Down Expand Up @@ -88,16 +113,22 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead
*/
private def retileRaster(rasterDf: DataFrame, config: Map[String, String]) = {
val retile = config("retile").toBoolean
val tileSize = config("tileSize").toInt
val tileSize = config.getOrElse("tileSize", "-1").toInt
val memSize = config.getOrElse("sizeInMB", "-1").toInt
val nPartitions = getNPartitions(config)

if (retile) {
rasterDf
.withColumn(
"tile",
rst_retile(col("tile"), lit(tileSize), lit(tileSize))
)
.repartition(nPartitions)
if (memSize > 0) {
rasterDf
.withColumn("tile", rst_subdivide(col("tile"), lit(memSize)))
.repartition(nPartitions)
} else if (tileSize > 0) {
rasterDf
.withColumn("tile", rst_retile(col("tile"), lit(tileSize), lit(tileSize)))
.repartition(nPartitions)
} else {
rasterDf
}
} else {
rasterDf
}
Expand Down Expand Up @@ -200,7 +231,8 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead
"resolution" -> this.extraOptions.getOrElse("resolution", "0"),
"combiner" -> this.extraOptions.getOrElse("combiner", "mean"),
"retile" -> this.extraOptions.getOrElse("retile", "false"),
"tileSize" -> this.extraOptions.getOrElse("tileSize", "256"),
"tileSize" -> this.extraOptions.getOrElse("tileSize", "-1"),
"sizeInMB" -> this.extraOptions.getOrElse("sizeInMB", ""),
"kRingInterpolate" -> this.extraOptions.getOrElse("kRingInterpolate", "0")
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALInfo
import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._


/** Returns the upper left x of the raster. */
case class RST_Avg(raster: Expression, expressionConfig: MosaicExpressionConfig)
extends RasterExpression[RST_Avg](raster, ArrayType(DoubleType), returnsRaster = false, expressionConfig)
with NullIntolerant
with CodegenFallback {

/** Returns the upper left x of the raster. */
override def rasterTransform(tile: MosaicRasterTile): Any = {
import org.json4s._
import org.json4s.jackson.JsonMethods._
implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats

val command = s"gdalinfo -stats -json -mm -nogcp -nomd -norat -noct"
val gdalInfo = GDALInfo.executeInfo(tile.raster, command)
// parse json from gdalinfo
val json = parse(gdalInfo).extract[Map[String, Any]]
val maxValues = json("bands").asInstanceOf[List[Map[String, Any]]].map { band =>
band("mean").asInstanceOf[Double]
}
ArrayData.toArrayData(maxValues.toArray)
}

}

/** Expression info required for the expression registration for spark SQL. */
object RST_Avg extends WithExpressionInfo {

override def name: String = "rst_avg"

override def usage: String = "_FUNC_(expr1) - Returns an array containing mean values for each band."

override def example: String =
"""
| Examples:
| > SELECT _FUNC_(raster_tile);
| [1.123, 2.123, 3.123]
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[RST_Avg](1, expressionConfig)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALInfo
import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._


/** Returns the upper left x of the raster. */
case class RST_Max(raster: Expression, expressionConfig: MosaicExpressionConfig)
extends RasterExpression[RST_Max](raster, ArrayType(DoubleType), returnsRaster = false, expressionConfig)
with NullIntolerant
with CodegenFallback {

/** Returns the upper left x of the raster. */
override def rasterTransform(tile: MosaicRasterTile): Any = {
val nBands = tile.raster.raster.GetRasterCount()
val maxValues = (1 to nBands).map(tile.raster.getBand(_).maxPixelValue)
ArrayData.toArrayData(maxValues.toArray)
}

}

/** Expression info required for the expression registration for spark SQL. */
object RST_Max extends WithExpressionInfo {

override def name: String = "rst_max"

override def usage: String = "_FUNC_(expr1) - Returns an array containing max values for each band."

override def example: String =
"""
| Examples:
| > SELECT _FUNC_(raster_tile);
| [1.123, 2.123, 3.123]
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[RST_Max](1, expressionConfig)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.raster.api.GDAL
import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL
import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALCalc, GDALInfo, GDALWarp}
import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import com.databricks.labs.mosaic.utils.PathUtils
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._

/** Returns the upper left x of the raster. */
case class RST_Median(rasterExpr: Expression, expressionConfig: MosaicExpressionConfig)
extends RasterExpression[RST_Median](rasterExpr, ArrayType(DoubleType), returnsRaster = false, expressionConfig)
with NullIntolerant
with CodegenFallback {

/** Returns the upper left x of the raster. */
override def rasterTransform(tile: MosaicRasterTile): Any = {
val raster = tile.raster
val width = raster.xSize * raster.pixelXSize
val height = raster.ySize * raster.pixelYSize
val outShortName = raster.getDriversShortName
val resultFileName = PathUtils.createTmpFilePath(GDAL.getExtension(outShortName))
val medRaster = GDALWarp.executeWarp(
resultFileName,
Seq(raster),
command = s"gdalwarp -r med -tr $width $height -of $outShortName"
)
// Max pixel is a hack since we get a 1x1 raster back
val maxValues = (1 to medRaster.raster.GetRasterCount()).map(medRaster.getBand(_).maxPixelValue)
ArrayData.toArrayData(maxValues.toArray)
}

}

/** Expression info required for the expression registration for spark SQL. */
object RST_Median extends WithExpressionInfo {

override def name: String = "rst_median"

override def usage: String = "_FUNC_(expr1) - Returns an array containing mean values for each band."

override def example: String =
"""
| Examples:
| > SELECT _FUNC_(raster_tile);
| [1.123, 2.123, 3.123]
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[RST_Median](1, expressionConfig)
}

}
Loading
Loading