From bdea09128760a5d36410cae2074f70a77e95c12b Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Fri, 20 Sep 2024 15:34:17 -0700 Subject: [PATCH] [SPARK-49557][SQL] Add SQL pipe syntax for the WHERE operator ### What changes were proposed in this pull request? This PR adds SQL pipe syntax support for the WHERE operator. For example: ``` CREATE TABLE t(x INT, y STRING) USING CSV; INSERT INTO t VALUES (0, 'abc'), (1, 'def'); CREATE TABLE other(a INT, b INT) USING JSON; INSERT INTO other VALUES (1, 1), (1, 2), (2, 4); TABLE t |> WHERE x + LENGTH(y) < 4; 0 abc TABLE t |> WHERE (SELECT ANY_VALUE(a) FROM other WHERE x = a LIMIT 1) = 1 1 def TABLE t |> WHERE SUM(x) = 1 Error: aggregate functions are not allowed in the pipe operator |> WHERE clause ``` ### Why are the changes needed? The SQL pipe operator syntax will let users compose queries in a more flexible fashion. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds a few unit test cases, but mostly relies on golden file test coverage. I did this to make sure the answers are correct as this feature is implemented and also so we can look at the analyzer output plans to ensure they look right as well. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48091 from dtenedor/pipe-where. Authored-by: Daniel Tenedorio Signed-off-by: Gengliang Wang --- .../sql/catalyst/parser/SqlBaseParser.g4 | 1 + .../sql/catalyst/parser/AstBuilder.scala | 15 +- .../analyzer-results/pipe-operators.sql.out | 272 ++++++++++++++++++ .../sql-tests/inputs/pipe-operators.sql | 94 +++++- .../sql-tests/results/pipe-operators.sql.out | 268 +++++++++++++++++ .../sql/execution/SparkSqlParserSuite.scala | 12 +- 6 files changed, 658 insertions(+), 4 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index e591a43b84d1a..094f7f5315b80 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1492,6 +1492,7 @@ version operatorPipeRightSide : selectClause + | whereClause ; // When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 52529bb4b789b..674005caaf1b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5876,7 +5876,20 @@ class AstBuilder extends DataTypeAstBuilder windowClause = null, relation = left, isPipeOperatorSelect = true) - }.get + }.getOrElse(Option(ctx.whereClause).map { c => + // Add a table subquery boundary between the new filter and the input plan if one does not + // already exist. This helps the analyzer behave as if we had added the WHERE clause after a + // table subquery containing the input plan. + val withSubqueryAlias = left match { + case s: SubqueryAlias => + s + case u: UnresolvedRelation => + u + case _ => + SubqueryAlias(SubqueryAlias.generateSubqueryName(), left) + } + withWhereClause(c, withSubqueryAlias) + }.get) } /** diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out index ab0635fef048b..c44ce153a2f41 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out @@ -255,6 +255,55 @@ Distinct +- Relation spark_catalog.default.t[x#x,y#x] csv +-- !query +table t +|> select * +-- !query analysis +Project [x#x, y#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select * except (y) +-- !query analysis +Project [x#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ * +-- !query analysis +Repartition 3, true ++- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ distinct x +-- !query analysis +Repartition 3, true ++- Distinct + +- Project [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ all x +-- !query analysis +Repartition 3, true ++- Project [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + -- !query table t |> select sum(x) as result @@ -297,6 +346,229 @@ org.apache.spark.sql.AnalysisException } +-- !query +table t +|> where true +-- !query analysis +Filter true ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x + length(y) < 4 +-- !query analysis +Filter ((x#x + length(y#x)) < 4) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3 +-- !query analysis +Filter ((x#x + length(y#x)) < 3) ++- SubqueryAlias __auto_generated_subquery_name + +- Filter ((x#x + length(y#x)) < 4) + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias __auto_generated_subquery_name + +- Aggregate [x#x], [x#x, sum(length(y#x)) AS sum_len#xL] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where t.x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where spark_catalog.default.t.x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select col from st) +|> where col.i1 = 1 +-- !query analysis +Filter (col#x.i1 = 1) ++- SubqueryAlias __auto_generated_subquery_name + +- Project [col#x] + +- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table st +|> where st.col.i1 = 2 +-- !query analysis +Filter (col#x.i1 = 2) ++- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table t +|> where exists (select a from other where x = a limit 1) +-- !query analysis +Filter exists#x [x#x] +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Project [a#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where (select any_value(a) from other where x = a limit 1) = 1 +-- !query analysis +Filter (scalar-subquery#x [x#x] = 1) +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Aggregate [any_value(a#x, false) AS any_value(a)#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where sum(x) = 1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(x) = 1)\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 27, + "fragment" : "table t\n|> where sum(x) = 1" + } ] +} + + +-- !query +table t +|> where y = 'abc' or length(y) + sum(x) = 1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"((y = abc) OR ((length(y) + sum(x)) = 1))\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "table t\n|> where y = 'abc' or length(y) + sum(x) = 1" + } ] +} + + +-- !query +table t +|> where first_value(x) over (partition by y) = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +select * from t where first_value(x) over (partition by y) = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +table t +|> select x, length(y) as z +|> where x + length(y) < 4 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `z`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 57, + "stopIndex" : 57, + "fragment" : "y" + } ] +} + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `sum_len`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 77, + "stopIndex" : 77, + "fragment" : "y" + } ] +} + + -- !query drop table t -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 7d0966e7f2095..49a72137ee047 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -12,7 +12,7 @@ drop table if exists st; create table st(x int, col struct) using parquet; insert into st values (1, (2, 3)); --- Selection operators: positive tests. +-- SELECT operators: positive tests. --------------------------------------- -- Selecting a constant. @@ -85,7 +85,24 @@ table t table t |> select distinct x, y; --- Selection operators: negative tests. +-- SELECT * is supported. +table t +|> select *; + +table t +|> select * except (y); + +-- Hints are supported. +table t +|> select /*+ repartition(3) */ *; + +table t +|> select /*+ repartition(3) */ distinct x; + +table t +|> select /*+ repartition(3) */ all x; + +-- SELECT operators: negative tests. --------------------------------------- -- Aggregate functions are not allowed in the pipe operator SELECT list. @@ -95,6 +112,79 @@ table t table t |> select y, length(y) + sum(x) as result; +-- WHERE operators: positive tests. +----------------------------------- + +-- Filtering with a constant predicate. +table t +|> where true; + +-- Filtering with a predicate based on attributes from the input relation. +table t +|> where x + length(y) < 4; + +-- Two consecutive filters are allowed. +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3; + +-- It is possible to use the WHERE operator instead of the HAVING clause when processing the result +-- of aggregations. For example, this WHERE operator is equivalent to the normal SQL "HAVING x = 1". +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1; + +-- Filtering by referring to the table or table subquery alias. +table t +|> where t.x = 1; + +table t +|> where spark_catalog.default.t.x = 1; + +-- Filtering using struct fields. +(select col from st) +|> where col.i1 = 1; + +table st +|> where st.col.i1 = 2; + +-- Expression subqueries in the WHERE clause. +table t +|> where exists (select a from other where x = a limit 1); + +-- Aggregations are allowed within expression subqueries in the pipe operator WHERE clause as long +-- no aggregate functions exist in the top-level expression predicate. +table t +|> where (select any_value(a) from other where x = a limit 1) = 1; + +-- WHERE operators: negative tests. +----------------------------------- + +-- Aggregate functions are not allowed in the top-level WHERE predicate. +-- (Note: to implement this behavior, perform the aggregation first separately and then add a +-- pipe-operator WHERE clause referring to the result of aggregate expression(s) therein). +table t +|> where sum(x) = 1; + +table t +|> where y = 'abc' or length(y) + sum(x) = 1; + +-- Window functions are not allowed in the WHERE clause (pipe operators or otherwise). +table t +|> where first_value(x) over (partition by y) = 1; + +select * from t where first_value(x) over (partition by y) = 1; + +-- Pipe operators may only refer to attributes produced as output from the directly-preceding +-- pipe operator, not from earlier ones. +table t +|> select x, length(y) as z +|> where x + length(y) < 4; + +-- If the WHERE clause wants to filter rows produced by an aggregation, it is not valid to try to +-- refer to the aggregate functions directly; it is necessary to use aliases instead. +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3; + -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out index 7e0b7912105c2..38436b0941034 100644 --- a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out @@ -238,6 +238,56 @@ struct 1 def +-- !query +table t +|> select * +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select * except (y) +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select /*+ repartition(3) */ * +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select /*+ repartition(3) */ distinct x +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select /*+ repartition(3) */ all x +-- !query schema +struct +-- !query output +0 +1 + + -- !query table t |> select sum(x) as result @@ -284,6 +334,224 @@ org.apache.spark.sql.AnalysisException } +-- !query +table t +|> where true +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> where x + length(y) < 4 +-- !query schema +struct +-- !query output +0 abc + + +-- !query +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3 +-- !query schema +struct +-- !query output + + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1 +-- !query schema +struct +-- !query output +1 3 + + +-- !query +table t +|> where t.x = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where spark_catalog.default.t.x = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +(select col from st) +|> where col.i1 = 1 +-- !query schema +struct> +-- !query output + + + +-- !query +table st +|> where st.col.i1 = 2 +-- !query schema +struct> +-- !query output +1 {"i1":2,"i2":3} + + +-- !query +table t +|> where exists (select a from other where x = a limit 1) +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where (select any_value(a) from other where x = a limit 1) = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where sum(x) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(x) = 1)\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 27, + "fragment" : "table t\n|> where sum(x) = 1" + } ] +} + + +-- !query +table t +|> where y = 'abc' or length(y) + sum(x) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"((y = abc) OR ((length(y) + sum(x)) = 1))\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "table t\n|> where y = 'abc' or length(y) + sum(x) = 1" + } ] +} + + +-- !query +table t +|> where first_value(x) over (partition by y) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +select * from t where first_value(x) over (partition by y) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +table t +|> select x, length(y) as z +|> where x + length(y) < 4 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `z`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 57, + "stopIndex" : 57, + "fragment" : "y" + } ] +} + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `sum_len`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 77, + "stopIndex" : 77, + "fragment" : "y" + } ] +} + + -- !query drop table t -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index a80444feb68ae..ab949c5a21e44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, Un import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, PROJECT, UNRESOLVED_RELATION} +import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, LOCAL_RELATION, PROJECT, UNRESOLVED_RELATION} import org.apache.spark.sql.connector.catalog.TableCatalog import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource} @@ -895,6 +895,16 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { checkPipeSelect("TABLE t |> SELECT 1 AS X") checkPipeSelect("TABLE t |> SELECT 1 AS X, 2 AS Y |> SELECT X + Y AS Z") checkPipeSelect("VALUES (0), (1) tab(col) |> SELECT col * 2 AS result") + // Basic WHERE operators. + def checkPipeWhere(query: String): Unit = { + val plan: LogicalPlan = parser.parsePlan(query) + assert(plan.containsPattern(FILTER)) + assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) + } + checkPipeWhere("TABLE t |> WHERE X = 1") + checkPipeWhere("TABLE t |> SELECT X, LENGTH(Y) AS Z |> WHERE X + LENGTH(Y) < 4") + checkPipeWhere("TABLE t |> WHERE X = 1 AND Y = 2 |> WHERE X + Y = 3") + checkPipeWhere("VALUES (0), (1) tab(col) |> WHERE col < 1") } } }