Skip to content

Commit

Permalink
Add Snowflake plugin with integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
sethj-pantomath committed Jul 17, 2024
1 parent a6f63b1 commit 4495e93
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 1 deletion.
5 changes: 5 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@
<artifactId>mongo-spark-connector_${scala.binary.version}</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>net.snowflake</groupId>
<artifactId>spark-snowflake_${scala.binary.version}</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.elasticsearch</groupId>
<artifactId>elasticsearch-hadoop</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2021 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.spline.harvester.plugin.embedded

import za.co.absa.spline.commons.reflect.ReflectionUtils.extractValue
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.{LogicalRelation, SaveIntoDataSourceCommand}
import org.apache.spark.sql.sources.BaseRelation
import za.co.absa.spline.commons.reflect.extractors.SafeTypeMatchingExtractor
import za.co.absa.spline.harvester.builder.SourceIdentifier
import za.co.absa.spline.harvester.plugin.Plugin.{Precedence, ReadNodeInfo, WriteNodeInfo}
import za.co.absa.spline.harvester.plugin.embedded.SnowflakePlugin._
import za.co.absa.spline.harvester.plugin.{BaseRelationProcessing, Plugin, RelationProviderProcessing}

import javax.annotation.Priority
import scala.language.reflectiveCalls

@Priority(Precedence.Normal)
class SnowflakePlugin(spark: SparkSession)
extends Plugin
with BaseRelationProcessing
with RelationProviderProcessing {

import za.co.absa.spline.commons.ExtractorImplicits._

override def baseRelationProcessor: PartialFunction[(BaseRelation, LogicalRelation), ReadNodeInfo] = {
case (`_: SnowflakeRelation`(r), _) =>
val params = extractValue[net.snowflake.spark.snowflake.Parameters.MergedParameters](r, "params")

val url: String = params.sfURL
val warehouse: String = params.sfWarehouse.getOrElse("")
val database: String = params.sfDatabase
val schema: String = params.sfSchema
val table: String = params.table.getOrElse("").toString

ReadNodeInfo(asSourceId(url, warehouse, database, schema, table), Map.empty)
}

override def relationProviderProcessor: PartialFunction[(AnyRef, SaveIntoDataSourceCommand), WriteNodeInfo] = {
case (rp, cmd) if rp == "net.snowflake.spark.snowflake.DefaultSource" || SnowflakeSourceExtractor.matches(rp) =>
val url: String = cmd.options("sfUrl")
val warehouse: String = cmd.options("sfWarehouse")
val database: String = cmd.options("sfDatabase")
val schema: String = cmd.options("sfSchema")
val table: String = cmd.options("dbtable")

WriteNodeInfo(asSourceId(url, warehouse, database, schema, table), cmd.mode, cmd.query, cmd.options) }
}

object SnowflakePlugin {

private object `_: SnowflakeRelation` extends SafeTypeMatchingExtractor[AnyRef]("net.snowflake.spark.snowflake.SnowflakeRelation")

private object SnowflakeSourceExtractor extends SafeTypeMatchingExtractor(classOf[net.snowflake.spark.snowflake.DefaultSource])

private def asSourceId(url: String, warehouse: String, database: String, schema: String, table: String) =
SourceIdentifier(Some("snowflake"), s"snowflake://$url.$warehouse.$database.$schema.$table")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright 2021 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.spline.harvester.plugin.embedded

import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.execution.datasources.{LogicalRelation, SaveIntoDataSourceCommand}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar
import za.co.absa.spline.harvester.plugin.Plugin.{ReadNodeInfo, WriteNodeInfo}
import za.co.absa.spline.harvester.builder.SourceIdentifier
import org.mockito.Mockito.{mock, _}
import net.snowflake.spark.snowflake.Parameters
import net.snowflake.spark.snowflake.Parameters.MergedParameters
import org.apache.spark.sql.sources.BaseRelation
import za.co.absa.spline.commons.reflect.extractors.SafeTypeMatchingExtractor
import za.co.absa.spline.commons.reflect.{ReflectionUtils, ValueExtractor}

class SnowflakePluginSpec extends AnyFlatSpec with Matchers with MockitoSugar {
"SnowflakePlugin" should "process Snowflake relation providers" in {
// Setup
val spark = mock[SparkSession]
val plugin = new SnowflakePlugin(spark)

val options = Map(
"sfUrl" -> "test-url",
"sfWarehouse" -> "test-warehouse",
"sfDatabase" -> "test-database",
"sfSchema" -> "test-schema",
"sfUser" -> "user1",
"dbtable" -> "test-table"
)

val cmd = mock[SaveIntoDataSourceCommand]
when(cmd.options) thenReturn(options)
when(cmd.mode) thenReturn(SaveMode.Overwrite)
when(cmd.query) thenReturn(null)

// Mocking the relation provider to be Snowflake
val snowflakeRP = "net.snowflake.spark.snowflake.DefaultSource"

// Execute
val result = plugin.relationProviderProcessor((snowflakeRP, cmd))

// Verify
val expectedSourceId = SourceIdentifier(Some("snowflake"), "snowflake://test-url.test-warehouse.test-database.test-schema.test-table")
result shouldEqual WriteNodeInfo(expectedSourceId, SaveMode.Overwrite, null, options)
}

it should "not process non-Snowflake relation providers" in {
// Setup
val spark = mock[SparkSession]
val plugin = new SnowflakePlugin(spark)

val cmd = mock[SaveIntoDataSourceCommand]

// Mocking the relation provider to be non-Snowflake
val nonSnowflakeRP = "some.other.datasource"

// Execute & Verify
assertThrows[MatchError] {
plugin.relationProviderProcessor((nonSnowflakeRP, cmd))
}
}
}
11 changes: 11 additions & 0 deletions integration-tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@
<artifactId>spark-cobol_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>net.snowflake</groupId>
<artifactId>spark-snowflake_${scala.binary.version}</artifactId>
<optional>true</optional>
</dependency>

<!-- to force newer version of jackson-annotations - needed for testcontainers -->
<dependency>
Expand Down Expand Up @@ -163,6 +168,12 @@
<version>${testcontainers.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>localstack</artifactId>
<version>1.19.8</version>
<scope>test</scope>
</dependency>

<!-- required for spark-cassandra-connector -->
<dependency>
Expand Down
106 changes: 106 additions & 0 deletions integration-tests/src/test/scala/za/co/absa/spline/SnowflakeSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright 2019 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.spline

import org.apache.spark.sql.{Row, RowFactory}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AsyncFlatSpec
import org.scalatest.matchers.should.Matchers
import org.testcontainers.containers.GenericContainer
import org.testcontainers.containers.wait.strategy.Wait
import org.testcontainers.utility.DockerImageName
import za.co.absa.spline.commons.io.TempDirectory
import za.co.absa.spline.test.fixture.spline.SplineFixture
import za.co.absa.spline.test.fixture.{ReleasableResourceFixture, SparkFixture}

import java.util

class SnowflakeSpec
extends AsyncFlatSpec
with BeforeAndAfterAll
with Matchers
with SparkFixture
with SplineFixture
with ReleasableResourceFixture {

val tableName = "testTable"
val schemaName = "testSchema"
val warehouseName = "testWarehouse"
val databaseName = "test"
val sparkFormat = "net.snowflake.spark.snowflake"

it should "support snowflake as a read and write source" in {
usingResource(new GenericContainer(DockerImageName.parse("localstack/snowflake"))) { container =>
container.start()
Wait.forHealthcheck

val host = container.getHost

withNewSparkSession { implicit spark =>

withLineageTracking { captor =>
val sfOptions = Map(
"sfURL" -> "snowflake.localhost.localstack.cloud",
"sfUser" -> "test",
"sfPassword" -> "test",
"sfDatabase" -> databaseName,
"sfWarehouse" -> warehouseName,
"sfSchema" -> schemaName
)

// Define your data as a Java List
val data = new util.ArrayList[Row]()
data.add(RowFactory.create(1.asInstanceOf[Object]))
data.add(RowFactory.create(2.asInstanceOf[Object]))
data.add(RowFactory.create(3.asInstanceOf[Object]))

// Use the method to create DataFrame
val testData = spark.sqlContext.createDataFrame(data, classOf[Row])

for {
(writePlan, _) <- captor.lineageOf(
testData.write
.format(sparkFormat)
.options(sfOptions)
.option("dbtable", tableName)
.mode("overwrite")
.save()
)

(readPlan, _) <- captor.lineageOf {
val df = spark.read.format(sparkFormat)
.options(sfOptions)
.option("dbtable", tableName) // specify the source table
.load()

df.write.save(TempDirectory(pathOnly = true).deleteOnExit().path.toString)
}
} yield {
writePlan.operations.write.append shouldBe false
writePlan.operations.write.extra("destinationType") shouldBe Some("snowflake")
writePlan.operations.write.outputSource shouldBe s"snowflake://$host.$warehouseName.$databaseName.$schemaName.$tableName"

readPlan.operations.reads.head.inputSources.head shouldBe writePlan.operations.write.outputSource
readPlan.operations.reads.head.extra("sourceType") shouldBe Some("snowflake")
readPlan.operations.write.append shouldBe false
}
}
}
}
}

}
7 changes: 6 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

<!-- Spark -->

<spark.version>${spark-24.version}</spark.version>
<spark.version>${spark-33.version}</spark.version>
<spark-22.version>2.2.3</spark-22.version>
<spark-23.version>2.3.4</spark-23.version>
<spark-24.version>2.4.8</spark-24.version>
Expand Down Expand Up @@ -452,6 +452,11 @@
<artifactId>mongo-spark-connector_${scala.binary.version}</artifactId>
<version>2.4.1</version>
</dependency>
<dependency>
<groupId>net.snowflake</groupId>
<artifactId>spark-snowflake_${scala.binary.version}</artifactId>
<version>2.16.0-spark_3.3</version>
</dependency>
<dependency>
<groupId>org.elasticsearch</groupId>
<artifactId>elasticsearch-hadoop</artifactId>
Expand Down

0 comments on commit 4495e93

Please sign in to comment.