From fa2fc5cce6763122dc901c74ad12b6544c0b6a76 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Wed, 17 Jan 2024 14:02:15 +0000 Subject: [PATCH] Add tests for RST_stats expressions. --- .../mosaic/expressions/raster/RST_Avg.scala | 2 +- .../expressions/raster/RST_AvgBehaviors.scala | 48 +++++++++++++++++++ .../expressions/raster/RST_AvgTest.scala | 32 +++++++++++++ .../expressions/raster/RST_MaxBehaviors.scala | 4 -- .../expressions/raster/RST_MinBehaviors.scala | 48 +++++++++++++++++++ .../expressions/raster/RST_MinTest.scala | 32 +++++++++++++ 6 files changed, 161 insertions(+), 5 deletions(-) create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgTest.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinTest.scala diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala index be82af449..82752cad4 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala @@ -39,7 +39,7 @@ case class RST_Avg(raster: Expression, expressionConfig: MosaicExpressionConfig) /** Expression info required for the expression registration for spark SQL. */ object RST_Avg extends WithExpressionInfo { - override def name: String = "rst_mean" + override def name: String = "rst_avg" override def usage: String = "_FUNC_(expr1) - Returns an array containing mean values for each band." diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgBehaviors.scala new file mode 100644 index 000000000..f01ce2d25 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgBehaviors.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_AvgBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_avg($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_avg(tile) from source + |""".stripMargin) + + val result = df.as[Double].collect().max + + result > 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_avg() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgTest.scala new file mode 100644 index 000000000..6805f0723 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_AvgTest extends QueryTest with SharedSparkSessionGDAL with RST_AvgBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_avg behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala index 9c095488d..daab1ee90 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala @@ -35,10 +35,6 @@ trait RST_MaxBehaviors extends QueryTest { |select rst_max(tile) from source |""".stripMargin) - noException should be thrownBy rastersInMemory - .withColumn("result", rst_rastertogridmax($"tile", lit(3))) - .select("result") - val result = df.as[Double].collect().max result > 0 shouldBe true diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala new file mode 100644 index 000000000..bd867ee65 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_MinBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_min($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_min(tile) from source + |""".stripMargin) + + val result = df.as[Double].collect().min + + result < 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_min() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinTest.scala new file mode 100644 index 000000000..ec09792f9 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_MinTest extends QueryTest with SharedSparkSessionGDAL with RST_MinBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_min behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +}