-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Snowflake plugin with integration test
- Loading branch information
1 parent
a6f63b1
commit 4495e93
Showing
6 changed files
with
279 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
core/src/main/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePlugin.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* Copyright 2021 ABSA Group Limited | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package za.co.absa.spline.harvester.plugin.embedded | ||
|
||
import za.co.absa.spline.commons.reflect.ReflectionUtils.extractValue | ||
import org.apache.spark.sql.SparkSession | ||
import org.apache.spark.sql.execution.datasources.{LogicalRelation, SaveIntoDataSourceCommand} | ||
import org.apache.spark.sql.sources.BaseRelation | ||
import za.co.absa.spline.commons.reflect.extractors.SafeTypeMatchingExtractor | ||
import za.co.absa.spline.harvester.builder.SourceIdentifier | ||
import za.co.absa.spline.harvester.plugin.Plugin.{Precedence, ReadNodeInfo, WriteNodeInfo} | ||
import za.co.absa.spline.harvester.plugin.embedded.SnowflakePlugin._ | ||
import za.co.absa.spline.harvester.plugin.{BaseRelationProcessing, Plugin, RelationProviderProcessing} | ||
|
||
import javax.annotation.Priority | ||
import scala.language.reflectiveCalls | ||
|
||
@Priority(Precedence.Normal) | ||
class SnowflakePlugin(spark: SparkSession) | ||
extends Plugin | ||
with BaseRelationProcessing | ||
with RelationProviderProcessing { | ||
|
||
import za.co.absa.spline.commons.ExtractorImplicits._ | ||
|
||
override def baseRelationProcessor: PartialFunction[(BaseRelation, LogicalRelation), ReadNodeInfo] = { | ||
case (`_: SnowflakeRelation`(r), _) => | ||
val params = extractValue[net.snowflake.spark.snowflake.Parameters.MergedParameters](r, "params") | ||
|
||
val url: String = params.sfURL | ||
val warehouse: String = params.sfWarehouse.getOrElse("") | ||
val database: String = params.sfDatabase | ||
val schema: String = params.sfSchema | ||
val table: String = params.table.getOrElse("").toString | ||
|
||
ReadNodeInfo(asSourceId(url, warehouse, database, schema, table), Map.empty) | ||
} | ||
|
||
override def relationProviderProcessor: PartialFunction[(AnyRef, SaveIntoDataSourceCommand), WriteNodeInfo] = { | ||
case (rp, cmd) if rp == "net.snowflake.spark.snowflake.DefaultSource" || SnowflakeSourceExtractor.matches(rp) => | ||
val url: String = cmd.options("sfUrl") | ||
val warehouse: String = cmd.options("sfWarehouse") | ||
val database: String = cmd.options("sfDatabase") | ||
val schema: String = cmd.options("sfSchema") | ||
val table: String = cmd.options("dbtable") | ||
|
||
WriteNodeInfo(asSourceId(url, warehouse, database, schema, table), cmd.mode, cmd.query, cmd.options) } | ||
} | ||
|
||
object SnowflakePlugin { | ||
|
||
private object `_: SnowflakeRelation` extends SafeTypeMatchingExtractor[AnyRef]("net.snowflake.spark.snowflake.SnowflakeRelation") | ||
|
||
private object SnowflakeSourceExtractor extends SafeTypeMatchingExtractor(classOf[net.snowflake.spark.snowflake.DefaultSource]) | ||
|
||
private def asSourceId(url: String, warehouse: String, database: String, schema: String, table: String) = | ||
SourceIdentifier(Some("snowflake"), s"snowflake://$url.$warehouse.$database.$schema.$table") | ||
} |
79 changes: 79 additions & 0 deletions
79
core/src/test/scala/za/co/absa/spline/harvester/plugin/embedded/SnowflakePluginSpec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/* | ||
* Copyright 2021 ABSA Group Limited | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package za.co.absa.spline.harvester.plugin.embedded | ||
|
||
import org.apache.spark.sql.{SaveMode, SparkSession} | ||
import org.apache.spark.sql.execution.datasources.{LogicalRelation, SaveIntoDataSourceCommand} | ||
import org.scalatest.flatspec.AnyFlatSpec | ||
import org.scalatest.matchers.should.Matchers | ||
import org.scalatestplus.mockito.MockitoSugar | ||
import za.co.absa.spline.harvester.plugin.Plugin.{ReadNodeInfo, WriteNodeInfo} | ||
import za.co.absa.spline.harvester.builder.SourceIdentifier | ||
import org.mockito.Mockito.{mock, _} | ||
import net.snowflake.spark.snowflake.Parameters | ||
import net.snowflake.spark.snowflake.Parameters.MergedParameters | ||
import org.apache.spark.sql.sources.BaseRelation | ||
import za.co.absa.spline.commons.reflect.extractors.SafeTypeMatchingExtractor | ||
import za.co.absa.spline.commons.reflect.{ReflectionUtils, ValueExtractor} | ||
|
||
class SnowflakePluginSpec extends AnyFlatSpec with Matchers with MockitoSugar { | ||
"SnowflakePlugin" should "process Snowflake relation providers" in { | ||
// Setup | ||
val spark = mock[SparkSession] | ||
val plugin = new SnowflakePlugin(spark) | ||
|
||
val options = Map( | ||
"sfUrl" -> "test-url", | ||
"sfWarehouse" -> "test-warehouse", | ||
"sfDatabase" -> "test-database", | ||
"sfSchema" -> "test-schema", | ||
"sfUser" -> "user1", | ||
"dbtable" -> "test-table" | ||
) | ||
|
||
val cmd = mock[SaveIntoDataSourceCommand] | ||
when(cmd.options) thenReturn(options) | ||
when(cmd.mode) thenReturn(SaveMode.Overwrite) | ||
when(cmd.query) thenReturn(null) | ||
|
||
// Mocking the relation provider to be Snowflake | ||
val snowflakeRP = "net.snowflake.spark.snowflake.DefaultSource" | ||
|
||
// Execute | ||
val result = plugin.relationProviderProcessor((snowflakeRP, cmd)) | ||
|
||
// Verify | ||
val expectedSourceId = SourceIdentifier(Some("snowflake"), "snowflake://test-url.test-warehouse.test-database.test-schema.test-table") | ||
result shouldEqual WriteNodeInfo(expectedSourceId, SaveMode.Overwrite, null, options) | ||
} | ||
|
||
it should "not process non-Snowflake relation providers" in { | ||
// Setup | ||
val spark = mock[SparkSession] | ||
val plugin = new SnowflakePlugin(spark) | ||
|
||
val cmd = mock[SaveIntoDataSourceCommand] | ||
|
||
// Mocking the relation provider to be non-Snowflake | ||
val nonSnowflakeRP = "some.other.datasource" | ||
|
||
// Execute & Verify | ||
assertThrows[MatchError] { | ||
plugin.relationProviderProcessor((nonSnowflakeRP, cmd)) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
integration-tests/src/test/scala/za/co/absa/spline/SnowflakeSpec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
/* | ||
* Copyright 2019 ABSA Group Limited | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package za.co.absa.spline | ||
|
||
import org.apache.spark.sql.{Row, RowFactory} | ||
import org.scalatest.BeforeAndAfterAll | ||
import org.scalatest.flatspec.AsyncFlatSpec | ||
import org.scalatest.matchers.should.Matchers | ||
import org.testcontainers.containers.GenericContainer | ||
import org.testcontainers.containers.wait.strategy.Wait | ||
import org.testcontainers.utility.DockerImageName | ||
import za.co.absa.spline.commons.io.TempDirectory | ||
import za.co.absa.spline.test.fixture.spline.SplineFixture | ||
import za.co.absa.spline.test.fixture.{ReleasableResourceFixture, SparkFixture} | ||
|
||
import java.util | ||
|
||
class SnowflakeSpec | ||
extends AsyncFlatSpec | ||
with BeforeAndAfterAll | ||
with Matchers | ||
with SparkFixture | ||
with SplineFixture | ||
with ReleasableResourceFixture { | ||
|
||
val tableName = "testTable" | ||
val schemaName = "testSchema" | ||
val warehouseName = "testWarehouse" | ||
val databaseName = "test" | ||
val sparkFormat = "net.snowflake.spark.snowflake" | ||
|
||
it should "support snowflake as a read and write source" in { | ||
usingResource(new GenericContainer(DockerImageName.parse("localstack/snowflake"))) { container => | ||
container.start() | ||
Wait.forHealthcheck | ||
|
||
val host = container.getHost | ||
|
||
withNewSparkSession { implicit spark => | ||
|
||
withLineageTracking { captor => | ||
val sfOptions = Map( | ||
"sfURL" -> "snowflake.localhost.localstack.cloud", | ||
"sfUser" -> "test", | ||
"sfPassword" -> "test", | ||
"sfDatabase" -> databaseName, | ||
"sfWarehouse" -> warehouseName, | ||
"sfSchema" -> schemaName | ||
) | ||
|
||
// Define your data as a Java List | ||
val data = new util.ArrayList[Row]() | ||
data.add(RowFactory.create(1.asInstanceOf[Object])) | ||
data.add(RowFactory.create(2.asInstanceOf[Object])) | ||
data.add(RowFactory.create(3.asInstanceOf[Object])) | ||
|
||
// Use the method to create DataFrame | ||
val testData = spark.sqlContext.createDataFrame(data, classOf[Row]) | ||
|
||
for { | ||
(writePlan, _) <- captor.lineageOf( | ||
testData.write | ||
.format(sparkFormat) | ||
.options(sfOptions) | ||
.option("dbtable", tableName) | ||
.mode("overwrite") | ||
.save() | ||
) | ||
|
||
(readPlan, _) <- captor.lineageOf { | ||
val df = spark.read.format(sparkFormat) | ||
.options(sfOptions) | ||
.option("dbtable", tableName) // specify the source table | ||
.load() | ||
|
||
df.write.save(TempDirectory(pathOnly = true).deleteOnExit().path.toString) | ||
} | ||
} yield { | ||
writePlan.operations.write.append shouldBe false | ||
writePlan.operations.write.extra("destinationType") shouldBe Some("snowflake") | ||
writePlan.operations.write.outputSource shouldBe s"snowflake://$host.$warehouseName.$databaseName.$schemaName.$tableName" | ||
|
||
readPlan.operations.reads.head.inputSources.head shouldBe writePlan.operations.write.outputSource | ||
readPlan.operations.reads.head.extra("sourceType") shouldBe Some("snowflake") | ||
readPlan.operations.write.append shouldBe false | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters