diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 954fcdf3b4d93..302a0275491a2 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5277,6 +5277,11 @@ "Resilient Distributed Datasets (RDDs)." ] }, + "REGISTER_UDAF" : { + "message" : [ + "Registering User Defined Aggregate Functions (UDAFs)." + ] + }, "SESSION_BASE_RELATION_TO_DATAFRAME" : { "message" : [ "Invoking SparkSession 'baseRelationToDataFrame'. This is server side developer API" diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala index bff6db25a21f2..ca82381eec9e3 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala @@ -31,8 +31,8 @@ import ammonite.util.Util.newLine import org.apache.spark.SparkBuildInfo.spark_version import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.SparkSession.withLocalConnectServer +import org.apache.spark.sql.connect.SparkSession +import org.apache.spark.sql.connect.SparkSession.withLocalConnectServer import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkConnectClientParser} /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala deleted file mode 100644 index 86b1dbe4754e6..0000000000000 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalog - -import java.util - -import org.apache.spark.sql.{api, DataFrame, Dataset} -import org.apache.spark.sql.connect.ConnectConversions._ -import org.apache.spark.sql.types.StructType - -/** @inheritdoc */ -abstract class Catalog extends api.Catalog { - - /** @inheritdoc */ - override def listDatabases(): Dataset[Database] - - /** @inheritdoc */ - override def listDatabases(pattern: String): Dataset[Database] - - /** @inheritdoc */ - override def listTables(): Dataset[Table] - - /** @inheritdoc */ - override def listTables(dbName: String): Dataset[Table] - - /** @inheritdoc */ - override def listTables(dbName: String, pattern: String): Dataset[Table] - - /** @inheritdoc */ - override def listFunctions(): Dataset[Function] - - /** @inheritdoc */ - override def listFunctions(dbName: String): Dataset[Function] - - /** @inheritdoc */ - override def listFunctions(dbName: String, pattern: String): Dataset[Function] - - /** @inheritdoc */ - override def listColumns(tableName: String): Dataset[Column] - - /** @inheritdoc */ - override def listColumns(dbName: String, tableName: String): Dataset[Column] - - /** @inheritdoc */ - override def createTable(tableName: String, path: String): DataFrame - - /** @inheritdoc */ - override def createTable(tableName: String, path: String, source: String): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - description: String, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - description: String, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def listCatalogs(): Dataset[CatalogMetadata] - - /** @inheritdoc */ - override def listCatalogs(pattern: String): Dataset[CatalogMetadata] - - /** @inheritdoc */ - override def createExternalTable(tableName: String, path: String): DataFrame = - super.createExternalTable(tableName, path) - - /** @inheritdoc */ - override def createExternalTable(tableName: String, path: String, source: String): DataFrame = - super.createExternalTable(tableName, path, source) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - options: util.Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, options) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - options: Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, options) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - schema: StructType, - options: util.Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, schema, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - description: String, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, description, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, schema, options) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - schema: StructType, - options: Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, schema, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - description: String, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, schema, description, options) -} diff --git a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java index e5f06eb7dbcec..907105e370c08 100644 --- a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java +++ b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java @@ -28,7 +28,7 @@ import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.RowFactory.create; import org.apache.spark.api.java.function.MapFunction; -import org.apache.spark.sql.test.SparkConnectServerUtils; +import org.apache.spark.sql.connect.test.SparkConnectServerUtils; import org.apache.spark.sql.types.StructType; /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala index 33c9911e75c9b..ba77879a5a800 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql import org.apache.spark.SparkRuntimeException +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} class DataFrameSubquerySuite extends QueryTest with RemoteSparkSession { import testImplicits._ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala index 12a49ad21676e..d11f276e8ed47 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} class DataFrameTableValuedFunctionsSuite extends QueryTest with RemoteSparkSession { import testImplicits._ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 8aed696005799..a548ec7007dbe 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -38,12 +38,13 @@ import org.apache.spark.sql.avro.{functions => avroFn} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.connect.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.test.{ConnectFunSuite, IntegrationTestUtils} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.protobuf.{functions => pbFn} -import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.SparkFileUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLExpressionsSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLExpressionsSuite.scala index fcd2b3a388042..83ad943a2253e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLExpressionsSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLExpressionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.unsafe.types.VariantVal class SQLExpressionsSuite extends QueryTest with RemoteSparkSession { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index bb7d1b25738c1..3596de39e86a3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -25,7 +25,7 @@ import scala.util.Properties import org.apache.commons.io.output.ByteArrayOutputStream import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession} import org.apache.spark.tags.AmmoniteTest import org.apache.spark.util.IvyTestUtils import org.apache.spark.util.MavenUtils.MavenCoordinate diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CatalogSuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CatalogSuite.scala index ce552bdd4f0f0..b2c19226dc542 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CatalogSuite.scala @@ -15,14 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.io.{File, FilenameFilter} import org.apache.commons.io.FileUtils import org.apache.spark.SparkException -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} import org.apache.spark.sql.types.{DoubleType, LongType, StructType} import org.apache.spark.storage.StorageLevel diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CheckpointSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CheckpointSuite.scala index 0d9685d9c710f..eeda75ae13310 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CheckpointSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.io.{ByteArrayOutputStream, PrintStream} @@ -26,7 +26,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.apache.spark.SparkException import org.apache.spark.connect.proto -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} import org.apache.spark.storage.StorageLevel class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelper { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala index 84ed624a95214..a7e2e61a106f9 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala @@ -15,14 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.Random import org.scalatest.matchers.must.Matchers._ import org.apache.spark.SparkIllegalArgumentException -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} class ClientDataFrameStatSuite extends ConnectFunSuite with RemoteSparkSession { private def toLetter(i: Int): String = (i + 97).toChar.toString diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDatasetSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDatasetSuite.scala index c93f9b1c0dbdd..7e6cebfd972df 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDatasetSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.Properties import java.util.concurrent.TimeUnit @@ -25,9 +25,10 @@ import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder} import org.scalatest.BeforeAndAfterEach import org.apache.spark.connect.proto +import org.apache.spark.sql.Column import org.apache.spark.sql.connect.client.{DummySparkConnectService, SparkConnectClient} +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.SparkSerDeUtils // Add sample tests. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala similarity index 99% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala index c7979b8e033ea..1b73f9f2f4543 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.io.{ByteArrayOutputStream, PrintStream} import java.nio.file.Files @@ -34,15 +34,17 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException} import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.sql.{functions, AnalysisException, Observation, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, TableAlreadyExistsException, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession, SQLHelper} +import org.apache.spark.sql.connect.test.SparkConnectServerUtils.port import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession, SQLHelper} -import org.apache.spark.sql.test.SparkConnectServerUtils.port import org.apache.spark.sql.types._ import org.apache.spark.util.SparkThreadUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala index 24e9156775c44..02f0c35c44a8f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.connect import org.apache.spark.SparkException import org.apache.spark.connect.proto @@ -23,9 +23,10 @@ import org.apache.spark.sql.{Column, Encoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder} import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ProtoDataTypes} +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator} -import org.apache.spark.sql.test.ConnectFunSuite -import org.apache.spark.sql.types.{BinaryType, DataType, DoubleType, LongType, MetadataBuilder, ShortType, StringType, StructType} +import org.apache.spark.sql.internal._ +import org.apache.spark.sql.types._ /** * Test suite for [[ColumnNode]] to [[proto.Expression]] conversions. @@ -471,8 +472,8 @@ class ColumnNodeToProtoConverterSuite extends ConnectFunSuite { } } -private[internal] case class Nope(override val origin: Origin = CurrentOrigin.get) +private[connect] case class Nope(override val origin: Origin = CurrentOrigin.get) extends ColumnNode { override def sql: String = "nope" - override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty + override def children: Seq[ColumnNodeLike] = Seq.empty } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnTestSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnTestSuite.scala index 86c7a20136851..863cb5872c72e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnTestSuite.scala @@ -14,14 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.io.ByteArrayOutputStream import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.{functions => fn} -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.{functions => fn, Column, ColumnName} +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.types._ /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameNaFunctionSuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameNaFunctionSuite.scala index 8a783d880560e..988d4d4c4f5da 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameNaFunctionSuite.scala @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import scala.jdk.CollectionConverters._ +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.types.{StringType, StructType} class DataFrameNaFunctionSuite extends QueryTest with RemoteSparkSession { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/FunctionTestSuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/FunctionTestSuite.scala index 40b66bcb8358d..8c5fc2b2b8ec9 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/FunctionTestSuite.scala @@ -14,15 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.Collections import scala.jdk.CollectionConverters._ +import org.apache.spark.sql.Column import org.apache.spark.sql.avro.{functions => avroFn} +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types.{DataType, StructType} /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/KeyValueGroupedDatasetE2ETestSuite.scala similarity index 99% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/KeyValueGroupedDatasetE2ETestSuite.scala index 021b4fea26e2a..c2046c6f26700 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/KeyValueGroupedDatasetE2ETestSuite.scala @@ -14,15 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.sql.Timestamp import java.util.Arrays +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.types._ import org.apache.spark.util.SparkSerDeUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/MergeIntoE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/MergeIntoE2ETestSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/MergeIntoE2ETestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/MergeIntoE2ETestSuite.scala index cdb72e3baf0e9..832efca96550e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/MergeIntoE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/MergeIntoE2ETestSuite.scala @@ -14,10 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect +import org.apache.spark.sql.Row +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} class MergeIntoE2ETestSuite extends ConnectFunSuite with RemoteSparkSession { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala index b3b8020b1e4c7..2791c6b6add55 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.sql.{Date, Timestamp} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} @@ -26,10 +26,11 @@ import org.apache.arrow.memory.RootAllocator import org.apache.commons.lang3.SystemUtils import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.{Column, Encoder, SaveMode} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers, ArrowSerializer} -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite /** * Test suite for SQL implicits. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionBuilderImplementationBindingSuite.scala similarity index 82% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionBuilderImplementationBindingSuite.scala index ed930882ac2fd..cc6bc8af1f4b3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionBuilderImplementationBindingSuite.scala @@ -14,17 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect -import org.apache.spark.sql.api.SparkSessionBuilder -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} +import org.apache.spark.sql +import org.apache.spark.sql.SparkSessionBuilder +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} /** * Make sure the api.SparkSessionBuilder binds to Connect implementation. */ class SparkSessionBuilderImplementationBindingSuite extends ConnectFunSuite - with api.SparkSessionBuilderImplementationBindingSuite + with sql.SparkSessionBuilderImplementationBindingSuite with RemoteSparkSession { override protected def configure(builder: SparkSessionBuilder): builder.type = { // We need to set this configuration because the port used by the server is random. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala similarity index 99% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala index b116edb7df7ce..41e49e3d47d9e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.concurrent.ForkJoinPool @@ -26,7 +26,7 @@ import scala.util.{Failure, Success} import org.scalatest.concurrent.Eventually._ import org.apache.spark.SparkException -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} import org.apache.spark.util.SparkThreadUtils.awaitResult /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionSuite.scala similarity index 99% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionSuite.scala index dec56554d143e..bab6ae39563f6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.concurrent.{Executors, Phaser} @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} import org.apache.spark.SparkException -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.util.SparkSerDeUtils /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/StubbingTestSuite.scala similarity index 91% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/StubbingTestSuite.scala index 5bcb17672d6a9..a0c9a2b992f1d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/StubbingTestSuite.scala @@ -14,10 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import org.apache.spark.sql.connect.client.ToStub -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} class StubbingTestSuite extends ConnectFunSuite with RemoteSparkSession { private def eval[T](f: => T): T = f diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UDFClassLoadingE2ESuite.scala similarity index 96% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UDFClassLoadingE2ESuite.scala index c1e44b6fb11b2..b50442de31f04 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UDFClassLoadingE2ESuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.io.File import java.nio.file.{Files, Paths} @@ -25,7 +25,7 @@ import com.google.protobuf.ByteString import org.apache.spark.connect.proto import org.apache.spark.sql.connect.common.ProtoDataTypes -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} class UDFClassLoadingE2ESuite extends ConnectFunSuite with RemoteSparkSession { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UnsupportedFeaturesSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UnsupportedFeaturesSuite.scala similarity index 96% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UnsupportedFeaturesSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UnsupportedFeaturesSuite.scala index 42ae6987c9f36..a71d887b2d6c7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UnsupportedFeaturesSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UnsupportedFeaturesSuite.scala @@ -14,13 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types.StructType /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionE2ETestSuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionE2ETestSuite.scala index 19275326d6421..a4ef11554d30e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionE2ETestSuite.scala @@ -14,21 +14,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.lang.{Long => JLong} -import java.util.{Iterator => JIterator} -import java.util.Arrays +import java.util.{Arrays, Iterator => JIterator} import java.util.concurrent.atomic.AtomicLong import scala.jdk.CollectionConverters._ import org.apache.spark.api.java.function._ +import org.apache.spark.sql.{AnalysisException, Encoder, Encoders, Row} import org.apache.spark.sql.api.java.UDF2 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder} +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions.{col, struct, udaf, udf} -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.types.IntegerType /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionSuite.scala similarity index 96% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionSuite.scala index f7dadea98c281..153026fcdbee3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionSuite.scala @@ -14,15 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import scala.reflect.runtime.universe.typeTag import org.apache.spark.SparkException +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.connect.common.UdfPacket +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.functions.{lit, udf} -import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.SparkSerDeUtils class UserDefinedFunctionSuite extends ConnectFunSuite { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala index 66a2c943af5f6..fb35812233562 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala @@ -32,7 +32,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.connect.proto.AddArtifactsRequest import org.apache.spark.sql.Artifact import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.util.IvyTestUtils import org.apache.spark.util.MavenUtils.MavenCoordinate diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 7bac10e79d0b4..11572da2c663c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -27,7 +27,7 @@ import com.typesafe.tools.mima.core._ import com.typesafe.tools.mima.lib.MiMaLib import org.apache.spark.SparkBuildInfo.spark_version -import org.apache.spark.sql.test.IntegrationTestUtils._ +import org.apache.spark.sql.connect.test.IntegrationTestUtils._ /** * A tool for checking the binary compatibility of the connect client API against the spark SQL diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala index ca23436675f87..2f8332878bbf5 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala @@ -20,7 +20,7 @@ import java.nio.file.Paths import org.apache.commons.io.FileUtils -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.util.SparkFileUtils class ClassFinderSuite extends ConnectFunSuite { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala index 68d2e86b19d70..b342d5b415692 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.connect.client import java.util.UUID -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite /** * Test suite for [[SparkConnectClient.Builder]] parsing and configuration. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index ac56600392aa3..acee1b2775f17 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -33,9 +33,9 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, Relation, SparkConnectServiceGrpc, SQL} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.SparkSession import org.apache.spark.sql.connect.common.config.ConnectCommon -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { @@ -630,26 +630,29 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer var errorToThrowOnExecute: Option[Throwable] = None - private[sql] def getAndClearLatestInputPlan(): proto.Plan = { + private[sql] def getAndClearLatestInputPlan(): proto.Plan = synchronized { val plan = inputPlan inputPlan = null plan } - private[sql] def getAndClearLatestAddArtifactRequests(): Seq[AddArtifactsRequest] = { - val requests = inputArtifactRequests.toSeq - inputArtifactRequests.clear() - requests - } + private[sql] def getAndClearLatestAddArtifactRequests(): Seq[AddArtifactsRequest] = + synchronized { + val requests = inputArtifactRequests.toSeq + inputArtifactRequests.clear() + requests + } override def executePlan( request: ExecutePlanRequest, responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { - if (errorToThrowOnExecute.isDefined) { - val error = errorToThrowOnExecute.get - errorToThrowOnExecute = None - responseObserver.onError(error) - return + synchronized { + if (errorToThrowOnExecute.isDefined) { + val error = errorToThrowOnExecute.get + errorToThrowOnExecute = None + responseObserver.onError(error) + return + } } // Reply with a dummy response using the same client ID @@ -659,7 +662,9 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer } else { UUID.randomUUID().toString } - inputPlan = request.getPlan + synchronized { + inputPlan = request.getPlan + } val response = ExecutePlanResponse .newBuilder() .setSessionId(requestSessionId) @@ -668,7 +673,7 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer responseObserver.onNext(response) // Reattachable execute must end with ResultComplete if (request.getRequestOptionsList.asScala.exists { option => - option.hasReattachOptions && option.getReattachOptions.getReattachable == true + option.hasReattachOptions && option.getReattachOptions.getReattachable }) { val resultComplete = ExecutePlanResponse .newBuilder() @@ -686,20 +691,22 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { // Reply with a dummy response using the same client ID val requestSessionId = request.getSessionId - request.getAnalyzeCase match { - case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA => - inputPlan = request.getSchema.getPlan - case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN => - inputPlan = request.getExplain.getPlan - case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING => - inputPlan = request.getTreeString.getPlan - case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL => - inputPlan = request.getIsLocal.getPlan - case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING => - inputPlan = request.getIsStreaming.getPlan - case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES => - inputPlan = request.getInputFiles.getPlan - case _ => inputPlan = null + synchronized { + request.getAnalyzeCase match { + case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA => + inputPlan = request.getSchema.getPlan + case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN => + inputPlan = request.getExplain.getPlan + case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING => + inputPlan = request.getTreeString.getPlan + case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL => + inputPlan = request.getIsLocal.getPlan + case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING => + inputPlan = request.getIsStreaming.getPlan + case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES => + inputPlan = request.getInputFiles.getPlan + case _ => inputPlan = null + } } val response = AnalyzePlanResponse .newBuilder() @@ -711,7 +718,8 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer override def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse]) : StreamObserver[AddArtifactsRequest] = new StreamObserver[AddArtifactsRequest] { - override def onNext(v: AddArtifactsRequest): Unit = inputArtifactRequests.append(v) + override def onNext(v: AddArtifactsRequest): Unit = + synchronized(inputArtifactRequests.append(v)) override def onError(throwable: Throwable): Unit = responseObserver.onError(throwable) @@ -728,13 +736,15 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer request.getNamesList().iterator().asScala.foreach { name => val status = proto.ArtifactStatusesResponse.ArtifactStatus.newBuilder() val exists = if (name.startsWith("cache/")) { - inputArtifactRequests.exists { artifactReq => - if (artifactReq.hasBatch) { - val batch = artifactReq.getBatch - batch.getArtifactsList.asScala.exists { singleArtifact => - singleArtifact.getName == name - } - } else false + synchronized { + inputArtifactRequests.exists { artifactReq => + if (artifactReq.hasBatch) { + val batch = artifactReq.getBatch + batch.getArtifactsList.asScala.exists { singleArtifact => + singleArtifact.getName == name + } + } else false + } } } else false builder.putStatuses(name, status.setExists(exists).build()) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 2cdbdc67d2cad..f6662b3351ba7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._ import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType} import org.apache.spark.unsafe.types.VariantVal diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala index 199a1507a3b19..d5f38231d3450 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.connect.streaming import java.io.{File, FileWriter} import java.nio.file.Paths @@ -29,10 +29,12 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.connect.SparkSession +import org.apache.spark.sql.connect.test.{IntegrationTestUtils, QueryTest, RemoteSparkSession} import org.apache.spark.sql.functions.{col, lit, udf, window} +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamingQueryListener, Trigger} import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent} -import org.apache.spark.sql.test.{IntegrationTestUtils, QueryTest, RemoteSparkSession} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.SparkFileUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/FlatMapGroupsWithStateStreamingSuite.scala similarity index 96% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/FlatMapGroupsWithStateStreamingSuite.scala index 9bd6614028cbf..2496038a77fe4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/FlatMapGroupsWithStateStreamingSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.connect.streaming import java.sql.Timestamp @@ -23,9 +23,10 @@ import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} +import org.apache.spark.sql.connect.SparkSession +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} case class ClickEvent(id: String, timestamp: Timestamp) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/StreamingQueryProgressSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/StreamingQueryProgressSuite.scala index aed6c55b3e7fb..0b8447e3bb4c1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/StreamingQueryProgressSuite.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.connect.streaming import scala.jdk.CollectionConverters._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite +import org.apache.spark.sql.streaming.StreamingQueryProgress import org.apache.spark.sql.types.StructType class StreamingQueryProgressSuite extends ConnectFunSuite { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/ConnectFunSuite.scala similarity index 94% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/ConnectFunSuite.scala index f46b98646ae4f..89f70d6f1214a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/ConnectFunSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql.connect.test import java.nio.file.Path @@ -22,7 +22,7 @@ import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite import org.apache.spark.connect.proto import org.apache.spark.sql.Column -import org.apache.spark.sql.internal.ColumnNodeToProtoConverter +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter /** * The basic testsuite the client tests should extend from. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/IntegrationTestUtils.scala similarity index 99% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/IntegrationTestUtils.scala index 3ae9b9fc73b48..79c7a797a97e0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/IntegrationTestUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql.connect.test import java.io.File import java.nio.file.{Files, Paths} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/QueryTest.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/QueryTest.scala index b22488858b8f5..5ae23368b9729 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/QueryTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql.connect.test import java.util.TimeZone @@ -24,8 +24,9 @@ import scala.jdk.CollectionConverters._ import org.scalatest.Assertions import org.apache.spark.{QueryContextType, SparkThrowable} -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.SparkStringUtils.sideBySide +import org.apache.spark.sql.connect.{DataFrame, Dataset, SparkSession} import org.apache.spark.util.ArrayImplicits._ abstract class QueryTest extends ConnectFunSuite with SQLHelper { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala index 36aaa2cc7fbf6..b6f17627fca85 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql.connect.test import java.io.{File, IOException, OutputStream} import java.lang.ProcessBuilder.Redirect @@ -23,17 +23,17 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration.FiniteDuration +import IntegrationTestUtils._ import org.scalatest.{BeforeAndAfterAll, Suite} import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkBuildInfo -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.SparkSession import org.apache.spark.sql.connect.client.RetryPolicy import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.connect.common.config.ConnectCommon -import org.apache.spark.sql.test.IntegrationTestUtils._ import org.apache.spark.util.ArrayImplicits._ /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/SQLHelper.scala similarity index 94% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/SQLHelper.scala index d9828ae92267b..f23221f1a46c0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/SQLHelper.scala @@ -14,15 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql.connect.test import java.io.File import java.util.UUID import org.scalatest.Assertions.fail -import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession, SQLImplicits} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.connect.{DataFrame, SparkSession, SQLImplicits} import org.apache.spark.util.{SparkErrorUtils, SparkFileUtils} trait SQLHelper { @@ -39,9 +40,7 @@ trait SQLHelper { * because we create the `SparkSession` immediately before the first test is run, but the * implicits import is needed in the constructor. */ - protected object testImplicits extends SQLImplicits { - override protected def session: SparkSession = spark - } + protected object testImplicits extends SQLImplicits(spark) /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index b77bb94aaf46f..ff884310f660c 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.types.StructType diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index d43b22d9de922..1b52046b14833 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -26,9 +26,11 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{ERROR, FROM_OFFSET, OFFSETS, TIP, TOPIC_PARTITIONS, UNTIL_OFFSET} import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql._ +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.connector.read.streaming import org.apache.spark.sql.connector.read.streaming.{Offset => _, _} import org.apache.spark.sql.execution.streaming._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index e67b72e090601..ac8845b09069d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.ClassicConversions.ColumnConstructorExt +import org.apache.spark.sql.classic.ExpressionUtils.expression import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.ExpressionUtils.expression import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala b/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala index b3a3bc1a791ba..2141d4251eded 100644 --- a/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala +++ b/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala @@ -21,10 +21,10 @@ import org.apache.spark.mllib.linalg.{SparseVector => OldSparseVector, Vector => import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.{Expression, StringLiteral} +import org.apache.spark.sql.classic.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction} import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF /** * Register a couple ML vector conversion UDFs in the internal function registry. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 2b20a282dd14a..ff9651c11250f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkContext import org.apache.spark.ml.feature._ import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util.TempDirectory -import org.apache.spark.sql.{SparkSession, SQLImplicits} +import org.apache.spark.sql.classic.{SparkSession, SQLImplicits} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d93f15d3a63a3..618d21149ab45 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -205,10 +205,33 @@ object MimaExcludes { // SPARK-50112: Moving avro files from connector to sql/core ProblemFilters.exclude[Problem]("org.apache.spark.sql.avro.*"), + // SPARK-49700: Unified Scala SQL Interface. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameNaFunctions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameReader"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameStatFunctions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.KeyValueGroupedDataset"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLImplicits"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$Builder"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$implicits$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.package"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.package$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.Catalog"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.DataStreamReader"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.DataStreamWriter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.DataStreamWriter$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryManager"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQuery"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$implicits$"), + // SPARK-50768: Introduce TaskContext.createResourceUninterruptibly to avoid stream leak by task interruption ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.interruptible"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.pendingInterrupt"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.createResourceUninterruptibly"), + ) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++ loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++ loggingExcludes("org.apache.spark.sql.SparkSession#Builder") diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 4f19421ca64f1..3dd8123d581c6 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -161,7 +161,10 @@ def killChild(): java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") java_import(gateway.jvm, "org.apache.spark.resource.*") # TODO(davies): move into sql - java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.Encoders") + java_import(gateway.jvm, "org.apache.spark.sql.OnSuccessCall") + java_import(gateway.jvm, "org.apache.spark.sql.functions") + java_import(gateway.jvm, "org.apache.spark.sql.classic.*") java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala index 4d3465b320391..8548801266b26 100644 --- a/repl/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -25,6 +25,7 @@ import scala.tools.nsc.GenericRunnerSettings import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession.hiveClassesArePresent import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.Utils @@ -108,7 +109,7 @@ object Main extends Logging { val builder = SparkSession.builder().config(conf) if (conf.get(CATALOG_IMPLEMENTATION.key, "hive") == "hive") { - if (SparkSession.hiveClassesArePresent) { + if (hiveClassesArePresent) { // In the case that the property is not set at all, builder's config // does not have this value set to 'hive' yet. The original default // behavior is that when there are hive classes, we use hive catalog. diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 327ef3d074207..8969bc8b5e2b9 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -25,7 +25,7 @@ import org.apache.logging.log4j.core.{Logger, LoggerContext} import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION class ReplSuite extends SparkFunSuite { diff --git a/sql/api/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java b/sql/api/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java index 91a1231ec0303..e37a5dff01fd6 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java +++ b/sql/api/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java @@ -25,7 +25,7 @@ import org.apache.spark.sql.internal.TypedSumLong; /** - * Type-safe functions available for {@link org.apache.spark.sql.api.Dataset} operations in Java. + * Type-safe functions available for {@link org.apache.spark.sql.Dataset} operations in Java. * * Scala users should use {@link org.apache.spark.sql.expressions.scalalang.typed}. * diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala similarity index 85% rename from sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala rename to sql/api/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ef6cc64c058a4..005e418a3b29e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -14,14 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql -import scala.jdk.CollectionConverters._ +import java.util -import _root_.java.util +import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.Row import org.apache.spark.util.ArrayImplicits._ /** @@ -37,7 +36,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(): Dataset[Row] = drop("any") + def drop(): DataFrame = drop("any") /** * Returns a new `DataFrame` that drops rows containing null or NaN values. @@ -47,7 +46,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(how: String): Dataset[Row] = drop(toMinNonNulls(how)) + def drop(how: String): DataFrame = drop(toMinNonNulls(how)) /** * Returns a new `DataFrame` that drops rows containing any null or NaN values in the specified @@ -55,7 +54,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(cols: Array[String]): Dataset[Row] = drop(cols.toImmutableArraySeq) + def drop(cols: Array[String]): DataFrame = drop(cols.toImmutableArraySeq) /** * (Scala-specific) Returns a new `DataFrame` that drops rows containing any null or NaN values @@ -63,7 +62,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(cols: Seq[String]): Dataset[Row] = drop(cols.size, cols) + def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols) /** * Returns a new `DataFrame` that drops rows containing null or NaN values in the specified @@ -74,7 +73,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(how: String, cols: Array[String]): Dataset[Row] = drop(how, cols.toImmutableArraySeq) + def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toImmutableArraySeq) /** * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values in @@ -85,7 +84,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(how: String, cols: Seq[String]): Dataset[Row] = drop(toMinNonNulls(how), cols) + def drop(how: String, cols: Seq[String]): DataFrame = drop(toMinNonNulls(how), cols) /** * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and @@ -93,7 +92,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(minNonNulls: Int): Dataset[Row] = drop(Option(minNonNulls)) + def drop(minNonNulls: Int): DataFrame = drop(Option(minNonNulls)) /** * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and @@ -101,7 +100,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(minNonNulls: Int, cols: Array[String]): Dataset[Row] = + def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toImmutableArraySeq) /** @@ -110,7 +109,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(minNonNulls: Int, cols: Seq[String]): Dataset[Row] = drop(Option(minNonNulls), cols) + def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = drop(Option(minNonNulls), cols) private def toMinNonNulls(how: String): Option[Int] = { how.toLowerCase(util.Locale.ROOT) match { @@ -120,29 +119,29 @@ abstract class DataFrameNaFunctions { } } - protected def drop(minNonNulls: Option[Int]): Dataset[Row] + protected def drop(minNonNulls: Option[Int]): DataFrame - protected def drop(minNonNulls: Option[Int], cols: Seq[String]): Dataset[Row] + protected def drop(minNonNulls: Option[Int], cols: Seq[String]): DataFrame /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * * @since 2.2.0 */ - def fill(value: Long): Dataset[Row] + def fill(value: Long): DataFrame /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * @since 1.3.1 */ - def fill(value: Double): Dataset[Row] + def fill(value: Double): DataFrame /** * Returns a new `DataFrame` that replaces null values in string columns with `value`. * * @since 1.3.1 */ - def fill(value: String): Dataset[Row] + def fill(value: String): DataFrame /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a @@ -150,7 +149,7 @@ abstract class DataFrameNaFunctions { * * @since 2.2.0 */ - def fill(value: Long, cols: Array[String]): Dataset[Row] = fill(value, cols.toImmutableArraySeq) + def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a @@ -158,7 +157,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(value: Double, cols: Array[String]): Dataset[Row] = + def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) /** @@ -167,7 +166,7 @@ abstract class DataFrameNaFunctions { * * @since 2.2.0 */ - def fill(value: Long, cols: Seq[String]): Dataset[Row] + def fill(value: Long, cols: Seq[String]): DataFrame /** * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified @@ -175,7 +174,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(value: Double, cols: Seq[String]): Dataset[Row] + def fill(value: Double, cols: Seq[String]): DataFrame /** * Returns a new `DataFrame` that replaces null values in specified string columns. If a @@ -183,7 +182,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(value: String, cols: Array[String]): Dataset[Row] = + def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) /** @@ -192,14 +191,14 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(value: String, cols: Seq[String]): Dataset[Row] + def fill(value: String, cols: Seq[String]): DataFrame /** * Returns a new `DataFrame` that replaces null values in boolean columns with `value`. * * @since 2.3.0 */ - def fill(value: Boolean): Dataset[Row] + def fill(value: Boolean): DataFrame /** * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified boolean @@ -207,7 +206,7 @@ abstract class DataFrameNaFunctions { * * @since 2.3.0 */ - def fill(value: Boolean, cols: Seq[String]): Dataset[Row] + def fill(value: Boolean, cols: Seq[String]): DataFrame /** * Returns a new `DataFrame` that replaces null values in specified boolean columns. If a @@ -215,7 +214,7 @@ abstract class DataFrameNaFunctions { * * @since 2.3.0 */ - def fill(value: Boolean, cols: Array[String]): Dataset[Row] = + def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) /** @@ -234,7 +233,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(valueMap: util.Map[String, Any]): Dataset[Row] = fillMap(valueMap.asScala.toSeq) + def fill(valueMap: util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values. @@ -254,9 +253,9 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(valueMap: Map[String, Any]): Dataset[Row] = fillMap(valueMap.toSeq) + def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq) - protected def fillMap(values: Seq[(String, Any)]): Dataset[Row] + protected def fillMap(values: Seq[(String, Any)]): DataFrame /** * Replaces values matching keys in `replacement` map with the corresponding values. @@ -283,7 +282,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def replace[T](col: String, replacement: util.Map[T, T]): Dataset[Row] = { + def replace[T](col: String, replacement: util.Map[T, T]): DataFrame = { replace[T](col, replacement.asScala.toMap) } @@ -309,7 +308,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def replace[T](cols: Array[String], replacement: util.Map[T, T]): Dataset[Row] = { + def replace[T](cols: Array[String], replacement: util.Map[T, T]): DataFrame = { replace(cols.toImmutableArraySeq, replacement.asScala.toMap) } @@ -336,7 +335,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def replace[T](col: String, replacement: Map[T, T]): Dataset[Row] + def replace[T](col: String, replacement: Map[T, T]): DataFrame /** * (Scala-specific) Replaces values matching keys in `replacement` map. @@ -358,5 +357,5 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def replace[T](cols: Seq[String], replacement: Map[T, T]): Dataset[Row] + def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameReader.scala similarity index 94% rename from sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala rename to sql/api/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 8c88387714228..95f1eca665784 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -14,16 +14,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql -import scala.jdk.CollectionConverters._ +import java.util -import _root_.java.util +import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils} import org.apache.spark.sql.errors.DataTypeErrors @@ -37,7 +36,6 @@ import org.apache.spark.sql.types.StructType */ @Stable abstract class DataFrameReader { - type DS[U] <: Dataset[U] /** * Specifies the input data source format. @@ -152,7 +150,7 @@ abstract class DataFrameReader { * * @since 1.4.0 */ - def load(): Dataset[Row] + def load(): DataFrame /** * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a @@ -160,7 +158,7 @@ abstract class DataFrameReader { * * @since 1.4.0 */ - def load(path: String): Dataset[Row] + def load(path: String): DataFrame /** * Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if @@ -169,7 +167,7 @@ abstract class DataFrameReader { * @since 1.6.0 */ @scala.annotation.varargs - def load(paths: String*): Dataset[Row] + def load(paths: String*): DataFrame /** * Construct a `DataFrame` representing the database table accessible via JDBC URL url named @@ -182,7 +180,7 @@ abstract class DataFrameReader { * * @since 1.4.0 */ - def jdbc(url: String, table: String, properties: util.Properties): Dataset[Row] = { + def jdbc(url: String, table: String, properties: util.Properties): DataFrame = { assertNoSpecifiedSchema("jdbc") // properties should override settings in extraOptions. this.extraOptions ++= properties.asScala @@ -226,7 +224,7 @@ abstract class DataFrameReader { lowerBound: Long, upperBound: Long, numPartitions: Int, - connectionProperties: util.Properties): Dataset[Row] = { + connectionProperties: util.Properties): DataFrame = { // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions. this.extraOptions ++= Map( "partitionColumn" -> columnName, @@ -263,7 +261,7 @@ abstract class DataFrameReader { url: String, table: String, predicates: Array[String], - connectionProperties: util.Properties): Dataset[Row] + connectionProperties: util.Properties): DataFrame /** * Loads a JSON file and returns the results as a `DataFrame`. @@ -272,7 +270,7 @@ abstract class DataFrameReader { * * @since 1.4.0 */ - def json(path: String): Dataset[Row] = { + def json(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 json(Seq(path): _*) } @@ -293,7 +291,7 @@ abstract class DataFrameReader { * @since 2.0.0 */ @scala.annotation.varargs - def json(paths: String*): Dataset[Row] = { + def json(paths: String*): DataFrame = { validateJsonSchema() format("json").load(paths: _*) } @@ -309,7 +307,7 @@ abstract class DataFrameReader { * input Dataset with one JSON object per record * @since 2.2.0 */ - def json(jsonDataset: DS[String]): Dataset[Row] + def json(jsonDataset: Dataset[String]): DataFrame /** * Loads a `JavaRDD[String]` storing JSON objects (JSON Lines @@ -325,7 +323,7 @@ abstract class DataFrameReader { * @since 1.4.0 */ @deprecated("Use json(Dataset[String]) instead.", "2.2.0") - def json(jsonRDD: JavaRDD[String]): DS[Row] + def json(jsonRDD: JavaRDD[String]): DataFrame /** * Loads an `RDD[String]` storing JSON objects (JSON Lines text @@ -341,7 +339,7 @@ abstract class DataFrameReader { * @since 1.4.0 */ @deprecated("Use json(Dataset[String]) instead.", "2.2.0") - def json(jsonRDD: RDD[String]): DS[Row] + def json(jsonRDD: RDD[String]): DataFrame /** * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other @@ -349,7 +347,7 @@ abstract class DataFrameReader { * * @since 2.0.0 */ - def csv(path: String): Dataset[Row] = { + def csv(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 csv(Seq(path): _*) } @@ -375,7 +373,7 @@ abstract class DataFrameReader { * input Dataset with one CSV row per record * @since 2.2.0 */ - def csv(csvDataset: DS[String]): Dataset[Row] + def csv(csvDataset: Dataset[String]): DataFrame /** * Loads CSV files and returns the result as a `DataFrame`. @@ -391,7 +389,7 @@ abstract class DataFrameReader { * @since 2.0.0 */ @scala.annotation.varargs - def csv(paths: String*): Dataset[Row] = format("csv").load(paths: _*) + def csv(paths: String*): DataFrame = format("csv").load(paths: _*) /** * Loads a XML file and returns the result as a `DataFrame`. See the documentation on the other @@ -399,7 +397,7 @@ abstract class DataFrameReader { * * @since 4.0.0 */ - def xml(path: String): Dataset[Row] = { + def xml(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 xml(Seq(path): _*) } @@ -418,7 +416,7 @@ abstract class DataFrameReader { * @since 4.0.0 */ @scala.annotation.varargs - def xml(paths: String*): Dataset[Row] = { + def xml(paths: String*): DataFrame = { validateXmlSchema() format("xml").load(paths: _*) } @@ -433,7 +431,7 @@ abstract class DataFrameReader { * input Dataset with one XML object per record * @since 4.0.0 */ - def xml(xmlDataset: DS[String]): Dataset[Row] + def xml(xmlDataset: Dataset[String]): DataFrame /** * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation on the @@ -441,7 +439,7 @@ abstract class DataFrameReader { * * @since 2.0.0 */ - def parquet(path: String): Dataset[Row] = { + def parquet(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 parquet(Seq(path): _*) } @@ -456,7 +454,7 @@ abstract class DataFrameReader { * @since 1.4.0 */ @scala.annotation.varargs - def parquet(paths: String*): Dataset[Row] = format("parquet").load(paths: _*) + def parquet(paths: String*): DataFrame = format("parquet").load(paths: _*) /** * Loads an ORC file and returns the result as a `DataFrame`. @@ -465,7 +463,7 @@ abstract class DataFrameReader { * input path * @since 1.5.0 */ - def orc(path: String): Dataset[Row] = { + def orc(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 orc(Seq(path): _*) } @@ -482,7 +480,7 @@ abstract class DataFrameReader { * @since 2.0.0 */ @scala.annotation.varargs - def orc(paths: String*): Dataset[Row] = format("orc").load(paths: _*) + def orc(paths: String*): DataFrame = format("orc").load(paths: _*) /** * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch @@ -497,7 +495,7 @@ abstract class DataFrameReader { * database. Note that, the global temporary view database is also valid here. * @since 1.4.0 */ - def table(tableName: String): Dataset[Row] + def table(tableName: String): DataFrame /** * Loads text files and returns a `DataFrame` whose schema starts with a string column named @@ -506,7 +504,7 @@ abstract class DataFrameReader { * * @since 2.0.0 */ - def text(path: String): Dataset[Row] = { + def text(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 text(Seq(path): _*) } @@ -534,7 +532,7 @@ abstract class DataFrameReader { * @since 1.6.0 */ @scala.annotation.varargs - def text(paths: String*): Dataset[Row] = format("text").load(paths: _*) + def text(paths: String*): DataFrame = format("text").load(paths: _*) /** * Loads text files and returns a [[Dataset]] of String. See the documentation on the other diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala similarity index 97% rename from sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala rename to sql/api/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index ae7c256b30ace..db74282dc1188 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -14,14 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql -import scala.jdk.CollectionConverters._ +import java.{lang => jl, util => ju} -import _root_.java.{lang => jl, util => ju} +import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.BinaryEncoder import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.functions.{count_min_sketch, lit} @@ -35,7 +34,7 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} */ @Stable abstract class DataFrameStatFunctions { - protected def df: Dataset[Row] + protected def df: DataFrame /** * Calculates the approximate quantiles of a numerical column of a DataFrame. @@ -202,7 +201,7 @@ abstract class DataFrameStatFunctions { * * @since 1.4.0 */ - def crosstab(col1: String, col2: String): Dataset[Row] + def crosstab(col1: String, col2: String): DataFrame /** * Finding frequent items for columns, possibly with false positives. Using the frequent element @@ -246,7 +245,7 @@ abstract class DataFrameStatFunctions { * }}} * @since 1.4.0 */ - def freqItems(cols: Array[String], support: Double): Dataset[Row] = + def freqItems(cols: Array[String], support: Double): DataFrame = freqItems(cols.toImmutableArraySeq, support) /** @@ -263,7 +262,7 @@ abstract class DataFrameStatFunctions { * A Local DataFrame with the Array of frequent items for each column. * @since 1.4.0 */ - def freqItems(cols: Array[String]): Dataset[Row] = freqItems(cols, 0.01) + def freqItems(cols: Array[String]): DataFrame = freqItems(cols, 0.01) /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the @@ -307,7 +306,7 @@ abstract class DataFrameStatFunctions { * * @since 1.4.0 */ - def freqItems(cols: Seq[String], support: Double): Dataset[Row] + def freqItems(cols: Seq[String], support: Double): DataFrame /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the @@ -324,7 +323,7 @@ abstract class DataFrameStatFunctions { * A Local DataFrame with the Array of frequent items for each column. * @since 1.4.0 */ - def freqItems(cols: Seq[String]): Dataset[Row] = freqItems(cols, 0.01) + def freqItems(cols: Seq[String]): DataFrame = freqItems(cols, 0.01) /** * Returns a stratified sample without replacement based on the fraction given on each stratum. @@ -356,7 +355,7 @@ abstract class DataFrameStatFunctions { * * @since 1.5.0 */ - def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): Dataset[Row] = { + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { sampleBy(Column(col), fractions, seed) } @@ -376,7 +375,7 @@ abstract class DataFrameStatFunctions { * * @since 1.5.0 */ - def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): Dataset[Row] = { + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } @@ -413,7 +412,7 @@ abstract class DataFrameStatFunctions { * * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): Dataset[Row] + def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame /** * (Java-specific) Returns a stratified sample without replacement based on the fraction given @@ -432,7 +431,7 @@ abstract class DataFrameStatFunctions { * a new `DataFrame` that represents the stratified sample * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): Dataset[Row] = { + def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 565b64aa95472..0bde59155f3f3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.CompilationErrors /** - * Interface used to write a [[org.apache.spark.sql.api.Dataset]] to external storage systems - * (e.g. file systems, key-value stores, etc). Use `Dataset.write` to access this. + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage systems (e.g. + * file systems, key-value stores, etc). Use `Dataset.write` to access this. * * @since 1.4.0 */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 37a29c2e4b66d..66a4b4232a22d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} /** - * Interface used to write a [[org.apache.spark.sql.api.Dataset]] to external storage using the v2 + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 * API. * * @since 3.0.0 diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala similarity index 95% rename from sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala rename to sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala index 20c181e7b9cf6..c49a5d5a50886 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -14,20 +14,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql + +import java.util import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag -import _root_.java.util - import org.apache.spark.annotation.{DeveloperApi, Stable, Unstable} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction, ReduceFunction} +import org.apache.spark.api.java.function._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, DataFrameWriterV2, Encoder, MergeIntoWriter, Observation, Row, TypedColumn} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors} +import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types.{Metadata, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ @@ -123,9 +123,8 @@ import org.apache.spark.util.SparkClassUtils */ @Stable abstract class Dataset[T] extends Serializable { - type DS[U] <: Dataset[U] - def sparkSession: SparkSession + val sparkSession: SparkSession val encoder: Encoder[T] @@ -141,7 +140,7 @@ abstract class Dataset[T] extends Serializable { */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): Dataset[Row] + def toDF(): DataFrame /** * Returns a new Dataset where each record has been mapped on to the specified type. The method @@ -180,7 +179,7 @@ abstract class Dataset[T] extends Serializable { * @group basic * @since 3.4.0 */ - def to(schema: StructType): Dataset[Row] + def to(schema: StructType): DataFrame /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -196,7 +195,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def toDF(colNames: String*): Dataset[Row] + def toDF(colNames: String*): DataFrame /** * Returns the schema of this Dataset. @@ -611,7 +610,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_]): Dataset[Row] + def join(right: Dataset[_]): DataFrame /** * Inner equi-join with another `DataFrame` using the given column. @@ -637,7 +636,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumn: String): Dataset[Row] = { + def join(right: Dataset[_], usingColumn: String): DataFrame = { join(right, Seq(usingColumn)) } @@ -653,7 +652,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumns: Array[String]): Dataset[Row] = { + def join(right: Dataset[_], usingColumns: Array[String]): DataFrame = { join(right, usingColumns.toImmutableArraySeq) } @@ -681,7 +680,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumns: Seq[String]): Dataset[Row] = { + def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame = { join(right, usingColumns, "inner") } @@ -711,7 +710,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumn: String, joinType: String): Dataset[Row] = { + def join(right: Dataset[_], usingColumn: String, joinType: String): DataFrame = { join(right, Seq(usingColumn), joinType) } @@ -732,7 +731,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumns: Array[String], joinType: String): Dataset[Row] = { + def join(right: Dataset[_], usingColumns: Array[String], joinType: String): DataFrame = { join(right, usingColumns.toImmutableArraySeq, joinType) } @@ -762,7 +761,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumns: Seq[String], joinType: String): Dataset[Row] + def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame /** * Inner join with another `DataFrame`, using the given join expression. @@ -776,7 +775,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], joinExprs: Column): Dataset[Row] = + def join(right: Dataset[_], joinExprs: Column): DataFrame = join(right, joinExprs, "inner") /** @@ -806,7 +805,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], joinExprs: Column, joinType: String): Dataset[Row] + def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame /** * Explicit cartesian join with another `DataFrame`. @@ -818,7 +817,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.1.0 */ - def crossJoin(right: DS[_]): Dataset[Row] + def crossJoin(right: Dataset[_]): DataFrame /** * Joins this Dataset returning a `Tuple2` for each pair where `condition` evaluates to true. @@ -842,7 +841,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: DS[U], condition: Column, joinType: String): Dataset[(T, U)] + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] /** * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair where @@ -855,7 +854,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: DS[U], condition: Column): Dataset[(T, U)] = { + def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { joinWith(other, condition, "inner") } @@ -869,7 +868,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def lateralJoin(right: DS[_]): Dataset[Row] + def lateralJoin(right: Dataset[_]): DataFrame /** * Lateral join with another `DataFrame`. @@ -883,7 +882,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def lateralJoin(right: DS[_], joinExprs: Column): Dataset[Row] + def lateralJoin(right: Dataset[_], joinExprs: Column): DataFrame /** * Lateral join with another `DataFrame`. @@ -896,7 +895,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def lateralJoin(right: DS[_], joinType: String): Dataset[Row] + def lateralJoin(right: Dataset[_], joinType: String): DataFrame /** * Lateral join with another `DataFrame`. @@ -911,7 +910,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def lateralJoin(right: DS[_], joinExprs: Column, joinType: String): Dataset[Row] + def lateralJoin(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame protected def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] @@ -1101,7 +1100,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def select(cols: Column*): Dataset[Row] + def select(cols: Column*): DataFrame /** * Selects a set of columns. This is a variant of `select` that can only select existing columns @@ -1117,7 +1116,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): Dataset[Row] = select((col +: cols).map(Column(_)): _*) + def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)): _*) /** * Selects a set of SQL expressions. This is a variant of `select` that accepts SQL expressions. @@ -1132,7 +1131,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def selectExpr(exprs: String*): Dataset[Row] = select(exprs.map(functions.expr): _*) + def selectExpr(exprs: String*): DataFrame = select(exprs.map(functions.expr): _*) /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expression for @@ -1449,7 +1448,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): Dataset[Row] = { + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { groupBy().agg(aggExpr, aggExprs: _*) } @@ -1464,7 +1463,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(exprs: Map[String, String]): Dataset[Row] = groupBy().agg(exprs) + def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) /** * (Java-specific) Aggregates on the entire Dataset without groups. @@ -1477,7 +1476,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(exprs: util.Map[String, String]): Dataset[Row] = groupBy().agg(exprs) + def agg(exprs: util.Map[String, String]): DataFrame = groupBy().agg(exprs) /** * Aggregates on the entire Dataset without groups. @@ -1491,7 +1490,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): Dataset[Row] = groupBy().agg(expr, exprs: _*) + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs: _*) /** * (Scala-specific) Reduces the elements of this Dataset using the specified binary function. @@ -1594,7 +1593,7 @@ abstract class Dataset[T] extends Serializable { ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): Dataset[Row] + valueColumnName: String): DataFrame /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns @@ -1617,10 +1616,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def unpivot( - ids: Array[Column], - variableColumnName: String, - valueColumnName: String): Dataset[Row] + def unpivot(ids: Array[Column], variableColumnName: String, valueColumnName: String): DataFrame /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns @@ -1644,7 +1640,7 @@ abstract class Dataset[T] extends Serializable { ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): Dataset[Row] = + valueColumnName: String): DataFrame = unpivot(ids, values, variableColumnName, valueColumnName) /** @@ -1666,10 +1662,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def melt( - ids: Array[Column], - variableColumnName: String, - valueColumnName: String): Dataset[Row] = + def melt(ids: Array[Column], variableColumnName: String, valueColumnName: String): DataFrame = unpivot(ids, variableColumnName, valueColumnName) /** @@ -1732,7 +1725,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def transpose(indexColumn: Column): Dataset[Row] + def transpose(indexColumn: Column): DataFrame /** * Transposes a DataFrame, switching rows to columns. This function transforms the DataFrame @@ -1751,7 +1744,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def transpose(): Dataset[Row] + def transpose(): DataFrame /** * Return a `Column` object for a SCALAR Subquery containing exactly one row and one column. @@ -1872,7 +1865,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def union(other: DS[T]): Dataset[T] + def union(other: Dataset[T]): Dataset[T] /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. This is @@ -1886,7 +1879,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def unionAll(other: DS[T]): Dataset[T] = union(other) + def unionAll(other: Dataset[T]): Dataset[T] = union(other) /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. @@ -1917,7 +1910,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.3.0 */ - def unionByName(other: DS[T]): Dataset[T] = unionByName(other, allowMissingColumns = false) + def unionByName(other: Dataset[T]): Dataset[T] = unionByName(other, allowMissingColumns = false) /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. @@ -1961,7 +1954,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 3.1.0 */ - def unionByName(other: DS[T], allowMissingColumns: Boolean): Dataset[T] + def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] /** * Returns a new Dataset containing rows only in both this Dataset and another Dataset. This is @@ -1973,7 +1966,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 1.6.0 */ - def intersect(other: DS[T]): Dataset[T] + def intersect(other: Dataset[T]): Dataset[T] /** * Returns a new Dataset containing rows only in both this Dataset and another Dataset while @@ -1986,7 +1979,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.4.0 */ - def intersectAll(other: DS[T]): Dataset[T] + def intersectAll(other: Dataset[T]): Dataset[T] /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. This is @@ -1998,7 +1991,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def except(other: DS[T]): Dataset[T] + def except(other: Dataset[T]): Dataset[T] /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset while @@ -2011,7 +2004,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.4.0 */ - def exceptAll(other: DS[T]): Dataset[T] + def exceptAll(other: Dataset[T]): Dataset[T] /** * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), using a @@ -2096,7 +2089,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def randomSplit(weights: Array[Double], seed: Long): Array[_ <: Dataset[T]] + def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] /** * Returns a Java list that contains randomly split Dataset with the provided weights. @@ -2108,7 +2101,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def randomSplitAsList(weights: Array[Double], seed: Long): util.List[_ <: Dataset[T]] + def randomSplitAsList(weights: Array[Double], seed: Long): util.List[Dataset[T]] /** * Randomly splits this Dataset with the provided weights. @@ -2118,7 +2111,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def randomSplit(weights: Array[Double]): Array[_ <: Dataset[T]] + def randomSplit(weights: Array[Double]): Array[Dataset[T]] /** * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more rows @@ -2148,7 +2141,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") - def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): Dataset[Row] + def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame /** * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero or @@ -2174,7 +2167,7 @@ abstract class Dataset[T] extends Serializable { */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)( - f: A => IterableOnce[B]): Dataset[Row] + f: A => IterableOnce[B]): DataFrame /** * Returns a new Dataset by adding a column or replacing the existing column that has the same @@ -2191,7 +2184,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def withColumn(colName: String, col: Column): Dataset[Row] = withColumns(Seq(colName), Seq(col)) + def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col)) /** * (Scala-specific) Returns a new Dataset by adding columns or replacing the existing columns @@ -2203,7 +2196,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withColumns(colsMap: Map[String, Column]): Dataset[Row] = { + def withColumns(colsMap: Map[String, Column]): DataFrame = { val (colNames, newCols) = colsMap.toSeq.unzip withColumns(colNames, newCols) } @@ -2218,14 +2211,41 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withColumns(colsMap: util.Map[String, Column]): Dataset[Row] = withColumns( + def withColumns(colsMap: util.Map[String, Column]): DataFrame = withColumns( colsMap.asScala.toMap) /** * Returns a new Dataset by adding columns or replacing the existing columns that has the same * names. */ - protected def withColumns(colNames: Seq[String], cols: Seq[Column]): Dataset[Row] + private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame + + /** + * Returns a new Dataset by adding columns with metadata. + */ + private[spark] def withColumns( + colNames: Seq[String], + cols: Seq[Column], + metadata: Seq[Metadata]): DataFrame = { + require( + colNames.size == cols.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of columns: ${cols.size}") + require( + colNames.size == metadata.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of metadata elements: ${metadata.size}") + val colsWithMetadata = colNames.zip(cols).zip(metadata).map { case ((colName, col), meta) => + col.as(colName, meta) + } + withColumns(colNames, colsWithMetadata) + } + + /** + * Returns a new Dataset by adding a column with metadata. + */ + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = + withColumn(colName, col.as(colName, metadata)) /** * Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain @@ -2234,7 +2254,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def withColumnRenamed(existingName: String, newName: String): Dataset[Row] = + def withColumnRenamed(existingName: String, newName: String): DataFrame = withColumnsRenamed(Seq(existingName), Seq(newName)) /** @@ -2249,7 +2269,7 @@ abstract class Dataset[T] extends Serializable { * @since 3.4.0 */ @throws[AnalysisException] - def withColumnsRenamed(colsMap: Map[String, String]): Dataset[Row] = { + def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = { val (colNames, newColNames) = colsMap.toSeq.unzip withColumnsRenamed(colNames, newColNames) } @@ -2263,10 +2283,10 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def withColumnsRenamed(colsMap: util.Map[String, String]): Dataset[Row] = + def withColumnsRenamed(colsMap: util.Map[String, String]): DataFrame = withColumnsRenamed(colsMap.asScala.toMap) - protected def withColumnsRenamed(colNames: Seq[String], newColNames: Seq[String]): Dataset[Row] + protected def withColumnsRenamed(colNames: Seq[String], newColNames: Seq[String]): DataFrame /** * Returns a new Dataset by updating an existing column with metadata. @@ -2274,7 +2294,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withMetadata(columnName: String, metadata: Metadata): Dataset[Row] + def withMetadata(columnName: String, metadata: Metadata): DataFrame /** * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain column @@ -2347,7 +2367,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def drop(colName: String): Dataset[Row] = drop(colName :: Nil: _*) + def drop(colName: String): DataFrame = drop(colName :: Nil: _*) /** * Returns a new Dataset with columns dropped. This is a no-op if schema doesn't contain column @@ -2360,7 +2380,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def drop(colNames: String*): Dataset[Row] + def drop(colNames: String*): DataFrame /** * Returns a new Dataset with column dropped. @@ -2375,7 +2395,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def drop(col: Column): Dataset[Row] = drop(col, Nil: _*) + def drop(col: Column): DataFrame = drop(col, Nil: _*) /** * Returns a new Dataset with columns dropped. @@ -2387,7 +2407,7 @@ abstract class Dataset[T] extends Serializable { * @since 3.4.0 */ @scala.annotation.varargs - def drop(col: Column, cols: Column*): Dataset[Row] + def drop(col: Column, cols: Column*): DataFrame /** * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias @@ -2567,7 +2587,7 @@ abstract class Dataset[T] extends Serializable { * @since 1.6.0 */ @scala.annotation.varargs - def describe(cols: String*): Dataset[Row] + def describe(cols: String*): DataFrame /** * Computes specified statistics for numeric and string columns. Available statistics are: