From 471cf23bf4b2c9f1e37a2df7e818636746f54447 Mon Sep 17 00:00:00 2001 From: Heng Qian Date: Fri, 20 Dec 2024 11:26:14 +0800 Subject: [PATCH 1/3] [POC][DON'T MERGE] Transform ANTLR4 to SqlNode Signed-off-by: Heng Qian --- build.sbt | 4 + .../opensearch/flint/spark/ppl/Function.scala | 16 ++ .../flint/spark/ppl/PPLFunctionResolver.scala | 17 ++ .../flint/spark/ppl/PPLSyntaxParser.scala | 4 +- .../flint/spark/ppl/SqlNodeBuilder.scala | 191 ++++++++++++++++++ .../flint/spark/ppl/PPLSqlNodeTestSuite.scala | 96 +++++++++ 6 files changed, 326 insertions(+), 2 deletions(-) create mode 100644 ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/Function.scala create mode 100644 ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala create mode 100644 ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala diff --git a/build.sbt b/build.sbt index 365b88aa3..35f821e14 100644 --- a/build.sbt +++ b/build.sbt @@ -15,6 +15,7 @@ lazy val jacksonVersion = "2.15.2" lazy val opensearchVersion = "2.6.0" lazy val opensearchMavenVersion = "2.6.0.0" lazy val icebergVersion = "1.5.0" +lazy val calciteVersion = "1.37.0" val scalaMinorVersion = scala212.split("\\.").take(2).mkString(".") val sparkMinorVersion = sparkVersion.split("\\.").take(2).mkString(".") @@ -120,6 +121,7 @@ lazy val flintCore = (project in file("flint-core")) exclude ("com.fasterxml.jackson.core", "jackson-core") exclude ("org.apache.httpcomponents.client5", "httpclient5"), "org.opensearch" % "opensearch-job-scheduler-spi" % opensearchMavenVersion, + "org.apache.calcite" % "calcite-core" % calciteVersion, "dev.failsafe" % "failsafe" % "3.3.2", "com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided" exclude ("com.fasterxml.jackson.core", "jackson-databind"), @@ -193,6 +195,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test", "com.github.sbt" % "junit-interface" % "0.13.3" % "test", "org.projectlombok" % "lombok" % "1.18.30", + "org.apache.calcite" % "calcite-core" % calciteVersion, "com.github.seancfoley" % "ipaddress" % "5.5.1", ), libraryDependencies ++= deps(sparkVersion), @@ -228,6 +231,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) libraryDependencies ++= Seq( "com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided" exclude ("com.fasterxml.jackson.core", "jackson-databind"), + "org.apache.calcite" % "calcite-core" % calciteVersion, "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/Function.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/Function.scala new file mode 100644 index 000000000..35ad534a9 --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/Function.scala @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import java.util + +import org.apache.calcite.sql.parser.SqlParserPos.ZERO +import org.apache.calcite.sql.{SqlCall, SqlLiteral, SqlNode, SqlOperator} + +case class Function(functionName: String, sqlOperator: SqlOperator) { + + def createCall(function: SqlNode, operands: util.List[SqlNode], qualifier: SqlLiteral): SqlCall = + sqlOperator.createCall(qualifier, ZERO, operands) +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala new file mode 100644 index 000000000..6682de9f2 --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import org.apache.calcite.sql.SqlOperator +import org.apache.calcite.sql.fun.SqlStdOperatorTable + +case class PPLFunctionResolver() { + def resolve(name: String): SqlOperator = { + name match { + case "=" => SqlStdOperatorTable.EQUALS + case "avg" => SqlStdOperatorTable.AVG + } + } +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala index ed498e98b..f21b02d8c 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala @@ -4,11 +4,11 @@ */ package org.opensearch.flint.spark.ppl -import org.antlr.v4.runtime.{CommonTokenStream, Lexer} import org.antlr.v4.runtime.tree.ParseTree +import org.antlr.v4.runtime.{CommonTokenStream, Lexer} import org.opensearch.sql.ast.statement.Statement import org.opensearch.sql.common.antlr.{CaseInsensitiveCharStream, Parser, SyntaxAnalysisErrorListener} -import org.opensearch.sql.ppl.parser.{AstBuilder, AstExpressionBuilder, AstStatementBuilder} +import org.opensearch.sql.ppl.parser.{AstBuilder, AstStatementBuilder} class PPLSyntaxParser extends Parser { // Analyze the query syntax diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala new file mode 100644 index 000000000..a4cf4f0bc --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import scala.collection.JavaConverters._ + +import org.antlr.v4.runtime.CommonTokenStream +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.ParseCancellationException +import org.opensearch.sql.common.antlr.{CaseInsensitiveCharStream, SyntaxAnalysisErrorListener} + +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.sql.parser.SqlParserPos.ZERO +import org.apache.calcite.sql.{SqlBasicCall, SqlIdentifier, SqlLiteral, SqlNode, SqlNodeList, SqlSelect} + + +class PPLParser { + val astBuilder = new PPLAstBuilder() + + def parseQuery(query: String): SqlNode = parse(query) { parser => + val ctx = parser.root().pplStatement() + val a = astBuilder.visit(ctx) + a + } + + protected def parse[T](command: String)(toResult: OpenSearchPPLParser => T): T = { + val lexer = new OpenSearchPPLLexer(new CaseInsensitiveCharStream(command)) + // lexer.removeErrorListeners() + // lexer.addErrorListener(ParseErrorListener) + lexer.addErrorListener(new SyntaxAnalysisErrorListener()) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new OpenSearchPPLParser(tokenStream) + parser.addErrorListener(new SyntaxAnalysisErrorListener()) + // parser.addParseListener(PostProcessor) + // parser.addParseListener(UnclosedCommentProcessor(command, tokenStream)) + // parser.removeErrorListeners() + // parser.addErrorListener(ParseErrorListener) + /* + parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced + parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled + parser.SQL_standard_keyword_behavior = conf.enforceReservedKeywords + parser.double_quoted_identifiers = conf.doubleQuotedIdentifiers + */ + + // https://github.com/antlr/antlr4/issues/192#issuecomment-15238595 + // Save a great deal of time on correct inputs by using a two-stage parsing strategy. + try { + try { + // first, try parsing with potentially faster SLL mode w/ SparkParserBailErrorStrategy + // parser.setErrorHandler(new SparkParserBailErrorStrategy()) + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + val a = toResult(parser) + a + } + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode w/ SparkParserErrorStrategy + tokenStream.seek(0) // rewind input stream + parser.reset() + + // Try Again. + // parser.setErrorHandler(new SparkParserErrorStrategy()) + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + } +} + +class PPLAstBuilder extends OpenSearchPPLParserBaseVisitor[SqlNode] { + val functionResolver = PPLFunctionResolver(); + + override def visitDmlStatement(ctx: OpenSearchPPLParser.DmlStatementContext): SqlNode = { + visit(ctx.queryStatement()) + } + + override def visitQueryStatement(ctx: OpenSearchPPLParser.QueryStatementContext): SqlNode = { + val source = visit(ctx.pplCommands()).asInstanceOf[SqlSelect] + val commands = ctx.commands().asScala.map(visit).map(_.asInstanceOf[SqlSelect]) + val result: SqlSelect = commands.foldLeft(source) {(pre: SqlNode, cur: SqlSelect) => + cur.setFrom(pre) + cur + } + result + } + + override def visitPplCommands(ctx: OpenSearchPPLParser.PplCommandsContext): SqlNode = { + val from = visit(ctx.searchCommand()) + new SqlSelect(ZERO, null, SqlNodeList.SINGLETON_STAR, from, null, null, null, null, null, null, null, null) + } + + override def visitFromClause(ctx: OpenSearchPPLParser.FromClauseContext): SqlNode = { + super.visitFromClause(ctx) + } + + override def visitTableOrSubqueryClause(ctx: OpenSearchPPLParser.TableOrSubqueryClauseContext): SqlNode = { + if (ctx.subSearch() != null) { + null + } else { + visitTableSourceClause(ctx.tableSourceClause()); + } + } + + override def visitTableSourceClause(ctx: OpenSearchPPLParser.TableSourceClauseContext): SqlNode = { + var sqlNodes = Seq[SqlNode]() + for (i <- 0 until ctx.tableSource().size) { + sqlNodes :+= visitTableSource(ctx.tableSource(i)).asInstanceOf[SqlNode] + } + // val sqlNodes = ctx.tableSource().stream().map(a => visitTableSource(a)).collect(Collectors.toList) + if (ctx.alias == null) { + sqlNodes.head + // sqlNodes.get(0) + //} else new SqlBasicCall(SqlStdOperatorTable.AS, sqlNodes.toArray(new Array[SqlNode](0)), ZERO) + } else new SqlBasicCall(SqlStdOperatorTable.AS, sqlNodes.toArray, ZERO) + } + + override def visitIdentsAsTableQualifiedName(ctx: OpenSearchPPLParser.IdentsAsTableQualifiedNameContext): SqlNode = { + new SqlIdentifier(ctx.tableIdent().ident().getText, ZERO) + } + + override def visitWhereCommand(ctx: OpenSearchPPLParser.WhereCommandContext): SqlNode = { + val where = visitChildren(ctx) + new SqlSelect(ZERO, null, SqlNodeList.SINGLETON_STAR, null, where, null, null, null, null, null, null, null) + } + + override def visitComparsion(ctx: OpenSearchPPLParser.ComparsionContext): SqlNode = { + super.visitComparsion(ctx) + } + + override def visitCompareExpr(ctx: OpenSearchPPLParser.CompareExprContext): SqlNode = { + functionResolver.resolve(ctx.comparisonOperator.getText).createCall(null, ZERO, visit(ctx.left), visit(ctx.right)) + } + + override def visitIdentsAsQualifiedName(ctx: OpenSearchPPLParser.IdentsAsQualifiedNameContext): SqlNode = { + new SqlIdentifier(ctx.ident().asScala.map(_.getText).reduce((a, b) => a + "." + b), ZERO) + } + + override def visitIdent(ctx: OpenSearchPPLParser.IdentContext): SqlNode = { + new SqlIdentifier(ctx.getText, ZERO) + } + + override def visitIntegerLiteral(ctx: OpenSearchPPLParser.IntegerLiteralContext): SqlNode = { + SqlLiteral.createExactNumeric(ctx.getText, ZERO) + } + + override def visitFieldsCommand(ctx: OpenSearchPPLParser.FieldsCommandContext): SqlNode = { + val selectExpr = visitFieldList(ctx.fieldList()) + new SqlSelect(ZERO, null, selectExpr, null, null, null, null, null, null, null, null, null) + } + + override def visitFieldList(ctx: OpenSearchPPLParser.FieldListContext): SqlNodeList = { + val fields = ctx.fieldExpression.asScala.map(visit) + SqlNodeList.of(ZERO, fields.asJava) + } + + override def visitSortCommand(ctx: OpenSearchPPLParser.SortCommandContext): SqlNode = { + val orderByList = visitSortbyClause(ctx.sortbyClause()) + new SqlSelect(ZERO, null, SqlNodeList.SINGLETON_STAR, null, null, null, null, null, orderByList, null, null, null) + } + + override def visitSortbyClause(ctx: OpenSearchPPLParser.SortbyClauseContext): SqlNodeList = { + val fields = ctx.sortField().asScala.map(visit) + SqlNodeList.of(ZERO, fields.asJava) + } + + override def visitStatsCommand(ctx: OpenSearchPPLParser.StatsCommandContext): SqlNode = { + val aggList = ctx.statsAggTerm.asScala.map(visit) + val groupByList = visitStatsByClause(ctx.statsByClause()) + new SqlSelect(ZERO, null, SqlNodeList.of(ZERO, (groupByList.getList.asScala ++ aggList).asJava), null, null, groupByList, null, null, null, null, null, null) + } + + override def visitStatsAggTerm(ctx: OpenSearchPPLParser.StatsAggTermContext): SqlNode = { + val agg = visit(ctx.statsFunction()) + + if (ctx.alias == null) agg else { + val alias = visit(ctx.alias) + new SqlBasicCall(SqlStdOperatorTable.AS, Seq(agg, alias).asJava.toArray(new Array[SqlNode](0)), ZERO) + } + } + + override def visitStatsFunctionCall(ctx: OpenSearchPPLParser.StatsFunctionCallContext): SqlNode = { + functionResolver.resolve(ctx.statsFunctionName.getText).createCall(null, ZERO, visit(ctx.valueExpression)) + } + + override def visitStatsByClause(ctx: OpenSearchPPLParser.StatsByClauseContext): SqlNodeList = { + SqlNodeList.of(ZERO, ctx.fieldList().fieldExpression().asScala.map(visit).asJava) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala new file mode 100644 index 000000000..968821df8 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import java.util +import java.util.Objects.requireNonNull + +import org.scalatest.matchers.should.Matchers + +import org.apache.calcite.adapter.java.AbstractQueryableTable +import org.apache.calcite.config.{CalciteConnectionConfig, Lex} +import org.apache.calcite.jdbc.{CalciteSchema, JavaTypeFactoryImpl} +import org.apache.calcite.linq4j.{Enumerable, Linq4j, QueryProvider, Queryable} +import org.apache.calcite.plan.RelOptCluster +import org.apache.calcite.plan.volcano.VolcanoPlanner +import org.apache.calcite.prepare.{CalciteCatalogReader, PlannerImpl} +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory} +import org.apache.calcite.rel.rel2sql.RelToSqlConverter +import org.apache.calcite.rex.RexBuilder +import org.apache.calcite.schema.SchemaPlus +import org.apache.calcite.schema.impl.AbstractTable +import org.apache.calcite.sql.SqlDialect.DatabaseProduct +import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.calcite.sql.parser.SqlParser +import org.apache.calcite.sql2rel.SqlToRelConverter +import org.apache.calcite.tools.{FrameworkConfig, Frameworks, Programs} +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.PlanTest + +class PPLSqlNodeTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + val t: AbstractTable = new AbstractQueryableTable(classOf[Integer]) { + val enumerable: Enumerable[Integer] = Linq4j.asEnumerable(new util.ArrayList[Integer]()) + + override def asQueryable[E](queryProvider: QueryProvider, schema: SchemaPlus, tableName: String): Queryable[E] = enumerable.asQueryable.asInstanceOf[Queryable[E]] + + override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = { + val builder: RelDataTypeFactory.Builder = typeFactory.builder + builder.add("a", SqlTypeName.INTEGER) + builder.add("b", SqlTypeName.INTEGER) + builder.add("c", SqlTypeName.INTEGER) + for (i <- 0 until 3) { + builder.add(s"c$i", SqlTypeName.INTEGER) + } + builder.build + } + } + + private def createCatalogReader = { + val defaultSchema = requireNonNull(config.getDefaultSchema, "defaultSchema") + val rootSchema = defaultSchema + new CalciteCatalogReader(CalciteSchema.from(rootSchema), CalciteSchema.from(defaultSchema).path(null), typeFactory, CalciteConnectionConfig.DEFAULT) + } + + val schema: SchemaPlus = Frameworks.createRootSchema(true) + schema.add("table", t) + val config: FrameworkConfig = Frameworks.newConfigBuilder + .parserConfig(SqlParser.config.withLex(Lex.MYSQL)) + .defaultSchema(schema) + .programs(Programs.ofRules(Programs.RULE_SET)) + .build + val typeFactory = new JavaTypeFactoryImpl(config.getTypeSystem) + val pplParser = new PPLParser() + val planner = Frameworks.getPlanner(config) + val cluster: RelOptCluster = RelOptCluster.create(requireNonNull(new VolcanoPlanner(config.getCostFactory, config.getContext), "planner"), new RexBuilder(typeFactory)) + val sqlToRelConverter = new SqlToRelConverter(planner.asInstanceOf[PlannerImpl], null, createCatalogReader, cluster, config.getConvertletTable, config.getSqlToRelConverterConfig) + val relToSqlConverter = new RelToSqlConverter(DatabaseProduct.CALCITE.getDialect) + val pplParserOld = new PPLSyntaxParser() + + test("test") { + val sqlNode = pplParser.parseQuery("source=table | where a = 1| stats avg(b) as avg_b by c | sort c | fields c, avg_b") + val relNode = sqlToRelConverter.convertQuery(sqlNode, false, true) + + val sqlNode2 = planner.parse(sqlNode.toString()) + planner.validate(sqlNode2) + val relNode2 = planner.rel(sqlNode2) + val sqlNode3 = relToSqlConverter.visitRoot(relNode.rel).asStatement() + + // val relNode = planner.rel(sqlNode) + // val osPlan = plan(pplParserOld, "source=t") + //scalastyle:off + println(sqlNode) + println(relNode2) + println(sqlNode3) + // println(osPlan) + //scalastyle:on + } + +} From 7fb4493c55ede2c4d39d43f09b25a99856419c06 Mon Sep 17 00:00:00 2001 From: Heng Qian Date: Wed, 25 Dec 2024 16:26:08 +0800 Subject: [PATCH 2/3] [POC] Support Eval Signed-off-by: Heng Qian --- .../flint/spark/ppl/MyValidator.scala | 34 ++++++++++ .../flint/spark/ppl/PPLFunctionResolver.scala | 2 + .../flint/spark/ppl/SqlIdentifierExtend.scala | 27 ++++++++ .../flint/spark/ppl/SqlNodeBuilder.scala | 21 +++++++ .../flint/spark/ppl/PPLSqlNodeTestSuite.scala | 62 +++++++++++++++---- 5 files changed, 134 insertions(+), 12 deletions(-) create mode 100644 ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/MyValidator.scala create mode 100644 ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlIdentifierExtend.scala diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/MyValidator.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/MyValidator.scala new file mode 100644 index 000000000..f48be2ea7 --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/MyValidator.scala @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import scala.collection.JavaConverters._ + +import org.apache.calcite.adapter.java.JavaTypeFactory +import org.apache.calcite.prepare.CalciteCatalogReader +import org.apache.calcite.sql.parser.SqlParserPos.ZERO +import org.apache.calcite.sql.validate.SqlValidator.Config +import org.apache.calcite.sql.validate.SqlValidatorImpl +import org.apache.calcite.sql.{SqlNodeList, SqlOperatorTable, SqlSelect} + +class MyValidator(opTab: SqlOperatorTable, catalogReader: CalciteCatalogReader, typeFactory: JavaTypeFactory, config: Config) + extends SqlValidatorImpl(opTab, catalogReader, typeFactory, config) { + + override def expandStar(selectList: SqlNodeList, select: SqlSelect, includeSystemVars: Boolean): SqlNodeList = { + val (starExcepts, others) = selectList.asScala.partition(_.isInstanceOf[StarExcept]) + val starExceptList = starExcepts.flatMap(starExcept => { + val originList = super.expandStar(SqlNodeList.of(starExcept), select, includeSystemVars) + val exceptList = super.expandStar(starExcept.asInstanceOf[StarExcept].exceptList, select, includeSystemVars) + val exceptListStr = exceptList.asScala.map(_.toString) + originList.removeIf(item => exceptListStr.contains(item.toString)) + originList.asScala + }) + val otherList = super.expandStar(SqlNodeList.of(ZERO, others.asJava), select, includeSystemVars) + val expandedList = SqlNodeList.of(ZERO, (otherList.asScala ++ starExceptList).asJava) + getRawSelectScope(select).setExpandedSelectList(expandedList) + expandedList + } + +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala index 6682de9f2..1b421e9a2 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala @@ -11,6 +11,8 @@ case class PPLFunctionResolver() { def resolve(name: String): SqlOperator = { name match { case "=" => SqlStdOperatorTable.EQUALS + case "+" => SqlStdOperatorTable.PLUS + case "-" => SqlStdOperatorTable.MINUS case "avg" => SqlStdOperatorTable.AVG } } diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlIdentifierExtend.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlIdentifierExtend.scala new file mode 100644 index 000000000..3dc092020 --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlIdentifierExtend.scala @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + + +import com.google.common.collect.ImmutableList +import lombok.Getter + +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.{SqlIdentifier, SqlNodeList, SqlWriter} + +@Getter +case class StarExcept(exceptList: SqlNodeList)(pos: SqlParserPos) + extends SqlIdentifier(ImmutableList.of(""), pos) { + + override def toString: String = { + super.toString + " EXCEPT " + exceptList.toString + } + + override def unparse(writer: SqlWriter, leftPrec: Int, rightPrec: Int): Unit = { + super.unparse(writer, leftPrec, rightPrec) + writer.keyword("EXCEPT") + writer.list(SqlWriter.FrameTypeEnum.PARENTHESES, SqlWriter.COMMA, exceptList) + } +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala index a4cf4f0bc..990a6078f 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala @@ -188,4 +188,25 @@ class PPLAstBuilder extends OpenSearchPPLParserBaseVisitor[SqlNode] { override def visitStatsByClause(ctx: OpenSearchPPLParser.StatsByClauseContext): SqlNodeList = { SqlNodeList.of(ZERO, ctx.fieldList().fieldExpression().asScala.map(visit).asJava) } + + override def visitEvalCommand(ctx: OpenSearchPPLParser.EvalCommandContext): SqlNode = { + val evalClause = ctx.evalClause().asScala + val (identList, fieldExprList) = evalClause.map(clause => { + val fieldExpr = visit(clause.fieldExpression().qualifiedName()) + val expr = visit(clause.expression()) + (new SqlBasicCall(SqlStdOperatorTable.AS, Seq(expr, fieldExpr).asJava, ZERO).asInstanceOf[SqlNode], fieldExpr) + }).unzip + identList.append(StarExcept(SqlNodeList.of(ZERO, fieldExprList.asJava))(ZERO)) + new SqlSelect(ZERO, null, SqlNodeList.of(ZERO, identList.asJava), null, null, null, null, null, null, null, null, null) + } + + override def visitBinaryArithmetic(ctx: OpenSearchPPLParser.BinaryArithmeticContext): SqlNode = { + functionResolver.resolve(ctx.binaryOperator.getText).createCall(null, ZERO, visit(ctx.left), visit(ctx.right)) + } + + + override def visitLookupCommand(ctx: OpenSearchPPLParser.LookupCommandContext): SqlNode = { + super.visitLookupCommand(ctx) + } + } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala index 968821df8..81dc82c4f 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala @@ -25,6 +25,7 @@ import org.apache.calcite.schema.impl.AbstractTable import org.apache.calcite.sql.SqlDialect.DatabaseProduct import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.sql.parser.SqlParser +import org.apache.calcite.sql.util.SqlOperatorTables import org.apache.calcite.sql2rel.SqlToRelConverter import org.apache.calcite.tools.{FrameworkConfig, Frameworks, Programs} import org.apache.spark.SparkFunSuite @@ -46,9 +47,6 @@ class PPLSqlNodeTestSuite builder.add("a", SqlTypeName.INTEGER) builder.add("b", SqlTypeName.INTEGER) builder.add("c", SqlTypeName.INTEGER) - for (i <- 0 until 3) { - builder.add(s"c$i", SqlTypeName.INTEGER) - } builder.build } } @@ -61,6 +59,7 @@ class PPLSqlNodeTestSuite val schema: SchemaPlus = Frameworks.createRootSchema(true) schema.add("table", t) + schema.add("table2", t) val config: FrameworkConfig = Frameworks.newConfigBuilder .parserConfig(SqlParser.config.withLex(Lex.MYSQL)) .defaultSchema(schema) @@ -70,26 +69,65 @@ class PPLSqlNodeTestSuite val pplParser = new PPLParser() val planner = Frameworks.getPlanner(config) val cluster: RelOptCluster = RelOptCluster.create(requireNonNull(new VolcanoPlanner(config.getCostFactory, config.getContext), "planner"), new RexBuilder(typeFactory)) - val sqlToRelConverter = new SqlToRelConverter(planner.asInstanceOf[PlannerImpl], null, createCatalogReader, cluster, config.getConvertletTable, config.getSqlToRelConverterConfig) + val catalogReader = createCatalogReader + val opTab = SqlOperatorTables.chain(config.getOperatorTable, catalogReader) + val validator = new MyValidator(opTab, catalogReader, typeFactory, config.getSqlValidatorConfig) + val sqlToRelConverter = new SqlToRelConverter(planner.asInstanceOf[PlannerImpl], validator, catalogReader, cluster, config.getConvertletTable, config.getSqlToRelConverterConfig) val relToSqlConverter = new RelToSqlConverter(DatabaseProduct.CALCITE.getDialect) val pplParserOld = new PPLSyntaxParser() - test("test") { - val sqlNode = pplParser.parseQuery("source=table | where a = 1| stats avg(b) as avg_b by c | sort c | fields c, avg_b") + test("test basic command") { + val sqlNode = pplParser.parseQuery("source=table | where a = 1| stats avg(b) as avg_b by c | sort c | fields c, avg_b") + val validatedSqlNode = validator.validate(sqlNode) val relNode = sqlToRelConverter.convertQuery(sqlNode, false, true) + val convertedSqlNode = relToSqlConverter.visitRoot(relNode.rel).asStatement() + //scalastyle:off + println(sqlNode) + println(validatedSqlNode) + println(relNode) + println(convertedSqlNode) + // println(osPlan) + //scalastyle:on + val sqlNode2 = planner.parse(sqlNode.toString()) - planner.validate(sqlNode2) - val relNode2 = planner.rel(sqlNode2) - val sqlNode3 = relToSqlConverter.visitRoot(relNode.rel).asStatement() + val validatedSqlNode2 = planner.validate(sqlNode2) + val relNode2 = planner.rel(validatedSqlNode2) + val convertedSqlNode2 = relToSqlConverter.visitRoot(relNode2.rel).asStatement() // val relNode = planner.rel(sqlNode) // val osPlan = plan(pplParserOld, "source=t") //scalastyle:off - println(sqlNode) + println(sqlNode2) + println(validatedSqlNode2) println(relNode2) - println(sqlNode3) - // println(osPlan) + println(convertedSqlNode2) + //scalastyle:on + } + + test("test eval") { + val sqlNode = pplParser.parseQuery("source=table | where a = 1| stats avg(b) as avg_b by c | sort c | eval avg_b = avg_b + 1 | fields c, avg_b") + val validatedSqlNode = validator.validate(sqlNode) + val relNode = sqlToRelConverter.convertQuery(sqlNode, false, true) + val convertedSqlNode = relToSqlConverter.visitRoot(relNode.rel).asStatement() + //scalastyle:off + println(sqlNode) + println(validatedSqlNode) + println(relNode) + println(convertedSqlNode) + //scalastyle:on + } + + test("test eval") { + val sqlNode = pplParser.parseQuery("source=table | where a = 1| stats avg(b) as avg_b by c | sort c | eval avg_b = avg_b + 1 | fields c, avg_b") + val validatedSqlNode = validator.validate(sqlNode) + val relNode = sqlToRelConverter.convertQuery(sqlNode, false, true) + val convertedSqlNode = relToSqlConverter.visitRoot(relNode.rel).asStatement() + //scalastyle:off + println(sqlNode) + println(validatedSqlNode) + println(relNode) + println(convertedSqlNode) //scalastyle:on } From 9c9142add7c5da7b8615ede1c6e637d36c6aa0a5 Mon Sep 17 00:00:00 2001 From: Heng Qian Date: Tue, 31 Dec 2024 10:54:18 +0800 Subject: [PATCH 3/3] [POC] Support LookUp and EventStats Signed-off-by: Heng Qian --- .../flint/spark/ppl/MyValidator.scala | 25 +++++- .../flint/spark/ppl/PPLFunctionResolver.scala | 3 + .../flint/spark/ppl/SqlIdentifierExtend.scala | 6 +- .../flint/spark/ppl/SqlNodeBuilder.scala | 84 +++++++++++++++++-- .../flint/spark/ppl/PPLSqlNodeTestSuite.scala | 62 +++++++++++--- 5 files changed, 159 insertions(+), 21 deletions(-) diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/MyValidator.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/MyValidator.scala index f48be2ea7..7b5beb3f8 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/MyValidator.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/MyValidator.scala @@ -8,6 +8,7 @@ import scala.collection.JavaConverters._ import org.apache.calcite.adapter.java.JavaTypeFactory import org.apache.calcite.prepare.CalciteCatalogReader +import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.sql.parser.SqlParserPos.ZERO import org.apache.calcite.sql.validate.SqlValidator.Config import org.apache.calcite.sql.validate.SqlValidatorImpl @@ -22,8 +23,7 @@ class MyValidator(opTab: SqlOperatorTable, catalogReader: CalciteCatalogReader, val originList = super.expandStar(SqlNodeList.of(starExcept), select, includeSystemVars) val exceptList = super.expandStar(starExcept.asInstanceOf[StarExcept].exceptList, select, includeSystemVars) val exceptListStr = exceptList.asScala.map(_.toString) - originList.removeIf(item => exceptListStr.contains(item.toString)) - originList.asScala + originList.asScala.filter(item => !exceptListStr.contains(item.toString)) }) val otherList = super.expandStar(SqlNodeList.of(ZERO, others.asJava), select, includeSystemVars) val expandedList = SqlNodeList.of(ZERO, (otherList.asScala ++ starExceptList).asJava) @@ -31,4 +31,25 @@ class MyValidator(opTab: SqlOperatorTable, catalogReader: CalciteCatalogReader, expandedList } + override def validateSelectList (selectItems: SqlNodeList, select: SqlSelect, targetRowType: RelDataType): RelDataType = { + val (starExcepts, others) = selectItems.asScala.partition(_.isInstanceOf[StarExcept]) + val (starExceptList, starExceptExpandedList) = starExcepts.map(starExcept => { + val originList = super.validateSelectList(SqlNodeList.of(starExcept), select, targetRowType) + val originExpandedList = getRawSelectScope(select).getExpandedSelectList + val exceptList = super.validateSelectList(starExcept.asInstanceOf[StarExcept].exceptList, select, targetRowType) + val exceptListStr = exceptList.getFieldNames.asScala + val exceptExpandedList = getRawSelectScope(select).getExpandedSelectList + val exceptExpandedListStr = exceptExpandedList.asScala.map(_.toString) + (originList.getFieldList.asScala.filter(field => !exceptListStr.contains(field.getName)), + originExpandedList.asScala.filter(item => !exceptExpandedListStr.contains(item.toString))) + }).unzip + val otherList = super.validateSelectList(SqlNodeList.of(ZERO, others.asJava), select, targetRowType) + val newSelectItems = (getRawSelectScope(select).getExpandedSelectList.asScala ++ starExceptExpandedList.flatten).asJava + if (config.identifierExpansion) { + select.setSelectList(SqlNodeList.of(ZERO, newSelectItems)) + } + getRawSelectScope(select).setExpandedSelectList(newSelectItems) + typeFactory.createStructType((otherList.getFieldList.asScala ++ starExceptList.flatten).asJava) + } + } diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala index 1b421e9a2..ab9e34332 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLFunctionResolver.scala @@ -14,6 +14,9 @@ case class PPLFunctionResolver() { case "+" => SqlStdOperatorTable.PLUS case "-" => SqlStdOperatorTable.MINUS case "avg" => SqlStdOperatorTable.AVG + case "min" => SqlStdOperatorTable.MIN + case "max" => SqlStdOperatorTable.MAX + case "count" => SqlStdOperatorTable.COUNT } } } diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlIdentifierExtend.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlIdentifierExtend.scala index 3dc092020..46a3d1818 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlIdentifierExtend.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlIdentifierExtend.scala @@ -7,13 +7,15 @@ package org.opensearch.flint.spark.ppl import com.google.common.collect.ImmutableList import lombok.Getter +import scala.collection.JavaConverters._ import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.parser.SqlParserPos.ZERO import org.apache.calcite.sql.{SqlIdentifier, SqlNodeList, SqlWriter} @Getter -case class StarExcept(exceptList: SqlNodeList)(pos: SqlParserPos) - extends SqlIdentifier(ImmutableList.of(""), pos) { +case class StarExcept(exceptList: SqlNodeList)(names: Seq[String] = Seq(""), pos: SqlParserPos = ZERO) + extends SqlIdentifier(names.asJava, pos) { override def toString: String = { super.toString + " EXCEPT " + exceptList.toString diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala index 990a6078f..cce0d9934 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/SqlNodeBuilder.scala @@ -4,6 +4,7 @@ */ package org.opensearch.flint.spark.ppl +import scala.+: import scala.collection.JavaConverters._ import org.antlr.v4.runtime.CommonTokenStream @@ -13,7 +14,7 @@ import org.opensearch.sql.common.antlr.{CaseInsensitiveCharStream, SyntaxAnalysi import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.sql.parser.SqlParserPos.ZERO -import org.apache.calcite.sql.{SqlBasicCall, SqlIdentifier, SqlLiteral, SqlNode, SqlNodeList, SqlSelect} +import org.apache.calcite.sql._ class PPLParser { @@ -72,6 +73,7 @@ class PPLParser { class PPLAstBuilder extends OpenSearchPPLParserBaseVisitor[SqlNode] { val functionResolver = PPLFunctionResolver(); + var subquery_count = 0; override def visitDmlStatement(ctx: OpenSearchPPLParser.DmlStatementContext): SqlNode = { visit(ctx.queryStatement()) @@ -81,7 +83,13 @@ class PPLAstBuilder extends OpenSearchPPLParserBaseVisitor[SqlNode] { val source = visit(ctx.pplCommands()).asInstanceOf[SqlSelect] val commands = ctx.commands().asScala.map(visit).map(_.asInstanceOf[SqlSelect]) val result: SqlSelect = commands.foldLeft(source) {(pre: SqlNode, cur: SqlSelect) => - cur.setFrom(pre) + cur.getFrom match { + case null => + cur.setFrom(pre) + case join: SqlJoin if join.getLeft.isInstanceOf[SqlBasicCall] => + join.getLeft.asInstanceOf[SqlBasicCall].setOperand(0, pre) + case _ => + } cur } result @@ -168,8 +176,18 @@ class PPLAstBuilder extends OpenSearchPPLParserBaseVisitor[SqlNode] { override def visitStatsCommand(ctx: OpenSearchPPLParser.StatsCommandContext): SqlNode = { val aggList = ctx.statsAggTerm.asScala.map(visit) - val groupByList = visitStatsByClause(ctx.statsByClause()) - new SqlSelect(ZERO, null, SqlNodeList.of(ZERO, (groupByList.getList.asScala ++ aggList).asJava), null, null, groupByList, null, null, null, null, null, null) + val statsByList = visitStatsByClause(ctx.statsByClause()) + if (ctx.EVENTSTATS != null) { + val windowDecl = SqlWindow.create(null, null, statsByList, SqlNodeList.EMPTY, SqlLiteral.createBoolean(false, ZERO), null, null, null, ZERO) + val newAggList = aggList.map { case agg: SqlBasicCall => + if (agg.getKind == SqlKind.AS) { + val aggExpr = agg.getOperandList.get(0) + agg.setOperand(0, new SqlBasicCall(SqlStdOperatorTable.OVER, Seq(aggExpr, windowDecl).asJava.toArray(new Array[SqlNode](0)), ZERO)) + agg + } else new SqlBasicCall(SqlStdOperatorTable.OVER, Seq(agg, windowDecl).asJava.toArray(new Array[SqlNode](0)), ZERO) + } + new SqlSelect(ZERO, null, SqlNodeList.of(ZERO, ((SqlIdentifier.STAR +: newAggList).asJava)), null, null, null, null, null, null, null, null, null) + } else new SqlSelect(ZERO, null, SqlNodeList.of(ZERO, (statsByList.getList.asScala ++ aggList).asJava), null, null, statsByList, null, null, null, null, null, null) } override def visitStatsAggTerm(ctx: OpenSearchPPLParser.StatsAggTermContext): SqlNode = { @@ -196,7 +214,7 @@ class PPLAstBuilder extends OpenSearchPPLParserBaseVisitor[SqlNode] { val expr = visit(clause.expression()) (new SqlBasicCall(SqlStdOperatorTable.AS, Seq(expr, fieldExpr).asJava, ZERO).asInstanceOf[SqlNode], fieldExpr) }).unzip - identList.append(StarExcept(SqlNodeList.of(ZERO, fieldExprList.asJava))(ZERO)) + identList.append(StarExcept(SqlNodeList.of(ZERO, fieldExprList.asJava))()) new SqlSelect(ZERO, null, SqlNodeList.of(ZERO, identList.asJava), null, null, null, null, null, null, null, null, null) } @@ -206,7 +224,59 @@ class PPLAstBuilder extends OpenSearchPPLParserBaseVisitor[SqlNode] { override def visitLookupCommand(ctx: OpenSearchPPLParser.LookupCommandContext): SqlNode = { - super.visitLookupCommand(ctx) + val right = visit(ctx.tableSource) + require(right.isInstanceOf[SqlIdentifier], "LOOKUP table is not an ident") + val rightTableIdent = right.asInstanceOf[SqlIdentifier] + val leftSubqueryName = s"TEMP_SUBQUERY_$subquery_count" + subquery_count += 1 + val leftTableIdent = new SqlIdentifier(leftSubqueryName, ZERO) + val left = new SqlBasicCall(SqlStdOperatorTable.AS, Seq(null, leftTableIdent).asJava, ZERO) + val conditionList = ctx.lookupMappingList.lookupPair.asScala.map { pair => { + val rightKey = visit(pair.inputField) + val leftKey = if (pair.outputField != null) { + visit(pair.outputField) + } else rightKey.clone(ZERO) + require(leftKey.isInstanceOf[SqlIdentifier], "left join key is not an ident") + require(rightKey.isInstanceOf[SqlIdentifier], "right join key is not an ident") + val leftKeyIdent = leftKey.asInstanceOf[SqlIdentifier] + val rightKeyIdent = rightKey.asInstanceOf[SqlIdentifier] + val newLeftKey = new SqlIdentifier((leftSubqueryName +: leftKeyIdent.names.asScala).asJava, ZERO) + val newRightKey = new SqlIdentifier((rightTableIdent.names.asScala ++ rightKeyIdent.names.asScala).asJava, ZERO) + SqlStdOperatorTable.EQUALS.createCall(null, ZERO, newLeftKey, newRightKey) + }} + val conditionCombine = conditionList.reduce((a, b) => SqlStdOperatorTable.AND.createCall(null, ZERO, a, b)) + val sqlJoin = new SqlJoin(ZERO, + left, + SqlLiteral.createBoolean(false, ZERO), + JoinType.LEFT.symbol(ZERO), + right, + JoinConditionType.ON.symbol(ZERO), + conditionCombine) + + val fieldsPair = ctx.outputCandidateList().lookupPair.asScala.map { pair => { + val rightField = visit(pair.inputField) + val leftField = if (pair.outputField != null) { + visit(pair.outputField) + } else rightField.clone(ZERO) + require(leftField.isInstanceOf[SqlIdentifier], "left join key is not an ident") + require(rightField.isInstanceOf[SqlIdentifier], "right join key is not an ident") + val leftKeyIdent = leftField.asInstanceOf[SqlIdentifier] + val rightKeyIdent = rightField.asInstanceOf[SqlIdentifier] + val newLeftField = new SqlIdentifier((leftSubqueryName +: leftKeyIdent.names.asScala).asJava, ZERO) + val newRightField = new SqlIdentifier((rightTableIdent.names.asScala ++ rightKeyIdent.names.asScala).asJava, ZERO) + (newLeftField.asInstanceOf[SqlNode], newRightField.asInstanceOf[SqlNode]) + }} + val leftFields = fieldsPair.map(_._1) + + val selectItems = if (ctx.APPEND != null) { + val newFields = fieldsPair.map { case (leftField, rightField) => new SqlBasicCall(SqlStdOperatorTable.AS, Seq(rightField, leftField).asJava, ZERO) } + SqlNodeList.of(ZERO, (newFields :+ StarExcept(SqlNodeList.of(ZERO, leftFields.asJava))(leftTableIdent.names.asScala :+ "")).asJava) + } else if (ctx.REPLACE != null) { + val newFields = fieldsPair.map { case (leftField, rightField) => new SqlBasicCall(SqlStdOperatorTable.COALESCE, Seq(leftField, rightField).asJava, ZERO) } + SqlNodeList.of(ZERO, (newFields :+ StarExcept(SqlNodeList.of(ZERO, leftFields.asJava))(leftTableIdent.names.asScala :+ "")).asJava) + } else { + SqlNodeList.SINGLETON_STAR + } + new SqlSelect(ZERO, null, selectItems, sqlJoin, null, null, null, null, null, null, null, null) } - } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala index 81dc82c4f..e6636850f 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLSqlNodeTestSuite.scala @@ -37,7 +37,7 @@ class PPLSqlNodeTestSuite with LogicalPlanTestUtils with Matchers { - val t: AbstractTable = new AbstractQueryableTable(classOf[Integer]) { + val t1: AbstractTable = new AbstractQueryableTable(classOf[Integer]) { val enumerable: Enumerable[Integer] = Linq4j.asEnumerable(new util.ArrayList[Integer]()) override def asQueryable[E](queryProvider: QueryProvider, schema: SchemaPlus, tableName: String): Queryable[E] = enumerable.asQueryable.asInstanceOf[Queryable[E]] @@ -46,7 +46,21 @@ class PPLSqlNodeTestSuite val builder: RelDataTypeFactory.Builder = typeFactory.builder builder.add("a", SqlTypeName.INTEGER) builder.add("b", SqlTypeName.INTEGER) - builder.add("c", SqlTypeName.INTEGER) + builder.add("c", SqlTypeName.INTEGER).nullable(true) + builder.build + } + } + + val t2: AbstractTable = new AbstractQueryableTable(classOf[Integer]) { + val enumerable: Enumerable[Integer] = Linq4j.asEnumerable(new util.ArrayList[Integer]()) + + override def asQueryable[E](queryProvider: QueryProvider, schema: SchemaPlus, tableName: String): Queryable[E] = enumerable.asQueryable.asInstanceOf[Queryable[E]] + + override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = { + val builder: RelDataTypeFactory.Builder = typeFactory.builder + builder.add("a", SqlTypeName.INTEGER) + builder.add("a1", SqlTypeName.INTEGER) + builder.add("a2", SqlTypeName.INTEGER).nullable(true) builder.build } } @@ -58,8 +72,8 @@ class PPLSqlNodeTestSuite } val schema: SchemaPlus = Frameworks.createRootSchema(true) - schema.add("table", t) - schema.add("table2", t) + schema.add("table", t1) + schema.add("table2", t2) val config: FrameworkConfig = Frameworks.newConfigBuilder .parserConfig(SqlParser.config.withLex(Lex.MYSQL)) .defaultSchema(schema) @@ -71,18 +85,21 @@ class PPLSqlNodeTestSuite val cluster: RelOptCluster = RelOptCluster.create(requireNonNull(new VolcanoPlanner(config.getCostFactory, config.getContext), "planner"), new RexBuilder(typeFactory)) val catalogReader = createCatalogReader val opTab = SqlOperatorTables.chain(config.getOperatorTable, catalogReader) - val validator = new MyValidator(opTab, catalogReader, typeFactory, config.getSqlValidatorConfig) + val validator = new MyValidator(opTab, catalogReader, typeFactory, config.getSqlValidatorConfig.withIdentifierExpansion(true)) val sqlToRelConverter = new SqlToRelConverter(planner.asInstanceOf[PlannerImpl], validator, catalogReader, cluster, config.getConvertletTable, config.getSqlToRelConverterConfig) val relToSqlConverter = new RelToSqlConverter(DatabaseProduct.CALCITE.getDialect) val pplParserOld = new PPLSyntaxParser() test("test basic command") { val sqlNode = pplParser.parseQuery("source=table | where a = 1| stats avg(b) as avg_b by c | sort c | fields c, avg_b") + //scalastyle:off + println(sqlNode) + //scalastyle:on + val validatedSqlNode = validator.validate(sqlNode) val relNode = sqlToRelConverter.convertQuery(sqlNode, false, true) val convertedSqlNode = relToSqlConverter.visitRoot(relNode.rel).asStatement() //scalastyle:off - println(sqlNode) println(validatedSqlNode) println(relNode) println(convertedSqlNode) @@ -91,6 +108,10 @@ class PPLSqlNodeTestSuite val sqlNode2 = planner.parse(sqlNode.toString()) + //scalastyle:off + println(sqlNode2) + //scalastyle:on + val validatedSqlNode2 = planner.validate(sqlNode2) val relNode2 = planner.rel(validatedSqlNode2) val convertedSqlNode2 = relToSqlConverter.visitRoot(relNode2.rel).asStatement() @@ -98,7 +119,6 @@ class PPLSqlNodeTestSuite // val relNode = planner.rel(sqlNode) // val osPlan = plan(pplParserOld, "source=t") //scalastyle:off - println(sqlNode2) println(validatedSqlNode2) println(relNode2) println(convertedSqlNode2) @@ -107,24 +127,46 @@ class PPLSqlNodeTestSuite test("test eval") { val sqlNode = pplParser.parseQuery("source=table | where a = 1| stats avg(b) as avg_b by c | sort c | eval avg_b = avg_b + 1 | fields c, avg_b") + //scalastyle:off + println(sqlNode) + //scalastyle:on + val validatedSqlNode = validator.validate(sqlNode) val relNode = sqlToRelConverter.convertQuery(sqlNode, false, true) val convertedSqlNode = relToSqlConverter.visitRoot(relNode.rel).asStatement() //scalastyle:off - println(sqlNode) println(validatedSqlNode) println(relNode) println(convertedSqlNode) //scalastyle:on } - test("test eval") { - val sqlNode = pplParser.parseQuery("source=table | where a = 1| stats avg(b) as avg_b by c | sort c | eval avg_b = avg_b + 1 | fields c, avg_b") + test("test lookup") { + val sqlNode = pplParser.parseQuery("source = table | LOOKUP table2 a, a1 as b replace a2 as c") + //scalastyle:off + println(sqlNode) + //scalastyle:on + val validatedSqlNode = validator.validate(sqlNode) val relNode = sqlToRelConverter.convertQuery(sqlNode, false, true) val convertedSqlNode = relToSqlConverter.visitRoot(relNode.rel).asStatement() //scalastyle:off + println(validatedSqlNode) + println(relNode) + println(convertedSqlNode) + //scalastyle:on + } + + test("test window") { + val sqlNode = pplParser.parseQuery("source = table | eventstats min(a) as min_a, max(a) as max_a, count(1) by b") + //scalastyle:off println(sqlNode) + //scalastyle:on + + val validatedSqlNode = validator.validate(sqlNode) + val relNode = sqlToRelConverter.convertQuery(sqlNode, false, true) + val convertedSqlNode = relToSqlConverter.visitRoot(relNode.rel).asStatement() + //scalastyle:off println(validatedSqlNode) println(relNode) println(convertedSqlNode)