diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index cffdd28722241..3affd91dd3b82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -22,7 +22,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Overlay} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Overlay, StringLPad, StringRPad} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, StringType} @@ -52,6 +52,11 @@ object CollationTypeCasts extends TypeCoercionRule { overlay.withNewChildren(collateToSingleType(Seq(overlay.input, overlay.replace)) ++ Seq(overlay.pos, overlay.len)) + case stringPadExpr @ (_: StringRPad | _: StringLPad) => + val Seq(str, len, pad) = stringPadExpr.children + val Seq(newStr, newPad) = collateToSingleType(Seq(str, pad)) + stringPadExpr.withNewChildren(Seq(newStr, len, newPad)) + case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | _: Coalesce | _: BinaryExpression | _: ConcatWs) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 2b7703ed82b37..cd21a6f5fdc21 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1586,7 +1586,8 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) override def third: Expression = pad override def dataType: DataType = str.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) @@ -1665,7 +1666,8 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera override def third: Expression = pad override def dataType: DataType = str.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 123d642ed4cd7..9c207df95dadb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -534,6 +534,80 @@ class CollationStringExpressionsSuite } } + test("Support StringRPad string expressions with collation") { + // Supported collations + case class StringRPadTestCase[R](s: String, len: Int, pad: String, c: String, result: R) + val testCases = Seq( + StringRPadTestCase("", 5, " ", "UTF8_BINARY", " "), + StringRPadTestCase("abc", 5, " ", "UNICODE", "abc "), + StringRPadTestCase("Hello", 7, "Wörld", "UTF8_BINARY_LCASE", "HelloWö"), + StringRPadTestCase("1234567890", 5, "aaaAAa", "UNICODE_CI", "12345"), + StringRPadTestCase("aaAA", 2, " ", "UTF8_BINARY", "aa"), + StringRPadTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ℀℃", 2, "1", "UTF8_BINARY_LCASE", "ÀÃ"), + StringRPadTestCase("ĂȦÄäåäá", 20, "ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", "UNICODE", "ĂȦÄäåäáÀÃÂĀĂȦÄäåäáâã"), + StringRPadTestCase("aȦÄä", 8, "a1", "UNICODE_CI", "aȦÄäa1a1") + ) + testCases.foreach(t => { + val query = s"SELECT rpad(collate('${t.s}', '${t.c}')," + + s" ${t.len}, collate('${t.pad}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + // Implicit casting + checkAnswer( + sql(s"SELECT rpad(collate('${t.s}', '${t.c}'), ${t.len}, '${t.pad}')"), + Row(t.result)) + checkAnswer( + sql(s"SELECT rpad('${t.s}', ${t.len}, collate('${t.pad}', '${t.c}'))"), + Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT rpad(collate('abcde', 'UNICODE_CI'),1,collate('C', 'UTF8_BINARY_LCASE'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } + + test("Support StringLPad string expressions with collation") { + // Supported collations + case class StringLPadTestCase[R](s: String, len: Int, pad: String, c: String, result: R) + val testCases = Seq( + StringLPadTestCase("", 5, " ", "UTF8_BINARY", " "), + StringLPadTestCase("abc", 5, " ", "UNICODE", " abc"), + StringLPadTestCase("Hello", 7, "Wörld", "UTF8_BINARY_LCASE", "WöHello"), + StringLPadTestCase("1234567890", 5, "aaaAAa", "UNICODE_CI", "12345"), + StringLPadTestCase("aaAA", 2, " ", "UTF8_BINARY", "aa"), + StringLPadTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ℀℃", 2, "1", "UTF8_BINARY_LCASE", "ÀÃ"), + StringLPadTestCase("ĂȦÄäåäá", 20, "ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", "UNICODE", "ÀÃÂĀĂȦÄäåäáâãĂȦÄäåäá"), + StringLPadTestCase("aȦÄä", 8, "a1", "UNICODE_CI", "a1a1aȦÄä") + ) + testCases.foreach(t => { + val query = s"SELECT lpad(collate('${t.s}', '${t.c}')," + + s" ${t.len}, collate('${t.pad}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + // Implicit casting + checkAnswer( + sql(s"SELECT lpad(collate('${t.s}', '${t.c}'), ${t.len}, '${t.pad}')"), + Row(t.result)) + checkAnswer( + sql(s"SELECT lpad('${t.s}', ${t.len}, collate('${t.pad}', '${t.c}'))"), + Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT lpad(collate('abcde', 'UNICODE_CI'),1,collate('C', 'UTF8_BINARY_LCASE'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } + + test("Support StringLPad string expressions with explicit collation on second parameter") { + val query = "SELECT lpad('abc', collate('5', 'unicode_ci'), ' ')" + checkAnswer(sql(query), Row(" abc")) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(0))) + } + // TODO: Add more tests for other string expressions }