Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fix raster to grid #512

Merged
merged 30 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterBandGDAL, Mosaic
import com.databricks.labs.mosaic.core.raster.io.RasterCleaner
import com.databricks.labs.mosaic.core.raster.operator.transform.RasterTransform
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import com.databricks.labs.mosaic.gdal.MosaicGDAL
import com.databricks.labs.mosaic.gdal.MosaicGDAL.configureGDAL
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{BinaryType, DataType, StringType}
Expand Down Expand Up @@ -114,6 +115,8 @@ object GDAL {
} else {
raster
}
case _ =>
throw new IllegalArgumentException(s"Unsupported data type: $inputDT")
}
}

Expand All @@ -122,19 +125,17 @@ object GDAL {
*
* @param generatedRasters
* The rasters to write.
* @param checkpointPath
* The path to write the rasters to.
* @return
* Returns the paths of the written rasters.
*/
def writeRasters(generatedRasters: Seq[MosaicRasterGDAL], checkpointPath: String, rasterDT: DataType): Seq[Any] = {
def writeRasters(generatedRasters: Seq[MosaicRasterGDAL], rasterDT: DataType): Seq[Any] = {
generatedRasters.map(raster =>
if (raster != null) {
rasterDT match {
case StringType =>
val uuid = UUID.randomUUID().toString
val extension = GDAL.getExtension(raster.getDriversShortName)
val writePath = s"$checkpointPath/$uuid.$extension"
val writePath = s"${MosaicGDAL.checkpointPath}/$uuid.$extension"
val outPath = raster.writeToPath(writePath)
RasterCleaner.dispose(raster)
UTF8String.fromString(outPath)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package com.databricks.labs.mosaic.core.raster.gdal

import scala.reflect.ClassTag

case class GDALBlock[T: ClassTag](
block: Array[T],
maskBlock: Array[Double],
noDataValue: Double,
xOffset: Int,
yOffset: Int,
width: Int,
height: Int,
padding: Padding
)(implicit
num: Numeric[T]
) {

def elementAt(index: Int): T = block(index)

def maskAt(index: Int): Double = maskBlock(index)

def elementAt(x: Int, y: Int): T = block(y * width + x)

def maskAt(x: Int, y: Int): Double = maskBlock(y * width + x)

def rasterElementAt(x: Int, y: Int): T = block((y - yOffset) * width + (x - xOffset))

def rasterMaskAt(x: Int, y: Int): Double = maskBlock((y - yOffset) * width + (x - xOffset))

def valuesAt(x: Int, y: Int, kernelWidth: Int, kernelHeight: Int): Array[Double] = {
val kernelCenterX = kernelWidth / 2
val kernelCenterY = kernelHeight / 2
val values = Array.fill[Double](kernelWidth * kernelHeight)(noDataValue)
var n = 0
for (i <- 0 until kernelHeight) {
for (j <- 0 until kernelWidth) {
val xIndex = x + (j - kernelCenterX)
val yIndex = y + (i - kernelCenterY)
if (xIndex >= 0 && xIndex < width && yIndex >= 0 && yIndex < height) {
val maskValue = maskAt(xIndex, yIndex)
val value = elementAt(xIndex, yIndex)
if (maskValue != 0.0 && num.toDouble(value) != noDataValue) {
values(n) = num.toDouble(value)
n += 1
}
}
}
}
val result = values.filter(_ != noDataValue)
// always return only one NoDataValue if no valid values are found
// one and only one NoDataValue can be returned
// in all cases that have some valid values, the NoDataValue will be filtered out
if (result.isEmpty) {
Array(noDataValue)
} else {
result
}
}

// TODO: Test and fix, not tested, other filters work.
def convolveAt(x: Int, y: Int, kernel: Array[Array[Double]]): Double = {
val kernelWidth = kernel.head.length
val kernelHeight = kernel.length
val kernelCenterX = kernelWidth / 2
val kernelCenterY = kernelHeight / 2
var sum = 0.0
for (i <- 0 until kernelHeight) {
for (j <- 0 until kernelWidth) {
val xIndex = x + (j - kernelCenterX)
val yIndex = y + (i - kernelCenterY)
if (xIndex >= 0 && xIndex < width && yIndex >= 0 && yIndex < height) {
val maskValue = maskAt(xIndex, yIndex)
val value = rasterElementAt(xIndex, yIndex)
if (maskValue != 0.0 && num.toDouble(value) != noDataValue) {
sum += num.toDouble(value) * kernel(i)(j)
}
}
}
}
sum
}

def avgFilterAt(x: Int, y: Int, kernelSize: Int): Double = {
val values = valuesAt(x, y, kernelSize, kernelSize)
values.sum / values.length
}

def minFilterAt(x: Int, y: Int, kernelSize: Int): Double = {
val values = valuesAt(x, y, kernelSize, kernelSize)
values.min
}

def maxFilterAt(x: Int, y: Int, kernelSize: Int): Double = {
val values = valuesAt(x, y, kernelSize, kernelSize)
values.max
}

def medianFilterAt(x: Int, y: Int, kernelSize: Int): Double = {
val values = valuesAt(x, y, kernelSize, kernelSize)
val n = values.length
scala.util.Sorting.quickSort(values)
values(n / 2)
}

def modeFilterAt(x: Int, y: Int, kernelSize: Int): Double = {
val values = valuesAt(x, y, kernelSize, kernelSize)
val counts = values.groupBy(identity).mapValues(_.length)
counts.maxBy(_._2)._1
}

def trimBlock(stride: Int): GDALBlock[Double] = {
val resultBlock = padding.removePadding(block.map(num.toDouble), width, stride)
val resultMask = padding.removePadding(maskBlock, width, stride)

val newOffset = padding.newOffset(xOffset, yOffset, stride)
val newSize = padding.newSize(width, height, stride)

new GDALBlock[Double](
resultBlock,
resultMask,
noDataValue,
newOffset._1,
newOffset._2,
newSize._1,
newSize._2,
Padding.NoPadding
)
}

}

object GDALBlock {

def getSize(offset: Int, maxSize: Int, blockSize: Int, stride: Int, paddingStrides: Int): Int = {
if (offset + blockSize + paddingStrides * stride > maxSize) {
maxSize - offset
} else {
blockSize + paddingStrides * stride
}
}

def apply(
band: MosaicRasterBandGDAL,
stride: Int,
xOffset: Int,
yOffset: Int,
blockSize: Int
): GDALBlock[Double] = {
val noDataValue = band.noDataValue
val rasterWidth = band.xSize
val rasterHeight = band.ySize
// Always read blockSize + stride pixels on every side
// This is fine since kernel size is always much smaller than blockSize

val padding = Padding(
left = xOffset != 0,
right = xOffset + blockSize < rasterWidth - 1, // not sure about -1
top = yOffset != 0,
bottom = yOffset + blockSize < rasterHeight - 1
)

val xo = Math.max(0, xOffset - stride)
val yo = Math.max(0, yOffset - stride)

val xs = getSize(xo, rasterWidth, blockSize, stride, padding.horizontalStrides)
val ys = getSize(yo, rasterHeight, blockSize, stride, padding.verticalStrides)

val block = Array.ofDim[Double](xs * ys)
val maskBlock = Array.ofDim[Double](xs * ys)

band.getBand.ReadRaster(xo, yo, xs, ys, block)
band.getBand.GetMaskBand().ReadRaster(xo, yo, xs, ys, maskBlock)

GDALBlock(
block,
maskBlock,
noDataValue,
xo,
yo,
xs,
ys,
padding
)
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.databricks.labs.mosaic.core.raster.gdal

import com.databricks.labs.mosaic.gdal.MosaicGDAL
import org.gdal.gdal.Band
import org.gdal.gdalconst.gdalconstConstants

Expand Down Expand Up @@ -219,6 +220,30 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) {
}
}

/**
* Counts the number of pixels in the band. The mask is used to determine
* if a pixel is valid. If pixel value is noData or mask value is 0.0, the
* pixel is not counted.
*
* @return
* Returns the band's pixel count.
*/
def pixelCount: Int = {
val line = Array.ofDim[Double](band.GetXSize())
val maskLine = Array.ofDim[Double](band.GetXSize())
var count = 0
for (y <- 0 until band.GetYSize()) {
band.ReadRaster(0, y, band.GetXSize(), 1, line)
val maskRead = band.GetMaskBand().ReadRaster(0, y, band.GetXSize(), 1, maskLine)
if (maskRead != gdalconstConstants.CE_None) {
count = count + line.count(_ != noDataValue)
} else {
count = count + line.zip(maskLine).count { case (pixel, mask) => pixel != noDataValue && mask != 0.0 }
}
}
count
}

/**
* @return
* Returns the band's mask flags.
Expand All @@ -231,4 +256,80 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) {
*/
def isNoDataMask: Boolean = band.GetMaskFlags() == gdalconstConstants.GMF_NODATA

def convolve(kernel: Array[Array[Double]]): Unit = {
val kernelWidth = kernel.head.length
val kernelHeight = kernel.length
val blockSize = MosaicGDAL.defaultBlockSize
val strideX = kernelWidth / 2
val strideY = kernelHeight / 2

val block = Array.ofDim[Double](blockSize * blockSize)
val maskBlock = Array.ofDim[Double](blockSize * blockSize)
val result = Array.ofDim[Double](blockSize * blockSize)

for (yOffset <- 0 until ySize by blockSize - strideY) {
for (xOffset <- 0 until xSize by blockSize - strideX) {
val xSize = Math.min(blockSize, this.xSize - xOffset)
val ySize = Math.min(blockSize, this.ySize - yOffset)

band.ReadRaster(xOffset, yOffset, xSize, ySize, block)
band.GetMaskBand().ReadRaster(xOffset, yOffset, xSize, ySize, maskBlock)

val currentBlock = GDALBlock[Double](block, maskBlock, noDataValue, xOffset, yOffset, xSize, ySize, Padding.NoPadding)

for (y <- 0 until ySize) {
for (x <- 0 until xSize) {
result(y * xSize + x) = currentBlock.convolveAt(x, y, kernel)
}
}

band.WriteRaster(xOffset, yOffset, xSize, ySize, block)
}
}
}

def filter(kernelSize: Int, operation: String, outputBand: Band): Unit = {
require(kernelSize % 2 == 1, "Kernel size must be odd")

val blockSize = MosaicGDAL.blockSize
val stride = kernelSize / 2

for (yOffset <- 0 until ySize by blockSize) {
for (xOffset <- 0 until xSize by blockSize) {

val currentBlock = GDALBlock(
this,
stride,
xOffset,
yOffset,
blockSize
)

val result = Array.ofDim[Double](currentBlock.block.length)

for (y <- 0 until currentBlock.height) {
for (x <- 0 until currentBlock.width) {
result(y * currentBlock.width + x) = operation match {
case "avg" => currentBlock.avgFilterAt(x, y, kernelSize)
case "min" => currentBlock.minFilterAt(x, y, kernelSize)
case "max" => currentBlock.maxFilterAt(x, y, kernelSize)
case "median" => currentBlock.medianFilterAt(x, y, kernelSize)
case "mode" => currentBlock.modeFilterAt(x, y, kernelSize)
case _ => throw new Exception("Invalid operation")
}
}
}

val trimmedResult = currentBlock.copy(block = result).trimBlock(stride)

outputBand.WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.block)
outputBand.FlushCache()
outputBand.GetMaskBand().WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.maskBlock)
outputBand.GetMaskBand().FlushCache()

}
}

}

}
Loading
Loading