diff --git a/core/pom.xml b/core/pom.xml index f66fecd7..53a1d920 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -139,6 +139,11 @@ mongo-spark-connector_${scala.binary.version} true + + net.snowflake + spark-snowflake_${scala.binary.version} + true + org.elasticsearch elasticsearch-hadoop diff --git a/core/src/main/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePlugin.scala b/core/src/main/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePlugin.scala new file mode 100644 index 00000000..9c070e3a --- /dev/null +++ b/core/src/main/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePlugin.scala @@ -0,0 +1,72 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spline.harvester.plugin.embedded + +import za.co.absa.spline.commons.reflect.ReflectionUtils.extractValue +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.{LogicalRelation, SaveIntoDataSourceCommand} +import org.apache.spark.sql.sources.BaseRelation +import za.co.absa.spline.commons.reflect.extractors.SafeTypeMatchingExtractor +import za.co.absa.spline.harvester.builder.SourceIdentifier +import za.co.absa.spline.harvester.plugin.Plugin.{Precedence, ReadNodeInfo, WriteNodeInfo} +import za.co.absa.spline.harvester.plugin.embedded.SnowflakePlugin._ +import za.co.absa.spline.harvester.plugin.{BaseRelationProcessing, Plugin, RelationProviderProcessing} + +import javax.annotation.Priority +import scala.language.reflectiveCalls + +@Priority(Precedence.Normal) +class SnowflakePlugin(spark: SparkSession) + extends Plugin + with BaseRelationProcessing + with RelationProviderProcessing { + + import za.co.absa.spline.commons.ExtractorImplicits._ + + override def baseRelationProcessor: PartialFunction[(BaseRelation, LogicalRelation), ReadNodeInfo] = { + case (`_: SnowflakeRelation`(r), _) => + val params = extractValue[net.snowflake.spark.snowflake.Parameters.MergedParameters](r, "params") + + val url: String = params.sfURL + val warehouse: String = params.sfWarehouse.getOrElse("") + val database: String = params.sfDatabase + val schema: String = params.sfSchema + val table: String = params.table.getOrElse("").toString + + ReadNodeInfo(asSourceId(url, warehouse, database, schema, table), Map.empty) + } + + override def relationProviderProcessor: PartialFunction[(AnyRef, SaveIntoDataSourceCommand), WriteNodeInfo] = { + case (rp, cmd) if rp == "net.snowflake.spark.snowflake.DefaultSource" || SnowflakeSourceExtractor.matches(rp) => + val url: String = cmd.options("sfUrl") + val warehouse: String = cmd.options("sfWarehouse") + val database: String = cmd.options("sfDatabase") + val schema: String = cmd.options("sfSchema") + val table: String = cmd.options("dbtable") + + WriteNodeInfo(asSourceId(url, warehouse, database, schema, table), cmd.mode, cmd.query, cmd.options) } +} + +object SnowflakePlugin { + + private object `_: SnowflakeRelation` extends SafeTypeMatchingExtractor[AnyRef]("net.snowflake.spark.snowflake.SnowflakeRelation") + + private object SnowflakeSourceExtractor extends SafeTypeMatchingExtractor(classOf[net.snowflake.spark.snowflake.DefaultSource]) + + private def asSourceId(url: String, warehouse: String, database: String, schema: String, table: String) = + SourceIdentifier(Some("snowflake"), s"snowflake://$url.$warehouse.$database.$schema.$table") +} diff --git a/core/src/test/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePluginSpec.scala b/core/src/test/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePluginSpec.scala new file mode 100644 index 00000000..7c39b518 --- /dev/null +++ b/core/src/test/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePluginSpec.scala @@ -0,0 +1,79 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spline.harvester.plugin.embedded + +import org.apache.spark.sql.{SaveMode, SparkSession} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, SaveIntoDataSourceCommand} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar +import za.co.absa.spline.harvester.plugin.Plugin.{ReadNodeInfo, WriteNodeInfo} +import za.co.absa.spline.harvester.builder.SourceIdentifier +import org.mockito.Mockito.{mock, _} +import net.snowflake.spark.snowflake.Parameters +import net.snowflake.spark.snowflake.Parameters.MergedParameters +import org.apache.spark.sql.sources.BaseRelation +import za.co.absa.spline.commons.reflect.extractors.SafeTypeMatchingExtractor +import za.co.absa.spline.commons.reflect.{ReflectionUtils, ValueExtractor} + +class SnowflakePluginSpec extends AnyFlatSpec with Matchers with MockitoSugar { + "SnowflakePlugin" should "process Snowflake relation providers" in { + // Setup + val spark = mock[SparkSession] + val plugin = new SnowflakePlugin(spark) + + val options = Map( + "sfUrl" -> "test-url", + "sfWarehouse" -> "test-warehouse", + "sfDatabase" -> "test-database", + "sfSchema" -> "test-schema", + "sfUser" -> "user1", + "dbtable" -> "test-table" + ) + + val cmd = mock[SaveIntoDataSourceCommand] + when(cmd.options) thenReturn(options) + when(cmd.mode) thenReturn(SaveMode.Overwrite) + when(cmd.query) thenReturn(null) + + // Mocking the relation provider to be Snowflake + val snowflakeRP = "net.snowflake.spark.snowflake.DefaultSource" + + // Execute + val result = plugin.relationProviderProcessor((snowflakeRP, cmd)) + + // Verify + val expectedSourceId = SourceIdentifier(Some("snowflake"), "snowflake://test-url.test-warehouse.test-database.test-schema.test-table") + result shouldEqual WriteNodeInfo(expectedSourceId, SaveMode.Overwrite, null, options) + } + + it should "not process non-Snowflake relation providers" in { + // Setup + val spark = mock[SparkSession] + val plugin = new SnowflakePlugin(spark) + + val cmd = mock[SaveIntoDataSourceCommand] + + // Mocking the relation provider to be non-Snowflake + val nonSnowflakeRP = "some.other.datasource" + + // Execute & Verify + assertThrows[MatchError] { + plugin.relationProviderProcessor((nonSnowflakeRP, cmd)) + } + } +} diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index ec2b15bf..158ac6b1 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -131,6 +131,11 @@ spark-cobol_${scala.binary.version} test + + net.snowflake + spark-snowflake_${scala.binary.version} + true + @@ -163,6 +168,12 @@ ${testcontainers.version} test + + org.testcontainers + localstack + 1.19.8 + test + org.elasticsearch diff --git a/integration-tests/src/test/scala/za/co/absa/spline/SnowflakeSpec.scala b/integration-tests/src/test/scala/za/co/absa/spline/SnowflakeSpec.scala new file mode 100644 index 00000000..157e8b6f --- /dev/null +++ b/integration-tests/src/test/scala/za/co/absa/spline/SnowflakeSpec.scala @@ -0,0 +1,106 @@ +/* + * Copyright 2019 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spline + +import org.apache.spark.sql.{Row, RowFactory} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.flatspec.AsyncFlatSpec +import org.scalatest.matchers.should.Matchers +import org.testcontainers.containers.GenericContainer +import org.testcontainers.containers.wait.strategy.Wait +import org.testcontainers.utility.DockerImageName +import za.co.absa.spline.commons.io.TempDirectory +import za.co.absa.spline.test.fixture.spline.SplineFixture +import za.co.absa.spline.test.fixture.{ReleasableResourceFixture, SparkFixture} + +import java.util + +class SnowflakeSpec + extends AsyncFlatSpec + with BeforeAndAfterAll + with Matchers + with SparkFixture + with SplineFixture + with ReleasableResourceFixture { + + val tableName = "testTable" + val schemaName = "testSchema" + val warehouseName = "testWarehouse" + val databaseName = "test" + val sparkFormat = "net.snowflake.spark.snowflake" + + it should "support snowflake as a read and write source" in { + usingResource(new GenericContainer(DockerImageName.parse("localstack/snowflake"))) { container => + container.start() + Wait.forHealthcheck + + val host = container.getHost + + withNewSparkSession { implicit spark => + + withLineageTracking { captor => + val sfOptions = Map( + "sfURL" -> "snowflake.localhost.localstack.cloud", + "sfUser" -> "test", + "sfPassword" -> "test", + "sfDatabase" -> databaseName, + "sfWarehouse" -> warehouseName, + "sfSchema" -> schemaName + ) + + // Define your data as a Java List + val data = new util.ArrayList[Row]() + data.add(RowFactory.create(1.asInstanceOf[Object])) + data.add(RowFactory.create(2.asInstanceOf[Object])) + data.add(RowFactory.create(3.asInstanceOf[Object])) + + // Use the method to create DataFrame + val testData = spark.sqlContext.createDataFrame(data, classOf[Row]) + + for { + (writePlan, _) <- captor.lineageOf( + testData.write + .format(sparkFormat) + .options(sfOptions) + .option("dbtable", tableName) + .mode("overwrite") + .save() + ) + + (readPlan, _) <- captor.lineageOf { + val df = spark.read.format(sparkFormat) + .options(sfOptions) + .option("dbtable", tableName) // specify the source table + .load() + + df.write.save(TempDirectory(pathOnly = true).deleteOnExit().path.toString) + } + } yield { + writePlan.operations.write.append shouldBe false + writePlan.operations.write.extra("destinationType") shouldBe Some("snowflake") + writePlan.operations.write.outputSource shouldBe s"snowflake://$host.$warehouseName.$databaseName.$schemaName.$tableName" + + readPlan.operations.reads.head.inputSources.head shouldBe writePlan.operations.write.outputSource + readPlan.operations.reads.head.extra("sourceType") shouldBe Some("snowflake") + readPlan.operations.write.append shouldBe false + } + } + } + } + } + +} diff --git a/pom.xml b/pom.xml index db6717f8..86681fde 100644 --- a/pom.xml +++ b/pom.xml @@ -83,7 +83,7 @@ - ${spark-24.version} + ${spark-33.version} 2.2.3 2.3.4 2.4.8 @@ -472,6 +472,11 @@ mongo-spark-connector_${scala.binary.version} 2.4.1 + + net.snowflake + spark-snowflake_${scala.binary.version} + 2.16.0-spark_3.3 + org.elasticsearch elasticsearch-hadoop