Skip to content

Commit

Permalink
[SPARK-45827][SQL] Move data type checks to CreatableRelationProvider
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

In DataSource.scala, there are checks to prevent writing Variant and Interval types to a `CreatableRelationalProvider`. This PR unifies the checks in a method on `CreatableRelationalProvider` so that data sources can override in order to specify a different set of supported data types.

### Why are the changes needed?

Allows data sources to specify what types they support, while providing a sensible default for most data sources.

### Does this PR introduce _any_ user-facing change?

The error message for Variant and Interval are now shared, and are a bit more generic. The intent is to otherwise not have any user-facing change.

### How was this patch tested?

Unit tests added.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#45409 from cashmand/SPARK-45827-CreatableRelationProvider.

Authored-by: cashmand <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
cashmand authored and cloud-fan committed Mar 7, 2024
1 parent 456d246 commit 0f91642
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, Tex
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{DataType, StructField, StructType, VariantType}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.util.{HadoopFSUtils, ThreadUtils, Utils}
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -503,8 +503,12 @@ case class DataSource(
val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames)
providingInstance() match {
case dataSource: CreatableRelationProvider =>
disallowWritingIntervals(outputColumns.map(_.dataType), forbidAnsiIntervals = true)
disallowWritingVariant(outputColumns.map(_.dataType))
outputColumns.foreach { attr =>
if (!dataSource.supportsDataType(attr.dataType)) {
throw QueryCompilationErrors.dataTypeUnsupportedByDataSourceError(
dataSource.toString, StructField(attr.toString, attr.dataType))
}
}
dataSource.createRelation(
sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data))
case format: FileFormat =>
Expand All @@ -525,8 +529,12 @@ case class DataSource(
def planForWriting(mode: SaveMode, data: LogicalPlan): LogicalPlan = {
providingInstance() match {
case dataSource: CreatableRelationProvider =>
disallowWritingIntervals(data.schema.map(_.dataType), forbidAnsiIntervals = true)
disallowWritingVariant(data.schema.map(_.dataType))
data.schema.foreach { field =>
if (!dataSource.supportsDataType(field.dataType)) {
throw QueryCompilationErrors.dataTypeUnsupportedByDataSourceError(
dataSource.toString, field)
}
}
SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode)
case format: FileFormat =>
disallowWritingIntervals(data.schema.map(_.dataType), forbidAnsiIntervals = false)
Expand Down Expand Up @@ -563,14 +571,6 @@ case class DataSource(
throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError()
})
}

private def disallowWritingVariant(dataTypes: Seq[DataType]): Unit = {
dataTypes.foreach { dt =>
if (dt.existsRecursively(_.isInstanceOf[VariantType])) {
throw QueryCompilationErrors.cannotSaveVariantIntoExternalStorageError()
}
}
}
}

object DataSource extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types._

/**
* Data sources should implement this trait so that they can register an alias to their data source.
Expand Down Expand Up @@ -175,6 +175,27 @@ trait CreatableRelationProvider {
mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation

/**
* Check if the relation supports the given data type.
*
* @param dt Data type to check
* @return True if the data type is supported
*
* @since 4.0.0
*/
def supportsDataType(dt: DataType): Boolean = {
dt match {
case ArrayType(e, _) => supportsDataType(e)
case MapType(k, v, _) => supportsDataType(k) && supportsDataType(v)
case StructType(fields) => fields.forall(f => supportsDataType(f.dataType))
case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)
case BinaryType | BooleanType | ByteType | CharType(_) | DateType | _ : DecimalType |
DoubleType | FloatType | IntegerType | LongType | NullType | ObjectType(_) | ShortType |
StringType | TimestampNTZType | TimestampType | VarcharType(_) => true
case _ => false
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, TableScan}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{LongType, StructField, StructType}
Expand Down Expand Up @@ -71,6 +71,42 @@ class SaveIntoDataSourceCommandSuite extends QueryTest with SharedSparkSession {

FakeV1DataSource.data = null
}

test("Data type support") {

val dataSource = DataSource(
sparkSession = spark,
className = "jdbc",
partitionColumns = Nil,
options = Map())

val df = spark.range(1).selectExpr(
"cast('a' as binary) a", "true b", "cast(1 as byte) c", "1.23 d")
dataSource.planForWriting(SaveMode.ErrorIfExists, df.logicalPlan)

// Variant and Interval types are disallowed by default.
val unsupportedTypes = Seq(
("parse_json('1') col", "VARIANT"),
("array(parse_json('1')) col", "ARRAY<VARIANT>"),
("struct(1, parse_json('1')) col", "STRUCT<col1: INT NOT NULL, col2: VARIANT NOT NULL>"),
("map(1, parse_json('1')) col", "MAP<INT, VARIANT>"),
("INTERVAL '1' MONTH col", "INTERVAL MONTH"),
("make_ym_interval(1, 2) col", "INTERVAL YEAR TO MONTH"),
("make_dt_interval(1, 2, 3, 4) col", "INTERVAL DAY TO SECOND"))

unsupportedTypes.foreach { testCase =>
val df = spark.range(1).selectExpr(testCase._1)
checkError(
exception = intercept[AnalysisException] {
dataSource.planForWriting(SaveMode.ErrorIfExists, df.logicalPlan)
},
errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE",
parameters = Map("columnName" -> "`col`", "columnType" -> s"\"${testCase._2}\"",
"format" -> ".*JdbcRelationProvider.*"),
matchPVals = true
)
}
}
}

object FakeV1DataSource {
Expand Down

0 comments on commit 0f91642

Please sign in to comment.