diff --git a/core/pom.xml b/core/pom.xml
index f66fecd7..53a1d920 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -139,6 +139,11 @@
mongo-spark-connector_${scala.binary.version}
true
+
+ net.snowflake
+ spark-snowflake_${scala.binary.version}
+ true
+
org.elasticsearch
elasticsearch-hadoop
diff --git a/core/src/main/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePlugin.scala b/core/src/main/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePlugin.scala
new file mode 100644
index 00000000..9c070e3a
--- /dev/null
+++ b/core/src/main/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePlugin.scala
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2021 ABSA Group Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package za.co.absa.spline.harvester.plugin.embedded
+
+import za.co.absa.spline.commons.reflect.ReflectionUtils.extractValue
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.datasources.{LogicalRelation, SaveIntoDataSourceCommand}
+import org.apache.spark.sql.sources.BaseRelation
+import za.co.absa.spline.commons.reflect.extractors.SafeTypeMatchingExtractor
+import za.co.absa.spline.harvester.builder.SourceIdentifier
+import za.co.absa.spline.harvester.plugin.Plugin.{Precedence, ReadNodeInfo, WriteNodeInfo}
+import za.co.absa.spline.harvester.plugin.embedded.SnowflakePlugin._
+import za.co.absa.spline.harvester.plugin.{BaseRelationProcessing, Plugin, RelationProviderProcessing}
+
+import javax.annotation.Priority
+import scala.language.reflectiveCalls
+
+@Priority(Precedence.Normal)
+class SnowflakePlugin(spark: SparkSession)
+ extends Plugin
+ with BaseRelationProcessing
+ with RelationProviderProcessing {
+
+ import za.co.absa.spline.commons.ExtractorImplicits._
+
+ override def baseRelationProcessor: PartialFunction[(BaseRelation, LogicalRelation), ReadNodeInfo] = {
+ case (`_: SnowflakeRelation`(r), _) =>
+ val params = extractValue[net.snowflake.spark.snowflake.Parameters.MergedParameters](r, "params")
+
+ val url: String = params.sfURL
+ val warehouse: String = params.sfWarehouse.getOrElse("")
+ val database: String = params.sfDatabase
+ val schema: String = params.sfSchema
+ val table: String = params.table.getOrElse("").toString
+
+ ReadNodeInfo(asSourceId(url, warehouse, database, schema, table), Map.empty)
+ }
+
+ override def relationProviderProcessor: PartialFunction[(AnyRef, SaveIntoDataSourceCommand), WriteNodeInfo] = {
+ case (rp, cmd) if rp == "net.snowflake.spark.snowflake.DefaultSource" || SnowflakeSourceExtractor.matches(rp) =>
+ val url: String = cmd.options("sfUrl")
+ val warehouse: String = cmd.options("sfWarehouse")
+ val database: String = cmd.options("sfDatabase")
+ val schema: String = cmd.options("sfSchema")
+ val table: String = cmd.options("dbtable")
+
+ WriteNodeInfo(asSourceId(url, warehouse, database, schema, table), cmd.mode, cmd.query, cmd.options) }
+}
+
+object SnowflakePlugin {
+
+ private object `_: SnowflakeRelation` extends SafeTypeMatchingExtractor[AnyRef]("net.snowflake.spark.snowflake.SnowflakeRelation")
+
+ private object SnowflakeSourceExtractor extends SafeTypeMatchingExtractor(classOf[net.snowflake.spark.snowflake.DefaultSource])
+
+ private def asSourceId(url: String, warehouse: String, database: String, schema: String, table: String) =
+ SourceIdentifier(Some("snowflake"), s"snowflake://$url.$warehouse.$database.$schema.$table")
+}
diff --git a/core/src/test/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePluginSpec.scala b/core/src/test/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePluginSpec.scala
new file mode 100644
index 00000000..7c39b518
--- /dev/null
+++ b/core/src/test/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePluginSpec.scala
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2021 ABSA Group Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package za.co.absa.spline.harvester.plugin.embedded
+
+import org.apache.spark.sql.{SaveMode, SparkSession}
+import org.apache.spark.sql.execution.datasources.{LogicalRelation, SaveIntoDataSourceCommand}
+import org.scalatest.flatspec.AnyFlatSpec
+import org.scalatest.matchers.should.Matchers
+import org.scalatestplus.mockito.MockitoSugar
+import za.co.absa.spline.harvester.plugin.Plugin.{ReadNodeInfo, WriteNodeInfo}
+import za.co.absa.spline.harvester.builder.SourceIdentifier
+import org.mockito.Mockito.{mock, _}
+import net.snowflake.spark.snowflake.Parameters
+import net.snowflake.spark.snowflake.Parameters.MergedParameters
+import org.apache.spark.sql.sources.BaseRelation
+import za.co.absa.spline.commons.reflect.extractors.SafeTypeMatchingExtractor
+import za.co.absa.spline.commons.reflect.{ReflectionUtils, ValueExtractor}
+
+class SnowflakePluginSpec extends AnyFlatSpec with Matchers with MockitoSugar {
+ "SnowflakePlugin" should "process Snowflake relation providers" in {
+ // Setup
+ val spark = mock[SparkSession]
+ val plugin = new SnowflakePlugin(spark)
+
+ val options = Map(
+ "sfUrl" -> "test-url",
+ "sfWarehouse" -> "test-warehouse",
+ "sfDatabase" -> "test-database",
+ "sfSchema" -> "test-schema",
+ "sfUser" -> "user1",
+ "dbtable" -> "test-table"
+ )
+
+ val cmd = mock[SaveIntoDataSourceCommand]
+ when(cmd.options) thenReturn(options)
+ when(cmd.mode) thenReturn(SaveMode.Overwrite)
+ when(cmd.query) thenReturn(null)
+
+ // Mocking the relation provider to be Snowflake
+ val snowflakeRP = "net.snowflake.spark.snowflake.DefaultSource"
+
+ // Execute
+ val result = plugin.relationProviderProcessor((snowflakeRP, cmd))
+
+ // Verify
+ val expectedSourceId = SourceIdentifier(Some("snowflake"), "snowflake://test-url.test-warehouse.test-database.test-schema.test-table")
+ result shouldEqual WriteNodeInfo(expectedSourceId, SaveMode.Overwrite, null, options)
+ }
+
+ it should "not process non-Snowflake relation providers" in {
+ // Setup
+ val spark = mock[SparkSession]
+ val plugin = new SnowflakePlugin(spark)
+
+ val cmd = mock[SaveIntoDataSourceCommand]
+
+ // Mocking the relation provider to be non-Snowflake
+ val nonSnowflakeRP = "some.other.datasource"
+
+ // Execute & Verify
+ assertThrows[MatchError] {
+ plugin.relationProviderProcessor((nonSnowflakeRP, cmd))
+ }
+ }
+}
diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml
index ec2b15bf..158ac6b1 100644
--- a/integration-tests/pom.xml
+++ b/integration-tests/pom.xml
@@ -131,6 +131,11 @@
spark-cobol_${scala.binary.version}
test
+
+ net.snowflake
+ spark-snowflake_${scala.binary.version}
+ true
+
@@ -163,6 +168,12 @@
${testcontainers.version}
test
+
+ org.testcontainers
+ localstack
+ 1.19.8
+ test
+
org.elasticsearch
diff --git a/integration-tests/src/test/scala/za/co/absa/spline/SnowflakeSpec.scala b/integration-tests/src/test/scala/za/co/absa/spline/SnowflakeSpec.scala
new file mode 100644
index 00000000..157e8b6f
--- /dev/null
+++ b/integration-tests/src/test/scala/za/co/absa/spline/SnowflakeSpec.scala
@@ -0,0 +1,106 @@
+/*
+ * Copyright 2019 ABSA Group Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package za.co.absa.spline
+
+import org.apache.spark.sql.{Row, RowFactory}
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.flatspec.AsyncFlatSpec
+import org.scalatest.matchers.should.Matchers
+import org.testcontainers.containers.GenericContainer
+import org.testcontainers.containers.wait.strategy.Wait
+import org.testcontainers.utility.DockerImageName
+import za.co.absa.spline.commons.io.TempDirectory
+import za.co.absa.spline.test.fixture.spline.SplineFixture
+import za.co.absa.spline.test.fixture.{ReleasableResourceFixture, SparkFixture}
+
+import java.util
+
+class SnowflakeSpec
+ extends AsyncFlatSpec
+ with BeforeAndAfterAll
+ with Matchers
+ with SparkFixture
+ with SplineFixture
+ with ReleasableResourceFixture {
+
+ val tableName = "testTable"
+ val schemaName = "testSchema"
+ val warehouseName = "testWarehouse"
+ val databaseName = "test"
+ val sparkFormat = "net.snowflake.spark.snowflake"
+
+ it should "support snowflake as a read and write source" in {
+ usingResource(new GenericContainer(DockerImageName.parse("localstack/snowflake"))) { container =>
+ container.start()
+ Wait.forHealthcheck
+
+ val host = container.getHost
+
+ withNewSparkSession { implicit spark =>
+
+ withLineageTracking { captor =>
+ val sfOptions = Map(
+ "sfURL" -> "snowflake.localhost.localstack.cloud",
+ "sfUser" -> "test",
+ "sfPassword" -> "test",
+ "sfDatabase" -> databaseName,
+ "sfWarehouse" -> warehouseName,
+ "sfSchema" -> schemaName
+ )
+
+ // Define your data as a Java List
+ val data = new util.ArrayList[Row]()
+ data.add(RowFactory.create(1.asInstanceOf[Object]))
+ data.add(RowFactory.create(2.asInstanceOf[Object]))
+ data.add(RowFactory.create(3.asInstanceOf[Object]))
+
+ // Use the method to create DataFrame
+ val testData = spark.sqlContext.createDataFrame(data, classOf[Row])
+
+ for {
+ (writePlan, _) <- captor.lineageOf(
+ testData.write
+ .format(sparkFormat)
+ .options(sfOptions)
+ .option("dbtable", tableName)
+ .mode("overwrite")
+ .save()
+ )
+
+ (readPlan, _) <- captor.lineageOf {
+ val df = spark.read.format(sparkFormat)
+ .options(sfOptions)
+ .option("dbtable", tableName) // specify the source table
+ .load()
+
+ df.write.save(TempDirectory(pathOnly = true).deleteOnExit().path.toString)
+ }
+ } yield {
+ writePlan.operations.write.append shouldBe false
+ writePlan.operations.write.extra("destinationType") shouldBe Some("snowflake")
+ writePlan.operations.write.outputSource shouldBe s"snowflake://$host.$warehouseName.$databaseName.$schemaName.$tableName"
+
+ readPlan.operations.reads.head.inputSources.head shouldBe writePlan.operations.write.outputSource
+ readPlan.operations.reads.head.extra("sourceType") shouldBe Some("snowflake")
+ readPlan.operations.write.append shouldBe false
+ }
+ }
+ }
+ }
+ }
+
+}
diff --git a/pom.xml b/pom.xml
index db6717f8..86681fde 100644
--- a/pom.xml
+++ b/pom.xml
@@ -83,7 +83,7 @@
- ${spark-24.version}
+ ${spark-33.version}
2.2.3
2.3.4
2.4.8
@@ -472,6 +472,11 @@
mongo-spark-connector_${scala.binary.version}
2.4.1
+
+ net.snowflake
+ spark-snowflake_${scala.binary.version}
+ 2.16.0-spark_3.3
+
org.elasticsearch
elasticsearch-hadoop