Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#542 Add databricks driver processing in SqlGeneratorLoader and provide databricks dialect using #543

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.sql

import org.apache.spark.sql.jdbc.JdbcDialects
import org.slf4j.LoggerFactory
import za.co.absa.pramen.api.offset.OffsetValue
import za.co.absa.pramen.api.sql.{SqlColumnType, SqlConfig, SqlGeneratorBase}
import za.co.absa.pramen.core.sql.dialects.DatabricksDialect

import java.time.format.DateTimeFormatter
import java.time.{LocalDate, LocalDateTime}

object SqlGeneratorDatabricks {
private val log = LoggerFactory.getLogger(this.getClass)

/**
* This is required for Spark to be able to handle data that comes from Databricks JDBC drivers
*/
lazy val registerDialect: Boolean = {
log.info(s"Registering Databricks dialect...")
JdbcDialects.registerDialect(DatabricksDialect)
true
}
}

class SqlGeneratorDatabricks(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig) {
private val dateFormatterApp = DateTimeFormatter.ofPattern(sqlConfig.dateFormatApp)

override val beginEndEscapeChars: (Char, Char) = ('`', '`')

SqlGeneratorDatabricks.registerDialect

override def getDtable(sql: String): String = {
s"($sql) tbl"
}

override def getCountQuery(tableName: String): String = {
s"SELECT COUNT(*) FROM ${escape(tableName)}"
}

override def getCountQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT COUNT(*) FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}

override def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit)}"
}

override def getWhere(dateBegin: LocalDate, dateEnd: LocalDate): String = {
val dateBeginLit = getDateLiteral(dateBegin)
val dateEndLit = getDateLiteral(dateEnd)

val dateTypes: Array[SqlColumnType] = Array(SqlColumnType.DATETIME)

val infoDateColumnAdjusted =
if (dateTypes.contains(sqlConfig.infoDateType)) {
s"CAST($infoDateColumn AS DATE)"
} else {
infoDateColumn
}

if (dateBeginLit == dateEndLit) {
s"$infoDateColumnAdjusted = $dateBeginLit"
} else {
s"$infoDateColumnAdjusted >= $dateBeginLit AND $infoDateColumnAdjusted <= $dateEndLit"
}
}

override def getDateLiteral(date: LocalDate): String = {
sqlConfig.infoDateType match {
case SqlColumnType.DATE =>
val dateStr = DateTimeFormatter.ISO_LOCAL_DATE.format(date)
s"to_date('$dateStr')"
case SqlColumnType.DATETIME =>
val dateStr = DateTimeFormatter.ISO_LOCAL_DATE.format(date)
s"to_date('$dateStr')"
case SqlColumnType.STRING =>
val dateStr = dateFormatterApp.format(date)
s"'$dateStr'"
case SqlColumnType.NUMBER =>
val dateStr = dateFormatterApp.format(date)
s"$dateStr"
}
}

override def getOffsetWhereCondition(column: String, condition: String, offset: OffsetValue): String = {
offset match {
case OffsetValue.DateTimeValue(ts) =>
val ldt = LocalDateTime.ofInstant(ts, sqlConfig.serverTimeZone)
val tsLiteral = timestampGenericDbFormatter.format(ldt)
s"$column $condition '$tsLiteral'"
case OffsetValue.IntegralValue(value) =>
s"$column $condition $value"
case OffsetValue.StringValue(value) =>
s"$column $condition '$value'"
}
}

private def getLimit(limit: Option[Int]): String = {
limit.map(n => s" LIMIT $n").getOrElse("")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ object SqlGeneratorLoader {
case "com.simba.spark.jdbc.Driver" => new SqlGeneratorHive(sqlConfig)
case "org.hsqldb.jdbc.JDBCDriver" => new SqlGeneratorHsqlDb(sqlConfig)
case "com.ibm.db2.jcc.DB2Driver" => new SqlGeneratorDb2(sqlConfig)
case "com.databricks.client.jdbc.Driver" => new SqlGeneratorDatabricks(sqlConfig)
case d =>
log.warn(s"Unsupported JDBC driver: '$d'. Trying to use a generic SQL generator.")
new SqlGeneratorGeneric(sqlConfig)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.sql.dialects

import org.apache.spark.sql.jdbc.JdbcDialect
import org.apache.spark.sql.types.{DataType, MetadataBuilder}
import org.slf4j.LoggerFactory

/**
* This is required for Spark to be able to handle data that comes from Databricks JDBC drivers
*/
object DatabricksDialect extends JdbcDialect {
private val logger = LoggerFactory.getLogger(this.getClass)

override def canHandle(url: String): Boolean = url.startsWith("jdbc:databricks")

override def quoteIdentifier(colName: String): String = {
colName.split('.').map(sub => s"`$sub`").mkString(".")
}

override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
super.getCatalystType(sqlType, typeName, size, md)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.tests.sql

import org.scalatest.wordspec.AnyWordSpec
import za.co.absa.pramen.api.offset.OffsetValue
import za.co.absa.pramen.api.sql.{QuotingPolicy, SqlColumnType, SqlGenerator, SqlGeneratorBase}
import za.co.absa.pramen.core.mocks.DummySqlConfigFactory

import java.time.{Instant, LocalDate}

class SqlGeneratorDatabricksSuite extends AnyWordSpec {

import za.co.absa.pramen.core.sql.SqlGeneratorLoader._

private val sqlConfigDate = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.DATE, infoDateColumn = "D")
private val sqlConfigEscape = DummySqlConfigFactory.getDummyConfig(infoDateColumn = "Info date", identifierQuotingPolicy = QuotingPolicy.Always)
private val sqlConfigDateTime = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.DATETIME, infoDateColumn = "D")
private val sqlConfigString = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.STRING, infoDateColumn = "D")
private val sqlConfigNumber = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.NUMBER, infoDateColumn = "D", dateFormatApp = "yyyyMMdd")
private val columns = Seq("A", "D", "Column with spaces")

private val date1 = LocalDate.of(2020, 8, 17)
private val date2 = LocalDate.of(2020, 8, 30)

val driver = "com.databricks.client.jdbc.Driver"

val gen: SqlGenerator = getSqlGenerator(driver, sqlConfigDate)
val genStr: SqlGenerator = getSqlGenerator(driver, sqlConfigString)
val genNum: SqlGenerator = getSqlGenerator(driver, sqlConfigNumber)
val genDateTime: SqlGenerator = getSqlGenerator(driver, sqlConfigDateTime)
val genEscaped: SqlGenerator = getSqlGenerator(driver, sqlConfigEscape)
val genEscaped2: SqlGenerator = getSqlGenerator(driver, DummySqlConfigFactory.getDummyConfig(infoDateColumn = "`Info date`", identifierQuotingPolicy = QuotingPolicy.Auto))

"generate count queries without date ranges" in {
assert(gen.getCountQuery("A") == "SELECT COUNT(*) FROM A")
}

"generate data queries without date ranges" in {
assert(gen.getDataQuery("A", Nil, None) == "SELECT * FROM A")
}

"generate data queries when list of columns is specified" in {
assert(genEscaped.getDataQuery("A", columns, None) == "SELECT `A`, `D`, `Column with spaces` FROM `A`")
}

"generate data queries with limit clause date ranges" in {
assert(gen.getDataQuery("A", Nil, Some(100)) == "SELECT * FROM A LIMIT 100")
}

"generate ranged count queries" when {
"date is in DATE format" in {
assert(gen.getCountQuery("A", date1, date1) ==
"SELECT COUNT(*) FROM A WHERE D = to_date('2020-08-17')")
assert(gen.getCountQuery("A", date1, date2) ==
"SELECT COUNT(*) FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30')")
}

"date is in DATETIME format" in {
assert(genDateTime.getCountQuery("A", date1, date1) ==
"SELECT COUNT(*) FROM A WHERE CAST(D AS DATE) = to_date('2020-08-17')")
assert(genDateTime.getCountQuery("A", date1, date2) ==
"SELECT COUNT(*) FROM A WHERE CAST(D AS DATE) >= to_date('2020-08-17') AND CAST(D AS DATE) <= to_date('2020-08-30')")
}

"date is in STRING format" in {
assert(genStr.getCountQuery("A", date1, date1) ==
"SELECT COUNT(*) FROM A WHERE D = '2020-08-17'")
assert(genStr.getCountQuery("A", date1, date2) ==
"SELECT COUNT(*) FROM A WHERE D >= '2020-08-17' AND D <= '2020-08-30'")
}

"date is in NUMBER format" in {
assert(genNum.getCountQuery("A", date1, date1) ==
"SELECT COUNT(*) FROM A WHERE D = 20200817")
assert(genNum.getCountQuery("A", date1, date2) ==
"SELECT COUNT(*) FROM A WHERE D >= 20200817 AND D <= 20200830")
}

"the table name and column name need to be escaped" in {
assert(genEscaped.getCountQuery("Input Table", date1, date1) ==
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` = to_date('2020-08-17')")
assert(genEscaped.getCountQuery("Input Table", date1, date2) ==
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` >= to_date('2020-08-17') AND `Info date` <= to_date('2020-08-30')")
}

"the table name and column name already escaped" in {
assert(genEscaped2.getCountQuery("Input Table", date1, date1) ==
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` = to_date('2020-08-17')")
assert(genEscaped2.getCountQuery("Input Table", date1, date2) ==
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` >= to_date('2020-08-17') AND `Info date` <= to_date('2020-08-30')")
}
}

"generate ranged data queries" when {
"date is in DATE format" in {
assert(gen.getDataQuery("A", date1, date1, Nil, None) ==
"SELECT * FROM A WHERE D = to_date('2020-08-17')")
assert(gen.getDataQuery("A", date1, date2, Nil, None) ==
"SELECT * FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30')")
}

"date is in DATETIME format" in {
assert(genDateTime.getDataQuery("A", date1, date1, Nil, None) ==
"SELECT * FROM A WHERE CAST(D AS DATE) = to_date('2020-08-17')")
assert(genDateTime.getDataQuery("A", date1, date2, Nil, None) ==
"SELECT * FROM A WHERE CAST(D AS DATE) >= to_date('2020-08-17') AND CAST(D AS DATE) <= to_date('2020-08-30')")
}

"date is in STRING format" in {
assert(genStr.getDataQuery("A", date1, date1, Nil, None) ==
"SELECT * FROM A WHERE D = '2020-08-17'")
assert(genStr.getDataQuery("A", date1, date2, Nil, None) ==
"SELECT * FROM A WHERE D >= '2020-08-17' AND D <= '2020-08-30'")
}

"date is in NUMBER format" in {
assert(genNum.getDataQuery("A", date1, date1, Nil, None) ==
"SELECT * FROM A WHERE D = 20200817")
assert(genNum.getDataQuery("A", date1, date2, Nil, None) ==
"SELECT * FROM A WHERE D >= 20200817 AND D <= 20200830")
}

"with limit records" in {
assert(gen.getDataQuery("A", date1, date1, Nil, Some(100)) ==
"SELECT * FROM A WHERE D = to_date('2020-08-17') LIMIT 100")
assert(gen.getDataQuery("A", date1, date2, Nil, Some(100)) ==
"SELECT * FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30') LIMIT 100")
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(gen.getDtable("A") == "(A) tbl")
}

"wrapped query without alias for SQL queries " in {
assert(gen.getDtable("SELECT A FROM B") == "(SELECT A FROM B) tbl")
}
}

"quote" should {
"escape each subfields separately" in {
val actual = gen.quote("System User.`Table Name`")

assert(actual == "`System User`.`Table Name`")
}
}

"getOffsetWhereCondition" should {
"return the correct condition for integral offsets" in {
val actual = gen.asInstanceOf[SqlGeneratorBase]
.getOffsetWhereCondition("offset", "<", OffsetValue.IntegralValue(1))

assert(actual == "offset < 1")
}

"return the correct condition for datetime offsets" in {
val actual = gen.asInstanceOf[SqlGeneratorBase]
.getOffsetWhereCondition("offset", ">", OffsetValue.DateTimeValue(Instant.ofEpochMilli(1727761000)))

assert(actual == "offset > '1970-01-21 01:56:01.000'")
}

"return the correct condition for string offsets" in {
val actual = gen.asInstanceOf[SqlGeneratorBase]
.getOffsetWhereCondition("offset", ">=", OffsetValue.StringValue("AAA"))

assert(actual == "offset >= 'AAA'")
}
}
}
Loading