Skip to content

Commit

Permalink
Fix the format name for grib files in tests.
Browse files Browse the repository at this point in the history
Fix temp location utils.
Separate temp on worker location and off worker location.
Introduce GDAL Block notion.
Implement Kernel filters via GDALBlocks.
Add additional params to gdal programs when they run.
Fix TILED=YES issue with TIF files.
Introduce writeOptions concept for tmp writes.
Update expressions to take into account new concepts.
Fix Zarr format issues with SerDeser.
  • Loading branch information
milos.colic committed Feb 1, 2024
1 parent aacf3d2 commit 79ff6e6
Show file tree
Hide file tree
Showing 114 changed files with 2,005 additions and 560 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterBandGDAL, Mosaic
import com.databricks.labs.mosaic.core.raster.io.RasterCleaner
import com.databricks.labs.mosaic.core.raster.operator.transform.RasterTransform
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import com.databricks.labs.mosaic.gdal.MosaicGDAL
import com.databricks.labs.mosaic.gdal.MosaicGDAL.configureGDAL
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{BinaryType, DataType, StringType}
Expand Down Expand Up @@ -114,6 +115,8 @@ object GDAL {
} else {
raster
}
case _ =>
throw new IllegalArgumentException(s"Unsupported data type: $inputDT")
}
}

Expand All @@ -122,19 +125,17 @@ object GDAL {
*
* @param generatedRasters
* The rasters to write.
* @param checkpointPath
* The path to write the rasters to.
* @return
* Returns the paths of the written rasters.
*/
def writeRasters(generatedRasters: Seq[MosaicRasterGDAL], checkpointPath: String, rasterDT: DataType): Seq[Any] = {
def writeRasters(generatedRasters: Seq[MosaicRasterGDAL], rasterDT: DataType): Seq[Any] = {
generatedRasters.map(raster =>
if (raster != null) {
rasterDT match {
case StringType =>
val uuid = UUID.randomUUID().toString
val extension = GDAL.getExtension(raster.getDriversShortName)
val writePath = s"$checkpointPath/$uuid.$extension"
val writePath = s"${MosaicGDAL.checkpointPath}/$uuid.$extension"
val outPath = raster.writeToPath(writePath)
RasterCleaner.dispose(raster)
UTF8String.fromString(outPath)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package com.databricks.labs.mosaic.core.raster.gdal

import scala.reflect.ClassTag

case class GDALBlock[T: ClassTag](
block: Array[T],
maskBlock: Array[Double],
noDataValue: Double,
xOffset: Int,
yOffset: Int,
width: Int,
height: Int,
padding: Padding
)(implicit
num: Numeric[T]
) {

def elementAt(index: Int): T = block(index)

def maskAt(index: Int): Double = maskBlock(index)

def elementAt(x: Int, y: Int): T = block(y * width + x)

def maskAt(x: Int, y: Int): Double = maskBlock(y * width + x)

def rasterElementAt(x: Int, y: Int): T = block((y - yOffset) * width + (x - xOffset))

def rasterMaskAt(x: Int, y: Int): Double = maskBlock((y - yOffset) * width + (x - xOffset))

def valuesAt(x: Int, y: Int, kernelWidth: Int, kernelHeight: Int): Array[Double] = {
val kernelCenterX = kernelWidth / 2
val kernelCenterY = kernelHeight / 2
val values = Array.fill[Double](kernelWidth * kernelHeight)(noDataValue)
var n = 0
for (i <- 0 until kernelHeight) {
for (j <- 0 until kernelWidth) {
val xIndex = x + (j - kernelCenterX)
val yIndex = y + (i - kernelCenterY)
if (xIndex >= 0 && xIndex < width && yIndex >= 0 && yIndex < height) {
val maskValue = maskAt(xIndex, yIndex)
val value = elementAt(xIndex, yIndex)
if (maskValue != 0.0 && num.toDouble(value) != noDataValue) {
values(n) = num.toDouble(value)
n += 1
}
}
}
}
val result = values.filter(_ != noDataValue)
// always return only one NoDataValue if no valid values are found
// one and only one NoDataValue can be returned
// in all cases that have some valid values, the NoDataValue will be filtered out
if (result.isEmpty) {
Array(noDataValue)
} else {
result
}
}

// TODO: Test and fix, not tested, other filters work.
def convolveAt(x: Int, y: Int, kernel: Array[Array[Double]]): Double = {
val kernelWidth = kernel.head.length
val kernelHeight = kernel.length
val kernelCenterX = kernelWidth / 2
val kernelCenterY = kernelHeight / 2
var sum = 0.0
for (i <- 0 until kernelHeight) {
for (j <- 0 until kernelWidth) {
val xIndex = x + (j - kernelCenterX)
val yIndex = y + (i - kernelCenterY)
if (xIndex >= 0 && xIndex < width && yIndex >= 0 && yIndex < height) {
val maskValue = maskAt(xIndex, yIndex)
val value = rasterElementAt(xIndex, yIndex)
if (maskValue != 0.0 && num.toDouble(value) != noDataValue) {
sum += num.toDouble(value) * kernel(i)(j)
}
}
}
}
sum
}

def avgFilterAt(x: Int, y: Int, kernelSize: Int): Double = {
val values = valuesAt(x, y, kernelSize, kernelSize)
values.sum / values.length
}

def minFilterAt(x: Int, y: Int, kernelSize: Int): Double = {
val values = valuesAt(x, y, kernelSize, kernelSize)
values.min
}

def maxFilterAt(x: Int, y: Int, kernelSize: Int): Double = {
val values = valuesAt(x, y, kernelSize, kernelSize)
values.max
}

def medianFilterAt(x: Int, y: Int, kernelSize: Int): Double = {
val values = valuesAt(x, y, kernelSize, kernelSize)
val n = values.length
scala.util.Sorting.quickSort(values)
values(n / 2)
}

def modeFilterAt(x: Int, y: Int, kernelSize: Int): Double = {
val values = valuesAt(x, y, kernelSize, kernelSize)
val counts = values.groupBy(identity).mapValues(_.length)
counts.maxBy(_._2)._1
}

def trimBlock(stride: Int): GDALBlock[Double] = {
val resultBlock = padding.removePadding(block.map(num.toDouble), width, stride)
val resultMask = padding.removePadding(maskBlock, width, stride)

val newOffset = padding.newOffset(xOffset, yOffset, stride)
val newSize = padding.newSize(width, height, stride)

new GDALBlock[Double](
resultBlock,
resultMask,
noDataValue,
newOffset._1,
newOffset._2,
newSize._1,
newSize._2,
Padding.NoPadding
)
}

}

object GDALBlock {

def getSize(offset: Int, maxSize: Int, blockSize: Int, stride: Int, paddingStrides: Int): Int = {
if (offset + blockSize + paddingStrides * stride > maxSize) {
maxSize - offset
} else {
blockSize + paddingStrides * stride
}
}

def apply(
band: MosaicRasterBandGDAL,
stride: Int,
xOffset: Int,
yOffset: Int,
blockSize: Int
): GDALBlock[Double] = {
val noDataValue = band.noDataValue
val rasterWidth = band.xSize
val rasterHeight = band.ySize
// Always read blockSize + stride pixels on every side
// This is fine since kernel size is always much smaller than blockSize

val padding = Padding(
left = xOffset != 0,
right = xOffset + blockSize < rasterWidth - 1, // not sure about -1
top = yOffset != 0,
bottom = yOffset + blockSize < rasterHeight - 1
)

val xo = Math.max(0, xOffset - stride)
val yo = Math.max(0, yOffset - stride)

val xs = getSize(xo, rasterWidth, blockSize, stride, padding.horizontalStrides)
val ys = getSize(yo, rasterHeight, blockSize, stride, padding.verticalStrides)

val block = Array.ofDim[Double](xs * ys)
val maskBlock = Array.ofDim[Double](xs * ys)

band.getBand.ReadRaster(xo, yo, xs, ys, block)
band.getBand.GetMaskBand().ReadRaster(xo, yo, xs, ys, maskBlock)

GDALBlock(
block,
maskBlock,
noDataValue,
xo,
yo,
xs,
ys,
padding
)
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.databricks.labs.mosaic.core.raster.gdal

import com.databricks.labs.mosaic.gdal.MosaicGDAL
import org.gdal.gdal.Band
import org.gdal.gdalconst.gdalconstConstants

Expand Down Expand Up @@ -255,4 +256,80 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) {
*/
def isNoDataMask: Boolean = band.GetMaskFlags() == gdalconstConstants.GMF_NODATA

def convolve(kernel: Array[Array[Double]]): Unit = {
val kernelWidth = kernel.head.length
val kernelHeight = kernel.length
val blockSize = MosaicGDAL.defaultBlockSize
val strideX = kernelWidth / 2
val strideY = kernelHeight / 2

val block = Array.ofDim[Double](blockSize * blockSize)
val maskBlock = Array.ofDim[Double](blockSize * blockSize)
val result = Array.ofDim[Double](blockSize * blockSize)

for (yOffset <- 0 until ySize by blockSize - strideY) {
for (xOffset <- 0 until xSize by blockSize - strideX) {
val xSize = Math.min(blockSize, this.xSize - xOffset)
val ySize = Math.min(blockSize, this.ySize - yOffset)

band.ReadRaster(xOffset, yOffset, xSize, ySize, block)
band.GetMaskBand().ReadRaster(xOffset, yOffset, xSize, ySize, maskBlock)

val currentBlock = GDALBlock[Double](block, maskBlock, noDataValue, xOffset, yOffset, xSize, ySize, Padding.NoPadding)

for (y <- 0 until ySize) {
for (x <- 0 until xSize) {
result(y * xSize + x) = currentBlock.convolveAt(x, y, kernel)
}
}

band.WriteRaster(xOffset, yOffset, xSize, ySize, block)
}
}
}

def filter(kernelSize: Int, operation: String, outputBand: Band): Unit = {
require(kernelSize % 2 == 1, "Kernel size must be odd")

val blockSize = MosaicGDAL.blockSize
val stride = kernelSize / 2

for (yOffset <- 0 until ySize by blockSize) {
for (xOffset <- 0 until xSize by blockSize) {

val currentBlock = GDALBlock(
this,
stride,
xOffset,
yOffset,
blockSize
)

val result = Array.ofDim[Double](currentBlock.block.length)

for (y <- 0 until currentBlock.height) {
for (x <- 0 until currentBlock.width) {
result(y * currentBlock.width + x) = operation match {
case "avg" => currentBlock.avgFilterAt(x, y, kernelSize)
case "min" => currentBlock.minFilterAt(x, y, kernelSize)
case "max" => currentBlock.maxFilterAt(x, y, kernelSize)
case "median" => currentBlock.medianFilterAt(x, y, kernelSize)
case "mode" => currentBlock.modeFilterAt(x, y, kernelSize)
case _ => throw new Exception("Invalid operation")
}
}
}

val trimmedResult = currentBlock.copy(block = result).trimBlock(stride)

outputBand.WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.block)
outputBand.FlushCache()
outputBand.GetMaskBand().WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.maskBlock)
outputBand.GetMaskBand().FlushCache()

}
}

}

}
Loading

0 comments on commit 79ff6e6

Please sign in to comment.