From 79ff6e6feeed61d98e542d9db8467bf6a595a4d1 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Thu, 1 Feb 2024 15:40:17 +0000 Subject: [PATCH] Fix the format name for grib files in tests. Fix temp location utils. Separate temp on worker location and off worker location. Introduce GDAL Block notion. Implement Kernel filters via GDALBlocks. Add additional params to gdal programs when they run. Fix TILED=YES issue with TIF files. Introduce writeOptions concept for tmp writes. Update expressions to take into account new concepts. Fix Zarr format issues with SerDeser. --- .../labs/mosaic/core/raster/api/GDAL.scala | 9 +- .../mosaic/core/raster/gdal/GDALBlock.scala | 186 ++++++++++++ .../raster/gdal/MosaicRasterBandGDAL.scala | 77 +++++ .../core/raster/gdal/MosaicRasterGDAL.scala | 228 +++++++++++---- .../gdal/MosaicRasterWriteOptions.scala | 55 ++++ .../mosaic/core/raster/gdal/Padding.scala | 58 ++++ .../operator/clip/RasterClipByVector.scala | 10 +- .../raster/operator/gdal/GDALBuildVRT.scala | 16 +- .../core/raster/operator/gdal/GDALCalc.scala | 17 +- .../raster/operator/gdal/GDALTranslate.scala | 25 +- .../core/raster/operator/gdal/GDALWarp.scala | 5 +- .../operator/gdal/OperatorOptions.scala | 35 +++ .../raster/operator/merge/MergeBands.scala | 14 +- .../raster/operator/merge/MergeRasters.scala | 18 +- .../operator/pixel/PixelCombineRasters.scala | 8 +- .../raster/operator/proj/RasterProject.scala | 7 +- .../operator/retile/OverlappingTiles.scala | 5 +- .../operator/retile/RasterTessellate.scala | 5 +- .../core/raster/operator/retile/ReTile.scala | 7 +- .../mosaic/core/types/RasterTileType.scala | 43 ++- .../core/types/model/MosaicRasterTile.scala | 26 +- .../mosaic/datasource/gdal/ReTileOnRead.scala | 15 +- .../mosaic/datasource/gdal/ReadInMemory.scala | 4 +- .../multiread/RasterAsGridReader.scala | 40 ++- .../mosaic/expressions/raster/RST_Avg.scala | 6 +- .../expressions/raster/RST_BandMetaData.scala | 3 +- .../expressions/raster/RST_BoundingBox.scala | 4 +- .../mosaic/expressions/raster/RST_Clip.scala | 3 +- .../expressions/raster/RST_CombineAvg.scala | 8 +- .../raster/RST_CombineAvgAgg.scala | 28 +- .../expressions/raster/RST_Convolve.scala | 73 +++++ .../expressions/raster/RST_DerivedBand.scala | 8 +- .../raster/RST_DerivedBandAgg.scala | 19 +- .../expressions/raster/RST_Filter.scala | 77 +++++ .../expressions/raster/RST_FromBands.scala | 8 +- .../expressions/raster/RST_FromContent.scala | 25 +- .../expressions/raster/RST_FromFile.scala | 10 +- .../expressions/raster/RST_GeoReference.scala | 4 +- .../expressions/raster/RST_GetNoData.scala | 5 +- .../raster/RST_GetSubdataset.scala | 13 +- .../expressions/raster/RST_Height.scala | 4 +- .../expressions/raster/RST_InitNoData.scala | 8 +- .../expressions/raster/RST_IsEmpty.scala | 4 +- .../expressions/raster/RST_MakeTiles.scala | 205 +++++++++++++ .../expressions/raster/RST_MapAlgebra.scala | 8 +- .../mosaic/expressions/raster/RST_Max.scala | 4 +- .../expressions/raster/RST_Median.scala | 7 +- .../expressions/raster/RST_MemSize.scala | 4 +- .../mosaic/expressions/raster/RST_Merge.scala | 8 +- .../expressions/raster/RST_MergeAgg.scala | 19 +- .../expressions/raster/RST_MetaData.scala | 4 +- .../mosaic/expressions/raster/RST_Min.scala | 5 +- .../mosaic/expressions/raster/RST_NDVI.scala | 8 +- .../expressions/raster/RST_NumBands.scala | 4 +- .../expressions/raster/RST_PixelCount.scala | 4 +- .../expressions/raster/RST_PixelHeight.scala | 4 +- .../expressions/raster/RST_PixelWidth.scala | 4 +- .../raster/RST_RasterToWorldCoord.scala | 4 +- .../raster/RST_RasterToWorldCoordX.scala | 4 +- .../raster/RST_RasterToWorldCoordY.scala | 4 +- .../expressions/raster/RST_ReTile.scala | 3 + .../expressions/raster/RST_Rotation.scala | 4 +- .../mosaic/expressions/raster/RST_SRID.scala | 4 +- .../expressions/raster/RST_ScaleX.scala | 4 +- .../expressions/raster/RST_ScaleY.scala | 4 +- .../expressions/raster/RST_SetNoData.scala | 8 +- .../mosaic/expressions/raster/RST_SkewX.scala | 4 +- .../mosaic/expressions/raster/RST_SkewY.scala | 4 +- .../expressions/raster/RST_Subdatasets.scala | 3 +- .../expressions/raster/RST_Summary.scala | 4 +- .../expressions/raster/RST_TryOpen.scala | 4 +- .../expressions/raster/RST_UpperLeftX.scala | 4 +- .../expressions/raster/RST_UpperLeftY.scala | 4 +- .../mosaic/expressions/raster/RST_Width.scala | 4 +- .../raster/RST_WorldToRasterCoord.scala | 5 +- .../raster/RST_WorldToRasterCoordX.scala | 4 +- .../raster/RST_WorldToRasterCoordY.scala | 4 +- .../raster/base/Raster1ArgExpression.scala | 17 +- .../raster/base/Raster2ArgExpression.scala | 16 +- .../base/RasterArray1ArgExpression.scala | 12 +- .../base/RasterArray2ArgExpression.scala | 12 +- .../raster/base/RasterArrayExpression.scala | 10 +- .../raster/base/RasterArrayUtils.scala | 8 +- .../raster/base/RasterBandExpression.scala | 17 +- .../raster/base/RasterExpression.scala | 16 +- .../base/RasterExpressionSerialization.scala | 4 +- .../base/RasterGeneratorExpression.scala | 11 +- .../RasterTessellateGeneratorExpression.scala | 15 +- .../raster/base/RasterToGridExpression.scala | 4 +- .../labs/mosaic/functions/MosaicContext.scala | 24 +- .../functions/MosaicExpressionConfig.scala | 2 + .../labs/mosaic/gdal/MosaicGDAL.scala | 66 +++-- .../com/databricks/labs/mosaic/package.scala | 6 +- .../labs/mosaic/utils/FileUtils.scala | 6 +- .../labs/mosaic/utils/PathUtils.scala | 66 ++++- .../labs/mosaic/utils/SysUtils.scala | 44 ++- ...-041ac051-015d-49b0-95df-b5daa7084c7e.grb} | Bin ...1-015d-49b0-95df-b5daa7084c7e.grb.aux.xml} | 0 ...-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb} | Bin ...6-16ca-4e11-919d-bdbd5a51da35.grb.aux.xml} | 0 ...-0ede0273-89e3-4100-a0f2-48916ca607ed.grb} | Bin ...3-89e3-4100-a0f2-48916ca607ed.grb.aux.xml} | 0 .../core/raster/TestRasterBandGDAL.scala | 4 +- .../mosaic/core/raster/TestRasterGDAL.scala | 223 +++++++++++++- .../datasource/GDALFileFormatTest.scala | 19 +- .../multiread/RasterAsGridReaderTest.scala | 272 +++++++++--------- .../raster/RST_CombineAvgBehaviors.scala | 4 +- .../raster/RST_FilterBehaviors.scala | 36 +++ .../expressions/raster/RST_FilterTest.scala | 32 +++ .../expressions/raster/RST_MinBehaviors.scala | 2 +- .../raster/RST_TessellateBehaviors.scala | 11 +- .../databricks/labs/mosaic/test/package.scala | 2 +- .../sql/test/MosaicTestSparkSession.scala | 4 +- .../sql/test/SharedSparkSessionGDAL.scala | 6 +- 114 files changed, 2005 insertions(+), 560 deletions(-) create mode 100644 src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterWriteOptions.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/Padding.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Filter.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib => adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb} (100%) rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib.aux.xml => adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb.aux.xml} (100%) rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib => adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb} (100%) rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib.aux.xml => adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb.aux.xml} (100%) rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib => adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb} (100%) rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib.aux.xml => adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb.aux.xml} (100%) create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterTest.scala diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala index 66bde39a3..b86489359 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala @@ -4,6 +4,7 @@ import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterBandGDAL, Mosaic import com.databricks.labs.mosaic.core.raster.io.RasterCleaner import com.databricks.labs.mosaic.core.raster.operator.transform.RasterTransform import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.gdal.MosaicGDAL import com.databricks.labs.mosaic.gdal.MosaicGDAL.configureGDAL import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.{BinaryType, DataType, StringType} @@ -114,6 +115,8 @@ object GDAL { } else { raster } + case _ => + throw new IllegalArgumentException(s"Unsupported data type: $inputDT") } } @@ -122,19 +125,17 @@ object GDAL { * * @param generatedRasters * The rasters to write. - * @param checkpointPath - * The path to write the rasters to. * @return * Returns the paths of the written rasters. */ - def writeRasters(generatedRasters: Seq[MosaicRasterGDAL], checkpointPath: String, rasterDT: DataType): Seq[Any] = { + def writeRasters(generatedRasters: Seq[MosaicRasterGDAL], rasterDT: DataType): Seq[Any] = { generatedRasters.map(raster => if (raster != null) { rasterDT match { case StringType => val uuid = UUID.randomUUID().toString val extension = GDAL.getExtension(raster.getDriversShortName) - val writePath = s"$checkpointPath/$uuid.$extension" + val writePath = s"${MosaicGDAL.checkpointPath}/$uuid.$extension" val outPath = raster.writeToPath(writePath) RasterCleaner.dispose(raster) UTF8String.fromString(outPath) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala new file mode 100644 index 000000000..8c5c7a495 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala @@ -0,0 +1,186 @@ +package com.databricks.labs.mosaic.core.raster.gdal + +import scala.reflect.ClassTag + +case class GDALBlock[T: ClassTag]( + block: Array[T], + maskBlock: Array[Double], + noDataValue: Double, + xOffset: Int, + yOffset: Int, + width: Int, + height: Int, + padding: Padding +)(implicit + num: Numeric[T] +) { + + def elementAt(index: Int): T = block(index) + + def maskAt(index: Int): Double = maskBlock(index) + + def elementAt(x: Int, y: Int): T = block(y * width + x) + + def maskAt(x: Int, y: Int): Double = maskBlock(y * width + x) + + def rasterElementAt(x: Int, y: Int): T = block((y - yOffset) * width + (x - xOffset)) + + def rasterMaskAt(x: Int, y: Int): Double = maskBlock((y - yOffset) * width + (x - xOffset)) + + def valuesAt(x: Int, y: Int, kernelWidth: Int, kernelHeight: Int): Array[Double] = { + val kernelCenterX = kernelWidth / 2 + val kernelCenterY = kernelHeight / 2 + val values = Array.fill[Double](kernelWidth * kernelHeight)(noDataValue) + var n = 0 + for (i <- 0 until kernelHeight) { + for (j <- 0 until kernelWidth) { + val xIndex = x + (j - kernelCenterX) + val yIndex = y + (i - kernelCenterY) + if (xIndex >= 0 && xIndex < width && yIndex >= 0 && yIndex < height) { + val maskValue = maskAt(xIndex, yIndex) + val value = elementAt(xIndex, yIndex) + if (maskValue != 0.0 && num.toDouble(value) != noDataValue) { + values(n) = num.toDouble(value) + n += 1 + } + } + } + } + val result = values.filter(_ != noDataValue) + // always return only one NoDataValue if no valid values are found + // one and only one NoDataValue can be returned + // in all cases that have some valid values, the NoDataValue will be filtered out + if (result.isEmpty) { + Array(noDataValue) + } else { + result + } + } + + // TODO: Test and fix, not tested, other filters work. + def convolveAt(x: Int, y: Int, kernel: Array[Array[Double]]): Double = { + val kernelWidth = kernel.head.length + val kernelHeight = kernel.length + val kernelCenterX = kernelWidth / 2 + val kernelCenterY = kernelHeight / 2 + var sum = 0.0 + for (i <- 0 until kernelHeight) { + for (j <- 0 until kernelWidth) { + val xIndex = x + (j - kernelCenterX) + val yIndex = y + (i - kernelCenterY) + if (xIndex >= 0 && xIndex < width && yIndex >= 0 && yIndex < height) { + val maskValue = maskAt(xIndex, yIndex) + val value = rasterElementAt(xIndex, yIndex) + if (maskValue != 0.0 && num.toDouble(value) != noDataValue) { + sum += num.toDouble(value) * kernel(i)(j) + } + } + } + } + sum + } + + def avgFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + values.sum / values.length + } + + def minFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + values.min + } + + def maxFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + values.max + } + + def medianFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + val n = values.length + scala.util.Sorting.quickSort(values) + values(n / 2) + } + + def modeFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + val counts = values.groupBy(identity).mapValues(_.length) + counts.maxBy(_._2)._1 + } + + def trimBlock(stride: Int): GDALBlock[Double] = { + val resultBlock = padding.removePadding(block.map(num.toDouble), width, stride) + val resultMask = padding.removePadding(maskBlock, width, stride) + + val newOffset = padding.newOffset(xOffset, yOffset, stride) + val newSize = padding.newSize(width, height, stride) + + new GDALBlock[Double]( + resultBlock, + resultMask, + noDataValue, + newOffset._1, + newOffset._2, + newSize._1, + newSize._2, + Padding.NoPadding + ) + } + +} + +object GDALBlock { + + def getSize(offset: Int, maxSize: Int, blockSize: Int, stride: Int, paddingStrides: Int): Int = { + if (offset + blockSize + paddingStrides * stride > maxSize) { + maxSize - offset + } else { + blockSize + paddingStrides * stride + } + } + + def apply( + band: MosaicRasterBandGDAL, + stride: Int, + xOffset: Int, + yOffset: Int, + blockSize: Int + ): GDALBlock[Double] = { + val noDataValue = band.noDataValue + val rasterWidth = band.xSize + val rasterHeight = band.ySize + // Always read blockSize + stride pixels on every side + // This is fine since kernel size is always much smaller than blockSize + + val padding = Padding( + left = xOffset != 0, + right = xOffset + blockSize < rasterWidth - 1, // not sure about -1 + top = yOffset != 0, + bottom = yOffset + blockSize < rasterHeight - 1 + ) + + val xo = Math.max(0, xOffset - stride) + val yo = Math.max(0, yOffset - stride) + + val xs = getSize(xo, rasterWidth, blockSize, stride, padding.horizontalStrides) + val ys = getSize(yo, rasterHeight, blockSize, stride, padding.verticalStrides) + + val block = Array.ofDim[Double](xs * ys) + val maskBlock = Array.ofDim[Double](xs * ys) + + band.getBand.ReadRaster(xo, yo, xs, ys, block) + band.getBand.GetMaskBand().ReadRaster(xo, yo, xs, ys, maskBlock) + + GDALBlock( + block, + maskBlock, + noDataValue, + xo, + yo, + xs, + ys, + padding + ) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala index a7c9ece10..281eb8b01 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala @@ -1,5 +1,6 @@ package com.databricks.labs.mosaic.core.raster.gdal +import com.databricks.labs.mosaic.gdal.MosaicGDAL import org.gdal.gdal.Band import org.gdal.gdalconst.gdalconstConstants @@ -255,4 +256,80 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) { */ def isNoDataMask: Boolean = band.GetMaskFlags() == gdalconstConstants.GMF_NODATA + def convolve(kernel: Array[Array[Double]]): Unit = { + val kernelWidth = kernel.head.length + val kernelHeight = kernel.length + val blockSize = MosaicGDAL.defaultBlockSize + val strideX = kernelWidth / 2 + val strideY = kernelHeight / 2 + + val block = Array.ofDim[Double](blockSize * blockSize) + val maskBlock = Array.ofDim[Double](blockSize * blockSize) + val result = Array.ofDim[Double](blockSize * blockSize) + + for (yOffset <- 0 until ySize by blockSize - strideY) { + for (xOffset <- 0 until xSize by blockSize - strideX) { + val xSize = Math.min(blockSize, this.xSize - xOffset) + val ySize = Math.min(blockSize, this.ySize - yOffset) + + band.ReadRaster(xOffset, yOffset, xSize, ySize, block) + band.GetMaskBand().ReadRaster(xOffset, yOffset, xSize, ySize, maskBlock) + + val currentBlock = GDALBlock[Double](block, maskBlock, noDataValue, xOffset, yOffset, xSize, ySize, Padding.NoPadding) + + for (y <- 0 until ySize) { + for (x <- 0 until xSize) { + result(y * xSize + x) = currentBlock.convolveAt(x, y, kernel) + } + } + + band.WriteRaster(xOffset, yOffset, xSize, ySize, block) + } + } + } + + def filter(kernelSize: Int, operation: String, outputBand: Band): Unit = { + require(kernelSize % 2 == 1, "Kernel size must be odd") + + val blockSize = MosaicGDAL.blockSize + val stride = kernelSize / 2 + + for (yOffset <- 0 until ySize by blockSize) { + for (xOffset <- 0 until xSize by blockSize) { + + val currentBlock = GDALBlock( + this, + stride, + xOffset, + yOffset, + blockSize + ) + + val result = Array.ofDim[Double](currentBlock.block.length) + + for (y <- 0 until currentBlock.height) { + for (x <- 0 until currentBlock.width) { + result(y * currentBlock.width + x) = operation match { + case "avg" => currentBlock.avgFilterAt(x, y, kernelSize) + case "min" => currentBlock.minFilterAt(x, y, kernelSize) + case "max" => currentBlock.maxFilterAt(x, y, kernelSize) + case "median" => currentBlock.medianFilterAt(x, y, kernelSize) + case "mode" => currentBlock.modeFilterAt(x, y, kernelSize) + case _ => throw new Exception("Invalid operation") + } + } + } + + val trimmedResult = currentBlock.copy(block = result).trimBlock(stride) + + outputBand.WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.block) + outputBand.FlushCache() + outputBand.GetMaskBand().WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.maskBlock) + outputBand.GetMaskBand().FlushCache() + + } + } + + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala index 3ac467f53..b63bd851e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -8,9 +8,9 @@ import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose import com.databricks.labs.mosaic.core.raster.io.{RasterCleaner, RasterReader, RasterWriter} import com.databricks.labs.mosaic.core.raster.operator.clip.RasterClipByVector import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum.POLYGON -import com.databricks.labs.mosaic.utils.{FileUtils, PathUtils} -import org.gdal.gdal.gdal.GDALInfo -import org.gdal.gdal.{Dataset, InfoOptions, gdal} +import com.databricks.labs.mosaic.gdal.MosaicGDAL +import com.databricks.labs.mosaic.utils.{FileUtils, PathUtils, SysUtils} +import org.gdal.gdal.{Dataset, gdal} import org.gdal.gdalconst.gdalconstConstants._ import org.gdal.osr import org.gdal.osr.SpatialReference @@ -32,25 +32,41 @@ case class MosaicRasterGDAL( ) extends RasterWriter with RasterCleaner { + def getWriteOptions: MosaicRasterWriteOptions = MosaicRasterWriteOptions(this) + + def getCompression: String = { + val compression = Option(raster.GetMetadata_Dict("IMAGE_STRUCTURE")) + .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) + .getOrElse(Map.empty[String, String]) + .getOrElse("COMPRESSION", "NONE") + compression + } + def getSpatialReference: SpatialReference = { - if (raster != null) { - raster.GetSpatialRef + val spatialRef = + if (raster != null) { + raster.GetSpatialRef + } else { + val tmp = refresh() + val result = tmp.raster.GetSpatialRef + dispose(tmp) + result + } + if (spatialRef == null) { + MosaicGDAL.WSG84 } else { - val tmp = refresh() - val result = tmp.spatialRef - dispose(tmp) - result + spatialRef } } + def isSubDataset: Boolean = { + val isSubdataset = PathUtils.isSubdataset(path) + isSubdataset + } + // Factory for creating CRS objects protected val crsFactory: CRSFactory = new CRSFactory - // Only use this with GDAL rasters - private val wsg84 = new osr.SpatialReference() - wsg84.ImportFromEPSG(4326) - wsg84.SetAxisMappingStrategy(osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER) - /** * @return * The raster's driver short name. @@ -157,6 +173,7 @@ case class MosaicRasterGDAL( .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) .getOrElse(Map.empty[String, String]) val keys = subdatasetsMap.keySet + val sanitizedParentPath = PathUtils.getCleanPath(parentPath) keys.flatMap(key => if (key.toUpperCase(Locale.ROOT).contains("NAME")) { val path = subdatasetsMap(key) @@ -164,7 +181,7 @@ case class MosaicRasterGDAL( Seq( key -> pieces.last, s"${pieces.last}_tmp" -> path, - pieces.last -> s"${pieces.head}:$parentPath:${pieces.last}" + pieces.last -> s"${pieces.head}:$sanitizedParentPath:${pieces.last}" ) } else Seq(key -> subdatasetsMap(key)) ).toMap @@ -253,12 +270,6 @@ case class MosaicRasterGDAL( */ def getRaster: Dataset = this.raster - /** - * @return - * Returns the raster's spatial reference. - */ - def spatialRef: SpatialReference = raster.GetSpatialRef() - /** * Applies a function to each band of the raster. * @param f @@ -272,10 +283,10 @@ case class MosaicRasterGDAL( * @return * Returns MosaicGeometry representing bounding box of the raster. */ - def bbox(geometryAPI: GeometryAPI, destCRS: SpatialReference = wsg84): MosaicGeometry = { + def bbox(geometryAPI: GeometryAPI, destCRS: SpatialReference = MosaicGDAL.WSG84): MosaicGeometry = { val gt = getGeoTransform - val sourceCRS = spatialRef + val sourceCRS = getSpatialReference val transform = new osr.CoordinateTransformation(sourceCRS, destCRS) val bbox = geometryAPI.geometry( @@ -300,23 +311,10 @@ case class MosaicRasterGDAL( * compute since it requires reading the raster and computing statistics. */ def isEmpty: Boolean = { - import org.json4s._ - import org.json4s.jackson.JsonMethods._ - implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats - - val vector = new JVector[String]() - vector.add("-stats") - vector.add("-json") - val infoOptions = new InfoOptions(vector) - val gdalInfo = GDALInfo(raster, infoOptions) - val json = parse(gdalInfo).extract[Map[String, Any]] - - if (json.contains("STATISTICS_VALID_PERCENT")) { - json("STATISTICS_VALID_PERCENT").asInstanceOf[Double] == 0.0 - } else if (subdatasets.nonEmpty) { + if (subdatasets.nonEmpty) { false } else { - getBandStats.values.map(_.getOrElse("mean", 0.0)).forall(_ == 0.0) + getValidCount.values.sum == 0 } } @@ -347,11 +345,18 @@ case class MosaicRasterGDAL( val isSubdataset = PathUtils.isSubdataset(path) val filePath = if (isSubdataset) PathUtils.fromSubdatasetPath(path) else path val pamFilePath = s"$filePath.aux.xml" + val cleanPath = filePath.replace("/vsizip/", "") + val zipPath = if (cleanPath.endsWith("zip")) cleanPath else s"$cleanPath.zip" if (path != PathUtils.getCleanPath(parentPath)) { Try(gdal.GetDriverByName(driverShortName).Delete(path)) + Try(Files.deleteIfExists(Paths.get(cleanPath))) Try(Files.deleteIfExists(Paths.get(path))) Try(Files.deleteIfExists(Paths.get(filePath))) Try(Files.deleteIfExists(Paths.get(pamFilePath))) + if (Files.exists(Paths.get(zipPath))) { + Try(Files.deleteIfExists(Paths.get(zipPath.replace(".zip", "")))) + } + Try(Files.deleteIfExists(Paths.get(zipPath))) } } @@ -382,12 +387,26 @@ case class MosaicRasterGDAL( * A boolean indicating if the write was successful. */ def writeToPath(path: String, dispose: Boolean = true): String = { - val driver = raster.GetDriver() - val ds = driver.CreateCopy(path, this.flushCache().getRaster) - ds.FlushCache() - ds.delete() - if (dispose) RasterCleaner.dispose(this) - path + if (isSubDataset) { + val driver = raster.GetDriver() + val ds = driver.CreateCopy(path, this.flushCache().getRaster, 1) + if (ds == null) { + val error = gdal.GetLastErrorMsg() + throw new Exception(s"Error writing raster to path: $error") + } + ds.FlushCache() + ds.delete() + if (dispose) RasterCleaner.dispose(this) + path + } else { + val thisPath = Paths.get(this.path) + val fromDir = thisPath.getParent + val toDir = Paths.get(path).getParent + val stemRegex = PathUtils.getStemRegex(this.path) + PathUtils.wildcardCopy(fromDir.toString, toDir.toString, stemRegex) + if (dispose) RasterCleaner.dispose(this) + s"$toDir/${thisPath.getFileName}" + } } /** @@ -398,17 +417,33 @@ case class MosaicRasterGDAL( */ def writeToBytes(dispose: Boolean = true): Array[Byte] = { val isSubdataset = PathUtils.isSubdataset(path) - val readPath = - if (isSubdataset) { - val tmpPath = PathUtils.createTmpFilePath(getRasterFileExtension) - writeToPath(tmpPath, dispose = false) + val readPath = { + val tmpPath = + if (isSubdataset) { + val tmpPath = PathUtils.createTmpFilePath(getRasterFileExtension) + writeToPath(tmpPath, dispose = false) + tmpPath + } else { + path + } + if (Files.isDirectory(Paths.get(tmpPath))) { + SysUtils.runCommand(s"zip -r0 $tmpPath.zip $tmpPath") + s"$tmpPath.zip" } else { - path + tmpPath } + } val byteArray = FileUtils.readBytes(readPath) if (dispose) RasterCleaner.dispose(this) if (readPath != PathUtils.getCleanPath(parentPath)) { Files.deleteIfExists(Paths.get(readPath)) + if (readPath.endsWith(".zip")) { + val nonZipPath = readPath.replace(".zip", "") + if (Files.isDirectory(Paths.get(nonZipPath))) { + SysUtils.runCommand(s"rm -rf $nonZipPath") + } + Files.deleteIfExists(Paths.get(readPath.replace(".zip", ""))) + } } byteArray } @@ -464,6 +499,20 @@ case class MosaicRasterGDAL( .toMap } + /** + * @return + * Returns the raster's band valid pixel count. + */ + def getValidCount: Map[Int, Long] = { + (1 to numBands) + .map(i => { + val band = raster.GetRasterBand(i) + val validCount = band.AsMDArray().GetStatistics().getValid_count + i -> validCount + }) + .toMap + } + /** * @param subsetName * The name of the subdataset to get. @@ -471,24 +520,59 @@ case class MosaicRasterGDAL( * Returns the raster's subdataset with given name. */ def getSubdataset(subsetName: String): MosaicRasterGDAL = { - subdatasets - val path = Option(raster.GetMetadata_Dict("SUBDATASETS")) - .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) - .getOrElse(Map.empty[String, String]) - .values - .find(_.toUpperCase(Locale.ROOT).endsWith(subsetName.toUpperCase(Locale.ROOT))) - .getOrElse(throw new Exception(s""" - |Subdataset $subsetName not found! - |Available subdatasets: - | ${subdatasets.keys.filterNot(_.startsWith("SUBDATASET_")).mkString(", ")} - """.stripMargin)) - val ds = openRaster(path) + val path = subdatasets.getOrElse( + s"${subsetName}_tmp", + throw new Exception(s""" + |Subdataset $subsetName not found! + |Available subdatasets: + | ${subdatasets.keys.filterNot(_.startsWith("SUBDATASET_")).mkString(", ")} + | """.stripMargin) + ) + val sanitized = PathUtils.getCleanPath(path) + val subdatasetPath = PathUtils.getSubdatasetPath(sanitized) + + val ds = openRaster(subdatasetPath) // Avoid costly IO to compute MEM size here // It will be available when the raster is serialized for next operation // If value is needed then it will be computed when getMemSize is called MosaicRasterGDAL(ds, path, parentPath, driverShortName, -1) } + def convolve(kernel: Array[Array[Double]]): MosaicRasterGDAL = { + val resultRasterPath = PathUtils.createTmpFilePath(getRasterFileExtension) + val outputRaster = this.raster + .GetDriver() + .Create(resultRasterPath, this.xSize, this.ySize, this.numBands, this.raster.GetRasterBand(1).getDataType) + + for (bandIndex <- 1 to this.numBands) { + val band = this.getBand(bandIndex) + band.convolve(kernel) + } + + MosaicRasterGDAL(outputRaster, resultRasterPath, parentPath, driverShortName, -1) + + } + + def filter(kernelSize: Int, operation: String): MosaicRasterGDAL = { + val resultRasterPath = PathUtils.createTmpFilePath(getRasterFileExtension) + + this.raster + .GetDriver() + .CreateCopy(resultRasterPath, this.raster, 1) + .delete() + + val outputRaster = gdal.Open(resultRasterPath, GF_Write) + + for (bandIndex <- 1 to this.numBands) { + val band = this.getBand(bandIndex) + val outputBand = outputRaster.GetRasterBand(bandIndex) + band.filter(kernelSize, operation, outputBand) + } + + val result = MosaicRasterGDAL(outputRaster, resultRasterPath, parentPath, driverShortName, this.memSize) + result.flushCache() + } + } //noinspection ZeroIndexToHead @@ -583,11 +667,29 @@ object MosaicRasterGDAL extends RasterReader { // Try reading as a tmp file, if that fails, rename as a zipped file val dataset = openRaster(tmpPath, Some(driverShortName)) if (dataset == null) { - val zippedPath = PathUtils.createTmpFilePath("zip") + val zippedPath = s"$tmpPath.zip" Files.move(Paths.get(tmpPath), Paths.get(zippedPath), StandardCopyOption.REPLACE_EXISTING) val readPath = PathUtils.getZipPath(zippedPath) val ds = openRaster(readPath, Some(driverShortName)) - MosaicRasterGDAL(ds, readPath, parentPath, driverShortName, contentBytes.length) + if (ds == null) { + // the way we zip using uuid is not compatible with GDAL + // we need to unzip and read the file if it was zipped by us + val parentDir = Paths.get(zippedPath).getParent + val prompt = SysUtils.runScript(Array("/bin/sh", "-c", s"cd $parentDir && unzip -o $zippedPath -d /")) + // zipped files will have the old uuid name of the raster + // we need to get the last extracted file name, but the last extracted file name is not the raster name + // we can't list folders due to concurrent writes + val extension = GDAL.getExtension(driverShortName) + val lastExtracted = SysUtils.getLastOutputLine(prompt) + val unzippedPath = PathUtils.parseUnzippedPathFromExtracted(lastExtracted, extension) + val dataset = openRaster(unzippedPath, Some(driverShortName)) + if (dataset == null) { + throw new Exception(s"Error reading raster from bytes: ${prompt._3}") + } + MosaicRasterGDAL(dataset, unzippedPath, parentPath, driverShortName, contentBytes.length) + } else { + MosaicRasterGDAL(ds, readPath, parentPath, driverShortName, contentBytes.length) + } } else { MosaicRasterGDAL(dataset, tmpPath, parentPath, driverShortName, contentBytes.length) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterWriteOptions.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterWriteOptions.scala new file mode 100644 index 000000000..68a7bd75a --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterWriteOptions.scala @@ -0,0 +1,55 @@ +package com.databricks.labs.mosaic.core.raster.gdal + +import com.databricks.labs.mosaic.gdal.MosaicGDAL +import org.gdal.osr.SpatialReference + +case class MosaicRasterWriteOptions( + compression: String = "DEFLATE", + format: String = "GTiff", + extension: String = "tif", + resampling: String = "nearest", + crs: SpatialReference = MosaicGDAL.WSG84, // Assume WGS84 + pixelSize: Option[(Double, Double)] = None, + noDataValue: Option[Double] = None, + missingGeoRef: Boolean = false, + options: Map[String, String] = Map.empty[String, String] +) + +object MosaicRasterWriteOptions { + + val VRT: MosaicRasterWriteOptions = + MosaicRasterWriteOptions( + compression = "NONE", + format = "VRT", + extension = "vrt", + crs = MosaicGDAL.WSG84, + pixelSize = None, + noDataValue = None, + options = Map.empty[String, String] + ) + + val GTiff: MosaicRasterWriteOptions = MosaicRasterWriteOptions() + + def noGPCsNoTransform(raster: MosaicRasterGDAL): Boolean = { + val noGPCs = raster.getRaster.GetGCPCount == 0 + val noGeoTransform = raster.getRaster.GetGeoTransform == null || + (raster.getRaster.GetGeoTransform sameElements Array(0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) + noGPCs && noGeoTransform + } + + def apply(): MosaicRasterWriteOptions = new MosaicRasterWriteOptions() + + def apply(raster: MosaicRasterGDAL): MosaicRasterWriteOptions = { + val compression = raster.getCompression + val format = raster.getRaster.GetDriver.getShortName + val extension = raster.getRasterFileExtension + val resampling = "nearest" + val pixelSize = None + val noDataValue = None + val options = Map.empty[String, String] + val crs = raster.getSpatialReference + val missingGeoRef = noGPCsNoTransform(raster) + new MosaicRasterWriteOptions(compression, format, extension, resampling, crs, pixelSize, noDataValue, missingGeoRef, options) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/Padding.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/Padding.scala new file mode 100644 index 000000000..bb32e772f --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/Padding.scala @@ -0,0 +1,58 @@ +package com.databricks.labs.mosaic.core.raster.gdal + +case class Padding( + left: Boolean, + right: Boolean, + top: Boolean, + bottom: Boolean +) { + + def removePadding(array: Array[Double], rowWidth: Int, stride: Int): Array[Double] = { + val l = if (left) 1 else 0 + val r = if (right) 1 else 0 + val t = if (top) 1 else 0 + val b = if (bottom) 1 else 0 + + val yStart = t * stride * rowWidth + val yEnd = array.length - b * stride * rowWidth + + val slices = for (i <- yStart until yEnd by rowWidth) yield { + val xStart = i + l * stride + val xEnd = i + rowWidth - r * stride + array.slice(xStart, xEnd) + } + + slices.flatten.toArray + } + + def horizontalStrides: Int = { + if (left && right) 2 + else if (left || right) 1 + else 0 + } + + def verticalStrides: Int = { + if (top && bottom) 2 + else if (top || bottom) 1 + else 0 + } + + def newOffset(xOffset: Int, yOffset: Int, stride: Int): (Int, Int) = { + val x = if (left) xOffset + stride else xOffset + val y = if (top) yOffset + stride else yOffset + (x, y) + } + + def newSize(width: Int, height: Int, stride: Int): (Int, Int) = { + val w = if (left && right) width - 2 * stride else if (left || right) width - stride else width + val h = if (top && bottom) height - 2 * stride else if (top || bottom) height - stride else height + (w, h) + } + +} + +object Padding { + + val NoPadding: Padding = Padding(left = false, right = false, top = false, bottom = false) + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala index 6daabc25c..56c29563f 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala @@ -19,8 +19,7 @@ object RasterClipByVector { * abstractions over GDAL Warp. It uses CUTLINE_ALL_TOUCHED=TRUE to ensure * that all pixels that touch the geometry are included. This will avoid * the issue of having a pixel that is half in and half out of the - * geometry, important for tessellation. It also uses COMPRESS=DEFLATE to - * ensure that the output is compressed. The method also uses the geometry + * geometry, important for tessellation. The method also uses the geometry * API to generate a shapefile that is used to clip the raster. The * shapefile is deleted after the clip is complete. * @@ -38,16 +37,19 @@ object RasterClipByVector { def clip(raster: MosaicRasterGDAL, geometry: MosaicGeometry, geomCRS: SpatialReference, geometryAPI: GeometryAPI): MosaicRasterGDAL = { val rasterCRS = raster.getSpatialReference val outShortName = raster.getDriversShortName - val geomSrcCRS = if (geomCRS == null ) rasterCRS else geomCRS + val geomSrcCRS = if (geomCRS == null) rasterCRS else geomCRS val resultFileName = PathUtils.createTmpFilePath(GDAL.getExtension(outShortName)) val shapeFileName = VectorClipper.generateClipper(geometry, geomSrcCRS, rasterCRS, geometryAPI) + // For -wo consult https://gdal.org/doxygen/structGDALWarpOptions.html + // SOURCE_EXTRA=3 is used to ensure that when the raster is clipped, the + // pixels that touch the geometry are included. The default is 1, 3 seems to be a good empirical value. val result = GDALWarp.executeWarp( resultFileName, Seq(raster), - command = s"gdalwarp -wo CUTLINE_ALL_TOUCHED=TRUE -of $outShortName -cutline $shapeFileName -crop_to_cutline -co COMPRESS=DEFLATE -dstalpha" + command = s"gdalwarp -wo CUTLINE_ALL_TOUCHED=TRUE -wo SOURCE_EXTRA=3 -cutline $shapeFileName -crop_to_cutline" ) VectorClipper.cleanUpClipper(shapeFileName) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala index 389defad6..9e1e97401 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterGDAL, MosaicRasterWriteOptions} import org.gdal.gdal.{BuildVRTOptions, gdal} /** GDALBuildVRT is a wrapper for the GDAL BuildVRT command. */ @@ -20,16 +20,16 @@ object GDALBuildVRT { */ def executeVRT(outputPath: String, rasters: Seq[MosaicRasterGDAL], command: String): MosaicRasterGDAL = { require(command.startsWith("gdalbuildvrt"), "Not a valid GDAL Build VRT command.") - val vrtOptionsVec = OperatorOptions.parseOptions(command) + val effectiveCommand = OperatorOptions.appendOptions(command, MosaicRasterWriteOptions.VRT) + val vrtOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val vrtOptions = new BuildVRTOptions(vrtOptionsVec) val result = gdal.BuildVRT(outputPath, rasters.map(_.getRaster).toArray, vrtOptions) if (result == null) { - throw new Exception( - s""" - |Build VRT failed. - |Command: $command - |Error: ${gdal.GetLastErrorMsg} - |""".stripMargin) + throw new Exception(s""" + |Build VRT failed. + |Command: $effectiveCommand + |Error: ${gdal.GetLastErrorMsg} + |""".stripMargin) } // TODO: Figure out multiple parents, should this be an array? // VRT files are just meta files, mem size doesnt make much sense so we keep -1 diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala index 97a273d13..cc9c5e500 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala @@ -1,19 +1,16 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal import com.databricks.labs.mosaic.core.raster.api.GDAL -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterGDAL, MosaicRasterWriteOptions} import com.databricks.labs.mosaic.utils.SysUtils /** GDALCalc is a helper object for executing GDAL Calc commands. */ object GDALCalc { val gdal_calc: String = { - val calcPath = SysUtils.runCommand("find / -iname gdal_calc.py")._1.split("\n").headOption.getOrElse("") - if (calcPath.isEmpty) { - throw new RuntimeException("Could not find gdal_calc.py.") - } - if (calcPath == "ERROR") { - "/usr/lib/python3/dist-packages/osgeo_utils/gdal_calc.py" + val calcPath = SysUtils.runCommand("""find / -maxdepth 20 -iname gdal_calc.py""")._1.split("\n").headOption.getOrElse("") + if (calcPath.isEmpty || calcPath.startsWith("ERROR")) { + "/usr/local/lib/python3.10/dist-packages/osgeo_utils/gdal_calc.py" } else { calcPath } @@ -30,11 +27,13 @@ object GDALCalc { */ def executeCalc(gdalCalcCommand: String, resultPath: String): MosaicRasterGDAL = { require(gdalCalcCommand.startsWith("gdal_calc"), "Not a valid GDAL Calc command.") - val toRun = gdalCalcCommand.replace("gdal_calc", gdal_calc) + val effectiveCommand = OperatorOptions.appendOptions(gdalCalcCommand, MosaicRasterWriteOptions.GTiff) + val toRun = effectiveCommand.replace("gdal_calc", gdal_calc) val commandRes = SysUtils.runCommand(s"python3 $toRun") - if (commandRes._1 == "ERROR") { + if (commandRes._1.startsWith("ERROR")) { throw new RuntimeException(s""" |GDAL Calc command failed: + |$toRun |STDOUT: |${commandRes._2} |STDERR: diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala index bf266cfbf..fd24a0f73 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterGDAL, MosaicRasterWriteOptions} import org.gdal.gdal.{TranslateOptions, gdal} import java.nio.file.{Files, Paths} @@ -20,21 +20,26 @@ object GDALTranslate { * @return * A MosaicRaster object. */ - def executeTranslate(outputPath: String, raster: MosaicRasterGDAL, command: String): MosaicRasterGDAL = { + def executeTranslate( + outputPath: String, + raster: MosaicRasterGDAL, + command: String, + writeOptions: MosaicRasterWriteOptions + ): MosaicRasterGDAL = { require(command.startsWith("gdal_translate"), "Not a valid GDAL Translate command.") - val translateOptionsVec = OperatorOptions.parseOptions(command) + val effectiveCommand = OperatorOptions.appendOptions(command, writeOptions) + val translateOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val translateOptions = new TranslateOptions(translateOptionsVec) val result = gdal.Translate(outputPath, raster.getRaster, translateOptions) if (result == null) { - throw new Exception( - s""" - |Translate failed. - |Command: $command - |Error: ${gdal.GetLastErrorMsg} - |""".stripMargin) + throw new Exception(s""" + |Translate failed. + |Command: $effectiveCommand + |Error: ${gdal.GetLastErrorMsg} + |""".stripMargin) } val size = Files.size(Paths.get(outputPath)) - raster.copy(raster = result, path = outputPath, memSize = size).flushCache() + raster.copy(raster = result, path = outputPath, memSize = size, driverShortName = writeOptions.format).flushCache() } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala index 2b13a957b..ba6dce58d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala @@ -23,7 +23,8 @@ object GDALWarp { def executeWarp(outputPath: String, rasters: Seq[MosaicRasterGDAL], command: String): MosaicRasterGDAL = { require(command.startsWith("gdalwarp"), "Not a valid GDAL Warp command.") // Test: gdal.ParseCommandLine(command) - val warpOptionsVec = OperatorOptions.parseOptions(command) + val effectiveCommand = OperatorOptions.appendOptions(command, rasters.head.getWriteOptions) + val warpOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val warpOptions = new WarpOptions(warpOptionsVec) val result = gdal.Warp(outputPath, rasters.map(_.getRaster).toArray, warpOptions) // TODO: Figure out multiple parents, should this be an array? @@ -31,7 +32,7 @@ object GDALWarp { if (result == null) { throw new Exception(s""" |Warp failed. - |Command: $command + |Command: $effectiveCommand |Error: ${gdal.GetLastErrorMsg} |""".stripMargin) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala index b1529d3e7..bc656ec01 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala @@ -1,5 +1,7 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterWriteOptions + /** OperatorOptions is a helper object for parsing GDAL command options. */ object OperatorOptions { @@ -18,4 +20,37 @@ object OperatorOptions { optionsVec } + /** + * Add default options to the command. Extract the compression from the + * raster and append it to the command. This operation does not change the + * output format. For changing the output format, use RST_ToFormat. + * + * @param command + * The command to append options to. + * @param writeOptions + * The write options to append. Note that not all available options are + * actually appended. At this point it is up to the bellow logic to + * decide what is supported and for which format. + * @return + */ + def appendOptions(command: String, writeOptions: MosaicRasterWriteOptions): String = { + val compression = writeOptions.compression + if (command.startsWith("gdal_calc")) { + writeOptions.format match { + case f @ "GTiff" => command + s" --format $f --co TILED=YES --co COMPRESS=$compression" + case f @ "COG" => command + s" --format $f --co TILED=YES --co COMPRESS=$compression" + case f @ _ => command + s" --format $f --co COMPRESS=$compression" + } + } else { + writeOptions.format match { + case f @ "GTiff" => command + s" -of $f -co TILED=YES -co COMPRESS=$compression" + case f @ "COG" => command + s" -of $f -co TILED=YES -co COMPRESS=$compression" + case "VRT" => command + case f @ "Zarr" if writeOptions.missingGeoRef => + command + s" -of $f -co COMPRESS=$compression -to SRC_METHOD=NO_GEOTRANSFORM" + case f @ _ => command + s" -of $f -co COMPRESS=$compression" + } + } + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala index 6333c50c8..8a82d1238 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala @@ -19,10 +19,10 @@ object MergeBands { * A MosaicRaster object. */ def merge(rasters: Seq[MosaicRasterGDAL], resampling: String): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( vrtPath, @@ -33,7 +33,8 @@ object MergeBands { val result = GDALTranslate.executeTranslate( rasterPath, vrtRaster, - command = s"gdal_translate -r $resampling -of $outShortName -co COMPRESS=DEFLATE" + command = s"gdal_translate -r $resampling", + outOptions ) dispose(vrtRaster) @@ -55,10 +56,10 @@ object MergeBands { * A MosaicRaster object. */ def merge(rasters: Seq[MosaicRasterGDAL], pixel: (Double, Double), resampling: String): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( vrtPath, @@ -69,7 +70,8 @@ object MergeBands { val result = GDALTranslate.executeTranslate( rasterPath, vrtRaster, - command = s"gdalwarp -r $resampling -of $outShortName -co COMPRESS=DEFLATE -overwrite" + command = s"gdalwarp -r $resampling", + outOptions ) dispose(vrtRaster) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala index 694d9940a..fafaffbc4 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala @@ -17,21 +17,22 @@ object MergeRasters { * A MosaicRaster object. */ def merge(rasters: Seq[MosaicRasterGDAL]): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( - vrtPath, - rasters, - command = s"gdalbuildvrt -resolution highest" + vrtPath, + rasters, + command = s"gdalbuildvrt -resolution highest" ) val result = GDALTranslate.executeTranslate( - rasterPath, - vrtRaster, - command = s"gdal_translate -r bilinear -of $outShortName -co COMPRESS=DEFLATE" + rasterPath, + vrtRaster, + command = s"gdal_translate", + outOptions ) dispose(vrtRaster) @@ -39,5 +40,4 @@ object MergeRasters { result } - } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala index 5bf49fb96..cda9824dc 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala @@ -3,6 +3,7 @@ package com.databricks.labs.mosaic.core.raster.operator.pixel import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALBuildVRT, GDALTranslate} +import com.databricks.labs.mosaic.gdal.MosaicGDAL.defaultBlockSize import com.databricks.labs.mosaic.utils.PathUtils import java.io.File @@ -20,10 +21,10 @@ object PixelCombineRasters { * A MosaicRaster object. */ def combine(rasters: Seq[MosaicRasterGDAL], pythonFunc: String, pythonFuncName: String): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( vrtPath, @@ -37,7 +38,8 @@ object PixelCombineRasters { val result = GDALTranslate.executeTranslate( rasterPath, vrtRaster.refresh(), - command = s"gdal_translate -r bilinear -of $outShortName -co COMPRESS=DEFLATE" + command = s"gdal_translate", + outOptions ) dispose(vrtRaster) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala index efd7c8c67..5d7c5f5f2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala @@ -15,8 +15,7 @@ object RasterProject { /** * Projects a raster to a new CRS. The method handles all the abstractions * over GDAL Warp. It uses cubic resampling to ensure that the output is - * smooth. It also uses COMPRESS=DEFLATE to ensure that the output is - * compressed. + * smooth. * * @param raster * The raster to project. @@ -33,11 +32,11 @@ object RasterProject { // Note that Null is the right value here val authName = destCRS.GetAuthorityName(null) val authCode = destCRS.GetAuthorityCode(null) - + val result = GDALWarp.executeWarp( resultFileName, Seq(raster), - command = s"gdalwarp -of $outShortName -t_srs $authName:$authCode -r cubic -overwrite -co COMPRESS=DEFLATE" + command = s"gdalwarp -t_srs $authName:$authCode" ) result diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala index c1498ea05..4e9f61c5e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala @@ -48,12 +48,13 @@ object OverlappingTiles { val fileExtension = GDAL.getExtension(tile.getDriver) val rasterPath = PathUtils.createTmpFilePath(fileExtension) - val shortName = raster.getRaster.GetDriver.getShortName + val outOptions = raster.getWriteOptions val result = GDALTranslate.executeTranslate( rasterPath, raster, - command = s"gdal_translate -of $shortName -srcwin $xOff $yOff $width $height" + command = s"gdal_translate -srcwin $xOff $yOff $width $height", + outOptions ) val isEmpty = result.isEmpty diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala index d186de0a5..fa47c6c1d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala @@ -41,9 +41,10 @@ object RasterTessellate { (false, MosaicRasterTile(cell.index, null, "", "")) } else { val cellRaster = tmpRaster.getRasterForCell(cellID, indexSystem, geometryAPI) - val isValidRaster = cellRaster.getBandStats.values.map(_("mean")).sum > 0 && !cellRaster.isEmpty + val isValidRaster = cellRaster.getValidCount.values.sum > 0 && !cellRaster.isEmpty ( - isValidRaster, MosaicRasterTile(cell.index, cellRaster, raster.getParentPath, raster.getDriversShortName) + isValidRaster, + MosaicRasterTile(cell.index, cellRaster, raster.getParentPath, raster.getDriversShortName) ) } }) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala index edaab4720..f25a1f384 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala @@ -1,7 +1,7 @@ package com.databricks.labs.mosaic.core.raster.operator.retile import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose -import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALBuildVRT, GDALTranslate} +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALTranslate import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.utils.PathUtils @@ -39,12 +39,13 @@ object ReTile { val fileExtension = raster.getRasterFileExtension val rasterPath = PathUtils.createTmpFilePath(fileExtension) - val shortDriver = raster.getDriversShortName + val outOptions = raster.getWriteOptions val result = GDALTranslate.executeTranslate( rasterPath, raster, - command = s"gdal_translate -of $shortDriver -srcwin $xMin $yMin $xOffset $yOffset -co COMPRESS=DEFLATE" + command = s"gdal_translate -srcwin $xMin $yMin $xOffset $yOffset", + outOptions ) val isEmpty = result.isEmpty diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala index 1cadf2c9a..5203178e0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala @@ -1,11 +1,12 @@ package com.databricks.labs.mosaic.core.types +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types._ /** Type definition for the raster tile. */ class RasterTileType(fields: Array[StructField]) extends StructType(fields) { - def rasterType: DataType = fields(1).dataType + def rasterType: DataType = fields.find(_.name == "raster").get.dataType override def simpleString: String = "RASTER_TILE" @@ -19,20 +20,54 @@ object RasterTileType { * Creates a new instance of [[RasterTileType]]. * * @param idType - * Type of the index ID. + * Type of the index ID. Can be one of [[LongType]], [[IntegerType]] or + * [[StringType]]. + * @param rasterType + * Type of the raster. Can be one of [[ByteType]] or [[StringType]]. Not + * to be confused with the data type of the raster. This is the type of + * the column that contains the raster. + * * @return * An instance of [[RasterTileType]]. */ - def apply(idType: DataType): RasterTileType = { + def apply(idType: DataType, rasterType: DataType): DataType = { require(Seq(LongType, IntegerType, StringType).contains(idType)) new RasterTileType( Array( StructField("index_id", idType), - StructField("raster", BinaryType), + StructField("raster", rasterType), StructField("parentPath", StringType), StructField("driver", StringType) ) ) } + /** + * Creates a new instance of [[RasterTileType]]. + * + * @param idType + * Type of the index ID. Can be one of [[LongType]], [[IntegerType]] or + * [[StringType]]. + * @param tileExpr + * Expression containing a tile. This is used to infer the raster type + * when chaining expressions. + * @return + */ + def apply(idType: DataType, tileExpr: Expression): DataType = { + require(Seq(LongType, IntegerType, StringType).contains(idType)) + tileExpr.dataType match { + case st @ StructType(_) => apply(idType, st.find(_.name == "raster").get.dataType) + case _ @ArrayType(elementType: StructType, _) => apply(idType, elementType.find(_.name == "raster").get.dataType) + case _ => throw new IllegalArgumentException("Unsupported raster type.") + } + } + + def apply(tileExpr: Expression): RasterTileType = { + tileExpr.dataType match { + case StructType(fields) => new RasterTileType(fields) + case ArrayType(elementType: StructType, _) => new RasterTileType(elementType.fields) + case _ => throw new IllegalArgumentException("Unsupported raster type.") + } + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala index e7a8e9218..30a7765c1 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala @@ -106,19 +106,22 @@ case class MosaicRasterTile( * An instance of [[InternalRow]]. */ def serialize( - rasterDataType: DataType = BinaryType, - checkpointLocation: String = "" + rasterDataType: DataType ): InternalRow = { val parentPathUTF8 = UTF8String.fromString(parentPath) val driverUTF8 = UTF8String.fromString(driver) - val encodedRaster = encodeRaster(rasterDataType, checkpointLocation) + val encodedRaster = encodeRaster(rasterDataType) if (Option(index).isDefined) { if (index.isLeft) InternalRow.fromSeq( Seq(index.left.get, encodedRaster, parentPathUTF8, driverUTF8) ) - else InternalRow.fromSeq( - Seq(UTF8String.fromString(index.right.get), encodedRaster, parentPathUTF8, driverUTF8) - ) + else { + // Copy from tmp to checkpoint. + // Have to use GDAL Driver to do this since sidecar files are not copied by spark. + InternalRow.fromSeq( + Seq(UTF8String.fromString(index.right.get), encodedRaster, parentPathUTF8, driverUTF8) + ) + } } else { InternalRow.fromSeq(Seq(null, encodedRaster, parentPathUTF8, driverUTF8)) } @@ -132,10 +135,9 @@ case class MosaicRasterTile( * An instance of [[Array]] of [[Byte]] representing WKB. */ private def encodeRaster( - rasterDataType: DataType = BinaryType, - checkpointLocation: String = "" + rasterDataType: DataType = BinaryType ): Any = { - GDAL.writeRasters(Seq(raster), checkpointLocation, rasterDataType).head + GDAL.writeRasters(Seq(raster), rasterDataType).head } } @@ -153,12 +155,12 @@ object MosaicRasterTile { * @return * An instance of [[MosaicRasterTile]]. */ - def deserialize(row: InternalRow, idDataType: DataType): MosaicRasterTile = { + def deserialize(row: InternalRow, idDataType: DataType, rasterType: DataType): MosaicRasterTile = { val index = row.get(0, idDataType) - val rasterBytes = row.get(1, BinaryType) + val rawRaster = row.get(1, rasterType) val parentPath = row.get(2, StringType).toString val driver = row.get(3, StringType).toString - val raster = GDAL.readRaster(rasterBytes, parentPath, driver, BinaryType) + val raster = GDAL.readRaster(rawRaster, parentPath, driver, rasterType) // noinspection TypeCheckCanBeMatch if (Option(index).isDefined) { if (index.isInstanceOf[Long]) { diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala index 285df2191..a38e76900 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala @@ -19,11 +19,17 @@ import java.nio.file.{Files, Paths} /** An object defining the retiling read strategy for the GDAL file format. */ object ReTileOnRead extends ReadStrategy { + val tileDataType: DataType = StringType + // noinspection DuplicatedCode /** * Returns the schema of the GDAL file format. * @note - * Different read strategies can have different schemas. + * Different read strategies can have different schemas. This is because + * the schema is defined by the read strategy. For retiling we always use + * checkpoint location. In this case rasters are stored off spark rows. + * If you need the tiles in memory please load them from path stored in + * the tile returned by the reader. * * @param options * Options passed to the reader. @@ -54,7 +60,10 @@ object ReTileOnRead extends ReadStrategy { .add(StructField(SUBDATASETS, MapType(StringType, StringType), nullable = false)) .add(StructField(SRID, IntegerType, nullable = false)) .add(StructField(LENGTH, LongType, nullable = false)) - .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType), nullable = false)) + // Note that for retiling we always use checkpoint location. + // In this case rasters are stored off spark rows. + // If you need the tiles in memory please load them from path stored in the tile returned by the reader. + .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType, tileDataType), nullable = false)) } /** @@ -103,7 +112,7 @@ object ReTileOnRead extends ReadStrategy { case other => throw new RuntimeException(s"Unsupported field name: $other") } // Writing to bytes is destructive so we delay reading content and content length until the last possible moment - val row = Utils.createRow(fields ++ Seq(tile.formatCellId(indexSystem).serialize())) + val row = Utils.createRow(fields ++ Seq(tile.formatCellId(indexSystem).serialize(tileDataType))) RasterCleaner.dispose(tile) row }) diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala index 0517ac1d9..8c0c4a914 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala @@ -49,7 +49,9 @@ object ReadInMemory extends ReadStrategy { .add(StructField(METADATA, MapType(StringType, StringType), nullable = false)) .add(StructField(SUBDATASETS, MapType(StringType, StringType), nullable = false)) .add(StructField(SRID, IntegerType, nullable = false)) - .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType), nullable = false)) + // Note, for in memory reads the rasters are stored in the tile. + // For that we use Binary Columns. + .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType, BinaryType), nullable = false)) } /** diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala index d6f26caf4..2f5bf39b6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala @@ -65,36 +65,32 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead val retiledDf = retileRaster(rasterDf, config) - val loadedDf = rasterDf + val loadedDf = retiledDf .withColumn( "tile", - rst_tessellate(col("tile"), lit(resolution)) + rst_tessellate(col("tile"), lit(resolution)) ) .repartition(nPartitions) + .groupBy("tile.index_id") + .agg(rst_combineavg_agg(col("tile")).alias("tile")) .withColumn( "grid_measures", - rasterToGridCombiner(col("tile"), lit(resolution)) + rasterToGridCombiner(col("tile")) ) .select( "grid_measures", "tile" ) .select( - posexplode(col("grid_measures")).as(Seq("band_id", "grid_measures")) + posexplode(col("grid_measures")).as(Seq("band_id", "measure")), + col("tile").getField("index_id").alias("cell_id") ) .repartition(nPartitions) .select( col("band_id"), - explode(col("grid_measures")).alias("grid_measures") + col("cell_id"), + col("measure") ) - .repartition(nPartitions) - .select( - col("band_id"), - col("grid_measures").getItem("cellID").alias("cell_id"), - col("grid_measures").getItem("measure").alias("measure") - ) - .groupBy("band_id", "cell_id") - .agg(avg("measure").alias("measure")) kRingResample(loadedDf, config) @@ -203,15 +199,15 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead * @return * The raster to grid function. */ - private def getRasterToGridFunc(combiner: String): (Column, Column) => Column = { + private def getRasterToGridFunc(combiner: String): Column => Column = { combiner match { - case "mean" => rst_rastertogridavg - case "min" => rst_rastertogridmin - case "max" => rst_rastertogridmax - case "median" => rst_rastertogridmedian - case "count" => rst_rastertogridcount - case "average" => rst_rastertogridavg - case "avg" => rst_rastertogridavg + case "mean" => rst_avg + case "min" => rst_min + case "max" => rst_max + case "median" => rst_median + case "count" => rst_pixelcount + case "average" => rst_avg + case "avg" => rst_avg case _ => throw new Error("Combiner not supported") } } @@ -232,7 +228,7 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead "combiner" -> this.extraOptions.getOrElse("combiner", "mean"), "retile" -> this.extraOptions.getOrElse("retile", "false"), "tileSize" -> this.extraOptions.getOrElse("tileSize", "-1"), - "sizeInMB" -> this.extraOptions.getOrElse("sizeInMB", ""), + "sizeInMB" -> this.extraOptions.getOrElse("sizeInMB", "-1"), "kRingInterpolate" -> this.extraOptions.getOrElse("kRingInterpolate", "0") ) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala index 82752cad4..a5907dbe9 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala @@ -13,11 +13,13 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ -case class RST_Avg(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Avg](raster, ArrayType(DoubleType), returnsRaster = false, expressionConfig) +case class RST_Avg(tileExpr: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Avg](tileExpr, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(DoubleType) + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { import org.json4s._ diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala index 241d913bc..fec760813 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala @@ -25,13 +25,14 @@ case class RST_BandMetaData(raster: Expression, band: Expression, expressionConf extends RasterBandExpression[RST_BandMetaData]( raster, band, - MapType(StringType, StringType), returnsRaster = false, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, StringType) + /** * @param raster * The raster to be used. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala index 8fa2d7314..e79a8ec40 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala @@ -15,10 +15,12 @@ import org.apache.spark.sql.types._ case class RST_BoundingBox( raster: Expression, expressionConfig: MosaicExpressionConfig -) extends RasterExpression[RST_BoundingBox](raster, BinaryType, returnsRaster = false, expressionConfig = expressionConfig) +) extends RasterExpression[RST_BoundingBox](raster, returnsRaster = false, expressionConfig = expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = BinaryType + /** * Computes the bounding box of the raster. The bbox is returned as a WKB * polygon. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala index 557565afe..29449a6ef 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala @@ -19,13 +19,14 @@ case class RST_Clip( ) extends Raster1ArgExpression[RST_Clip]( rastersExpr, geometryExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: org.apache.spark.sql.types.DataType = RasterTileType(expressionConfig.getCellIdType, rastersExpr) + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) /** diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala index 1d923fdc1..d63fe6914 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala @@ -9,20 +9,22 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** Expression for combining rasters using average of pixels. */ case class RST_CombineAvg( - rastersExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArrayExpression[RST_CombineAvg]( - rastersExpr, - RasterTileType(expressionConfig.getCellIdType), + tileExpr, returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** Combines the rasters using average of pixels. */ override def rasterTransform(tiles: Seq[MosaicRasterTile]): Any = { val index = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala index f6b3ba1dc..5bbc01a7b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala @@ -6,6 +6,7 @@ import com.databricks.labs.mosaic.core.raster.io.RasterCleaner import com.databricks.labs.mosaic.core.raster.operator.CombineAVG import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile.{deserialize => deserializeTile} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpressionSerialization import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow @@ -13,7 +14,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType} +import org.apache.spark.sql.types.{ArrayType, DataType} import scala.collection.mutable.ArrayBuffer @@ -23,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer */ //noinspection DuplicatedCode case class RST_CombineAvgAgg( - rasterExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0 @@ -32,21 +33,23 @@ case class RST_CombineAvgAgg( with RasterExpressionSerialization { override lazy val deterministic: Boolean = true - override val child: Expression = rasterExpr + override val child: Expression = tileExpr override val nullable: Boolean = false - override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override lazy val dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + lazy val tileType: DataType = dataType.asInstanceOf[RasterTileType].rasterType override def prettyName: String = "rst_combine_avg_agg" + val cellIDType: DataType = expressionConfig.getCellIdType private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = dataType, containsNull = false))) private lazy val row = new UnsafeRow(1) - def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { + override def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { val value = child.eval(input) buffer += InternalRow.copyValue(value) buffer } - def merge(buffer: ArrayBuffer[Any], input: ArrayBuffer[Any]): ArrayBuffer[Any] = { + override def merge(buffer: ArrayBuffer[Any], input: ArrayBuffer[Any]): ArrayBuffer[Any] = { buffer ++= input } @@ -63,10 +66,15 @@ case class RST_CombineAvgAgg( if (buffer.isEmpty) { null + } else if (buffer.size == 1) { + val result = buffer.head + buffer.clear() + result } else { // Do do move the expression - var tiles = buffer.map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType)) + var tiles = buffer.map(row => deserializeTile(row.asInstanceOf[InternalRow], cellIDType, tileType)) + buffer.clear() // If merging multiple index rasters, the index value is dropped val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null @@ -77,9 +85,9 @@ case class RST_CombineAvgAgg( val result = MosaicRasterTile(idx, combined, parentPath, driver) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(BinaryType, expressionConfig.getRasterCheckpoint) + .serialize(tileType) - tiles.foreach(RasterCleaner.dispose(_)) + tiles.foreach(RasterCleaner.dispose) RasterCleaner.dispose(result) tiles = null @@ -101,7 +109,7 @@ case class RST_CombineAvgAgg( buffer } - override protected def withNewChildInternal(newChild: Expression): RST_CombineAvgAgg = copy(rasterExpr = newChild) + override protected def withNewChildInternal(newChild: Expression): RST_CombineAvgAgg = copy(tileExpr = newChild) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala new file mode 100644 index 000000000..db20f8a3a --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala @@ -0,0 +1,73 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster1ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** The expression for applying kernel filter on a raster. */ +case class RST_Convolve( + rastersExpr: Expression, + kernelExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster1ArgExpression[RST_Convolve]( + rastersExpr, + kernelExpr, + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + override def dataType: org.apache.spark.sql.types.DataType = RasterTileType(expressionConfig.getCellIdType, rastersExpr) + + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) + + /** + * Clips a raster by a vector. + * + * @param tile + * The raster to be used. + * @param arg1 + * The vector to be used. + * @return + * The clipped raster. + */ + override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = { + val kernel = arg1.asInstanceOf[Array[Array[Double]]] + tile.copy( + raster = tile.getRaster.convolve(kernel) + ) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Convolve extends WithExpressionInfo { + + override def name: String = "rst_convolve" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster with the kernel filter applied. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster, kernel); + | {index_id, clipped_raster, parentPath, driver} + | {index_id, clipped_raster, parentPath, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Convolve](2, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala index 822228a1b..fa576427a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala @@ -9,25 +9,27 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** Expression for combining rasters using average of pixels. */ case class RST_DerivedBand( - rastersExpr: Expression, + tileExpr: Expression, pythonFuncExpr: Expression, funcNameExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArray2ArgExpression[RST_DerivedBand]( - rastersExpr, + tileExpr, pythonFuncExpr, funcNameExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** Combines the rasters using average of pixels. */ override def rasterTransform(tiles: Seq[MosaicRasterTile], arg1: Any, arg2: Any): Any = { val pythonFunc = arg1.asInstanceOf[UTF8String].toString diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala index 47d4aa12a..f02194d62 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer */ //noinspection DuplicatedCode case class RST_DerivedBandAgg( - rasterExpr: Expression, + tileExpr: Expression, pythonFuncExpr: Expression, funcNameExpr: Expression, expressionConfig: MosaicExpressionConfig, @@ -36,13 +36,13 @@ case class RST_DerivedBandAgg( override lazy val deterministic: Boolean = true override val nullable: Boolean = false - override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override lazy val dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) override def prettyName: String = "rst_combine_avg_agg" private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = dataType, containsNull = false))) private lazy val row = new UnsafeRow(1) - override def first: Expression = rasterExpr + override def first: Expression = tileExpr override def second: Expression = pythonFuncExpr override def third: Expression = funcNameExpr @@ -74,9 +74,16 @@ case class RST_DerivedBandAgg( // This works for Literals only val pythonFunc = pythonFuncExpr.eval(null).asInstanceOf[UTF8String].toString val funcName = funcNameExpr.eval(null).asInstanceOf[UTF8String].toString + val rasterType = RasterTileType(tileExpr).rasterType // Do do move the expression - var tiles = buffer.map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType)) + var tiles = buffer.map(row => + MosaicRasterTile.deserialize( + row.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) + ) // If merging multiple index rasters, the index value is dropped val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null @@ -88,7 +95,7 @@ case class RST_DerivedBandAgg( val result = MosaicRasterTile(idx, combined, parentPath, driver) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(BinaryType, expressionConfig.getRasterCheckpoint) + .serialize(BinaryType) tiles.foreach(RasterCleaner.dispose(_)) RasterCleaner.dispose(result) @@ -113,7 +120,7 @@ case class RST_DerivedBandAgg( } override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): RST_DerivedBandAgg = - copy(rasterExpr = newFirst, pythonFuncExpr = newSecond, funcNameExpr = newThird) + copy(tileExpr = newFirst, pythonFuncExpr = newSecond, funcNameExpr = newThird) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Filter.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Filter.scala new file mode 100644 index 000000000..ee8b34d3b --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Filter.scala @@ -0,0 +1,77 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.unsafe.types.UTF8String + +/** The expression for applying NxN filter on a raster. */ +case class RST_Filter( + rastersExpr: Expression, + kernelSizeExpr: Expression, + operationExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster2ArgExpression[RST_Filter]( + rastersExpr, + kernelSizeExpr, + operationExpr, + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + override def dataType: org.apache.spark.sql.types.DataType = RasterTileType(expressionConfig.getCellIdType, rastersExpr) + + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) + + /** + * Clips a raster by a vector. + * + * @param tile + * The raster to be used. + * @param arg1 + * The vector to be used. + * @return + * The clipped raster. + */ + override def rasterTransform(tile: MosaicRasterTile, arg1: Any, arg2: Any): Any = { + val n = arg1.asInstanceOf[Int] + val operation = arg2.asInstanceOf[UTF8String].toString + tile.copy( + raster = tile.getRaster.filter(n, operation) + ) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Filter extends WithExpressionInfo { + + override def name: String = "rst_filter" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster with the filter applied. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster, kernelSize, operation); + | {index_id, clipped_raster, parentPath, driver} + | {index_id, clipped_raster, parentPath, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Filter](3, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala index 2befb353c..90e49b654 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala @@ -9,6 +9,7 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.ArrayType /** The expression for stacking and resampling input bands. */ case class RST_FromBands( @@ -16,13 +17,18 @@ case class RST_FromBands( expressionConfig: MosaicExpressionConfig ) extends RasterArrayExpression[RST_FromBands]( bandsExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: org.apache.spark.sql.types.DataType = + RasterTileType( + expressionConfig.getCellIdType, + RasterTileType(bandsExpr).rasterType + ) + /** * Stacks and resamples input bands. * @param rasters diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala index bd2926bcb..59c701f71 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala @@ -15,7 +15,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, Literal, NullIntolerant} -import org.apache.spark.sql.types.{DataType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import java.nio.file.{Files, Paths} @@ -25,7 +25,7 @@ import java.nio.file.{Files, Paths} * expression in the expression tree for a raster tile. */ case class RST_FromContent( - rasterExpr: Expression, + contentExpr: Expression, driverExpr: Expression, sizeInMB: Expression, expressionConfig: MosaicExpressionConfig @@ -33,8 +33,10 @@ case class RST_FromContent( with Serializable with NullIntolerant with CodegenFallback { + + val tileType: DataType = BinaryType - override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileType) protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) @@ -46,12 +48,13 @@ case class RST_FromContent( override def inline: Boolean = false - override def children: Seq[Expression] = Seq(rasterExpr, driverExpr, sizeInMB) + override def children: Seq[Expression] = Seq(contentExpr, driverExpr, sizeInMB) override def elementSchema: StructType = StructType(Array(StructField("tile", dataType))) /** - * subdivides raster binary content into tiles of the specified size (in MB). + * subdivides raster binary content into tiles of the specified size (in + * MB). * @param input * The input file path. * @return @@ -61,13 +64,13 @@ case class RST_FromContent( GDAL.enable(expressionConfig) val driver = driverExpr.eval(input).asInstanceOf[UTF8String].toString val ext = GDAL.getExtension(driver) - var rasterArr = rasterExpr.eval(input).asInstanceOf[Array[Byte]] + var rasterArr = contentExpr.eval(input).asInstanceOf[Array[Byte]] val targetSize = sizeInMB.eval(input).asInstanceOf[Int] if (targetSize <= 0 || rasterArr.length <= targetSize) { // - no split required var raster = MosaicRasterGDAL.readRaster(rasterArr, PathUtils.NO_PATH_STRING, driver) var tile = MosaicRasterTile(null, raster, PathUtils.NO_PATH_STRING, driver) - val row = tile.formatCellId(indexSystem).serialize() + val row = tile.formatCellId(indexSystem).serialize(tileType) RasterCleaner.dispose(raster) RasterCleaner.dispose(tile) rasterArr = null @@ -84,7 +87,7 @@ case class RST_FromContent( // split to tiles up to specifed threshold var tiles = ReTileOnRead.localSubdivide(rasterPath, PathUtils.NO_PATH_STRING, targetSize) - val rows = tiles.map(_.formatCellId(indexSystem).serialize()) + val rows = tiles.map(_.formatCellId(indexSystem).serialize(tileType)) tiles.foreach(RasterCleaner.dispose(_)) Files.deleteIfExists(Paths.get(rasterPath)) rasterArr = null @@ -118,10 +121,10 @@ object RST_FromContent extends WithExpressionInfo { | ... | """.stripMargin - override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { - (children: Seq[Expression]) => { + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { (children: Seq[Expression]) => + { val sizeExpr = if (children.length < 3) new Literal(-1, IntegerType) else children(2) - RST_FromContent(children(0), children(1), sizeExpr, expressionConfig) + RST_FromContent(children.head, children(1), sizeExpr, expressionConfig) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala index fbce5bf58..ee5bec721 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala @@ -15,7 +15,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, Literal, NullIntolerant} -import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import java.nio.file.{Files, Paths, StandardCopyOption} @@ -32,8 +32,10 @@ case class RST_FromFile( with Serializable with NullIntolerant with CodegenFallback { + + val tileType: DataType = BinaryType - override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileType) protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) @@ -66,7 +68,7 @@ case class RST_FromFile( if (targetSize <= 0 && Files.size(Paths.get(readPath)) <= Integer.MAX_VALUE) { var raster = MosaicRasterGDAL.readRaster(readPath, path) var tile = MosaicRasterTile(null, raster, path, raster.getDriversShortName) - val row = tile.formatCellId(indexSystem).serialize() + val row = tile.formatCellId(indexSystem).serialize(tileType) RasterCleaner.dispose(raster) RasterCleaner.dispose(tile) raster = null @@ -79,7 +81,7 @@ case class RST_FromFile( Files.copy(Paths.get(readPath), Paths.get(tmpPath), StandardCopyOption.REPLACE_EXISTING) val size = if (targetSize <= 0) 64 else targetSize var tiles = ReTileOnRead.localSubdivide(tmpPath, path, size) - val rows = tiles.map(_.formatCellId(indexSystem).serialize()) + val rows = tiles.map(_.formatCellId(indexSystem).serialize(tileType)) tiles.foreach(RasterCleaner.dispose(_)) Files.deleteIfExists(Paths.get(tmpPath)) tiles = null diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala index f4213eee7..404eb4b90 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the georeference of the raster. */ case class RST_GeoReference(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_GeoReference](raster, MapType(StringType, DoubleType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_GeoReference](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, DoubleType) + /** Returns the georeference of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val geoTransform = tile.getRaster.getRaster.GetGeoTransform() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala index 8f10b89cb..aa07a6637 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala @@ -8,7 +8,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{ArrayType, DoubleType} +import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType} /** The expression for extracting the no data value of a raster. */ case class RST_GetNoData( @@ -16,13 +16,14 @@ case class RST_GetNoData( expressionConfig: MosaicExpressionConfig ) extends RasterExpression[RST_GetNoData]( rastersExpr, - ArrayType(DoubleType), returnsRaster = false, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(DoubleType) + /** * Extracts the no data value of a raster. * diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala index a87f6fa25..8d1fc77f1 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala @@ -8,20 +8,25 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** Returns the subdatasets of the raster. */ -case class RST_GetSubdataset(raster: Expression, subsetName: Expression, expressionConfig: MosaicExpressionConfig) - extends Raster1ArgExpression[RST_GetSubdataset]( - raster, +case class RST_GetSubdataset( + tileExpr: Expression, + subsetName: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster1ArgExpression[RST_GetSubdataset]( + tileExpr, subsetName, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** Returns the subdatasets of the raster. */ override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = { val subsetName = arg1.asInstanceOf[UTF8String].toString diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala index ceb638f29..f2508e1e6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the width of the raster. */ case class RST_Height(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Height](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Height](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the width of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.ySize diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala index 8cf226664..3b1b806da 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala @@ -11,20 +11,22 @@ import com.databricks.labs.mosaic.utils.PathUtils import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** The expression that initializes no data values of a raster. */ case class RST_InitNoData( - rastersExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterExpression[RST_InitNoData]( - rastersExpr, - RasterTileType(expressionConfig.getCellIdType), + tileExpr, returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Initializes no data values of a raster. * diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala index 4a5f5034f..7d6267bec 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns true if the raster is empty. */ case class RST_IsEmpty(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_IsEmpty](raster, BooleanType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_IsEmpty](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = BooleanType + /** Returns true if the raster is empty. */ override def rasterTransform(tile: MosaicRasterTile): Any = { var raster = tile.getRaster diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala new file mode 100644 index 000000000..586337556 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala @@ -0,0 +1,205 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.MOSAIC_NO_DRIVER +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.datasource.gdal.ReTileOnRead +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.utils.PathUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, Literal, NullIntolerant} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.file.{Files, Paths} +import scala.util.Try + +/** + * Creates raster tiles from the input column. + * + * @param inputExpr + * The expression for the raster. If the raster is stored on disc, the path + * to the raster is provided. If the raster is stored in memory, the bytes of + * the raster are provided. + * @param sizeInMBExpr + * The size of the tiles in MB. If set to -1, the file is loaded and returned + * as a single tile. If set to 0, the file is loaded and subdivided into + * tiles of size 64MB. If set to a positive value, the file is loaded and + * subdivided into tiles of the specified size. If the file is too big to fit + * in memory, it is subdivided into tiles of size 64MB. + * @param driverExpr + * The driver to use for reading the raster. If not specified, the driver is + * inferred from the file extension. If the input is a byte array, the driver + * has to be specified. + * @param withCheckpointExpr + * If set to true, the tiles are written to the checkpoint directory. If set + * to false, the tiles are returned as a in-memory byte arrays. + * @param expressionConfig + * Additional arguments for the expression (expressionConfigs). + */ +case class RST_MakeTiles( + inputExpr: Expression, + driverExpr: Expression, + sizeInMBExpr: Expression, + withCheckpointExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends CollectionGenerator + with Serializable + with NullIntolerant + with CodegenFallback { + + override def dataType: DataType = { + require(withCheckpointExpr.isInstanceOf[Literal]) + if (withCheckpointExpr.eval().asInstanceOf[Boolean]) { + // Raster is referenced via a path + RasterTileType(expressionConfig.getCellIdType, StringType) + } else { + // Raster is referenced via a byte array + RasterTileType(expressionConfig.getCellIdType, BinaryType) + } + } + + protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) + + protected val indexSystem: IndexSystem = IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem) + + protected val cellIdDataType: DataType = indexSystem.getCellIdDataType + + override def position: Boolean = false + + override def inline: Boolean = false + + override def children: Seq[Expression] = Seq(inputExpr, driverExpr, sizeInMBExpr, withCheckpointExpr) + + override def elementSchema: StructType = StructType(Array(StructField("tile", dataType))) + + private def getDriver(rawInput: Any, rawDriver: String): String = { + if (rawDriver == MOSAIC_NO_DRIVER) { + if (inputExpr.dataType == StringType) { + val path = rawInput.asInstanceOf[UTF8String].toString + MosaicRasterGDAL.identifyDriver(path) + } else { + throw new IllegalArgumentException("Driver has to be specified for byte array input") + } + } else { + rawDriver + } + } + + private def getInputSize(rawInput: Any): Long = { + if (inputExpr.dataType == StringType) { + val path = rawInput.asInstanceOf[UTF8String].toString + Files.size(Paths.get(path)) + } else { + val bytes = rawInput.asInstanceOf[Array[Byte]] + bytes.length + } + } + + /** + * Loads a raster from a file and subdivides it into tiles of the specified + * size (in MB). + * @param input + * The input file path. + * @return + * The tiles. + */ + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + GDAL.enable(expressionConfig) + + val tileType = dataType.asInstanceOf[StructType].find(_.name == "raster").get.dataType + + val rawDriver = driverExpr.eval(input).asInstanceOf[UTF8String].toString + val rawInput = inputExpr.eval(input) + val driver = getDriver(rawInput, rawDriver) + val targetSize = sizeInMBExpr.eval(input).asInstanceOf[Int] + val inputSize = getInputSize(rawInput) + + if (targetSize <= 0 && inputSize <= Integer.MAX_VALUE) { + // - no split required + val raster = GDAL.readRaster(rawInput, PathUtils.NO_PATH_STRING, driver, inputExpr.dataType) + val tile = MosaicRasterTile(null, raster, PathUtils.NO_PATH_STRING, driver) + val row = tile.formatCellId(indexSystem).serialize(tileType) + RasterCleaner.dispose(raster) + RasterCleaner.dispose(tile) + Seq(InternalRow.fromSeq(Seq(row))) + } else { + // target size is > 0 and raster size > target size + // - write the initial raster to file (unsplit) + // - createDirectories in case of context isolation + val rasterPath = PathUtils.createTmpFilePath(GDAL.getExtension(driver)) + Files.createDirectories(Paths.get(rasterPath).getParent) + Files.write(Paths.get(rasterPath), rawInput.asInstanceOf[Array[Byte]]) + val size = if (targetSize <= 0) 64 else targetSize + var tiles = ReTileOnRead.localSubdivide(rasterPath, PathUtils.NO_PATH_STRING, size) + val rows = tiles.map(_.formatCellId(indexSystem).serialize(tileType)) + tiles.foreach(RasterCleaner.dispose(_)) + Files.deleteIfExists(Paths.get(rasterPath)) + tiles = null + rows.map(row => InternalRow.fromSeq(Seq(row))) + } + } + + override def makeCopy(newArgs: Array[AnyRef]): Expression = + GenericExpressionFactory.makeCopyImpl[RST_MakeTiles](this, newArgs, children.length, expressionConfig) + + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = makeCopy(newChildren.toArray) + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_MakeTiles extends WithExpressionInfo { + + override def name: String = "rst_maketiles" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a set of new rasters with the specified tile size (tileWidth x tileHeight). + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_path); + | {index_id, raster, parent_path, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { (children: Seq[Expression]) => + { + def checkSize(size: Expression) = Try(size.eval().asInstanceOf[Int]).isSuccess + def checkChkpnt(chkpnt: Expression) = Try(chkpnt.eval().asInstanceOf[Boolean]).isSuccess + def checkDriver(driver: Expression) = Try(driver.eval().asInstanceOf[UTF8String].toString).isSuccess + val noSize = new Literal(-1, IntegerType) + val noDriver = new Literal(MOSAIC_NO_DRIVER, StringType) + val noCheckpoint = new Literal(false, BooleanType) + + children match { + // Note type checking only works for literals + case Seq(input) => RST_MakeTiles(input, noDriver, noSize, noCheckpoint, expressionConfig) + case Seq(input, driver) if checkDriver(driver) => RST_MakeTiles(input, driver, noSize, noCheckpoint, expressionConfig) + case Seq(input, size) if checkSize(size) => RST_MakeTiles(input, noDriver, size, noCheckpoint, expressionConfig) + case Seq(input, checkpoint) if checkChkpnt(checkpoint) => + RST_MakeTiles(input, noDriver, noSize, checkpoint, expressionConfig) + case Seq(input, size, checkpoint) if checkSize(size) && checkChkpnt(checkpoint) => + RST_MakeTiles(input, noDriver, size, checkpoint, expressionConfig) + case Seq(input, driver, size) if checkDriver(driver) && checkSize(size) => + RST_MakeTiles(input, driver, size, noCheckpoint, expressionConfig) + case Seq(input, driver, checkpoint) if checkDriver(driver) && checkChkpnt(checkpoint) => + RST_MakeTiles(input, driver, noSize, checkpoint, expressionConfig) + case Seq(input, driver, size, checkpoint) if checkDriver(driver) && checkSize(size) && checkChkpnt(checkpoint) => + RST_MakeTiles(input, driver, size, checkpoint, expressionConfig) + case _ => RST_MakeTiles(children.head, children(1), children(2), children(3), expressionConfig) + } + } + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala index 53e84d96b..1c74e1f0a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala @@ -11,23 +11,25 @@ import com.databricks.labs.mosaic.utils.PathUtils import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** The expression for map algebra. */ case class RST_MapAlgebra( - rastersExpr: Expression, + tileExpr: Expression, jsonSpecExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArray1ArgExpression[RST_MapAlgebra]( - rastersExpr, + tileExpr, jsonSpecExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Map Algebra. * @param tiles diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala index abe042c2b..434be4a68 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala @@ -14,10 +14,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ case class RST_Max(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Max](raster, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Max](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(DoubleType) + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val nBands = tile.raster.raster.GetRasterCount() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala index 091121e91..19d3fc0a6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala @@ -1,8 +1,7 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.raster.api.GDAL -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL -import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALCalc, GDALInfo, GDALWarp} +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALWarp import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression @@ -16,10 +15,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ case class RST_Median(rasterExpr: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Median](rasterExpr, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Median](rasterExpr, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(DoubleType) + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val raster = tile.raster diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala index 804c4f195..f77058a65 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the memory size of the raster in bytes. */ case class RST_MemSize(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_MemSize](raster, LongType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_MemSize](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = LongType + /** Returns the memory size of the raster in bytes. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.getMemSize diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala index cb9907848..c8ef6846d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala @@ -9,20 +9,22 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** Returns a raster that is a result of merging an array of rasters. */ case class RST_Merge( - rastersExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArrayExpression[RST_Merge]( - rastersExpr, - RasterTileType(expressionConfig.getCellIdType), + tileExpr, returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Merges an array of rasters. * @param tiles diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala index 5902eac3b..88705dfe0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala @@ -20,7 +20,7 @@ import scala.collection.mutable.ArrayBuffer /** Merges rasters into a single raster. */ //noinspection DuplicatedCode case class RST_MergeAgg( - rasterExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0 @@ -29,9 +29,9 @@ case class RST_MergeAgg( with RasterExpressionSerialization { override lazy val deterministic: Boolean = true - override val child: Expression = rasterExpr + override val child: Expression = tileExpr override val nullable: Boolean = false - override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override lazy val dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) override def prettyName: String = "rst_merge_agg" private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = dataType, containsNull = false))) @@ -66,8 +66,15 @@ case class RST_MergeAgg( // This is a trick to get the rasters sorted by their parent path to ensure more consistent results // when merging rasters with large overlaps + val rasterType = RasterTileType(tileExpr).rasterType var tiles = buffer - .map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType)) + .map(row => + MosaicRasterTile.deserialize( + row.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) + ) .sortBy(_.getParentPath) // If merging multiple index rasters, the index value is dropped @@ -79,7 +86,7 @@ case class RST_MergeAgg( val result = MosaicRasterTile(idx, merged, parentPath, driver) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(BinaryType, expressionConfig.getRasterCheckpoint) + .serialize(BinaryType) tiles.foreach(RasterCleaner.dispose(_)) RasterCleaner.dispose(merged) @@ -103,7 +110,7 @@ case class RST_MergeAgg( buffer } - override protected def withNewChildInternal(newChild: Expression): RST_MergeAgg = copy(rasterExpr = newChild) + override protected def withNewChildInternal(newChild: Expression): RST_MergeAgg = copy(tileExpr = newChild) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala index 8a96ff0d1..0b6754ebe 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the metadata of the raster. */ case class RST_MetaData(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_MetaData](raster, MapType(StringType, StringType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_MetaData](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, StringType) + /** Returns the metadata of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = buildMapString(tile.getRaster.metadata) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala index 67fdb30d3..ea62e106f 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala @@ -1,6 +1,5 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALInfo import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression @@ -14,10 +13,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ case class RST_Min(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Min](raster, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Min](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(DoubleType) + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val nBands = tile.raster.raster.GetRasterCount() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala index fa595fd4b..67b580f0c 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala @@ -9,24 +9,26 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** The expression for computing NDVI index. */ case class RST_NDVI( - rastersExpr: Expression, + tileExpr: Expression, redIndex: Expression, nirIndex: Expression, expressionConfig: MosaicExpressionConfig ) extends Raster2ArgExpression[RST_NDVI]( - rastersExpr, + tileExpr, redIndex, nirIndex, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Computes NDVI index. * @param tile diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala index f5dd09551..e0a8c8d9e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the number of bands in the raster. */ case class RST_NumBands(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_NumBands](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_NumBands](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the number of bands in the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.numBands diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala index 79f44db03..b2543a87e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala @@ -12,10 +12,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ case class RST_PixelCount(rasterExpr: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_PixelCount](rasterExpr, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_PixelCount](rasterExpr, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(LongType) + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val bandCount = tile.raster.raster.GetRasterCount() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala index d1c3713ef..0c34be59b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the pixel height of the raster. */ case class RST_PixelHeight(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_PixelHeight](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_PixelHeight](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the pixel height of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val gt = tile.getRaster.getGeoTransform diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala index 6a4956e9e..b1645696b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the pixel width of the raster. */ case class RST_PixelWidth(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_PixelWidth](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_PixelWidth](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the pixel width of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val gt = tile.getRaster.getGeoTransform diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala index 42b9a928a..9da0f19ef 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala @@ -17,10 +17,12 @@ case class RST_RasterToWorldCoord( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoord](raster, x, y, StringType, returnsRaster = false, expressionConfig = expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoord](raster, x, y, returnsRaster = false, expressionConfig = expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = StringType + /** * Returns the world coordinates of the raster (x,y) pixel by applying * GeoTransform. This ensures the projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala index 4bd06646a..5fea59b49 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala @@ -16,10 +16,12 @@ case class RST_RasterToWorldCoordX( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoordX](raster, x, y, DoubleType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoordX](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** * Returns the world coordinates of the raster x pixel by applying * GeoTransform. This ensures the projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala index 262d6bbad..ae170709c 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala @@ -16,10 +16,12 @@ case class RST_RasterToWorldCoordY( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoordY](raster, x, y, DoubleType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoordY](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** * Returns the world coordinates of the raster y pixel by applying * GeoTransform. This ensures the projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala index 4465866dc..939011882 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala @@ -8,6 +8,7 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** * Returns a set of new rasters with the specified tile size (tileWidth x @@ -22,6 +23,8 @@ case class RST_ReTile( with NullIntolerant with CodegenFallback { + override def dataType: DataType = rasterExpr.dataType + /** * Returns a set of new rasters with the specified tile size (tileWidth x * tileHeight). diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala index c3cd097c7..5933c7133 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the rotation angle of the raster. */ case class RST_Rotation(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Rotation](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Rotation](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the rotation angle of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val gt = tile.getRaster.getRaster.GetGeoTransform() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala index c8bce06b7..648260ae5 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala @@ -14,10 +14,12 @@ import scala.util.Try /** Returns the SRID of the raster. */ case class RST_SRID(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SRID](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_SRID](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the SRID of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { // Reference: https://gis.stackexchange.com/questions/267321/extracting-epsg-from-a-raster-using-gdal-bindings-in-python diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala index c16891871..e13af4763 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the scale x of the raster. */ case class RST_ScaleX(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_ScaleX](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_ScaleX](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the scale x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(1) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala index 3b0779763..8defba49a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the scale y of the raster. */ case class RST_ScaleY(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_ScaleY](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_ScaleY](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the scale y of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(5) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala index 911271d33..f4350e7d3 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala @@ -12,22 +12,24 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.DataType /** Returns a raster with the specified no data values. */ case class RST_SetNoData( - rastersExpr: Expression, + tileExpr: Expression, noDataExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends Raster1ArgExpression[RST_SetNoData]( - rastersExpr, + tileExpr, noDataExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Returns a raster with the specified no data values. * @param tile diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala index ee3d0c4dd..439592e73 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the skew x of the raster. */ case class RST_SkewX(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SkewX](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_SkewX](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the skew x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(2) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala index ff9903687..1f259b5de 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the skew y of the raster. */ case class RST_SkewY(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SkewY](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_SkewY](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the skew y of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(4) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala index 8c58e7f74..3f1536510 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala @@ -13,13 +13,14 @@ import org.apache.spark.sql.types._ case class RST_Subdatasets(raster: Expression, expressionConfig: MosaicExpressionConfig) extends RasterExpression[RST_Subdatasets]( raster, - MapType(StringType, StringType), returnsRaster = false, expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, StringType) + /** Returns the subdatasets of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = buildMapString(tile.getRaster.subdatasets) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala index 6351d47f2..4900eaab2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala @@ -16,10 +16,12 @@ import java.util.{Vector => JVector} /** Returns the summary info the raster. */ case class RST_Summary(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Summary](raster, StringType, returnsRaster = false, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Summary](raster, returnsRaster = false, expressionConfig: MosaicExpressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = StringType + /** Returns the summary info the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val vector = new JVector[String]() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala index b364d39da..16dc25ee0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns true if the raster is empty. */ case class RST_TryOpen(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_TryOpen](raster, BooleanType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_TryOpen](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = BooleanType + /** Returns true if the raster can be opened. */ override def rasterTransform(tile: MosaicRasterTile): Any = { Option(tile.getRaster.getRaster).isDefined diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala index 4f050bc7e..143158736 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ case class RST_UpperLeftX(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_UpperLeftX](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_UpperLeftX](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(0) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala index 0e052e3ae..702c8a0c4 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left y of the raster. */ case class RST_UpperLeftY(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_UpperLeftY](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_UpperLeftY](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the upper left y of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(3) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala index 4bd56686a..953eb17bd 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the width of the raster. */ case class RST_Width(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Width](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Width](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the width of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.xSize diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala index 2d0884a81..2d5438c3b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala @@ -9,6 +9,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** Returns the world coordinate of the raster. */ case class RST_WorldToRasterCoord( @@ -16,10 +17,12 @@ case class RST_WorldToRasterCoord( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoord](raster, x, y, PixelCoordsType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoord](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = PixelCoordsType + /** * Returns the x and y of the raster by applying GeoTransform as a tuple of * Integers. This will ensure projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala index 26c888fe1..41d6e8b9b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala @@ -16,10 +16,12 @@ case class RST_WorldToRasterCoordX( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoordX](raster, x, y, IntegerType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoordX](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: IntegerType = IntegerType + /** * Returns the x coordinate of the raster by applying GeoTransform. This * will ensure projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala index 8bb125faa..62ba72228 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala @@ -16,10 +16,12 @@ case class RST_WorldToRasterCoordY( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoordY](raster, x, y, IntegerType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoordY](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: IntegerType = IntegerType + /** * Returns the y coordinate of the raster by applying GeoTransform. This * will ensure projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala index f01027ff1..35ad927c6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala @@ -2,12 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, NullIntolerant} -import org.apache.spark.sql.types.DataType import scala.reflect.ClassTag @@ -21,8 +21,6 @@ import scala.reflect.ClassTag * containing the raster file content. * @param arg1Expr * The expression for the first argument. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -31,7 +29,6 @@ import scala.reflect.ClassTag abstract class Raster1ArgExpression[T <: Expression: ClassTag]( rasterExpr: Expression, arg1Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends BinaryExpression @@ -43,9 +40,6 @@ abstract class Raster1ArgExpression[T <: Expression: ClassTag]( override def right: Expression = arg1Expr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster and the arguments to @@ -75,10 +69,15 @@ abstract class Raster1ArgExpression[T <: Expression: ClassTag]( // noinspection DuplicatedCode override def nullSafeEval(input: Any, arg1: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], expressionConfig.getCellIdType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + input.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) val raster = tile.getRaster val result = rasterTransform(tile, arg1) - val serialized = serialize(result, returnsRaster, outputType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) RasterCleaner.dispose(raster) RasterCleaner.dispose(result) serialized diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala index ccdc7d5b3..c5be60724 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala @@ -2,6 +2,7 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -22,8 +23,6 @@ import scala.reflect.ClassTag * The expression for the first argument. * @param arg2Expr * The expression for the second argument. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -33,7 +32,6 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( rasterExpr: Expression, arg1Expr: Expression, arg2Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends TernaryExpression @@ -47,9 +45,6 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( override def third: Expression = arg2Expr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster and the arguments to @@ -83,9 +78,14 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( // noinspection DuplicatedCode override def nullSafeEval(input: Any, arg1: Any, arg2: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], expressionConfig.getCellIdType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + input.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) val result = rasterTransform(tile, arg1, arg2) - val serialized = serialize(result, returnsRaster, outputType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) // passed by name makes things re-evaluated RasterCleaner.dispose(tile) serialized diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala index d21f96c2d..5dbfd08cc 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala @@ -2,11 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, NullIntolerant} -import org.apache.spark.sql.types.{ArrayType, DataType} +import org.apache.spark.sql.types.ArrayType import scala.reflect.ClassTag @@ -18,8 +19,6 @@ import scala.reflect.ClassTag * @param rastersExpr * The rasters expression. It is an array column containing rasters as either * paths or as content byte arrays. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -28,7 +27,6 @@ import scala.reflect.ClassTag abstract class RasterArray1ArgExpression[T <: Expression: ClassTag]( rastersExpr: Expression, arg1Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends BinaryExpression @@ -36,9 +34,6 @@ abstract class RasterArray1ArgExpression[T <: Expression: ClassTag]( with Serializable with RasterExpressionSerialization { - /** Output Data Type */ - override def dataType: DataType = if (returnsRaster) rastersExpr.dataType.asInstanceOf[ArrayType].elementType else outputType - override def left: Expression = rastersExpr override def right: Expression = arg1Expr @@ -72,7 +67,8 @@ abstract class RasterArray1ArgExpression[T <: Expression: ClassTag]( GDAL.enable(expressionConfig) val tiles = RasterArrayUtils.getTiles(input, rastersExpr, expressionConfig) val result = rasterTransform(tiles, arg1) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val resultType = if (returnsRaster) RasterTileType(rastersExpr).rasterType else dataType + val serialized = serialize(result, returnsRaster, resultType, expressionConfig) tiles.foreach(t => RasterCleaner.dispose(t)) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala index a26082f2d..9de963684 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala @@ -2,11 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, TernaryExpression} -import org.apache.spark.sql.types.{ArrayType, DataType} +import org.apache.spark.sql.types.ArrayType import scala.reflect.ClassTag @@ -18,8 +19,6 @@ import scala.reflect.ClassTag * @param rastersExpr * The rasters expression. It is an array column containing rasters as either * paths or as content byte arrays. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -29,7 +28,6 @@ abstract class RasterArray2ArgExpression[T <: Expression: ClassTag]( rastersExpr: Expression, arg1Expr: Expression, arg2Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends TernaryExpression @@ -37,9 +35,6 @@ abstract class RasterArray2ArgExpression[T <: Expression: ClassTag]( with Serializable with RasterExpressionSerialization { - /** Output Data Type */ - override def dataType: DataType = if (returnsRaster) rastersExpr.dataType.asInstanceOf[ArrayType].elementType else outputType - override def first: Expression = rastersExpr override def second: Expression = arg1Expr @@ -77,7 +72,8 @@ abstract class RasterArray2ArgExpression[T <: Expression: ClassTag]( GDAL.enable(expressionConfig) val tiles = RasterArrayUtils.getTiles(input, rastersExpr, expressionConfig) val result = rasterTransform(tiles, arg1, arg2) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val resultType = if (returnsRaster) RasterTileType(rastersExpr).rasterType else dataType + val serialized = serialize(result, returnsRaster, resultType, expressionConfig) tiles.foreach(t => RasterCleaner.dispose(t)) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala index b8ad9fc12..8c3a52d9a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala @@ -2,11 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, UnaryExpression} -import org.apache.spark.sql.types.{ArrayType, DataType} +import org.apache.spark.sql.types.ArrayType import scala.reflect.ClassTag @@ -27,7 +28,6 @@ import scala.reflect.ClassTag */ abstract class RasterArrayExpression[T <: Expression: ClassTag]( rastersExpr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends UnaryExpression @@ -37,9 +37,6 @@ abstract class RasterArrayExpression[T <: Expression: ClassTag]( override def child: Expression = rastersExpr - /** Output Data Type */ - override def dataType: DataType = if (returnsRaster) rastersExpr.dataType.asInstanceOf[ArrayType].elementType else outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the rasters to the expression. @@ -67,7 +64,8 @@ abstract class RasterArrayExpression[T <: Expression: ClassTag]( GDAL.enable(expressionConfig) val tiles = RasterArrayUtils.getTiles(input, rastersExpr, expressionConfig) val result = rasterTransform(tiles) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val resultType = if (returnsRaster) RasterTileType(rastersExpr).rasterType else dataType + val serialized = serialize(result, returnsRaster, resultType, expressionConfig) tiles.foreach(t => RasterCleaner.dispose(t)) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala index 3162bb421..f2d399350 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala @@ -1,5 +1,6 @@ package com.databricks.labs.mosaic.expressions.raster.base +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow @@ -12,11 +13,16 @@ object RasterArrayUtils { def getTiles(input: Any, rastersExpr: Expression, expressionConfig: MosaicExpressionConfig): Seq[MosaicRasterTile] = { val rasterDT = rastersExpr.dataType.asInstanceOf[ArrayType].elementType val arrayData = input.asInstanceOf[ArrayData] + val rasterType = RasterTileType(rastersExpr).rasterType val n = arrayData.numElements() (0 until n) .map(i => MosaicRasterTile - .deserialize(arrayData.get(i, rasterDT).asInstanceOf[InternalRow], expressionConfig.getCellIdType) + .deserialize( + arrayData.get(i, rasterDT).asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) ) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala index 7cee607ca..97bd3e333 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala @@ -3,12 +3,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterBandGDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, NullIntolerant} -import org.apache.spark.sql.types.DataType import scala.reflect.ClassTag @@ -23,8 +23,6 @@ import scala.reflect.ClassTag * MOSAIC_RASTER_STORAGE is set to MOSAIC_RASTER_STORAGE_BYTE. * @param bandExpr * The expression for the band index. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -33,7 +31,6 @@ import scala.reflect.ClassTag abstract class RasterBandExpression[T <: Expression: ClassTag]( rasterExpr: Expression, bandExpr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends BinaryExpression @@ -45,9 +42,6 @@ abstract class RasterBandExpression[T <: Expression: ClassTag]( override def right: Expression = bandExpr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster band to the @@ -79,13 +73,18 @@ abstract class RasterBandExpression[T <: Expression: ClassTag]( // noinspection DuplicatedCode override def nullSafeEval(inputRaster: Any, inputBand: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(inputRaster.asInstanceOf[InternalRow], expressionConfig.getCellIdType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + inputRaster.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) val bandIndex = inputBand.asInstanceOf[Int] val band = tile.getRaster.getBand(bandIndex) val result = bandTransform(tile, band) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) RasterCleaner.dispose(tile) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala index 462d3204b..66435f101 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala @@ -3,6 +3,7 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -20,8 +21,6 @@ import scala.reflect.ClassTag * The expression for the raster. If the raster is stored on disc, the path * to the raster is provided. If the raster is stored in memory, the bytes of * the raster are provided. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -29,7 +28,6 @@ import scala.reflect.ClassTag */ abstract class RasterExpression[T <: Expression: ClassTag]( rasterExpr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends UnaryExpression @@ -43,9 +41,6 @@ abstract class RasterExpression[T <: Expression: ClassTag]( override def child: Expression = rasterExpr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster to the expression. @@ -69,9 +64,14 @@ abstract class RasterExpression[T <: Expression: ClassTag]( */ override def nullSafeEval(input: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], cellIdDataType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + input.asInstanceOf[InternalRow], + cellIdDataType, + rasterType + ) val result = rasterTransform(tile) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) RasterCleaner.dispose(tile) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala index a9bf17917..dc04cb1c7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala @@ -35,11 +35,9 @@ trait RasterExpressionSerialization { ): Any = { if (returnsRaster) { val tile = data.asInstanceOf[MosaicRasterTile] - val checkpoint = expressionConfig.getRasterCheckpoint - val rasterType = outputDataType.asInstanceOf[StructType].fields(1).dataType val result = tile .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(rasterType, checkpoint) + .serialize(outputDataType) RasterCleaner.dispose(tile) result } else { diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala index 29c714788..3fc80752d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag * rasters based on the input raster. The new rasters are written in the * checkpoint directory. The files are written as GeoTiffs. Subdatasets are not * supported, please flatten beforehand. - * @param rasterExpr + * @param tileExpr * The expression for the raster. If the raster is stored on disc, the path * to the raster is provided. If the raster is stored in memory, the bytes of * the raster are provided. @@ -32,13 +32,13 @@ import scala.reflect.ClassTag * The type of the extending class. */ abstract class RasterGeneratorExpression[T <: Expression: ClassTag]( - rasterExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends CollectionGenerator with NullIntolerant with Serializable { - override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) val uuid: String = java.util.UUID.randomUUID().toString.replace("-", "_") @@ -72,11 +72,12 @@ abstract class RasterGeneratorExpression[T <: Expression: ClassTag]( override def eval(input: InternalRow): TraversableOnce[InternalRow] = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(rasterExpr.eval(input).asInstanceOf[InternalRow], cellIdDataType) + val rasterType = RasterTileType(tileExpr).rasterType + val tile = MosaicRasterTile.deserialize(tileExpr.eval(input).asInstanceOf[InternalRow], cellIdDataType, rasterType) val generatedRasters = rasterGenerator(tile) // Writing rasters disposes of the written raster - val rows = generatedRasters.map(_.formatCellId(indexSystem).serialize()) + val rows = generatedRasters.map(_.formatCellId(indexSystem).serialize(rasterType)) generatedRasters.foreach(gr => RasterCleaner.dispose(gr)) RasterCleaner.dispose(tile) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala index f2545942b..98ff86ca7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag * checkpoint directory. The files are written as GeoTiffs. Subdatasets are not * supported, please flatten beforehand. * - * @param rasterExpr + * @param tileExpr * The expression for the raster. If the raster is stored on disc, the path * to the raster is provided. If the raster is stored in memory, the bytes of * the raster are provided. @@ -33,7 +33,7 @@ import scala.reflect.ClassTag * The type of the extending class. */ abstract class RasterTessellateGeneratorExpression[T <: Expression: ClassTag]( - rasterExpr: Expression, + tileExpr: Expression, resolutionExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends CollectionGenerator @@ -55,7 +55,8 @@ abstract class RasterTessellateGeneratorExpression[T <: Expression: ClassTag]( * needs to be wrapped in a StructType. The actually type is that of the * structs element. */ - override def elementSchema: StructType = StructType(Array(StructField("element", RasterTileType(indexSystem.getCellIdDataType)))) + override def elementSchema: StructType = + StructType(Array(StructField("element", RasterTileType(indexSystem.getCellIdDataType, tileExpr)))) /** * The function to be overridden by the extending class. It is called when @@ -71,17 +72,15 @@ abstract class RasterTessellateGeneratorExpression[T <: Expression: ClassTag]( override def eval(input: InternalRow): TraversableOnce[InternalRow] = { GDAL.enable(expressionConfig) + val rasterType = RasterTileType(tileExpr).rasterType val tile = MosaicRasterTile - .deserialize( - rasterExpr.eval(input).asInstanceOf[InternalRow], - indexSystem.getCellIdDataType - ) + .deserialize(tileExpr.eval(input).asInstanceOf[InternalRow], indexSystem.getCellIdDataType, rasterType) val inResolution: Int = indexSystem.getResolution(resolutionExpr.eval(input)) val generatedChips = rasterGenerator(tile, inResolution) .map(chip => chip.formatCellId(indexSystem)) val rows = generatedChips - .map(chip => InternalRow.fromSeq(Seq(chip.formatCellId(indexSystem).serialize()))) + .map(chip => InternalRow.fromSeq(Seq(chip.formatCellId(indexSystem).serialize(rasterType)))) RasterCleaner.dispose(tile) generatedChips.foreach(chip => RasterCleaner.dispose(chip)) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala index 743f9cbd6..e7b04f989 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala @@ -37,11 +37,13 @@ abstract class RasterToGridExpression[T <: Expression: ClassTag, P]( resolution: Expression, measureType: DataType, expressionConfig: MosaicExpressionConfig -) extends Raster1ArgExpression[T](rasterExpr, resolution, RasterToGridType(expressionConfig.getCellIdType, measureType), returnsRaster = false, expressionConfig) +) extends Raster1ArgExpression[T](rasterExpr, resolution, returnsRaster = false, expressionConfig) with RasterGridExpression with NullIntolerant with Serializable { + override def dataType: DataType = RasterToGridType(expressionConfig.getCellIdType, measureType) + /** The index system to be used. */ val indexSystem: IndexSystem = IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem) val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 2b85c9785..8398e6882 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -8,20 +8,19 @@ import com.databricks.labs.mosaic.core.types.ChipType import com.databricks.labs.mosaic.datasource.multiread.MosaicDataFrameReader import com.databricks.labs.mosaic.expressions.constructors._ import com.databricks.labs.mosaic.expressions.format._ -import com.databricks.labs.mosaic.expressions.geometry._ import com.databricks.labs.mosaic.expressions.geometry.ST_MinMaxXYZ._ +import com.databricks.labs.mosaic.expressions.geometry._ import com.databricks.labs.mosaic.expressions.index._ import com.databricks.labs.mosaic.expressions.raster._ import com.databricks.labs.mosaic.expressions.util.TrySql -import com.databricks.labs.mosaic.functions.MosaicContext.mosaicVersion import com.databricks.labs.mosaic.utils.FileUtils import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Column, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{LongType, StringType} +import org.apache.spark.sql.{Column, SparkSession} import scala.reflect.runtime.universe @@ -270,6 +269,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends mosaicRegistry.registerExpression[RST_Height](expressionConfig) mosaicRegistry.registerExpression[RST_InitNoData](expressionConfig) mosaicRegistry.registerExpression[RST_IsEmpty](expressionConfig) + mosaicRegistry.registerExpression[RST_MakeTiles](expressionConfig) mosaicRegistry.registerExpression[RST_Max](expressionConfig) mosaicRegistry.registerExpression[RST_Min](expressionConfig) mosaicRegistry.registerExpression[RST_Median](expressionConfig) @@ -655,6 +655,10 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends def rst_combineavg(rasterArray: Column): Column = ColumnAdapter(RST_CombineAvg(rasterArray.expr, expressionConfig)) def rst_derivedband(raster: Column, pythonFunc: Column, funcName: Column): Column = ColumnAdapter(RST_DerivedBand(raster.expr, pythonFunc.expr, funcName.expr, expressionConfig)) + def rst_filter(raster: Column, kernelSize: Column, operation: Column): Column = + ColumnAdapter(RST_Filter(raster.expr, kernelSize.expr, operation.expr, expressionConfig)) + def rst_filter(raster: Column, kernelSize: Int, operation: String): Column = + ColumnAdapter(RST_Filter(raster.expr, lit(kernelSize).expr, lit(operation).expr, expressionConfig)) def rst_georeference(raster: Column): Column = ColumnAdapter(RST_GeoReference(raster.expr, expressionConfig)) def rst_getnodata(raster: Column): Column = ColumnAdapter(RST_GetNoData(raster.expr, expressionConfig)) def rst_getsubdataset(raster: Column, subdatasetName: Column): Column = @@ -664,6 +668,20 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends def rst_height(raster: Column): Column = ColumnAdapter(RST_Height(raster.expr, expressionConfig)) def rst_initnodata(raster: Column): Column = ColumnAdapter(RST_InitNoData(raster.expr, expressionConfig)) def rst_isempty(raster: Column): Column = ColumnAdapter(RST_IsEmpty(raster.expr, expressionConfig)) + def rst_maketiles(input: Column, driver: Column, size: Column, withCheckpoint: Column): Column = + ColumnAdapter(RST_MakeTiles(input.expr, driver.expr, size.expr, withCheckpoint.expr, expressionConfig)) + def rst_maketiles(input: Column, driver: String, size: Int, withCheckpoint: Boolean): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(size).expr, lit(withCheckpoint).expr, expressionConfig)) + def rst_maketiles(input: Column): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(MOSAIC_NO_DRIVER).expr, lit(-1).expr, lit(false).expr, expressionConfig)) + def rst_maketiles(input: Column, size: Int): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(MOSAIC_NO_DRIVER).expr, lit(size).expr, lit(false).expr, expressionConfig)) + def rst_maketiles(input: Column, driver: String): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(-1).expr, lit(false).expr, expressionConfig)) + def rst_maketiles(input: Column, driver: String, withCheckpoint: Boolean): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(-1).expr, lit(withCheckpoint).expr, expressionConfig)) + def rst_maketiles(input: Column, size: Int, withCheckpoint: Boolean): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(MOSAIC_NO_DRIVER).expr, lit(size).expr, lit(withCheckpoint).expr, expressionConfig)) def rst_max(raster: Column): Column = ColumnAdapter(RST_Max(raster.expr, expressionConfig)) def rst_min(raster: Column): Column = ColumnAdapter(RST_Min(raster.expr, expressionConfig)) def rst_median(raster: Column): Column = ColumnAdapter(RST_Median(raster.expr, expressionConfig)) diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala index f306d4e9c..d6643f59b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala @@ -33,6 +33,8 @@ case class MosaicExpressionConfig(configs: Map[String, String]) { def getRasterCheckpoint: String = configs.getOrElse(MOSAIC_RASTER_CHECKPOINT, MOSAIC_RASTER_CHECKPOINT_DEFAULT) def getCellIdType: DataType = IndexSystemFactory.getIndexSystem(getIndexSystem).cellIdType + + def getRasterBlockSize: Int = configs.getOrElse(MOSAIC_RASTER_BLOCKSIZE, MOSAIC_RASTER_BLOCKSIZE_DEFAULT).toInt def setGDALConf(conf: RuntimeConfig): MosaicExpressionConfig = { val toAdd = conf.getAll.filter(_._1.startsWith(MOSAIC_GDAL_PREFIX)) diff --git a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala index 9e8bf1132..92438844a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala @@ -1,11 +1,12 @@ package com.databricks.labs.mosaic.gdal +import com.databricks.labs.mosaic.MOSAIC_RASTER_BLOCKSIZE_DEFAULT import com.databricks.labs.mosaic.functions.{MosaicContext, MosaicExpressionConfig} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.gdal.gdal.gdal +import org.gdal.osr.SpatialReference -import java.io.{BufferedInputStream, File, PrintWriter} import java.nio.file.{Files, Paths} import scala.language.postfixOps import scala.util.Try @@ -22,9 +23,22 @@ object MosaicGDAL extends Logging { private val libjniso3003Path = "/usr/lib/libgdalalljni.so.30.0.3" private val libogdisoPath = "/usr/lib/ogdi/4.1/libgdal.so" + val defaultBlockSize = 1024 + val vrtBlockSize = 128 // This is a must value for VRTs before GDAL 3.7 + var blockSize: Int = MOSAIC_RASTER_BLOCKSIZE_DEFAULT.toInt + // noinspection ScalaWeakerAccess val GDAL_ENABLED = "spark.mosaic.gdal.native.enabled" var isEnabled = false + var checkpointPath: String = _ + + // Only use this with GDAL rasters + val WSG84: SpatialReference = { + val wsg84 = new SpatialReference() + wsg84.ImportFromEPSG(4326) + wsg84.SetAxisMappingStrategy(org.gdal.osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER) + wsg84 + } /** Returns true if GDAL is enabled. */ def wasEnabled(spark: SparkSession): Boolean = @@ -33,15 +47,28 @@ object MosaicGDAL extends Logging { /** Configures the GDAL environment. */ def configureGDAL(mosaicConfig: MosaicExpressionConfig): Unit = { val CPL_TMPDIR = MosaicContext.tmpDir - val GDAL_PAM_PROXY_DIR = MosaicContext.tmpDir gdal.SetConfigOption("GDAL_VRT_ENABLE_PYTHON", "YES") - gdal.SetConfigOption("GDAL_DISABLE_READDIR_ON_OPEN", "EMPTY_DIR") + gdal.SetConfigOption("GDAL_DISABLE_READDIR_ON_OPEN", "TRUE") gdal.SetConfigOption("CPL_TMPDIR", CPL_TMPDIR) - gdal.SetConfigOption("GDAL_PAM_PROXY_DIR", GDAL_PAM_PROXY_DIR) - gdal.SetConfigOption("GDAL_PAM_ENABLED", "NO") - gdal.SetConfigOption("CPL_VSIL_USE_TEMP_FILE_FOR_RANDOM_WRITE", "NO") gdal.SetConfigOption("CPL_LOG", s"$CPL_TMPDIR/gdal.log") + gdal.SetConfigOption("GDAL_CACHEMAX", "512") + gdal.SetConfigOption("GDAL_NUM_THREADS", "ALL_CPUS") mosaicConfig.getGDALConf.foreach { case (k, v) => gdal.SetConfigOption(k.split("\\.").last, v) } + setBlockSize(mosaicConfig) + checkpointPath = mosaicConfig.getRasterCheckpoint + } + + def setBlockSize(mosaicConfig: MosaicExpressionConfig): Unit = { + val blockSize = mosaicConfig.getRasterBlockSize + if (blockSize > 0) { + this.blockSize = blockSize + } + } + + def setBlockSize(size: Int): Unit = { + if (size > 0) { + this.blockSize = size + } } /** Enables the GDAL environment. */ @@ -91,18 +118,19 @@ object MosaicGDAL extends Logging { } } - /** Reads the resource bytes. */ - private def readResourceBytes(name: String): Array[Byte] = { - val bis = new BufferedInputStream(getClass.getResourceAsStream(name)) - try { Stream.continually(bis.read()).takeWhile(-1 !=).map(_.toByte).toArray } - finally bis.close() - } +// /** Reads the resource bytes. */ +// private def readResourceBytes(name: String): Array[Byte] = { +// val bis = new BufferedInputStream(getClass.getResourceAsStream(name)) +// try { Stream.continually(bis.read()).takeWhile(-1 !=).map(_.toByte).toArray } +// finally bis.close() +// } + +// /** Reads the resource lines. */ +// // noinspection SameParameterValue +// private def readResourceLines(name: String): Array[String] = { +// val bytes = readResourceBytes(name) +// val lines = new String(bytes).split("\n") +// lines +// } - /** Reads the resource lines. */ - // noinspection SameParameterValue - private def readResourceLines(name: String): Array[String] = { - val bytes = readResourceBytes(name) - val lines = new String(bytes).split("\n") - lines - } } diff --git a/src/main/scala/com/databricks/labs/mosaic/package.scala b/src/main/scala/com/databricks/labs/mosaic/package.scala index 58ee2f98e..eea63cd79 100644 --- a/src/main/scala/com/databricks/labs/mosaic/package.scala +++ b/src/main/scala/com/databricks/labs/mosaic/package.scala @@ -21,13 +21,17 @@ package object mosaic { val MOSAIC_GDAL_PREFIX = "spark.databricks.labs.mosaic.gdal." val MOSAIC_GDAL_NATIVE = "spark.databricks.labs.mosaic.gdal.native" val MOSAIC_RASTER_CHECKPOINT = "spark.databricks.labs.mosaic.raster.checkpoint" - val MOSAIC_RASTER_CHECKPOINT_DEFAULT = "dbfs:/tmp/mosaic/raster/checkpoint" + val MOSAIC_RASTER_CHECKPOINT_DEFAULT = "/dbfs/tmp/mosaic/raster/checkpoint" + val MOSAIC_RASTER_BLOCKSIZE = "spark.databricks.labs.mosaic.raster.blocksize" + val MOSAIC_RASTER_BLOCKSIZE_DEFAULT = "128" val MOSAIC_RASTER_READ_STRATEGY = "raster.read.strategy" val MOSAIC_RASTER_READ_IN_MEMORY = "in_memory" val MOSAIC_RASTER_READ_AS_PATH = "as_path" val MOSAIC_RASTER_RE_TILE_ON_READ = "retile_on_read" + val MOSAIC_NO_DRIVER = "no_driver" + def read: MosaicDataFrameReader = new MosaicDataFrameReader(SparkSession.builder().getOrCreate()) diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala index a1aac5c2f..fc01cfaa0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala @@ -22,10 +22,10 @@ object FileUtils { bytes } - def createMosaicTempDir(): String = { - val tempRoot = Paths.get("/mosaic_tmp/") + def createMosaicTempDir(prefix: String = ""): String = { + val tempRoot = Paths.get(s"$prefix/mosaic_tmp/") if (!Files.exists(tempRoot)) { - Files.createDirectory(tempRoot) + Files.createDirectories(tempRoot) } val tempDir = Files.createTempDirectory(tempRoot, "mosaic") tempDir.toFile.getAbsolutePath diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala index d48c03bfd..469bb0f44 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala @@ -1,7 +1,5 @@ package com.databricks.labs.mosaic.utils -import com.databricks.labs.mosaic.core.raster.api.GDAL -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.functions.MosaicContext import java.nio.file.{Files, Paths} @@ -10,11 +8,15 @@ object PathUtils { val NO_PATH_STRING = "no_path" - def getCleanPath(path: String): String = { - val cleanPath = path + def replaceDBFSTokens(path: String): String = { + path .replace("file:/", "/") .replace("dbfs:/Volumes", "/Volumes") - .replace("dbfs:/","/dbfs/") + .replace("dbfs:/", "/dbfs/") + } + + def getCleanPath(path: String): String = { + val cleanPath = replaceDBFSTokens(path) if (cleanPath.endsWith(".zip") || cleanPath.contains(".zip:")) { getZipPath(cleanPath) } else { @@ -61,17 +63,51 @@ object PathUtils { if (filePath.endsWith("\"")) result = result.dropRight(1) result } + + def getStemRegex(path: String): String = { + val cleanPath = replaceDBFSTokens(path) + val fileName = Paths.get(cleanPath).getFileName.toString + val stemName = fileName.substring(0, fileName.lastIndexOf(".")) + val stemEscaped = stemName.replace(".", "\\.") + val stemRegex = s"$stemEscaped\\..*".r + stemRegex.toString + } - def copyToTmp(inPath: String): String = { - val copyFromPath = inPath - .replace("file:/", "/") - .replace("dbfs:/Volumes", "/Volumes") - .replace("dbfs:/","/dbfs/") - val driver = MosaicRasterGDAL.identifyDriver(getCleanPath(inPath)) - val extension = if (inPath.endsWith(".zip")) "zip" else GDAL.getExtension(driver) - val tmpPath = createTmpFilePath(extension) - Files.copy(Paths.get(copyFromPath), Paths.get(tmpPath)) - tmpPath + def copyToTmp(inPath: String): String = { + val copyFromPath = replaceDBFSTokens(inPath) + val inPathDir = Paths.get(copyFromPath).getParent.toString + + val fullFileName = copyFromPath.split("/").last + val stemRegex = getStemRegex(inPath) + + wildcardCopy(inPathDir, MosaicContext.tmpDir, stemRegex.toString) + + s"${MosaicContext.tmpDir}/$fullFileName" + } + + def wildcardCopy(inDirPath: String, outDirPath: String, pattern: String): Unit = { + import org.apache.commons.io.FileUtils + val copyFromPath = replaceDBFSTokens(inDirPath) + val copyToPath = replaceDBFSTokens(outDirPath) + + val toCopy = Files + .list(Paths.get(copyFromPath)) + .filter(_.getFileName.toString.matches(pattern)) + + toCopy.forEach(path => { + val destination = Paths.get(copyToPath, path.getFileName.toString) + //noinspection SimplifyBooleanMatch + Files.isDirectory(path) match { + case true => FileUtils.copyDirectory(path.toFile, destination.toFile) + case false => Files.copy(path, destination) + } + }) + } + + def parseUnzippedPathFromExtracted(lastExtracted: String, extension: String): String = { + val trimmed = lastExtracted.replace("extracting: ", "").replace(" ", "") + val indexOfFormat = trimmed.indexOf(s".$extension/") + trimmed.substring(0, indexOfFormat + extension.length + 1) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala index 85fa12785..ba1d9c417 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.utils -import java.io.{ByteArrayOutputStream, PrintWriter} +import java.io.{BufferedReader, ByteArrayOutputStream, InputStreamReader, PrintWriter} object SysUtils { @@ -11,16 +11,40 @@ object SysUtils { val stderrStream = new ByteArrayOutputStream val stdoutWriter = new PrintWriter(stdoutStream) val stderrWriter = new PrintWriter(stderrStream) - val exitValue = try { - //noinspection ScalaStyle - cmd.!!(ProcessLogger(stdoutWriter.println, stderrWriter.println)) - } catch { - case _: Exception => "ERROR" - } finally { - stdoutWriter.close() - stderrWriter.close() - } + val exitValue = + try { + // noinspection ScalaStyle + cmd.!!(ProcessLogger(stdoutWriter.println, stderrWriter.println)) + } catch { + case e: Exception => s"ERROR: ${e.getMessage}" + } finally { + stdoutWriter.close() + stderrWriter.close() + } (exitValue, stdoutStream.toString, stderrStream.toString) } + def runScript(cmd: Array[String]): (String, String, String) = { + val p = Runtime.getRuntime.exec(cmd) + val stdinStream = new BufferedReader(new InputStreamReader(p.getInputStream)) + val stderrStream = new BufferedReader(new InputStreamReader(p.getErrorStream)) + val exitValue = + try { + p.waitFor() + } catch { + case e: Exception => s"ERROR: ${e.getMessage}" + } + val stdinOutput = stdinStream.lines().toArray.mkString("\n") + val stderrOutput = stderrStream.lines().toArray.mkString("\n") + stdinStream.close() + stderrStream.close() + (s"$exitValue", stdinOutput, stderrOutput) + } + + def getLastOutputLine(prompt: (String, String, String)): String = { + val (_, stdout, _) = prompt + val lines = stdout.split("\n") + lines.last + } + } diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib.aux.xml b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb.aux.xml similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib.aux.xml rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb.aux.xml diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib.aux.xml b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb.aux.xml similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib.aux.xml rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb.aux.xml diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib.aux.xml b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb.aux.xml similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib.aux.xml rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb.aux.xml diff --git a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala index 15eef2009..1337ae6d2 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala @@ -37,8 +37,8 @@ class TestRasterBandGDAL extends SharedSparkSessionGDAL { assume(System.getProperty("os.name") == "Linux") val testRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib"), - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib") + filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb"), + filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") ) val testBand = testRaster.getBand(1) testBand.description shouldBe "1[-] HYBL=\"Hybrid level\"" diff --git a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala index e39279843..bb53d6b79 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala @@ -1,9 +1,12 @@ package com.databricks.labs.mosaic.core.raster import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.gdal.MosaicGDAL import com.databricks.labs.mosaic.test.mocks.filePath import org.apache.spark.sql.test.SharedSparkSessionGDAL import org.scalatest.matchers.should.Matchers._ +import org.gdal.gdal.{gdal => gdalJNI} +import org.gdal.gdalconst import scala.sys.process._ import scala.util.Try @@ -43,7 +46,7 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { testRaster.SRID shouldBe 0 testRaster.extent shouldBe Seq(-8895604.157333, 1111950.519667, -7783653.637667, 2223901.039333) testRaster.getRaster.GetProjection() - noException should be thrownBy testRaster.spatialRef + noException should be thrownBy testRaster.getSpatialReference an[Exception] should be thrownBy testRaster.getBand(-1) an[Exception] should be thrownBy testRaster.getBand(Int.MaxValue) @@ -54,8 +57,8 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { assume(System.getProperty("os.name") == "Linux") val testRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib"), - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib") + filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb"), + filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") ) testRaster.xSize shouldBe 14 testRaster.ySize shouldBe 14 @@ -96,8 +99,8 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { assume(System.getProperty("os.name") == "Linux") val testRaster = MosaicRasterGDAL.readRaster( - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") + filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), + filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") ) testRaster.pixelXSize - 463.312716527 < 0.0000001 shouldBe true @@ -115,4 +118,214 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { testRaster.getRaster.delete() } + test("Raster filter operations are correct.") { + assume(System.getProperty("os.name") == "Linux") + + gdalJNI.AllRegister() + + MosaicGDAL.setBlockSize(30) + + val ds = gdalJNI.GetDriverByName("GTiff").Create("/mosaic_tmp/test.tif", 50, 50, 1, gdalconst.gdalconstConstants.GDT_Float32) + + val values = 0 until 50 * 50 + ds.GetRasterBand(1).WriteRaster(0, 0, 50, 50, values.toArray) + ds.FlushCache() + + var result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "avg").flushCache() + + var resultValues = result.getBand(1).values + + var inputMatrix = values.toArray.grouped(50).toArray + var resultMatrix = resultValues.grouped(50).toArray + + // first block + resultMatrix(10)(11) shouldBe ( + inputMatrix(8)(9) + inputMatrix(8)(10) + inputMatrix(8)(11) + inputMatrix(8)(12) + inputMatrix(8)(13) + + inputMatrix(9)(9) + inputMatrix(9)(10) + inputMatrix(9)(11) + inputMatrix(9)(12) + inputMatrix(9)(13) + + inputMatrix(10)(9) + inputMatrix(10)(10) + inputMatrix(10)(11) + inputMatrix(10)(12) + inputMatrix(10)(13) + + inputMatrix(11)(9) + inputMatrix(11)(10) + inputMatrix(11)(11) + inputMatrix(11)(12) + inputMatrix(11)(13) + + inputMatrix(12)(9) + inputMatrix(12)(10) + inputMatrix(12)(11) + inputMatrix(12)(12) + inputMatrix(12)(13) + ).toDouble / 25.0 + + // block overlap + resultMatrix(30)(32) shouldBe ( + inputMatrix(28)(30) + inputMatrix(28)(31) + inputMatrix(28)(32) + inputMatrix(28)(33) + inputMatrix(28)(34) + + inputMatrix(29)(30) + inputMatrix(29)(31) + inputMatrix(29)(32) + inputMatrix(29)(33) + inputMatrix(29)(34) + + inputMatrix(30)(30) + inputMatrix(30)(31) + inputMatrix(30)(32) + inputMatrix(30)(33) + inputMatrix(30)(34) + + inputMatrix(31)(30) + inputMatrix(31)(31) + inputMatrix(31)(32) + inputMatrix(31)(33) + inputMatrix(31)(34) + + inputMatrix(32)(30) + inputMatrix(32)(31) + inputMatrix(32)(32) + inputMatrix(32)(33) + inputMatrix(32)(34) + ).toDouble / 25.0 + + // mode + + result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "mode").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).groupBy(identity).maxBy(_._2.size)._1.toDouble + + // corner + + resultMatrix(49)(49) shouldBe Seq( + inputMatrix(47)(47), + inputMatrix(47)(48), + inputMatrix(47)(49), + inputMatrix(48)(47), + inputMatrix(48)(48), + inputMatrix(48)(49), + inputMatrix(49)(47), + inputMatrix(49)(48), + inputMatrix(49)(49) + ).groupBy(identity).maxBy(_._2.size)._1.toDouble + + // median + + result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "median").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).sorted.apply(12).toDouble + + // min filter + + result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "min").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).min.toDouble + + // max filter + + result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "max").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).max.toDouble + + } + } diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala index 99a1563ca..623993b01 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala @@ -34,7 +34,7 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { .take(1) } - + test("Read grib with GDALFileFormat") { assume(System.getProperty("os.name") == "Linux") @@ -43,25 +43,22 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { noException should be thrownBy spark.read .format("gdal") - .option("extensions", "grib") - .option("raster_storage", "disk") - .option("extensions", "grib") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") .load(filePath) .take(1) noException should be thrownBy spark.read .format("gdal") - .option("extensions", "grib") - .option("raster_storage", "disk") - .option("extensions", "grib") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") .load(filePath) .take(1) noException should be thrownBy spark.read .format("gdal") - .option("extensions", "grib") - .option("raster_storage", "disk") - .option("extensions", "grib") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") .load(filePath) .select("metadata") .take(1) @@ -92,7 +89,7 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { .select("metadata") .take(1) - noException should be thrownBy spark.read + noException should be thrownBy spark.read .format("gdal") .option(MOSAIC_RASTER_READ_STRATEGY, "retile_on_read") .load(filePath) diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala index 6e99aa1df..fba2b74cb 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala @@ -1,12 +1,12 @@ package com.databricks.labs.mosaic.datasource.multiread -import com.databricks.labs.mosaic.functions.MosaicContext import com.databricks.labs.mosaic.JTS import com.databricks.labs.mosaic.core.index.H3IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest import org.apache.spark.sql.test.SharedSparkSessionGDAL import org.scalatest.matchers.must.Matchers.{be, noException} -import org.scalatest.matchers.should.Matchers.{an, convertToAnyShouldWrapper} +import org.scalatest.matchers.should.Matchers.an import java.nio.file.{Files, Paths} @@ -14,149 +14,161 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess test("Read big tif with Raster As Grid Reader") { assume(System.getProperty("os.name") == "Linux") - spark.sparkContext.setLogLevel("INFO") MosaicContext.build(H3IndexSystem, JTS) - val tif = "/binary/big_tiff.tif" + val tif = "/modis/" val filePath = getClass.getResource(tif).getPath val df = MosaicContext.read .format("raster_to_grid") .option("retile", "true") - .option("sizeInMB", "64") + .option("sizeInMB", "128") .option("resolution", "1") .load(filePath) .select("measure") - //df.queryExecution.optimizedPlan + df.queryExecution.optimizedPlan - //noException should be thrownBy df.queryExecution.executedPlan + noException should be thrownBy df.queryExecution.executedPlan df.count() } -// test("Read netcdf with Raster As Grid Reader") { -// assume(System.getProperty("os.name") == "Linux") -// MosaicContext.build(H3IndexSystem, JTS) -// -// val netcdf = "/binary/netcdf-coral/" -// val filePath = getClass.getResource(netcdf).getPath -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("retile", "true") -// .option("tileSize", "10") -// .option("readSubdataset", "true") -// .option("subdataset", "1") -// .option("kRingInterpolate", "3") -// .load(filePath) -// .select("measure") -// .queryExecution -// .executedPlan -// -// } -// -// test("Read grib with Raster As Grid Reader") { -// assume(System.getProperty("os.name") == "Linux") -// MosaicContext.build(H3IndexSystem, JTS) -// -// val grib = "/binary/grib-cams/" -// val filePath = getClass.getResource(grib).getPath -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("extensions", "grib") -// .option("combiner", "min") -// .option("retile", "true") -// .option("tileSize", "10") -// .option("kRingInterpolate", "3") -// .load(filePath) -// .select("measure") -// .take(1) -// -// } -// -// test("Read tif with Raster As Grid Reader") { -// assume(System.getProperty("os.name") == "Linux") -// MosaicContext.build(H3IndexSystem, JTS) -// -// val tif = "/modis/" -// val filePath = getClass.getResource(tif).getPath -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "max") -// .option("tileSize", "10") -// .option("kRingInterpolate", "3") -// .load(filePath) -// .select("measure") -// .take(1) -// -// } -// -// test("Read zarr with Raster As Grid Reader") { -// assume(System.getProperty("os.name") == "Linux") -// MosaicContext.build(H3IndexSystem, JTS) -// -// val zarr = "/binary/zarr-example/" -// val filePath = getClass.getResource(zarr).getPath -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "median") -// .option("vsizip", "true") -// .option("tileSize", "10") -// .load(filePath) -// .select("measure") -// .take(1) -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "count") -// .option("vsizip", "true") -// .load(filePath) -// .select("measure") -// .take(1) -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "average") -// .option("vsizip", "true") -// .load(filePath) -// .select("measure") -// .take(1) -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "avg") -// .option("vsizip", "true") -// .load(filePath) -// .select("measure") -// .take(1) -// -// val paths = Files.list(Paths.get(filePath)).toArray.map(_.toString) -// -// an[Error] should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "count_+") -// .option("vsizip", "true") -// .load(paths: _*) -// .select("measure") -// .take(1) -// -// an[Error] should be thrownBy MosaicContext.read -// .format("invalid") -// .load(paths: _*) -// -// an[Error] should be thrownBy MosaicContext.read -// .format("invalid") -// .load(filePath) -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("kRingInterpolate", "3") -// .load(filePath) -// -// } + test("Read netcdf with Raster As Grid Reader") { + assume(System.getProperty("os.name") == "Linux") + MosaicContext.build(H3IndexSystem, JTS) + + val netcdf = "/binary/netcdf-coral/" + val filePath = getClass.getResource(netcdf).getPath + + //noException should be thrownBy + + + MosaicContext.read + .format("raster_to_grid") + .option("retile", "true") + .option("tileSize", "10") + .option("readSubdataset", "true") + .option("subdataset", "1") + .option("kRingInterpolate", "3") + .load(filePath) + .select("measure") + .queryExecution + .executedPlan + + } + + test("Read grib with Raster As Grid Reader") { + assume(System.getProperty("os.name") == "Linux") + MosaicContext.build(H3IndexSystem, JTS) + + val grib = "/binary/grib-cams/" + val filePath = getClass.getResource(grib).getPath + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("extensions", "grib") + .option("combiner", "min") + .option("retile", "true") + .option("tileSize", "10") + .option("kRingInterpolate", "3") + .load(filePath) + .select("measure") + .take(1) + + } + + test("Read tif with Raster As Grid Reader") { + assume(System.getProperty("os.name") == "Linux") + MosaicContext.build(H3IndexSystem, JTS) + + val tif = "/modis/" + val filePath = getClass.getResource(tif).getPath + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("combiner", "max") + .option("tileSize", "10") + .option("kRingInterpolate", "3") + .load(filePath) + .select("measure") + .take(1) + + } + + test("Read zarr with Raster As Grid Reader") { + assume(System.getProperty("os.name") == "Linux") + MosaicContext.build(H3IndexSystem, JTS) + + val zarr = "/binary/zarr-example/" + val filePath = getClass.getResource(zarr).getPath + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") + .option("combiner", "median") + .option("vsizip", "true") + .option("tileSize", "10") + .load(filePath) + .select("measure") + .take(1) + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") + .option("combiner", "count") + .option("vsizip", "true") + .load(filePath) + .select("measure") + .take(1) + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") + .option("combiner", "average") + .option("vsizip", "true") + .load(filePath) + .select("measure") + .take(1) + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") + .option("combiner", "avg") + .option("vsizip", "true") + .load(filePath) + .select("measure") + .take(1) + + val paths = Files.list(Paths.get(filePath)).toArray.map(_.toString) + + an[Error] should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("combiner", "count_+") + .option("vsizip", "true") + .load(paths: _*) + .select("measure") + .take(1) + + an[Error] should be thrownBy MosaicContext.read + .format("invalid") + .load(paths: _*) + + an[Error] should be thrownBy MosaicContext.read + .format("invalid") + .load(filePath) + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") + .option("kRingInterpolate", "3") + .load(filePath) + + } } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala index 8ce57f5b8..611bf8f77 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala @@ -35,7 +35,9 @@ trait RST_CombineAvgBehaviors extends QueryTest { rastersInMemory.union(rastersInMemory) .createOrReplaceTempView("source") - noException should be thrownBy spark.sql(""" + //noException should be thrownBy + + spark.sql(""" |select rst_combineavg(collect_set(tiles)) as tiles |from ( | select path, rst_tessellate(tile, 2) as tiles diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala new file mode 100644 index 000000000..d06923dc1 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala @@ -0,0 +1,36 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.scalatest.matchers.should.Matchers._ + +trait RST_FilterBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("result", rst_filter($"tile", 3, "mode")) + .select("result") + .collect() + + gridTiles.length should be(7) + + rastersInMemory.createOrReplaceTempView("source") + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterTest.scala new file mode 100644 index 000000000..a243f7168 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_FilterTest extends QueryTest with SharedSparkSessionGDAL with RST_FilterBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_Filter with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala index bd867ee65..d01f79fec 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala @@ -37,7 +37,7 @@ trait RST_MinBehaviors extends QueryTest { val result = df.as[Double].collect().min - result < 0 shouldBe true + result == 0 shouldBe true an[Exception] should be thrownBy spark.sql(""" |select rst_min() from source diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala index c346e82db..cfffd9e6b 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala @@ -4,6 +4,7 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ trait RST_TessellateBehaviors extends QueryTest { @@ -24,8 +25,10 @@ trait RST_TessellateBehaviors extends QueryTest { val gridTiles = rastersInMemory .withColumn("tiles", rst_tessellate($"tile", 3)) - .select("tiles") - + .withColumn("bbox", st_aswkt(rst_boundingbox($"tile"))) + .select("bbox", "path", "tiles") + .withColumn("avg", rst_avg($"tiles")) + rastersInMemory .createOrReplaceTempView("source") @@ -37,9 +40,9 @@ trait RST_TessellateBehaviors extends QueryTest { .withColumn("tiles", rst_tessellate($"tile", 3)) .select("tiles") - val result = gridTiles.collect() + val result = gridTiles.select(explode(col("avg")).alias("a")).groupBy("a").count().collect() - result.length should be(380) + result.length should be(441) } diff --git a/src/test/scala/com/databricks/labs/mosaic/test/package.scala b/src/test/scala/com/databricks/labs/mosaic/test/package.scala index 435ee552c..2e5951be7 100644 --- a/src/test/scala/com/databricks/labs/mosaic/test/package.scala +++ b/src/test/scala/com/databricks/labs/mosaic/test/package.scala @@ -164,7 +164,7 @@ package object test { } val geotiffBytes: Array[Byte] = fileBytes("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") val gribBytes: Array[Byte] = - fileBytes("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib") + fileBytes("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") val netcdfBytes: Array[Byte] = fileBytes("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") def polyDf(sparkSession: SparkSession, mosaicContext: MosaicContext): DataFrame = { diff --git a/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala b/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala index 8029c30a7..84a613b31 100644 --- a/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala +++ b/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala @@ -8,13 +8,13 @@ class MosaicTestSparkSession(sc: SparkContext) extends TestSparkSession(sc) { this( new SparkContext( - "local[4]", + "local[8]", "test-sql-context", sparkConf .set("spark.sql.adaptive.enabled", "false") .set("spark.driver.memory", "32g") .set("spark.executor.memory", "32g") - .set("spark.sql.shuffle.partitions", "4") + .set("spark.sql.shuffle.partitions", "8") .set("spark.sql.testkey", "true") ) ) diff --git a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala index 984fff9d8..12dcac6f3 100644 --- a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala +++ b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala @@ -1,6 +1,5 @@ package org.apache.spark.sql.test -import com.databricks.labs.mosaic._ import com.databricks.labs.mosaic.gdal.MosaicGDAL import com.databricks.labs.mosaic.utils.FileUtils import com.databricks.labs.mosaic.{MOSAIC_GDAL_NATIVE, MOSAIC_RASTER_CHECKPOINT} @@ -8,7 +7,6 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.gdal.gdal.gdal -import java.nio.file.{Files, Paths} import scala.util.Try trait SharedSparkSessionGDAL extends SharedSparkSession { @@ -20,10 +18,10 @@ trait SharedSparkSessionGDAL extends SharedSparkSession { override def createSparkSession: TestSparkSession = { val conf = sparkConf - conf.set(MOSAIC_RASTER_CHECKPOINT, FileUtils.createMosaicTempDir()) + conf.set(MOSAIC_RASTER_CHECKPOINT, FileUtils.createMosaicTempDir(prefix = "/mnt/")) SparkSession.cleanupAnyExistingSession() val session = new MosaicTestSparkSession(conf) - session.sparkContext.setLogLevel("INFO") + session.sparkContext.setLogLevel("FATAL") Try { MosaicGDAL.enableGDAL(session) }