diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java index dbbbac43eb741..802e373ddd112 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java @@ -19,6 +19,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; +import java.io.IOException; import java.nio.channels.WritableByteChannel; /** @@ -26,5 +27,5 @@ * after the writer has been closed. Used with DownloadFile and DownloadFileManager. */ public interface DownloadFileWritableChannel extends WritableByteChannel { - ManagedBuffer closeAndRead(); + ManagedBuffer closeAndRead() throws IOException; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java index 97ecaa627b66c..0acca3fdd6a44 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java @@ -69,7 +69,8 @@ private class SimpleDownloadWritableChannel implements DownloadFileWritableChann } @Override - public ManagedBuffer closeAndRead() { + public ManagedBuffer closeAndRead() throws IOException { + channel.close(); return new FileSegmentManagedBuffer(transportConf, file, 0, file.length()); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SimpleDownloadFileSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SimpleDownloadFileSuite.java new file mode 100644 index 0000000000000..120455eeb9351 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SimpleDownloadFileSuite.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; + +import org.junit.jupiter.api.Assertions; + +public class SimpleDownloadFileSuite { + @Test + public void testChannelIsClosedAfterCloseAndRead() throws IOException { + File tempFile = File.createTempFile("testChannelIsClosed", ".tmp"); + tempFile.deleteOnExit(); + TransportConf conf = new TransportConf("test", MapConfigProvider.EMPTY); + + DownloadFile downloadFile = null; + try { + downloadFile = new SimpleDownloadFile(tempFile, conf); + DownloadFileWritableChannel channel = downloadFile.openForWriting(); + channel.closeAndRead(); + Assertions.assertFalse(channel.isOpen(), "Channel should be closed after closeAndRead."); + } finally { + if (downloadFile != null) { + downloadFile.delete(); + } + } + } +} diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 13fdbce0211fb..302a0275491a2 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3135,6 +3135,24 @@ ], "sqlState" : "42836" }, + "INVALID_RECURSIVE_REFERENCE" : { + "message" : [ + "Invalid recursive reference found inside WITH RECURSIVE clause." + ], + "subClass" : { + "NUMBER" : { + "message" : [ + "Multiple self-references to one recursive CTE are not allowed." + ] + }, + "PLACE" : { + "message" : [ + "Recursive references cannot be used on the right side of left outer/semi/anti joins, on the left side of right outer joins, in full outer joins, in aggregates, and in subquery expressions." + ] + } + }, + "sqlState" : "42836" + }, "INVALID_REGEXP_REPLACE" : { "message" : [ "Could not perform regexp_replace for source = \"\", pattern = \"\", replacement = \"\" and position = ." @@ -5259,6 +5277,11 @@ "Resilient Distributed Datasets (RDDs)." ] }, + "REGISTER_UDAF" : { + "message" : [ + "Registering User Defined Aggregate Functions (UDAFs)." + ] + }, "SESSION_BASE_RELATION_TO_DATAFRAME" : { "message" : [ "Invoking SparkSession 'baseRelationToDataFrame'. This is server side developer API" diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 0df6a7c4bc90e..8342ca4e84275 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -25,7 +25,7 @@ import java.util.UUID import scala.jdk.CollectionConverters._ -import org.apache.avro.{AvroTypeException, Schema, SchemaBuilder} +import org.apache.avro.{AvroTypeException, Schema, SchemaBuilder, SchemaFormatter} import org.apache.avro.Schema.{Field, Type} import org.apache.avro.Schema.Type._ import org.apache.avro.file.{DataFileReader, DataFileWriter} @@ -86,7 +86,7 @@ abstract class AvroSuite } def getAvroSchemaStringFromFiles(filePath: String): String = { - new DataFileReader({ + val schema = new DataFileReader({ val file = new File(filePath) if (file.isFile) { file @@ -96,7 +96,8 @@ abstract class AvroSuite .filter(_.getName.endsWith("avro")) .head } - }, new GenericDatumReader[Any]()).getSchema.toString(false) + }, new GenericDatumReader[Any]()).getSchema + SchemaFormatter.format(AvroUtils.JSON_INLINE_FORMAT, schema) } // Check whether an Avro schema of union type is converted to SQL in an expected way, when the diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala index bff6db25a21f2..ca82381eec9e3 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala @@ -31,8 +31,8 @@ import ammonite.util.Util.newLine import org.apache.spark.SparkBuildInfo.spark_version import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.SparkSession.withLocalConnectServer +import org.apache.spark.sql.connect.SparkSession +import org.apache.spark.sql.connect.SparkSession.withLocalConnectServer import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkConnectClientParser} /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala deleted file mode 100644 index 86b1dbe4754e6..0000000000000 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalog - -import java.util - -import org.apache.spark.sql.{api, DataFrame, Dataset} -import org.apache.spark.sql.connect.ConnectConversions._ -import org.apache.spark.sql.types.StructType - -/** @inheritdoc */ -abstract class Catalog extends api.Catalog { - - /** @inheritdoc */ - override def listDatabases(): Dataset[Database] - - /** @inheritdoc */ - override def listDatabases(pattern: String): Dataset[Database] - - /** @inheritdoc */ - override def listTables(): Dataset[Table] - - /** @inheritdoc */ - override def listTables(dbName: String): Dataset[Table] - - /** @inheritdoc */ - override def listTables(dbName: String, pattern: String): Dataset[Table] - - /** @inheritdoc */ - override def listFunctions(): Dataset[Function] - - /** @inheritdoc */ - override def listFunctions(dbName: String): Dataset[Function] - - /** @inheritdoc */ - override def listFunctions(dbName: String, pattern: String): Dataset[Function] - - /** @inheritdoc */ - override def listColumns(tableName: String): Dataset[Column] - - /** @inheritdoc */ - override def listColumns(dbName: String, tableName: String): Dataset[Column] - - /** @inheritdoc */ - override def createTable(tableName: String, path: String): DataFrame - - /** @inheritdoc */ - override def createTable(tableName: String, path: String, source: String): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - description: String, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - description: String, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def listCatalogs(): Dataset[CatalogMetadata] - - /** @inheritdoc */ - override def listCatalogs(pattern: String): Dataset[CatalogMetadata] - - /** @inheritdoc */ - override def createExternalTable(tableName: String, path: String): DataFrame = - super.createExternalTable(tableName, path) - - /** @inheritdoc */ - override def createExternalTable(tableName: String, path: String, source: String): DataFrame = - super.createExternalTable(tableName, path, source) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - options: util.Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, options) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - options: Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, options) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - schema: StructType, - options: util.Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, schema, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - description: String, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, description, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, schema, options) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - schema: StructType, - options: Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, schema, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - description: String, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, schema, description, options) -} diff --git a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java index e5f06eb7dbcec..907105e370c08 100644 --- a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java +++ b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java @@ -28,7 +28,7 @@ import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.RowFactory.create; import org.apache.spark.api.java.function.MapFunction; -import org.apache.spark.sql.test.SparkConnectServerUtils; +import org.apache.spark.sql.connect.test.SparkConnectServerUtils; import org.apache.spark.sql.types.StructType; /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala index 33c9911e75c9b..ba77879a5a800 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql import org.apache.spark.SparkRuntimeException +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} class DataFrameSubquerySuite extends QueryTest with RemoteSparkSession { import testImplicits._ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala index 12a49ad21676e..d11f276e8ed47 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} class DataFrameTableValuedFunctionsSuite extends QueryTest with RemoteSparkSession { import testImplicits._ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 8aed696005799..a548ec7007dbe 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -38,12 +38,13 @@ import org.apache.spark.sql.avro.{functions => avroFn} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.connect.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.test.{ConnectFunSuite, IntegrationTestUtils} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.protobuf.{functions => pbFn} -import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.SparkFileUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLExpressionsSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLExpressionsSuite.scala index fcd2b3a388042..83ad943a2253e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLExpressionsSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLExpressionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.unsafe.types.VariantVal class SQLExpressionsSuite extends QueryTest with RemoteSparkSession { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index bb7d1b25738c1..3596de39e86a3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -25,7 +25,7 @@ import scala.util.Properties import org.apache.commons.io.output.ByteArrayOutputStream import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession} import org.apache.spark.tags.AmmoniteTest import org.apache.spark.util.IvyTestUtils import org.apache.spark.util.MavenUtils.MavenCoordinate diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CatalogSuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CatalogSuite.scala index ce552bdd4f0f0..b2c19226dc542 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CatalogSuite.scala @@ -15,14 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.io.{File, FilenameFilter} import org.apache.commons.io.FileUtils import org.apache.spark.SparkException -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} import org.apache.spark.sql.types.{DoubleType, LongType, StructType} import org.apache.spark.storage.StorageLevel diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CheckpointSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CheckpointSuite.scala index 0d9685d9c710f..eeda75ae13310 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CheckpointSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.io.{ByteArrayOutputStream, PrintStream} @@ -26,7 +26,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.apache.spark.SparkException import org.apache.spark.connect.proto -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} import org.apache.spark.storage.StorageLevel class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelper { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala index 84ed624a95214..a7e2e61a106f9 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala @@ -15,14 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.Random import org.scalatest.matchers.must.Matchers._ import org.apache.spark.SparkIllegalArgumentException -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} class ClientDataFrameStatSuite extends ConnectFunSuite with RemoteSparkSession { private def toLetter(i: Int): String = (i + 97).toChar.toString diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDatasetSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDatasetSuite.scala index c93f9b1c0dbdd..7e6cebfd972df 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDatasetSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.Properties import java.util.concurrent.TimeUnit @@ -25,9 +25,10 @@ import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder} import org.scalatest.BeforeAndAfterEach import org.apache.spark.connect.proto +import org.apache.spark.sql.Column import org.apache.spark.sql.connect.client.{DummySparkConnectService, SparkConnectClient} +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.SparkSerDeUtils // Add sample tests. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala similarity index 99% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala index a4615784d2e98..99374efeff1d1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.io.{ByteArrayOutputStream, PrintStream} import java.nio.file.Files @@ -34,11 +34,15 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException} import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.sql.{functions, AnalysisException, Observation, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, TableAlreadyExistsException, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession, SQLHelper} +import org.apache.spark.sql.connect.test.SparkConnectServerUtils.port import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession, SQLHelper} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala index 24e9156775c44..02f0c35c44a8f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.connect import org.apache.spark.SparkException import org.apache.spark.connect.proto @@ -23,9 +23,10 @@ import org.apache.spark.sql.{Column, Encoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder} import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ProtoDataTypes} +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator} -import org.apache.spark.sql.test.ConnectFunSuite -import org.apache.spark.sql.types.{BinaryType, DataType, DoubleType, LongType, MetadataBuilder, ShortType, StringType, StructType} +import org.apache.spark.sql.internal._ +import org.apache.spark.sql.types._ /** * Test suite for [[ColumnNode]] to [[proto.Expression]] conversions. @@ -471,8 +472,8 @@ class ColumnNodeToProtoConverterSuite extends ConnectFunSuite { } } -private[internal] case class Nope(override val origin: Origin = CurrentOrigin.get) +private[connect] case class Nope(override val origin: Origin = CurrentOrigin.get) extends ColumnNode { override def sql: String = "nope" - override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty + override def children: Seq[ColumnNodeLike] = Seq.empty } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnTestSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnTestSuite.scala index 86c7a20136851..863cb5872c72e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnTestSuite.scala @@ -14,14 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.io.ByteArrayOutputStream import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.{functions => fn} -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.{functions => fn, Column, ColumnName} +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.types._ /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameNaFunctionSuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameNaFunctionSuite.scala index 8a783d880560e..988d4d4c4f5da 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameNaFunctionSuite.scala @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import scala.jdk.CollectionConverters._ +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.types.{StringType, StructType} class DataFrameNaFunctionSuite extends QueryTest with RemoteSparkSession { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/FunctionTestSuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/FunctionTestSuite.scala index 40b66bcb8358d..8c5fc2b2b8ec9 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/FunctionTestSuite.scala @@ -14,15 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.Collections import scala.jdk.CollectionConverters._ +import org.apache.spark.sql.Column import org.apache.spark.sql.avro.{functions => avroFn} +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types.{DataType, StructType} /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/KeyValueGroupedDatasetE2ETestSuite.scala similarity index 99% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/KeyValueGroupedDatasetE2ETestSuite.scala index 021b4fea26e2a..c2046c6f26700 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/KeyValueGroupedDatasetE2ETestSuite.scala @@ -14,15 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.sql.Timestamp import java.util.Arrays +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.types._ import org.apache.spark.util.SparkSerDeUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/MergeIntoE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/MergeIntoE2ETestSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/MergeIntoE2ETestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/MergeIntoE2ETestSuite.scala index cdb72e3baf0e9..832efca96550e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/MergeIntoE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/MergeIntoE2ETestSuite.scala @@ -14,10 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect +import org.apache.spark.sql.Row +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} class MergeIntoE2ETestSuite extends ConnectFunSuite with RemoteSparkSession { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala index b3b8020b1e4c7..2791c6b6add55 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.sql.{Date, Timestamp} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} @@ -26,10 +26,11 @@ import org.apache.arrow.memory.RootAllocator import org.apache.commons.lang3.SystemUtils import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.{Column, Encoder, SaveMode} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers, ArrowSerializer} -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite /** * Test suite for SQL implicits. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionBuilderImplementationBindingSuite.scala similarity index 82% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionBuilderImplementationBindingSuite.scala index ed930882ac2fd..cc6bc8af1f4b3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionBuilderImplementationBindingSuite.scala @@ -14,17 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect -import org.apache.spark.sql.api.SparkSessionBuilder -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} +import org.apache.spark.sql +import org.apache.spark.sql.SparkSessionBuilder +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} /** * Make sure the api.SparkSessionBuilder binds to Connect implementation. */ class SparkSessionBuilderImplementationBindingSuite extends ConnectFunSuite - with api.SparkSessionBuilderImplementationBindingSuite + with sql.SparkSessionBuilderImplementationBindingSuite with RemoteSparkSession { override protected def configure(builder: SparkSessionBuilder): builder.type = { // We need to set this configuration because the port used by the server is random. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala similarity index 99% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala index b116edb7df7ce..41e49e3d47d9e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.concurrent.ForkJoinPool @@ -26,7 +26,7 @@ import scala.util.{Failure, Success} import org.scalatest.concurrent.Eventually._ import org.apache.spark.SparkException -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} import org.apache.spark.util.SparkThreadUtils.awaitResult /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionSuite.scala similarity index 99% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionSuite.scala index dec56554d143e..bab6ae39563f6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.concurrent.{Executors, Phaser} @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} import org.apache.spark.SparkException -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.util.SparkSerDeUtils /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/StubbingTestSuite.scala similarity index 91% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/StubbingTestSuite.scala index 5bcb17672d6a9..a0c9a2b992f1d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/StubbingTestSuite.scala @@ -14,10 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import org.apache.spark.sql.connect.client.ToStub -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} class StubbingTestSuite extends ConnectFunSuite with RemoteSparkSession { private def eval[T](f: => T): T = f diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UDFClassLoadingE2ESuite.scala similarity index 96% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UDFClassLoadingE2ESuite.scala index c1e44b6fb11b2..b50442de31f04 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UDFClassLoadingE2ESuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.io.File import java.nio.file.{Files, Paths} @@ -25,7 +25,7 @@ import com.google.protobuf.ByteString import org.apache.spark.connect.proto import org.apache.spark.sql.connect.common.ProtoDataTypes -import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} class UDFClassLoadingE2ESuite extends ConnectFunSuite with RemoteSparkSession { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UnsupportedFeaturesSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UnsupportedFeaturesSuite.scala similarity index 96% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UnsupportedFeaturesSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UnsupportedFeaturesSuite.scala index 42ae6987c9f36..a71d887b2d6c7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UnsupportedFeaturesSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UnsupportedFeaturesSuite.scala @@ -14,13 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types.StructType /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionE2ETestSuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionE2ETestSuite.scala index 19275326d6421..a4ef11554d30e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionE2ETestSuite.scala @@ -14,21 +14,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.lang.{Long => JLong} -import java.util.{Iterator => JIterator} -import java.util.Arrays +import java.util.{Arrays, Iterator => JIterator} import java.util.concurrent.atomic.AtomicLong import scala.jdk.CollectionConverters._ import org.apache.spark.api.java.function._ +import org.apache.spark.sql.{AnalysisException, Encoder, Encoders, Row} import org.apache.spark.sql.api.java.UDF2 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder} +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions.{col, struct, udaf, udf} -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} import org.apache.spark.sql.types.IntegerType /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionSuite.scala similarity index 96% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionSuite.scala index f7dadea98c281..153026fcdbee3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionSuite.scala @@ -14,15 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import scala.reflect.runtime.universe.typeTag import org.apache.spark.SparkException +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.connect.common.UdfPacket +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.functions.{lit, udf} -import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.SparkSerDeUtils class UserDefinedFunctionSuite extends ConnectFunSuite { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala index 66a2c943af5f6..fb35812233562 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala @@ -32,7 +32,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.connect.proto.AddArtifactsRequest import org.apache.spark.sql.Artifact import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.util.IvyTestUtils import org.apache.spark.util.MavenUtils.MavenCoordinate diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 7bac10e79d0b4..11572da2c663c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -27,7 +27,7 @@ import com.typesafe.tools.mima.core._ import com.typesafe.tools.mima.lib.MiMaLib import org.apache.spark.SparkBuildInfo.spark_version -import org.apache.spark.sql.test.IntegrationTestUtils._ +import org.apache.spark.sql.connect.test.IntegrationTestUtils._ /** * A tool for checking the binary compatibility of the connect client API against the spark SQL diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala index ca23436675f87..2f8332878bbf5 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala @@ -20,7 +20,7 @@ import java.nio.file.Paths import org.apache.commons.io.FileUtils -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.util.SparkFileUtils class ClassFinderSuite extends ConnectFunSuite { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala index 68d2e86b19d70..b342d5b415692 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.connect.client import java.util.UUID -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite /** * Test suite for [[SparkConnectClient.Builder]] parsing and configuration. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index ac56600392aa3..acee1b2775f17 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -33,9 +33,9 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, Relation, SparkConnectServiceGrpc, SQL} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.SparkSession import org.apache.spark.sql.connect.common.config.ConnectCommon -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { @@ -630,26 +630,29 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer var errorToThrowOnExecute: Option[Throwable] = None - private[sql] def getAndClearLatestInputPlan(): proto.Plan = { + private[sql] def getAndClearLatestInputPlan(): proto.Plan = synchronized { val plan = inputPlan inputPlan = null plan } - private[sql] def getAndClearLatestAddArtifactRequests(): Seq[AddArtifactsRequest] = { - val requests = inputArtifactRequests.toSeq - inputArtifactRequests.clear() - requests - } + private[sql] def getAndClearLatestAddArtifactRequests(): Seq[AddArtifactsRequest] = + synchronized { + val requests = inputArtifactRequests.toSeq + inputArtifactRequests.clear() + requests + } override def executePlan( request: ExecutePlanRequest, responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { - if (errorToThrowOnExecute.isDefined) { - val error = errorToThrowOnExecute.get - errorToThrowOnExecute = None - responseObserver.onError(error) - return + synchronized { + if (errorToThrowOnExecute.isDefined) { + val error = errorToThrowOnExecute.get + errorToThrowOnExecute = None + responseObserver.onError(error) + return + } } // Reply with a dummy response using the same client ID @@ -659,7 +662,9 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer } else { UUID.randomUUID().toString } - inputPlan = request.getPlan + synchronized { + inputPlan = request.getPlan + } val response = ExecutePlanResponse .newBuilder() .setSessionId(requestSessionId) @@ -668,7 +673,7 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer responseObserver.onNext(response) // Reattachable execute must end with ResultComplete if (request.getRequestOptionsList.asScala.exists { option => - option.hasReattachOptions && option.getReattachOptions.getReattachable == true + option.hasReattachOptions && option.getReattachOptions.getReattachable }) { val resultComplete = ExecutePlanResponse .newBuilder() @@ -686,20 +691,22 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { // Reply with a dummy response using the same client ID val requestSessionId = request.getSessionId - request.getAnalyzeCase match { - case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA => - inputPlan = request.getSchema.getPlan - case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN => - inputPlan = request.getExplain.getPlan - case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING => - inputPlan = request.getTreeString.getPlan - case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL => - inputPlan = request.getIsLocal.getPlan - case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING => - inputPlan = request.getIsStreaming.getPlan - case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES => - inputPlan = request.getInputFiles.getPlan - case _ => inputPlan = null + synchronized { + request.getAnalyzeCase match { + case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA => + inputPlan = request.getSchema.getPlan + case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN => + inputPlan = request.getExplain.getPlan + case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING => + inputPlan = request.getTreeString.getPlan + case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL => + inputPlan = request.getIsLocal.getPlan + case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING => + inputPlan = request.getIsStreaming.getPlan + case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES => + inputPlan = request.getInputFiles.getPlan + case _ => inputPlan = null + } } val response = AnalyzePlanResponse .newBuilder() @@ -711,7 +718,8 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer override def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse]) : StreamObserver[AddArtifactsRequest] = new StreamObserver[AddArtifactsRequest] { - override def onNext(v: AddArtifactsRequest): Unit = inputArtifactRequests.append(v) + override def onNext(v: AddArtifactsRequest): Unit = + synchronized(inputArtifactRequests.append(v)) override def onError(throwable: Throwable): Unit = responseObserver.onError(throwable) @@ -728,13 +736,15 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer request.getNamesList().iterator().asScala.foreach { name => val status = proto.ArtifactStatusesResponse.ArtifactStatus.newBuilder() val exists = if (name.startsWith("cache/")) { - inputArtifactRequests.exists { artifactReq => - if (artifactReq.hasBatch) { - val batch = artifactReq.getBatch - batch.getArtifactsList.asScala.exists { singleArtifact => - singleArtifact.getName == name - } - } else false + synchronized { + inputArtifactRequests.exists { artifactReq => + if (artifactReq.hasBatch) { + val batch = artifactReq.getBatch + batch.getArtifactsList.asScala.exists { singleArtifact => + singleArtifact.getName == name + } + } else false + } } } else false builder.putStatuses(name, status.setExists(exists).build()) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 2cdbdc67d2cad..f6662b3351ba7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._ import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType} import org.apache.spark.unsafe.types.VariantVal diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala index 199a1507a3b19..d5f38231d3450 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.connect.streaming import java.io.{File, FileWriter} import java.nio.file.Paths @@ -29,10 +29,12 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.connect.SparkSession +import org.apache.spark.sql.connect.test.{IntegrationTestUtils, QueryTest, RemoteSparkSession} import org.apache.spark.sql.functions.{col, lit, udf, window} +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamingQueryListener, Trigger} import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent} -import org.apache.spark.sql.test.{IntegrationTestUtils, QueryTest, RemoteSparkSession} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.SparkFileUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/FlatMapGroupsWithStateStreamingSuite.scala similarity index 96% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/FlatMapGroupsWithStateStreamingSuite.scala index 9bd6614028cbf..2496038a77fe4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/FlatMapGroupsWithStateStreamingSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.connect.streaming import java.sql.Timestamp @@ -23,9 +23,10 @@ import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append -import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} +import org.apache.spark.sql.connect.SparkSession +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} case class ClickEvent(id: String, timestamp: Timestamp) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/StreamingQueryProgressSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/StreamingQueryProgressSuite.scala index aed6c55b3e7fb..0b8447e3bb4c1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/StreamingQueryProgressSuite.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.connect.streaming import scala.jdk.CollectionConverters._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema -import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.connect.test.ConnectFunSuite +import org.apache.spark.sql.streaming.StreamingQueryProgress import org.apache.spark.sql.types.StructType class StreamingQueryProgressSuite extends ConnectFunSuite { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/ConnectFunSuite.scala similarity index 94% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/ConnectFunSuite.scala index f46b98646ae4f..89f70d6f1214a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/ConnectFunSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql.connect.test import java.nio.file.Path @@ -22,7 +22,7 @@ import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite import org.apache.spark.connect.proto import org.apache.spark.sql.Column -import org.apache.spark.sql.internal.ColumnNodeToProtoConverter +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter /** * The basic testsuite the client tests should extend from. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/IntegrationTestUtils.scala similarity index 99% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/IntegrationTestUtils.scala index 3ae9b9fc73b48..79c7a797a97e0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/IntegrationTestUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql.connect.test import java.io.File import java.nio.file.{Files, Paths} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/QueryTest.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/QueryTest.scala index b22488858b8f5..5ae23368b9729 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/QueryTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql.connect.test import java.util.TimeZone @@ -24,8 +24,9 @@ import scala.jdk.CollectionConverters._ import org.scalatest.Assertions import org.apache.spark.{QueryContextType, SparkThrowable} -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.SparkStringUtils.sideBySide +import org.apache.spark.sql.connect.{DataFrame, Dataset, SparkSession} import org.apache.spark.util.ArrayImplicits._ abstract class QueryTest extends ConnectFunSuite with SQLHelper { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala similarity index 98% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala index 36aaa2cc7fbf6..b6f17627fca85 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql.connect.test import java.io.{File, IOException, OutputStream} import java.lang.ProcessBuilder.Redirect @@ -23,17 +23,17 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration.FiniteDuration +import IntegrationTestUtils._ import org.scalatest.{BeforeAndAfterAll, Suite} import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkBuildInfo -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.SparkSession import org.apache.spark.sql.connect.client.RetryPolicy import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.connect.common.config.ConnectCommon -import org.apache.spark.sql.test.IntegrationTestUtils._ import org.apache.spark.util.ArrayImplicits._ /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/SQLHelper.scala similarity index 94% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/SQLHelper.scala index d9828ae92267b..f23221f1a46c0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/SQLHelper.scala @@ -14,15 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql.connect.test import java.io.File import java.util.UUID import org.scalatest.Assertions.fail -import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession, SQLImplicits} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.connect.{DataFrame, SparkSession, SQLImplicits} import org.apache.spark.util.{SparkErrorUtils, SparkFileUtils} trait SQLHelper { @@ -39,9 +40,7 @@ trait SQLHelper { * because we create the `SparkSession` immediately before the first test is run, but the * implicits import is needed in the constructor. */ - protected object testImplicits extends SQLImplicits { - override protected def session: SparkSession = spark - } + protected object testImplicits extends SQLImplicits(spark) /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index b77bb94aaf46f..ff884310f660c 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.types.StructType diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index d43b22d9de922..1b52046b14833 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -26,9 +26,11 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{ERROR, FROM_OFFSET, OFFSETS, TIP, TOPIC_PARTITIONS, UNTIL_OFFSET} import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql._ +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.connector.read.streaming import org.apache.spark.sql.connector.read.streaming.{Offset => _, _} import org.apache.spark.sql.execution.streaming._ diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index a73380cab690e..a30759e5d794e 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -265,8 +265,9 @@ private[spark] class CoarseGrainedExecutorBackend( } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit = { - val resources = executor.runningTasks.get(taskId).taskDescription.resources - val cpus = executor.runningTasks.get(taskId).taskDescription.cpus + val taskDescription = executor.runningTasks.get(taskId).taskDescription + val resources = taskDescription.resources + val cpus = taskDescription.cpus val msg = StatusUpdate(executorId, taskId, state, data, cpus, resources) if (TaskState.isFinished(state)) { lastTaskFinishTime.set(System.nanoTime()) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index d24466ce711f8..f7820e12d101c 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -63,14 +63,14 @@ derby/10.16.1.1//derby-10.16.1.1.jar derbyshared/10.16.1.1//derbyshared-10.16.1.1.jar derbytools/10.16.1.1//derbytools-10.16.1.1.jar dropwizard-metrics-hadoop-metrics2-reporter/0.1.2//dropwizard-metrics-hadoop-metrics2-reporter-0.1.2.jar -error_prone_annotations/2.28.0//error_prone_annotations-2.28.0.jar +error_prone_annotations/2.36.0//error_prone_annotations-2.36.0.jar esdk-obs-java/3.20.4.2//esdk-obs-java-3.20.4.2.jar failureaccess/1.0.2//failureaccess-1.0.2.jar flatbuffers-java/24.3.25//flatbuffers-java-24.3.25.jar gcs-connector/hadoop3-2.2.25/shaded/gcs-connector-hadoop3-2.2.25-shaded.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar gson/2.11.0//gson-2.11.0.jar -guava/33.3.1-jre//guava-33.3.1-jre.jar +guava/33.4.0-jre//guava-33.4.0-jre.jar hadoop-aliyun/3.4.1//hadoop-aliyun-3.4.1.jar hadoop-annotations/3.4.1//hadoop-annotations-3.4.1.jar hadoop-aws/3.4.1//hadoop-aws-3.4.1.jar diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 94548afb6c292..a86cd2c994498 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -240,7 +240,7 @@ sealed trait Vector extends Serializable { @Since("2.0.0") object Vectors { - private[ml] val empty: Vector = zeros(0) + private[ml] val empty: DenseVector = new DenseVector(Array.emptyDoubleArray) /** * Creates a dense vector from its values. diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator index 6c5bbd858d9cc..ca728566490de 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator @@ -19,6 +19,7 @@ # So register the supported estimator here if you're trying to add a new one. # classification +org.apache.spark.ml.classification.LinearSVC org.apache.spark.ml.classification.LogisticRegression org.apache.spark.ml.classification.DecisionTreeClassifier org.apache.spark.ml.classification.RandomForestClassifier @@ -35,6 +36,7 @@ org.apache.spark.ml.regression.GBTRegressor # clustering org.apache.spark.ml.clustering.KMeans org.apache.spark.ml.clustering.BisectingKMeans +org.apache.spark.ml.clustering.GaussianMixture # recommendation @@ -50,6 +52,12 @@ org.apache.spark.ml.feature.StandardScaler org.apache.spark.ml.feature.MaxAbsScaler org.apache.spark.ml.feature.MinMaxScaler org.apache.spark.ml.feature.RobustScaler +org.apache.spark.ml.feature.ChiSqSelector +org.apache.spark.ml.feature.UnivariateFeatureSelector +org.apache.spark.ml.feature.VarianceThresholdSelector org.apache.spark.ml.feature.StringIndexer org.apache.spark.ml.feature.PCA org.apache.spark.ml.feature.Word2Vec +org.apache.spark.ml.feature.CountVectorizer +org.apache.spark.ml.feature.OneHotEncoder + diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer index 0448117468198..dbedcf3e26e0a 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer @@ -18,10 +18,18 @@ # Spark Connect ML uses ServiceLoader to find out the supported Spark Ml non-model transformer. # So register the supported transformer here if you're trying to add a new one. ########### Transformers +org.apache.spark.ml.feature.DCT +org.apache.spark.ml.feature.Binarizer +org.apache.spark.ml.feature.Bucketizer org.apache.spark.ml.feature.VectorAssembler +org.apache.spark.ml.feature.Tokenizer +org.apache.spark.ml.feature.RegexTokenizer +org.apache.spark.ml.feature.SQLTransformer +org.apache.spark.ml.feature.StopWordsRemover ########### Model for loading # classification +org.apache.spark.ml.classification.LinearSVCModel org.apache.spark.ml.classification.LogisticRegressionModel org.apache.spark.ml.classification.DecisionTreeClassificationModel org.apache.spark.ml.classification.RandomForestClassificationModel @@ -36,6 +44,7 @@ org.apache.spark.ml.regression.GBTRegressionModel # clustering org.apache.spark.ml.clustering.KMeansModel org.apache.spark.ml.clustering.BisectingKMeansModel +org.apache.spark.ml.clustering.GaussianMixtureModel # recommendation org.apache.spark.ml.recommendation.ALSModel @@ -48,6 +57,12 @@ org.apache.spark.ml.feature.StandardScalerModel org.apache.spark.ml.feature.MaxAbsScalerModel org.apache.spark.ml.feature.MinMaxScalerModel org.apache.spark.ml.feature.RobustScalerModel +org.apache.spark.ml.feature.ChiSqSelectorModel +org.apache.spark.ml.feature.UnivariateFeatureSelectorModel +org.apache.spark.ml.feature.VarianceThresholdSelectorModel org.apache.spark.ml.feature.StringIndexerModel org.apache.spark.ml.feature.PCAModel org.apache.spark.ml.feature.Word2VecModel +org.apache.spark.ml.feature.CountVectorizerModel +org.apache.spark.ml.feature.OneHotEncoderModel + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 161e8f4cbd2c5..6fa7f4d5d493c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -365,6 +365,8 @@ class LinearSVCModel private[classification] ( extends ClassificationModel[Vector, LinearSVCModel] with LinearSVCParams with MLWritable with HasTrainingSummary[LinearSVCTrainingSummary] { + private[ml] def this() = this(Identifiable.randomUID("linearsvc"), Vectors.empty, 0.0) + @Since("2.2.0") override val numClasses: Int = 2 diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index d0db5dcba87b5..ad1533cd37a9e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -93,6 +93,9 @@ class GaussianMixtureModel private[ml] ( extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable with HasTrainingSummary[GaussianMixtureSummary] { + private[ml] def this() = this(Identifiable.randomUID("gmm"), + Array.emptyDoubleArray, Array.empty) + @Since("3.0.0") lazy val numFeatures: Int = gaussians.head.mean.size diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index eb2122b09b2fb..e93d96cf9717f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -137,6 +137,9 @@ final class ChiSqSelectorModel private[ml] ( import ChiSqSelectorModel._ + private[ml] def this() = this( + Identifiable.randomUID("chiSqSelector"), Array.emptyIntArray) + override protected def isNumericAttribute = false /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 611b5c710add1..95788be6bd2bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -277,6 +277,8 @@ class CountVectorizerModel( import CountVectorizerModel._ + private[ml] def this() = this(Identifiable.randomUID("cntVecModel"), Array.empty) + @Since("1.5.0") def this(vocabulary: Array[String]) = { this(Identifiable.randomUID("cntVecModel"), vocabulary) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 44b8b2047681b..25bcdc9a1c293 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -234,6 +234,8 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ + private[ml] def this() = this(Identifiable.randomUID("oneHotEncoder)"), Array.emptyIntArray) + // Returns the category size for each index with `dropLast` and `handleInvalid` // taken into account. private def getConfigedCategorySizes: Array[Int] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 7d630233eb0a4..67c8fcf15eec2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -130,7 +130,7 @@ class PCAModel private[ml] ( // For ml connect only @Since("4.0.0") private[ml] def this() = this(Identifiable.randomUID("pca"), - DenseMatrix.zeros(1, 1), Vectors.empty.asInstanceOf[DenseVector]) + DenseMatrix.zeros(1, 1), Vectors.empty) /** @group setParam */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index d0046e3f0c5bc..48783410448bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -282,8 +282,8 @@ object TargetEncoder extends DefaultParamsReadable[TargetEncoder] { */ @Since("4.0.0") class TargetEncoderModel private[ml] ( - @Since("4.0.0") override val uid: String, - @Since("4.0.0") val stats: Array[Map[Double, (Double, Double)]]) + @Since("4.0.0") override val uid: String, + @Since("4.0.0") private[ml] val stats: Array[Map[Double, (Double, Double)]]) extends Model[TargetEncoderModel] with TargetEncoderBase with MLWritable { /** @group setParam */ @@ -403,13 +403,18 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { private[TargetEncoderModel] class TargetEncoderModelWriter(instance: TargetEncoderModel) extends MLWriter { - private case class Data(stats: Array[Map[Double, (Double, Double)]]) + private case class Data(index: Int, categories: Array[Double], + counts: Array[Double], stats: Array[Double]) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) - val data = Data(instance.stats) + val datum = instance.stats.iterator.zipWithIndex.map { case (stat, index) => + val (_categories, _countsAndStats) = stat.toSeq.unzip + val (_counts, _stats) = _countsAndStats.unzip + Data(index, _categories.toArray, _counts.toArray, _stats.toArray) + }.toSeq val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + sparkSession.createDataFrame(datum).write.parquet(dataPath) } } @@ -420,10 +425,18 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { override def load(path: String): TargetEncoderModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("encodings") - .head() - val stats = data.getAs[Array[Map[Double, (Double, Double)]]](0) + + val stats = sparkSession.read.parquet(dataPath) + .select("index", "categories", "counts", "stats") + .collect() + .map { row => + val index = row.getInt(0) + val categories = row.getAs[Seq[Double]](1).toArray + val counts = row.getAs[Seq[Double]](2).toArray + val stats = row.getAs[Seq[Double]](3).toArray + (index, categories.zip(counts.zip(stats)).toMap) + }.sortBy(_._1).map(_._2) + val model = new TargetEncoderModel(metadata.uid, stats) metadata.getAndSetParams(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala index ea1a8c6438c8d..d845e2887a647 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala @@ -289,6 +289,9 @@ class UnivariateFeatureSelectorModel private[ml]( extends Model[UnivariateFeatureSelectorModel] with UnivariateFeatureSelectorParams with MLWritable { + private[ml] def this() = this( + Identifiable.randomUID("UnivariateFeatureSelector"), Array.emptyIntArray) + /** @group setParam */ @Since("3.1.1") def setFeaturesCol(value: String): this.type = set(featuresCol, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala index d767e113144c2..23ea1ee3066e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala @@ -126,6 +126,9 @@ class VarianceThresholdSelectorModel private[ml]( extends Model[VarianceThresholdSelectorModel] with VarianceThresholdSelectorParams with MLWritable { + private[ml] def this() = this( + Identifiable.randomUID("VarianceThresholdSelector"), Array.emptyIntArray) + if (selectedFeatures.length >= 2) { require(selectedFeatures.sliding(2).forall(l => l(0) < l(1)), "Index should be strictly increasing.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index e67b72e090601..ac8845b09069d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.ClassicConversions.ColumnConstructorExt +import org.apache.spark.sql.classic.ExpressionUtils.expression import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.ExpressionUtils.expression import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala b/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala index b3a3bc1a791ba..2141d4251eded 100644 --- a/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala +++ b/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala @@ -21,10 +21,10 @@ import org.apache.spark.mllib.linalg.{SparseVector => OldSparseVector, Vector => import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.{Expression, StringLiteral} +import org.apache.spark.sql.classic.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction} import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF /** * Register a couple ML vector conversion UDFs in the internal function registry. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 2b20a282dd14a..ff9651c11250f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkContext import org.apache.spark.ml.feature._ import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util.TempDirectory -import org.apache.spark.sql.{SparkSession, SQLImplicits} +import org.apache.spark.sql.classic.{SparkSession, SQLImplicits} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils diff --git a/pom.xml b/pom.xml index 88738bf75204a..c3f959c8903e3 100644 --- a/pom.xml +++ b/pom.xml @@ -195,7 +195,7 @@ 2.12.0 4.1.17 - 33.3.1-jre + 33.4.0-jre 2.11.0 3.1.9 3.0.16 @@ -294,7 +294,7 @@ true - 33.3.1-jre + 33.4.0-jre 1.0.2 1.67.1 1.1.4 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d93f15d3a63a3..618d21149ab45 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -205,10 +205,33 @@ object MimaExcludes { // SPARK-50112: Moving avro files from connector to sql/core ProblemFilters.exclude[Problem]("org.apache.spark.sql.avro.*"), + // SPARK-49700: Unified Scala SQL Interface. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameNaFunctions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameReader"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameStatFunctions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.KeyValueGroupedDataset"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLImplicits"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$Builder"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$implicits$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.package"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.package$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.Catalog"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.DataStreamReader"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.DataStreamWriter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.DataStreamWriter$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryManager"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQuery"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$implicits$"), + // SPARK-50768: Introduce TaskContext.createResourceUninterruptibly to avoid stream leak by task interruption ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.interruptible"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.pendingInterrupt"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.createResourceUninterruptibly"), + ) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++ loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++ loggingExcludes("org.apache.spark.sql.SparkSession#Builder") diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index abf6b5df469ba..1d25215590af5 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1058,7 +1058,7 @@ object KubernetesIntegrationTests { * Overrides to work around sbt's dependency resolution being different from Maven's. */ object DependencyOverrides { - lazy val guavaVersion = sys.props.get("guava.version").getOrElse("33.3.1-jre") + lazy val guavaVersion = sys.props.get("guava.version").getOrElse("33.4.0-jre") lazy val settings = Seq( dependencyOverrides += "com.google.guava" % "guava" % guavaVersion, dependencyOverrides += "jline" % "jline" % "2.14.6", diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 4f19421ca64f1..3dd8123d581c6 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -161,7 +161,10 @@ def killChild(): java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") java_import(gateway.jvm, "org.apache.spark.resource.*") # TODO(davies): move into sql - java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.Encoders") + java_import(gateway.jvm, "org.apache.spark.sql.OnSuccessCall") + java_import(gateway.jvm, "org.apache.spark.sql.functions") + java_import(gateway.jvm, "org.apache.spark.sql.classic.*") java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index fa51d88283403..d8ed51a82abe9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -62,7 +62,6 @@ HasSolver, HasParallelism, ) -from pyspark.ml.util import try_remote_attribute_relation from pyspark.ml.tree import ( _DecisionTreeModel, _DecisionTreeParams, @@ -86,6 +85,7 @@ MLWriter, MLWritable, HasTrainingSummary, + try_remote_attribute_relation, ) from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper from pyspark.ml.common import inherit_doc diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 8a518dac380c6..6cd508a9e950b 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -241,6 +241,7 @@ def gaussians(self) -> List[MultivariateGaussian]: @property @since("2.0.0") + @try_remote_attribute_relation def gaussiansDF(self) -> DataFrame: """ Retrieve Gaussian distributions as a DataFrame. @@ -542,6 +543,7 @@ def probabilityCol(self) -> str: @property @since("2.1.0") + @try_remote_attribute_relation def probability(self) -> DataFrame: """ DataFrame of probabilities of each cluster for each training data point. diff --git a/python/pyspark/ml/connect/serialize.py b/python/pyspark/ml/connect/serialize.py index 62b21460feb7c..e93f917e27c96 100644 --- a/python/pyspark/ml/connect/serialize.py +++ b/python/pyspark/ml/connect/serialize.py @@ -18,8 +18,6 @@ import pyspark.sql.connect.proto as pb2 from pyspark.ml.linalg import ( - VectorUDT, - MatrixUDT, DenseVector, SparseVector, DenseMatrix, @@ -49,13 +47,23 @@ def build_float_list(value: List[float]) -> pb2.Expression.Literal: return p +def build_proto_udt(jvm_class: str) -> pb2.DataType: + ret = pb2.DataType() + ret.udt.type = "udt" + ret.udt.jvm_class = jvm_class + return ret + + +proto_vector_udt = build_proto_udt("org.apache.spark.ml.linalg.VectorUDT") +proto_matrix_udt = build_proto_udt("org.apache.spark.ml.linalg.MatrixUDT") + + def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.Literal: - from pyspark.sql.connect.types import pyspark_types_to_proto_types from pyspark.sql.connect.expressions import LiteralExpression if isinstance(value, SparseVector): p = pb2.Expression.Literal() - p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType())) + p.struct.struct_type.CopyFrom(proto_vector_udt) # type = 0 p.struct.elements.append(pb2.Expression.Literal(byte=0)) # size @@ -68,7 +76,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression. elif isinstance(value, DenseVector): p = pb2.Expression.Literal() - p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType())) + p.struct.struct_type.CopyFrom(proto_vector_udt) # type = 1 p.struct.elements.append(pb2.Expression.Literal(byte=1)) # size = null @@ -81,7 +89,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression. elif isinstance(value, SparseMatrix): p = pb2.Expression.Literal() - p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType())) + p.struct.struct_type.CopyFrom(proto_matrix_udt) # type = 0 p.struct.elements.append(pb2.Expression.Literal(byte=0)) # numRows @@ -100,7 +108,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression. elif isinstance(value, DenseMatrix): p = pb2.Expression.Literal() - p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType())) + p.struct.struct_type.CopyFrom(proto_matrix_udt) # type = 1 p.struct.elements.append(pb2.Expression.Literal(byte=1)) # numRows @@ -134,14 +142,13 @@ def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]: def deserialize_param(literal: pb2.Expression.Literal) -> Any: - from pyspark.sql.connect.types import proto_schema_to_pyspark_data_type from pyspark.sql.connect.expressions import LiteralExpression if literal.HasField("struct"): s = literal.struct - schema = proto_schema_to_pyspark_data_type(s.struct_type) + jvm_class = s.struct_type.udt.jvm_class - if schema == VectorUDT.sqlType(): + if jvm_class == "org.apache.spark.ml.linalg.VectorUDT": assert len(s.elements) == 4 tpe = s.elements[0].byte if tpe == 0: @@ -155,7 +162,7 @@ def deserialize_param(literal: pb2.Expression.Literal) -> Any: else: raise ValueError(f"Unknown Vector type {tpe}") - elif schema == MatrixUDT.sqlType(): + elif jvm_class == "org.apache.spark.ml.linalg.MatrixUDT": assert len(s.elements) == 7 tpe = s.elements[0].byte if tpe == 0: @@ -175,7 +182,7 @@ def deserialize_param(literal: pb2.Expression.Literal) -> Any: else: raise ValueError(f"Unknown Matrix type {tpe}") else: - raise ValueError(f"Unsupported parameter struct {schema}") + raise ValueError(f"Unknown UDT {jvm_class}") else: return LiteralExpression._to_value(literal) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index b2b2d32c31f0c..568583eb08ecb 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -311,6 +311,10 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) + def isLargerBetter(self) -> bool: + """Override this function to make it run on connect""" + return True + @inherit_doc class RegressionEvaluator( @@ -467,6 +471,10 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) + def isLargerBetter(self) -> bool: + """Override this function to make it run on connect""" + return self.getMetricName() in ["r2", "var"] + @inherit_doc class MulticlassClassificationEvaluator( @@ -700,6 +708,15 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) + def isLargerBetter(self) -> bool: + """Override this function to make it run on connect""" + return not self.getMetricName() in [ + "weightedFalsePositiveRate", + "falsePositiveRateByLabel", + "logLoss", + "hammingLoss", + ] + @inherit_doc class MultilabelClassificationEvaluator( @@ -843,6 +860,10 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) + def isLargerBetter(self) -> bool: + """Override this function to make it run on connect""" + return self.getMetricName() != "hammingLoss" + @inherit_doc class ClusteringEvaluator( @@ -1002,6 +1023,10 @@ def setWeightCol(self, value: str) -> "ClusteringEvaluator": """ return self._set(weightCol=value) + def isLargerBetter(self) -> bool: + """Override this function to make it run on connect""" + return True + @inherit_doc class RankingEvaluator( @@ -1138,6 +1163,10 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) + def isLargerBetter(self) -> bool: + """Override this function to make it run on connect""" + return True + if __name__ == "__main__": import doctest diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index ff8555fadbd12..4c218267749cc 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -4970,7 +4970,8 @@ class StopWordsRemover( Notes ----- - null values from input array are preserved unless adding null to stopWords explicitly. + - null values from input array are preserved unless adding null to stopWords explicitly. + - In Spark Connect Mode, the default value of parameter `locale` is not set. Examples -------- @@ -5069,11 +5070,19 @@ def __init__( self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.StopWordsRemover", self.uid ) - self._setDefault( - stopWords=StopWordsRemover.loadDefaultStopWords("english"), - caseSensitive=False, - locale=self._java_obj.getLocale(), - ) + if isinstance(self._java_obj, str): + # Skip setting the default value of 'locale', which needs to invoke a JVM method. + # So if users don't explicitly set 'locale', then getLocale fails. + self._setDefault( + stopWords=StopWordsRemover.loadDefaultStopWords("english"), + caseSensitive=False, + ) + else: + self._setDefault( + stopWords=StopWordsRemover.loadDefaultStopWords("english"), + caseSensitive=False, + locale=self._java_obj.getLocale(), + ) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -5491,15 +5500,6 @@ def setSmoothing(self, value: float) -> "TargetEncoderModel": """ return self._set(smoothing=value) - @property - @since("4.0.0") - def stats(self) -> List[Dict[float, Tuple[float, float]]]: - """ - Fitted statistics for each feature to being encoded. - The list contains a dictionary for each input column. - """ - return self._call_java("stats") - @inherit_doc class Tokenizer( diff --git a/python/pyspark/ml/tests/connect/test_parity_feature.py b/python/pyspark/ml/tests/connect/test_parity_feature.py index 105ba07df43bf..55d299c063708 100644 --- a/python/pyspark/ml/tests/connect/test_parity_feature.py +++ b/python/pyspark/ml/tests/connect/test_parity_feature.py @@ -22,10 +22,6 @@ class FeatureParityTests(FeatureTestsMixin, ReusedConnectTestCase): - @unittest.skip("Need to support.") - def test_binarizer(self): - super().test_binarizer() - @unittest.skip("Need to support.") def test_idf(self): super().test_idf() @@ -34,10 +30,6 @@ def test_idf(self): def test_ngram(self): super().test_ngram() - @unittest.skip("Need to support.") - def test_stopwordsremover(self): - super().test_stopwordsremover() - @unittest.skip("Need to support.") def test_count_vectorizer_with_binary(self): super().test_count_vectorizer_with_binary() diff --git a/python/pyspark/ml/tests/test_classification.py b/python/pyspark/ml/tests/test_classification.py index 8afa7327fea16..bcf376007198a 100644 --- a/python/pyspark/ml/tests/test_classification.py +++ b/python/pyspark/ml/tests/test_classification.py @@ -24,6 +24,10 @@ from pyspark.ml.linalg import Vectors, Matrices from pyspark.sql import SparkSession, DataFrame from pyspark.ml.classification import ( + LinearSVC, + LinearSVCModel, + LinearSVCSummary, + LinearSVCTrainingSummary, LogisticRegression, LogisticRegressionModel, LogisticRegressionSummary, @@ -299,6 +303,78 @@ def test_logistic_regression(self): except OSError: pass + def test_linear_svc(self): + df = ( + self.spark.createDataFrame( + [ + (1.0, 1.0, Vectors.dense(0.0, 5.0)), + (0.0, 2.0, Vectors.dense(1.0, 2.0)), + (1.0, 3.0, Vectors.dense(2.0, 1.0)), + (0.0, 4.0, Vectors.dense(3.0, 3.0)), + ], + ["label", "weight", "features"], + ) + .coalesce(1) + .sortWithinPartitions("weight") + ) + + svc = LinearSVC(maxIter=1, regParam=1.0) + self.assertEqual(svc.getMaxIter(), 1) + self.assertEqual(svc.getRegParam(), 1.0) + + model = svc.fit(df) + self.assertEqual(model.numClasses, 2) + self.assertEqual(model.numFeatures, 2) + self.assertTrue(np.allclose(model.intercept, 0.025877458475338313, atol=1e-4)) + self.assertTrue( + np.allclose(model.coefficients.toArray(), [-0.03622844, 0.01035098], atol=1e-4) + ) + + vec = Vectors.dense(0.0, 5.0) + self.assertEqual(model.predict(vec), 1.0) + self.assertTrue( + np.allclose(model.predictRaw(vec).toArray(), [-0.07763238, 0.07763238], atol=1e-4) + ) + + output = model.transform(df) + expected_cols = [ + "label", + "weight", + "features", + "rawPrediction", + "prediction", + ] + self.assertEqual(output.columns, expected_cols) + self.assertEqual(output.count(), 4) + + # model summary + self.assertTrue(model.hasSummary) + summary = model.summary() + self.assertIsInstance(summary, LinearSVCSummary) + self.assertIsInstance(summary, LinearSVCTrainingSummary) + self.assertEqual(summary.labels, [0.0, 1.0]) + self.assertEqual(summary.accuracy, 0.5) + self.assertEqual(summary.areaUnderROC, 0.75) + self.assertEqual(summary.predictions.columns, expected_cols) + + summary2 = model.evaluate(df) + self.assertIsInstance(summary2, LinearSVCSummary) + self.assertFalse(isinstance(summary2, LinearSVCTrainingSummary)) + self.assertEqual(summary2.labels, [0.0, 1.0]) + self.assertEqual(summary2.accuracy, 0.5) + self.assertEqual(summary2.areaUnderROC, 0.75) + self.assertEqual(summary2.predictions.columns, expected_cols) + + # Model save & load + with tempfile.TemporaryDirectory(prefix="linear_svc") as d: + svc.write().overwrite().save(d) + svc2 = LinearSVC.load(d) + self.assertEqual(str(svc), str(svc2)) + + model.write().overwrite().save(d) + model2 = LinearSVCModel.load(d) + self.assertEqual(str(model), str(model2)) + def test_decision_tree_classifier(self): df = ( self.spark.createDataFrame( diff --git a/python/pyspark/ml/tests/test_clustering.py b/python/pyspark/ml/tests/test_clustering.py index 98b7a4f57c1dc..a6685914eab80 100644 --- a/python/pyspark/ml/tests/test_clustering.py +++ b/python/pyspark/ml/tests/test_clustering.py @@ -29,13 +29,15 @@ BisectingKMeans, BisectingKMeansModel, BisectingKMeansSummary, + GaussianMixture, + GaussianMixtureModel, + GaussianMixtureSummary, ) class ClusteringTestsMixin: - @property - def df(self): - return ( + def test_kmeans(self): + df = ( self.spark.createDataFrame( [ (1, 1.0, Vectors.dense([-0.1, -0.05])), @@ -49,11 +51,9 @@ def df(self): ) .coalesce(1) .sortWithinPartitions("index") + .select("weight", "features") ) - def test_kmeans(self): - df = self.df.select("weight", "features") - km = KMeans( k=2, maxIter=2, @@ -68,11 +68,7 @@ def test_kmeans(self): # self.assertEqual(model.numFeatures, 2) output = model.transform(df) - expected_cols = [ - "weight", - "features", - "prediction", - ] + expected_cols = ["weight", "features", "prediction"] self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 6) @@ -107,7 +103,22 @@ def test_kmeans(self): self.assertEqual(str(model), str(model2)) def test_bisecting_kmeans(self): - df = self.df.select("weight", "features") + df = ( + self.spark.createDataFrame( + [ + (1, 1.0, Vectors.dense([-0.1, -0.05])), + (2, 2.0, Vectors.dense([-0.01, -0.1])), + (3, 3.0, Vectors.dense([0.9, 0.8])), + (4, 1.0, Vectors.dense([0.75, 0.935])), + (5, 1.0, Vectors.dense([-0.83, -0.68])), + (6, 1.0, Vectors.dense([-0.91, -0.76])), + ], + ["index", "weight", "features"], + ) + .coalesce(1) + .sortWithinPartitions("index") + .select("weight", "features") + ) bkm = BisectingKMeans( k=2, @@ -125,11 +136,7 @@ def test_bisecting_kmeans(self): # self.assertEqual(model.numFeatures, 2) output = model.transform(df) - expected_cols = [ - "weight", - "features", - "prediction", - ] + expected_cols = ["weight", "features", "prediction"] self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 6) @@ -166,6 +173,94 @@ def test_bisecting_kmeans(self): model2 = BisectingKMeansModel.load(d) self.assertEqual(str(model), str(model2)) + def test_gaussian_mixture(self): + df = ( + self.spark.createDataFrame( + [ + (1, 1.0, Vectors.dense([-0.1, -0.05])), + (2, 2.0, Vectors.dense([-0.01, -0.1])), + (3, 3.0, Vectors.dense([0.9, 0.8])), + (4, 1.0, Vectors.dense([0.75, 0.935])), + (5, 1.0, Vectors.dense([-0.83, -0.68])), + (6, 1.0, Vectors.dense([-0.91, -0.76])), + ], + ["index", "weight", "features"], + ) + .coalesce(1) + .sortWithinPartitions("index") + .select("weight", "features") + ) + + gmm = GaussianMixture( + k=2, + maxIter=2, + weightCol="weight", + seed=1, + ) + self.assertEqual(gmm.getK(), 2) + self.assertEqual(gmm.getMaxIter(), 2) + self.assertEqual(gmm.getWeightCol(), "weight") + self.assertEqual(gmm.getSeed(), 1) + + model = gmm.fit(df) + # TODO: support GMM.numFeatures in Python + # self.assertEqual(model.numFeatures, 2) + self.assertEqual(len(model.weights), 2) + self.assertTrue( + np.allclose(model.weights, [0.541014115744985, 0.4589858842550149], atol=1e-4), + model.weights, + ) + # TODO: support GMM.gaussians on connect + # self.assertEqual(model.gaussians, xxx) + self.assertEqual(model.gaussiansDF.columns, ["mean", "cov"]) + self.assertEqual(model.gaussiansDF.count(), 2) + + vec = Vectors.dense(0.0, 5.0) + pred = model.predict(vec) + self.assertTrue(np.allclose(pred, 0, atol=1e-4), pred) + pred = model.predictProbability(vec) + self.assertTrue(np.allclose(pred.toArray(), [0.5, 0.5], atol=1e-4), pred) + + output = model.transform(df) + expected_cols = ["weight", "features", "probability", "prediction"] + self.assertEqual(output.columns, expected_cols) + self.assertEqual(output.count(), 6) + + # Model summary + self.assertTrue(model.hasSummary) + summary = model.summary + self.assertTrue(isinstance(summary, GaussianMixtureSummary)) + self.assertEqual(summary.k, 2) + self.assertEqual(summary.numIter, 2) + self.assertEqual(len(summary.clusterSizes), 2) + self.assertEqual(summary.clusterSizes, [3, 3]) + ll = summary.logLikelihood + self.assertTrue(ll < 0, ll) + self.assertTrue(np.allclose(ll, -1.311264553744033, atol=1e-4), ll) + + self.assertEqual(summary.featuresCol, "features") + self.assertEqual(summary.predictionCol, "prediction") + self.assertEqual(summary.probabilityCol, "probability") + + self.assertEqual(summary.cluster.columns, ["prediction"]) + self.assertEqual(summary.cluster.count(), 6) + + self.assertEqual(summary.predictions.columns, expected_cols) + self.assertEqual(summary.predictions.count(), 6) + + self.assertEqual(summary.probability.columns, ["probability"]) + self.assertEqual(summary.predictions.count(), 6) + + # save & load + with tempfile.TemporaryDirectory(prefix="gaussian_mixture") as d: + gmm.write().overwrite().save(d) + gmm2 = GaussianMixture.load(d) + self.assertEqual(str(gmm), str(gmm2)) + + model.write().overwrite().save(d) + model2 = GaussianMixtureModel.load(d) + self.assertEqual(str(model), str(model2)) + class ClusteringTests(ClusteringTestsMixin, unittest.TestCase): def setUp(self) -> None: diff --git a/python/pyspark/ml/tests/test_evaluation.py b/python/pyspark/ml/tests/test_evaluation.py index 492c5e3967576..3cd84ba528a71 100644 --- a/python/pyspark/ml/tests/test_evaluation.py +++ b/python/pyspark/ml/tests/test_evaluation.py @@ -42,6 +42,7 @@ def test_ranking_evaluator(self): # Initialize RankingEvaluator evaluator = RankingEvaluator().setPredictionCol("prediction") + self.assertTrue(evaluator.isLargerBetter()) # Evaluate the dataset using the default metric (mean average precision) mean_average_precision = evaluator.evaluate(dataset) @@ -94,6 +95,25 @@ def test_multilabel_classification_evaluator(self): self.assertEqual(evaluator2.getPredictionCol(), "prediction") self.assertEqual(str(evaluator), str(evaluator2)) + for metric in [ + "subsetAccuracy", + "accuracy", + "precision", + "recall", + "f1Measure", + "precisionByLabel", + "recallByLabel", + "f1MeasureByLabel", + "microPrecision", + "microRecall", + "microF1Measure", + ]: + evaluator.setMetricName(metric) + self.assertTrue(evaluator.isLargerBetter()) + + evaluator.setMetricName("hammingLoss") + self.assertTrue(not evaluator.isLargerBetter()) + def test_multiclass_classification_evaluator(self): dataset = self.spark.createDataFrame( [ @@ -163,6 +183,29 @@ def test_multiclass_classification_evaluator(self): log_loss = evaluator.evaluate(dataset) self.assertTrue(np.allclose(log_loss, 1.0093, atol=1e-4)) + for metric in [ + "f1", + "accuracy", + "weightedPrecision", + "weightedRecall", + "weightedTruePositiveRate", + "weightedFMeasure", + "truePositiveRateByLabel", + "precisionByLabel", + "recallByLabel", + "fMeasureByLabel", + ]: + evaluator.setMetricName(metric) + self.assertTrue(evaluator.isLargerBetter()) + for metric in [ + "weightedFalsePositiveRate", + "falsePositiveRateByLabel", + "logLoss", + "hammingLoss", + ]: + evaluator.setMetricName(metric) + self.assertTrue(not evaluator.isLargerBetter()) + def test_binary_classification_evaluator(self): # Define score and labels data data = map( @@ -180,6 +223,8 @@ def test_binary_classification_evaluator(self): dataset = self.spark.createDataFrame(data, ["raw", "label", "weight"]) evaluator = BinaryClassificationEvaluator().setRawPredictionCol("raw") + self.assertTrue(evaluator.isLargerBetter()) + auc_roc = evaluator.evaluate(dataset) self.assertTrue(np.allclose(auc_roc, 0.7083, atol=1e-4)) @@ -226,6 +271,8 @@ def test_clustering_evaluator(self): dataset = self.spark.createDataFrame(data, ["features", "prediction", "weight"]) evaluator = ClusteringEvaluator().setPredictionCol("prediction") + self.assertTrue(evaluator.isLargerBetter()) + score = evaluator.evaluate(dataset) self.assertTrue(np.allclose(score, 0.9079, atol=1e-4)) @@ -300,6 +347,13 @@ def test_regression_evaluator(self): through_origin = evaluator_with_weights.getThroughOrigin() self.assertEqual(through_origin, False) + for metric in ["mse", "rmse", "mae"]: + evaluator.setMetricName(metric) + self.assertTrue(not evaluator.isLargerBetter()) + for metric in ["r2", "var"]: + evaluator.setMetricName(metric) + self.assertTrue(evaluator.isLargerBetter()) + class EvaluatorTests(EvaluatorTestsMixin, unittest.TestCase): def setUp(self) -> None: diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 9766ab1b02438..a3dd889ba1f41 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -23,13 +23,20 @@ import numpy as np from pyspark.ml.feature import ( + DCT, Binarizer, + Bucketizer, CountVectorizer, CountVectorizerModel, + OneHotEncoder, + OneHotEncoderModel, HashingTF, IDF, NGram, RFormula, + Tokenizer, + SQLTransformer, + RegexTokenizer, StandardScaler, StandardScalerModel, MaxAbsScaler, @@ -38,10 +45,17 @@ MinMaxScalerModel, RobustScaler, RobustScalerModel, + ChiSqSelector, + ChiSqSelectorModel, + UnivariateFeatureSelector, + UnivariateFeatureSelectorModel, + VarianceThresholdSelector, + VarianceThresholdSelectorModel, StopWordsRemover, StringIndexer, StringIndexerModel, TargetEncoder, + TargetEncoderModel, VectorSizeHint, VectorAssembler, PCA, @@ -56,6 +70,34 @@ class FeatureTestsMixin: + def test_dct(self): + df = self.spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"]) + dct = DCT() + dct.setInverse(False) + dct.setInputCol("vec") + dct.setOutputCol("resultVec") + + self.assertFalse(dct.getInverse()) + self.assertEqual(dct.getInputCol(), "vec") + self.assertEqual(dct.getOutputCol(), "resultVec") + + output = dct.transform(df) + self.assertEqual(output.columns, ["vec", "resultVec"]) + self.assertEqual(output.count(), 1) + self.assertTrue( + np.allclose( + output.head().resultVec.toArray(), + [10.96965511, -0.70710678, -2.04124145], + atol=1e-4, + ) + ) + + # save & load + with tempfile.TemporaryDirectory(prefix="dct") as d: + dct.write().overwrite().save(d) + dct2 = DCT.load(d) + self.assertEqual(str(dct), str(dct2)) + def test_string_indexer(self): df = ( self.spark.createDataFrame( @@ -359,6 +401,102 @@ def test_robust_scaler(self): self.assertEqual(str(model), str(model2)) self.assertEqual(model2.getOutputCol(), "scaled") + def test_chi_sq_selector(self): + df = self.spark.createDataFrame( + [ + (Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), + (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), + (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0), + ], + ["features", "label"], + ) + + selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") + self.assertEqual(selector.getNumTopFeatures(), 1) + self.assertEqual(selector.getOutputCol(), "selectedFeatures") + + model = selector.fit(df) + self.assertEqual(model.selectedFeatures, [2]) + + output = model.transform(df) + self.assertEqual(output.columns, ["features", "label", "selectedFeatures"]) + self.assertEqual(output.count(), 3) + + # save & load + with tempfile.TemporaryDirectory(prefix="chi_sq_selector") as d: + selector.write().overwrite().save(d) + selector2 = ChiSqSelector.load(d) + self.assertEqual(str(selector), str(selector2)) + + model.write().overwrite().save(d) + model2 = ChiSqSelectorModel.load(d) + self.assertEqual(str(model), str(model2)) + + def test_univariate_selector(self): + df = self.spark.createDataFrame( + [ + (Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), + (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), + (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0), + ], + ["features", "label"], + ) + + selector = UnivariateFeatureSelector(outputCol="selectedFeatures") + selector.setFeatureType("continuous").setLabelType("categorical").setSelectionThreshold(1) + self.assertEqual(selector.getFeatureType(), "continuous") + self.assertEqual(selector.getLabelType(), "categorical") + self.assertEqual(selector.getOutputCol(), "selectedFeatures") + self.assertEqual(selector.getSelectionThreshold(), 1) + + model = selector.fit(df) + self.assertEqual(model.selectedFeatures, [3]) + + output = model.transform(df) + self.assertEqual(output.columns, ["features", "label", "selectedFeatures"]) + self.assertEqual(output.count(), 3) + + # save & load + with tempfile.TemporaryDirectory(prefix="univariate_selector") as d: + selector.write().overwrite().save(d) + selector2 = UnivariateFeatureSelector.load(d) + self.assertEqual(str(selector), str(selector2)) + + model.write().overwrite().save(d) + model2 = UnivariateFeatureSelectorModel.load(d) + self.assertEqual(str(model), str(model2)) + + def test_variance_threshold_selector(self): + df = self.spark.createDataFrame( + [ + (Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), + (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), + (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0), + ], + ["features", "label"], + ) + + selector = VarianceThresholdSelector(varianceThreshold=2, outputCol="selectedFeatures") + self.assertEqual(selector.getVarianceThreshold(), 2) + self.assertEqual(selector.getOutputCol(), "selectedFeatures") + + model = selector.fit(df) + self.assertEqual(model.selectedFeatures, [2]) + + output = model.transform(df) + self.assertEqual(output.columns, ["features", "label", "selectedFeatures"]) + self.assertEqual(output.count(), 3) + + # save & load + with tempfile.TemporaryDirectory(prefix="variance_threshold_selector") as d: + selector.write().overwrite().save(d) + selector2 = VarianceThresholdSelector.load(d) + self.assertEqual(str(selector), str(selector2)) + + model.write().overwrite().save(d) + model2 = VarianceThresholdSelectorModel.load(d) + self.assertEqual(str(model), str(model2)) + def test_word2vec(self): sent = ("a b " * 100 + "a c " * 10).split(" ") df = self.spark.createDataFrame([(sent,), (sent,)], ["sentence"]).coalesce(1) @@ -401,6 +539,176 @@ def test_word2vec(self): model2 = Word2VecModel.load(d) self.assertEqual(str(model), str(model2)) + def test_count_vectorizer(self): + df = self.spark.createDataFrame( + [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])], + ["label", "raw"], + ) + + cv = CountVectorizer() + cv.setInputCol("raw") + cv.setOutputCol("vectors") + self.assertEqual(cv.getInputCol(), "raw") + self.assertEqual(cv.getOutputCol(), "vectors") + + model = cv.fit(df) + self.assertEqual(sorted(model.vocabulary), ["a", "b", "c"]) + + output = model.transform(df) + self.assertEqual(output.columns, ["label", "raw", "vectors"]) + self.assertEqual(output.count(), 2) + + # save & load + with tempfile.TemporaryDirectory(prefix="count_vectorizer") as d: + cv.write().overwrite().save(d) + cv2 = CountVectorizer.load(d) + self.assertEqual(str(cv), str(cv2)) + + model.write().overwrite().save(d) + model2 = CountVectorizerModel.load(d) + self.assertEqual(str(model), str(model2)) + + def test_one_hot_encoder(self): + df = self.spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"]) + + encoder = OneHotEncoder() + encoder.setInputCols(["input"]) + encoder.setOutputCols(["output"]) + self.assertEqual(encoder.getInputCols(), ["input"]) + self.assertEqual(encoder.getOutputCols(), ["output"]) + + model = encoder.fit(df) + self.assertEqual(model.categorySizes, [3]) + + output = model.transform(df) + self.assertEqual(output.columns, ["input", "output"]) + self.assertEqual(output.count(), 3) + + # save & load + with tempfile.TemporaryDirectory(prefix="count_vectorizer") as d: + encoder.write().overwrite().save(d) + encoder2 = OneHotEncoder.load(d) + self.assertEqual(str(encoder), str(encoder2)) + + model.write().overwrite().save(d) + model2 = OneHotEncoderModel.load(d) + self.assertEqual(str(model), str(model2)) + + def test_tokenizer(self): + df = self.spark.createDataFrame([("a b c",)], ["text"]) + + tokenizer = Tokenizer(outputCol="words") + tokenizer.setInputCol("text") + self.assertEqual(tokenizer.getInputCol(), "text") + self.assertEqual(tokenizer.getOutputCol(), "words") + + output = tokenizer.transform(df) + self.assertEqual(output.columns, ["text", "words"]) + self.assertEqual(output.count(), 1) + self.assertEqual(output.head().words, ["a", "b", "c"]) + + # save & load + with tempfile.TemporaryDirectory(prefix="tokenizer") as d: + tokenizer.write().overwrite().save(d) + tokenizer2 = Tokenizer.load(d) + self.assertEqual(str(tokenizer), str(tokenizer2)) + + def test_regex_tokenizer(self): + df = self.spark.createDataFrame([("A B c",)], ["text"]) + + tokenizer = RegexTokenizer(outputCol="words") + tokenizer.setInputCol("text") + self.assertEqual(tokenizer.getInputCol(), "text") + self.assertEqual(tokenizer.getOutputCol(), "words") + + output = tokenizer.transform(df) + self.assertEqual(output.columns, ["text", "words"]) + self.assertEqual(output.count(), 1) + self.assertEqual(output.head().words, ["a", "b", "c"]) + + # save & load + with tempfile.TemporaryDirectory(prefix="regex_tokenizer") as d: + tokenizer.write().overwrite().save(d) + tokenizer2 = RegexTokenizer.load(d) + self.assertEqual(str(tokenizer), str(tokenizer2)) + + def test_sql_transformer(self): + df = self.spark.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"]) + + statement = "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__" + sql = SQLTransformer(statement=statement) + self.assertEqual(sql.getStatement(), statement) + + output = sql.transform(df) + self.assertEqual(output.columns, ["id", "v1", "v2", "v3", "v4"]) + self.assertEqual(output.count(), 2) + self.assertEqual( + output.collect(), + [ + Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0), + Row(id=2, v1=2.0, v2=5.0, v3=7.0, v4=10.0), + ], + ) + + # save & load + with tempfile.TemporaryDirectory(prefix="sql_transformer") as d: + sql.write().overwrite().save(d) + sql2 = SQLTransformer.load(d) + self.assertEqual(str(sql), str(sql2)) + + def test_stop_words_remover(self): + df = self.spark.createDataFrame([(["a", "b", "c"],)], ["text"]) + + remover = StopWordsRemover(stopWords=["b"]) + remover.setInputCol("text") + remover.setOutputCol("words") + + self.assertEqual(remover.getStopWords(), ["b"]) + self.assertEqual(remover.getInputCol(), "text") + self.assertEqual(remover.getOutputCol(), "words") + + output = remover.transform(df) + self.assertEqual(output.columns, ["text", "words"]) + self.assertEqual(output.count(), 1) + self.assertEqual(output.head().words, ["a", "c"]) + + # save & load + with tempfile.TemporaryDirectory(prefix="stop_words_remover") as d: + remover.write().overwrite().save(d) + remover2 = StopWordsRemover.load(d) + self.assertEqual(str(remover), str(remover2)) + + def test_stop_words_remover_II(self): + dataset = self.spark.createDataFrame([Row(input=["a", "panda"])]) + stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") + # Default + self.assertEqual(stopWordRemover.getInputCol(), "input") + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, ["panda"]) + self.assertEqual(type(stopWordRemover.getStopWords()), list) + self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], str)) + # Custom + stopwords = ["panda"] + stopWordRemover.setStopWords(stopwords) + self.assertEqual(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, ["a"]) + # with language selection + stopwords = StopWordsRemover.loadDefaultStopWords("turkish") + dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) + stopWordRemover.setStopWords(stopwords) + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) + # with locale + stopwords = ["BELKÄ°"] + dataset = self.spark.createDataFrame([Row(input=["belki"])]) + stopWordRemover.setStopWords(stopwords).setLocale("tr") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) + def test_binarizer(self): b0 = Binarizer() self.assertListEqual( @@ -427,6 +735,112 @@ def test_binarizer(self): self.assertEqual(b1.getInputCol(), "input") self.assertEqual(b1.getOutputCol(), "output") + df = self.spark.createDataFrame( + [ + (0.1, 0.0), + (0.4, 1.0), + (1.2, 1.3), + (1.5, float("nan")), + (float("nan"), 1.0), + (float("nan"), 0.0), + ], + ["v1", "v2"], + ) + + binarizer = Binarizer(threshold=1.0, inputCol="v1", outputCol="f1") + output = binarizer.transform(df) + self.assertEqual(output.columns, ["v1", "v2", "f1"]) + self.assertEqual(output.count(), 6) + self.assertEqual( + [r.f1 for r in output.select("f1").collect()], + [0.0, 0.0, 1.0, 1.0, 0.0, 0.0], + ) + + binarizer = Binarizer(threshold=1.0, inputCols=["v1", "v2"], outputCols=["f1", "f2"]) + output = binarizer.transform(df) + self.assertEqual(output.columns, ["v1", "v2", "f1", "f2"]) + self.assertEqual(output.count(), 6) + self.assertEqual( + [r.f1 for r in output.select("f1").collect()], + [0.0, 0.0, 1.0, 1.0, 0.0, 0.0], + ) + self.assertEqual( + [r.f2 for r in output.select("f2").collect()], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + ) + + # save & load + with tempfile.TemporaryDirectory(prefix="binarizer") as d: + binarizer.write().overwrite().save(d) + binarizer2 = Binarizer.load(d) + self.assertEqual(str(binarizer), str(binarizer2)) + + def test_bucketizer(self): + df = self.spark.createDataFrame( + [ + (0.1, 0.0), + (0.4, 1.0), + (1.2, 1.3), + (1.5, float("nan")), + (float("nan"), 1.0), + (float("nan"), 0.0), + ], + ["v1", "v2"], + ) + + splits = [-float("inf"), 0.5, 1.4, float("inf")] + bucketizer = Bucketizer() + bucketizer.setSplits(splits) + bucketizer.setHandleInvalid("keep") + bucketizer.setInputCol("v1") + bucketizer.setOutputCol("b1") + + self.assertEqual(bucketizer.getSplits(), splits) + self.assertEqual(bucketizer.getHandleInvalid(), "keep") + self.assertEqual(bucketizer.getInputCol(), "v1") + self.assertEqual(bucketizer.getOutputCol(), "b1") + + output = bucketizer.transform(df) + self.assertEqual(output.columns, ["v1", "v2", "b1"]) + self.assertEqual(output.count(), 6) + self.assertEqual( + [r.b1 for r in output.select("b1").collect()], + [0.0, 0.0, 1.0, 2.0, 3.0, 3.0], + ) + + splitsArray = [ + [-float("inf"), 0.5, 1.4, float("inf")], + [-float("inf"), 0.5, float("inf")], + ] + bucketizer = Bucketizer( + splitsArray=splitsArray, + inputCols=["v1", "v2"], + outputCols=["b1", "b2"], + ) + bucketizer.setHandleInvalid("keep") + self.assertEqual(bucketizer.getSplitsArray(), splitsArray) + self.assertEqual(bucketizer.getHandleInvalid(), "keep") + self.assertEqual(bucketizer.getInputCols(), ["v1", "v2"]) + self.assertEqual(bucketizer.getOutputCols(), ["b1", "b2"]) + + output = bucketizer.transform(df) + self.assertEqual(output.columns, ["v1", "v2", "b1", "b2"]) + self.assertEqual(output.count(), 6) + self.assertEqual( + [r.b1 for r in output.select("b1").collect()], + [0.0, 0.0, 1.0, 2.0, 3.0, 3.0], + ) + self.assertEqual( + [r.b2 for r in output.select("b2").collect()], + [0.0, 1.0, 1.0, 2.0, 1.0, 0.0], + ) + + # save & load + with tempfile.TemporaryDirectory(prefix="bucketizer") as d: + bucketizer.write().overwrite().save(d) + bucketizer2 = Bucketizer.load(d) + self.assertEqual(str(bucketizer), str(bucketizer2)) + def test_idf(self): dataset = self.spark.createDataFrame( [(DenseVector([1.0, 2.0]),), (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], @@ -454,37 +868,6 @@ def test_ngram(self): transformedDF = ngram0.transform(dataset) self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) - def test_stopwordsremover(self): - dataset = self.spark.createDataFrame([Row(input=["a", "panda"])]) - stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") - # Default - self.assertEqual(stopWordRemover.getInputCol(), "input") - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, ["panda"]) - self.assertEqual(type(stopWordRemover.getStopWords()), list) - self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], str)) - # Custom - stopwords = ["panda"] - stopWordRemover.setStopWords(stopwords) - self.assertEqual(stopWordRemover.getInputCol(), "input") - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, ["a"]) - # with language selection - stopwords = StopWordsRemover.loadDefaultStopWords("turkish") - dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) - stopWordRemover.setStopWords(stopwords) - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, []) - # with locale - stopwords = ["BELKÄ°"] - dataset = self.spark.createDataFrame([Row(input=["belki"])]) - stopWordRemover.setStopWords(stopwords).setLocale("tr") - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, []) - def test_count_vectorizer_with_binary(self): dataset = self.spark.createDataFrame( [ @@ -731,148 +1114,22 @@ def test_target_encoder_binary(self): targetType="binary", ) model = encoder.fit(df) - te = model.transform(df) - actual = te.drop("label").collect() - expected = [ - Row(input1=0, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), - Row(input1=1, input2=4, input3=5.0, output1=2.0 / 3, output2=1.0, output3=1.0 / 3), - Row(input1=2, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), - Row(input1=0, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), - Row(input1=1, input2=3, input3=6.0, output1=2.0 / 3, output2=0.0, output3=2.0 / 3), - Row(input1=2, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), - Row(input1=0, input2=3, input3=7.0, output1=1.0 / 3, output2=0.0, output3=0.0), - Row(input1=1, input2=4, input3=8.0, output1=2.0 / 3, output2=1.0, output3=1.0), - Row(input1=2, input2=3, input3=9.0, output1=1.0 / 3, output2=0.0, output3=0.0), - ] - self.assertEqual(actual, expected) - te = model.setSmoothing(1.0).transform(df) - actual = te.drop("label").collect() - expected = [ - Row( - input1=0, - input2=3, - input3=5.0, - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), - output2=(1 - 5 / 6) * (4 / 9), - output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), - ), - Row( - input1=1, - input2=4, - input3=5.0, - output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), - output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), - output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), - ), - Row( - input1=2, - input2=3, - input3=5.0, - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), - output2=(1 - 5 / 6) * (4 / 9), - output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), - ), - Row( - input1=0, - input2=4, - input3=6.0, - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), - output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), - output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), - ), - Row( - input1=1, - input2=3, - input3=6.0, - output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), - output2=(1 - 5 / 6) * (4 / 9), - output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), - ), - Row( - input1=2, - input2=4, - input3=6.0, - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), - output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), - output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), - ), - Row( - input1=0, - input2=3, - input3=7.0, - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), - output2=(1 - 5 / 6) * (4 / 9), - output3=(1 - 1 / 2) * (4 / 9), - ), - Row( - input1=1, - input2=4, - input3=8.0, - output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), - output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), - output3=(1 / 2) + (1 - 1 / 2) * (4 / 9), - ), - Row( - input1=2, - input2=3, - input3=9.0, - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), - output2=(1 - 5 / 6) * (4 / 9), - output3=(1 - 1 / 2) * (4 / 9), - ), - ] - self.assertEqual(actual, expected) - - def test_target_encoder_continuous(self): - df = self.spark.createDataFrame( - [ - (0, 3, 5.0, 10.0), - (1, 4, 5.0, 20.0), - (2, 3, 5.0, 30.0), - (0, 4, 6.0, 40.0), - (1, 3, 6.0, 50.0), - (2, 4, 6.0, 60.0), - (0, 3, 7.0, 70.0), - (1, 4, 8.0, 80.0), - (2, 3, 9.0, 90.0), - ], - schema="input1 short, input2 int, input3 double, label double", - ) - encoder = TargetEncoder( - inputCols=["input1", "input2", "input3"], - outputCols=["output", "output2", "output3"], - labelCol="label", - targetType="continuous", + output = model.transform(df) + self.assertEqual( + output.columns, + ["input1", "input2", "input3", "label", "output", "output2", "output3"], ) - model = encoder.fit(df) - te = model.transform(df) - actual = te.drop("label").collect() - expected = [ - Row(input1=0, input2=3, input3=5.0, output1=40.0, output2=50.0, output3=20.0), - Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0, output3=20.0), - Row(input1=2, input2=3, input3=5.0, output1=60.0, output2=50.0, output3=20.0), - Row(input1=0, input2=4, input3=6.0, output1=40.0, output2=50.0, output3=50.0), - Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0, output3=50.0), - Row(input1=2, input2=4, input3=6.0, output1=60.0, output2=50.0, output3=50.0), - Row(input1=0, input2=3, input3=7.0, output1=40.0, output2=50.0, output3=70.0), - Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0, output3=80.0), - Row(input1=2, input2=3, input3=9.0, output1=60.0, output2=50.0, output3=90.0), - ] - self.assertEqual(actual, expected) - te = model.setSmoothing(1.0).transform(df) - actual = te.drop("label").collect() - expected = [ - Row(input1=0, input2=3, input3=5.0, output1=42.5, output2=50.0, output3=27.5), - Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0, output3=27.5), - Row(input1=2, input2=3, input3=5.0, output1=57.5, output2=50.0, output3=27.5), - Row(input1=0, input2=4, input3=6.0, output1=42.5, output2=50.0, output3=50.0), - Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0, output3=50.0), - Row(input1=2, input2=4, input3=6.0, output1=57.5, output2=50.0, output3=50.0), - Row(input1=0, input2=3, input3=7.0, output1=42.5, output2=50.0, output3=60.0), - Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0, output3=65.0), - Row(input1=2, input2=3, input3=9.0, output1=57.5, output2=50.0, output3=70.0), - ] - self.assertEqual(actual, expected) + self.assertEqual(output.count(), 9) + + # save & load + with tempfile.TemporaryDirectory(prefix="target_encoder") as d: + encoder.write().overwrite().save(d) + encoder2 = TargetEncoder.load(d) + self.assertEqual(str(encoder), str(encoder2)) + + model.write().overwrite().save(d) + model2 = TargetEncoderModel.load(d) + self.assertEqual(str(model), str(model2)) def test_vector_size_hint(self): df = self.spark.createDataFrame( diff --git a/python/pyspark/pandas/indexing.py b/python/pyspark/pandas/indexing.py index c93366a31e315..d1ab273ef4f5e 100644 --- a/python/pyspark/pandas/indexing.py +++ b/python/pyspark/pandas/indexing.py @@ -29,6 +29,7 @@ from pyspark.sql import functions as F, Column as PySparkColumn from pyspark.sql.types import BooleanType, LongType, DataType +from pyspark.sql.utils import is_remote from pyspark.errors import AnalysisException from pyspark import pandas as ps # noqa: F401 from pyspark.pandas._typing import Label, Name, Scalar @@ -534,11 +535,19 @@ def __getitem__(self, key: Any) -> Union["Series", "DataFrame"]: sdf = sdf.limit(sdf.count() + limit) sdf = sdf.drop(NATURAL_ORDER_COLUMN_NAME) except AnalysisException: - raise KeyError( - "[{}] don't exist in columns".format( - [col._jc.toString() for col in data_spark_columns] - ) - ) + if is_remote(): + from pyspark.sql.connect.column import Column as ConnectColumn + + cols_as_str = [ + cast(ConnectColumn, col)._expr.__repr__() for col in data_spark_columns + ] + else: + from pyspark.sql.classic.column import Column as ClassicColumn + + cols_as_str = [ + cast(ClassicColumn, col)._jc.toString() for col in data_spark_columns + ] + raise KeyError("[{}] don't exist in columns".format(cols_as_str)) internal = InternalFrame( spark_frame=sdf, diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index 111bfd4630667..e17afa026c5af 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -957,6 +957,18 @@ def spark_column_equals(left: Column, right: Column) -> bool: ) return repr(left).replace("`", "") == repr(right).replace("`", "") else: + from pyspark.sql.classic.column import Column as ClassicColumn + + if not isinstance(left, ClassicColumn): + raise PySparkTypeError( + errorClass="NOT_COLUMN", + messageParameters={"arg_name": "left", "arg_type": type(left).__name__}, + ) + if not isinstance(right, ClassicColumn): + raise PySparkTypeError( + errorClass="NOT_COLUMN", + messageParameters={"arg_name": "right", "arg_type": type(right).__name__}, + ) return left._jc.equals(right._jc) diff --git a/python/pyspark/sql/classic/column.py b/python/pyspark/sql/classic/column.py index fe0e440203c36..161f8ba4bb7ab 100644 --- a/python/pyspark/sql/classic/column.py +++ b/python/pyspark/sql/classic/column.py @@ -177,13 +177,11 @@ def _reverse_op( @with_origin_to_class class Column(ParentColumn): - def __new__( - cls, - jc: "JavaObject", - ) -> "Column": - self = object.__new__(cls) - self.__init__(jc) # type: ignore[misc] - return self + def __new__(cls, *args: Any, **kwargs: Any) -> "Column": + return object.__new__(cls) + + def __getnewargs__(self) -> Tuple[Any, ...]: + return (self._jc,) def __init__(self, jc: "JavaObject") -> None: self._jc = jc diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index e5640dd81b1fb..a055e44564952 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -31,7 +31,6 @@ from pyspark.errors import PySparkValueError if TYPE_CHECKING: - from py4j.java_gateway import JavaObject from pyspark.sql._typing import LiteralType, DecimalLiteral, DateTimeLiteral from pyspark.sql.window import WindowSpec @@ -72,16 +71,10 @@ class Column(TableValuedFunctionArgument): # HACK ALERT!! this is to reduce the backward compatibility concern, and returns # Spark Classic Column by default. This is NOT an API, and NOT supposed to # be directly invoked. DO NOT use this constructor. - def __new__( - cls, - jc: "JavaObject", - ) -> "Column": + def __new__(cls, *args: Any, **kwargs: Any) -> "Column": from pyspark.sql.classic.column import Column - return Column.__new__(Column, jc) - - def __init__(self, jc: "JavaObject") -> None: - self._jc = jc + return Column.__new__(Column, *args, **kwargs) # arithmetic operators @dispatch_col_method diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index c5733801814eb..e6d58aefbf2f9 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -27,6 +27,7 @@ Any, Union, Optional, + Tuple, ) from pyspark.sql.column import Column as ParentColumn @@ -109,13 +110,11 @@ def _to_expr(v: Any) -> Expression: @with_origin_to_class(["to_plan"]) class Column(ParentColumn): - def __new__( - cls, - expr: "Expression", - ) -> "Column": - self = object.__new__(cls) - self.__init__(expr) # type: ignore[misc] - return self + def __new__(cls, *args: Any, **kwargs: Any) -> "Column": + return object.__new__(cls) + + def __getnewargs__(self) -> Tuple[Any, ...]: + return (self._expr,) def __init__(self, expr: "Expression") -> None: if not isinstance(expr, Expression): diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 343a68bf010bf..e0108da34f0c2 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -566,8 +566,8 @@ def handle_expired_timers( statefulProcessorApiClient.set_implicit_key(key_obj) for pd in statefulProcessor.handleExpiredTimer( key=key_obj, - timer_values=TimerValues(batch_timestamp, watermark_timestamp), - expired_timer_info=ExpiredTimerInfo(expiry_timestamp), + timerValues=TimerValues(batch_timestamp, watermark_timestamp), + expiredTimerInfo=ExpiredTimerInfo(expiry_timestamp), ): yield pd statefulProcessorApiClient.delete_timer(expiry_timestamp) diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index b04bb955488ab..ba2707ccfb892 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -45,33 +45,33 @@ class ValueState: .. versionadded:: 4.0.0 """ - def __init__(self, value_state_client: ValueStateClient, state_name: str) -> None: - self._value_state_client = value_state_client - self._state_name = state_name + def __init__(self, valueStateClient: ValueStateClient, stateName: str) -> None: + self._valueStateClient = valueStateClient + self._stateName = stateName def exists(self) -> bool: """ Whether state exists or not. """ - return self._value_state_client.exists(self._state_name) + return self._valueStateClient.exists(self._stateName) def get(self) -> Optional[Tuple]: """ Get the state value if it exists. Returns None if the state variable does not have a value. """ - return self._value_state_client.get(self._state_name) + return self._valueStateClient.get(self._stateName) - def update(self, new_value: Tuple) -> None: + def update(self, newValue: Tuple) -> None: """ Update the value of the state. """ - self._value_state_client.update(self._state_name, new_value) + self._valueStateClient.update(self._stateName, newValue) def clear(self) -> None: """ Remove this state. """ - self._value_state_client.clear(self._state_name) + self._valueStateClient.clear(self._stateName) class TimerValues: @@ -81,23 +81,21 @@ class TimerValues: .. versionadded:: 4.0.0 """ - def __init__( - self, current_processing_time_in_ms: int = -1, current_watermark_in_ms: int = -1 - ) -> None: - self._current_processing_time_in_ms = current_processing_time_in_ms - self._current_watermark_in_ms = current_watermark_in_ms + def __init__(self, currentProcessingTimeInMs: int = -1, currentWatermarkInMs: int = -1) -> None: + self._currentProcessingTimeInMs = currentProcessingTimeInMs + self._currentWatermarkInMs = currentWatermarkInMs - def get_current_processing_time_in_ms(self) -> int: + def getCurrentProcessingTimeInMs(self) -> int: """ Get processing time for current batch, return timestamp in millisecond. """ - return self._current_processing_time_in_ms + return self._currentProcessingTimeInMs - def get_current_watermark_in_ms(self) -> int: + def getCurrentWatermarkInMs(self) -> int: """ Get watermark for current batch, return timestamp in millisecond. """ - return self._current_watermark_in_ms + return self._currentWatermarkInMs class ExpiredTimerInfo: @@ -106,14 +104,14 @@ class ExpiredTimerInfo: .. versionadded:: 4.0.0 """ - def __init__(self, expiry_time_in_ms: int = -1) -> None: - self._expiry_time_in_ms = expiry_time_in_ms + def __init__(self, expiryTimeInMs: int = -1) -> None: + self._expiryTimeInMs = expiryTimeInMs - def get_expiry_time_in_ms(self) -> int: + def getExpiryTimeInMs(self) -> int: """ Get the timestamp for expired timer, return timestamp in millisecond. """ - return self._expiry_time_in_ms + return self._expiryTimeInMs class ListState: @@ -124,45 +122,45 @@ class ListState: .. versionadded:: 4.0.0 """ - def __init__(self, list_state_client: ListStateClient, state_name: str) -> None: - self._list_state_client = list_state_client - self._state_name = state_name + def __init__(self, listStateClient: ListStateClient, stateName: str) -> None: + self._listStateClient = listStateClient + self._stateName = stateName def exists(self) -> bool: """ Whether list state exists or not. """ - return self._list_state_client.exists(self._state_name) + return self._listStateClient.exists(self._stateName) def get(self) -> Iterator[Tuple]: """ Get list state with an iterator. """ - return ListStateIterator(self._list_state_client, self._state_name) + return ListStateIterator(self._listStateClient, self._stateName) - def put(self, new_state: List[Tuple]) -> None: + def put(self, newState: List[Tuple]) -> None: """ Update the values of the list state. """ - self._list_state_client.put(self._state_name, new_state) + self._listStateClient.put(self._stateName, newState) - def append_value(self, new_state: Tuple) -> None: + def appendValue(self, newState: Tuple) -> None: """ Append a new value to the list state. """ - self._list_state_client.append_value(self._state_name, new_state) + self._listStateClient.append_value(self._stateName, newState) - def append_list(self, new_state: List[Tuple]) -> None: + def appendList(self, newState: List[Tuple]) -> None: """ Append a list of new values to the list state. """ - self._list_state_client.append_list(self._state_name, new_state) + self._listStateClient.append_list(self._stateName, newState) def clear(self) -> None: """ Remove this state. """ - self._list_state_client.clear(self._state_name) + self._listStateClient.clear(self._stateName) class MapState: @@ -175,65 +173,65 @@ class MapState: def __init__( self, - map_state_client: MapStateClient, - state_name: str, + MapStateClient: MapStateClient, + stateName: str, ) -> None: - self._map_state_client = map_state_client - self._state_name = state_name + self._mapStateClient = MapStateClient + self._stateName = stateName def exists(self) -> bool: """ Whether state exists or not. """ - return self._map_state_client.exists(self._state_name) + return self._mapStateClient.exists(self._stateName) - def get_value(self, key: Tuple) -> Optional[Tuple]: + def getValue(self, key: Tuple) -> Optional[Tuple]: """ Get the state value for given user key if it exists. """ - return self._map_state_client.get_value(self._state_name, key) + return self._mapStateClient.get_value(self._stateName, key) - def contains_key(self, key: Tuple) -> bool: + def containsKey(self, key: Tuple) -> bool: """ Check if the user key is contained in the map. """ - return self._map_state_client.contains_key(self._state_name, key) + return self._mapStateClient.contains_key(self._stateName, key) - def update_value(self, key: Tuple, value: Tuple) -> None: + def updateValue(self, key: Tuple, value: Tuple) -> None: """ Update value for given user key. """ - return self._map_state_client.update_value(self._state_name, key, value) + return self._mapStateClient.update_value(self._stateName, key, value) def iterator(self) -> Iterator[Tuple[Tuple, Tuple]]: """ Get the map associated with grouping key. """ - return MapStateKeyValuePairIterator(self._map_state_client, self._state_name) + return MapStateKeyValuePairIterator(self._mapStateClient, self._stateName) def keys(self) -> Iterator[Tuple]: """ Get the list of keys present in map associated with grouping key. """ - return MapStateIterator(self._map_state_client, self._state_name, True) + return MapStateIterator(self._mapStateClient, self._stateName, True) def values(self) -> Iterator[Tuple]: """ Get the list of values present in map associated with grouping key. """ - return MapStateIterator(self._map_state_client, self._state_name, False) + return MapStateIterator(self._mapStateClient, self._stateName, False) - def remove_key(self, key: Tuple) -> None: + def removeKey(self, key: Tuple) -> None: """ Remove user key from map state. """ - return self._map_state_client.remove_key(self._state_name, key) + return self._mapStateClient.remove_key(self._stateName, key) def clear(self) -> None: """ Remove this state. """ - self._map_state_client.clear(self._state_name) + self._mapStateClient.clear(self._stateName) class StatefulProcessorHandle: @@ -244,11 +242,11 @@ class StatefulProcessorHandle: .. versionadded:: 4.0.0 """ - def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient) -> None: - self.stateful_processor_api_client = stateful_processor_api_client + def __init__(self, statefulProcessorApiClient: StatefulProcessorApiClient) -> None: + self._statefulProcessorApiClient = statefulProcessorApiClient def getValueState( - self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None + self, stateName: str, schema: Union[StructType, str], ttlDurationMs: Optional[int] = None ) -> ValueState: """ Function to create new or return existing single value state variable of given type. @@ -257,7 +255,7 @@ def getValueState( Parameters ---------- - state_name : str + stateName : str name of the state variable schema : :class:`pyspark.sql.types.DataType` or str The schema of the state variable. The value can be either a @@ -268,11 +266,11 @@ def getValueState( resets the expiration time to current processing time plus ttlDuration. If ttl is not specified the state will never expire. """ - self.stateful_processor_api_client.get_value_state(state_name, schema, ttl_duration_ms) - return ValueState(ValueStateClient(self.stateful_processor_api_client, schema), state_name) + self._statefulProcessorApiClient.get_value_state(stateName, schema, ttlDurationMs) + return ValueState(ValueStateClient(self._statefulProcessorApiClient, schema), stateName) def getListState( - self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None + self, stateName: str, schema: Union[StructType, str], ttlDurationMs: Optional[int] = None ) -> ListState: """ Function to create new or return existing single value state variable of given type. @@ -281,7 +279,7 @@ def getListState( Parameters ---------- - state_name : str + stateName : str name of the state variable schema : :class:`pyspark.sql.types.DataType` or str The schema of the state variable. The value can be either a @@ -292,15 +290,15 @@ def getListState( resets the expiration time to current processing time plus ttlDuration. If ttl is not specified the state will never expire. """ - self.stateful_processor_api_client.get_list_state(state_name, schema, ttl_duration_ms) - return ListState(ListStateClient(self.stateful_processor_api_client, schema), state_name) + self._statefulProcessorApiClient.get_list_state(stateName, schema, ttlDurationMs) + return ListState(ListStateClient(self._statefulProcessorApiClient, schema), stateName) def getMapState( self, - state_name: str, - user_key_schema: Union[StructType, str], - value_schema: Union[StructType, str], - ttl_duration_ms: Optional[int] = None, + stateName: str, + userKeySchema: Union[StructType, str], + valueSchema: Union[StructType, str], + ttlDurationMs: Optional[int] = None, ) -> MapState: """ Function to create new or return existing single map state variable of given type. @@ -309,51 +307,51 @@ def getMapState( Parameters ---------- - state_name : str + stateName : str name of the state variable - user_key_schema : :class:`pyspark.sql.types.DataType` or str + userKeySchema : :class:`pyspark.sql.types.DataType` or str The schema of the key of map state. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - value_schema : :class:`pyspark.sql.types.DataType` or str + valueSchema : :class:`pyspark.sql.types.DataType` or str The schema of the value of map state The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - ttl_duration_ms: int + ttlDurationMs: int Time to live duration of the state in milliseconds. State values will not be returned past ttlDuration and will be eventually removed from the state store. Any state update resets the expiration time to current processing time plus ttlDuration. If ttl is not specified the state will never expire. """ - self.stateful_processor_api_client.get_map_state( - state_name, user_key_schema, value_schema, ttl_duration_ms + self._statefulProcessorApiClient.get_map_state( + stateName, userKeySchema, valueSchema, ttlDurationMs ) return MapState( - MapStateClient(self.stateful_processor_api_client, user_key_schema, value_schema), - state_name, + MapStateClient(self._statefulProcessorApiClient, userKeySchema, valueSchema), + stateName, ) - def registerTimer(self, expiry_time_stamp_ms: int) -> None: + def registerTimer(self, expiryTimestampMs: int) -> None: """ Register a timer for a given expiry timestamp in milliseconds for the grouping key. """ - self.stateful_processor_api_client.register_timer(expiry_time_stamp_ms) + self._statefulProcessorApiClient.register_timer(expiryTimestampMs) - def deleteTimer(self, expiry_time_stamp_ms: int) -> None: + def deleteTimer(self, expiryTimestampMs: int) -> None: """ Delete a timer for a given expiry timestamp in milliseconds for the grouping key. """ - self.stateful_processor_api_client.delete_timer(expiry_time_stamp_ms) + self._statefulProcessorApiClient.delete_timer(expiryTimestampMs) def listTimers(self) -> Iterator[int]: """ List all timers of their expiry timestamps in milliseconds for the grouping key. """ - return ListTimerIterator(self.stateful_processor_api_client) + return ListTimerIterator(self._statefulProcessorApiClient) - def deleteIfExists(self, state_name: str) -> None: + def deleteIfExists(self, stateName: str) -> None: """ Function to delete and purge state variable if defined previously """ - self.stateful_processor_api_client.delete_if_exists(state_name) + self._statefulProcessorApiClient.delete_if_exists(stateName) class StatefulProcessor(ABC): @@ -383,7 +381,7 @@ def handleInputRows( self, key: Any, rows: Iterator["PandasDataFrameLike"], - timer_values: TimerValues, + timerValues: TimerValues, ) -> Iterator["PandasDataFrameLike"]: """ Function that will allow users to interact with input data rows along with the grouping key. @@ -402,14 +400,14 @@ def handleInputRows( grouping key. rows : iterable of :class:`pandas.DataFrame` iterator of input rows associated with grouping key - timer_values: TimerValues - Timer value for the current batch that process the input rows. - Users can get the processing or event time timestamp from TimerValues. + timerValues: TimerValues + Timer value for the current batch that process the input rows. + Users can get the processing or event time timestamp from TimerValues. """ ... def handleExpiredTimer( - self, key: Any, timer_values: TimerValues, expired_timer_info: ExpiredTimerInfo + self, key: Any, timerValues: TimerValues, expiredTimerInfo: ExpiredTimerInfo ) -> Iterator["PandasDataFrameLike"]: """ Optional to implement. Will act return an empty iterator if not defined. @@ -420,11 +418,11 @@ def handleExpiredTimer( ---------- key : Any grouping key. - timer_values: TimerValues - Timer value for the current batch that process the input rows. - Users can get the processing or event time timestamp from TimerValues. - expired_timer_info: ExpiredTimerInfo - Instance of ExpiredTimerInfo that provides access to expired timer. + timerValues: TimerValues + Timer value for the current batch that process the input rows. + Users can get the processing or event time timestamp from TimerValues. + expiredTimerInfo: ExpiredTimerInfo + Instance of ExpiredTimerInfo that provides access to expired timer. """ return iter([]) @@ -437,7 +435,7 @@ def close(self) -> None: ... def handleInitialState( - self, key: Any, initialState: "PandasDataFrameLike", timer_values: TimerValues + self, key: Any, initialState: "PandasDataFrameLike", timerValues: TimerValues ) -> None: """ Optional to implement. Will act as no-op if not defined or no initial state input. @@ -449,8 +447,8 @@ def handleInitialState( grouping key. initialState: :class:`pandas.DataFrame` One dataframe in the initial state associated with the key. - timer_values: TimerValues - Timer value for the current batch that process the input rows. - Users can get the processing or event time timestamp from TimerValues. + timerValues: TimerValues + Timer value for the current batch that process the input rows. + Users can get the processing or event time timestamp from TimerValues. """ pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index fec2e5d0caa2e..d554a0cb37d73 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -330,9 +330,8 @@ def check_results(batch_df, batch_id): SimpleTTLStatefulProcessor(), check_results, False, "processingTime" ) - @unittest.skipIf( - "COVERAGE_PROCESS_START" in os.environ, "Flaky with coverage enabled, skipping for now." - ) + # TODO SPARK-50908 holistic fix for TTL suite + @unittest.skip("test is flaky and it is only a timing issue, skipping until we can resolve") def test_value_state_ttl_expiration(self): def check_results(batch_df, batch_id): if batch_id == 0: @@ -1481,7 +1480,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.value_state = handle.getValueState("value_state", state_schema) self.handle = handle - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: exists = self.value_state.exists() if exists: value_row = self.value_state.get() @@ -1504,7 +1503,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: else: yield pd.DataFrame({"id": key, "value": str(accumulated_value)}) - def handleInitialState(self, key, initialState, timer_values) -> None: + def handleInitialState(self, key, initialState, timerValues) -> None: init_val = initialState.at[0, "initVal"] self.value_state.update((init_val,)) if len(key) == 1: @@ -1515,16 +1514,14 @@ def close(self) -> None: class StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitialState): - def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: - self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]: + self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs()) str_key = f"{str(key[0])}-expired" - yield pd.DataFrame( - {"id": (str_key,), "value": str(expired_timer_info.get_expiry_time_in_ms())} - ) + yield pd.DataFrame({"id": (str_key,), "value": str(expiredTimerInfo.getExpiryTimeInMs())}) - def handleInitialState(self, key, initialState, timer_values) -> None: - super().handleInitialState(key, initialState, timer_values) - self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() - 1) + def handleInitialState(self, key, initialState, timerValues) -> None: + super().handleInitialState(key, initialState, timerValues) + self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() - 1) class StatefulProcessorWithListStateInitialState(SimpleStatefulProcessorWithInitialState): @@ -1533,9 +1530,9 @@ def init(self, handle: StatefulProcessorHandle) -> None: list_ele_schema = StructType([StructField("value", IntegerType(), True)]) self.list_state = handle.getListState("list_state", list_ele_schema) - def handleInitialState(self, key, initialState, timer_values) -> None: + def handleInitialState(self, key, initialState, timerValues) -> None: for val in initialState["initVal"].tolist(): - self.list_state.append_value((val,)) + self.list_state.appendValue((val,)) # A stateful processor that output the max event time it has seen. Register timer for @@ -1546,15 +1543,15 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.handle = handle self.max_state = handle.getValueState("max_state", state_schema) - def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: + def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]: self.max_state.clear() - self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs()) str_key = f"{str(key[0])}-expired" yield pd.DataFrame( - {"id": (str_key,), "timestamp": str(expired_timer_info.get_expiry_time_in_ms())} + {"id": (str_key,), "timestamp": str(expiredTimerInfo.getExpiryTimeInMs())} ) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: timestamp_list = [] for pdf in rows: # int64 will represent timestamp in nanosecond, restore to second @@ -1567,7 +1564,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: max_event_time = str(max(cur_max, max(timestamp_list))) self.max_state.update((max_event_time,)) - self.handle.registerTimer(timer_values.get_current_watermark_in_ms()) + self.handle.registerTimer(timerValues.getCurrentWatermarkInMs()) yield pd.DataFrame({"id": key, "timestamp": max_event_time}) @@ -1583,7 +1580,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.handle = handle self.count_state = handle.getValueState("count_state", state_schema) - def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: + def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]: # reset count state each time the timer is expired timer_list_1 = [e for e in self.handle.listTimers()] timer_list_2 = [] @@ -1597,23 +1594,23 @@ def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[ if len(timer_list_1) > 0: assert len(timer_list_1) == 2 self.count_state.clear() - self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs()) yield pd.DataFrame( { "id": key, "countAsString": str("-1"), - "timeValues": str(expired_timer_info.get_expiry_time_in_ms()), + "timeValues": str(expiredTimerInfo.getExpiryTimeInMs()), } ) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: if not self.count_state.exists(): count = 0 else: count = int(self.count_state.get()[0]) if key == ("0",): - self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() + 1) + self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 1) rows_count = 0 for pdf in rows: @@ -1623,7 +1620,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = count + rows_count self.count_state.update((str(count),)) - timestamp = str(timer_values.get_current_processing_time_in_ms()) + timestamp = str(timerValues.getCurrentProcessingTimeInMs()) yield pd.DataFrame({"id": key, "countAsString": str(count), "timeValues": timestamp}) @@ -1642,7 +1639,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.temp_state = handle.getValueState("tempState", state_schema) handle.deleteIfExists("tempState") - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value state exists"): self.temp_state.exists() new_violations = 0 @@ -1674,7 +1671,7 @@ class StatefulProcessorChainingOps(StatefulProcessor): def init(self, handle: StatefulProcessorHandle) -> None: pass - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: timestamp_list = pdf["eventTime"].tolist() yield pd.DataFrame({"id": key, "outputTimestamp": timestamp_list[0]}) @@ -1704,7 +1701,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: "ttl-map-state", user_key_schema, state_schema, 10000 ) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: count = 0 ttl_count = 0 ttl_list_state_count = 0 @@ -1719,7 +1716,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: for s in iter: ttl_list_state_count += s[0] if self.ttl_map_state.exists(): - ttl_map_state_count = self.ttl_map_state.get_value(key)[0] + ttl_map_state_count = self.ttl_map_state.getValue(key)[0] for pdf in rows: pdf_count = pdf.count().get("temperature") count += pdf_count @@ -1732,7 +1729,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: if not (ttl_count == 2 and id == "0"): self.ttl_count_state.update((ttl_count,)) self.ttl_list_state.put([(ttl_list_state_count,), (ttl_list_state_count,)]) - self.ttl_map_state.update_value(key, (ttl_map_state_count,)) + self.ttl_map_state.updateValue(key, (ttl_map_state_count,)) yield pd.DataFrame( { "id": [ @@ -1754,7 +1751,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) self.num_violations_state = handle.getValueState("numViolations", state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: count = 0 exists = self.num_violations_state.exists() assert not exists @@ -1778,16 +1775,16 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.list_state1 = handle.getListState("listState1", state_schema) self.list_state2 = handle.getListState("listState2", state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: count = 0 for pdf in rows: list_state_rows = [(120,), (20,)] self.list_state1.put(list_state_rows) self.list_state2.put(list_state_rows) - self.list_state1.append_value((111,)) - self.list_state2.append_value((222,)) - self.list_state1.append_list(list_state_rows) - self.list_state2.append_list(list_state_rows) + self.list_state1.appendValue((111,)) + self.list_state2.appendValue((222,)) + self.list_state1.appendList(list_state_rows) + self.list_state2.appendList(list_state_rows) pdf_count = pdf.count() count += pdf_count.get("temperature") iter1 = self.list_state1.get() @@ -1832,7 +1829,7 @@ def init(self, handle: StatefulProcessorHandle): # Test string type schemas self.map_state = handle.getMapState("mapState", "name string", "count int") - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: count = 0 key1 = ("key1",) key2 = ("key2",) @@ -1842,12 +1839,12 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: value1 = count value2 = count if self.map_state.exists(): - if self.map_state.contains_key(key1): - value1 += self.map_state.get_value(key1)[0] - if self.map_state.contains_key(key2): - value2 += self.map_state.get_value(key2)[0] - self.map_state.update_value(key1, (value1,)) - self.map_state.update_value(key2, (value2,)) + if self.map_state.containsKey(key1): + value1 += self.map_state.getValue(key1)[0] + if self.map_state.containsKey(key2): + value2 += self.map_state.getValue(key2)[0] + self.map_state.updateValue(key1, (value1,)) + self.map_state.updateValue(key2, (value2,)) key_iter = self.map_state.keys() assert next(key_iter)[0] == "key1" assert next(key_iter)[0] == "key2" @@ -1857,8 +1854,8 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: map_iter = self.map_state.iterator() assert next(map_iter)[0] == key1 assert next(map_iter)[1] == (value2,) - self.map_state.remove_key(key1) - assert not self.map_state.contains_key(key1) + self.map_state.removeKey(key1) + assert not self.map_state.containsKey(key1) yield pd.DataFrame({"id": key, "countAsString": str(count)}) def close(self) -> None: @@ -1884,7 +1881,7 @@ class BasicProcessor(StatefulProcessor): def init(self, handle): self.state = handle.getValueState("state", self.state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: pass id_val = int(key[0]) @@ -1910,7 +1907,7 @@ class AddFieldsProcessor(StatefulProcessor): def init(self, handle): self.state = handle.getValueState("state", self.state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: pass id_val = int(key[0]) @@ -1958,7 +1955,7 @@ class RemoveFieldsProcessor(StatefulProcessor): def init(self, handle): self.state = handle.getValueState("state", self.state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: pass id_val = int(key[0]) @@ -1986,7 +1983,7 @@ class ReorderedFieldsProcessor(StatefulProcessor): def init(self, handle): self.state = handle.getValueState("state", self.state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: pass id_val = int(key[0]) @@ -2035,7 +2032,7 @@ class UpcastProcessor(StatefulProcessor): def init(self, handle): self.state = handle.getValueState("state", self.state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: pass id_val = int(key[0]) diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala index 4d3465b320391..8548801266b26 100644 --- a/repl/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -25,6 +25,7 @@ import scala.tools.nsc.GenericRunnerSettings import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession.hiveClassesArePresent import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.Utils @@ -108,7 +109,7 @@ object Main extends Logging { val builder = SparkSession.builder().config(conf) if (conf.get(CATALOG_IMPLEMENTATION.key, "hive") == "hive") { - if (SparkSession.hiveClassesArePresent) { + if (hiveClassesArePresent) { // In the case that the property is not set at all, builder's config // does not have this value set to 'hive' yet. The original default // behavior is that when there are hive classes, we use hive catalog. diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 327ef3d074207..8969bc8b5e2b9 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -25,7 +25,7 @@ import org.apache.logging.log4j.core.{Logger, LoggerContext} import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION class ReplSuite extends SparkFunSuite { diff --git a/sql/api/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/api/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index d4e1d89491f43..271b38c2ead13 100644 --- a/sql/api/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/api/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -21,7 +21,6 @@ import java.util.Iterator; import org.apache.spark.annotation.Evolving; -import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.streaming.GroupState; /** @@ -32,7 +31,6 @@ * org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)} * @since 2.1.1 */ -@Experimental @Evolving public interface FlatMapGroupsWithStateFunction extends Serializable { Iterator call(K key, Iterator values, GroupState state) throws Exception; diff --git a/sql/api/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/api/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java index f0abfde843cc5..d468a3303b122 100644 --- a/sql/api/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java +++ b/sql/api/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -21,7 +21,6 @@ import java.util.Iterator; import org.apache.spark.annotation.Evolving; -import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.streaming.GroupState; /** @@ -31,7 +30,6 @@ * MapGroupsWithStateFunction, org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)} * @since 2.1.1 */ -@Experimental @Evolving public interface MapGroupsWithStateFunction extends Serializable { R call(K key, Iterator values, GroupState state) throws Exception; diff --git a/sql/api/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java b/sql/api/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java index 91a1231ec0303..e37a5dff01fd6 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java +++ b/sql/api/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java @@ -25,7 +25,7 @@ import org.apache.spark.sql.internal.TypedSumLong; /** - * Type-safe functions available for {@link org.apache.spark.sql.api.Dataset} operations in Java. + * Type-safe functions available for {@link org.apache.spark.sql.Dataset} operations in Java. * * Scala users should use {@link org.apache.spark.sql.expressions.scalalang.typed}. * diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index ee51ddb0e1ef5..737602b90a649 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.streaming; import org.apache.spark.annotation.Evolving; -import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.catalyst.plans.logical.*; /** @@ -29,7 +28,6 @@ * * @since 2.2.0 */ -@Experimental @Evolving public class GroupStateTimeout { // NOTE: if you're adding new type of timeout, you should also fix the places below: diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeMode.java index a45a31bd1a05c..128519f6de883 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeMode.java +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeMode.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.streaming; import org.apache.spark.annotation.Evolving; -import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.catalyst.plans.logical.EventTime$; import org.apache.spark.sql.catalyst.plans.logical.NoTime$; import org.apache.spark.sql.catalyst.plans.logical.ProcessingTime$; @@ -27,7 +26,6 @@ * Represents the time modes (used for specifying timers and ttl) possible for * the Dataset operations {@code transformWithState}. */ -@Experimental @Evolving public class TimeMode { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala similarity index 85% rename from sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala rename to sql/api/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ef6cc64c058a4..005e418a3b29e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -14,14 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql -import scala.jdk.CollectionConverters._ +import java.util -import _root_.java.util +import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.Row import org.apache.spark.util.ArrayImplicits._ /** @@ -37,7 +36,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(): Dataset[Row] = drop("any") + def drop(): DataFrame = drop("any") /** * Returns a new `DataFrame` that drops rows containing null or NaN values. @@ -47,7 +46,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(how: String): Dataset[Row] = drop(toMinNonNulls(how)) + def drop(how: String): DataFrame = drop(toMinNonNulls(how)) /** * Returns a new `DataFrame` that drops rows containing any null or NaN values in the specified @@ -55,7 +54,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(cols: Array[String]): Dataset[Row] = drop(cols.toImmutableArraySeq) + def drop(cols: Array[String]): DataFrame = drop(cols.toImmutableArraySeq) /** * (Scala-specific) Returns a new `DataFrame` that drops rows containing any null or NaN values @@ -63,7 +62,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(cols: Seq[String]): Dataset[Row] = drop(cols.size, cols) + def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols) /** * Returns a new `DataFrame` that drops rows containing null or NaN values in the specified @@ -74,7 +73,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(how: String, cols: Array[String]): Dataset[Row] = drop(how, cols.toImmutableArraySeq) + def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toImmutableArraySeq) /** * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values in @@ -85,7 +84,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(how: String, cols: Seq[String]): Dataset[Row] = drop(toMinNonNulls(how), cols) + def drop(how: String, cols: Seq[String]): DataFrame = drop(toMinNonNulls(how), cols) /** * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and @@ -93,7 +92,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(minNonNulls: Int): Dataset[Row] = drop(Option(minNonNulls)) + def drop(minNonNulls: Int): DataFrame = drop(Option(minNonNulls)) /** * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and @@ -101,7 +100,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(minNonNulls: Int, cols: Array[String]): Dataset[Row] = + def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toImmutableArraySeq) /** @@ -110,7 +109,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def drop(minNonNulls: Int, cols: Seq[String]): Dataset[Row] = drop(Option(minNonNulls), cols) + def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = drop(Option(minNonNulls), cols) private def toMinNonNulls(how: String): Option[Int] = { how.toLowerCase(util.Locale.ROOT) match { @@ -120,29 +119,29 @@ abstract class DataFrameNaFunctions { } } - protected def drop(minNonNulls: Option[Int]): Dataset[Row] + protected def drop(minNonNulls: Option[Int]): DataFrame - protected def drop(minNonNulls: Option[Int], cols: Seq[String]): Dataset[Row] + protected def drop(minNonNulls: Option[Int], cols: Seq[String]): DataFrame /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * * @since 2.2.0 */ - def fill(value: Long): Dataset[Row] + def fill(value: Long): DataFrame /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * @since 1.3.1 */ - def fill(value: Double): Dataset[Row] + def fill(value: Double): DataFrame /** * Returns a new `DataFrame` that replaces null values in string columns with `value`. * * @since 1.3.1 */ - def fill(value: String): Dataset[Row] + def fill(value: String): DataFrame /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a @@ -150,7 +149,7 @@ abstract class DataFrameNaFunctions { * * @since 2.2.0 */ - def fill(value: Long, cols: Array[String]): Dataset[Row] = fill(value, cols.toImmutableArraySeq) + def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a @@ -158,7 +157,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(value: Double, cols: Array[String]): Dataset[Row] = + def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) /** @@ -167,7 +166,7 @@ abstract class DataFrameNaFunctions { * * @since 2.2.0 */ - def fill(value: Long, cols: Seq[String]): Dataset[Row] + def fill(value: Long, cols: Seq[String]): DataFrame /** * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified @@ -175,7 +174,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(value: Double, cols: Seq[String]): Dataset[Row] + def fill(value: Double, cols: Seq[String]): DataFrame /** * Returns a new `DataFrame` that replaces null values in specified string columns. If a @@ -183,7 +182,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(value: String, cols: Array[String]): Dataset[Row] = + def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) /** @@ -192,14 +191,14 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(value: String, cols: Seq[String]): Dataset[Row] + def fill(value: String, cols: Seq[String]): DataFrame /** * Returns a new `DataFrame` that replaces null values in boolean columns with `value`. * * @since 2.3.0 */ - def fill(value: Boolean): Dataset[Row] + def fill(value: Boolean): DataFrame /** * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified boolean @@ -207,7 +206,7 @@ abstract class DataFrameNaFunctions { * * @since 2.3.0 */ - def fill(value: Boolean, cols: Seq[String]): Dataset[Row] + def fill(value: Boolean, cols: Seq[String]): DataFrame /** * Returns a new `DataFrame` that replaces null values in specified boolean columns. If a @@ -215,7 +214,7 @@ abstract class DataFrameNaFunctions { * * @since 2.3.0 */ - def fill(value: Boolean, cols: Array[String]): Dataset[Row] = + def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) /** @@ -234,7 +233,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(valueMap: util.Map[String, Any]): Dataset[Row] = fillMap(valueMap.asScala.toSeq) + def fill(valueMap: util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values. @@ -254,9 +253,9 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def fill(valueMap: Map[String, Any]): Dataset[Row] = fillMap(valueMap.toSeq) + def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq) - protected def fillMap(values: Seq[(String, Any)]): Dataset[Row] + protected def fillMap(values: Seq[(String, Any)]): DataFrame /** * Replaces values matching keys in `replacement` map with the corresponding values. @@ -283,7 +282,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def replace[T](col: String, replacement: util.Map[T, T]): Dataset[Row] = { + def replace[T](col: String, replacement: util.Map[T, T]): DataFrame = { replace[T](col, replacement.asScala.toMap) } @@ -309,7 +308,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def replace[T](cols: Array[String], replacement: util.Map[T, T]): Dataset[Row] = { + def replace[T](cols: Array[String], replacement: util.Map[T, T]): DataFrame = { replace(cols.toImmutableArraySeq, replacement.asScala.toMap) } @@ -336,7 +335,7 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def replace[T](col: String, replacement: Map[T, T]): Dataset[Row] + def replace[T](col: String, replacement: Map[T, T]): DataFrame /** * (Scala-specific) Replaces values matching keys in `replacement` map. @@ -358,5 +357,5 @@ abstract class DataFrameNaFunctions { * * @since 1.3.1 */ - def replace[T](cols: Seq[String], replacement: Map[T, T]): Dataset[Row] + def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameReader.scala similarity index 94% rename from sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala rename to sql/api/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 8c88387714228..95f1eca665784 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -14,16 +14,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql -import scala.jdk.CollectionConverters._ +import java.util -import _root_.java.util +import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils} import org.apache.spark.sql.errors.DataTypeErrors @@ -37,7 +36,6 @@ import org.apache.spark.sql.types.StructType */ @Stable abstract class DataFrameReader { - type DS[U] <: Dataset[U] /** * Specifies the input data source format. @@ -152,7 +150,7 @@ abstract class DataFrameReader { * * @since 1.4.0 */ - def load(): Dataset[Row] + def load(): DataFrame /** * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a @@ -160,7 +158,7 @@ abstract class DataFrameReader { * * @since 1.4.0 */ - def load(path: String): Dataset[Row] + def load(path: String): DataFrame /** * Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if @@ -169,7 +167,7 @@ abstract class DataFrameReader { * @since 1.6.0 */ @scala.annotation.varargs - def load(paths: String*): Dataset[Row] + def load(paths: String*): DataFrame /** * Construct a `DataFrame` representing the database table accessible via JDBC URL url named @@ -182,7 +180,7 @@ abstract class DataFrameReader { * * @since 1.4.0 */ - def jdbc(url: String, table: String, properties: util.Properties): Dataset[Row] = { + def jdbc(url: String, table: String, properties: util.Properties): DataFrame = { assertNoSpecifiedSchema("jdbc") // properties should override settings in extraOptions. this.extraOptions ++= properties.asScala @@ -226,7 +224,7 @@ abstract class DataFrameReader { lowerBound: Long, upperBound: Long, numPartitions: Int, - connectionProperties: util.Properties): Dataset[Row] = { + connectionProperties: util.Properties): DataFrame = { // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions. this.extraOptions ++= Map( "partitionColumn" -> columnName, @@ -263,7 +261,7 @@ abstract class DataFrameReader { url: String, table: String, predicates: Array[String], - connectionProperties: util.Properties): Dataset[Row] + connectionProperties: util.Properties): DataFrame /** * Loads a JSON file and returns the results as a `DataFrame`. @@ -272,7 +270,7 @@ abstract class DataFrameReader { * * @since 1.4.0 */ - def json(path: String): Dataset[Row] = { + def json(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 json(Seq(path): _*) } @@ -293,7 +291,7 @@ abstract class DataFrameReader { * @since 2.0.0 */ @scala.annotation.varargs - def json(paths: String*): Dataset[Row] = { + def json(paths: String*): DataFrame = { validateJsonSchema() format("json").load(paths: _*) } @@ -309,7 +307,7 @@ abstract class DataFrameReader { * input Dataset with one JSON object per record * @since 2.2.0 */ - def json(jsonDataset: DS[String]): Dataset[Row] + def json(jsonDataset: Dataset[String]): DataFrame /** * Loads a `JavaRDD[String]` storing JSON objects (JSON Lines @@ -325,7 +323,7 @@ abstract class DataFrameReader { * @since 1.4.0 */ @deprecated("Use json(Dataset[String]) instead.", "2.2.0") - def json(jsonRDD: JavaRDD[String]): DS[Row] + def json(jsonRDD: JavaRDD[String]): DataFrame /** * Loads an `RDD[String]` storing JSON objects (JSON Lines text @@ -341,7 +339,7 @@ abstract class DataFrameReader { * @since 1.4.0 */ @deprecated("Use json(Dataset[String]) instead.", "2.2.0") - def json(jsonRDD: RDD[String]): DS[Row] + def json(jsonRDD: RDD[String]): DataFrame /** * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other @@ -349,7 +347,7 @@ abstract class DataFrameReader { * * @since 2.0.0 */ - def csv(path: String): Dataset[Row] = { + def csv(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 csv(Seq(path): _*) } @@ -375,7 +373,7 @@ abstract class DataFrameReader { * input Dataset with one CSV row per record * @since 2.2.0 */ - def csv(csvDataset: DS[String]): Dataset[Row] + def csv(csvDataset: Dataset[String]): DataFrame /** * Loads CSV files and returns the result as a `DataFrame`. @@ -391,7 +389,7 @@ abstract class DataFrameReader { * @since 2.0.0 */ @scala.annotation.varargs - def csv(paths: String*): Dataset[Row] = format("csv").load(paths: _*) + def csv(paths: String*): DataFrame = format("csv").load(paths: _*) /** * Loads a XML file and returns the result as a `DataFrame`. See the documentation on the other @@ -399,7 +397,7 @@ abstract class DataFrameReader { * * @since 4.0.0 */ - def xml(path: String): Dataset[Row] = { + def xml(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 xml(Seq(path): _*) } @@ -418,7 +416,7 @@ abstract class DataFrameReader { * @since 4.0.0 */ @scala.annotation.varargs - def xml(paths: String*): Dataset[Row] = { + def xml(paths: String*): DataFrame = { validateXmlSchema() format("xml").load(paths: _*) } @@ -433,7 +431,7 @@ abstract class DataFrameReader { * input Dataset with one XML object per record * @since 4.0.0 */ - def xml(xmlDataset: DS[String]): Dataset[Row] + def xml(xmlDataset: Dataset[String]): DataFrame /** * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation on the @@ -441,7 +439,7 @@ abstract class DataFrameReader { * * @since 2.0.0 */ - def parquet(path: String): Dataset[Row] = { + def parquet(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 parquet(Seq(path): _*) } @@ -456,7 +454,7 @@ abstract class DataFrameReader { * @since 1.4.0 */ @scala.annotation.varargs - def parquet(paths: String*): Dataset[Row] = format("parquet").load(paths: _*) + def parquet(paths: String*): DataFrame = format("parquet").load(paths: _*) /** * Loads an ORC file and returns the result as a `DataFrame`. @@ -465,7 +463,7 @@ abstract class DataFrameReader { * input path * @since 1.5.0 */ - def orc(path: String): Dataset[Row] = { + def orc(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 orc(Seq(path): _*) } @@ -482,7 +480,7 @@ abstract class DataFrameReader { * @since 2.0.0 */ @scala.annotation.varargs - def orc(paths: String*): Dataset[Row] = format("orc").load(paths: _*) + def orc(paths: String*): DataFrame = format("orc").load(paths: _*) /** * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch @@ -497,7 +495,7 @@ abstract class DataFrameReader { * database. Note that, the global temporary view database is also valid here. * @since 1.4.0 */ - def table(tableName: String): Dataset[Row] + def table(tableName: String): DataFrame /** * Loads text files and returns a `DataFrame` whose schema starts with a string column named @@ -506,7 +504,7 @@ abstract class DataFrameReader { * * @since 2.0.0 */ - def text(path: String): Dataset[Row] = { + def text(path: String): DataFrame = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 text(Seq(path): _*) } @@ -534,7 +532,7 @@ abstract class DataFrameReader { * @since 1.6.0 */ @scala.annotation.varargs - def text(paths: String*): Dataset[Row] = format("text").load(paths: _*) + def text(paths: String*): DataFrame = format("text").load(paths: _*) /** * Loads text files and returns a [[Dataset]] of String. See the documentation on the other diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala similarity index 97% rename from sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala rename to sql/api/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index ae7c256b30ace..db74282dc1188 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -14,14 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql -import scala.jdk.CollectionConverters._ +import java.{lang => jl, util => ju} -import _root_.java.{lang => jl, util => ju} +import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.BinaryEncoder import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.functions.{count_min_sketch, lit} @@ -35,7 +34,7 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} */ @Stable abstract class DataFrameStatFunctions { - protected def df: Dataset[Row] + protected def df: DataFrame /** * Calculates the approximate quantiles of a numerical column of a DataFrame. @@ -202,7 +201,7 @@ abstract class DataFrameStatFunctions { * * @since 1.4.0 */ - def crosstab(col1: String, col2: String): Dataset[Row] + def crosstab(col1: String, col2: String): DataFrame /** * Finding frequent items for columns, possibly with false positives. Using the frequent element @@ -246,7 +245,7 @@ abstract class DataFrameStatFunctions { * }}} * @since 1.4.0 */ - def freqItems(cols: Array[String], support: Double): Dataset[Row] = + def freqItems(cols: Array[String], support: Double): DataFrame = freqItems(cols.toImmutableArraySeq, support) /** @@ -263,7 +262,7 @@ abstract class DataFrameStatFunctions { * A Local DataFrame with the Array of frequent items for each column. * @since 1.4.0 */ - def freqItems(cols: Array[String]): Dataset[Row] = freqItems(cols, 0.01) + def freqItems(cols: Array[String]): DataFrame = freqItems(cols, 0.01) /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the @@ -307,7 +306,7 @@ abstract class DataFrameStatFunctions { * * @since 1.4.0 */ - def freqItems(cols: Seq[String], support: Double): Dataset[Row] + def freqItems(cols: Seq[String], support: Double): DataFrame /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the @@ -324,7 +323,7 @@ abstract class DataFrameStatFunctions { * A Local DataFrame with the Array of frequent items for each column. * @since 1.4.0 */ - def freqItems(cols: Seq[String]): Dataset[Row] = freqItems(cols, 0.01) + def freqItems(cols: Seq[String]): DataFrame = freqItems(cols, 0.01) /** * Returns a stratified sample without replacement based on the fraction given on each stratum. @@ -356,7 +355,7 @@ abstract class DataFrameStatFunctions { * * @since 1.5.0 */ - def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): Dataset[Row] = { + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { sampleBy(Column(col), fractions, seed) } @@ -376,7 +375,7 @@ abstract class DataFrameStatFunctions { * * @since 1.5.0 */ - def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): Dataset[Row] = { + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } @@ -413,7 +412,7 @@ abstract class DataFrameStatFunctions { * * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): Dataset[Row] + def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame /** * (Java-specific) Returns a stratified sample without replacement based on the fraction given @@ -432,7 +431,7 @@ abstract class DataFrameStatFunctions { * a new `DataFrame` that represents the stratified sample * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): Dataset[Row] = { + def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 565b64aa95472..0bde59155f3f3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.CompilationErrors /** - * Interface used to write a [[org.apache.spark.sql.api.Dataset]] to external storage systems - * (e.g. file systems, key-value stores, etc). Use `Dataset.write` to access this. + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage systems (e.g. + * file systems, key-value stores, etc). Use `Dataset.write` to access this. * * @since 1.4.0 */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 37a29c2e4b66d..66a4b4232a22d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} /** - * Interface used to write a [[org.apache.spark.sql.api.Dataset]] to external storage using the v2 + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 * API. * * @since 3.0.0 diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala similarity index 95% rename from sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala rename to sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala index 20c181e7b9cf6..c49a5d5a50886 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -14,20 +14,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql + +import java.util import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag -import _root_.java.util - import org.apache.spark.annotation.{DeveloperApi, Stable, Unstable} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction, ReduceFunction} +import org.apache.spark.api.java.function._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, DataFrameWriterV2, Encoder, MergeIntoWriter, Observation, Row, TypedColumn} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors} +import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types.{Metadata, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ @@ -123,9 +123,8 @@ import org.apache.spark.util.SparkClassUtils */ @Stable abstract class Dataset[T] extends Serializable { - type DS[U] <: Dataset[U] - def sparkSession: SparkSession + val sparkSession: SparkSession val encoder: Encoder[T] @@ -141,7 +140,7 @@ abstract class Dataset[T] extends Serializable { */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): Dataset[Row] + def toDF(): DataFrame /** * Returns a new Dataset where each record has been mapped on to the specified type. The method @@ -180,7 +179,7 @@ abstract class Dataset[T] extends Serializable { * @group basic * @since 3.4.0 */ - def to(schema: StructType): Dataset[Row] + def to(schema: StructType): DataFrame /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -196,7 +195,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def toDF(colNames: String*): Dataset[Row] + def toDF(colNames: String*): DataFrame /** * Returns the schema of this Dataset. @@ -611,7 +610,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_]): Dataset[Row] + def join(right: Dataset[_]): DataFrame /** * Inner equi-join with another `DataFrame` using the given column. @@ -637,7 +636,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumn: String): Dataset[Row] = { + def join(right: Dataset[_], usingColumn: String): DataFrame = { join(right, Seq(usingColumn)) } @@ -653,7 +652,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumns: Array[String]): Dataset[Row] = { + def join(right: Dataset[_], usingColumns: Array[String]): DataFrame = { join(right, usingColumns.toImmutableArraySeq) } @@ -681,7 +680,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumns: Seq[String]): Dataset[Row] = { + def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame = { join(right, usingColumns, "inner") } @@ -711,7 +710,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumn: String, joinType: String): Dataset[Row] = { + def join(right: Dataset[_], usingColumn: String, joinType: String): DataFrame = { join(right, Seq(usingColumn), joinType) } @@ -732,7 +731,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumns: Array[String], joinType: String): Dataset[Row] = { + def join(right: Dataset[_], usingColumns: Array[String], joinType: String): DataFrame = { join(right, usingColumns.toImmutableArraySeq, joinType) } @@ -762,7 +761,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumns: Seq[String], joinType: String): Dataset[Row] + def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame /** * Inner join with another `DataFrame`, using the given join expression. @@ -776,7 +775,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], joinExprs: Column): Dataset[Row] = + def join(right: Dataset[_], joinExprs: Column): DataFrame = join(right, joinExprs, "inner") /** @@ -806,7 +805,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], joinExprs: Column, joinType: String): Dataset[Row] + def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame /** * Explicit cartesian join with another `DataFrame`. @@ -818,7 +817,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.1.0 */ - def crossJoin(right: DS[_]): Dataset[Row] + def crossJoin(right: Dataset[_]): DataFrame /** * Joins this Dataset returning a `Tuple2` for each pair where `condition` evaluates to true. @@ -842,7 +841,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: DS[U], condition: Column, joinType: String): Dataset[(T, U)] + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] /** * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair where @@ -855,7 +854,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: DS[U], condition: Column): Dataset[(T, U)] = { + def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { joinWith(other, condition, "inner") } @@ -869,7 +868,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def lateralJoin(right: DS[_]): Dataset[Row] + def lateralJoin(right: Dataset[_]): DataFrame /** * Lateral join with another `DataFrame`. @@ -883,7 +882,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def lateralJoin(right: DS[_], joinExprs: Column): Dataset[Row] + def lateralJoin(right: Dataset[_], joinExprs: Column): DataFrame /** * Lateral join with another `DataFrame`. @@ -896,7 +895,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def lateralJoin(right: DS[_], joinType: String): Dataset[Row] + def lateralJoin(right: Dataset[_], joinType: String): DataFrame /** * Lateral join with another `DataFrame`. @@ -911,7 +910,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def lateralJoin(right: DS[_], joinExprs: Column, joinType: String): Dataset[Row] + def lateralJoin(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame protected def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] @@ -1101,7 +1100,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def select(cols: Column*): Dataset[Row] + def select(cols: Column*): DataFrame /** * Selects a set of columns. This is a variant of `select` that can only select existing columns @@ -1117,7 +1116,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): Dataset[Row] = select((col +: cols).map(Column(_)): _*) + def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)): _*) /** * Selects a set of SQL expressions. This is a variant of `select` that accepts SQL expressions. @@ -1132,7 +1131,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def selectExpr(exprs: String*): Dataset[Row] = select(exprs.map(functions.expr): _*) + def selectExpr(exprs: String*): DataFrame = select(exprs.map(functions.expr): _*) /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expression for @@ -1449,7 +1448,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): Dataset[Row] = { + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { groupBy().agg(aggExpr, aggExprs: _*) } @@ -1464,7 +1463,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(exprs: Map[String, String]): Dataset[Row] = groupBy().agg(exprs) + def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) /** * (Java-specific) Aggregates on the entire Dataset without groups. @@ -1477,7 +1476,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(exprs: util.Map[String, String]): Dataset[Row] = groupBy().agg(exprs) + def agg(exprs: util.Map[String, String]): DataFrame = groupBy().agg(exprs) /** * Aggregates on the entire Dataset without groups. @@ -1491,7 +1490,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): Dataset[Row] = groupBy().agg(expr, exprs: _*) + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs: _*) /** * (Scala-specific) Reduces the elements of this Dataset using the specified binary function. @@ -1594,7 +1593,7 @@ abstract class Dataset[T] extends Serializable { ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): Dataset[Row] + valueColumnName: String): DataFrame /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns @@ -1617,10 +1616,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def unpivot( - ids: Array[Column], - variableColumnName: String, - valueColumnName: String): Dataset[Row] + def unpivot(ids: Array[Column], variableColumnName: String, valueColumnName: String): DataFrame /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns @@ -1644,7 +1640,7 @@ abstract class Dataset[T] extends Serializable { ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): Dataset[Row] = + valueColumnName: String): DataFrame = unpivot(ids, values, variableColumnName, valueColumnName) /** @@ -1666,10 +1662,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def melt( - ids: Array[Column], - variableColumnName: String, - valueColumnName: String): Dataset[Row] = + def melt(ids: Array[Column], variableColumnName: String, valueColumnName: String): DataFrame = unpivot(ids, variableColumnName, valueColumnName) /** @@ -1732,7 +1725,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def transpose(indexColumn: Column): Dataset[Row] + def transpose(indexColumn: Column): DataFrame /** * Transposes a DataFrame, switching rows to columns. This function transforms the DataFrame @@ -1751,7 +1744,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def transpose(): Dataset[Row] + def transpose(): DataFrame /** * Return a `Column` object for a SCALAR Subquery containing exactly one row and one column. @@ -1872,7 +1865,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def union(other: DS[T]): Dataset[T] + def union(other: Dataset[T]): Dataset[T] /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. This is @@ -1886,7 +1879,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def unionAll(other: DS[T]): Dataset[T] = union(other) + def unionAll(other: Dataset[T]): Dataset[T] = union(other) /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. @@ -1917,7 +1910,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.3.0 */ - def unionByName(other: DS[T]): Dataset[T] = unionByName(other, allowMissingColumns = false) + def unionByName(other: Dataset[T]): Dataset[T] = unionByName(other, allowMissingColumns = false) /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. @@ -1961,7 +1954,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 3.1.0 */ - def unionByName(other: DS[T], allowMissingColumns: Boolean): Dataset[T] + def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] /** * Returns a new Dataset containing rows only in both this Dataset and another Dataset. This is @@ -1973,7 +1966,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 1.6.0 */ - def intersect(other: DS[T]): Dataset[T] + def intersect(other: Dataset[T]): Dataset[T] /** * Returns a new Dataset containing rows only in both this Dataset and another Dataset while @@ -1986,7 +1979,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.4.0 */ - def intersectAll(other: DS[T]): Dataset[T] + def intersectAll(other: Dataset[T]): Dataset[T] /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. This is @@ -1998,7 +1991,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def except(other: DS[T]): Dataset[T] + def except(other: Dataset[T]): Dataset[T] /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset while @@ -2011,7 +2004,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.4.0 */ - def exceptAll(other: DS[T]): Dataset[T] + def exceptAll(other: Dataset[T]): Dataset[T] /** * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), using a @@ -2096,7 +2089,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def randomSplit(weights: Array[Double], seed: Long): Array[_ <: Dataset[T]] + def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] /** * Returns a Java list that contains randomly split Dataset with the provided weights. @@ -2108,7 +2101,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def randomSplitAsList(weights: Array[Double], seed: Long): util.List[_ <: Dataset[T]] + def randomSplitAsList(weights: Array[Double], seed: Long): util.List[Dataset[T]] /** * Randomly splits this Dataset with the provided weights. @@ -2118,7 +2111,7 @@ abstract class Dataset[T] extends Serializable { * @group typedrel * @since 2.0.0 */ - def randomSplit(weights: Array[Double]): Array[_ <: Dataset[T]] + def randomSplit(weights: Array[Double]): Array[Dataset[T]] /** * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more rows @@ -2148,7 +2141,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") - def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): Dataset[Row] + def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame /** * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero or @@ -2174,7 +2167,7 @@ abstract class Dataset[T] extends Serializable { */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)( - f: A => IterableOnce[B]): Dataset[Row] + f: A => IterableOnce[B]): DataFrame /** * Returns a new Dataset by adding a column or replacing the existing column that has the same @@ -2191,7 +2184,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def withColumn(colName: String, col: Column): Dataset[Row] = withColumns(Seq(colName), Seq(col)) + def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col)) /** * (Scala-specific) Returns a new Dataset by adding columns or replacing the existing columns @@ -2203,7 +2196,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withColumns(colsMap: Map[String, Column]): Dataset[Row] = { + def withColumns(colsMap: Map[String, Column]): DataFrame = { val (colNames, newCols) = colsMap.toSeq.unzip withColumns(colNames, newCols) } @@ -2218,14 +2211,41 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withColumns(colsMap: util.Map[String, Column]): Dataset[Row] = withColumns( + def withColumns(colsMap: util.Map[String, Column]): DataFrame = withColumns( colsMap.asScala.toMap) /** * Returns a new Dataset by adding columns or replacing the existing columns that has the same * names. */ - protected def withColumns(colNames: Seq[String], cols: Seq[Column]): Dataset[Row] + private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame + + /** + * Returns a new Dataset by adding columns with metadata. + */ + private[spark] def withColumns( + colNames: Seq[String], + cols: Seq[Column], + metadata: Seq[Metadata]): DataFrame = { + require( + colNames.size == cols.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of columns: ${cols.size}") + require( + colNames.size == metadata.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of metadata elements: ${metadata.size}") + val colsWithMetadata = colNames.zip(cols).zip(metadata).map { case ((colName, col), meta) => + col.as(colName, meta) + } + withColumns(colNames, colsWithMetadata) + } + + /** + * Returns a new Dataset by adding a column with metadata. + */ + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = + withColumn(colName, col.as(colName, metadata)) /** * Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain @@ -2234,7 +2254,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def withColumnRenamed(existingName: String, newName: String): Dataset[Row] = + def withColumnRenamed(existingName: String, newName: String): DataFrame = withColumnsRenamed(Seq(existingName), Seq(newName)) /** @@ -2249,7 +2269,7 @@ abstract class Dataset[T] extends Serializable { * @since 3.4.0 */ @throws[AnalysisException] - def withColumnsRenamed(colsMap: Map[String, String]): Dataset[Row] = { + def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = { val (colNames, newColNames) = colsMap.toSeq.unzip withColumnsRenamed(colNames, newColNames) } @@ -2263,10 +2283,10 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def withColumnsRenamed(colsMap: util.Map[String, String]): Dataset[Row] = + def withColumnsRenamed(colsMap: util.Map[String, String]): DataFrame = withColumnsRenamed(colsMap.asScala.toMap) - protected def withColumnsRenamed(colNames: Seq[String], newColNames: Seq[String]): Dataset[Row] + protected def withColumnsRenamed(colNames: Seq[String], newColNames: Seq[String]): DataFrame /** * Returns a new Dataset by updating an existing column with metadata. @@ -2274,7 +2294,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withMetadata(columnName: String, metadata: Metadata): Dataset[Row] + def withMetadata(columnName: String, metadata: Metadata): DataFrame /** * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain column @@ -2347,7 +2367,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def drop(colName: String): Dataset[Row] = drop(colName :: Nil: _*) + def drop(colName: String): DataFrame = drop(colName :: Nil: _*) /** * Returns a new Dataset with columns dropped. This is a no-op if schema doesn't contain column @@ -2360,7 +2380,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def drop(colNames: String*): Dataset[Row] + def drop(colNames: String*): DataFrame /** * Returns a new Dataset with column dropped. @@ -2375,7 +2395,7 @@ abstract class Dataset[T] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def drop(col: Column): Dataset[Row] = drop(col, Nil: _*) + def drop(col: Column): DataFrame = drop(col, Nil: _*) /** * Returns a new Dataset with columns dropped. @@ -2387,7 +2407,7 @@ abstract class Dataset[T] extends Serializable { * @since 3.4.0 */ @scala.annotation.varargs - def drop(col: Column, cols: Column*): Dataset[Row] + def drop(col: Column, cols: Column*): DataFrame /** * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias @@ -2567,7 +2587,7 @@ abstract class Dataset[T] extends Serializable { * @since 1.6.0 */ @scala.annotation.varargs - def describe(cols: String*): Dataset[Row] + def describe(cols: String*): DataFrame /** * Computes specified statistics for numeric and string columns. Available statistics are:
    @@ -2637,7 +2657,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.3.0 */ @scala.annotation.varargs - def summary(statistics: String*): Dataset[Row] + def summary(statistics: String*): DataFrame /** * Returns the first `n` rows. @@ -3194,7 +3214,7 @@ abstract class Dataset[T] extends Serializable { * @since 3.1.0 */ @DeveloperApi - def sameSemantics(other: DS[T]): Boolean + def sameSemantics(other: Dataset[T]): Boolean /** * Returns a `hashCode` of the logical query plan against this [[Dataset]]. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/api/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index dd7e8e81a088c..6a8f2248e669a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql import org.apache.spark.annotation.Stable -import org.apache.spark.sql.api.Dataset /** - * A container for a [[org.apache.spark.sql.api.Dataset]], used for implicit conversions in Scala. + * A container for a [[org.apache.spark.sql.Dataset]], used for implicit conversions in Scala. * * To use this, import implicit conversions in SQL: * {{{ @@ -32,15 +31,15 @@ import org.apache.spark.sql.api.Dataset * @since 1.6.0 */ @Stable -class DatasetHolder[T, DS[U] <: Dataset[U]](ds: DS[T]) { +abstract class DatasetHolder[T] { // This is declared with parentheses to prevent the Scala compiler from treating // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset. - def toDS(): DS[T] = ds + def toDS(): Dataset[T] // This is declared with parentheses to prevent the Scala compiler from treating // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DS[Row] = ds.toDF().asInstanceOf[DS[Row]] + def toDF(): DataFrame - def toDF(colNames: String*): DS[Row] = ds.toDF(colNames: _*).asInstanceOf[DS[Row]] + def toDF(colNames: String*): DataFrame } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala similarity index 97% rename from sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala rename to sql/api/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 81f999430a128..4117ab71ac67b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -14,14 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql -import org.apache.spark.api.java.function.{CoGroupFunction, FlatMapGroupsFunction, FlatMapGroupsWithStateFunction, MapFunction, MapGroupsFunction, MapGroupsWithStateFunction, ReduceFunction} -import org.apache.spark.sql.{Column, Encoder, TypedColumn} +import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder import org.apache.spark.sql.functions.{count => cnt, lit} import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors} -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode} +import org.apache.spark.sql.streaming._ /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -31,7 +30,6 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode * @since 2.0.0 */ abstract class KeyValueGroupedDataset[K, V] extends Serializable { - type KVDS[KL, VL] <: KeyValueGroupedDataset[KL, VL] /** * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the @@ -301,7 +299,8 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { */ def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout, - initialState: KVDS[K, S])(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] + initialState: KeyValueGroupedDataset[K, S])( + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] /** * (Java-specific) Applies the given function to each group of data, while maintaining a @@ -400,7 +399,7 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout, - initialState: KVDS[K, S]): Dataset[U] = { + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { val f = ToScalaUDF(func) mapGroupsWithState[S, U](timeoutConf, initialState)(f)(stateEncoder, outputEncoder) } @@ -463,7 +462,8 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout, - initialState: KVDS[K, S])(func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] + initialState: KeyValueGroupedDataset[K, S])( + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] /** * (Java-specific) Applies the given function to each group of data, while maintaining a @@ -541,7 +541,7 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout, - initialState: KVDS[K, S]): Dataset[U] = { + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(ToScalaUDF(func))( stateEncoder, outputEncoder) @@ -690,7 +690,7 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeMode: TimeMode, outputMode: OutputMode, - initialState: KVDS[K, S]): Dataset[U] + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] /** * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state @@ -723,7 +723,7 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], eventTimeColumnName: String, outputMode: OutputMode, - initialState: KVDS[K, S]): Dataset[U] + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] /** * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API @@ -755,7 +755,7 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeMode: TimeMode, outputMode: OutputMode, - initialState: KVDS[K, S], + initialState: KeyValueGroupedDataset[K, S], outputEncoder: Encoder[U], initialStateEncoder: Encoder[S]): Dataset[U] = { transformWithState(statefulProcessor, timeMode, outputMode, initialState)( @@ -796,7 +796,7 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], outputMode: OutputMode, - initialState: KVDS[K, S], + initialState: KeyValueGroupedDataset[K, S], eventTimeColumnName: String, outputEncoder: Encoder[U], initialStateEncoder: Encoder[S]): Dataset[U] = { @@ -956,7 +956,7 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { * * @since 1.6.0 */ - def cogroup[U, R: Encoder](other: KVDS[K, U])( + def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { cogroupSorted(other)(Nil: _*)(Nil: _*)(f) } @@ -970,7 +970,7 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { * @since 1.6.0 */ def cogroup[U, R]( - other: KVDS[K, U], + other: KeyValueGroupedDataset[K, U], f: CoGroupFunction[K, V, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)(ToScalaUDF(f))(encoder) @@ -991,7 +991,7 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { * `org.apache.spark.sql.api.KeyValueGroupedDataset#cogroup` * @since 3.4.0 */ - def cogroupSorted[U, R: Encoder](other: KVDS[K, U])(thisSortExprs: Column*)( + def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)( otherSortExprs: Column*)(f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] /** @@ -1010,7 +1010,7 @@ abstract class KeyValueGroupedDataset[K, V] extends Serializable { * @since 3.4.0 */ def cogroupSorted[U, R]( - other: KVDS[K, U], + other: KeyValueGroupedDataset[K, U], thisSortExprs: Array[Column], otherSortExprs: Array[Column], f: CoGroupFunction[K, V, U, R], diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala similarity index 92% rename from sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala rename to sql/api/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 118b8f1ecd488..040fc4f4f260f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -14,14 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql -import scala.jdk.CollectionConverters._ +import java.util -import _root_.java.util +import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.{functions, Column, Encoder, Row} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -36,12 +35,12 @@ import org.apache.spark.sql.{functions, Column, Encoder, Row} */ @Stable abstract class RelationalGroupedDataset { - protected def df: Dataset[Row] + protected def df: DataFrame /** * Create a aggregation based on the grouping column, the grouping type, and the aggregations. */ - protected def toDF(aggCols: Seq[Column]): Dataset[Row] + protected def toDF(aggCols: Seq[Column]): DataFrame protected def selectNumericColumns(colNames: Seq[String]): Seq[Column] @@ -60,7 +59,7 @@ abstract class RelationalGroupedDataset { private def aggregateNumericColumns( colNames: Seq[String], - function: Column => Column): Dataset[Row] = { + function: Column => Column): DataFrame = { toDF(selectNumericColumns(colNames).map(function)) } @@ -87,7 +86,7 @@ abstract class RelationalGroupedDataset { * * @since 1.3.0 */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): Dataset[Row] = + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = toDF((aggExpr +: aggExprs).map(toAggCol)) /** @@ -105,7 +104,7 @@ abstract class RelationalGroupedDataset { * * @since 1.3.0 */ - def agg(exprs: Map[String, String]): Dataset[Row] = toDF(exprs.map(toAggCol).toSeq) + def agg(exprs: Map[String, String]): DataFrame = toDF(exprs.map(toAggCol).toSeq) /** * (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods. @@ -120,7 +119,7 @@ abstract class RelationalGroupedDataset { * * @since 1.3.0 */ - def agg(exprs: util.Map[String, String]): Dataset[Row] = { + def agg(exprs: util.Map[String, String]): DataFrame = { agg(exprs.asScala.toMap) } @@ -156,7 +155,7 @@ abstract class RelationalGroupedDataset { * @since 1.3.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): Dataset[Row] = toDF(expr +: exprs) + def agg(expr: Column, exprs: Column*): DataFrame = toDF(expr +: exprs) /** * Count the number of rows for each group. The resulting `DataFrame` will also contain the @@ -164,7 +163,7 @@ abstract class RelationalGroupedDataset { * * @since 1.3.0 */ - def count(): Dataset[Row] = toDF(functions.count(functions.lit(1)).as("count") :: Nil) + def count(): DataFrame = toDF(functions.count(functions.lit(1)).as("count") :: Nil) /** * Compute the average value for each numeric columns for each group. This is an alias for @@ -174,7 +173,7 @@ abstract class RelationalGroupedDataset { * @since 1.3.0 */ @scala.annotation.varargs - def mean(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.avg) + def mean(colNames: String*): DataFrame = aggregateNumericColumns(colNames, functions.avg) /** * Compute the max value for each numeric columns for each group. The resulting `DataFrame` will @@ -184,7 +183,7 @@ abstract class RelationalGroupedDataset { * @since 1.3.0 */ @scala.annotation.varargs - def max(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.max) + def max(colNames: String*): DataFrame = aggregateNumericColumns(colNames, functions.max) /** * Compute the mean value for each numeric columns for each group. The resulting `DataFrame` @@ -194,7 +193,7 @@ abstract class RelationalGroupedDataset { * @since 1.3.0 */ @scala.annotation.varargs - def avg(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.avg) + def avg(colNames: String*): DataFrame = aggregateNumericColumns(colNames, functions.avg) /** * Compute the min value for each numeric column for each group. The resulting `DataFrame` will @@ -204,7 +203,7 @@ abstract class RelationalGroupedDataset { * @since 1.3.0 */ @scala.annotation.varargs - def min(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.min) + def min(colNames: String*): DataFrame = aggregateNumericColumns(colNames, functions.min) /** * Compute the sum for each numeric columns for each group. The resulting `DataFrame` will also @@ -214,7 +213,7 @@ abstract class RelationalGroupedDataset { * @since 1.3.0 */ @scala.annotation.varargs - def sum(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.sum) + def sum(colNames: String*): DataFrame = aggregateNumericColumns(colNames, functions.sum) /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SQLContext.scala b/sql/api/src/main/scala/org/apache/spark/sql/SQLContext.scala similarity index 99% rename from sql/api/src/main/scala/org/apache/spark/sql/api/SQLContext.scala rename to sql/api/src/main/scala/org/apache/spark/sql/SQLContext.scala index 50590fffa1521..8da22791f6030 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SQLContext.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -15,23 +15,22 @@ * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql + +import java.util.{List => JList, Map => JMap, Properties} import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag -import _root_.java.util.{List => JList, Map => JMap, Properties} - import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Encoder, Encoders, ExperimentalMethods, Row} -import org.apache.spark.sql.api.SQLImplicits import org.apache.spark.sql.catalog.Table import org.apache.spark.sql.functions.{array_size, coalesce, col, lit, when} import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQueryManager} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ExecutionListenerManager @@ -637,7 +636,7 @@ abstract class SQLContext private[sql] (val sparkSession: SparkSession) /** * Returns a `StreamingQueryManager` that allows managing all the - * [[org.apache.spark.sql.api.StreamingQuery StreamingQueries]] active on `this` context. + * [[org.apache.spark.sql.streaming.StreamingQuery StreamingQueries]] active on `this` context. * * @since 2.0.0 */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala b/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala similarity index 92% rename from sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala rename to sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 200e913b5412e..2f68d436acfcd 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -14,29 +14,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql import scala.collection.Map import scala.language.implicitConversions import scala.reflect.classTag import scala.reflect.runtime.universe.TypeTag -import _root_.java - import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{ColumnName, DatasetHolder, Encoder, Encoders} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, DEFAULT_SCALA_DECIMAL_ENCODER, IterableEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, StringEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ /** * A collection of implicit methods for converting common Scala objects into - * [[org.apache.spark.sql.api.Dataset]]s. + * [[org.apache.spark.sql.Dataset]]s. * * @since 1.6.0 */ abstract class SQLImplicits extends EncoderImplicits with Serializable { - type DS[U] <: Dataset[U] protected def session: SparkSession @@ -55,17 +51,14 @@ abstract class SQLImplicits extends EncoderImplicits with Serializable { * Creates a [[Dataset]] from a local Seq. * @since 1.6.0 */ - implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T, DS] = { - new DatasetHolder(session.createDataset(s).asInstanceOf[DS[T]]) - } + implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T] /** * Creates a [[Dataset]] from an RDD. * * @since 1.6.0 */ - implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T, DS] = - new DatasetHolder(session.createDataset(rdd).asInstanceOf[DS[T]]) + implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T] /** * An implicit conversion that turns a Scala `Symbol` into a [[org.apache.spark.sql.Column]]. @@ -301,7 +294,7 @@ trait EncoderImplicits extends LowPrioritySQLImplicits with Serializable { /** * Lower priority implicit methods for converting Scala objects into - * [[org.apache.spark.sql.api.Dataset]]s. Conflicting implicits are placed here to disambiguate + * [[org.apache.spark.sql.Dataset]]s. Conflicting implicits are placed here to disambiguate * resolution. * * Reasons for including specific implicits: newProductEncoder - to disambiguate for `List`s which diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/SparkSession.scala similarity index 79% rename from sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala rename to sql/api/src/main/scala/org/apache/spark/sql/SparkSession.scala index af2144cb9eb41..e94ecb4f26fc5 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -14,25 +14,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql +import java.{lang, util} +import java.io.Closeable +import java.net.URI +import java.util.Locale +import java.util.concurrent.atomic.AtomicReference + +import scala.collection.mutable import scala.concurrent.duration.NANOSECONDS import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag - -import _root_.java.io.Closeable -import _root_.java.lang -import _root_.java.net.URI -import _root_.java.util -import _root_.java.util.concurrent.atomic.AtomicReference +import scala.util.Try import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Encoder, ExperimentalMethods, Row, RuntimeConfig, SparkSessionExtensions} +import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.internal.{SessionState, SharedState} import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQueryManager} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.SparkClassUtils @@ -205,14 +208,14 @@ abstract class SparkSession extends Serializable with Closeable { * @since 2.0.0 */ @transient - def emptyDataFrame: Dataset[Row] + def emptyDataFrame: DataFrame /** * Creates a `DataFrame` from a local Seq of Product. * * @since 2.0.0 */ - def createDataFrame[A <: Product: TypeTag](data: Seq[A]): Dataset[Row] + def createDataFrame[A <: Product: TypeTag](data: Seq[A]): DataFrame /** * :: DeveloperApi :: Creates a `DataFrame` from a `java.util.List` containing @@ -223,7 +226,7 @@ abstract class SparkSession extends Serializable with Closeable { * @since 2.0.0 */ @DeveloperApi - def createDataFrame(rows: util.List[Row], schema: StructType): Dataset[Row] + def createDataFrame(rows: util.List[Row], schema: StructType): DataFrame /** * Applies a schema to a List of Java Beans. @@ -233,7 +236,7 @@ abstract class SparkSession extends Serializable with Closeable { * * @since 1.6.0 */ - def createDataFrame(data: util.List[_], beanClass: Class[_]): Dataset[Row] + def createDataFrame(data: util.List[_], beanClass: Class[_]): DataFrame /** * Creates a `DataFrame` from an RDD of Product (e.g. case classes, tuples). @@ -242,7 +245,7 @@ abstract class SparkSession extends Serializable with Closeable { * this method is not supported in Spark Connect. * @since 2.0.0 */ - def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): Dataset[Row] + def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame /** * :: DeveloperApi :: Creates a `DataFrame` from an `RDD` containing @@ -277,7 +280,7 @@ abstract class SparkSession extends Serializable with Closeable { * @since 2.0.0 */ @DeveloperApi - def createDataFrame(rowRDD: RDD[Row], schema: StructType): Dataset[Row] + def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame /** * :: DeveloperApi :: Creates a `DataFrame` from a `JavaRDD` containing @@ -290,7 +293,7 @@ abstract class SparkSession extends Serializable with Closeable { * @since 2.0.0 */ @DeveloperApi - def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): Dataset[Row] + def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame /** * Applies a schema to an RDD of Java Beans. @@ -300,7 +303,7 @@ abstract class SparkSession extends Serializable with Closeable { * * @since 2.0.0 */ - def createDataFrame(rdd: RDD[_], beanClass: Class[_]): Dataset[Row] + def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame /** * Applies a schema to an RDD of Java Beans. @@ -312,7 +315,7 @@ abstract class SparkSession extends Serializable with Closeable { * this method is not supported in Spark Connect. * @since 2.0.0 */ - def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): Dataset[Row] + def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame /** * Convert a `BaseRelation` created for external data sources into a `DataFrame`. @@ -321,7 +324,7 @@ abstract class SparkSession extends Serializable with Closeable { * this method is not supported in Spark Connect. * @since 2.0.0 */ - def baseRelationToDataFrame(baseRelation: BaseRelation): Dataset[Row] + def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame /* ------------------------------- * | Methods for creating DataSets | @@ -449,7 +452,7 @@ abstract class SparkSession extends Serializable with Closeable { * database. Note that, the global temporary view database is also valid here. * @since 2.0.0 */ - def table(tableName: String): Dataset[Row] + def table(tableName: String): DataFrame /* ----------------- * | Everything else | @@ -470,7 +473,7 @@ abstract class SparkSession extends Serializable with Closeable { * is. * @since 3.5.0 */ - def sql(sqlText: String, args: Array[_]): Dataset[Row] + def sql(sqlText: String, args: Array[_]): DataFrame /** * Executes a SQL query substituting named parameters by the given arguments, returning the @@ -487,7 +490,7 @@ abstract class SparkSession extends Serializable with Closeable { * `array()`, `struct()`, in that case it is taken as is. * @since 3.4.0 */ - def sql(sqlText: String, args: Map[String, Any]): Dataset[Row] + def sql(sqlText: String, args: Map[String, Any]): DataFrame /** * Executes a SQL query substituting named parameters by the given arguments, returning the @@ -504,7 +507,7 @@ abstract class SparkSession extends Serializable with Closeable { * `array()`, `struct()`, in that case it is taken as is. * @since 3.4.0 */ - def sql(sqlText: String, args: util.Map[String, Any]): Dataset[Row] = { + def sql(sqlText: String, args: util.Map[String, Any]): DataFrame = { sql(sqlText, args.asScala.toMap) } @@ -514,7 +517,7 @@ abstract class SparkSession extends Serializable with Closeable { * * @since 2.0.0 */ - def sql(sqlText: String): Dataset[Row] = sql(sqlText, Map.empty[String, Any]) + def sql(sqlText: String): DataFrame = sql(sqlText, Map.empty[String, Any]) /** * Execute an arbitrary string command inside an external execution engine rather than Spark. @@ -537,7 +540,7 @@ abstract class SparkSession extends Serializable with Closeable { * @since 3.0.0 */ @Unstable - def executeCommand(runner: String, command: String, options: Map[String, String]): Dataset[Row] + def executeCommand(runner: String, command: String, options: Map[String, String]): DataFrame /** * Add a single artifact to the current session. @@ -776,40 +779,161 @@ abstract class SparkSession extends Serializable with Closeable { * means the connection to the server is usable. */ private[sql] def isUsable: Boolean + + /** + * Execute a block of code with this session set as the active session, and restore the previous + * session on completion. + */ + @DeveloperApi + def withActive[T](block: => T): T = { + // Use the active session thread local directly to make sure we get the session that is actually + // set and not the default session. This to prevent that we promote the default session to the + // active session once we are done. + val old = SparkSession.getActiveSession.orNull + SparkSession.setActiveSession(this) + try block + finally { + SparkSession.setActiveSession(old) + } + } } object SparkSession extends SparkSessionCompanion { type Session = SparkSession - private[this] val companion: SparkSessionCompanion = { - val cls = SparkClassUtils.classForName("org.apache.spark.sql.SparkSession") + // Implementation specific companions + private lazy val CLASSIC_COMPANION = lookupCompanion( + "org.apache.spark.sql.classic.SparkSession") + private lazy val CONNECT_COMPANION = lookupCompanion( + "org.apache.spark.sql.connect.SparkSession") + private def DEFAULT_COMPANION = + Try(CLASSIC_COMPANION).orElse(Try(CONNECT_COMPANION)).getOrElse { + throw new IllegalStateException( + "Cannot find a SparkSession implementation on the Classpath.") + } + + private[this] def lookupCompanion(name: String): SparkSessionCompanion = { + val cls = SparkClassUtils.classForName(name) val mirror = scala.reflect.runtime.currentMirror val module = mirror.classSymbol(cls).companion.asModule mirror.reflectModule(module).instance.asInstanceOf[SparkSessionCompanion] } /** @inheritdoc */ - override def builder(): SparkSessionBuilder = companion.builder() + override def builder(): Builder = new Builder /** @inheritdoc */ - override def setActiveSession(session: SparkSession): Unit = - companion.setActiveSession(session.asInstanceOf[companion.Session]) + override def setActiveSession(session: SparkSession): Unit = super.setActiveSession(session) /** @inheritdoc */ - override def clearActiveSession(): Unit = companion.clearActiveSession() + override def setDefaultSession(session: SparkSession): Unit = super.setDefaultSession(session) /** @inheritdoc */ - override def setDefaultSession(session: SparkSession): Unit = - companion.setDefaultSession(session.asInstanceOf[companion.Session]) + override def getActiveSession: Option[SparkSession] = super.getActiveSession /** @inheritdoc */ - override def clearDefaultSession(): Unit = companion.clearDefaultSession() + override def getDefaultSession: Option[SparkSession] = super.getDefaultSession - /** @inheritdoc */ - override def getActiveSession: Option[SparkSession] = companion.getActiveSession + override protected def tryCastToImplementation(session: SparkSession): Option[SparkSession] = + Some(session) - /** @inheritdoc */ - override def getDefaultSession: Option[SparkSession] = companion.getDefaultSession + class Builder extends SparkSessionBuilder { + import SparkSessionBuilder._ + private val extensionModifications = mutable.Buffer.empty[SparkSessionExtensions => Unit] + private var sc: Option[SparkContext] = None + private var companion: SparkSessionCompanion = DEFAULT_COMPANION + + /** @inheritdoc */ + override def appName(name: String): this.type = super.appName(name) + + /** @inheritdoc */ + override def master(master: String): this.type = super.master(master) + + /** @inheritdoc */ + override def enableHiveSupport(): this.type = super.enableHiveSupport() + + /** @inheritdoc */ + override def config(key: String, value: String): this.type = super.config(key, value) + + /** @inheritdoc */ + override def config(key: String, value: Long): this.type = super.config(key, value) + + /** @inheritdoc */ + override def config(key: String, value: Double): this.type = super.config(key, value) + + /** @inheritdoc */ + override def config(key: String, value: Boolean): this.type = super.config(key, value) + + /** @inheritdoc */ + override def config(map: Map[String, Any]): this.type = super.config(map) + + /** @inheritdoc */ + override def config(map: util.Map[String, Any]): this.type = super.config(map) + + /** @inheritdoc */ + override def config(conf: SparkConf): this.type = super.config(conf) + + /** @inheritdoc */ + override def remote(connectionString: String): this.type = super.remote(connectionString) + + /** @inheritdoc */ + override def withExtensions(f: SparkSessionExtensions => Unit): this.type = synchronized { + extensionModifications += f + this + } + + /** @inheritdoc */ + override private[spark] def sparkContext(sparkContext: SparkContext): this.type = + synchronized { + sc = Option(sparkContext) + this + } + + /** + * Make the builder create a Classic SparkSession. + */ + def classic(): this.type = mode(CONNECT_COMPANION) + + /** + * Make the builder create a Connect SparkSession. + */ + def connect(): this.type = mode(CONNECT_COMPANION) + + private def mode(companion: SparkSessionCompanion): this.type = synchronized { + this.companion = companion + this + } + + /** @inheritdoc */ + override def getOrCreate(): SparkSession = builder().getOrCreate() + + /** @inheritdoc */ + override def create(): SparkSession = builder().create() + + override protected def handleBuilderConfig(key: String, value: String): Boolean = key match { + case API_MODE_KEY => + companion = value.toLowerCase(Locale.ROOT).trim match { + case API_MODE_CLASSIC => CLASSIC_COMPANION + case API_MODE_CONNECT => CONNECT_COMPANION + case other => + throw new IllegalArgumentException(s"Unknown API mode: $other") + } + true + case _ => + false + } + + /** + * Create an API mode implementation specific builder. + */ + private def builder(): SparkSessionBuilder = synchronized { + val builder = companion.builder() + sc.foreach(builder.sparkContext) + options.foreach(kv => builder.config(kv._1, kv._2)) + extensionModifications.foreach(builder.withExtensions) + builder + } + } } /** @@ -819,6 +943,8 @@ object SparkSession extends SparkSessionCompanion { private[sql] abstract class SparkSessionCompanion { private[sql] type Session <: SparkSession + import SparkSessionCompanion._ + /** * Changes the SparkSession that will be returned in this thread and its children when * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -826,7 +952,9 @@ private[sql] abstract class SparkSessionCompanion { * * @since 2.0.0 */ - def setActiveSession(session: Session): Unit + def setActiveSession(session: Session): Unit = { + activeThreadSession.set(session) + } /** * Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will @@ -834,21 +962,27 @@ private[sql] abstract class SparkSessionCompanion { * * @since 2.0.0 */ - def clearActiveSession(): Unit + def clearActiveSession(): Unit = { + activeThreadSession.remove() + } /** * Sets the default SparkSession that is returned by the builder. * * @since 2.0.0 */ - def setDefaultSession(session: Session): Unit + def setDefaultSession(session: Session): Unit = { + defaultSession.set(session) + } /** * Clears the default SparkSession that is returned by the builder. * * @since 2.0.0 */ - def clearDefaultSession(): Unit + def clearDefaultSession(): Unit = { + defaultSession.set(null) + } /** * Returns the active SparkSession for the current thread, returned by the builder. @@ -858,7 +992,7 @@ private[sql] abstract class SparkSessionCompanion { * * @since 2.2.0 */ - def getActiveSession: Option[Session] + def getActiveSession: Option[Session] = usableSession(activeThreadSession.get()) /** * Returns the default SparkSession that is returned by the builder. @@ -868,7 +1002,7 @@ private[sql] abstract class SparkSessionCompanion { * * @since 2.2.0 */ - def getDefaultSession: Option[Session] + def getDefaultSession: Option[Session] = usableSession(defaultSession.get()) /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -882,65 +1016,19 @@ private[sql] abstract class SparkSessionCompanion { throw SparkException.internalError("No active or default Spark session found"))) } - /** - * Creates a [[SparkSessionBuilder]] for constructing a [[SparkSession]]. - * - * @since 2.0.0 - */ - def builder(): SparkSessionBuilder -} - -/** - * Abstract class for [[SparkSession]] companions. This implements active and default session - * management. - */ -private[sql] abstract class BaseSparkSessionCompanion extends SparkSessionCompanion { - - /** The active SparkSession for the current thread. */ - private val activeThreadSession = new InheritableThreadLocal[Session] - - /** Reference to the root SparkSession. */ - private val defaultSession = new AtomicReference[Session] - - /** @inheritdoc */ - def setActiveSession(session: Session): Unit = { - activeThreadSession.set(session) - } - - /** @inheritdoc */ - def clearActiveSession(): Unit = { - activeThreadSession.remove() - } - - /** @inheritdoc */ - def setDefaultSession(session: Session): Unit = { - defaultSession.set(session) - } - - /** @inheritdoc */ - def clearDefaultSession(): Unit = { - defaultSession.set(null.asInstanceOf[Session]) - } - - /** @inheritdoc */ - def getActiveSession: Option[Session] = usableSession(activeThreadSession.get()) - - /** @inheritdoc */ - def getDefaultSession: Option[Session] = usableSession(defaultSession.get()) - - private def usableSession(session: Session): Option[Session] = { - if ((session ne null) && canUseSession(session)) { - Some(session) + private def usableSession(session: SparkSession): Option[Session] = { + if ((session ne null) && session.isUsable) { + tryCastToImplementation(session) } else { None } } - protected def canUseSession(session: Session): Boolean = session.isUsable + protected def tryCastToImplementation(session: SparkSession): Option[Session] /** * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when - * they are not set yet or they are not usable. + * they are not set yet, or they are not usable. */ protected def setDefaultAndActiveSession(session: Session): Unit = { val currentDefault = defaultSession.getAcquire @@ -960,11 +1048,30 @@ private[sql] abstract class BaseSparkSessionCompanion extends SparkSessionCompan * When the session is closed remove it from active and default. */ private[sql] def onSessionClose(session: Session): Unit = { - defaultSession.compareAndSet(session, null.asInstanceOf[Session]) + defaultSession.compareAndSet(session, null) if (getActiveSession.contains(session)) { clearActiveSession() } } + + /** + * Creates a [[SparkSessionBuilder]] for constructing a [[SparkSession]]. + * + * @since 2.0.0 + */ + def builder(): SparkSessionBuilder +} + +/** + * This object keeps track of the global (default) and the thread-local SparkSession. + */ +object SparkSessionCompanion { + + /** The active SparkSession for the current thread. */ + private val activeThreadSession = new InheritableThreadLocal[SparkSession] + + /** Reference to the root SparkSession. */ + private val defaultSession = new AtomicReference[SparkSession] } /** @@ -972,6 +1079,7 @@ private[sql] abstract class BaseSparkSessionCompanion extends SparkSessionCompan */ @Stable abstract class SparkSessionBuilder { + import SparkSessionBuilder._ protected val options = new scala.collection.mutable.HashMap[String, String] /** @@ -980,7 +1088,7 @@ abstract class SparkSessionBuilder { * * @since 2.0.0 */ - def appName(name: String): this.type = config("spark.app.name", name) + def appName(name: String): this.type = config(APP_NAME_KEY, name) /** * Sets the Spark master URL to connect to, such as "local" to run locally, "local[4]" to run @@ -990,7 +1098,7 @@ abstract class SparkSessionBuilder { * this is only supported in Classic. * @since 2.0.0 */ - def master(master: String): this.type = config("spark.master", master) + def master(master: String): this.type = config(MASTER_KEY, master) /** * Enables Hive support, including connectivity to a persistent Hive metastore, support for Hive @@ -1000,7 +1108,7 @@ abstract class SparkSessionBuilder { * this is only supported in Classic. * @since 2.0.0 */ - def enableHiveSupport(): this.type = config("spark.sql.catalogImplementation", "hive") + def enableHiveSupport(): this.type = config(CATALOG_IMPL_KEY, "hive") /** * Sets the Spark Connect remote URL. @@ -1009,7 +1117,25 @@ abstract class SparkSessionBuilder { * this is only supported in Connect. * @since 3.5.0 */ - def remote(connectionString: String): this.type + def remote(connectionString: String): this.type = config(CONNECT_REMOTE_KEY, connectionString) + + private def putConfig(key: String, value: String): this.type = { + if (!handleBuilderConfig(key, value)) { + options += key -> value + } + this + } + + private def safePutConfig(key: String, value: String): this.type = + synchronized(putConfig(key, value)) + + /** + * Handle a configuration change that is only relevant to the builder. + * + * @return + * `true` when the change if builder only, otherwise it will be added to the configurations. + */ + protected def handleBuilderConfig(key: String, value: String): Boolean /** * Sets a config option. Options set using this method are automatically propagated to both @@ -1019,10 +1145,7 @@ abstract class SparkSessionBuilder { * this is only supported in Connect mode. * @since 2.0.0 */ - def config(key: String, value: String): this.type = synchronized { - options += key -> value - this - } + def config(key: String, value: String): this.type = safePutConfig(key, value) /** * Sets a config option. Options set using this method are automatically propagated to both @@ -1030,10 +1153,7 @@ abstract class SparkSessionBuilder { * * @since 2.0.0 */ - def config(key: String, value: Long): this.type = synchronized { - options += key -> value.toString - this - } + def config(key: String, value: Long): this.type = safePutConfig(key, value.toString) /** * Sets a config option. Options set using this method are automatically propagated to both @@ -1041,10 +1161,7 @@ abstract class SparkSessionBuilder { * * @since 2.0.0 */ - def config(key: String, value: Double): this.type = synchronized { - options += key -> value.toString - this - } + def config(key: String, value: Double): this.type = safePutConfig(key, value.toString) /** * Sets a config option. Options set using this method are automatically propagated to both @@ -1052,10 +1169,7 @@ abstract class SparkSessionBuilder { * * @since 2.0.0 */ - def config(key: String, value: Boolean): this.type = synchronized { - options += key -> value.toString - this - } + def config(key: String, value: Boolean): this.type = safePutConfig(key, value.toString) /** * Sets a config option. Options set using this method are automatically propagated to both @@ -1064,11 +1178,7 @@ abstract class SparkSessionBuilder { * @since 3.4.0 */ def config(map: Map[String, Any]): this.type = synchronized { - map.foreach { kv: (String, Any) => - { - options += kv._1 -> kv._2.toString - } - } + map.foreach(kv => putConfig(kv._1, kv._2.toString)) this } @@ -1088,7 +1198,7 @@ abstract class SparkSessionBuilder { * @since 2.0.0 */ def config(conf: SparkConf): this.type = synchronized { - conf.getAll.foreach { case (k, v) => options += k -> v } + conf.getAll.foreach(kv => putConfig(kv._1, kv._2)) this } @@ -1102,6 +1212,14 @@ abstract class SparkSessionBuilder { */ def withExtensions(f: SparkSessionExtensions => Unit): this.type + /** + * Set the [[SparkContext]] to use for the [[SparkSession]]. + * + * @note + * * this method is not supported in Spark Connect. + */ + private[spark] def sparkContext(sparkContext: SparkContext): this.type + /** * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new one based on * the options set in this builder. @@ -1129,3 +1247,14 @@ abstract class SparkSessionBuilder { */ def create(): SparkSession } + +object SparkSessionBuilder { + val MASTER_KEY = "spark.master" + val APP_NAME_KEY = "spark.app.name" + val CATALOG_IMPL_KEY = "spark.sql.catalogImplementation" + val CONNECT_REMOTE_KEY = "spark.connect.remote" + // Config key/values used to set the SparkSession API mode. + val API_MODE_KEY: String = "spark.api.mode" + val API_MODE_CLASSIC: String = "classic" + val API_MODE_CONNECT: String = "connect" +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/TableValuedFunction.scala b/sql/api/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala similarity index 98% rename from sql/api/src/main/scala/org/apache/spark/sql/api/TableValuedFunction.scala rename to sql/api/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala index c03abe0e3d97c..e00b9039e4c64 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/TableValuedFunction.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala @@ -14,11 +14,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql -import _root_.java.lang - -import org.apache.spark.sql.{Column, Row} +import java.lang abstract class TableValuedFunction { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala b/sql/api/src/main/scala/org/apache/spark/sql/UDFRegistration.scala similarity index 97% rename from sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala rename to sql/api/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index a8e8f5c5f8556..802ca6ae39b74 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -14,13 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.util.SparkCharVarcharUtils -import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction} +import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedFunction} import org.apache.spark.sql.internal.ToScalaUDF import org.apache.spark.sql.types.DataType @@ -80,6 +80,26 @@ abstract class UDFRegistration { register(name, udf, "scala_udf", validateParameterCount = false) } + /** + * Registers a user-defined aggregate function (UDAF). + * + * @param name + * the name of the UDAF. + * @param udaf + * the UDAF needs to be registered. + * @return + * the registered UDAF. + * @since 1.5.0 + * @deprecated + * this method and the use of UserDefinedAggregateFunction are deprecated. Aggregator[IN, BUF, + * OUT] should now be registered as a UDF via the functions.udaf(agg) method. + */ + @deprecated( + "Aggregator[IN, BUF, OUT] should now be registered as a UDF" + + " via the functions.udaf(agg) method.", + "3.0.0") + def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction + private def registerScalaUDF( name: String, func: AnyRef, diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala similarity index 95% rename from sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index a0f51d30dc572..3068b81c58c82 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -15,15 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql.catalog -import scala.jdk.CollectionConverters._ +import java.util -import _root_.java.util +import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.{AnalysisException, Row} -import org.apache.spark.sql.catalog.{CatalogMetadata, Column, Database, Function, Table} +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel @@ -280,7 +279,7 @@ abstract class Catalog { * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String): Dataset[Row] = { + def createExternalTable(tableName: String, path: String): DataFrame = { createTable(tableName, path) } @@ -293,7 +292,7 @@ abstract class Catalog { * identifier is provided, it refers to a table in the current database. * @since 2.2.0 */ - def createTable(tableName: String, path: String): Dataset[Row] + def createTable(tableName: String, path: String): DataFrame /** * Creates a table from the given path based on a data source and returns the corresponding @@ -305,7 +304,7 @@ abstract class Catalog { * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String, source: String): Dataset[Row] = { + def createExternalTable(tableName: String, path: String, source: String): DataFrame = { createTable(tableName, path, source) } @@ -318,7 +317,7 @@ abstract class Catalog { * identifier is provided, it refers to a table in the current database. * @since 2.2.0 */ - def createTable(tableName: String, path: String, source: String): Dataset[Row] + def createTable(tableName: String, path: String, source: String): DataFrame /** * Creates a table from the given path based on a data source and a set of options. Then, @@ -333,7 +332,7 @@ abstract class Catalog { def createExternalTable( tableName: String, source: String, - options: util.Map[String, String]): Dataset[Row] = { + options: util.Map[String, String]): DataFrame = { createTable(tableName, source, options) } @@ -349,7 +348,7 @@ abstract class Catalog { def createTable( tableName: String, source: String, - options: util.Map[String, String]): Dataset[Row] = { + options: util.Map[String, String]): DataFrame = { createTable(tableName, source, options.asScala.toMap) } @@ -366,7 +365,7 @@ abstract class Catalog { def createExternalTable( tableName: String, source: String, - options: Map[String, String]): Dataset[Row] = { + options: Map[String, String]): DataFrame = { createTable(tableName, source, options) } @@ -379,7 +378,7 @@ abstract class Catalog { * identifier is provided, it refers to a table in the current database. * @since 2.2.0 */ - def createTable(tableName: String, source: String, options: Map[String, String]): Dataset[Row] + def createTable(tableName: String, source: String, options: Map[String, String]): DataFrame /** * Create a table from the given path based on a data source, a schema and a set of options. @@ -395,7 +394,7 @@ abstract class Catalog { tableName: String, source: String, schema: StructType, - options: util.Map[String, String]): Dataset[Row] = { + options: util.Map[String, String]): DataFrame = { createTable(tableName, source, schema, options) } @@ -412,7 +411,7 @@ abstract class Catalog { tableName: String, source: String, description: String, - options: util.Map[String, String]): Dataset[Row] = { + options: util.Map[String, String]): DataFrame = { createTable( tableName, source = source, @@ -433,7 +432,7 @@ abstract class Catalog { tableName: String, source: String, description: String, - options: Map[String, String]): Dataset[Row] + options: Map[String, String]): DataFrame /** * Create a table based on the dataset in a data source, a schema and a set of options. Then, @@ -448,7 +447,7 @@ abstract class Catalog { tableName: String, source: String, schema: StructType, - options: util.Map[String, String]): Dataset[Row] = { + options: util.Map[String, String]): DataFrame = { createTable(tableName, source, schema, options.asScala.toMap) } @@ -466,7 +465,7 @@ abstract class Catalog { tableName: String, source: String, schema: StructType, - options: Map[String, String]): Dataset[Row] = { + options: Map[String, String]): DataFrame = { createTable(tableName, source, schema, options) } @@ -483,7 +482,7 @@ abstract class Catalog { tableName: String, source: String, schema: StructType, - options: Map[String, String]): Dataset[Row] + options: Map[String, String]): DataFrame /** * Create a table based on the dataset in a data source, a schema and a set of options. Then, @@ -499,7 +498,7 @@ abstract class Catalog { source: String, schema: StructType, description: String, - options: util.Map[String, String]): Dataset[Row] = { + options: util.Map[String, String]): DataFrame = { createTable( tableName, source = source, @@ -522,7 +521,7 @@ abstract class Catalog { source: String, schema: StructType, description: String, - options: Map[String, String]): Dataset[Row] + options: Map[String, String]): DataFrame /** * Drops the local temporary view with the given view name in the catalog. If the view has been diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala index 563554d506c4a..9fbfb9e679e58 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala @@ -128,7 +128,7 @@ object CurrentOrigin { } private val sparkCodePattern = Pattern.compile("(org\\.apache\\.spark\\.sql\\." + - "(?:api\\.)?" + + "(?:(classic|connect)\\.)?" + "(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder)" + "(?:|\\..*|\\$.*))" + "|(scala\\.collection\\..*)") diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala index 533b09e82df13..2f48241eadc2a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala @@ -50,7 +50,7 @@ trait AttributeNameParser { if (tmp.nonEmpty) throw e inBacktick = true } else if (char == '.') { - if (name(i - 1) == '.' || i == name.length - 1) throw e + if (i == 0 || name(i - 1) == '.' || i == name.length - 1) throw e nameParts += tmp.mkString tmp.clear() } else { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 9f509fa843a2b..5670e513287e6 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1794,7 +1794,7 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def broadcast[DS[U] <: api.Dataset[U]](df: DS[_]): df.type = { + def broadcast[U](df: Dataset[U]): df.type = { df.hint("broadcast").asInstanceOf[df.type] } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index ef4bdb8d5bdff..463307409839d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -85,7 +85,7 @@ trait ColumnNodeLike { } } -private[internal] object ColumnNode { +private[sql] object ColumnNode { val NO_ORIGIN: Origin = Origin() def normalize[T <: ColumnNodeLike](option: Option[T]): Option[T] = option.map(_.normalize().asInstanceOf[T]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/api/src/main/scala/org/apache/spark/sql/package.scala similarity index 80% rename from sql/core/src/main/scala/org/apache/spark/sql/package.scala rename to sql/api/src/main/scala/org/apache/spark/sql/package.scala index 1794ac513749f..f9ad85d65fb5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/package.scala @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark import org.apache.spark.annotation.{DeveloperApi, Unstable} @@ -23,19 +22,20 @@ import org.apache.spark.sql.execution.SparkStrategy /** * Allows the execution of relational queries, including those expressed in SQL using Spark. * - * @groupname dataType Data types - * @groupdesc Spark SQL data types. - * @groupprio dataType -3 - * @groupname field Field - * @groupprio field -2 - * @groupname row Row - * @groupprio row -1 + * @groupname dataType Data types + * @groupdesc Spark + * SQL data types. + * @groupprio dataType -3 + * @groupname field Field + * @groupprio field -2 + * @groupname row Row + * @groupprio row -1 */ package object sql { /** - * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting - * with the query planner and is not designed to be stable across spark releases. Developers + * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting + * with the query planner and is not designed to be stable across spark releases. Developers * writing libraries should instead consider using the stable APIs provided in * [[org.apache.spark.sql.sources]] */ @@ -47,9 +47,9 @@ package object sql { /** * Metadata key which is used to write Spark version in the followings: - * - Parquet file metadata - * - ORC file metadata - * - Avro file metadata + * - Parquet file metadata + * - ORC file metadata + * - Avro file metadata * * Note that Hive table property `spark.sql.create.version` also has Spark version. */ @@ -57,8 +57,8 @@ package object sql { /** * The metadata key which is used to write the current session time zone into: - * - Parquet file metadata - * - Avro file metadata + * - Parquet file metadata + * - Avro file metadata */ private[sql] val SPARK_TIMEZONE_METADATA_KEY = "org.apache.spark.timeZone" @@ -69,8 +69,8 @@ package object sql { private[sql] val SPARK_LEGACY_DATETIME_METADATA_KEY = "org.apache.spark.legacyDateTime" /** - * Parquet file metadata key to indicate that the file with INT96 column type was written - * with rebasing. + * Parquet file metadata key to indicate that the file with INT96 column type was written with + * rebasing. */ private[sql] val SPARK_LEGACY_INT96_METADATA_KEY = "org.apache.spark.legacyINT96" } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamReader.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala similarity index 95% rename from sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamReader.scala rename to sql/api/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 219ecb77d4033..9098c1af74e50 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamReader.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -14,14 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql.streaming import scala.jdk.CollectionConverters._ -import _root_.java - import org.apache.spark.annotation.Evolving -import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Encoders} import org.apache.spark.sql.types.StructType /** @@ -111,14 +109,14 @@ abstract class DataStreamReader { * * @since 2.0.0 */ - def load(): Dataset[Row] + def load(): DataFrame /** * Loads input in as a `DataFrame`, for data streams that read from some path. * * @since 2.0.0 */ - def load(path: String): Dataset[Row] + def load(path: String): DataFrame /** * Loads a JSON file stream and returns the results as a `DataFrame`. @@ -140,7 +138,7 @@ abstract class DataStreamReader { * * @since 2.0.0 */ - def json(path: String): Dataset[Row] = { + def json(path: String): DataFrame = { validateJsonSchema() format("json").load(path) } @@ -163,7 +161,7 @@ abstract class DataStreamReader { * * @since 2.0.0 */ - def csv(path: String): Dataset[Row] = format("csv").load(path) + def csv(path: String): DataFrame = format("csv").load(path) /** * Loads a XML file stream and returns the result as a `DataFrame`. @@ -183,7 +181,7 @@ abstract class DataStreamReader { * * @since 4.0.0 */ - def xml(path: String): Dataset[Row] = { + def xml(path: String): DataFrame = { validateXmlSchema() format("xml").load(path) } @@ -202,7 +200,7 @@ abstract class DataStreamReader { * * @since 2.3.0 */ - def orc(path: String): Dataset[Row] = { + def orc(path: String): DataFrame = { format("orc").load(path) } @@ -220,7 +218,7 @@ abstract class DataStreamReader { * * @since 2.0.0 */ - def parquet(path: String): Dataset[Row] = { + def parquet(path: String): DataFrame = { format("parquet").load(path) } @@ -231,7 +229,7 @@ abstract class DataStreamReader { * The name of the table * @since 3.1.0 */ - def table(tableName: String): Dataset[Row] + def table(tableName: String): DataFrame /** * Loads text files and returns a `DataFrame` whose schema starts with a string column named @@ -258,7 +256,7 @@ abstract class DataStreamReader { * * @since 2.0.0 */ - def text(path: String): Dataset[Row] = format("text").load(path) + def text(path: String): DataFrame = format("text").load(path) /** * Loads text file(s) and returns a `Dataset` of String. The underlying schema of the Dataset diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala similarity index 87% rename from sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamWriter.scala rename to sql/api/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index f627eb3e167a3..cb5ecc728c441 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamWriter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -14,15 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql.streaming -import _root_.java -import _root_.java.util.concurrent.TimeoutException +import java.util.concurrent.TimeoutException import org.apache.spark.annotation.Evolving import org.apache.spark.api.java.function.VoidFunction2 -import org.apache.spark.sql.{ForeachWriter, WriteConfigMethods} -import org.apache.spark.sql.streaming.{OutputMode, Trigger} +import org.apache.spark.sql.{Dataset, ForeachWriter, WriteConfigMethods} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -32,7 +30,6 @@ import org.apache.spark.sql.streaming.{OutputMode, Trigger} */ @Evolving abstract class DataStreamWriter[T] extends WriteConfigMethods[DataStreamWriter[T]] { - type DS[U] <: Dataset[U] /** * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.
    • @@ -85,9 +82,9 @@ abstract class DataStreamWriter[T] extends WriteConfigMethods[DataStreamWriter[T def trigger(trigger: Trigger): this.type /** - * Specifies the name of the [[org.apache.spark.sql.api.StreamingQuery]] that can be started - * with `start()`. This name must be unique among all the currently active queries in the - * associated SparkSession. + * Specifies the name of the [[org.apache.spark.sql.streaming.StreamingQuery]] that can be + * started with `start()`. This name must be unique among all the currently active queries in + * the associated SparkSession. * * @since 2.0.0 */ @@ -153,7 +150,7 @@ abstract class DataStreamWriter[T] extends WriteConfigMethods[DataStreamWriter[T * @since 2.4.0 */ @Evolving - def foreachBatch(function: (DS[T], Long) => Unit): this.type + def foreachBatch(function: (Dataset[T], Long) => Unit): this.type /** * :: Experimental :: @@ -169,14 +166,15 @@ abstract class DataStreamWriter[T] extends WriteConfigMethods[DataStreamWriter[T * @since 2.4.0 */ @Evolving - def foreachBatch(function: VoidFunction2[DS[T], java.lang.Long]): this.type = { - foreachBatch((batchDs: DS[T], batchId: Long) => function.call(batchDs, batchId)) + def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): this.type = { + foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) } /** * Starts the execution of the streaming query, which will continually output results to the - * given path as new data arrives. The returned [[org.apache.spark.sql.api.StreamingQuery]] - * object can be used to interact with the stream. + * given path as new data arrives. The returned + * [[org.apache.spark.sql.streaming.StreamingQuery]] object can be used to interact with the + * stream. * * @since 2.0.0 */ @@ -184,9 +182,9 @@ abstract class DataStreamWriter[T] extends WriteConfigMethods[DataStreamWriter[T /** * Starts the execution of the streaming query, which will continually output results to the - * given path as new data arrives. The returned [[org.apache.spark.sql.api.StreamingQuery]] - * object can be used to interact with the stream. Throws a `TimeoutException` if the following - * conditions are met: + * given path as new data arrives. The returned + * [[org.apache.spark.sql.streaming.StreamingQuery]] object can be used to interact with the + * stream. Throws a `TimeoutException` if the following conditions are met: * - Another run of the same streaming query, that is a streaming query sharing the same * checkpoint location, is already active on the same Spark Driver * - The SQL configuration `spark.sql.streaming.stopActiveRunOnRestart` is enabled @@ -200,8 +198,9 @@ abstract class DataStreamWriter[T] extends WriteConfigMethods[DataStreamWriter[T /** * Starts the execution of the streaming query, which will continually output results to the - * given table as new data arrives. The returned [[org.apache.spark.sql.api.StreamingQuery]] - * object can be used to interact with the stream. + * given table as new data arrives. The returned + * [[org.apache.spark.sql.streaming.StreamingQuery]] object can be used to interact with the + * stream. * * For v1 table, partitioning columns provided by `partitionBy` will be respected no matter the * table exists or not. A new table will be created if the table not exists. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala index c5d3adda8b87e..31075f00e56f1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.streaming import java.io.Serializable -import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.annotation.Evolving /** * Class used to provide access to expired timer's expiry time. */ -@Experimental @Evolving -private[sql] trait ExpiredTimerInfo extends Serializable { +trait ExpiredTimerInfo extends Serializable { /** * Get the expired timer's expiry time as milliseconds in epoch time. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 146990917a3fc..edd049715acc0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.annotation.Evolving import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState /** @@ -196,7 +196,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * types (see `Encoder` for more details). * @since 2.2.0 */ -@Experimental @Evolving trait GroupState[S] extends LogicalGroupState[S] { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala index 568578d1f4ff6..79b0d10072e83 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala @@ -16,14 +16,13 @@ */ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.annotation.Evolving -@Experimental @Evolving /** * Interface used for arbitrary stateful operations with the v2 API to capture list value state. */ -private[sql] trait ListState[S] extends Serializable { +trait ListState[S] extends Serializable { /** Whether state exists or not. */ def exists(): Boolean diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala index 7b01888bbac49..c514b4e375f89 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala @@ -16,9 +16,8 @@ */ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.annotation.Evolving -@Experimental @Evolving /** * Interface used for arbitrary stateful operations with the v2 API to capture map value state. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala index f239bcff49fea..2b56c92f85491 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala @@ -19,15 +19,14 @@ package org.apache.spark.sql.streaming import java.io.Serializable import java.util.UUID -import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.annotation.Evolving /** * Represents the query info provided to the stateful processor used in the arbitrary state API v2 * to easily identify task retries on the same partition. */ -@Experimental @Evolving -private[sql] trait QueryInfo extends Serializable { +trait QueryInfo extends Serializable { /** Returns the streaming query id associated with stateful operator */ def getQueryId: UUID diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index b47629cb54396..f0ea1dcd68710 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.streaming import java.io.Serializable -import org.apache.spark.annotation.{Evolving, Experimental} -import org.apache.spark.sql.api.EncoderImplicits +import org.apache.spark.annotation.Evolving +import org.apache.spark.sql.EncoderImplicits import org.apache.spark.sql.errors.ExecutionErrors /** @@ -30,9 +30,8 @@ import org.apache.spark.sql.errors.ExecutionErrors * Users can also explicitly use `import implicits._` to access the EncoderImplicits and use the * state variable APIs relying on implicit encoders. */ -@Experimental @Evolving -private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable { +abstract class StatefulProcessor[K, I, O] extends Serializable { // scalastyle:off // Disable style checker so "implicits" object can start with lowercase i @@ -123,10 +122,8 @@ private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable { * initial state to be initialized in the first batch. This can be used for starting a new * streaming query with existing state from a previous streaming query. */ -@Experimental @Evolving -private[sql] abstract class StatefulProcessorWithInitialState[K, I, O, S] - extends StatefulProcessor[K, I, O] { +abstract class StatefulProcessorWithInitialState[K, I, O, S] extends StatefulProcessor[K, I, O] { /** * Function that will be invoked only in the first batch for users to process initial states. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index f458f0de37cbd..5a6d9f6c76ea8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -18,16 +18,15 @@ package org.apache.spark.sql.streaming import java.io.Serializable -import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.annotation.Evolving import org.apache.spark.sql.Encoder /** * Represents the operation handle provided to the stateful processor used in the arbitrary state * API v2. */ -@Experimental @Evolving -private[sql] trait StatefulProcessorHandle extends Serializable { +trait StatefulProcessorHandle extends Serializable { /** * Function to create new or return existing single value state variable of given type with ttl. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala similarity index 85% rename from sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala rename to sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 0aeb3518facd8..03222b244fd04 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql.streaming -import _root_.java.util.UUID -import _root_.java.util.concurrent.TimeoutException +import java.util.UUID +import java.util.concurrent.TimeoutException import org.apache.spark.annotation.Evolving -import org.apache.spark.sql.streaming.{StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus} +import org.apache.spark.sql.SparkSession /** * A handle to a query that is executing continuously in the background as new data arrives. All @@ -72,8 +72,7 @@ trait StreamingQuery { def isActive: Boolean /** - * Returns the [[org.apache.spark.sql.streaming.StreamingQueryException]] if the query was - * terminated by an exception. + * Returns the [[StreamingQueryException]] if the query was terminated by an exception. * * @since 2.0.0 */ @@ -87,17 +86,16 @@ trait StreamingQuery { def status: StreamingQueryStatus /** - * Returns an array of the most recent [[org.apache.spark.sql.streaming.StreamingQueryProgress]] - * updates for this query. The number of progress updates retained for each stream is configured - * by Spark session configuration `spark.sql.streaming.numRecentProgressUpdates`. + * Returns an array of the most recent [[StreamingQueryProgress]] updates for this query. The + * number of progress updates retained for each stream is configured by Spark session + * configuration `spark.sql.streaming.numRecentProgressUpdates`. * * @since 2.1.0 */ def recentProgress: Array[StreamingQueryProgress] /** - * Returns the most recent [[org.apache.spark.sql.streaming.StreamingQueryProgress]] update of - * this streaming query. + * Returns the most recent [[StreamingQueryProgress]] update of this streaming query. * * @since 2.1.0 */ @@ -111,7 +109,7 @@ trait StreamingQuery { * immediately (if the query was terminated by `stop()`), or throw the exception immediately (if * the query has terminated with exception). * - * @throws org.apache.spark.sql.streaming.StreamingQueryException + * @throws StreamingQueryException * if the query has terminated with an exception. * * @since 2.0.0 @@ -128,7 +126,7 @@ trait StreamingQuery { * `true` immediately (if the query was terminated by `stop()`), or throw the exception * immediately (if the query has terminated with exception). * - * @throws org.apache.spark.sql.streaming.StreamingQueryException + * @throws StreamingQueryException * if the query has terminated with an exception * * @since 2.0.0 diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index a6684969ff1ec..148cb183d8909 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -31,8 +31,7 @@ import org.apache.spark.annotation.Evolving import org.apache.spark.scheduler.SparkListenerEvent /** - * Interface for listening to events related to - * [[org.apache.spark.sql.api.StreamingQuery StreamingQueries]]. + * Interface for listening to events related to [[StreamingQuery StreamingQueries]]. * * @note * The methods are not thread-safe as they may be called from different threads. @@ -49,8 +48,7 @@ abstract class StreamingQueryListener extends Serializable { * @note * This is called synchronously with `DataStreamWriter.start()`, that is, `onQueryStart` will * be called on all listeners before `DataStreamWriter.start()` returns the corresponding - * [[org.apache.spark.sql.api.StreamingQuery]]. Please don't block this method as it will - * block your query. + * [[StreamingQuery]]. Please don't block this method as it will block your query. * @since 2.0.0 */ def onQueryStarted(event: QueryStartedEvent): Unit @@ -59,11 +57,10 @@ abstract class StreamingQueryListener extends Serializable { * Called when there is some status update (ingestion rate updated, etc.) * * @note - * This method is asynchronous. The status in [[org.apache.spark.sql.api.StreamingQuery]] will - * always be latest no matter when this method is called. Therefore, the status of - * [[org.apache.spark.sql.api.StreamingQuery]] may be changed before/when you process the - * event. E.g., you may find [[org.apache.spark.sql.api.StreamingQuery]] is terminated when - * you are processing `QueryProgressEvent`. + * This method is asynchronous. The status in [[StreamingQuery]] will always be latest no + * matter when this method is called. Therefore, the status of [[StreamingQuery]] may be + * changed before/when you process the event. E.g., you may find [[StreamingQuery]] is + * terminated when you are processing `QueryProgressEvent`. * @since 2.0.0 */ def onQueryProgress(event: QueryProgressEvent): Unit diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQueryManager.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala similarity index 97% rename from sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQueryManager.scala rename to sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 88ba9a493d063..d103c58152283 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQueryManager.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -14,12 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql.streaming -import _root_.java.util.UUID +import java.util.UUID import org.apache.spark.annotation.Evolving -import org.apache.spark.sql.streaming.{StreamingQueryException, StreamingQueryListener} /** * A class to manage all the [[StreamingQuery]] active in a `SparkSession`. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala index 04c5f59428f7f..a3480065965e0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala @@ -19,15 +19,14 @@ package org.apache.spark.sql.streaming import java.io.Serializable -import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.annotation.Evolving /** * Class used to provide access to timer values for processing and event time populated before * method invocations using the arbitrary state API v2. */ -@Experimental @Evolving -private[sql] trait TimerValues extends Serializable { +trait TimerValues extends Serializable { /** * Get the current processing time as milliseconds in epoch time. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala index edb5f65365ab8..5273dd5440f39 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.streaming import java.io.Serializable -import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.annotation.Evolving -@Experimental @Evolving /** * Interface used for arbitrary stateful operations with the v2 API to capture single value state. */ -private[sql] trait ValueState[S] extends Serializable { +trait ValueState[S] extends Serializable { /** Whether state exists or not. */ def exists(): Boolean diff --git a/sql/api/src/test/scala/org/apache/spark/sql/api/SparkSessionBuilderImplementationBindingSuite.scala b/sql/api/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala similarity index 97% rename from sql/api/src/test/scala/org/apache/spark/sql/api/SparkSessionBuilderImplementationBindingSuite.scala rename to sql/api/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala index 84b6b85f639a3..34fb507c65686 100644 --- a/sql/api/src/test/scala/org/apache/spark/sql/api/SparkSessionBuilderImplementationBindingSuite.scala +++ b/sql/api/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.api +package org.apache.spark.sql // scalastyle:off funsuite import org.scalatest.BeforeAndAfterAll diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 0d2b0464251f2..a86cc4555ccb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -402,7 +402,7 @@ object CTESubstitution extends Rule[LogicalPlan] { other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { case e: SubqueryExpression => e.withNewPlan( - apply(substituteCTE(e.plan, alwaysInline, cteRelations, None))) + apply(substituteCTE(e.plan, alwaysInline, cteRelations, recursiveCTERelation))) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 397f04ac984a2..c7d5c355270f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -22,6 +22,7 @@ import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.analysis.ResolveWithCTE.{checkForSelfReferenceInSubquery, checkIfSelfReferenceIsPlacedCorrectly} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ListAgg, Median, PercentileCont, PercentileDisc} @@ -274,10 +275,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB checkTrailingCommaInSelect(proj) case agg: Aggregate => checkTrailingCommaInSelect(agg) + case unionLoop: UnionLoop => + // Recursive CTEs have already substituted Union to UnionLoop at this stage. + // Here we perform additional checks for them. + checkIfSelfReferenceIsPlacedCorrectly(unionLoop, unionLoop.id) case _ => } + // Check if there is any self-reference within subqueries + checkForSelfReferenceInSubquery(plan) + // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala index 3ad88514e17c7..454fcdbd38399 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION} @@ -49,17 +50,18 @@ object ResolveWithCTE extends Rule[LogicalPlan] { plan.resolveOperatorsDownWithPruning(_.containsAllPatterns(CTE)) { case withCTE @ WithCTE(_, cteDefs) => val newCTEDefs = cteDefs.map { - // `cteDef.recursive` means "presence of a recursive CTERelationRef under cteDef". The - // side effect of node substitution below is that after CTERelationRef substitution - // its cteDef is no more considered `recursive`. This code path is common for `cteDef` - // that were non-recursive from the get go, as well as those that are no more recursive - // due to node substitution. - case cteDef if !cteDef.recursive => + // cteDef in the first case is either recursive and all the recursive CTERelationRefs + // are already substituted to UnionLoopRef in the previous pass, or it is not recursive + // at all. In both cases we need to put it in the map in case it is resolved. + // Second case is performing the substitution of recursive CTERelationRefs. + case cteDef if !cteDef.hasSelfReferenceAsCTERef => if (cteDef.resolved) { cteDefMap.put(cteDef.id, cteDef) } cteDef case cteDef => + // Multiple self-references are not allowed within one cteDef. + checkNumberOfSelfReferences(cteDef) cteDef.child match { // If it is a supported recursive CTE query pattern (4 so far), extract the anchor and // recursive plans from the Union and rewrite Union with UnionLoop. The recursive CTE @@ -183,4 +185,72 @@ object ResolveWithCTE extends Rule[LogicalPlan] { columnNames.map(UnresolvedSubqueryColumnAliases(_, ref)).getOrElse(ref) } } + + /** + * Checks if there is any self-reference within subqueries and throws an error + * if that is the case. + */ + def checkForSelfReferenceInSubquery(plan: LogicalPlan): Unit = { + plan.subqueriesAll.foreach { subquery => + subquery.foreach { + case r: CTERelationRef if r.recursive => + throw new AnalysisException( + errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE", + messageParameters = Map.empty) + case _ => + } + } + } + + /** + * Counts number of self-references in a recursive CTE definition and throws an error + * if that number is bigger than 1. + */ + private def checkNumberOfSelfReferences(cteDef: CTERelationDef): Unit = { + val numOfSelfRef = cteDef.collectWithSubqueries { + case ref: CTERelationRef if ref.cteId == cteDef.id => ref + }.length + if (numOfSelfRef > 1) { + cteDef.failAnalysis( + errorClass = "INVALID_RECURSIVE_REFERENCE.NUMBER", + messageParameters = Map.empty) + } + } + + /** + * Throws error if self-reference is placed in places which are not allowed: + * right side of left outer/semi/anti joins, left side of right outer joins, + * in full outer joins and in aggregates + */ + def checkIfSelfReferenceIsPlacedCorrectly( + plan: LogicalPlan, + cteId: Long, + allowRecursiveRef: Boolean = true): Unit = plan match { + case Join(left, right, Inner, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef) + case Join(left, right, LeftOuter, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false) + case Join(left, right, RightOuter, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef = false) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef) + case Join(left, right, LeftSemi, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false) + case Join(left, right, LeftAnti, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false) + case Join(left, right, _, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef = false) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false) + case Aggregate(_, _, child, _) => + checkIfSelfReferenceIsPlacedCorrectly(child, cteId, allowRecursiveRef = false) + case r: UnionLoopRef if !allowRecursiveRef && r.loopId == cteId => + throw new AnalysisException( + errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE", + messageParameters = Map.empty) + case other => + other.children.foreach(checkIfSelfReferenceIsPlacedCorrectly(_, cteId, allowRecursiveRef)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index ad1a1a99b8257..62d99f7854891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -61,7 +61,10 @@ case class InlineCTE( // 1) It is fine to inline a CTE if it references another CTE that is non-deterministic; // 2) Any `CTERelationRef` that contains `OuterReference` would have been inlined first. refCount == 1 || - cteDef.deterministic || + // Don't inline recursive CTEs if not necessary as recursion is very costly. + // The check if cteDef is recursive is performed by checking if it contains + // a UnionLoopRef with the same ID. + (cteDef.deterministic && !cteDef.hasSelfReferenceAsUnionLoopRef) || cteDef.child.exists(_.expressions.exists(_.isInstanceOf[OuterReference])) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala index b98cef04d911c..a13650da2472a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala @@ -41,10 +41,12 @@ object NormalizePlan extends PredicateHelper { def normalizeExpressions(plan: LogicalPlan): LogicalPlan = { val withNormalizedRuntimeReplaceable = normalizeRuntimeReplaceable(plan) withNormalizedRuntimeReplaceable transformAllExpressions { - case c: CommonExpressionDef => - c.copy(id = new CommonExpressionId(id = 0)) - case c: CommonExpressionRef => - c.copy(id = new CommonExpressionId(id = 0)) + case commonExpressionDef: CommonExpressionDef => + commonExpressionDef.copy(id = new CommonExpressionId(id = 0)) + case commonExpressionRef: CommonExpressionRef => + commonExpressionRef.copy(id = new CommonExpressionId(id = 0)) + case expressionWithRandomSeed: ExpressionWithRandomSeed => + expressionWithRandomSeed.withNewSeed(0) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala index ac4e2b77f793a..03236571eade7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala @@ -100,12 +100,14 @@ case class CTERelationDef( override def output: Seq[Attribute] = if (resolved) child.output else Nil - lazy val recursive: Boolean = child.exists{ - // If the reference is found inside the child, referencing to this CTE definition, - // and already marked as recursive, then this CTE definition is recursive. + lazy val hasSelfReferenceAsCTERef: Boolean = child.exists{ case CTERelationRef(this.id, _, _, _, _, true) => true case _ => false } + lazy val hasSelfReferenceAsUnionLoopRef: Boolean = child.exists{ + case UnionLoopRef(this.id, _, _) => true + case _ => false + } } object CTERelationDef { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index b43e627c0eece..263508a9d5fb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.connector.expressions +import org.apache.commons.codec.binary.Hex import org.apache.commons.lang3.StringUtils import org.apache.spark.SparkException import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType} import org.apache.spark.util.ArrayImplicits._ /** @@ -388,12 +389,13 @@ private[sql] object HoursTransform { } private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { - override def toString: String = { - if (dataType.isInstanceOf[StringType]) { - s"'${StringUtils.replace(s"$value", "'", "''")}'" - } else { - s"$value" - } + override def toString: String = dataType match { + case StringType => s"'${StringUtils.replace(s"$value", "'", "''")}'" + case BinaryType => + assert(value.isInstanceOf[Array[Byte]]) + val bytes = value.asInstanceOf[Array[Byte]] + "0x" + Hex.encodeHexString(bytes, false) + case _ => s"$value" } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala index aa2a6408faf0b..d2ac103e14f7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.plans +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{ AssertTrue, @@ -100,6 +103,18 @@ class NormalizePlanSuite extends SparkFunSuite with SQLConfHelper { assert(NormalizePlan(baselinePlanDef) == NormalizePlan(testPlanDef)) } + test("Normalize non-deterministic expressions") { + val random = new Random() + val baselineExpression = rand(random.nextLong()) + val testExpression = rand(random.nextLong()) + + val baselinePlan = LocalRelation().select(baselineExpression) + val testPlan = LocalRelation().select(testExpression) + + assert(baselinePlan != testPlan) + assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan)) + } + private def setTimezoneForAllExpression(plan: LogicalPlan): LogicalPlan = { plan.transformAllExpressions { case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Catalog.scala similarity index 95% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Catalog.scala index 8706000ae5be0..6850ffd122608 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Catalog.scala @@ -15,10 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.connect -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.catalog.{Catalog, CatalogMetadata, Column, Database, Function, Table} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog +import org.apache.spark.sql.catalog.{CatalogMetadata, Column, Database, Function, Table} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveBooleanEncoder, StringEncoder} @@ -26,7 +27,7 @@ import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevel import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel -class CatalogImpl(sparkSession: SparkSession) extends Catalog { +class Catalog(sparkSession: SparkSession) extends catalog.Catalog { /** * Returns the current default database in this session. @@ -61,7 +62,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 3.5.0 */ override def listDatabases(): Dataset[Database] = { - sparkSession.newDataset(CatalogImpl.databaseEncoder) { builder => + sparkSession.newDataset(Catalog.databaseEncoder) { builder => builder.getCatalogBuilder.getListDatabasesBuilder } } @@ -73,7 +74,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 3.5.0 */ override def listDatabases(pattern: String): Dataset[Database] = { - sparkSession.newDataset(CatalogImpl.databaseEncoder) { builder => + sparkSession.newDataset(Catalog.databaseEncoder) { builder => builder.getCatalogBuilder.getListDatabasesBuilder.setPattern(pattern) } } @@ -96,7 +97,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database does not exist") override def listTables(dbName: String): Dataset[Table] = { - sparkSession.newDataset(CatalogImpl.tableEncoder) { builder => + sparkSession.newDataset(Catalog.tableEncoder) { builder => builder.getCatalogBuilder.getListTablesBuilder.setDbName(dbName) } } @@ -109,7 +110,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database does not exist") def listTables(dbName: String, pattern: String): Dataset[Table] = { - sparkSession.newDataset(CatalogImpl.tableEncoder) { builder => + sparkSession.newDataset(Catalog.tableEncoder) { builder => builder.getCatalogBuilder.getListTablesBuilder.setDbName(dbName).setPattern(pattern) } } @@ -132,7 +133,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database does not exist") override def listFunctions(dbName: String): Dataset[Function] = { - sparkSession.newDataset(CatalogImpl.functionEncoder) { builder => + sparkSession.newDataset(Catalog.functionEncoder) { builder => builder.getCatalogBuilder.getListFunctionsBuilder.setDbName(dbName) } } @@ -146,7 +147,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database does not exist") def listFunctions(dbName: String, pattern: String): Dataset[Function] = { - sparkSession.newDataset(CatalogImpl.functionEncoder) { builder => + sparkSession.newDataset(Catalog.functionEncoder) { builder => builder.getCatalogBuilder.getListFunctionsBuilder.setDbName(dbName).setPattern(pattern) } } @@ -162,7 +163,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database or table does not exist") override def listColumns(tableName: String): Dataset[Column] = { - sparkSession.newDataset(CatalogImpl.columnEncoder) { builder => + sparkSession.newDataset(Catalog.columnEncoder) { builder => builder.getCatalogBuilder.getListColumnsBuilder.setTableName(tableName) } } @@ -182,7 +183,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database does not exist") override def listColumns(dbName: String, tableName: String): Dataset[Column] = { - sparkSession.newDataset(CatalogImpl.columnEncoder) { builder => + sparkSession.newDataset(Catalog.columnEncoder) { builder => builder.getCatalogBuilder.getListColumnsBuilder .setTableName(tableName) .setDbName(dbName) @@ -197,7 +198,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def getDatabase(dbName: String): Database = { sparkSession - .newDataset(CatalogImpl.databaseEncoder) { builder => + .newDataset(Catalog.databaseEncoder) { builder => builder.getCatalogBuilder.getGetDatabaseBuilder.setDbName(dbName) } .head() @@ -215,7 +216,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def getTable(tableName: String): Table = { sparkSession - .newDataset(CatalogImpl.tableEncoder) { builder => + .newDataset(Catalog.tableEncoder) { builder => builder.getCatalogBuilder.getGetTableBuilder.setTableName(tableName) } .head() @@ -232,7 +233,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def getTable(dbName: String, tableName: String): Table = { sparkSession - .newDataset(CatalogImpl.tableEncoder) { builder => + .newDataset(Catalog.tableEncoder) { builder => builder.getCatalogBuilder.getGetTableBuilder .setTableName(tableName) .setDbName(dbName) @@ -252,7 +253,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def getFunction(functionName: String): Function = { sparkSession - .newDataset(CatalogImpl.functionEncoder) { builder => + .newDataset(Catalog.functionEncoder) { builder => builder.getCatalogBuilder.getGetFunctionBuilder.setFunctionName(functionName) } .head() @@ -273,7 +274,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def getFunction(dbName: String, functionName: String): Function = { sparkSession - .newDataset(CatalogImpl.functionEncoder) { builder => + .newDataset(Catalog.functionEncoder) { builder => builder.getCatalogBuilder.getGetFunctionBuilder .setFunctionName(functionName) .setDbName(dbName) @@ -688,7 +689,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def listCatalogs(): Dataset[CatalogMetadata] = sparkSession - .newDataset(CatalogImpl.catalogEncoder) { builder => + .newDataset(Catalog.catalogEncoder) { builder => builder.getCatalogBuilder.getListCatalogsBuilder } @@ -700,12 +701,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def listCatalogs(pattern: String): Dataset[CatalogMetadata] = sparkSession - .newDataset(CatalogImpl.catalogEncoder) { builder => + .newDataset(Catalog.catalogEncoder) { builder => builder.getCatalogBuilder.getListCatalogsBuilder.setPattern(pattern) } } -private object CatalogImpl { +private object Catalog { private val databaseEncoder: AgnosticEncoder[Database] = ScalaReflection .encoderFor(ScalaReflection.localTypeOf[Database]) .asInstanceOf[AgnosticEncoder[Database]] diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectClientUnsupportedErrors.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/ConnectClientUnsupportedErrors.scala similarity index 91% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectClientUnsupportedErrors.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/ConnectClientUnsupportedErrors.scala index 5783a20348d75..6d956bd5af7e8 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectClientUnsupportedErrors.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/ConnectClientUnsupportedErrors.scala @@ -53,4 +53,10 @@ private[sql] object ConnectClientUnsupportedErrors { def sparkContext(): SparkUnsupportedOperationException = unsupportedFeatureException("SESSION_SPARK_CONTEXT") + + def sqlContext(): SparkUnsupportedOperationException = + unsupportedFeatureException("SESSION_SQL_CONTEXT") + + def registerUdaf(): SparkUnsupportedOperationException = + unsupportedFeatureException("REGISTER_UDAF") } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala similarity index 88% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala index 0344152be86e6..1fb06bc930ce7 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala @@ -20,8 +20,9 @@ import scala.language.implicitConversions import org.apache.spark.annotation.DeveloperApi import org.apache.spark.connect.proto -import org.apache.spark.sql._ -import org.apache.spark.sql.internal.ProtoColumnNode +import org.apache.spark.sql +import org.apache.spark.sql.Column +import org.apache.spark.sql.connect.ProtoColumnNode /** * Conversions from sql interfaces to the Connect specific implementation. @@ -36,17 +37,17 @@ import org.apache.spark.sql.internal.ProtoColumnNode */ @DeveloperApi trait ConnectConversions { - implicit def castToImpl(session: api.SparkSession): SparkSession = + implicit def castToImpl(session: sql.SparkSession): SparkSession = session.asInstanceOf[SparkSession] - implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] = + implicit def castToImpl[T](ds: sql.Dataset[T]): Dataset[T] = ds.asInstanceOf[Dataset[T]] - implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset = + implicit def castToImpl(rgds: sql.RelationalGroupedDataset): RelationalGroupedDataset = rgds.asInstanceOf[RelationalGroupedDataset] implicit def castToImpl[K, V]( - kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] = + kvds: sql.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala similarity index 96% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala index 3777f82594aae..8f6c6ef07b3df 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala @@ -15,14 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.{NAReplace, Relation} import org.apache.spark.connect.proto.Expression.{Literal => GLiteral} import org.apache.spark.connect.proto.NAReplace.Replacement +import org.apache.spark.sql import org.apache.spark.sql.connect.ConnectConversions._ +import org.apache.spark.sql.functions /** * Functionality for working with missing data in `DataFrame`s. @@ -30,13 +32,13 @@ import org.apache.spark.sql.connect.ConnectConversions._ * @since 3.4.0 */ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation) - extends api.DataFrameNaFunctions { + extends sql.DataFrameNaFunctions { import sparkSession.RichColumn - override protected def drop(minNonNulls: Option[Int]): Dataset[Row] = + override protected def drop(minNonNulls: Option[Int]): DataFrame = buildDropDataFrame(None, minNonNulls) - override protected def drop(minNonNulls: Option[Int], cols: Seq[String]): Dataset[Row] = + override protected def drop(minNonNulls: Option[Int], cols: Seq[String]): DataFrame = buildDropDataFrame(Option(cols), minNonNulls) private def buildDropDataFrame( diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameReader.scala similarity index 95% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameReader.scala index 1fbc887901ecc..0af603e0f6cc9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameReader.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.Properties @@ -25,6 +25,7 @@ import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.connect.proto.Parse.ParseFormat import org.apache.spark.rdd.RDD +import org.apache.spark.sql import org.apache.spark.sql.connect.ConnectClientUnsupportedErrors import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.DataTypeProtoConverter @@ -37,8 +38,7 @@ import org.apache.spark.sql.types.StructType * @since 3.4.0 */ @Stable -class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.DataFrameReader { - type DS[U] = Dataset[U] +class DataFrameReader private[sql] (sparkSession: SparkSession) extends sql.DataFrameReader { /** @inheritdoc */ override def format(source: String): this.type = super.format(source) @@ -140,15 +140,15 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.Data override def json(paths: String*): DataFrame = super.json(paths: _*) /** @inheritdoc */ - def json(jsonDataset: Dataset[String]): DataFrame = + def json(jsonDataset: sql.Dataset[String]): DataFrame = parse(jsonDataset, ParseFormat.PARSE_FORMAT_JSON) /** @inheritdoc */ - override def json(jsonRDD: JavaRDD[String]): Dataset[Row] = + override def json(jsonRDD: JavaRDD[String]): DataFrame = throw ConnectClientUnsupportedErrors.rdd() /** @inheritdoc */ - override def json(jsonRDD: RDD[String]): Dataset[Row] = + override def json(jsonRDD: RDD[String]): DataFrame = throw ConnectClientUnsupportedErrors.rdd() /** @inheritdoc */ @@ -159,7 +159,7 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.Data override def csv(paths: String*): DataFrame = super.csv(paths: _*) /** @inheritdoc */ - def csv(csvDataset: Dataset[String]): DataFrame = + def csv(csvDataset: sql.Dataset[String]): DataFrame = parse(csvDataset, ParseFormat.PARSE_FORMAT_CSV) /** @inheritdoc */ @@ -170,7 +170,7 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.Data override def xml(paths: String*): DataFrame = super.xml(paths: _*) /** @inheritdoc */ - def xml(xmlDataset: Dataset[String]): DataFrame = + def xml(xmlDataset: sql.Dataset[String]): DataFrame = parse(xmlDataset, ParseFormat.PARSE_FORMAT_UNSPECIFIED) /** @inheritdoc */ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala similarity index 95% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala index bb7cfa75a9ab9..f3c3f82a233ae 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala @@ -15,14 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.{lang => jl, util => ju} import org.apache.spark.connect.proto.{Relation, StatSampleBy} -import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder +import org.apache.spark.sql +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder} import org.apache.spark.sql.connect.ConnectConversions._ +import org.apache.spark.sql.connect.DataFrameStatFunctions.approxQuantileResultEncoder import org.apache.spark.sql.functions.lit /** @@ -31,7 +33,7 @@ import org.apache.spark.sql.functions.lit * @since 3.4.0 */ final class DataFrameStatFunctions private[sql] (protected val df: DataFrame) - extends api.DataFrameStatFunctions { + extends sql.DataFrameStatFunctions { private def root: Relation = df.plan.getRoot private val sparkSession: SparkSession = df.sparkSession diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriter.scala similarity index 96% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriter.scala index a1ba8ba44700e..2038037d4439c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriter.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.connect import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.connect.proto -import org.apache.spark.sql.{DataFrameWriter, Dataset, SaveMode} +import org.apache.spark.sql +import org.apache.spark.sql.SaveMode /** * Interface used to write a [[Dataset]] to external storage systems (e.g. file systems, key-value @@ -30,7 +31,7 @@ import org.apache.spark.sql.{DataFrameWriter, Dataset, SaveMode} * @since 3.4.0 */ @Stable -final class DataFrameWriterImpl[T] private[sql] (ds: Dataset[T]) extends DataFrameWriter[T] { +final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) extends sql.DataFrameWriter[T] { /** @inheritdoc */ override def mode(saveMode: SaveMode): this.type = super.mode(saveMode) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala similarity index 94% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala index 4afa8b6d566c5..42cf2cdfad58a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.connect import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.connect.proto -import org.apache.spark.sql.{Column, DataFrameWriterV2, Dataset} +import org.apache.spark.sql +import org.apache.spark.sql.Column /** * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 @@ -30,8 +31,8 @@ import org.apache.spark.sql.{Column, DataFrameWriterV2, Dataset} * @since 3.4.0 */ @Experimental -final class DataFrameWriterV2Impl[T] private[sql] (table: String, ds: Dataset[T]) - extends DataFrameWriterV2[T] { +final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T]) + extends sql.DataFrameWriterV2[T] { import ds.sparkSession.RichColumn private val builder = proto.WriteOperationV2 diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataStreamReader.scala similarity index 96% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataStreamReader.scala index 2ff34a6343644..808df593b775a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataStreamReader.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.connect import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving import org.apache.spark.connect.proto.Read.DataSource -import org.apache.spark.sql.{api, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.errors.DataTypeErrors +import org.apache.spark.sql.streaming import org.apache.spark.sql.types.StructType /** @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType */ @Evolving final class DataStreamReader private[sql] (sparkSession: SparkSession) - extends api.DataStreamReader { + extends streaming.DataStreamReader { private val sourceBuilder = DataSource.newBuilder() diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataStreamWriter.scala similarity index 88% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataStreamWriter.scala index b2c4fcf64e70f..a42a463e2c42a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataStreamWriter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.connect import java.util.Locale import java.util.concurrent.TimeoutException @@ -27,14 +27,12 @@ import com.google.protobuf.ByteString import org.apache.spark.annotation.Evolving import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.Command -import org.apache.spark.connect.proto.WriteStreamOperationStart -import org.apache.spark.sql.{api, Dataset, ForeachWriter} +import org.apache.spark.connect.proto.{Command, WriteStreamOperationStart} +import org.apache.spark.sql.{Dataset => DS, ForeachWriter} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket} -import org.apache.spark.sql.execution.streaming.AvailableNowTrigger -import org.apache.spark.sql.execution.streaming.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.OneTimeTrigger -import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger +import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, ContinuousTrigger, OneTimeTrigger, ProcessingTimeTrigger} +import org.apache.spark.sql.streaming +import org.apache.spark.sql.streaming.{OutputMode, Trigger} import org.apache.spark.sql.streaming.StreamingQueryListener.QueryStartedEvent import org.apache.spark.sql.types.NullType import org.apache.spark.util.SparkSerDeUtils @@ -46,8 +44,8 @@ import org.apache.spark.util.SparkSerDeUtils * @since 3.5.0 */ @Evolving -final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataStreamWriter[T] { - override type DS[U] = Dataset[U] +final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) + extends streaming.DataStreamWriter[T] { /** @inheritdoc */ def outputMode(outputMode: OutputMode): this.type = { @@ -134,7 +132,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataSt /** @inheritdoc */ @Evolving - def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { + def foreachBatch(function: (DS[T], Long) => Unit): this.type = { // SPARK-50661: the client should send the encoder for the input dataset together with the // function to the server. val serializedFn = @@ -192,7 +190,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataSt /** @inheritdoc */ @Evolving - override def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): this.type = + override def foreachBatch(function: VoidFunction2[DS[T], java.lang.Long]): this.type = super.foreachBatch(function) private val sinkBuilder = WriteStreamOperationStart diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala similarity index 94% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala index 75df538678a3d..36003283a3369 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util @@ -30,11 +30,12 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto import org.apache.spark.rdd.RDD +import org.apache.spark.sql +import org.apache.spark.sql.{functions, AnalysisException, Column, Encoder, Observation, Row, TypedColumn} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.OrderUtils -import org.apache.spark.sql.connect.ConnectClientUnsupportedErrors import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkResult import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter} @@ -42,8 +43,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.{struct, to_json} -import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SubqueryExpressionNode, SubqueryType, ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex} -import org.apache.spark.sql.streaming.DataStreamWriter +import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex} import org.apache.spark.sql.types.{Metadata, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ @@ -139,9 +139,7 @@ class Dataset[T] private[sql] ( val sparkSession: SparkSession, @DeveloperApi val plan: proto.Plan, val encoder: Encoder[T]) - extends api.Dataset[T] { - type DS[U] = Dataset[U] - + extends sql.Dataset[T] { import sparkSession.RichColumn // Make sure we don't forget to set plan id. @@ -320,12 +318,12 @@ class Dataset[T] private[sql] ( } /** @inheritdoc */ - def join(right: Dataset[_]): DataFrame = buildJoin(right) { builder => + def join(right: sql.Dataset[_]): DataFrame = buildJoin(right) { builder => builder.setJoinType(proto.Join.JoinType.JOIN_TYPE_INNER) } /** @inheritdoc */ - def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { + def join(right: sql.Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { buildJoin(right) { builder => builder .setJoinType(toJoinType(joinType)) @@ -334,7 +332,7 @@ class Dataset[T] private[sql] ( } /** @inheritdoc */ - def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = { + def join(right: sql.Dataset[_], joinExprs: Column, joinType: String): DataFrame = { buildJoin(right, Seq(joinExprs)) { builder => builder .setJoinType(toJoinType(joinType)) @@ -343,12 +341,12 @@ class Dataset[T] private[sql] ( } /** @inheritdoc */ - def crossJoin(right: Dataset[_]): DataFrame = buildJoin(right) { builder => + def crossJoin(right: sql.Dataset[_]): DataFrame = buildJoin(right) { builder => builder.setJoinType(proto.Join.JoinType.JOIN_TYPE_CROSS) } /** @inheritdoc */ - def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { + def joinWith[U](other: sql.Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { val joinTypeValue = toJoinType(joinType, skipSemiAnti = true) val (leftNullable, rightNullable) = joinTypeValue match { case proto.Join.JoinType.JOIN_TYPE_INNER | proto.Join.JoinType.JOIN_TYPE_CROSS => @@ -385,7 +383,7 @@ class Dataset[T] private[sql] ( } private def lateralJoin( - right: DS[_], + right: sql.Dataset[_], joinExprs: Option[Column], joinType: String): DataFrame = { val joinTypeValue = toJoinType(joinType) @@ -404,22 +402,22 @@ class Dataset[T] private[sql] ( } /** @inheritdoc */ - def lateralJoin(right: DS[_]): DataFrame = { + def lateralJoin(right: sql.Dataset[_]): DataFrame = { lateralJoin(right, None, "inner") } /** @inheritdoc */ - def lateralJoin(right: DS[_], joinExprs: Column): DataFrame = { + def lateralJoin(right: sql.Dataset[_], joinExprs: Column): DataFrame = { lateralJoin(right, Some(joinExprs), "inner") } /** @inheritdoc */ - def lateralJoin(right: DS[_], joinType: String): DataFrame = { + def lateralJoin(right: sql.Dataset[_], joinType: String): DataFrame = { lateralJoin(right, None, joinType) } /** @inheritdoc */ - def lateralJoin(right: DS[_], joinExprs: Column, joinType: String): DataFrame = { + def lateralJoin(right: sql.Dataset[_], joinExprs: Column, joinType: String): DataFrame = { lateralJoin(right, Some(joinExprs), joinType) } @@ -674,42 +672,42 @@ class Dataset[T] private[sql] ( } /** @inheritdoc */ - def union(other: Dataset[T]): Dataset[T] = { + def union(other: sql.Dataset[T]): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_UNION) { builder => builder.setIsAll(true) } } /** @inheritdoc */ - def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] = { + def unionByName(other: sql.Dataset[T], allowMissingColumns: Boolean): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_UNION) { builder => builder.setByName(true).setIsAll(true).setAllowMissingColumns(allowMissingColumns) } } /** @inheritdoc */ - def intersect(other: Dataset[T]): Dataset[T] = { + def intersect(other: sql.Dataset[T]): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_INTERSECT) { builder => builder.setIsAll(false) } } /** @inheritdoc */ - def intersectAll(other: Dataset[T]): Dataset[T] = { + def intersectAll(other: sql.Dataset[T]): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_INTERSECT) { builder => builder.setIsAll(true) } } /** @inheritdoc */ - def except(other: Dataset[T]): Dataset[T] = { + def except(other: sql.Dataset[T]): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_EXCEPT) { builder => builder.setIsAll(false) } } /** @inheritdoc */ - def exceptAll(other: Dataset[T]): Dataset[T] = { + def exceptAll(other: sql.Dataset[T]): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_EXCEPT) { builder => builder.setIsAll(true) } @@ -728,7 +726,7 @@ class Dataset[T] private[sql] ( } /** @inheritdoc */ - def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = { + def randomSplit(weights: Array[Double], seed: Long): Array[sql.Dataset[T]] = { require( weights.forall(_ >= 0), s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}") @@ -767,15 +765,15 @@ class Dataset[T] private[sql] ( } /** @inheritdoc */ - override def randomSplitAsList(weights: Array[Double], seed: Long): util.List[Dataset[T]] = + override def randomSplitAsList(weights: Array[Double], seed: Long): util.List[sql.Dataset[T]] = util.Arrays.asList(randomSplit(weights, seed): _*) /** @inheritdoc */ - override def randomSplit(weights: Array[Double]): Array[Dataset[T]] = + override def randomSplit(weights: Array[Double]): Array[sql.Dataset[T]] = randomSplit(weights, SparkClassUtils.random.nextLong()) /** @inheritdoc */ - protected def withColumns(names: Seq[String], values: Seq[Column]): DataFrame = { + private[spark] def withColumns(names: Seq[String], values: Seq[Column]): DataFrame = { require( names.size == values.size, s"The size of column names: ${names.size} isn't equal to " + @@ -968,7 +966,7 @@ class Dataset[T] private[sql] ( UDFAdaptors.iterableOnceToSeq(f), Nil, ScalaReflection.encoderFor[Seq[B]]) - select(col("*"), functions.explode(generator(col(inputColumn))).as((outputColumn))) + select(col("*"), functions.explode(generator(col(inputColumn))).as(outputColumn)) } /** @inheritdoc */ @@ -1032,7 +1030,7 @@ class Dataset[T] private[sql] ( buildRepartition(numPartitions, shuffle = true) } - protected def repartitionByExpression( + protected[this] def repartitionByExpression( numPartitions: Option[Int], partitionExprs: Seq[Column]): Dataset[T] = { // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments. @@ -1074,12 +1072,12 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ def write: DataFrameWriter[T] = { - new DataFrameWriterImpl[T](this) + new DataFrameWriter[T](this) } /** @inheritdoc */ def writeTo(table: String): DataFrameWriterV2[T] = { - new DataFrameWriterV2Impl[T](table, this) + new DataFrameWriterV2[T](table, this) } /** @inheritdoc */ @@ -1090,7 +1088,7 @@ class Dataset[T] private[sql] ( messageParameters = Map("methodName" -> toSQLId("mergeInto"))) } - new MergeIntoWriterImpl[T](table, this, condition) + new MergeIntoWriter[T](table, this, condition) } /** @inheritdoc */ @@ -1208,7 +1206,7 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ @DeveloperApi - def sameSemantics(other: Dataset[T]): Boolean = { + def sameSemantics(other: sql.Dataset[T]): Boolean = { sparkSession.sameSemantics(this.plan, other.plan) } @@ -1264,27 +1262,30 @@ class Dataset[T] private[sql] ( override def drop(col: Column): DataFrame = super.drop(col) /** @inheritdoc */ - override def join(right: Dataset[_], usingColumn: String): DataFrame = + override def join(right: sql.Dataset[_], usingColumn: String): DataFrame = super.join(right, usingColumn) /** @inheritdoc */ - override def join(right: Dataset[_], usingColumns: Array[String]): DataFrame = + override def join(right: sql.Dataset[_], usingColumns: Array[String]): DataFrame = super.join(right, usingColumns) /** @inheritdoc */ - override def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame = + override def join(right: sql.Dataset[_], usingColumns: Seq[String]): DataFrame = super.join(right, usingColumns) /** @inheritdoc */ - override def join(right: Dataset[_], usingColumn: String, joinType: String): DataFrame = + override def join(right: sql.Dataset[_], usingColumn: String, joinType: String): DataFrame = super.join(right, usingColumn, joinType) /** @inheritdoc */ - override def join(right: Dataset[_], usingColumns: Array[String], joinType: String): DataFrame = + override def join( + right: sql.Dataset[_], + usingColumns: Array[String], + joinType: String): DataFrame = super.join(right, usingColumns, joinType) /** @inheritdoc */ - override def join(right: Dataset[_], joinExprs: Column): DataFrame = + override def join(right: sql.Dataset[_], joinExprs: Column): DataFrame = super.join(right, joinExprs) /** @inheritdoc */ @@ -1374,7 +1375,7 @@ class Dataset[T] private[sql] ( super.localCheckpoint(eager, storageLevel) /** @inheritdoc */ - override def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = + override def joinWith[U](other: sql.Dataset[U], condition: Column): Dataset[(T, U)] = super.joinWith(other, condition) /** @inheritdoc */ @@ -1428,10 +1429,10 @@ class Dataset[T] private[sql] ( override def where(conditionExpr: String): Dataset[T] = super.where(conditionExpr) /** @inheritdoc */ - override def unionAll(other: Dataset[T]): Dataset[T] = super.unionAll(other) + override def unionAll(other: sql.Dataset[T]): Dataset[T] = super.unionAll(other) /** @inheritdoc */ - override def unionByName(other: Dataset[T]): Dataset[T] = super.unionByName(other) + override def unionByName(other: sql.Dataset[T]): Dataset[T] = super.unionByName(other) /** @inheritdoc */ override def sample(fraction: Double, seed: Long): Dataset[T] = super.sample(fraction, seed) @@ -1535,6 +1536,20 @@ class Dataset[T] private[sql] ( encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] + /** @inheritdoc */ + override private[spark] def withColumns( + colNames: Seq[String], + cols: Seq[Column], + metadata: Seq[Metadata]): DataFrame = + super.withColumns(colNames, cols, metadata) + + /** @inheritdoc */ + override private[spark] def withColumn( + colName: String, + col: Column, + metadata: Metadata): DataFrame = + super.withColumn(colName, col, metadata) + /** @inheritdoc */ override def rdd: RDD[T] = throw ConnectClientUnsupportedErrors.rdd() diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala similarity index 94% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala index d5505d2222c4f..c984582ed6ae1 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.Arrays @@ -24,13 +24,15 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto +import org.apache.spark.sql +import org.apache.spark.sql.{Column, Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder} +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfUtils} import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.col -import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr import org.apache.spark.sql.internal.UDFAdaptors import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode} @@ -41,8 +43,7 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode * * @since 3.5.0 */ -class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDataset[K, V] { - type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL] +class KeyValueGroupedDataset[K, V] private[sql] () extends sql.KeyValueGroupedDataset[K, V] { private def unsupported(): Nothing = throw new UnsupportedOperationException() @@ -74,8 +75,9 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDa protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = unsupported() /** @inheritdoc */ - def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)( - otherSortExprs: Column*)(f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = + def cogroupSorted[U, R: Encoder](other: sql.KeyValueGroupedDataset[K, U])( + thisSortExprs: Column*)(otherSortExprs: Column*)( + f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = unsupported() protected[sql] def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder]( @@ -101,12 +103,12 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDa /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S])( + initialState: sql.KeyValueGroupedDataset[K, S])( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { flatMapGroupsWithStateHelper( None, timeoutConf, - Some(initialState), + Some(castToImpl(initialState)), isMapGroupWithState = true)(UDFAdaptors.mapGroupsWithStateToFlatMapWithState(func)) } @@ -126,12 +128,12 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDa def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S])( + initialState: sql.KeyValueGroupedDataset[K, S])( func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { flatMapGroupsWithStateHelper( Some(outputMode), timeoutConf, - Some(initialState), + Some(castToImpl(initialState)), isMapGroupWithState = false)(func) } @@ -147,7 +149,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDa statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeMode: TimeMode, outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = unsupported() /** @inheritdoc */ @@ -161,7 +163,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDa statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], eventTimeColumnName: String, outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = unsupported() + initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = unsupported() // Overrides... /** @inheritdoc */ @@ -212,7 +214,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDa stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf, initialState) /** @inheritdoc */ @@ -231,7 +233,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDa stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = super.flatMapGroupsWithState( + initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = super.flatMapGroupsWithState( func, outputMode, stateEncoder, @@ -260,7 +262,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDa statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeMode: TimeMode, outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S], + initialState: sql.KeyValueGroupedDataset[K, S], outputEncoder: Encoder[U], initialStateEncoder: Encoder[S]) = super.transformWithState( statefulProcessor, @@ -274,7 +276,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDa override private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S], + initialState: sql.KeyValueGroupedDataset[K, S], eventTimeColumnName: String, outputEncoder: Encoder[U], initialStateEncoder: Encoder[S]) = super.transformWithState( @@ -355,19 +357,19 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDa override def count(): Dataset[(K, Long)] = super.count() /** @inheritdoc */ - override def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( + override def cogroup[U, R: Encoder](other: sql.KeyValueGroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = super.cogroup(other)(f) /** @inheritdoc */ override def cogroup[U, R]( - other: KeyValueGroupedDataset[K, U], + other: sql.KeyValueGroupedDataset[K, U], f: CoGroupFunction[K, V, U, R], encoder: Encoder[R]): Dataset[R] = super.cogroup(other, f, encoder) /** @inheritdoc */ override def cogroupSorted[U, R]( - other: KeyValueGroupedDataset[K, U], + other: sql.KeyValueGroupedDataset[K, U], thisSortExprs: Array[Column], otherSortExprs: Array[Column], f: CoGroupFunction[K, V, U, R], @@ -440,7 +442,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( } } - override def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( + override def cogroupSorted[U, R: Encoder](other: sql.KeyValueGroupedDataset[K, U])( thisSortExprs: Column*)(otherSortExprs: Column*)( f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, Any]] diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala similarity index 95% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala index fba3c6343558b..c245a8644a3cb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala @@ -15,14 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.connect import org.apache.spark.SparkRuntimeException import org.apache.spark.annotation.Experimental import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{Expression, MergeAction, MergeIntoTableCommand} import org.apache.spark.connect.proto.MergeAction.ActionType._ -import org.apache.spark.sql.{Column, Dataset, MergeIntoWriter} +import org.apache.spark.sql +import org.apache.spark.sql.Column import org.apache.spark.sql.functions.expr /** @@ -41,8 +42,8 @@ import org.apache.spark.sql.functions.expr * @since 4.0.0 */ @Experimental -class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on: Column) - extends MergeIntoWriter[T] { +class MergeIntoWriter[T] private[sql] (table: String, ds: Dataset[T], on: Column) + extends sql.MergeIntoWriter[T] { import ds.sparkSession.RichColumn private val builder = MergeIntoTableCommand diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala similarity index 97% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala index 0944c88a67906..00dc1fb6906f7 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala @@ -15,11 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto +import org.apache.spark.sql +import org.apache.spark.sql.{functions, Column, Encoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor import org.apache.spark.sql.connect.ConnectConversions._ @@ -41,7 +43,7 @@ class RelationalGroupedDataset private[sql] ( groupType: proto.Aggregate.GroupType, pivot: Option[proto.Aggregate.Pivot] = None, groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) - extends api.RelationalGroupedDataset { + extends sql.RelationalGroupedDataset { import df.sparkSession.RichColumn protected def toDF(aggExprs: Seq[Column]): DataFrame = { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RuntimeConfig.scala similarity index 96% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RuntimeConfig.scala index 74348e8e015e2..6f258ec2e08e6 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RuntimeConfig.scala @@ -14,12 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.connect import org.apache.spark.connect.proto.{ConfigRequest, ConfigResponse, KeyValue} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{ConfigEntry, ConfigReader, OptionalConfigEntry} -import org.apache.spark.sql.RuntimeConfig +import org.apache.spark.sql import org.apache.spark.sql.connect.client.SparkConnectClient /** @@ -27,8 +27,8 @@ import org.apache.spark.sql.connect.client.SparkConnectClient * * @since 3.4.0 */ -class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) - extends RuntimeConfig +class RuntimeConfig private[sql] (client: SparkConnectClient) + extends sql.RuntimeConfig with Logging { self => /** @inheritdoc */ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SQLContext.scala similarity index 97% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLContext.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SQLContext.scala index 3603eb6ea508d..e38179e232d05 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SQLContext.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.util.{List => JList, Map => JMap, Properties} @@ -26,7 +26,8 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD -import org.apache.spark.sql.connect.ConnectClientUnsupportedErrors +import org.apache.spark.sql +import org.apache.spark.sql.{Encoder, ExperimentalMethods, Row} import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQueryManager} @@ -35,7 +36,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager @Stable class SQLContext private[sql] (override val sparkSession: SparkSession) - extends api.SQLContext(sparkSession) { + extends sql.SQLContext(sparkSession) { /** @inheritdoc */ def newSession(): SQLContext = sparkSession.newSession().sqlContext @@ -58,11 +59,7 @@ class SQLContext private[sql] (override val sparkSession: SparkSession) // Disable style checker so "implicits" object can start with lowercase i /** @inheritdoc */ - object implicits extends SQLImplicits { - - /** @inheritdoc */ - override protected def session: SparkSession = sparkSession - } + object implicits extends SQLImplicits(sparkSession) // scalastyle:on @@ -308,7 +305,7 @@ class SQLContext private[sql] (override val sparkSession: SparkSession) super.jdbc(url, table, theParts) } } -object SQLContext extends api.SQLContextCompanion { +object SQLContext extends sql.SQLContextCompanion { override private[sql] type SQLContextImpl = SQLContext override private[sql] type SparkContextImpl = SparkContext diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SQLImplicits.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SQLImplicits.scala new file mode 100644 index 0000000000000..c24390a68554f --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SQLImplicits.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import scala.language.implicitConversions + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql +import org.apache.spark.sql.Encoder + +/** @inheritdoc */ +abstract class SQLImplicits private[sql] (override val session: SparkSession) + extends sql.SQLImplicits { + + override implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T] = + new DatasetHolder[T](session.createDataset(s)) + + override implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T] = + new DatasetHolder[T](session.createDataset(rdd)) +} + +class DatasetHolder[U](ds: Dataset[U]) extends sql.DatasetHolder[U] { + override def toDS(): Dataset[U] = ds + override def toDF(): DataFrame = ds.toDF() + override def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*) +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SessionCleaner.scala similarity index 95% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SessionCleaner.scala index 21e4f4d141a89..c7ad6151c49b5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SessionCleaner.scala @@ -15,13 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.connect import java.lang.ref.Cleaner import org.apache.spark.connect.proto import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession private[sql] class SessionCleaner(session: SparkSession) extends Logging { private val cleaner = Cleaner.create() diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala similarity index 93% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala index 110ecde5f99fd..f7998cf60ecac 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import java.net.URI import java.nio.file.{Files, Paths} @@ -35,22 +35,21 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.connect.proto import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, MDC} +import org.apache.spark.internal.LogKeys.CONFIG import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalog.Catalog +import org.apache.spark.sql +import org.apache.spark.sql.{Column, Encoder, ExperimentalMethods, Observation, Row, SparkSessionBuilder, SparkSessionCompanion, SparkSessionExtensions} import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BoxedLongEncoder, UnboundRowEncoder} -import org.apache.spark.sql.connect.ConnectClientUnsupportedErrors +import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toTypedExpr} import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.{CatalogImpl, ConnectRuntimeConfig, SessionCleaner, SessionState, SharedState, SqlApiConf, SubqueryExpressionNode} -import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.{toExpr, toTypedExpr} +import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf} import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.sql.streaming.DataStreamReader -import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.ArrayImplicits._ @@ -76,7 +75,7 @@ import org.apache.spark.util.ArrayImplicits._ class SparkSession private[sql] ( private[sql] val client: SparkConnectClient, private val planIdGenerator: AtomicLong) - extends api.SparkSession + extends sql.SparkSession with Logging { private[this] val allocator = new RootAllocator() @@ -100,7 +99,7 @@ class SparkSession private[sql] ( throw ConnectClientUnsupportedErrors.sparkContext() /** @inheritdoc */ - val conf: RuntimeConfig = new ConnectRuntimeConfig(client) + val conf: RuntimeConfig = new RuntimeConfig(client) /** @inheritdoc */ @transient @@ -199,7 +198,7 @@ class SparkSession private[sql] ( throw ConnectClientUnsupportedErrors.experimental() /** @inheritdoc */ - override def baseRelationToDataFrame(baseRelation: BaseRelation): api.Dataset[Row] = + override def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = throw ConnectClientUnsupportedErrors.baseRelationToDataFrame() /** @inheritdoc */ @@ -270,7 +269,7 @@ class SparkSession private[sql] ( lazy val streams: StreamingQueryManager = new StreamingQueryManager(this) /** @inheritdoc */ - lazy val catalog: Catalog = new CatalogImpl(this) + lazy val catalog: Catalog = new Catalog(this) /** @inheritdoc */ def table(tableName: String): DataFrame = { @@ -300,9 +299,7 @@ class SparkSession private[sql] ( // scalastyle:off /** @inheritdoc */ - object implicits extends SQLImplicits { - override protected def session: SparkSession = SparkSession.this - } + object implicits extends SQLImplicits(this) // scalastyle:on /** @inheritdoc */ @@ -665,7 +662,7 @@ class SparkSession private[sql] ( // The minimal builder needed to create a spark session. // TODO: implements all methods mentioned in the scaladoc of [[SparkSession]] -object SparkSession extends api.BaseSparkSessionCompanion with Logging { +object SparkSession extends SparkSessionCompanion with Logging { override private[sql] type Session = SparkSession private val MAX_CACHED_SESSIONS = 100 @@ -742,7 +739,9 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging { */ def builder(): Builder = new Builder() - class Builder() extends api.SparkSessionBuilder { + class Builder() extends SparkSessionBuilder { + import SparkSessionBuilder._ + // Initialize the connection string of the Spark Connect client builder from SPARK_REMOTE // by default, if it exists. The connection string can be overridden using // the remote() function, as it takes precedence over the SPARK_REMOTE environment variable. @@ -750,8 +749,8 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging { private var client: SparkConnectClient = _ /** @inheritdoc */ - def remote(connectionString: String): this.type = { - builder.connectionString(connectionString) + @deprecated("sparkContext does not work in Spark Connect") + override private[spark] def sparkContext(sparkContext: SparkContext): this.type = { this } @@ -795,15 +794,30 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging { /** @inheritdoc */ @deprecated("enableHiveSupport does not work in Spark Connect") - override def enableHiveSupport(): this.type = this + override def enableHiveSupport(): this.type = super.enableHiveSupport() /** @inheritdoc */ @deprecated("master does not work in Spark Connect, please use remote instead") - override def master(master: String): this.type = this + override def master(master: String): this.type = super.master(master) /** @inheritdoc */ @deprecated("appName does not work in Spark Connect") - override def appName(name: String): this.type = this + override def appName(name: String): this.type = super.appName(name) + + /** @inheritdoc */ + override def remote(connectionString: String): Builder.this.type = + super.remote(connectionString) + + override protected def handleBuilderConfig(key: String, value: String): Boolean = key match { + case CONNECT_REMOTE_KEY => + builder.connectionString(value) + true + case APP_NAME_KEY | MASTER_KEY | CATALOG_IMPL_KEY | API_MODE_KEY => + logWarning(log"${MDC(CONFIG, key)} configuration is not supported in Connect mode.") + true + case _ => + false + } /** @inheritdoc */ @deprecated("withExtensions does not work in Spark Connect") @@ -896,4 +910,10 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging { /** @inheritdoc */ override def active: SparkSession = super.active + + override protected def tryCastToImplementation( + session: sql.SparkSession): Option[SparkSession] = session match { + case impl: SparkSession => Some(impl) + case _ => None + } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQuery.scala similarity index 93% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQuery.scala index 29fbcc443deb9..f74a107531de5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQuery.scala @@ -15,21 +15,19 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.connect import java.util.UUID import scala.jdk.CollectionConverters._ -import org.apache.spark.connect.proto.Command -import org.apache.spark.connect.proto.ExecutePlanResponse -import org.apache.spark.connect.proto.StreamingQueryCommand -import org.apache.spark.connect.proto.StreamingQueryCommandResult +import org.apache.spark.connect.proto.{Command, ExecutePlanResponse, StreamingQueryCommand, StreamingQueryCommandResult} import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance -import org.apache.spark.sql.{api, SparkSession} +import org.apache.spark.sql.streaming +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus} /** @inheritdoc */ -trait StreamingQuery extends api.StreamingQuery { +trait StreamingQuery extends streaming.StreamingQuery { /** @inheritdoc */ override def sparkSession: SparkSession diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListenerBus.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala similarity index 98% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListenerBus.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala index c2934bcfa7058..30cdf2b5cadb4 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListenerBus.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.connect import java.util.concurrent.CopyOnWriteArrayList @@ -23,8 +23,8 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.{Command, ExecutePlanResponse, Plan, StreamingQueryEventType} import org.apache.spark.internal.{Logging, LogKeys, MDC} -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.client.CloseableIterator +import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.sql.streaming.StreamingQueryListener.{Event, QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent} class StreamingQueryListenerBus(sparkSession: SparkSession) extends Logging { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryManager.scala similarity index 93% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryManager.scala index 647d29c714dbb..ac864a1292c81 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryManager.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.connect import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} @@ -23,12 +23,11 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving -import org.apache.spark.connect.proto.Command -import org.apache.spark.connect.proto.StreamingQueryManagerCommand -import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult +import org.apache.spark.connect.proto.{Command, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{api, SparkSession} import org.apache.spark.sql.connect.common.InvalidPlanInput +import org.apache.spark.sql.streaming +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamingQueryListener} /** * A class to manage all the [[StreamingQuery]] active in a `SparkSession`. @@ -37,7 +36,7 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput */ @Evolving class StreamingQueryManager private[sql] (sparkSession: SparkSession) - extends api.StreamingQueryManager + extends streaming.StreamingQueryManager with Logging { // Mapping from id to StreamingQueryListener. There's another mapping from id to diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/TableValuedFunction.scala similarity index 91% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/TableValuedFunction.scala index 2a5afd1d58717..05fc4b441f98e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/TableValuedFunction.scala @@ -14,13 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr +import org.apache.spark.sql +import org.apache.spark.sql.{Column, Row} -class TableValuedFunction(sparkSession: SparkSession) extends api.TableValuedFunction { +class TableValuedFunction(sparkSession: SparkSession) extends sql.TableValuedFunction { /** @inheritdoc */ override def range(end: Long): Dataset[java.lang.Long] = { @@ -50,7 +51,7 @@ class TableValuedFunction(sparkSession: SparkSession) extends api.TableValuedFun sparkSession.newDataFrame(args) { builder => builder.getUnresolvedTableValuedFunctionBuilder .setFunctionName(name) - .addAllArguments(args.map(toExpr).asJava) + .addAllArguments(args.map(ColumnNodeToProtoConverter.toExpr).asJava) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/UDFRegistration.scala similarity index 78% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/UDFRegistration.scala index 93d085a25c7b5..67471f9cf5231 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/UDFRegistration.scala @@ -15,10 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect -import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.internal.UdfToProtoUtils +import org.apache.spark.sql +import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} import org.apache.spark.sql.types.DataType /** @@ -30,12 +30,7 @@ import org.apache.spark.sql.types.DataType * * @since 3.5.0 */ -class UDFRegistration(session: SparkSession) extends api.UDFRegistration { - override def registerJava(name: String, className: String, returnDataType: DataType): Unit = { - throw new UnsupportedOperationException( - "registerJava is currently not supported in Spark Connect.") - } - +class UDFRegistration(session: SparkSession) extends sql.UDFRegistration { override protected def register( name: String, udf: UserDefinedFunction, @@ -45,4 +40,15 @@ class UDFRegistration(session: SparkSession) extends api.UDFRegistration { session.registerUdf(UdfToProtoUtils.toProto(named)) named } + + override def registerJava(name: String, className: String, returnDataType: DataType): Unit = { + throw new UnsupportedOperationException( + "registerJava is currently not supported in Spark Connect.") + } + + /** @inheritdoc */ + override def register( + name: String, + udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = + throw ConnectClientUnsupportedErrors.registerUdaf() } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/UdfToProtoUtils.scala similarity index 99% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/UdfToProtoUtils.scala index 409c43f480b8e..487b165c7fc23 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/UdfToProtoUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.connect import scala.collection.mutable import scala.jdk.CollectionConverters._ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala similarity index 95% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala index b03c925878719..f44ec5b2d5046 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.connect import scala.jdk.CollectionConverters._ @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregator, UserDefinedFunction} +import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SqlExpression, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame} /** * Converter for [[ColumnNode]] to [[proto.Expression]] conversions. @@ -264,7 +265,7 @@ case class ProtoColumnNode( override val origin: Origin = CurrentOrigin.get) extends ColumnNode { override def sql: String = expr.toString - override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty + override def children: Seq[ColumnNodeLike] = Seq.empty } sealed trait SubqueryType @@ -283,5 +284,5 @@ case class SubqueryExpressionNode( case SubqueryType.SCALAR => s"($relation)" case _ => s"$subqueryType ($relation)" } - override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty + override def children: Seq[ColumnNodeLike] = Seq.empty } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/package.scala similarity index 93% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/package.scala index ada94b76fcbcd..7be49a0ecdb83 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/package.scala @@ -15,8 +15,8 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.sql -package object sql { +package object connect { type DataFrame = Dataset[Row] } diff --git a/sql/connect/common/src/test/resources/query-tests/queries/apply.json b/sql/connect/common/src/test/resources/query-tests/queries/apply.json index 47ce6f98d6cc9..7f3635d36fc26 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/apply.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/apply.json @@ -21,7 +21,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.Dataset", + "declaringClass": "org.apache.spark.sql.Dataset", "methodName": "apply", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/apply.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/apply.proto.bin index 3bfad0b299103..72abcf5b9e396 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/apply.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/apply.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/between_expr.json b/sql/connect/common/src/test/resources/query-tests/queries/between_expr.json index a4632a1743322..1ac2b81895ec9 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/between_expr.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/between_expr.json @@ -20,7 +20,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "selectExpr", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/between_expr.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/between_expr.proto.bin index 7fee2b59dce2c..078ffc5b1dab9 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/between_expr.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/between_expr.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/col.json b/sql/connect/common/src/test/resources/query-tests/queries/col.json index c211abfe4f176..dbbcc0ca9cbb0 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/col.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/col.json @@ -21,7 +21,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { @@ -43,7 +43,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/col.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/col.proto.bin index cf8f550158dad..8ffba7bdd249e 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/col.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/col.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/colRegex.json b/sql/connect/common/src/test/resources/query-tests/queries/colRegex.json index 60295eb26bcd8..b80879c5fbe43 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/colRegex.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/colRegex.json @@ -21,7 +21,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "colRegex", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin index d0534ebfc8e24..dfcf10215ebc5 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/cube_column.json b/sql/connect/common/src/test/resources/query-tests/queries/cube_column.json index b5f207007eedc..202a93e525f3b 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/cube_column.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/cube_column.json @@ -74,7 +74,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] @@ -94,7 +94,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] @@ -114,7 +114,7 @@ "fileName": "Column.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/cube_column.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/cube_column.proto.bin index 4e637ae45f5dc..ded63b9ebc93c 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/cube_column.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/cube_column.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/cube_string.json b/sql/connect/common/src/test/resources/query-tests/queries/cube_string.json index 5ddb641cb9f75..ff1ae2c6359ae 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/cube_string.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/cube_string.json @@ -22,7 +22,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "cube", "fileName": "Dataset.scala" }, { @@ -44,7 +44,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "cube", "fileName": "Dataset.scala" }, { @@ -76,7 +76,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] @@ -96,7 +96,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] @@ -116,7 +116,7 @@ "fileName": "Column.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin index a24f9d5fc9751..398755055fa8d 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/filter_expr.json b/sql/connect/common/src/test/resources/query-tests/queries/filter_expr.json index 9cfc549702095..b5b382afbfe3f 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/filter_expr.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/filter_expr.json @@ -20,7 +20,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "filter", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/filter_expr.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/filter_expr.proto.bin index 2b6a1b487f0f1..de3653b4d718d 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/filter_expr.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/filter_expr.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_window_time.json b/sql/connect/common/src/test/resources/query-tests/queries/function_window_time.json index 94adb7a40fe61..cf2afe7af40d6 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_window_time.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_window_time.json @@ -27,7 +27,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "withMetadata", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_window_time.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_window_time.proto.bin index fe19255c876df..f150261892de5 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_window_time.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_window_time.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg.json b/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg.json index 79c32f8f424f5..f1aa3d89b5554 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg.json @@ -47,12 +47,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -72,7 +72,7 @@ "fileName": "Column.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -92,12 +92,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -117,7 +117,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -137,12 +137,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -162,7 +162,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -182,12 +182,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -207,7 +207,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -227,12 +227,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -252,7 +252,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -272,12 +272,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -297,7 +297,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -316,12 +316,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -341,7 +341,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -361,12 +361,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -386,7 +386,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin index a3bfa7d2e6510..f92c9b2a33d88 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.json b/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.json index 697e5f1aa5920..1ac119c11e68b 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.json @@ -22,7 +22,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "groupBy", "fileName": "Dataset.scala" }, { @@ -44,7 +44,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "groupBy", "fileName": "Dataset.scala" }, { @@ -70,12 +70,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -95,7 +95,7 @@ "fileName": "Column.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -115,12 +115,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -140,7 +140,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.proto.bin index 4b7a9b3fce3af..c3ae596f791f2 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.json b/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.json index 44129079ad438..dc31cb60f182a 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.json @@ -47,12 +47,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.connect.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -72,7 +72,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -92,12 +92,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.connect.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -117,7 +117,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.proto.bin index d4619035d4b30..83b29ed88b787 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_count.json b/sql/connect/common/src/test/resources/query-tests/queries/groupby_count.json index 8c06597e2e3ee..83cb74b026e4e 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/groupby_count.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/groupby_count.json @@ -53,7 +53,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] @@ -73,7 +73,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] @@ -93,7 +93,7 @@ "fileName": "Column.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_count.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/groupby_count.proto.bin index 1ead8adaad500..de2cfc63c492e 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/groupby_count.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/groupby_count.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_max.json b/sql/connect/common/src/test/resources/query-tests/queries/groupby_max.json index 8127946781a41..9b7364bc4350d 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/groupby_max.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/groupby_max.json @@ -47,12 +47,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.connect.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -72,7 +72,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -92,12 +92,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.connect.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -117,7 +117,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_max.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/groupby_max.proto.bin index ed45ab91f49e4..d674edc85c677 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/groupby_max.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/groupby_max.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_mean.json b/sql/connect/common/src/test/resources/query-tests/queries/groupby_mean.json index 44129079ad438..dc31cb60f182a 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/groupby_mean.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/groupby_mean.json @@ -47,12 +47,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.connect.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -72,7 +72,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -92,12 +92,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.connect.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -117,7 +117,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_mean.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/groupby_mean.proto.bin index d4619035d4b30..83b29ed88b787 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/groupby_mean.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/groupby_mean.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_min.json b/sql/connect/common/src/test/resources/query-tests/queries/groupby_min.json index f3b3ae567d6d1..e182da194134a 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/groupby_min.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/groupby_min.json @@ -47,12 +47,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.connect.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -72,7 +72,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -92,12 +92,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.connect.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -117,7 +117,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_min.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/groupby_min.proto.bin index d387a56ee1c49..81f8c0f44989e 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/groupby_min.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/groupby_min.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_sum.json b/sql/connect/common/src/test/resources/query-tests/queries/groupby_sum.json index eb3544e228fd4..d6fdbd3d09efc 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/groupby_sum.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/groupby_sum.json @@ -47,12 +47,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.connect.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -72,7 +72,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -92,12 +92,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.connect.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] @@ -117,7 +117,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "~~trimmed~anonfun~~", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_sum.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/groupby_sum.proto.bin index 44d74ad1da8b1..5672ce6528498 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/groupby_sum.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/groupby_sum.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupingSets.json b/sql/connect/common/src/test/resources/query-tests/queries/groupingSets.json index f7baa13392364..5168a3ef515f8 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/groupingSets.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/groupingSets.json @@ -47,12 +47,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -72,7 +72,7 @@ "fileName": "Column.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -92,12 +92,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] @@ -117,7 +117,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "toAggCol", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin index ee8d758767ee8..35ebfa8d7b0f3 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/grouping_and_grouping_id.json b/sql/connect/common/src/test/resources/query-tests/queries/grouping_and_grouping_id.json index 8e4a131b7b352..023f25e26ba4b 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/grouping_and_grouping_id.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/grouping_and_grouping_id.json @@ -22,7 +22,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "cube", "fileName": "Dataset.scala" }, { @@ -44,7 +44,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "cube", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/grouping_and_grouping_id.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/grouping_and_grouping_id.proto.bin index 9187ea3fe2721..8fc4d0fe6cd70 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/grouping_and_grouping_id.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/grouping_and_grouping_id.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/hint.json b/sql/connect/common/src/test/resources/query-tests/queries/hint.json index 1a7920b8f78cf..2ac930c0a3a71 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/hint.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/hint.json @@ -21,12 +21,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "~~trimmed~anonfun~~", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.SparkSession", + "declaringClass": "org.apache.spark.sql.connect.SparkSession", "methodName": "newDataset", "fileName": "SparkSession.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin index 28acd75953559..06459ee5b765c 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/orderBy_strings.json b/sql/connect/common/src/test/resources/query-tests/queries/orderBy_strings.json index 6e4f144da3532..0353d1408d4a6 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/orderBy_strings.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/orderBy_strings.json @@ -21,7 +21,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "sort", "fileName": "Dataset.scala" }, { @@ -46,7 +46,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "sort", "fileName": "Dataset.scala" }, { @@ -71,7 +71,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "sort", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/orderBy_strings.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/orderBy_strings.proto.bin index eebeb6c65d91b..ce04a01972235 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/orderBy_strings.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/orderBy_strings.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/pivot.json b/sql/connect/common/src/test/resources/query-tests/queries/pivot.json index 78724786d468a..2692412739e2d 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/pivot.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/pivot.json @@ -90,12 +90,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "pivot", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin index 8ca2970bc4a9b..cdcd5e0c8c20b 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/pivot_without_column_values.json b/sql/connect/common/src/test/resources/query-tests/queries/pivot_without_column_values.json index 2e7614b28d486..289f9e03b0b80 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/pivot_without_column_values.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/pivot_without_column_values.json @@ -90,12 +90,12 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "col", "fileName": "Dataset.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "pivot", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/pivot_without_column_values.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/pivot_without_column_values.proto.bin index 1bdcfd31236f0..6746b466e1190 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/pivot_without_column_values.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/pivot_without_column_values.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/repartitionByRange_num_partitions_expressions.json b/sql/connect/common/src/test/resources/query-tests/queries/repartitionByRange_num_partitions_expressions.json index f4f0daeb26a18..89340866b5a11 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/repartitionByRange_num_partitions_expressions.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/repartitionByRange_num_partitions_expressions.json @@ -43,7 +43,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "repartitionByRange", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/repartitionByRange_num_partitions_expressions.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/repartitionByRange_num_partitions_expressions.proto.bin index 5952ce04a27c5..ab25f6afc7a75 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/repartitionByRange_num_partitions_expressions.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/repartitionByRange_num_partitions_expressions.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/rollup_column.json b/sql/connect/common/src/test/resources/query-tests/queries/rollup_column.json index 598b11cb27790..bf44814871780 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/rollup_column.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/rollup_column.json @@ -74,7 +74,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] @@ -94,7 +94,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] @@ -114,7 +114,7 @@ "fileName": "Column.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/rollup_column.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/rollup_column.proto.bin index 1ac872441fb11..ea9801f2e4dc1 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/rollup_column.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/rollup_column.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/rollup_string.json b/sql/connect/common/src/test/resources/query-tests/queries/rollup_string.json index 798402c639149..d6b6c236e5f43 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/rollup_string.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/rollup_string.json @@ -22,7 +22,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "rollup", "fileName": "Dataset.scala" }, { @@ -44,7 +44,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "rollup", "fileName": "Dataset.scala" }, { @@ -76,7 +76,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] @@ -96,7 +96,7 @@ "fileName": "functions.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] @@ -116,7 +116,7 @@ "fileName": "Column.scala" }, { "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.api.RelationalGroupedDataset", + "declaringClass": "org.apache.spark.sql.RelationalGroupedDataset", "methodName": "count", "fileName": "RelationalGroupedDataset.scala" }] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin index af379d569392b..f68dcf33a5403 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/sampleBy.json b/sql/connect/common/src/test/resources/query-tests/queries/sampleBy.json index 23d222ea4059a..6db49dd9248e2 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/sampleBy.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/sampleBy.json @@ -20,7 +20,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.DataFrameStatFunctions", + "declaringClass": "org.apache.spark.sql.connect.DataFrameStatFunctions", "methodName": "sampleBy", "fileName": "DataFrameStatFunctions.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/sampleBy.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/sampleBy.proto.bin index 2ec5211e9abba..1df38ce7d56f9 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/sampleBy.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/sampleBy.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/selectExpr.json b/sql/connect/common/src/test/resources/query-tests/queries/selectExpr.json index ae4b29daef3ed..a60ec0979c1b1 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/selectExpr.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/selectExpr.json @@ -20,7 +20,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "selectExpr", "fileName": "Dataset.scala" }, { @@ -41,7 +41,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "selectExpr", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/selectExpr.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/selectExpr.proto.bin index dcb1efde331f7..d70101dda47f2 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/selectExpr.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/selectExpr.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/select_collated_string.json b/sql/connect/common/src/test/resources/query-tests/queries/select_collated_string.json index cfc6f57e8a55b..0673695a3bb92 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/select_collated_string.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/select_collated_string.json @@ -20,7 +20,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "select", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/select_collated_string.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/select_collated_string.proto.bin index ca64246359ae3..7df8af61f82eb 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/select_collated_string.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/select_collated_string.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/select_strings.json b/sql/connect/common/src/test/resources/query-tests/queries/select_strings.json index 53e8959b02c79..a642beee6325b 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/select_strings.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/select_strings.json @@ -20,7 +20,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "select", "fileName": "Dataset.scala" }, { @@ -41,7 +41,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "select", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/select_strings.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/select_strings.proto.bin index 434f2371e78f8..294ebb939ee81 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/select_strings.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/select_strings.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.json b/sql/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.json index 7a02546136b6a..e622334615657 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.json @@ -90,7 +90,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "select", "fileName": "Dataset.scala" }, { @@ -110,7 +110,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "select", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.proto.bin index f83aed8d47241..cedff5b9af78d 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/sortWithinPartitions_strings.json b/sql/connect/common/src/test/resources/query-tests/queries/sortWithinPartitions_strings.json index 96469705f6fff..043a3b69f579b 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/sortWithinPartitions_strings.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/sortWithinPartitions_strings.json @@ -21,7 +21,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "sortWithinPartitions", "fileName": "Dataset.scala" }, { @@ -46,7 +46,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "sortWithinPartitions", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/sortWithinPartitions_strings.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/sortWithinPartitions_strings.proto.bin index 1604cb6a2decf..864ee10e4f3b8 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/sortWithinPartitions_strings.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/sortWithinPartitions_strings.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/sort_strings.json b/sql/connect/common/src/test/resources/query-tests/queries/sort_strings.json index 978878ed504e9..22737633c60a3 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/sort_strings.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/sort_strings.json @@ -21,7 +21,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "sort", "fileName": "Dataset.scala" }, { @@ -46,7 +46,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "sort", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/sort_strings.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/sort_strings.proto.bin index 6ac68f86ad212..1a2f899a9d523 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/sort_strings.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/sort_strings.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/toJSON.json b/sql/connect/common/src/test/resources/query-tests/queries/toJSON.json index 6d25bbea1af67..8ea8d319e974d 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/toJSON.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/toJSON.json @@ -26,7 +26,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "toJSON", "fileName": "Dataset.scala" }, { @@ -46,7 +46,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "toJSON", "fileName": "Dataset.scala" }, { @@ -66,7 +66,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "toJSON", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/toJSON.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/toJSON.proto.bin index 47fa8e56de0c2..0ad2a18105b05 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/toJSON.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/toJSON.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/where_expr.json b/sql/connect/common/src/test/resources/query-tests/queries/where_expr.json index 589869a084af7..2817c23b45aab 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/where_expr.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/where_expr.json @@ -20,7 +20,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "where", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/where_expr.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/where_expr.proto.bin index 2277175d687b6..570729cc6dbbc 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/where_expr.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/where_expr.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/withMetadata.json b/sql/connect/common/src/test/resources/query-tests/queries/withMetadata.json index 8628da7d9f202..660d481e3586c 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/withMetadata.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/withMetadata.json @@ -22,7 +22,7 @@ "jvmOrigin": { "stackTrace": [{ "classLoaderName": "app", - "declaringClass": "org.apache.spark.sql.Dataset", + "declaringClass": "org.apache.spark.sql.connect.Dataset", "methodName": "withMetadata", "fileName": "Dataset.scala" }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/withMetadata.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/withMetadata.proto.bin index 4c776b3de558f..94bf58f513b38 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/withMetadata.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/withMetadata.proto.bin differ diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 5e3499573e9d9..497576a6630d3 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -27,8 +27,8 @@ import io.grpc.stub.StreamObserver import org.apache.spark.SparkEnv import org.apache.spark.connect.proto import org.apache.spark.connect.proto.ExecutePlanResponse -import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index dd961a3415cb5..c468c6e443335 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -37,8 +37,9 @@ import org.apache.spark.ml.recommendation._ import org.apache.spark.ml.regression._ import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.{HasTrainingSummary, Identifiable, MLWritable} -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.classic.Dataset +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry import org.apache.spark.sql.connect.service.SessionHolder @@ -147,13 +148,11 @@ private[ml] object MLUtils { val value = literal.getLiteralTypeCase match { case proto.Expression.Literal.LiteralTypeCase.STRUCT => val s = literal.getStruct - val schema = DataTypeProtoConverter.toCatalystType(s.getStructType) - if (schema == VectorUDT.sqlType) { - deserializeVector(s) - } else if (schema == MatrixUDT.sqlType) { - deserializeMatrix(s) - } else { - throw MlUnsupportedException(s"Unsupported parameter struct ${schema} for ${name}") + s.getStructType.getUdt.getJvmClass match { + case "org.apache.spark.ml.linalg.VectorUDT" => deserializeVector(s) + case "org.apache.spark.ml.linalg.MatrixUDT" => deserializeMatrix(s) + case _ => + throw MlUnsupportedException(s"Unsupported struct ${literal.getStruct} for ${name}") } case _ => @@ -189,6 +188,8 @@ private[ml] object MLUtils { array.map(_.asInstanceOf[Double]) } else if (elementType == classOf[String]) { array.map(_.asInstanceOf[String]) + } else if (elementType.isArray && elementType.getComponentType == classOf[Double]) { + array.map(_.asInstanceOf[Array[_]].map(_.asInstanceOf[Double])) } else { throw MlUnsupportedException( s"array element type unsupported, " + @@ -230,14 +231,10 @@ private[ml] object MLUtils { value.asInstanceOf[String] } else if (paramType.isArray) { val compType = paramType.getComponentType - if (compType.isArray) { - throw MlUnsupportedException(s"Array of array unsupported") - } else { - val array = value.asInstanceOf[Array[_]].map { e => - reconcileParam(compType, e) - } - reconcileArray(compType, array) + val array = value.asInstanceOf[Array[_]].map { e => + reconcileParam(compType, e) } + reconcileArray(compType, array) } else { throw MlUnsupportedException(s"Unsupported parameter type, found ${paramType.getName}") } @@ -526,6 +523,7 @@ private[ml] object MLUtils { (classOf[GBTRegressionModel], Set("featureImportances", "evaluateEachIteration")), // Classification Models + (classOf[LinearSVCModel], Set("intercept", "coefficients", "evaluate")), ( classOf[LogisticRegressionModel], Set("intercept", "coefficients", "interceptVector", "coefficientMatrix", "evaluate")), @@ -563,6 +561,10 @@ private[ml] object MLUtils { classOf[BisectingKMeansModel], Set("predict", "numFeatures", "clusterCenters", "computeCost")), (classOf[BisectingKMeansSummary], Set("trainingCost")), + ( + classOf[GaussianMixtureModel], + Set("predict", "numFeatures", "weights", "gaussians", "predictProbability", "gaussiansDF")), + (classOf[GaussianMixtureSummary], Set("probability", "probabilityCol", "logLikelihood")), // Recommendation Models ( @@ -584,8 +586,13 @@ private[ml] object MLUtils { (classOf[MaxAbsScalerModel], Set("maxAbs")), (classOf[MinMaxScalerModel], Set("originalMax", "originalMin")), (classOf[RobustScalerModel], Set("range", "median")), + (classOf[ChiSqSelectorModel], Set("selectedFeatures")), + (classOf[UnivariateFeatureSelectorModel], Set("selectedFeatures")), + (classOf[VarianceThresholdSelectorModel], Set("selectedFeatures")), (classOf[PCAModel], Set("pc", "explainedVariance")), - (classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray"))) + (classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray")), + (classOf[CountVectorizerModel], Set("vocabulary")), + (classOf[OneHotEncoderModel], Set("categorySizes"))) private def validate(obj: Any, method: String): Unit = { assert(obj != null) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala index ee0812a1a98ca..df3e97398012b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala @@ -21,7 +21,7 @@ import org.apache.spark.connect.proto import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.Params import org.apache.spark.sql.Dataset -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter, ProtoDataTypes} +import org.apache.spark.sql.connect.common.{LiteralValueProtoConverter, ProtoDataTypes} import org.apache.spark.sql.connect.service.SessionHolder private[ml] object Serializer { @@ -37,7 +37,7 @@ private[ml] object Serializer { data match { case v: SparseVector => val builder = proto.Expression.Literal.Struct.newBuilder() - builder.setStructType(DataTypeProtoConverter.toConnectProtoType(VectorUDT.sqlType)) + builder.setStructType(ProtoDataTypes.VectorUDT) // type = 0 builder.addElements(proto.Expression.Literal.newBuilder().setByte(0)) // size @@ -50,7 +50,7 @@ private[ml] object Serializer { case v: DenseVector => val builder = proto.Expression.Literal.Struct.newBuilder() - builder.setStructType(DataTypeProtoConverter.toConnectProtoType(VectorUDT.sqlType)) + builder.setStructType(ProtoDataTypes.VectorUDT) // type = 1 builder.addElements(proto.Expression.Literal.newBuilder().setByte(1)) // size = null @@ -65,7 +65,7 @@ private[ml] object Serializer { case m: SparseMatrix => val builder = proto.Expression.Literal.Struct.newBuilder() - builder.setStructType(DataTypeProtoConverter.toConnectProtoType(MatrixUDT.sqlType)) + builder.setStructType(ProtoDataTypes.MatrixUDT) // type = 0 builder.addElements(proto.Expression.Literal.newBuilder().setByte(0)) // numRows @@ -84,7 +84,7 @@ private[ml] object Serializer { case m: DenseMatrix => val builder = proto.Expression.Literal.Struct.newBuilder() - builder.setStructType(DataTypeProtoConverter.toConnectProtoType(MatrixUDT.sqlType)) + builder.setStructType(ProtoDataTypes.MatrixUDT) // type = 1 builder.addElements(proto.Expression.Literal.newBuilder().setByte(1)) // numRows @@ -146,13 +146,13 @@ private[ml] object Serializer { literal.getLiteralTypeCase match { case proto.Expression.Literal.LiteralTypeCase.STRUCT => val struct = literal.getStruct - val schema = DataTypeProtoConverter.toCatalystType(struct.getStructType) - if (schema == VectorUDT.sqlType) { - (MLUtils.deserializeVector(struct), classOf[Vector]) - } else if (schema == MatrixUDT.sqlType) { - (MLUtils.deserializeMatrix(struct), classOf[Matrix]) - } else { - throw MlUnsupportedException(s"$schema not supported") + struct.getStructType.getUdt.getJvmClass match { + case "org.apache.spark.ml.linalg.VectorUDT" => + (MLUtils.deserializeVector(struct), classOf[Vector]) + case "org.apache.spark.ml.linalg.MatrixUDT" => + (MLUtils.deserializeMatrix(struct), classOf[Matrix]) + case _ => + throw MlUnsupportedException(s"Unsupported struct ${literal.getStruct}") } case proto.Expression.Literal.LiteralTypeCase.INTEGER => (literal.getInteger.asInstanceOf[Object], classOf[Int]) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 56824bbb4a417..8e683d89d5b4a 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -43,7 +43,7 @@ import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} -import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, Row, SparkSession} +import org.apache.spark.sql.{Column, Encoders, ForeachWriter, Observation, Row} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LazyExpression, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedStarWithColumns, UnresolvedStarWithColumnsRenames, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTranspose} import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder} @@ -58,6 +58,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} +import org.apache.spark.sql.classic.{Catalog, Dataset, MergeIntoWriter, RelationalGroupedDataset, SparkSession, TypedAggUtils, UserDefinedFunctionUtils} import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidCommandInput, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE @@ -78,7 +79,6 @@ import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} -import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl, TypedAggUtils, UserDefinedFunctionUtils} import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -3524,7 +3524,7 @@ class SparkConnectPlanner( val sourceDs = Dataset.ofRows(session, transformRelation(cmd.getSourceTablePlan)) val mergeInto = sourceDs .mergeInto(cmd.getTargetTableName, Column(transformExpression(cmd.getMergeCondition))) - .asInstanceOf[MergeIntoWriterImpl[Row]] + .asInstanceOf[MergeIntoWriter[Row]] mergeInto.matchedActions ++= matchedActions mergeInto.notMatchedActions ++= notMatchedActions mergeInto.notMatchedBySourceActions ++= notMatchedBySourceActions @@ -3600,14 +3600,14 @@ class SparkConnectPlanner( } private def transformGetDatabase(getGetDatabase: proto.GetDatabase): LogicalPlan = { - CatalogImpl + Catalog .makeDataset(session.catalog.getDatabase(getGetDatabase.getDbName) :: Nil, session) .logicalPlan } private def transformGetTable(getGetTable: proto.GetTable): LogicalPlan = { if (getGetTable.hasDbName) { - CatalogImpl + Catalog .makeDataset( session.catalog.getTable( dbName = getGetTable.getDbName, @@ -3615,7 +3615,7 @@ class SparkConnectPlanner( session) .logicalPlan } else { - CatalogImpl + Catalog .makeDataset(session.catalog.getTable(getGetTable.getTableName) :: Nil, session) .logicalPlan } @@ -3623,7 +3623,7 @@ class SparkConnectPlanner( private def transformGetFunction(getGetFunction: proto.GetFunction): LogicalPlan = { if (getGetFunction.hasDbName) { - CatalogImpl + Catalog .makeDataset( session.catalog.getFunction( dbName = getGetFunction.getDbName, @@ -3631,7 +3631,7 @@ class SparkConnectPlanner( session) .logicalPlan } else { - CatalogImpl + Catalog .makeDataset(session.catalog.getFunction(getGetFunction.getFunctionName) :: Nil, session) .logicalPlan } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index ab6bed7152c09..07c5da9744cc6 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -108,7 +108,7 @@ object StreamingForeachBatchHelper extends Logging { args.df.asInstanceOf[Dataset[Any]] } else { // Recover the Dataset from the DataFrame using the encoder. - Dataset.apply(args.df.sparkSession, args.df.logicalPlan)(encoder) + args.df.as(encoder) } fn(ds, args.batchId) } catch { diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 5b56b7079a897..631885a5d741c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -33,14 +33,13 @@ import org.apache.spark.api.python.PythonFunction.PythonAccumulator import org.apache.spark.connect.proto import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.ml.MLCache import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper -import org.apache.spark.sql.connect.service.ExecuteKey import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC} import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.{SystemClock, Utils} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index 8ca021c5be39e..f751546cf7053 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -23,7 +23,7 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.internal.Logging -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter} import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExtendedMode, FormattedMode, SimpleMode} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala index c59fd02a829ae..8581bb7b98f05 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -30,7 +30,7 @@ import com.google.common.cache.CacheBuilder import org.apache.spark.{SparkEnv, SparkSQLException} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{INTERVAL, SESSION_HOLD_INFO} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.config.Connect.{CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE, CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT, CONNECT_SESSION_MANAGER_MAINTENANCE_INTERVAL} import org.apache.spark.util.ThreadUtils diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala index 1a656605b4154..d06c93cc1cad6 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.connect import java.util.UUID -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService} object SparkConnectTestUtils { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 55c492f511049..2a09d5f8e8bd5 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -24,13 +24,14 @@ import com.google.protobuf.ByteString import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Expression.{Alias, ExpressionString, UnresolvedStar} -import org.apache.spark.sql.{AnalysisException, Dataset, Row} +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.connect.SparkConnectTestUtils import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerTestUtils.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerTestUtils.scala index f700fd67d37fa..ee830c3b96729 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerTestUtils.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerTestUtils.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.connect.planner import org.apache.spark.connect.proto -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.SparkConnectTestUtils import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus, SessionStatus, SparkConnectService} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 1a86ced3a2ac9..2bbd6863b1105 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -26,12 +26,14 @@ import org.apache.spark.{SparkClassNotFoundException, SparkIllegalArgumentExcept import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Expression import org.apache.spark.connect.proto.Join.JoinType -import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Observation, Row, SaveMode} +import org.apache.spark.sql.{AnalysisException, Column, Observation, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, Distinct, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.connect.common.{InvalidCommandInput, InvalidPlanInput} import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.dsl.MockRemoteSession diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala index e8b955cf33ebc..63d623cd2779b 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.connect.planner import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto -import org.apache.spark.sql._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic import org.apache.spark.sql.connect.SparkConnectTestUtils import org.apache.spark.sql.types.{DataType, StructType} @@ -57,7 +58,7 @@ class SparkConnectWithSessionExtensionSuite extends SparkFunSuite { } test("Parse table name with test parser") { - val spark = SparkSession + val spark = classic.SparkSession .builder() .master("local[1]") .withExtensions(extension => extension.injectParser(MyParser)) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index 2b768875c6e20..32dbc9595eab2 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -24,9 +24,9 @@ import com.google.protobuf import org.apache.spark.{SparkContext, SparkEnv, SparkException} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Relation -import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.connect.ConnectProtoUtils import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala index a9843e261fff8..a17c76ae95286 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{ExecutePlanRequest, Plan, UserContext} import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.planner.SparkConnectPlanTest import org.apache.spark.sql.internal.{SessionState, SQLConf} import org.apache.spark.util.{JsonProtocol, ManualClock} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala index 7025146b0295b..dc4340aa18ebd 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala @@ -22,7 +22,7 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.planner.SparkConnectPlanTest import org.apache.spark.util.ManualClock diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala index 729a995f46145..6a133f87c1b11 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala @@ -27,9 +27,7 @@ import org.scalatest.concurrent.Futures.timeout import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.streaming.StreamingQuery -import org.apache.spark.sql.streaming.StreamingQueryManager +import org.apache.spark.sql.classic.{SparkSession, StreamingQuery, StreamingQueryManager} import org.apache.spark.util.ManualClock class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSugar { diff --git a/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala b/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala index 9c5fb515580a7..8753d17f35cb2 100644 --- a/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala +++ b/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala @@ -35,6 +35,7 @@ package sql { package execution { class QueryExecution + class SparkStrategy } package internal { class SharedState diff --git a/sql/core/src/main/scala/org/apache/spark/sql/TableArg.scala b/sql/core/src/main/scala/org/apache/spark/sql/TableArg.scala index 133775c0b666c..bdacdcca24fdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/TableArg.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/TableArg.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, Functio class TableArg( private[sql] val expression: FunctionTableSubqueryArgumentExpression, - sparkSession: SparkSession) + sparkSession: classic.SparkSession) extends TableValuedFunctionArgument { import sparkSession.toRichColumn diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 49fe494903cdc..374d38db371a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.ExpressionUtils.expression import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.internal.ExpressionUtils.expression import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{MutableURLClassLoader, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index ecbc57f25ad44..ad58fc0c2fcf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -30,9 +30,10 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.CONFIG import org.apache.spark.rdd.RDD -import org.apache.spark.sql._ +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericRowWithSchema, Literal} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.classic.{DataFrame, RelationalGroupedDataset, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala index f68385140af8a..ca2b01f9eca0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala @@ -34,7 +34,8 @@ import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath} import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv, SparkException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.config.{CONNECT_SCALA_UDF_STUB_PREFIXES, EXECUTOR_USER_CLASS_PATH_FIRST} -import org.apache.spark.sql.{Artifact, SparkSession} +import org.apache.spark.sql.Artifact +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.ArtifactUtils import org.apache.spark.storage.{BlockManager, CacheId, StorageLevel} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 594ebb4716c41..3299b34bcc933 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.jdk.CollectionConverters._ -import org.apache.avro.Schema +import org.apache.avro.{Schema, SchemaFormatter} import org.apache.avro.file.{DataFileReader, FileReader} import org.apache.avro.generic.{GenericDatumReader, GenericRecord} import org.apache.avro.mapred.{AvroOutputFormat, FsInput} @@ -44,6 +44,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils private[sql] object AvroUtils extends Logging { + + val JSON_INLINE_FORMAT: String = "json/inline" + val JSON_PRETTY_FORMAT: String = "json/pretty" + def inferSchema( spark: SparkSession, options: Map[String, String], @@ -71,7 +75,7 @@ private[sql] object AvroUtils extends Logging { case _ => throw new RuntimeException( s"""Avro schema cannot be converted to a Spark SQL StructType: | - |${avroSchema.toString(true)} + |${SchemaFormatter.format(JSON_PRETTY_FORMAT, avroSchema)} |""".stripMargin) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 6939c15e042f9..c936099ef3764 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.collection.mutable import scala.jdk.CollectionConverters._ -import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder, SchemaFormatter} import org.apache.avro.LogicalTypes.{Decimal, _} import org.apache.avro.Schema.Type._ import org.apache.avro.SchemaBuilder.FieldAssembler @@ -148,9 +148,10 @@ object SchemaConverters extends Logging { case RECORD => val recursiveDepth: Int = existingRecordNames.getOrElse(avroSchema.getFullName, 0) if (recursiveDepth > 0 && recursiveFieldMaxDepth <= 0) { + val formattedAvroSchema = SchemaFormatter.format(AvroUtils.JSON_PRETTY_FORMAT, avroSchema) throw new IncompatibleSchemaException(s""" |Found recursive reference in Avro schema, which can not be processed by Spark by - | default: ${avroSchema.toString(true)}. Try setting the option `recursiveFieldMaxDepth` + | default: $formattedAvroSchema. Try setting the option `recursiveFieldMaxDepth` | to 1 - $RECURSIVE_FIELD_MAX_DEPTH_LIMIT. """.stripMargin) } else if (recursiveDepth > 0 && recursiveDepth >= recursiveFieldMaxDepth) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala deleted file mode 100644 index c39018ff06fca..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalog - -import java.util - -import org.apache.spark.sql.{api, DataFrame, Dataset} -import org.apache.spark.sql.classic.ClassicConversions._ -import org.apache.spark.sql.types.StructType - -/** @inheritdoc */ -abstract class Catalog extends api.Catalog { - /** @inheritdoc */ - override def listDatabases(): Dataset[Database] - - /** @inheritdoc */ - override def listDatabases(pattern: String): Dataset[Database] - - /** @inheritdoc */ - override def listTables(): Dataset[Table] - - /** @inheritdoc */ - override def listTables(dbName: String): Dataset[Table] - - /** @inheritdoc */ - override def listTables(dbName: String, pattern: String): Dataset[Table] - - /** @inheritdoc */ - override def listFunctions(): Dataset[Function] - - /** @inheritdoc */ - override def listFunctions(dbName: String): Dataset[Function] - - /** @inheritdoc */ - override def listFunctions(dbName: String, pattern: String): Dataset[Function] - - /** @inheritdoc */ - override def listColumns(tableName: String): Dataset[Column] - - /** @inheritdoc */ - override def listColumns(dbName: String, tableName: String): Dataset[Column] - - /** @inheritdoc */ - override def createTable(tableName: String, path: String): DataFrame - - /** @inheritdoc */ - override def createTable(tableName: String, path: String, source: String): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - description: String, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - description: String, - options: Map[String, String]): DataFrame - - /** @inheritdoc */ - override def listCatalogs(): Dataset[CatalogMetadata] - - /** @inheritdoc */ - override def listCatalogs(pattern: String): Dataset[CatalogMetadata] - - /** @inheritdoc */ - override def createExternalTable(tableName: String, path: String): DataFrame = - super.createExternalTable(tableName, path) - - /** @inheritdoc */ - override def createExternalTable(tableName: String, path: String, source: String): DataFrame = - super.createExternalTable(tableName, path, source) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - options: util.Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, options) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - options: Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, options) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - schema: StructType, - options: util.Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, schema, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - description: String, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, description, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, schema, options) - - /** @inheritdoc */ - override def createExternalTable( - tableName: String, - source: String, - schema: StructType, - options: Map[String, String]): DataFrame = - super.createExternalTable(tableName, source, schema, options) - - /** @inheritdoc */ - override def createTable( - tableName: String, - source: String, - schema: StructType, - description: String, - options: util.Map[String, String]): DataFrame = - super.createTable(tableName, source, schema, description, options) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/EvalSubqueriesForTimeTravel.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/EvalSubqueriesForTimeTravel.scala index 036ecda4f7c85..7ce2b284684b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/EvalSubqueriesForTimeTravel.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/EvalSubqueriesForTimeTravel.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Literal, ScalarSubquery, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.RELATION_TIME_TRAVEL +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.{QueryExecution, ScalarSubquery => ScalarSubqueryExec, SubqueryExec} class EvalSubqueriesForTimeTravel extends Rule[LogicalPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataSource.scala index 3969057a5ce6f..249ea6e6d04cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataSource.scala @@ -21,12 +21,12 @@ import java.util.Locale import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnresolvedDataSource} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.{SupportsRead, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.errors.QueryCompilationErrors diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala index 5fd88b417ac44..aff65496b763b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.classic import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.sql._ -import org.apache.spark.sql.catalog.{Catalog, CatalogMetadata, Column, Database, Function, Table} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog +import org.apache.spark.sql.catalog.{CatalogMetadata, Column, Database, Function, Table} import org.apache.spark.sql.catalyst.DefinedByConstructorParams import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ @@ -47,7 +48,7 @@ import org.apache.spark.util.ArrayImplicits._ /** * Internal implementation of the user-facing `Catalog`. */ -class CatalogImpl(sparkSession: SparkSession) extends Catalog { +class Catalog(sparkSession: SparkSession) extends catalog.Catalog { private def sessionCatalog: SessionCatalog = sparkSession.sessionState.catalog @@ -114,7 +115,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { val dbName = row.getString(0) makeDatabase(Some(catalog.name()), dbName) } - CatalogImpl.makeDataset(databases.toImmutableArraySeq, sparkSession) + Catalog.makeDataset(databases.toImmutableArraySeq, sparkSession) } /** @@ -160,7 +161,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { sparkSession.sessionState.catalogManager.v2SessionCatalog }.get val tables = qe.toRdd.collect().flatMap { row => resolveTable(row, catalog.name()) } - CatalogImpl.makeDataset(tables.toImmutableArraySeq, sparkSession) + Catalog.makeDataset(tables.toImmutableArraySeq, sparkSession) } private[sql] def resolveTable(row: InternalRow, catalogName: String): Option[Table] = { @@ -297,7 +298,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { functions += makeFunction(parseIdent(row.getString(0))) } - CatalogImpl.makeDataset(functions.result().toImmutableArraySeq, sparkSession) + Catalog.makeDataset(functions.result().toImmutableArraySeq, sparkSession) } private def toFunctionIdent(functionName: String): Seq[String] = { @@ -315,7 +316,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { private def functionExists(ident: Seq[String]): Boolean = { val plan = - UnresolvedFunctionName(ident, CatalogImpl.FUNCTION_EXISTS_COMMAND_NAME, false, None) + UnresolvedFunctionName(ident, Catalog.FUNCTION_EXISTS_COMMAND_NAME, false, None) try { sparkSession.sessionState.executePlan(plan).analyzed match { case _: ResolvedPersistentFunc => true @@ -413,7 +414,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { case _ => throw QueryCompilationErrors.tableOrViewNotFound(ident) } - CatalogImpl.makeDataset(columns, sparkSession) + Catalog.makeDataset(columns, sparkSession) } private def schemaToColumns( @@ -921,7 +922,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def listCatalogs(): Dataset[CatalogMetadata] = { val catalogs = sparkSession.sessionState.catalogManager.listCatalogs(None) - CatalogImpl.makeDataset(catalogs.map(name => makeCatalog(name)), sparkSession) + Catalog.makeDataset(catalogs.map(name => makeCatalog(name)), sparkSession) } /** @@ -931,7 +932,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def listCatalogs(pattern: String): Dataset[CatalogMetadata] = { val catalogs = sparkSession.sessionState.catalogManager.listCatalogs(Some(pattern)) - CatalogImpl.makeDataset(catalogs.map(name => makeCatalog(name)), sparkSession) + Catalog.makeDataset(catalogs.map(name => makeCatalog(name)), sparkSession) } private def makeCatalog(name: String): CatalogMetadata = { @@ -942,7 +943,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } -private[sql] object CatalogImpl { +private[sql] object Catalog { def makeDataset[T <: DefinedByConstructorParams: TypeTag]( data: Seq[T], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameNaFunctions.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameNaFunctions.scala index 0d49e850b4637..b4da4aae07c4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameNaFunctions.scala @@ -15,11 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic import java.{lang => jl} import org.apache.spark.annotation.Stable +import org.apache.spark.sql +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.QueryExecutionErrors @@ -33,14 +35,14 @@ import org.apache.spark.sql.types._ */ @Stable final class DataFrameNaFunctions private[sql](df: DataFrame) - extends api.DataFrameNaFunctions { + extends sql.DataFrameNaFunctions { import df.sparkSession.toRichColumn - protected def drop(minNonNulls: Option[Int]): Dataset[Row] = { + protected def drop(minNonNulls: Option[Int]): DataFrame = { drop0(minNonNulls, outputAttributes) } - override protected def drop(minNonNulls: Option[Int], cols: Seq[String]): Dataset[Row] = { + override protected def drop(minNonNulls: Option[Int], cols: Seq[String]): DataFrame = { drop0(minNonNulls, cols.map(df.resolve)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameReader.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameReader.scala index 8d9813218745d..489ac31e5291d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameReader.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic import java.util.Properties @@ -25,6 +25,8 @@ import org.apache.spark.Partition import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD +import org.apache.spark.sql +import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.expressions.ExprUtils @@ -53,8 +55,7 @@ import org.apache.spark.unsafe.types.UTF8String */ @Stable class DataFrameReader private[sql](sparkSession: SparkSession) - extends api.DataFrameReader { - override type DS[U] = Dataset[U] + extends sql.DataFrameReader { format(sparkSession.sessionState.conf.defaultDataSourceName) @@ -158,7 +159,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) } /** @inheritdoc */ - def json(jsonDataset: Dataset[String]): DataFrame = { + def json(jsonDataset: sql.Dataset[String]): DataFrame = { val parsedOptions = new JSONOptions( extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone, @@ -194,7 +195,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) override def csv(path: String): DataFrame = super.csv(path) /** @inheritdoc */ - def csv(csvDataset: Dataset[String]): DataFrame = { + def csv(csvDataset: sql.Dataset[String]): DataFrame = { val parsedOptions: CSVOptions = new CSVOptions( extraOptions.toMap, sparkSession.sessionState.conf.csvColumnPruning, @@ -264,7 +265,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) override def xml(paths: String*): DataFrame = super.xml(paths: _*) /** @inheritdoc */ - def xml(xmlDataset: Dataset[String]): DataFrame = { + def xml(xmlDataset: sql.Dataset[String]): DataFrame = { val parsedOptions: XmlOptions = new XmlOptions( extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameStatFunctions.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameStatFunctions.scala index 9f7180d8dfd6a..f18073d39c8eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameStatFunctions.scala @@ -15,13 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic import java.{lang => jl, util => ju} import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable +import org.apache.spark.sql +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.stat._ @@ -35,7 +37,7 @@ import org.apache.spark.util.ArrayImplicits._ */ @Stable final class DataFrameStatFunctions private[sql](protected val df: DataFrame) - extends api.DataFrameStatFunctions { + extends sql.DataFrameStatFunctions { /** @inheritdoc */ def approxQuantile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala index 06219e04db591..b423c89fff3db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala @@ -15,14 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.classic import java.util.Locale import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.{DataFrameWriter, Dataset, SaveMode, SparkSession} +import org.apache.spark.sql +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, NoSuchTableException, UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ @@ -38,6 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -50,7 +52,7 @@ import org.apache.spark.util.ArrayImplicits._ * @since 1.4.0 */ @Stable -final class DataFrameWriterImpl[T] private[sql](ds: Dataset[T]) extends DataFrameWriter[T] { +final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFrameWriter[T] { format(ds.sparkSession.sessionState.conf.defaultDataSourceName) private val df = ds.toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala index 86ea55bc59b7b..e4efee93d2a08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.classic import java.util @@ -23,7 +23,8 @@ import scala.collection.mutable import scala.jdk.CollectionConverters.MapHasAsScala import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{Column, DataFrame, DataFrameWriterV2, Dataset} +import org.apache.spark.sql +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical._ @@ -34,13 +35,14 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.types.IntegerType /** - * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 API. + * Interface used to write a [[org.apache.spark.sql.classic.Dataset]] to external storage using + * the v2 API. * * @since 3.0.0 */ @Experimental -final class DataFrameWriterV2Impl[T] private[sql](table: String, ds: Dataset[T]) - extends DataFrameWriterV2[T] { +final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) + extends sql.DataFrameWriterV2[T] { private val df: DataFrame = ds.toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamReader.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamReader.scala index 0a731c450baf5..f6148623ebfae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamReader.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.classic import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving -import org.apache.spark.sql.{api, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.UnresolvedDataSource import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} @@ -28,6 +27,7 @@ import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.json.JsonUtils.checkJsonSchema import org.apache.spark.sql.execution.datasources.xml.XmlUtils.checkXmlSchema +import org.apache.spark.sql.streaming import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -38,7 +38,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * @since 2.0.0 */ @Evolving -final class DataStreamReader private[sql](sparkSession: SparkSession) extends api.DataStreamReader { +final class DataStreamReader private[sql](sparkSession: SparkSession) + extends streaming.DataStreamReader { /** @inheritdoc */ def format(source: String): this.type = { this.source = source diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala index d41933c6a135c..96e8755577542 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.classic import java.util.Locale import java.util.concurrent.TimeoutException @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Evolving import org.apache.spark.api.java.function.VoidFunction2 -import org.apache.spark.sql._ +import org.apache.spark.sql.{streaming, Dataset => DS, ForeachWriter} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDat import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2 import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources._ +import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, Trigger} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -55,8 +56,7 @@ import org.apache.spark.util.Utils * @since 2.0.0 */ @Evolving -final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends api.DataStreamWriter[T] { - type DS[U] = Dataset[U] +final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.DataStreamWriter[T] { /** @inheritdoc */ def outputMode(outputMode: OutputMode): this.type = { @@ -139,7 +139,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends api.DataStr @Evolving @throws[TimeoutException] def toTable(tableName: String): StreamingQuery = { - import ds.sparkSession.sessionState.analyzer.CatalogAndIdentifier import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ val parser = ds.sparkSession.sessionState.sqlParser @@ -355,7 +354,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends api.DataStr /** @inheritdoc */ @Evolving - def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { + def foreachBatch(function: (DS[T], Long) => Unit): this.type = { this.source = DataStreamWriter.SOURCE_NAME_FOREACH_BATCH if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") this.foreachBatchWriter = function @@ -410,7 +409,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends api.DataStr /** @inheritdoc */ @Evolving - override def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): this.type = + override def foreachBatch(function: VoidFunction2[DS[T], java.lang.Long]): this.type = super.foreachBatch(function) /////////////////////////////////////////////////////////////////////////////////////// @@ -430,7 +429,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends api.DataStr private var foreachWriterEncoder: ExpressionEncoder[Any] = ds.exprEnc.asInstanceOf[ExpressionEncoder[Any]] - private var foreachBatchWriter: (Dataset[T], Long) => Unit = _ + private var foreachBatchWriter: (DS[T], Long) => Unit = _ private var partitioningColumns: Option[Seq[String]] = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index e41521cba533a..d78a3a391edb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream} import java.util @@ -29,7 +29,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils import org.apache.commons.text.StringEscapeUtils -import org.apache.spark.TaskContext +import org.apache.spark.{sql, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ @@ -38,6 +38,7 @@ import org.apache.spark.api.r.RRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.resource.ResourceProfile +import org.apache.spark.sql.{AnalysisException, Column, Encoder, Encoders, Observation, Row, SQLContext, TableArg, TypedColumn} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation @@ -53,6 +54,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.TypedAggUtils.withInputType import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression @@ -62,9 +64,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelationWithTable import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, ExpressionColumnNode, MergeIntoWriterImpl, SQLConf} -import org.apache.spark.sql.internal.TypedAggUtils.withInputType -import org.apache.spark.sql.streaming.DataStreamWriter +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.storage.StorageLevel @@ -232,8 +232,7 @@ private[sql] object Dataset { class Dataset[T] private[sql]( @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, @transient encoderGenerator: () => Encoder[T]) - extends api.Dataset[T] { - type DS[U] = Dataset[U] + extends sql.Dataset[T] { @transient lazy val sparkSession: SparkSession = { if (queryExecution == null || queryExecution.sparkSession == null) { @@ -623,12 +622,12 @@ class Dataset[T] private[sql]( def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF()) /** @inheritdoc */ - def join(right: Dataset[_]): DataFrame = withPlan { + def join(right: sql.Dataset[_]): DataFrame = withPlan { Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) } /** @inheritdoc */ - def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { + def join(right: sql.Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( @@ -687,19 +686,19 @@ class Dataset[T] private[sql]( } /** @inheritdoc */ - def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = { + def join(right: sql.Dataset[_], joinExprs: Column, joinType: String): DataFrame = { withPlan { resolveSelfJoinCondition(right, Some(joinExprs), joinType) } } /** @inheritdoc */ - def crossJoin(right: Dataset[_]): DataFrame = withPlan { + def crossJoin(right: sql.Dataset[_]): DataFrame = withPlan { Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) } /** @inheritdoc */ - def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { + def joinWith[U](other: sql.Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, // etc. val joined = sparkSession.sessionState.executePlan( @@ -724,7 +723,7 @@ class Dataset[T] private[sql]( } private[sql] def lateralJoin( - right: DS[_], joinExprs: Option[Column], joinType: JoinType): DataFrame = { + right: sql.Dataset[_], joinExprs: Option[Column], joinType: JoinType): DataFrame = { withPlan { LateralJoin( logicalPlan, @@ -736,22 +735,22 @@ class Dataset[T] private[sql]( } /** @inheritdoc */ - def lateralJoin(right: DS[_]): DataFrame = { + def lateralJoin(right: sql.Dataset[_]): DataFrame = { lateralJoin(right, None, Inner) } /** @inheritdoc */ - def lateralJoin(right: DS[_], joinExprs: Column): DataFrame = { + def lateralJoin(right: sql.Dataset[_], joinExprs: Column): DataFrame = { lateralJoin(right, Some(joinExprs), Inner) } /** @inheritdoc */ - def lateralJoin(right: DS[_], joinType: String): DataFrame = { + def lateralJoin(right: sql.Dataset[_], joinType: String): DataFrame = { lateralJoin(right, None, LateralJoinType(joinType)) } /** @inheritdoc */ - def lateralJoin(right: DS[_], joinExprs: Column, joinType: String): DataFrame = { + def lateralJoin(right: sql.Dataset[_], joinExprs: Column, joinType: String): DataFrame = { lateralJoin(right, Some(joinExprs), LateralJoinType(joinType)) } @@ -1142,12 +1141,12 @@ class Dataset[T] private[sql]( } /** @inheritdoc */ - def union(other: Dataset[T]): Dataset[T] = withSetOperator { + def union(other: sql.Dataset[T]): Dataset[T] = withSetOperator { combineUnions(Union(logicalPlan, other.logicalPlan)) } /** @inheritdoc */ - def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] = { + def unionByName(other: sql.Dataset[T], allowMissingColumns: Boolean): Dataset[T] = { withSetOperator { // We need to resolve the by-name Union first, as the underlying Unions are already resolved // and we can only combine adjacent Unions if they are all resolved. @@ -1158,22 +1157,22 @@ class Dataset[T] private[sql]( } /** @inheritdoc */ - def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { + def intersect(other: sql.Dataset[T]): Dataset[T] = withSetOperator { Intersect(logicalPlan, other.logicalPlan, isAll = false) } /** @inheritdoc */ - def intersectAll(other: Dataset[T]): Dataset[T] = withSetOperator { + def intersectAll(other: sql.Dataset[T]): Dataset[T] = withSetOperator { Intersect(logicalPlan, other.logicalPlan, isAll = true) } /** @inheritdoc */ - def except(other: Dataset[T]): Dataset[T] = withSetOperator { + def except(other: sql.Dataset[T]): Dataset[T] = withSetOperator { Except(logicalPlan, other.logicalPlan, isAll = false) } /** @inheritdoc */ - def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator { + def exceptAll(other: sql.Dataset[T]): Dataset[T] = withSetOperator { Except(logicalPlan, other.logicalPlan, isAll = true) } @@ -1185,7 +1184,7 @@ class Dataset[T] private[sql]( } /** @inheritdoc */ - def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = { + def randomSplit(weights: Array[Double], seed: Long): Array[sql.Dataset[T]] = { require(weights.forall(_ >= 0), s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}") require(weights.sum > 0, @@ -1215,7 +1214,7 @@ class Dataset[T] private[sql]( } /** @inheritdoc */ - override def randomSplit(weights: Array[Double]): Array[Dataset[T]] = + override def randomSplit(weights: Array[Double]): Array[sql.Dataset[T]] = randomSplit(weights, Utils.random.nextLong()) /** @@ -1224,12 +1223,12 @@ class Dataset[T] private[sql]( * @param weights weights for splits, will be normalized if they don't sum to 1. * @param seed Seed for sampling. */ - private[spark] def randomSplit(weights: List[Double], seed: Long): Array[Dataset[T]] = { + private[spark] def randomSplit(weights: List[Double], seed: Long): Array[sql.Dataset[T]] = { randomSplit(weights.toArray, seed) } /** @inheritdoc */ - override def randomSplitAsList(weights: Array[Double], seed: Long): util.List[Dataset[T]] = + override def randomSplitAsList(weights: Array[Double], seed: Long): util.List[sql.Dataset[T]] = util.Arrays.asList(randomSplit(weights, seed): _*) /** @inheritdoc */ @@ -1285,24 +1284,6 @@ class Dataset[T] private[sql]( } } - /** @inheritdoc */ - private[spark] def withColumns( - colNames: Seq[String], - cols: Seq[Column], - metadata: Seq[Metadata]): DataFrame = { - require(colNames.size == metadata.size, - s"The size of column names: ${colNames.size} isn't equal to " + - s"the size of metadata elements: ${metadata.size}") - val newCols = colNames.zip(cols).zip(metadata).map { case ((colName, col), metadata) => - col.as(colName, metadata) - } - withColumns(colNames, newCols) - } - - /** @inheritdoc */ - private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = - withColumns(Seq(colName), Seq(col), Seq(metadata)) - protected[spark] def withColumnsRenamed( colNames: Seq[String], newColNames: Seq[String]): DataFrame = { @@ -1645,7 +1626,7 @@ class Dataset[T] private[sql]( errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", messageParameters = Map("methodName" -> toSQLId("write"))) } - new DataFrameWriterImpl[T](this) + new DataFrameWriter[T](this) } /** @inheritdoc */ @@ -1656,7 +1637,7 @@ class Dataset[T] private[sql]( errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", messageParameters = Map("methodName" -> toSQLId("writeTo"))) } - new DataFrameWriterV2Impl[T](table, this) + new DataFrameWriterV2[T](table, this) } /** @inheritdoc */ @@ -1667,7 +1648,7 @@ class Dataset[T] private[sql]( messageParameters = Map("methodName" -> toSQLId("mergeInto"))) } - new MergeIntoWriterImpl[T](table, this, condition) + new MergeIntoWriter[T](table, this, condition) } /** @inheritdoc */ @@ -1728,7 +1709,7 @@ class Dataset[T] private[sql]( /** @inheritdoc */ @DeveloperApi - def sameSemantics(other: Dataset[T]): Boolean = { + def sameSemantics(other: sql.Dataset[T]): Boolean = { queryExecution.analyzed.sameResult(other.queryExecution.analyzed) } @@ -1758,30 +1739,30 @@ class Dataset[T] private[sql]( override def drop(col: Column): DataFrame = super.drop(col) /** @inheritdoc */ - override def join(right: Dataset[_], usingColumn: String): DataFrame = + override def join(right: sql.Dataset[_], usingColumn: String): DataFrame = super.join(right, usingColumn) /** @inheritdoc */ - override def join(right: Dataset[_], usingColumns: Array[String]): DataFrame = + override def join(right: sql.Dataset[_], usingColumns: Array[String]): DataFrame = super.join(right, usingColumns) /** @inheritdoc */ - override def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame = + override def join(right: sql.Dataset[_], usingColumns: Seq[String]): DataFrame = super.join(right, usingColumns) /** @inheritdoc */ - override def join(right: Dataset[_], usingColumn: String, joinType: String): DataFrame = + override def join(right: sql.Dataset[_], usingColumn: String, joinType: String): DataFrame = super.join(right, usingColumn, joinType) /** @inheritdoc */ override def join( - right: Dataset[_], + right: sql.Dataset[_], usingColumns: Array[String], joinType: String): DataFrame = super.join(right, usingColumns, joinType) /** @inheritdoc */ - override def join(right: Dataset[_], joinExprs: Column): DataFrame = + override def join(right: sql.Dataset[_], joinExprs: Column): DataFrame = super.join(right, joinExprs) /** @inheritdoc */ @@ -1868,7 +1849,7 @@ class Dataset[T] private[sql]( super.localCheckpoint(eager, storageLevel) /** @inheritdoc */ - override def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = + override def joinWith[U](other: sql.Dataset[U], condition: Column): Dataset[(T, U)] = super.joinWith(other, condition) /** @inheritdoc */ @@ -1922,10 +1903,10 @@ class Dataset[T] private[sql]( override def where(conditionExpr: String): Dataset[T] = super.where(conditionExpr) /** @inheritdoc */ - override def unionAll(other: Dataset[T]): Dataset[T] = super.unionAll(other) + override def unionAll(other: sql.Dataset[T]): Dataset[T] = super.unionAll(other) /** @inheritdoc */ - override def unionByName(other: Dataset[T]): Dataset[T] = super.unionByName(other) + override def unionByName(other: sql.Dataset[T]): Dataset[T] = super.unionByName(other) /** @inheritdoc */ override def sample(fraction: Double, seed: Long): Dataset[T] = super.sample(fraction, seed) @@ -2027,6 +2008,20 @@ class Dataset[T] private[sql]( encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] + /** @inheritdoc */ + override private[spark] def withColumns( + colNames: Seq[String], + cols: Seq[Column], + metadata: Seq[Metadata]): DataFrame = + super.withColumns(colNames, cols, metadata) + + /** @inheritdoc */ + override private[spark] def withColumn( + colName: String, + col: Column, + metadata: Metadata): DataFrame = + super.withColumn(colName, col, metadata) + //////////////////////////////////////////////////////////////////////////// // For Python API //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/KeyValueGroupedDataset.scala similarity index 89% rename from sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/KeyValueGroupedDataset.scala index 6dcf01d3a9db2..89158928dd55a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/KeyValueGroupedDataset.scala @@ -15,18 +15,20 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic import org.apache.spark.api.java.function._ +import org.apache.spark.sql +import org.apache.spark.sql.{Column, Encoder, TypedColumn} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.TypedAggUtils.{aggKeyColumn, withInputType} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator -import org.apache.spark.sql.internal.TypedAggUtils.{aggKeyColumn, withInputType} import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode} /** @@ -42,15 +44,14 @@ class KeyValueGroupedDataset[K, V] private[sql]( @transient private[sql] val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) - extends api.KeyValueGroupedDataset[K, V] { - type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL] + extends sql.KeyValueGroupedDataset[K, V] { private implicit def kEncoderImpl: Encoder[K] = kEncoder private implicit def vEncoderImpl: Encoder[V] = vEncoder private def logicalPlan = queryExecution.analyzed private def sparkSession = queryExecution.sparkSession - import queryExecution.sparkSession._ + import queryExecution.sparkSession.toRichColumn /** @inheritdoc */ def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] = @@ -135,9 +136,10 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S])( + initialState: sql.KeyValueGroupedDataset[K, S])( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) + val initialStateImpl = castToImpl(initialState) Dataset[U]( sparkSession, @@ -149,9 +151,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( isMapGroupsWithState = true, timeoutConf, child = logicalPlan, - initialState.groupingAttributes, - initialState.dataAttributes, - initialState.queryExecution.analyzed + initialStateImpl.groupingAttributes, + initialStateImpl.dataAttributes, + initialStateImpl.queryExecution.analyzed )) } @@ -179,11 +181,12 @@ class KeyValueGroupedDataset[K, V] private[sql]( def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S])( + initialState: sql.KeyValueGroupedDataset[K, S])( func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { throw new IllegalArgumentException("The output mode of function should be append or update") } + val initialStateImpl = castToImpl(initialState) Dataset[U]( sparkSession, FlatMapGroupsWithState[K, V, S, U]( @@ -194,14 +197,14 @@ class KeyValueGroupedDataset[K, V] private[sql]( isMapGroupsWithState = false, timeoutConf, child = logicalPlan, - initialState.groupingAttributes, - initialState.dataAttributes, - initialState.queryExecution.analyzed + initialStateImpl.groupingAttributes, + initialStateImpl.dataAttributes, + initialStateImpl.queryExecution.analyzed )) } /** @inheritdoc */ - private[sql] def transformWithState[U: Encoder]( + def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeMode: TimeMode, outputMode: OutputMode): Dataset[U] = { @@ -219,7 +222,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( } /** @inheritdoc */ - private[sql] def transformWithState[U: Encoder]( + def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], eventTimeColumnName: String, outputMode: OutputMode): Dataset[U] = { @@ -235,11 +238,12 @@ class KeyValueGroupedDataset[K, V] private[sql]( } /** @inheritdoc */ - private[sql] def transformWithState[U: Encoder, S: Encoder]( + def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeMode: TimeMode, outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { + initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = { + val initialStateImpl = castToImpl(initialState) Dataset[U]( sparkSession, TransformWithState[K, V, U, S]( @@ -249,19 +253,20 @@ class KeyValueGroupedDataset[K, V] private[sql]( timeMode, outputMode, child = logicalPlan, - initialState.groupingAttributes, - initialState.dataAttributes, - initialState.queryExecution.analyzed + initialStateImpl.groupingAttributes, + initialStateImpl.dataAttributes, + initialStateImpl.queryExecution.analyzed ) ) } /** @inheritdoc */ - private[sql] def transformWithState[U: Encoder, S: Encoder]( + def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], eventTimeColumnName: String, outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { + initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = { + val initialStateImpl = castToImpl(initialState) val transformWithState = TransformWithState[K, V, U, S]( groupingAttributes, dataAttributes, @@ -269,9 +274,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( TimeMode.EventTime(), outputMode, child = logicalPlan, - initialState.groupingAttributes, - initialState.dataAttributes, - initialState.queryExecution.analyzed + initialStateImpl.groupingAttributes, + initialStateImpl.dataAttributes, + initialStateImpl.queryExecution.analyzed ) updateEventTimeColumnAfterTransformWithState(transformWithState, eventTimeColumnName) @@ -317,23 +322,24 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** @inheritdoc */ def cogroupSorted[U, R : Encoder]( - other: KeyValueGroupedDataset[K, U])( + other: sql.KeyValueGroupedDataset[K, U])( thisSortExprs: Column*)( otherSortExprs: Column*)( f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - implicit val uEncoder = other.vEncoderImpl + val otherImpl = castToImpl(other) + implicit val uEncoder: Encoder[U] = otherImpl.vEncoderImpl Dataset[R]( sparkSession, CoGroup( f, this.groupingAttributes, - other.groupingAttributes, + otherImpl.groupingAttributes, this.dataAttributes, - other.dataAttributes, + otherImpl.dataAttributes, MapGroups.sortOrder(thisSortExprs.map(_.expr)), MapGroups.sortOrder(otherSortExprs.map(_.expr)), this.logicalPlan, - other.logicalPlan)) + otherImpl.logicalPlan)) } override def toString: String = { @@ -409,7 +415,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf, initialState) /** @inheritdoc */ @@ -428,7 +434,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = super.flatMapGroupsWithState( func, outputMode, @@ -458,7 +464,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeMode: TimeMode, outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S], + initialState: sql.KeyValueGroupedDataset[K, S], outputEncoder: Encoder[U], initialStateEncoder: Encoder[S]) = super.transformWithState( statefulProcessor, @@ -472,7 +478,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( override private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S], + initialState: sql.KeyValueGroupedDataset[K, S], eventTimeColumnName: String, outputEncoder: Encoder[U], initialStateEncoder: Encoder[S]) = super.transformWithState( @@ -554,20 +560,20 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** @inheritdoc */ override def cogroup[U, R: Encoder]( - other: KeyValueGroupedDataset[K, U])( + other: sql.KeyValueGroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = super.cogroup(other)(f) /** @inheritdoc */ override def cogroup[U, R]( - other: KeyValueGroupedDataset[K, U], + other: sql.KeyValueGroupedDataset[K, U], f: CoGroupFunction[K, V, U, R], encoder: Encoder[R]): Dataset[R] = super.cogroup(other, f, encoder) /** @inheritdoc */ override def cogroupSorted[U, R]( - other: KeyValueGroupedDataset[K, U], + other: sql.KeyValueGroupedDataset[K, U], thisSortExprs: Array[Column], otherSortExprs: Array[Column], f: CoGroupFunction[K, V, U, R], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/MergeIntoWriter.scala similarity index 87% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/MergeIntoWriter.scala index 2f1a34648a470..0269b15061c97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/MergeIntoWriter.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.classic import scala.collection.mutable import org.apache.spark.SparkRuntimeException import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{Column, DataFrame, Dataset, MergeIntoWriter} +import org.apache.spark.sql +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.functions.expr @@ -38,8 +39,8 @@ import org.apache.spark.sql.functions.expr * @since 4.0.0 */ @Experimental -class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on: Column) - extends MergeIntoWriter[T] { +class MergeIntoWriter[T] private[sql](table: String, ds: Dataset[T], on: Column) + extends sql.MergeIntoWriter[T] { private val df: DataFrame = ds.toDF() @@ -75,28 +76,28 @@ class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on: Co qe.assertCommandExecuted() } - override protected[sql] def insertAll(condition: Option[Column]): MergeIntoWriter[T] = { + override protected[sql] def insertAll(condition: Option[Column]): this.type = { this.notMatchedActions += InsertStarAction(condition.map(_.expr)) this } override protected[sql] def insert( condition: Option[Column], - map: Map[String, Column]): MergeIntoWriter[T] = { + map: Map[String, Column]): this.type = { this.notMatchedActions += InsertAction(condition.map(_.expr), mapToAssignments(map)) this } override protected[sql] def updateAll( condition: Option[Column], - notMatchedBySource: Boolean): MergeIntoWriter[T] = { + notMatchedBySource: Boolean): this.type = { appendUpdateDeleteAction(UpdateStarAction(condition.map(_.expr)), notMatchedBySource) } override protected[sql] def update( condition: Option[Column], map: Map[String, Column], - notMatchedBySource: Boolean): MergeIntoWriter[T] = { + notMatchedBySource: Boolean): this.type = { appendUpdateDeleteAction( UpdateAction(condition.map(_.expr), mapToAssignments(map)), notMatchedBySource) @@ -104,13 +105,13 @@ class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on: Co override protected[sql] def delete( condition: Option[Column], - notMatchedBySource: Boolean): MergeIntoWriter[T] = { + notMatchedBySource: Boolean): this.type = { appendUpdateDeleteAction(DeleteAction(condition.map(_.expr)), notMatchedBySource) } private def appendUpdateDeleteAction( action: MergeAction, - notMatchedBySource: Boolean): MergeIntoWriter[T] = { + notMatchedBySource: Boolean): this.type = { if (notMatchedBySource) { notMatchedBySourceActions += action } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/ObservationManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/ObservationManager.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala index 4fa1f0f4962a0..1fde17ace3067 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/ObservationManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala @@ -14,11 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.classic import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.sql.{Dataset, Observation, SparkSession} +import org.apache.spark.sql.Observation import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala index b8c4b03fc13d2..082292145e858 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic import org.apache.spark.SparkRuntimeException import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql +import org.apache.spark.sql.{AnalysisException, Column, Encoder} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -29,10 +31,10 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.ExpressionUtils.generateAlias +import org.apache.spark.sql.classic.TypedAggUtils.withInputType import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.internal.ExpressionUtils.generateAlias -import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{NumericType, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -53,13 +55,13 @@ class RelationalGroupedDataset protected[sql]( protected[sql] val df: DataFrame, private[sql] val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) - extends api.RelationalGroupedDataset { + extends sql.RelationalGroupedDataset { import RelationalGroupedDataset._ - import df.sparkSession._ + import df.sparkSession.toRichColumn override protected def toDF(aggCols: Seq[Column]): DataFrame = { - val aggExprs = aggCols.map(expression).map { e => + val aggExprs = aggCols.map(_.expr).map { e => withInputType(e, df.exprEnc, df.logicalPlan.output) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/RuntimeConfig.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/RuntimeConfig.scala index b2004215a99f6..d3b498bd16068 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/RuntimeConfig.scala @@ -15,15 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.classic import scala.jdk.CollectionConverters._ import org.apache.spark.SPARK_DOC_ROOT import org.apache.spark.annotation.Stable import org.apache.spark.internal.config.{ConfigEntry, DEFAULT_PARALLELISM, OptionalConfigEntry} -import org.apache.spark.sql.RuntimeConfig +import org.apache.spark.sql import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf /** * Runtime configuration interface for Spark. To access this, use `SparkSession.conf`. @@ -33,7 +34,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors * @since 2.0.0 */ @Stable -class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends RuntimeConfig { +class RuntimeConfig private[sql](val sqlConf: SQLConf = new SQLConf) extends sql.RuntimeConfig { /** @inheritdoc */ def set(key: String, value: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala index 1318563f8c93b..2d5d26fe6016e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic import java.util.{List => JList, Map => JMap, Properties} @@ -26,9 +26,12 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.rdd.RDD +import org.apache.spark.sql +import org.apache.spark.sql.{Encoder, ExperimentalMethods, Row} import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.classic.SparkSession.{builder => newSparkSessionBuilder} import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQueryManager} @@ -54,9 +57,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager */ @Stable class SQLContext private[sql] (override val sparkSession: SparkSession) - extends api.SQLContext(sparkSession) { - - self => + extends sql.SQLContext(sparkSession) { self => sparkSession.sparkContext.assertNotStopped() @@ -65,7 +66,7 @@ class SQLContext private[sql] (override val sparkSession: SparkSession) @deprecated("Use SparkSession.builder instead", "2.0.0") def this(sc: SparkContext) = { - this(SparkSession.builder().sparkContext(sc).getOrCreate()) + this(newSparkSessionBuilder().sparkContext(sc).getOrCreate()) } @deprecated("Use SparkSession.builder instead", "2.0.0") @@ -106,7 +107,6 @@ class SQLContext private[sql] (override val sparkSession: SparkSession) /** @inheritdoc */ object implicits extends SQLImplicits { - /** @inheritdoc */ override protected def session: SparkSession = sparkSession } @@ -377,14 +377,14 @@ class SQLContext private[sql] (override val sparkSession: SparkSession) super.jdbc(url, table, theParts) } -object SQLContext extends api.SQLContextCompanion { +object SQLContext extends sql.SQLContextCompanion { override private[sql] type SQLContextImpl = SQLContext override private[sql] type SparkContextImpl = SparkContext /** @inheritdoc */ def getOrCreate(sparkContext: SparkContext): SQLContext = { - SparkSession.builder().sparkContext(sparkContext).getOrCreate().sqlContext + newSparkSessionBuilder().sparkContext(sparkContext).getOrCreate().sqlContext } /** @inheritdoc */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLImplicits.scala similarity index 53% rename from sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/SQLImplicits.scala index b6ed50447109d..165deeb1bf65f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLImplicits.scala @@ -15,11 +15,28 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic + +import scala.language.implicitConversions + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql +import org.apache.spark.sql.Encoder /** @inheritdoc */ -abstract class SQLImplicits extends api.SQLImplicits { - type DS[U] = Dataset[U] +abstract class SQLImplicits extends sql.SQLImplicits { protected def session: SparkSession + + override implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T] = + new DatasetHolder[T](session.createDataset(s)) + + override implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T] = + new DatasetHolder[T](session.createDataset(rdd)) +} + +class DatasetHolder[U](ds: Dataset[U]) extends sql.DatasetHolder[U] { + override def toDS(): Dataset[U] = ds + override def toDF(): DataFrame = ds.toDF() + override def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala index 3b36f6b59cb38..4ef7547f544e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic import java.net.URI import java.nio.file.Paths @@ -32,13 +32,13 @@ import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, SparkException, import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.{Logging, MDC} -import org.apache.spark.internal.LogKeys.{CALL_SITE_LONG_FORM, CLASS_NAME} +import org.apache.spark.internal.LogKeys.{CALL_SITE_LONG_FORM, CLASS_NAME, CONFIG} import org.apache.spark.internal.config.{ConfigEntry, EXECUTOR_ALLOW_SPARK_CONTEXT} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} -import org.apache.spark.sql.SparkSession.applyAndLoadExtensions +import org.apache.spark.sql +import org.apache.spark.sql.{Artifact, DataSourceRegistration, Encoder, Encoders, ExperimentalMethods, Row, SparkSessionBuilder, SparkSessionCompanion, SparkSessionExtensions, SparkSessionExtensionsProvider, UDTFRegistration} import org.apache.spark.sql.artifact.ArtifactManager -import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation} import org.apache.spark.sql.catalyst.encoders._ @@ -48,6 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.classic.SparkSession.applyAndLoadExtensions import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.errors.{QueryCompilationErrors, SqlScriptingErrors} import org.apache.spark.sql.execution._ @@ -58,7 +59,6 @@ import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.scripting.SqlScriptingExecution import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.{CallSite, SparkFileUtils, ThreadUtils, Utils} @@ -98,7 +98,7 @@ class SparkSession private( @transient private[sql] val extensions: SparkSessionExtensions, @transient private[sql] val initialSessionOptions: Map[String, String], @transient private val parentManagedJobTags: Map[String, String]) - extends api.SparkSession with Logging with classic.ColumnConversions { self => + extends sql.SparkSession with Logging with ColumnConversions { self => // The call site where this SparkSession was constructed. private val creationSite: CallSite = Utils.getCallSite() @@ -197,7 +197,7 @@ class SparkSession private( val sqlContext: SQLContext = new SQLContext(this) /** @inheritdoc */ - @transient lazy val conf: RuntimeConfig = new RuntimeConfigImpl(sessionState.conf) + @transient lazy val conf: RuntimeConfig = new RuntimeConfig(sessionState.conf) /** @inheritdoc */ def listenerManager: ExecutionListenerManager = sessionState.listenerManager @@ -418,7 +418,7 @@ class SparkSession private( * ------------------------- */ /** @inheritdoc */ - @transient lazy val catalog: Catalog = new CatalogImpl(self) + @transient lazy val catalog: Catalog = new Catalog(self) /** @inheritdoc */ def table(tableName: String): DataFrame = { @@ -778,21 +778,6 @@ class SparkSession private( }.toImmutableArraySeq } - /** - * Execute a block of code with this session set as the active session, and restore the - * previous session on completion. - */ - private[sql] def withActive[T](block: => T): T = { - // Use the active session thread local directly to make sure we get the session that is actually - // set and not the default session. This to prevent that we promote the default session to the - // active session once we are done. - val old = SparkSession.getActiveSession.orNull - SparkSession.setActiveSession(this) - try block finally { - SparkSession.setActiveSession(old) - } - } - private[sql] def leafNodeDefaultParallelism: Int = { sessionState.conf.getConf(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM) .getOrElse(sparkContext.defaultParallelism) @@ -811,14 +796,15 @@ class SparkSession private( @Stable -object SparkSession extends api.BaseSparkSessionCompanion with Logging { +object SparkSession extends SparkSessionCompanion with Logging { override private[sql] type Session = SparkSession /** * Builder for [[SparkSession]]. */ @Stable - class Builder extends api.SparkSessionBuilder { + class Builder extends SparkSessionBuilder { + import SparkSessionBuilder._ private[this] val extensions = new SparkSessionExtensions @@ -860,18 +846,26 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging { override def master(master: String): this.type = super.master(master) /** @inheritdoc */ - override def enableHiveSupport(): this.type = synchronized { - if (hiveClassesArePresent) { - // TODO(SPARK-50244): We now isolate artifacts added by the `ADD JAR` command. This will - // break an existing Hive use case (one session adds JARs and another session uses them). - // We need to decide whether/how to enable isolation for Hive. - super.enableHiveSupport() - .config(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key, false) - } else { - throw new IllegalArgumentException( - "Unable to instantiate SparkSession with Hive support because " + - "Hive classes are not found.") - } + override def enableHiveSupport(): this.type = super.enableHiveSupport() + + override protected def handleBuilderConfig(key: String, value: String): Boolean = key match { + case CONNECT_REMOTE_KEY | API_MODE_KEY => + logWarning(log"${MDC(CONFIG, key)} configuration is not supported in Classic mode.") + true + case CATALOG_IMPL_KEY if value == "hive" => + if (hiveClassesArePresent) { + // TODO(SPARK-50244): We now isolate artifacts added by the `ADD JAR` command. This will + // break an existing Hive use case (one session adds JARs and another session uses + // them). We need to decide whether/how to enable isolation for Hive. + config(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key, false) + } else { + throw new IllegalArgumentException( + "Unable to instantiate SparkSession with Hive support because " + + "Hive classes are not found.") + } + false + case _ => + false } /** @inheritdoc */ @@ -958,8 +952,11 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging { /** @inheritdoc */ override def active: SparkSession = super.active - override protected def canUseSession(session: SparkSession): Boolean = - session.isUsable && !Utils.isInRunningSparkTask + override protected def tryCastToImplementation( + session: sql.SparkSession): Option[SparkSession] = session match { + case impl: SparkSession if !Utils.isInRunningSparkTask => Some(impl) + case _ => None + } /** * Apply modifiable settings to an existing [[SparkSession]]. This method are used diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQuery.scala similarity index 86% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQuery.scala index 7cf92db59067c..68ea9ce0fdea7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQuery.scala @@ -14,12 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.classic -import org.apache.spark.sql.{api, SparkSession} +import org.apache.spark.sql.streaming /** @inheritdoc */ -trait StreamingQuery extends api.StreamingQuery { +trait StreamingQuery extends streaming.StreamingQuery { /** @inheritdoc */ override def sparkSession: SparkSession } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala index 42f6d04466b08..6ce6f06de113d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.classic import java.util.UUID import java.util.concurrent.{TimeoutException, TimeUnit} @@ -27,7 +27,6 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{CLASS_NAME, QUERY_ID, RUN_ID} -import org.apache.spark.sql.{api, Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.streaming.{WriteToStream, WriteToStreamStatement} import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog} @@ -37,6 +36,8 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS +import org.apache.spark.sql.streaming +import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamingQueryListener, Trigger} import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -48,7 +49,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} class StreamingQueryManager private[sql] ( sparkSession: SparkSession, sqlConf: SQLConf) - extends api.StreamingQueryManager + extends streaming.StreamingQueryManager with Logging { private[sql] val stateStoreCoordinator = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/TableValuedFunction.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/TableValuedFunction.scala index 406b67e6f3b8a..d2034033fee7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/TableValuedFunction.scala @@ -14,12 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic +import org.apache.spark.sql +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedTableValuedFunction class TableValuedFunction(sparkSession: SparkSession) - extends api.TableValuedFunction { + extends sql.TableValuedFunction { /** @inheritdoc */ override def range(end: Long): Dataset[java.lang.Long] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/TypedAggUtils.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/TypedAggUtils.scala index 23ceb8135fa8a..d474522fe12f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/TypedAggUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.classic import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFo import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.internal.SQLConf private[sql] object TypedAggUtils { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/UDFRegistration.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/UDFRegistration.scala index 6715673cf3d1c..a35c97c455a5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/UDFRegistration.scala @@ -15,22 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic import java.lang.reflect.ParameterizedType import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging +import org.apache.spark.sql import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.JavaTypeInference import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.classic.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} -import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.types.DataType /** @@ -44,7 +45,7 @@ import org.apache.spark.sql.types.DataType */ @Stable class UDFRegistration private[sql] (session: SparkSession, functionRegistry: FunctionRegistry) - extends api.UDFRegistration + extends sql.UDFRegistration with Logging { protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { log.debug( @@ -136,7 +137,7 @@ class UDFRegistration private[sql] (session: SparkSession, functionRegistry: Fun } // scalastyle:off line.size.limit - + /** @inheritdoc */ override def registerJava(name: String, className: String, returnDataType: DataType): Unit = { try { val clazz = session.artifactManager.classloader.loadClass(className) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/UserDefinedFunctionUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/UserDefinedFunctionUtils.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/UserDefinedFunctionUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/UserDefinedFunctionUtils.scala index bd8735d15be13..124372057c2c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/UserDefinedFunctionUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/UserDefinedFunctionUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.classic import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder import org.apache.spark.sql.catalyst.encoders.encoderFor diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala index 8f37f5c32de34..5766535ac5dac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala @@ -14,12 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.classic import UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.SparkException -import org.apache.spark.sql.{Column, Dataset, SparkSession} +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.{analysis, expressions, CatalystTypeConverters} import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAlias} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Generator, NamedExpression, Unevaluable} @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF, TypedAggregateExpression} import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator} +import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SQLConf, SqlExpression, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame} import org.apache.spark.sql.types.{DataType, NullType} /** @@ -274,7 +275,7 @@ private[sql] case class ExpressionColumnNode private( override def sql: String = expression.sql - override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty + override def children: Seq[ColumnNodeLike] = Seq.empty } private[sql] object ExpressionColumnNode { @@ -284,7 +285,7 @@ private[sql] object ExpressionColumnNode { } } -private[internal] case class ColumnNodeExpression private(node: ColumnNode) extends Unevaluable { +private[classic] case class ColumnNodeExpression private(node: ColumnNode) extends Unevaluable { override def nullable: Boolean = true override def dataType: DataType = NullType override def children: Seq[Expression] = Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/conversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/conversions.scala index e90fd4b6a6032..3cfdace73c458 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/conversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/conversions.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.classic import scala.language.implicitConversions import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql._ +import org.apache.spark.sql +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} -import org.apache.spark.sql.internal.{ColumnNodeToExpressionConverter, ExpressionUtils} /** * Conversions from sql interfaces to the Classic specific implementation. @@ -36,18 +36,20 @@ import org.apache.spark.sql.internal.{ColumnNodeToExpressionConverter, Expressio */ @DeveloperApi trait ClassicConversions { - implicit def castToImpl(session: api.SparkSession): SparkSession = + implicit def castToImpl(session: sql.SparkSession): SparkSession = session.asInstanceOf[SparkSession] - implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] = + implicit def castToImpl[T](ds: sql.Dataset[T]): Dataset[T] = ds.asInstanceOf[Dataset[T]] - implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset = + implicit def castToImpl(rgds: sql.RelationalGroupedDataset): RelationalGroupedDataset = rgds.asInstanceOf[RelationalGroupedDataset] - implicit def castToImpl[K, V](kvds: api.KeyValueGroupedDataset[K, V]) + implicit def castToImpl[K, V](kvds: sql.KeyValueGroupedDataset[K, V]) : KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] + implicit def castToImpl(context: sql.SQLContext): SQLContext = context.asInstanceOf[SQLContext] + /** * Helper that makes it easy to construct a Column from an Expression. */ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/package.scala similarity index 70% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala rename to sql/core/src/main/scala/org/apache/spark/sql/classic/package.scala index 993b09ace9139..f4c44a013f944 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/package.scala @@ -14,11 +14,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql -/** @inheritdoc */ -abstract class SQLImplicits extends api.SQLImplicits { - type DS[U] = Dataset[U] +package org.apache.spark.sql - protected def session: SparkSession +/** + * Allows the execution of relational queries, including those expressed in SQL using Spark. + * + * @groupname dataType Data types + * @groupdesc Spark SQL data types. + * @groupprio dataType -3 + * @groupname field Field + * @groupprio field -2 + * @groupname row Row + * @groupprio row -1 + */ +package object classic { + type DataFrame = Dataset[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index a3382c83e1f20..8fe7565c902a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -21,13 +21,13 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.{LogEntry, Logging, MDC} import org.apache.spark.internal.LogKeys._ -import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression} import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPlan, ResolvedHint, View} import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.sql.classic.{Dataset, SparkSession} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.command.CommandUtils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index fd8f0b85edd26..1ac7aa00d98c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql.execution import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{LOGICAL_PLAN_COLUMNS, OPTIMIZED_PLAN_COLUMNS} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, Encoder, SparkSession} +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.classic.{Dataset, SparkSession} import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index d9b1a2136a5d3..82f102a145a12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{LazyExpression, UnsupportedOperationChecker} import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, Command import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan} import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 242149010ceef..cb3f382914321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -28,7 +28,7 @@ import org.apache.spark.SparkContext.{SPARK_JOB_DESCRIPTION, SPARK_JOB_INTERRUPT import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{SPARK_DRIVER_PREFIX, SPARK_EXECUTOR_PREFIX} import org.apache.spark.internal.config.Tests.IS_TESTING -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index fb3ec3ad41812..d9bb057282dff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -27,13 +27,14 @@ import org.apache.spark.{broadcast, SparkEnv, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.{RDD, RDDOperationScope} -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, UnaryLike} +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.WriteFilesSpec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 5dfe85548349c..04132c9f4c301 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql._ +import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.execution.{SparkStrategy => Strategy} import org.apache.spark.sql.execution.adaptive.LogicalQueryStageStrategy import org.apache.spark.sql.execution.command.v2.V2CommandStrategy import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} @@ -92,7 +94,7 @@ class SparkPlanner(val session: SparkSession, val experimentalMethods: Experimen val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) val filterCondition: Option[Expression] = - prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) + prunePushedDownFilters(filterPredicates).reduceLeftOption(And) // Right now we still use a projection even if the only evaluation is applying an alias // to a column. Since this is a no-op, it could be avoided. However, using this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 36e25773f8342..31d86bdf8c529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -21,7 +21,7 @@ import java.util.Locale import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, AnalysisException, Strategy} +import org.apache.spark.sql.{execution, AnalysisException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.execution.{SparkStrategy => Strategy} import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index a0a0991429309..886ef12f32d73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -31,7 +31,6 @@ import org.apache.spark.broadcast import org.apache.spark.internal.{MDC, MessageWithContext} import org.apache.spark.internal.LogKeys._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} @@ -39,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDi import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala index 25c8f695689c0..c7f4ef43e4810 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution.analysis import scala.collection.mutable import org.apache.spark.SparkException -import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Cast, Equality, Expression, ExprId} import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.MetadataBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 9e6a99ef9fb28..7ea1bd6ff7dc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -32,11 +32,11 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, Message import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 1268b14a32fb5..66548404684a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql._ +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DatetimeType, _} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index 4efd94e442e4a..8f1e05c87c8f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.util.PartitioningUtils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 157554e821811..37234656864eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier - +import org.apache.spark.sql.classic.ClassicConversions.castToImpl /** * Analyzes the given table to generate statistics, which will be used in query optimizations. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala index 1650af74bc242..26192551632e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala @@ -22,7 +22,7 @@ import scala.util.control.NonFatal import org.apache.spark.internal.LogKeys.{DATABASE_NAME, ERROR, TABLE_NAME} import org.apache.spark.internal.MDC import org.apache.spark.sql.{Row, SparkSession} - +import org.apache.spark.sql.classic.ClassicConversions.castToImpl /** * Analyzes all tables in the given database to generate statistics. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 48d98c14c3889..7cbba170cd1e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -26,7 +26,6 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{COUNT, DATABASE_NAME, ERROR, TABLE_NAME, TIME} -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.ResolvedIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, CatalogTableType, ExternalCatalogUtils} @@ -35,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.{DataSourceUtils, InMemoryFileIndex} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index 592ae04a055d1..90050b25e9543 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -22,9 +22,10 @@ import java.net.URI import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext -import org.apache.spark.sql.{Row, SaveMode, SparkSession} +import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, LogicalPlan, UnaryCommand} +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala index 4440a8889c05c..33b0e5e794ade 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.util.{ Iso8601TimestampFormatter, LegacyDateFormats } +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.V1Table import org.apache.spark.sql.errors.QueryCompilationErrors diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala index 7c690c8ccc08d..d3a72f915c47b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MDC -import org.apache.spark.sql._ +import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, CTERelationDef, LogicalPlan, WithCTE} import org.apache.spark.sql.errors.QueryExecutionErrors diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 4e513fc3e8c1d..87cd9376b77b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.errors.QueryCompilationErrors.toSQLId import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 539d8346a5cad..f29d2267f75fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -21,10 +21,11 @@ import java.net.URI import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MDC -import org.apache.spark.sql._ +import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, CTERelationDef, LogicalPlan, WithCTE} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils} +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.CommandExecutionMode import org.apache.spark.sql.execution.datasources._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 9dfe5c3e4c301..6eb81e6ec670b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.catalog.SupportsNamespaces._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index a58e8fac6e36d..092e6669338ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.DataSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index a583f1e4650b6..dbf98c70504d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, SubqueryExpression, VariableReference} import org.apache.spark.sql.catalyst.plans.logical.{AnalysisOnlyCommand, CTEInChildren, CTERelationDef, LogicalPlan, Project, View, WithCTE} import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 58bbd91a8cc77..ad5da35d2f50c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -29,11 +29,13 @@ import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{CLASS_NAME, DATA_SOURCE, DATA_SOURCES, PATHS} -import org.apache.spark.sql._ +import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TypeUtils} +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.connector.catalog.TableProvider import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.command.DataWritingCommand diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolver.scala index 3a2a3207a01f9..4cd75736ea9eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolver.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.resolver.{ ExplicitlyUnsupportedResolverFeature, ResolverExtension @@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.resolver.{ import org.apache.spark.sql.catalyst.catalog.UnresolvedCatalogRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.streaming.StreamingRelation /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 95746218e8792..e2e72a9e36953 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.PREDICATES import org.apache.spark.rdd.RDD -import org.apache.spark.sql._ +import org.apache.spark.sql.{Row, SaveMode, Strategy} import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow, QualifiedTableName, SQLConfHelper} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ @@ -41,15 +41,18 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder} +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.{SupportsRead, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation} import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, PushedDownOperators} import org.apache.spark.sql.execution.streaming.StreamingRelation +import org.apache.spark.sql.sources import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.{PartitioningUtils => CatalystPartitioningUtils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallBackFileSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallBackFileSourceV2.scala index 66b5971eef27f..8c7203bca625f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallBackFileSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallBackFileSourceV2.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.datasources import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, FileTable} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 7788f3287ac4b..f82da44e73031 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompres import org.apache.hadoop.mapreduce.Job import org.apache.spark.paths.SparkPath -import org.apache.spark.sql._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 5e6107c4f49c7..b4cffa59c98d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -29,13 +29,13 @@ import org.apache.spark._ import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala index 8a47a28de845c..50af845c37cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala @@ -24,6 +24,7 @@ import org.apache.spark.Partition import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{CONFIG, DESIRED_NUM_PARTITIONS, MAX_NUM_PARTITIONS, NUM_PARTITIONS} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.ScanFileListing import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileResolver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileResolver.scala index 44102da752c2e..0728054625aa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileResolver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileResolver.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension import org.apache.spark.sql.catalyst.plans.logical.{AnalysisHelper, LogicalPlan} +import org.apache.spark.sql.classic.SparkSession /** * The [[FileResolver]] is a [[MetadataResolver]] extension that resolves [[UnresolvedRelation]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index 128f6acdeaa69..3fec6b4a274e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql._ +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, CTERelationDef, LogicalPlan, WithCTE} +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.sources.InsertableRelation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 8a795f0748811..03d4f09ab3379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql.execution.datasources import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.io.FileCommitProtocol -import org.apache.spark.sql._ +import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 51fed315439ef..3c8ed907fd439 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.datasources import scala.util.control.NonFatal import org.apache.spark.SparkThrowable -import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} +import org.apache.spark.sql.{Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, CTERelationDef, LogicalPlan, WithCTE} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 655632aa6d9b5..6196bef106fa5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -37,6 +37,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVInferSchema, CSVOptions, UnivocityParser} +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 20d9e0b872017..e4cecf7d6ead7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale -import org.apache.spark.sql._ +import org.apache.spark.sql.{Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index e9b31875bd7b0..3a4ca99fc95a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} import org.apache.spark.sql.catalyst.util.FailureSafeParser +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala index a288b5ebf8b38..43fc69a52286d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.spark.SparkException import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.JSONOptions +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.DataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index e5fbf8be1f0c2..1748105318b4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -34,7 +34,7 @@ import org.apache.parquet.hadoop._ import org.apache.spark.TaskContext import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{PATH, SCHEMA} -import org.apache.spark.sql._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 663182d8d1820..bad883534115c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -35,8 +35,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{CLASS_NAME, CONFIG} -import org.apache.spark.sql.Row -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.JoinedRow import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 1aee23f36313c..6420d3ab374e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{HashMap, HashSet} import scala.jdk.CollectionConverters._ import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.TypeUtils._ +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.expressions.{FieldReference, RewritableTransform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.DDLUtils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala index 86fa0c8523f1e..c7f47d2eaaaad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala @@ -21,12 +21,12 @@ import java.util.Locale import org.apache.spark.internal.LogKeys.OPTIONS import org.apache.spark.internal.MDC -import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{LocalTempView, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper import org.apache.spark.sql.execution.command.CreateViewCommand import org.apache.spark.storage.StorageLevel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index ce863791659bb..bca3146df2766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -25,7 +25,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.EXPR -import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.catalyst.expressions @@ -34,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder} +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable} import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} @@ -42,7 +42,7 @@ import org.apache.spark.sql.connector.read.LocalScan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan, SparkStrategy => Strategy} import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelationWithTable, PushableColumnAndNestedColumn} import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameTableExec.scala index 98994332d160f..1a755cbbb7b85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameTableExec.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} import org.apache.spark.storage.StorageLevel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index 801151c51206d..358f35e11d655 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.connector.catalog.SupportsWrite import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala index da86dd63cfd62..fafde89001aa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.catalyst.xml.{StaxXmlParser, XmlInferSchema, XmlOptions} +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index 6907061d67703..5f5a9e188532e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.dynamicpruning -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeSeq, BindReferences, DynamicPruningExpression, DynamicPruningSubquery, Expression, ListQuery, Literal} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} @@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.DYNAMIC_PRUNING_SUBQUERY +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan, SubqueryBroadcastExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala index c1c5f4a2a2b59..1cff218229b87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala @@ -258,7 +258,7 @@ case class TransformWithStateInPandasExec( schemaForValueRow, NoPrefixKeyStateEncoderSpec(schemaForKeyRow), session.sqlContext.sessionState, - Some(session.sqlContext.streams.stateStoreCoordinator), + Some(session.streams.stateStoreCoordinator), useColumnFamilies = true, useMultipleValuesPerKey = true ) { @@ -287,7 +287,7 @@ case class TransformWithStateInPandasExec( initialState.execute(), getStateInfo, storeNames = Seq(), - session.sqlContext.streams.stateStoreCoordinator) { + session.streams.stateStoreCoordinator) { // The state store aware zip partitions will provide us with two iterators, // child data iterator and the initial state iterator per partition. case (partitionId, childDataIterator, initStateIterator) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 575e3d4072b8c..c6cb12b1a59f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -24,13 +24,14 @@ import scala.collection.mutable.ArrayBuffer import net.razorvine.pickle.Pickler import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorkerUtils, SpecialLengths} -import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession, TableArg, TableValuedFunctionArgument} +import org.apache.spark.sql.{Column, TableArg, TableValuedFunctionArgument} import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, NullsFirst, NullsLast, PythonUDAF, PythonUDF, PythonUDTF, PythonUDTFAnalyzeResult, PythonUDTFSelectedExpression, SortOrder, UnresolvedPolymorphicPythonUDTF} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, NamedParametersSupport, OneRowRelation} +import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.ExpressionUtils.expression import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.ExpressionUtils.expression import org.apache.spark.sql.types.{DataType, StructType} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 221ca17ddf19d..d9109e0d9e9ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -22,13 +22,15 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.{functions, Column, DataFrame} +import org.apache.spark.sql.{functions, Column} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.DataFrame +import org.apache.spark.sql.classic.ExpressionUtils.expression import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -52,7 +54,6 @@ object FrequentItems extends Logging { df: DataFrame, cols: Seq[String], support: Double): DataFrame = { - import df.sparkSession.expression require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.") // number of max items to keep counts for diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index dd7fee455b4df..89fc69cd2bdd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.stat import java.util.Locale import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.QuantileSummaries +import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala index 4a7cb5b71a77f..0f3ae844808e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala @@ -22,8 +22,8 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.internal.LogKeys.{BATCH_ID, PRETTY_ID_STRING} import org.apache.spark.internal.MDC -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.streaming.WriteToStream +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.streaming.Trigger import org.apache.spark.util.{Clock, ThreadUtils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index d561ee1ef730f..c1711ffe2923e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.LogKeys.{BATCH_ID, ERROR, PATH} import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter} import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 25c6d454dfd58..f0debce44e376 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -29,8 +29,8 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.paths.SparkPath -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.connector.read.streaming import org.apache.spark.sql.connector.read.streaming.{ReadAllAvailable, ReadLimit, ReadMaxBytes, ReadMaxFiles, SupportsAdmissionControl, SupportsTriggerAvailableNow} import org.apache.spark.sql.errors.QueryExecutionErrors diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index d5ad77a441107..8fe6178744ac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -26,14 +26,14 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{BATCH_TIMESTAMP, ERROR} -import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, ExpressionWithRandomSeed} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, SerializeFromObjectExec, SparkPlan, SparkPlanner, UnaryExecNode} +import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, SerializeFromObjectExec, SparkPlan, SparkPlanner, SparkStrategy => Strategy, UnaryExecNode} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec, UpdatingSessionsExec} import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index cfaa7c5993741..fe06cbb19c3a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -21,13 +21,14 @@ import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable import org.apache.spark.internal.{LogKeys, MDC} -import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project, StreamSourceAwareLogicalPlan} import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, WriteToStream} import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.classic.{Dataset, SparkSession} +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsTriggerAvailableNow} import org.apache.spark.sql.errors.QueryExecutionErrors diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index fdb4f2813dba2..3ac07cf1d7308 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -26,11 +26,12 @@ import scala.collection.mutable import scala.jdk.CollectionConverters._ import org.apache.spark.internal.{Logging, LogKeys, MDC} -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.optimizer.InlineCTE import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, WithCTE} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream} import org.apache.spark.sql.execution.{QueryExecution, StreamSourceAwareSparkPlan} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala index 36c7796ec4399..a88d1654487be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.StructType * exactly once semantics a sink must be idempotent in the face of multiple attempts to add the same * batch. * - * Note that, we extends `Table` here, to make the v1 streaming sink API be compatible with + * Note that, we extend `Table` here, to make the v1 streaming sink API be compatible with * data source v2. */ trait Sink extends Table { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 44202bb0d2944..708b7ee7e6f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -35,9 +35,9 @@ import org.apache.logging.log4j.CloseableThreadContext import org.apache.spark.{JobArtifactSet, SparkContext, SparkException, SparkThrowable} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{CHECKPOINT_PATH, CHECKPOINT_ROOT, LOGICAL_PLAN, PATH, PRETTY_ID_STRING, QUERY_ID, RUN_ID, SPARK_DATA_STREAM} -import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.classic.{SparkSession, StreamingQuery} import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream} import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write} @@ -47,7 +47,7 @@ import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchUserFuncExc import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataV2FileManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend -import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamingQueryListener, StreamingQueryProgress, StreamingQueryStatus, Trigger} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala index 3f2cdadfbaeee..c59b9584f5383 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus} +import org.apache.spark.sql.classic.{SparkSession, StreamingQuery} +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus} /** * Wrap non-serializable StreamExecution to make the query serializable as it's easy for it to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 2f75150f7a2c2..a13c00ee20576 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -27,11 +27,11 @@ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.SparkEnv import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MDC -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{CurrentDate, CurrentTimestampLike, LocalTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, WriteToStream} import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability} import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, PartitionOffset, ReadLimit} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 2ff478ef98e4e..48af1972e581c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,7 +24,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ListBuffer import org.apache.spark.internal.Logging -import org.apache.spark.sql._ +import org.apache.spark.sql.{Encoder, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -32,6 +32,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.classic.{DataFrame, Dataset} +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsTriggerAvailableNow} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala index f1839ccceee1d..ceca5e8e1eaa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.classic.{Dataset, SparkSession} import org.apache.spark.sql.connector.write.{PhysicalWriteInfo, WriterCommitMessage} import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala index c687caafdef37..139845c13e11c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution.streaming.sources import scala.util.control.NonFatal import org.apache.spark.{SparkException, SparkThrowable} -import org.apache.spark.sql._ +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.classic.Dataset.ofRows import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.streaming.DataStreamWriter @@ -33,7 +35,7 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: Expr val node = LogicalRDD.fromDataset(rdd = data.queryExecution.toRdd, originDataset = data, isStreaming = false) implicit val enc = encoder - val ds = Dataset.ofRows(data.sparkSession, node).as[T] + val ds = ofRows(data.sparkSession, node).as[T] // SPARK-47329 - for stateful queries that perform multiple operations on the dataframe, it is // highly recommended to persist the dataframe to prevent state stores from reloading // state multiple times in each batch. We cannot however always call `persist` on the dataframe @@ -82,7 +84,7 @@ trait PythonForeachBatchFunction { object PythonForeachBatchHelper { def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: PythonForeachBatchFunction): Unit = { - dsw.foreachBatch(pythonFunc.call _) + dsw.foreachBatch((df: DataFrame, id: Long) => pythonFunc.call(castToImpl(df), id)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 81b2308a9e3f0..a82eff4812953 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType @@ -43,8 +44,8 @@ package object state { keySchema, valueSchema, keyStateEncoderSpec, - sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator))( + sqlContext.sparkSession.sessionState, + Some(castToImpl(sqlContext.sparkSession).streams.stateStoreCoordinator))( storeUpdateFunction) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 771c743f70629..d5f258a8084be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution import org.apache.spark.QueryContext import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, ExprId, InSet, ListQuery, Literal, PlanExpression, Predicate, SupportQueryContext} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index fd7effca0f03a..1936a9aab0de6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} +import org.apache.spark.sql.{DataSourceRegistration, ExperimentalMethods, SparkSessionExtensions, Strategy, UDTFRegistration} import org.apache.spark.sql.artifact.ArtifactManager import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveDataSource, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.classic.{SparkSession, StreamingQueryManager, UDFRegistration} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} @@ -38,7 +39,6 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.{TableCapabilityCheck, V2SessionCatalog} import org.apache.spark.sql.execution.streaming.ResolveWriteToStream import org.apache.spark.sql.expressions.UserDefinedAggregateFunction -import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index bc6710e6cbdb8..440148989ffb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Unstable -import org.apache.spark.sql._ +import org.apache.spark.sql.{DataSourceRegistration, ExperimentalMethods, UDTFRegistration} import org.apache.spark.sql.artifact.ArtifactManager import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ @@ -32,11 +32,11 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.classic.{SparkSession, StreamingQueryManager, UDFRegistration} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder import org.apache.spark.sql.execution.datasources.DataSourceManager -import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.{DependencyUtils, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 2b15a6c55fa97..fabd47422daf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.scripting -import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody} +import org.apache.spark.sql.classic.{DataFrame, SparkSession} /** * SQL scripting executor - executes script and returns result statements. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 58cbfb0feb015..d11bf14be6546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -21,11 +21,12 @@ import java.util import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} +import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 7d00bbb3538df..d98588278956f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.scripting -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.classic.SparkSession /** * SQL scripting interpreter - builds SQL script execution plan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala index 7b8847dc35856..b92e361ef805b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.annotation.Evolving import org.apache.spark.api.java.Optional import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.execution.streaming.GroupStateImpl._ @@ -114,7 +114,6 @@ import org.apache.spark.sql.execution.streaming.GroupStateImpl._ * Spark SQL types (see `Encoder` for more details). * @since 3.2.0 */ -@Experimental @Evolving trait TestGroupState[S] extends GroupState[S] { /** Whether the state has been marked for removing */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 309853abbd39a..079c6d286b9ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -22,7 +22,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java index 8ecd4ae4cb658..f33196890553c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java @@ -18,7 +18,7 @@ package test.org.apache.spark.sql; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.classic.SparkSession; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 924f6c26e21dc..c1f48a922b727 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.classic.SparkSession; import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.sql.types.DataTypes; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala index 53662c65609d7..b84f6e608205d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala @@ -82,7 +82,9 @@ class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSpa ApproxCountDistinctForIntervals(dtAttr, CreateArray(endpoints.map(Literal(_)))) val dtAggExpr = dtAggFunc.toAggregateExpression() val dtNamedExpr = Alias(dtAggExpr, dtAggExpr.toString)() - val result = Dataset.ofRows(spark, Aggregate(Nil, Seq(ymNamedExpr, dtNamedExpr), relation)) + val result = classic.Dataset.ofRows( + spark, + Aggregate(Nil, Seq(ymNamedExpr, dtNamedExpr), relation)) checkAnswer(result, Row(Array(1, 1, 1, 1, 1), Array(1, 1, 1, 1, 1))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index ca3f282d8cd40..ff455d28e8680 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -92,7 +92,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils maybeBlock.nonEmpty && isExpectLevel } - private def getNumInMemoryRelations(ds: Dataset[_]): Int = { + private def getNumInMemoryRelations(ds: classic.Dataset[_]): Int = { val plan = ds.queryExecution.withCachedData var sum = plan.collect { case _: InMemoryRelation => 1 }.sum plan.transformAllExpressions { @@ -515,7 +515,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils /** * Verifies that the plan for `df` contains `expected` number of Exchange operators. */ - private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { + private def verifyNumExchanges(df: classic.DataFrame, expected: Int): Unit = { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { df.collect() } @@ -1052,7 +1052,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } test("Cache should respect the hint") { - def testHint(df: Dataset[_], expectedHint: JoinStrategyHint): Unit = { + def testHint(df: classic.Dataset[_], expectedHint: JoinStrategyHint): Unit = { val df2 = spark.range(2000).cache() df2.count() @@ -1097,7 +1097,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } test("analyzes column statistics in cached query") { - def query(): DataFrame = { + def query(): classic.DataFrame = { spark.range(100) .selectExpr("id % 3 AS c0", "id % 5 AS c1", "2 AS c2") .groupBy("c0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 4602c1a9cc1a2..8ed9670217859 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -462,7 +462,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi */ def generateTableData( inputTypes: Seq[AbstractDataType], - collationType: CollationType): DataFrame = { + collationType: CollationType): classic.DataFrame = { val tblName = collationType match { case Utf8Binary => "tbl" case Utf8Lcase => "tbl_lcase" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala index 5f0ae918524e9..e140d81ce89e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala @@ -31,7 +31,7 @@ class DataFrameAsOfJoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlanHelper { - def prepareForAsOfJoin(): (DataFrame, DataFrame) = { + def prepareForAsOfJoin(): (classic.DataFrame, classic.DataFrame) = { val schema1 = StructType( StructField("a", IntegerType, false) :: StructField("b", StringType, false) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index e2bdf1c732078..6e634af9c2fe5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -251,7 +251,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { } } - def createDFsWithSameFieldsName(): (DataFrame, DataFrame) = { + def createDFsWithSameFieldsName(): (classic.DataFrame, classic.DataFrame) = { val df1 = Seq( ("f1-1", "f2", null), ("f1-2", null, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 0e9b1c9d2104e..b60b10d68e86e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} +import org.apache.spark.sql.classic.{Dataset => DatasetImpl} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{col, count, explode, sum, year} import org.apache.spark.sql.internal.SQLConf @@ -247,9 +248,9 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { TestData(2, "personnel"), TestData(3, "develop"), TestData(4, "IT")).toDF() - val ds_id1 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG) + val ds_id1 = df.logicalPlan.getTagValue(DatasetImpl.DATASET_ID_TAG) df.show(0) - val ds_id2 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG) + val ds_id2 = df.logicalPlan.getTagValue(DatasetImpl.DATASET_ID_TAG) assert(ds_id1 === ds_id2) } @@ -268,27 +269,27 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { TestData(2, "personnel"), TestData(3, "develop"), TestData(4, "IT")).toDS() - var dsIdSetOpt = ds.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG) + var dsIdSetOpt = ds.logicalPlan.getTagValue(DatasetImpl.DATASET_ID_TAG) assert(dsIdSetOpt.get.size === 1) var col1DsId = -1L val col1 = ds.col("key") col1.expr.foreach { case a: AttributeReference => - col1DsId = a.metadata.getLong(Dataset.DATASET_ID_KEY) + col1DsId = a.metadata.getLong(DatasetImpl.DATASET_ID_KEY) assert(dsIdSetOpt.get.contains(col1DsId)) - assert(a.metadata.getLong(Dataset.COL_POS_KEY) === 0) + assert(a.metadata.getLong(DatasetImpl.COL_POS_KEY) === 0) } val df = ds.toDF() - dsIdSetOpt = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG) + dsIdSetOpt = df.logicalPlan.getTagValue(DatasetImpl.DATASET_ID_TAG) assert(dsIdSetOpt.get.size === 2) var col2DsId = -1L val col2 = df.col("key") col2.expr.foreach { case a: AttributeReference => - col2DsId = a.metadata.getLong(Dataset.DATASET_ID_KEY) - assert(dsIdSetOpt.get.contains(a.metadata.getLong(Dataset.DATASET_ID_KEY))) - assert(a.metadata.getLong(Dataset.COL_POS_KEY) === 0) + col2DsId = a.metadata.getLong(DatasetImpl.DATASET_ID_KEY) + assert(dsIdSetOpt.get.contains(a.metadata.getLong(DatasetImpl.DATASET_ID_KEY))) + assert(a.metadata.getLong(DatasetImpl.COL_POS_KEY) === 0) } assert(col1DsId !== col2DsId) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index 6c1ca94a03079..6d118a7fd98e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -480,7 +480,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession df <- Seq(df1, df2) nullable <- Seq(true, false) } { - val dfWithDesiredNullability = new DataFrame(df.queryExecution, ExpressionEncoder( + val dfWithDesiredNullability = new classic.DataFrame(df.queryExecution, ExpressionEncoder( StructType(df.schema.fields.map(_.copy(nullable = nullable))))) // session window without dynamic gap val windowedProject = dfWithDesiredNullability diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 9c182be0f7dd6..332be4c7bbc6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -959,7 +959,7 @@ class DataFrameSetOperationsSuite extends QueryTest } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - deep expr") { - def nestedDf(depth: Int, numColsAtEachDepth: Int): DataFrame = { + def nestedDf(depth: Int, numColsAtEachDepth: Int): classic.DataFrame = { val initialNestedStructType = StructType( (0 to numColsAtEachDepth).map(i => StructField(s"nested${depth}Col$i", IntegerType, nullable = false)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f424233050510..9b8400f0e3a15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1352,7 +1352,7 @@ class DataFrameSuite extends QueryTest ) // error case: insert into an OneRowRelation - Dataset.ofRows(spark, OneRowRelation()).createOrReplaceTempView("one_row") + classic.Dataset.ofRows(spark, OneRowRelation()).createOrReplaceTempView("one_row") checkError( exception = intercept[AnalysisException] { insertion.write.insertInto("one_row") @@ -1470,7 +1470,7 @@ class DataFrameSuite extends QueryTest /** * Verifies that there is no Exchange between the Aggregations for `df` */ - private def verifyNonExchangingAgg(df: DataFrame) = { + private def verifyNonExchangingAgg(df: classic.DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { case agg: HashAggregateExec => @@ -1485,7 +1485,7 @@ class DataFrameSuite extends QueryTest /** * Verifies that there is an Exchange between the Aggregations for `df` */ - private def verifyExchangingAgg(df: DataFrame) = { + private def verifyExchangingAgg(df: classic.DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { case agg: HashAggregateExec => @@ -1623,7 +1623,7 @@ class DataFrameSuite extends QueryTest val statsPlan = OutputListAwareConstraintsTestPlan(outputList = outputList) - val df = Dataset.ofRows(spark, statsPlan) + val df = classic.Dataset.ofRows(spark, statsPlan) // add some map-like operations which optimizer will optimize away, and make a divergence // for output between logical plan and optimized plan // logical plan @@ -1791,7 +1791,7 @@ class DataFrameSuite extends QueryTest } private def verifyNullabilityInFilterExec( - df: DataFrame, + df: classic.DataFrame, expr: String, expectedNonNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) @@ -2711,6 +2711,16 @@ class DataFrameSuite extends QueryTest val expected = getQueryResult(false).map(_.getTimestamp(0).toString).sorted assert(actual == expected) } + + test("SPARK-50962: Avoid StringIndexOutOfBoundsException in AttributeNameParser") { + checkError( + exception = intercept[AnalysisException] { + spark.emptyDataFrame.colRegex(".whatever") + }, + condition = "INVALID_ATTRIBUTE_NAME_SYNTAX", + parameters = Map("name" -> ".whatever") + ) + } } case class GroupByKey(a: Int, b: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index c52d428cd5dd4..a2a3f9d74fc05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -596,8 +596,9 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { df <- Seq(df1, df2) nullable <- Seq(true, false) } { - val dfWithDesiredNullability = new DataFrame(df.queryExecution, ExpressionEncoder( - StructType(df.schema.fields.map(_.copy(nullable = nullable))))) + val dfWithDesiredNullability = new classic.DataFrame( + df.queryExecution, + ExpressionEncoder(StructType(df.schema.fields.map(_.copy(nullable = nullable))))) // tumbling windows val windowedProject = dfWithDesiredNullability .select(window($"time", "10 seconds").as("window"), $"value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index d03288d7dbcdf..09a53edf9909e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Literal, NonFoldableLiteral} import org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions import org.apache.spark.sql.catalyst.plans.logical.{Window => WindowNode} +import org.apache.spark.sql.classic.ExpressionColumnNode import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.{ExpressionColumnNode, SQLConf} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.CalendarIntervalType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ca584d6b9ce88..403d2b697a9e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1461,7 +1461,7 @@ class DatasetSuite extends QueryTest checkAnswer(df.map(row => row)(ExpressionEncoder(df.schema)).select("b", "a"), Row(2, 1)) } - private def checkShowString[T](ds: Dataset[T], expected: String): Unit = { + private def checkShowString[T](ds: classic.Dataset[T], expected: String): Unit = { val numRows = expected.split("\n").length - 4 val actual = ds.showString(numRows, truncate = 20) @@ -2429,9 +2429,9 @@ class DatasetSuite extends QueryTest } test("SparkSession.active should be the same instance after dataset operations") { - val active = SparkSession.getActiveSession.get + val active = classic.SparkSession.getActiveSession.get val clone = active.cloneSession() - val ds = new Dataset(clone, spark.range(10).queryExecution.logical, Encoders.INT) + val ds = new classic.Dataset(clone, spark.range(10).queryExecution.logical, Encoders.INT) ds.queryExecution.analyzed @@ -2922,7 +2922,7 @@ object JavaData { /** Used to test importing dataset.spark.implicits._ */ object DatasetTransform { - def addOne(ds: Dataset[Int]): Dataset[Int] = { + def addOne(ds: classic.Dataset[Int]): Dataset[Int] = { import ds.sparkSession.implicits._ ds.map(_ + 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala index 401039b0f9c0f..d17456d33cfea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala @@ -129,10 +129,10 @@ class DeprecatedAPISuite extends QueryTest with SharedSparkSession { test("SQLContext.setActive/clearActive") { val sc = spark.sparkContext - val sqlContext = new SQLContext(sc) - SQLContext.setActive(sqlContext) + val sqlContext = new classic.SQLContext(sc) + classic.SQLContext.setActive(sqlContext) assert(SparkSession.getActiveSession === Some(spark)) - SQLContext.clearActive() + classic.SQLContext.clearActive() assert(SparkSession.getActiveSession === None) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 22f55819d1d4c..6d5456462d8d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -31,12 +31,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, PythonUDF} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.ClassicConversions.ColumnConstructorExt +import org.apache.spark.sql.classic.ExpressionUtils.expression +import org.apache.spark.sql.classic.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource import org.apache.spark.sql.execution.python.{UserDefinedPythonFunction, UserDefinedPythonTableFunction} import org.apache.spark.sql.expressions.SparkUserDefinedFunction -import org.apache.spark.sql.internal.ExpressionUtils.expression -import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType, StructType, VariantType} import org.apache.spark.util.ArrayImplicits._ @@ -425,7 +425,7 @@ object IntegratedUDFTestUtils extends SQLHelper { } sealed trait TestUDTF { - def apply(session: SparkSession, exprs: Column*): DataFrame = + def apply(session: classic.SparkSession, exprs: Column*): DataFrame = udtf.apply(session, exprs: _*) val name: String = getClass.getSimpleName.stripSuffix("$") @@ -1601,7 +1601,7 @@ object IntegratedUDFTestUtils extends SQLHelper { /** * Register UDFs used in this test case. */ - def registerTestUDF(testUDF: TestUDF, session: SparkSession): Unit = testUDF match { + def registerTestUDF(testUDF: TestUDF, session: classic.SparkSession): Unit = testUDF match { case udf: TestPythonUDF => session.udf.registerPython(udf.name, udf.udf) case udf: TestScalarPandasUDF => session.udf.registerPython(udf.name, udf.udf) case udf: TestGroupedAggPandasUDF => session.udf.registerPython(udf.name, udf.udf) @@ -1615,7 +1615,7 @@ object IntegratedUDFTestUtils extends SQLHelper { * Register UDTFs used in the test cases. */ case class TestUDTFSet(udtfs: Seq[TestUDTF]) - def registerTestUDTFs(testUDTFSet: TestUDTFSet, session: SparkSession): Unit = { + def registerTestUDTFs(testUDTFSet: TestUDTFSet, session: classic.SparkSession): Unit = { testUDTFSet.udtfs.foreach { _ match { case udtf: TestUDTF => session.udtf.registerPython(udtf.name, udtf.udtf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 0f5582def82da..41f2e5c9a406e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -46,7 +46,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan setupTestData() - def statisticSizeInByte(df: DataFrame): BigInt = { + def statisticSizeInByte(df: classic.DataFrame): BigInt = { df.queryExecution.optimizedPlan.stats.sizeInBytes } @@ -1768,7 +1768,7 @@ class ThreadLeakInSortMergeJoinSuite setupTestData() override protected def createSparkSession: TestSparkSession = { - SparkSession.cleanupAnyExistingSession() + classic.SparkSession.cleanupAnyExistingSession() new TestSparkSession( sparkConf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, 20)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index b59c83c23d3c3..7ae29a6a17126 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -27,6 +27,7 @@ import org.scalatest.Assertions import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.util.QueryExecutionListener diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index ce3ac9b8834bf..8c9e3eae49816 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.internal.config import org.apache.spark.internal.config.DEFAULT_PARALLELISM -import org.apache.spark.sql.internal.{RuntimeConfigImpl, SQLConf} +import org.apache.spark.sql.classic.RuntimeConfig +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.CHECKPOINT_LOCATION import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE class RuntimeConfigSuite extends SparkFunSuite { - private def newConf(): RuntimeConfig = new RuntimeConfigImpl() + private def newConf(): RuntimeConfig = new RuntimeConfig() test("set and get") { val conf = newConf() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 003f5bc835d5f..540ca2b1ec887 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -33,11 +33,11 @@ class SessionStateSuite extends SparkFunSuite { * session as this is a singleton HiveSparkSession in HiveSessionStateSuite and it's shared * with all Hive test suites. */ - protected var activeSession: SparkSession = _ + protected var activeSession: classic.SparkSession = _ override def beforeAll(): Unit = { super.beforeAll() - activeSession = SparkSession.builder() + activeSession = classic.SparkSession.builder() .master("local") .config("default-config", "default") .getOrCreate() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index a30b13df74ae8..bbc396c879de1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -61,7 +61,7 @@ class SparkSessionBuilderSuite extends SparkFunSuite with Eventually { } (1 to 10).foreach { _ => - spark.cloneSession() + spark.asInstanceOf[classic.SparkSession].cloneSession() SparkSession.clearActiveSession() } @@ -330,8 +330,8 @@ class SparkSessionBuilderSuite extends SparkFunSuite with Eventually { .set(wh, "./data1") .set(td, "bob") - val sc = new SparkContext(conf) - + // This creates an active SparkContext, which will be picked up by the session below. + new SparkContext(conf) val spark = SparkSession.builder() .config(wh, "./data2") .config(td, "alice") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 986d547b798e8..920a0872ee4f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AggregateHint, Co import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index 89500fe51f3ac..5ba69c8f9d929 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -44,10 +44,10 @@ class SparkSessionJobTaggingAndCancellationSuite override def afterEach(): Unit = { try { // This suite should not interfere with the other test suites. - SparkSession.getActiveSession.foreach(_.stop()) - SparkSession.clearActiveSession() - SparkSession.getDefaultSession.foreach(_.stop()) - SparkSession.clearDefaultSession() + classic.SparkSession.getActiveSession.foreach(_.stop()) + classic.SparkSession.clearActiveSession() + classic.SparkSession.getDefaultSession.foreach(_.stop()) + classic.SparkSession.clearDefaultSession() resetSparkContext() } finally { super.afterEach() @@ -55,7 +55,7 @@ class SparkSessionJobTaggingAndCancellationSuite } test("Tags are not inherited by new sessions") { - val session = SparkSession.builder().master("local").getOrCreate() + val session = classic.SparkSession.builder().master("local").getOrCreate() assert(session.getTags() == Set()) session.addTag("one") @@ -66,7 +66,7 @@ class SparkSessionJobTaggingAndCancellationSuite } test("Tags are inherited by cloned sessions") { - val session = SparkSession.builder().master("local").getOrCreate() + val session = classic.SparkSession.builder().master("local").getOrCreate() assert(session.getTags() == Set()) session.addTag("one") @@ -83,7 +83,7 @@ class SparkSessionJobTaggingAndCancellationSuite test("Tags set from session are prefixed with session UUID") { sc = new SparkContext("local[2]", "test") - val session = SparkSession.builder().sparkContext(sc).getOrCreate() + val session = classic.SparkSession.builder().sparkContext(sc).getOrCreate() import session.implicits._ val sem = new Semaphore(0) @@ -116,9 +116,10 @@ class SparkSessionJobTaggingAndCancellationSuite // TODO(SPARK-50205): Re-enable this test case. ignore("Cancellation APIs in SparkSession are isolated") { sc = new SparkContext("local[2]", "test") - val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate() - var (sessionA, sessionB, sessionC): (SparkSession, SparkSession, SparkSession) = - (null, null, null) + val globalSession = classic.SparkSession.builder().sparkContext(sc).getOrCreate() + var sessionA: classic.SparkSession = null + var sessionB: classic.SparkSession = null + var sessionC: classic.SparkSession = null var (threadUuidA, threadUuidB, threadUuidC): (String, String, String) = (null, null, null) // global ExecutionContext has only 2 threads in Apache Spark CI diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 8d7ada15381bf..5222d5ce26658 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -919,7 +919,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared size = Some(expectedSize)) withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { - val df = Dataset.ofRows(spark, statsPlan) + val df = classic.Dataset.ofRows(spark, statsPlan) // add some map-like operations which optimizer will optimize away, and make a divergence // for output between logical plan and optimized plan // logical plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 18af2fcb0ee73..d736e9494bd36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -1109,7 +1109,8 @@ class UDFSuite extends QueryTest with SharedSparkSession { spark.udf.register("dummyUDF", (x: Int) => x + 1) val expressionInfo = spark.sessionState.catalog .lookupFunctionInfo(FunctionIdentifier("dummyUDF")) - assert(expressionInfo.getClassName.contains("org.apache.spark.sql.UDFRegistration$$Lambda")) + assert(expressionInfo.getClassName.contains( + "org.apache.spark.sql.classic.UDFRegistration$$Lambda")) } test("SPARK-11725: correctly handle null inputs for ScalaUDF") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala index a24982aea1585..8c03ea3ae3c25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala @@ -24,7 +24,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.metrics.source.CodegenMetrics -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/ColumnNodeToExpressionConverterSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/classic/ColumnNodeToExpressionConverterSuite.scala index d72e86450de22..9aef714a0b35a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/ColumnNodeToExpressionConverterSuite.scala @@ -14,10 +14,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql.classic import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.{analysis, expressions, InternalRow} import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ @@ -28,6 +27,7 @@ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator} +import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, Literal, SortOrder, SQLConf, SqlExpression, TypedSumLong, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame, WindowSpec} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -401,9 +401,9 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { } } -private[internal] case class Nope(override val origin: Origin = CurrentOrigin.get) +private[classic] case class Nope(override val origin: Origin = CurrentOrigin.get) extends ColumnNode { - override private[internal] def normalize(): Nope = this + override private[sql] def normalize(): Nope = this override def sql: String = "nope" - override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty + override def children: Seq[ColumnNodeLike] = Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SQLContextSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/classic/SQLContextSuite.scala index ea0d405d2a8f7..7fe2eca4c267a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SQLContextSuite.scala @@ -15,16 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -@deprecated("This suite is deprecated to silent compiler deprecation warnings", "2.0.0") +@deprecated("This suite is deprecated to silence compiler deprecation warnings", "2.0.0") class SQLContextSuite extends SparkFunSuite with SharedSparkContext { object DummyRule extends Rule[LogicalPlan] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBuilderImplementationBindingSuite.scala similarity index 89% rename from sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBuilderImplementationBindingSuite.scala index c4fd16ca5ce59..5bfc60bef2bfd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBuilderImplementationBindingSuite.scala @@ -14,8 +14,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.classic +import org.apache.spark.sql import org.apache.spark.sql.test.SharedSparkSession /** @@ -23,4 +24,4 @@ import org.apache.spark.sql.test.SharedSparkSession */ class SparkSessionBuilderImplementationBindingSuite extends SharedSparkSession - with api.SparkSessionBuilderImplementationBindingSuite + with sql.SparkSessionBuilderImplementationBindingSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala index 8aa8fb21f4ae3..524c0e138721d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.Row +import org.apache.spark.sql.classic.MergeIntoWriter import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.MergeIntoWriterImpl class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { @@ -960,9 +960,9 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { .insertAll() .whenNotMatchedBySource(col("col") === 1) .delete() - .asInstanceOf[MergeIntoWriterImpl[Row]] + .asInstanceOf[MergeIntoWriter[Row]] val writer2 = writer1.withSchemaEvolution() - .asInstanceOf[MergeIntoWriterImpl[Row]] + .asInstanceOf[MergeIntoWriter[Row]] assert(writer1 eq writer2) assert(writer1.matchedActions.length === 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index 6a3d6054301e9..95301adb9b686 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.connector import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.connector.catalog.{Table, TableCapability} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.datasources.DataSource diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala index a5f0285bf2eff..49997b5b0c18a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala @@ -65,7 +65,8 @@ object V2FunctionBenchmark extends SqlBasedBenchmark { N: Long, codegenEnabled: Boolean, resultNullable: Boolean): Unit = { - import spark.toRichColumn + val classicSession = castToImpl(spark) + import classicSession.toRichColumn withSQLConf(s"spark.sql.catalog.$catalogName" -> classOf[InMemoryCatalog].getName) { createFunction("java_long_add_default", new JavaLongAdd(new JavaLongAddDefault(resultNullable))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala index fde5a32e722f4..753f612238e94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.errors import org.apache.spark._ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} -import org.apache.spark.sql.{QueryTest, SparkSession} +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime} import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala index c43149133d65d..60a74a553bc45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.broadcast.TorrentBroadcast import org.apache.spark.scheduler._ -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.HashedRelation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index d670b3d8c77d3..c461f41c9104c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -20,7 +20,7 @@ import scala.collection.mutable import scala.io.Source import scala.util.Try -import org.apache.spark.sql.{AnalysisException, Dataset, ExtendedExplainGenerator, FastOperator} +import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, FastOperator} import org.apache.spark.sql.catalyst.{QueryPlanningTracker, QueryPlanningTrackerCallback, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{CurrentNamespace, UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{Alias, UnsafeRow} @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, LogicalPlan, OneRowRelation, Project, ShowTables, SubqueryAlias} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} import org.apache.spark.sql.execution.datasources.v2.ShowTablesExec diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index 059a4c9b83763..8fb8b302e03b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.classic import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart import org.apache.spark.sql.types._ import org.apache.spark.util.ThreadUtils @@ -136,7 +137,7 @@ class SQLExecutionSuite extends SparkFunSuite with SQLConfHelper { val executor1 = Executors.newSingleThreadExecutor() val executor2 = Executors.newSingleThreadExecutor() var session: SparkSession = null - SparkSession.cleanupAnyExistingSession() + classic.SparkSession.cleanupAnyExistingSession() withTempDir { tempDir => try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index f1fcf3bc5125e..95d7c4cd3caf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.test.SQLTestUtils /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index c90b1d3ca5978..33e5d46ee2333 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -29,10 +29,11 @@ import org.apache.arrow.vector.ipc.JsonFileReader import org.apache.arrow.vector.util.{ByteArrayReadableSeekableByteChannel, Validator} import org.apache.spark.TaskContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ArrayType, BinaryType, Decimal, IntegerType, NullType, StringType, StructField, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 4ea945d105e77..4f07d3d1c0300 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -22,10 +22,11 @@ import java.sql.{Date, Timestamp} import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution.{FilterExec, InputAdapter, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala index 2e3a4bbafb8f5..3e354751ae14a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala @@ -22,8 +22,9 @@ import java.util.zip.GZIPOutputStream import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite -import org.apache.spark.sql.{Dataset, Encoders, FakeFileSystemRequiringDSOption, SparkSession} +import org.apache.spark.sql.{Encoders, FakeFileSystemRequiringDSOption} import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.classic .{Dataset, SparkSession} /** * The trait contains tests for all file-based data sources. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala index 05872d411311a..d0dd8e03e58ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala @@ -21,9 +21,10 @@ import java.io.File import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.spark.sql.{DataFrame, Dataset, QueryTest, Row} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, FileSourceConstantMetadataStructField, FileSourceGeneratedMetadataStructField, Literal} +import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.functions.{col, lit, when} import org.apache.spark.sql.test.SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 76267ad4a8054..9676c42688920 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -28,11 +28,12 @@ import org.apache.hadoop.mapreduce.Job import org.apache.spark.SparkException import org.apache.spark.internal.config import org.apache.spark.paths.SparkPath.{fromUrlString => sp} -import org.apache.spark.sql._ +import org.apache.spark.sql.{execution, DataFrame, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet} import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.execution.{DataSourceScanExec, FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.functions._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala index 140b30a9ef9da..adcc771fe33ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.noop -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.{StreamingQuery, StreamTest, Trigger} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala index 06ea12f83ce75..1ab63bd2f3110 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala @@ -116,7 +116,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession { val partitionValues = new GenericInternalRow(Array(v)) val file = new File(TestUtils.listDirectory(dir).head) val fileSplit = new FileSplit(new Path(file.getCanonicalPath), 0L, file.length, Array.empty) - val taskConf = sqlContext.sessionState.newHadoopConf() + val taskConf = spark.sessionState.newHadoopConf() val orcFileSchema = TypeDescription.fromString(schema.simpleString) val vectorizedReader = new OrcColumnarBatchReader(4096, MemoryMode.ON_HEAP) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 9fbc872ad262b..b96a61962a70f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -25,7 +25,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql._ +import org.apache.spark.sql.{Column, DataFrame, QueryTest} import org.apache.spark.sql.catalyst.expressions.{Attribute, Predicate} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.classic.ClassicConversions._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index 25414bfc299a7..20124bb82b7e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -23,13 +23,14 @@ import java.util.{List => JList, Map => JMap} import scala.jdk.CollectionConverters._ -import org.apache.avro.Schema +import org.apache.avro.{Schema, SchemaFormatter} import org.apache.avro.generic.IndexedRecord import org.apache.hadoop.fs.Path import org.apache.parquet.avro.AvroParquetWriter import org.apache.parquet.hadoop.ParquetWriter import org.apache.spark.sql.Row +import org.apache.spark.sql.avro.AvroUtils import org.apache.spark.sql.execution.datasources.parquet.test.avro._ import org.apache.spark.sql.test.SharedSparkSession @@ -40,7 +41,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logInfo( s"""Writing Avro records with the following Avro schema into Parquet file: | - |${schema.toString(true)} + |${SchemaFormatter.format(AvroUtils.JSON_PRETTY_FORMAT, schema)} """.stripMargin) val writer = AvroParquetWriter.builder[T](new Path(path)).withSchema(schema).build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala index fb435e3639fde..458138a6213de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} import org.apache.spark.{LocalSparkContext, SparkFunSuite} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.tags.ExtendedSQLTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 5f7a0c9e7e749..6080a5e8e4bb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.InferFiltersFromConstraints import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath import org.apache.spark.sql.execution.ExplainMode import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelationWithTable, PushableColumnAndNestedColumn} @@ -2262,7 +2263,7 @@ class ParquetV1FilterSuite extends ParquetFilterSuite { SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", SQLConf.NESTED_PREDICATE_PUSHDOWN_FILE_SOURCE_LIST.key -> pushdownDsList) { val query = df - .select(output.map(Column(_)): _*) + .select(output.map(e => Column(e)): _*) .where(Column(predicate)) val nestedOrAttributes = predicate.collectFirst { @@ -2343,7 +2344,7 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> InferFiltersFromConstraints.ruleName, SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df - .select(output.map(Column(_)): _*) + .select(output.map(e => Column(e)): _*) .where(Column(predicate)) query.queryExecution.optimizedPlan.collectFirst { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index 851dceeb8ac88..3900d6eb97eb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -953,7 +953,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { test("createNamespace: basic behavior") { val catalog = newCatalog() - val sessionCatalog = sqlContext.sessionState.catalog + val sessionCatalog = spark.sessionState.catalog val expectedPath = new Path(spark.sessionState.conf.warehousePath, sessionCatalog.getDefaultDBPath(testNs(0)).toString).toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 5479be86e9f4c..69dd04e07d551 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -21,12 +21,13 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.internal.config.EXECUTOR_MEMORY -import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BitwiseAnd, BitwiseOr, Cast, Expression, Literal, ShiftLeft} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.BROADCAST import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} +import org.apache.spark.sql.classic.{Dataset, SparkSession} import org.apache.spark.sql.execution.{DummySparkPlan, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index b704790e4296b..b697c2cf4f8b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 5de106415ec68..fc01bf89993eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{EqualNullSafe, EqualTo, Expression} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 7ba93ee13e182..6ec9c2e99b622 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, LessThan} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala index a318769af6871..6af5925c9ec46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.SparkRuntimeException -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.BuildRight import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, Project} +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.test.SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala index d042a04a0e56f..3c2c75e5c7f9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -24,8 +24,8 @@ import org.scalatest.BeforeAndAfter import org.scalatest.matchers.should._ import org.scalatest.time.{Seconds, Span} -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.connector.read.streaming import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.functions.{count, timestamp_seconds, window} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index d039c72bb7d18..2ebc533f71375 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 947fccdfce72c..d108c309cb597 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.LocalSparkSession._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.tags.ExtendedSQLTest import org.apache.spark.util.{CompletionIterator, Utils} @@ -168,7 +169,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { withSparkSession(SparkSession.builder().config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext - val coordinatorRef = sqlContext.streams.stateStoreCoordinator + val coordinatorRef = castToImpl(spark).streams.stateStoreCoordinator val storeProviderId1 = StateStoreProviderId(StateStoreId(path, opId, 0), queryRunId) val storeProviderId2 = StateStoreProviderId(StateStoreId(path, opId, 1), queryRunId) coordinatorRef.reportActiveInstance(storeProviderId1, "host1", "exec1", Seq.empty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index e63ff019a2b6c..45fce208c0783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -35,12 +35,12 @@ import org.apache.spark.internal.config.Status._ import org.apache.spark.rdd.RDD import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler._ -import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.classic.{DataFrame, SparkSession} import org.apache.spark.sql.connector.{CSVDataWriter, CSVDataWriterFactory, RangeInputPartition, SimpleScanBuilder, SimpleWritableDataSource, TestLocalScanTable} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 401b17d2b24a9..cd52b52a3e1f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.classic.Catalog import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, InMemoryCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper @@ -526,10 +527,10 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf functionFields(5)) == ("nama", "cataloa", "descripta", "classa", false)) assert(functionFields(2).asInstanceOf[Array[String]].sameElements(Array("databasa"))) assert(columnFields == Seq("nama", "descripta", "typa", false, true, true, true)) - val dbString = CatalogImpl.makeDataset(Seq(db), spark).showString(10) - val tableString = CatalogImpl.makeDataset(Seq(table), spark).showString(10) - val functionString = CatalogImpl.makeDataset(Seq(function), spark).showString(10) - val columnString = CatalogImpl.makeDataset(Seq(column), spark).showString(10) + val dbString = Catalog.makeDataset(Seq(db), spark).showString(10) + val tableString = Catalog.makeDataset(Seq(table), spark).showString(10) + val functionString = Catalog.makeDataset(Seq(function), spark).showString(10) + val columnString = Catalog.makeDataset(Seq(column), spark).showString(10) dbFields.foreach { f => assert(dbString.contains(f.toString)) } tableFields.foreach { f => assert(tableString.contains(f.toString) || tableString.contains(f.asInstanceOf[Array[String]].mkString(""))) } @@ -1117,7 +1118,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf } test("SPARK-46145: listTables does not throw exception when the table or view is not found") { - val impl = spark.catalog.asInstanceOf[CatalogImpl] + val impl = spark.catalog.asInstanceOf[Catalog] for ((isTemp, dbName) <- Seq((true, ""), (false, "non_existing_db"))) { val row = new GenericInternalRow( Array(UTF8String.fromString(dbName), UTF8String.fromString("non_existing_table"), isTemp)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala index 7bf70695a9854..8376433e4427d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.internal import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{functions, Dataset} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId} import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.classic.{Dataset, ExpressionColumnNode} +import org.apache.spark.sql.functions import org.apache.spark.sql.types.{IntegerType, LongType, Metadata, MetadataBuilder, StringType} class ColumnNodeSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 8d7b7817b3a39..b51ecd55d12d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -23,10 +23,10 @@ import org.scalatest.Assertions._ import org.apache.spark.{SparkFunSuite, SparkNoSuchElementException, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.classic.{Dataset, SparkSession} import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution import org.apache.spark.sql.execution.debug.codegenStringSeq diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 2b58440baf852..c42599ceb7e12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -23,9 +23,10 @@ import org.apache.hadoop.fs.Path import org.apache.logging.log4j.Level import org.apache.spark.{SPARK_DOC_ROOT, SparkIllegalArgumentException, SparkNoSuchElementException} -import org.apache.spark.sql._ +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.MIT +import org.apache.spark.sql.classic.{SparkSession, SQLContext} import org.apache.spark.sql.execution.datasources.parquet.ParquetCompressionCodec.{GZIP, LZO} import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.test.{SharedSparkSession, TestSQLContext} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 541b2975da1e5..bf9e091c5296d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -3107,13 +3107,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel sql(s"CREATE TABLE $tableName (binary_col BINARY)") sql(s"INSERT INTO $tableName VALUES ($binary)") - val select = s"SELECT * FROM $tableName WHERE binary_col = $binary" - val df = sql(select) - val filter = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - assert(filter.isEmpty, "Filter is not pushed") - assert(df.collect().length === 1, s"Binary literal test failed: $select") + val df = sql(s"SELECT * FROM $tableName WHERE binary_col = $binary") + checkFiltersRemoved(df) + checkPushedInfo(df, "PushedFilters: [binary_col IS NOT NULL, binary_col = 0x123456]") + checkAnswer(df, Row(Array(18, 52, 86))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 325c8ce380c63..2189a0a280ca3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{DropVariable, LeafNode, OneRowRelation, Project} import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.classic.{DataFrame, SparkSession} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructField, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index c7439a8934d73..601548a2e6bd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.scripting import org.apache.spark.{SparkConf, SparkException, SparkNumberFormatException} -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.CompoundBody +import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/AcceptsLatestSeenOffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/AcceptsLatestSeenOffsetSuite.scala index b0c585ffda5e7..2dadbbe7cb237 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/AcceptsLatestSeenOffsetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/AcceptsLatestSeenOffsetSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfter -import org.apache.spark.sql._ +import org.apache.spark.sql.{Encoder, SQLContext} import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.connector.read.streaming import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, SparkDataStream} import org.apache.spark.sql.execution.streaming._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 214381f960300..b0967d5ffdf10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -33,11 +33,13 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, TaskContext, TestUtils} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} -import org.apache.spark.sql._ +import org.apache.spark.sql.{AnalysisException, Encoders, Row, SQLContext, TestStrategy} import org.apache.spark.sql.catalyst.plans.logical.{Range, RepartitionByExpression} import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.classic.{DataFrame, Dataset} +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan} import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 6d9731fa63b58..a6efc2d8fa9c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.AllTuples import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2ScanRelation import org.apache.spark.sql.execution.streaming._ @@ -355,7 +356,8 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with // and it may not work correctly when multiple `testStream`s run concurrently. val stream = _stream.toDF() - val sparkSession = stream.sparkSession // use the session in DF, not the default session + // use the session in DF, not the default session + val sparkSession = castToImpl(stream.sparkSession) var pos = 0 var currentStream: StreamExecution = null var lastStream: StreamExecution = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 0f382f4ed77de..9cecd16364759 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -27,11 +27,12 @@ import org.scalatest.Assertions import org.apache.spark.{SparkEnv, SparkException, SparkUnsupportedOperationException} import org.apache.spark.rdd.BlockRDD -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ +import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 19ab272827441..d3c44dcead3ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -564,7 +564,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5, None) implicit val sqlContext = spark.sqlContext - val coordinatorRef = sqlContext.streams.stateStoreCoordinator + val coordinatorRef = spark.streams.stateStoreCoordinator val numPartitions = 5 val storeNames = Seq("name1", "name2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index c12846d7512d9..ff23e00336a40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -36,11 +36,12 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkException, SparkUnsupportedOperationException, TestUtils} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Row, SaveMode} +import org.apache.spark.sql.{AnalysisException, Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LocalRelation} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit} import org.apache.spark.sql.execution.exchange.{REQUIRED_BY_STATEFUL_OPERATOR, ReusedExchangeExec, ShuffleExchangeExec} @@ -1496,7 +1497,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi override def schema: StructType = triggerDF.schema override def getOffset: Option[Offset] = Some(LongOffset(0)) override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - sqlContext.internalCreateDataFrame( + sqlContext.sparkSession.internalCreateDataFrame( triggerDF.queryExecution.toRdd, triggerDF.schema, isStreaming = true) } override def stop(): Unit = {} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala index 3de6273ffb7b5..659e2198f4377 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.streaming import java.io.File -import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.catalyst.util.stringToFile +import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.connector.read.streaming import org.apache.spark.sql.connector.read.streaming.{ReadLimit, SupportsAdmissionControl} import org.apache.spark.sql.execution.streaming.{LongOffset, MemoryStream, MicroBatchExecution, MultiBatchExecutor, Offset, SerializedOffset, SingleBatchExecutor, Source, StreamingExecutionRelation, StreamingQueryWrapper} @@ -59,7 +59,7 @@ class TriggerAvailableNowSuite extends FileStreamSourceTest { start.map(getOffsetValue).getOrElse(0L) + 1L, getOffsetValue(end) + 1L, 1, None, // Intentionally set isStreaming to false; we only use RDD plan in below. isStreaming = false) - sqlContext.internalCreateDataFrame( + sqlContext.sparkSession.internalCreateDataFrame( plan.queryExecution.toRdd, plan.schema, isStreaming = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index e74627f3f51e9..224dec72c79b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -30,6 +30,9 @@ import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.classic.Dataset.ofRows import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf @@ -91,8 +94,8 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { override def getOffset: Option[Offset] = Some(new LongOffset(0)) - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - spark.internalCreateDataFrame(spark.sparkContext.emptyRDD, schema, isStreaming = true) + override def getBatch(start: Option[Offset], end: Offset): classic.DataFrame = { + ofRows(spark.sparkSession, LocalRelation(schema).copy(isStreaming = true)) } override def stop(): Unit = {} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala index a720cc94ecd9d..daa8ca7bd9f51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala @@ -22,7 +22,9 @@ import java.util.concurrent.CountDownLatch import org.apache.zookeeper.KeeperException.UnimplementedException -import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} +import org.apache.spark.sql.{Row, SparkSession, SQLContext} +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.catalog.TableCapability.CONTINUOUS_READ import org.apache.spark.sql.connector.read.{streaming, InputPartition, Scan, ScanBuilder} @@ -85,7 +87,7 @@ class BlockOnStopSource(spark: SparkSession, latch: CountDownLatch) extends Sour override val schema: StructType = BlockOnStopSourceProvider.schema override def getOffset: Option[Offset] = Some(LongOffset(0)) override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - spark.createDataFrame(spark.sparkContext.emptyRDD[Row], schema) + spark.createDataFrame(util.Collections.emptyList[Row](), schema) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala index b083d180d9911..6865e5e0269b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming.util import java.util.concurrent.CountDownLatch -import org.apache.spark.sql.{SQLContext, _} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source} import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} import org.apache.spark.sql.streaming.OutputMode @@ -49,8 +49,7 @@ class BlockingSource extends StreamSourceProvider with StreamSinkProvider { override def schema: StructType = fakeSchema override def getOffset: Option[Offset] = Some(new LongOffset(0)) override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - import spark.sparkSession.implicits._ - Seq[Int]().toDS().toDF() + spark.sparkSession.emptyDataFrame } override def stop(): Unit = {} } @@ -64,5 +63,5 @@ class BlockingSource extends StreamSourceProvider with StreamSinkProvider { } object BlockingSource { - var latch: CountDownLatch = null + var latch: CountDownLatch = _ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 90432dea3a017..477da731b81b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets import java.time.{Duration, Period} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SparkSession, SQLImplicits} +import org.apache.spark.sql.classic.{DataFrame, SparkSession, SQLImplicits} import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND} import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index c93f17701c620..3ceffc74adc28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -33,7 +33,7 @@ import org.scalatest.{BeforeAndAfterAll, Suite, Tag} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite -import org.apache.spark.sql._ +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE @@ -41,11 +41,11 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.classic.{ClassicConversions, ColumnConversions} +import org.apache.spark.sql.classic.{ClassicConversions, ColumnConversions, ColumnNodeToExpressionConverter, DataFrame, Dataset, SparkSession, SQLImplicits} import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution import org.apache.spark.sql.execution.datasources.DataSourceUtils -import org.apache.spark.sql.internal.{ColumnNodeToExpressionConverter, SQLConf} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.UninterruptibleThread import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 4d4cc44eb3e72..b8348cefe7c9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -24,9 +24,10 @@ import org.scalatest.concurrent.Eventually import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} trait SharedSparkSession extends SQLTestUtils with SharedSparkSessionBase { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 14cec000ab02a..91c6ac6f96ef3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf, WithTestConf} /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 7e6f10bcc46f0..099a09d7784d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -22,9 +22,10 @@ import java.lang.{Long => JLong} import scala.collection.mutable.ArrayBuffer import org.apache.spark._ -import org.apache.spark.sql.{functions, Dataset, Encoder, Encoders, QueryTest, Row, SparkSession} +import org.apache.spark.sql.{functions, Encoder, Encoders, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand} @@ -448,6 +449,6 @@ case class ErrorTestCommand(foo: String) extends LeafRunnableCommand { override val output: Seq[Attribute] = Seq(AttributeReference("foo", StringType)()) - override def run(sparkSession: SparkSession): Seq[Row] = + override def run(sparkSession: org.apache.spark.sql.SparkSession): Seq[Row] = throw new java.lang.Error(foo) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala index 56219766f7095..0ddb3614270ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark._ import org.apache.spark.sql.{LocalSparkSession, SparkSession} +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf._ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 0801bffed8e52..7cc181ea6945a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -31,6 +31,7 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.COMMAND import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.CommandResult +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException, SQLExecution} import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.internal.{SQLConf, VariableSubstitution} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala index 83f1824c26d2b..5abf034c1dea1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -30,7 +30,7 @@ import org.mockito.Mockito.{doReturn, mock, spy, when, RETURNS_DEEP_STUBS} import org.mockito.invocation.InvocationOnMock import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.classic.{DataFrame, SparkSession} import org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2EventManager import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolver.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolver.scala index 842faba66cc30..acbc0cee0e301 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolver.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolver.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.datasources.{DataSourceResolver, LogicalRelation} /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2d72443a8b661..a1a0044ea5282 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.SparkContext import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.classic.{SparkSession, SQLContext} /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5947309b87983..bd35342b909f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -27,11 +27,12 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 169d6b70cb50b..ff2605b0b3b92 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -24,13 +24,14 @@ import scala.util.control.NonFatal import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} -import org.apache.spark.sql._ +import org.apache.spark.sql.{AnalysisException, Strategy} import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveDataSource, ResolveSessionCatalog, ResolveTranspose} import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.SparkPlanner import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 478f486eeb213..a8e91dc1c1e85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -38,8 +38,8 @@ import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys -import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.classic.SQLContext import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.hive.client._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index f69eab7f837f7..91918fe62362b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -25,10 +25,11 @@ import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.SparkException -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.hive.client.HiveClientImpl diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index cf296e8be4f14..b5d3fb699d62e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -23,13 +23,14 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.plan.FileSinkDesc import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 47d402c2e8b1a..b921bf28fd91c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.hive.execution import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.io.FileCommitProtocol -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 35c4b476e1125..2d72e0a7fd55f 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -60,7 +60,8 @@ public void setUp() throws IOException { if (path.exists()) { path.delete(); } - HiveSessionCatalog catalog = (HiveSessionCatalog) sqlContext.sessionState().catalog(); + HiveSessionCatalog catalog = + (HiveSessionCatalog) sqlContext.sparkSession().sessionState().catalog(); hiveManagedPath = new Path(catalog.defaultTablePath(new TableIdentifier("javaSavedTable"))); fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); fs.delete(hiveManagedPath, true); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index f5bf49439d3f9..f04f153801e31 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -24,10 +24,10 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{Column, DataFrame, SparkSession} import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.ExpressionUtils.expression import org.apache.spark.sql.functions.{lit, percentile_approx => pa} import org.apache.spark.sql.hive.execution.TestingTypedCount import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.internal.ExpressionUtils.expression import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.LongType diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index d7918f8cbf4f0..97de3809eb045 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.{AnalysisException, Dataset, QueryTest, Row, SaveMode} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 95baffdee06cb..684e4c53e4ef5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -34,9 +34,10 @@ import org.apache.spark.deploy.SparkSubmitTestUtils import org.apache.spark.internal.config.MASTER_REST_SERVER_ENABLED import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.launcher.JavaModuleOptions -import org.apache.spark.sql.{QueryTest, Row, SparkSession} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.tags.{ExtendedHiveTest, SlowHiveTest} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSharedStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSharedStateSuite.scala index 8c6113fb5569d..803ff5154b104 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSharedStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSharedStateSuite.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.internal.config.UI -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7c9b0b7781427..13cf2916c05a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -22,8 +22,10 @@ import scala.util.Random import test.org.apache.spark.sql.MyDoubleAvg import test.org.apache.spark.sql.MyDoubleSum -import org.apache.spark.sql._ +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 73dda42568a71..c4ccb07cb10b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -27,10 +27,10 @@ import scala.util.control.NonFatal import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 5431066c30a9f..9dcc6abe20271 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -27,10 +27,11 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkFiles, TestUtils} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.classic.{DataFrame, SparkSession} import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec import org.apache.spark.sql.hive.HiveUtils.{builtinHiveVersion => hiveVersion} import org.apache.spark.sql.hive.test.{HiveTestJars, TestHive} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 008a324f73dac..339b63c570901 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -22,8 +22,9 @@ import scala.util.Random import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax import org.scalatest.matchers.must.Matchers._ -import org.apache.spark.sql._ +import org.apache.spark.sql.{Column, QueryTest, RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.functions.{col, count_distinct, first, lit, max, percentile_approx => pa} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala index 247a1c7096cb7..220d965d28602 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -32,12 +32,12 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.internal.config.UI._ -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession, SQLContext} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.execution.{CommandExecutionMode, QueryExecution, SQLExecution} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index 7a0599cda2fe7..c8d72b78c64c4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.test import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.client.HiveClient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 4060cc7172e60..a89ea2424696e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -21,12 +21,13 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.spark.sql.{sources, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedProjection, JoinedRow, Literal, Predicate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SerializableConfiguration