This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4dad2170b05c [SPARK-47356][SQL] Add support for ConcatWs & Elt (all 
collations)
4dad2170b05c is described below

commit 4dad2170b05c04faf1da550ab3fb8c52a61b8be7
Author: Mihailo Milosevic <mihailo.milose...@databricks.com>
AuthorDate: Tue Apr 16 21:21:24 2024 +0800

    [SPARK-47356][SQL] Add support for ConcatWs & Elt (all collations)
    
    ### What changes were proposed in this pull request?
    Addition of support for ConcatWs and Elt expressions.
    
    ### Why are the changes needed?
    We need to enable these functions to support collations in order to scope 
all functions.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, both expressions now will not return error when called with collated 
strings.
    
    ### How was this patch tested?
    Addition of tests to `CollationStringExpressionsSuite`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #46061 from mihailom-db/SPARK-47356.
    
    Authored-by: Mihailo Milosevic <mihailo.milose...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/analysis/CollationTypeCasts.scala |  5 ++-
 .../catalyst/expressions/stringExpressions.scala   | 25 ++++++------
 .../sql/CollationStringExpressionsSuite.scala      | 46 ++++++++++++++++------
 3 files changed, 51 insertions(+), 25 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index 1a14b4227de8..795e8a696b01 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -22,7 +22,7 @@ import javax.annotation.Nullable
 import scala.annotation.tailrec
 
 import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, 
haveSameType}
-import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, 
CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, 
Greatest, If, In, InSubquery, Least}
+import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, 
CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, 
Expression, Greatest, If, In, InSubquery, Least}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
@@ -45,6 +45,9 @@ object CollationTypeCasts extends TypeCoercionRule {
           caseWhenExpr.elseValue.map(e => castStringType(e, 
outputStringType).getOrElse(e))
         CaseWhen(newBranches, newElseValue)
 
+    case eltExpr: Elt =>
+      eltExpr.withNewChildren(eltExpr.children.head +: 
collateToSingleType(eltExpr.children.tail))
+
     case otherExpr @ (
       _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: 
Greatest | _: Least |
       _: Coalesce | _: BinaryExpression | _: ConcatWs) =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 34e8f3f40859..4fe57b4f8f02 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -37,7 +37,7 @@ import 
org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO
 import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport, 
GenericArrayData, TypeUtils}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.types.StringTypeAnyCollation
+import org.apache.spark.sql.internal.types.{AbstractArrayType, 
StringTypeAnyCollation}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.UTF8StringBuilder
 import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -79,11 +79,12 @@ case class ConcatWs(children: Seq[Expression])
 
   /** The 1st child (separator) is str, and rest are either str or array of 
str. */
   override def inputTypes: Seq[AbstractDataType] = {
-    val arrayOrStr = TypeCollection(ArrayType(StringType), StringType)
-    StringType +: Seq.fill(children.size - 1)(arrayOrStr)
+    val arrayOrStr =
+      TypeCollection(AbstractArrayType(StringTypeAnyCollation), 
StringTypeAnyCollation)
+    StringTypeAnyCollation +: Seq.fill(children.size - 1)(arrayOrStr)
   }
 
-  override def dataType: DataType = StringType
+  override def dataType: DataType = children.head.dataType
 
   override def nullable: Boolean = children.head.nullable
   override def foldable: Boolean = children.forall(_.foldable)
@@ -102,7 +103,8 @@ case class ConcatWs(children: Seq[Expression])
     val flatInputs = children.flatMap { child =>
       child.eval(input) match {
         case s: UTF8String => Iterator(s)
-        case arr: ArrayData => arr.toArray[UTF8String](StringType)
+        case arr: ArrayData =>
+          
arr.toArray[UTF8String](child.dataType.asInstanceOf[ArrayType].elementType)
         case null => Iterator(null.asInstanceOf[UTF8String])
       }
     }
@@ -110,7 +112,7 @@ case class ConcatWs(children: Seq[Expression])
   }
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
-    if (children.forall(_.dataType == StringType)) {
+    if (children.forall(_.dataType.isInstanceOf[StringType])) {
       // All children are strings. In that case we can construct a fixed size 
array.
       val evals = children.map(_.genCode(ctx))
       val separator = evals.head
@@ -163,7 +165,7 @@ case class ConcatWs(children: Seq[Expression])
            """
 
         val (varCount, varBuild) = child.dataType match {
-          case StringType =>
+          case _: StringType =>
             val reprForValueCast = s"((UTF8String) $reprForValue)"
             ("", // we count all the StringType arguments num at once below.
               if (eval.isNull == TrueLiteral) {
@@ -171,7 +173,7 @@ case class ConcatWs(children: Seq[Expression])
               } else {
                 s"$array[$idxVararg ++] = $reprForIsNull ? (UTF8String) null : 
$reprForValueCast;"
               })
-          case _: ArrayType =>
+          case arr: ArrayType =>
             val reprForValueCast = s"((ArrayData) $reprForValue)"
             val size = ctx.freshName("n")
             if (eval.isNull == TrueLiteral) {
@@ -187,7 +189,7 @@ case class ConcatWs(children: Seq[Expression])
                 if (!$reprForIsNull) {
                   final int $size = $reprForValueCast.numElements();
                   for (int j = 0; j < $size; j ++) {
-                    $array[$idxVararg ++] = 
${CodeGenerator.getValue(reprForValueCast, StringType, "j")};
+                    $array[$idxVararg ++] = 
${CodeGenerator.getValue(reprForValueCast, arr.elementType, "j")};
                   }
                 }
                 """)
@@ -235,7 +237,7 @@ case class ConcatWs(children: Seq[Expression])
         boolean[] $isNullArgs = new boolean[${children.length - 1}];
         Object[] $valueArgs = new Object[${children.length - 1}];
         $argBuilds
-        int $varargNum = ${children.count(_.dataType == StringType) - 1};
+        int $varargNum = ${children.count(_.dataType.isInstanceOf[StringType]) 
- 1};
         int $idxVararg = 0;
         $varargCounts
         UTF8String[] $array = new UTF8String[$varargNum];
@@ -287,7 +289,8 @@ case class Elt(
   /** This expression is always nullable because it returns null if index is 
out of range. */
   override def nullable: Boolean = true
 
-  override def dataType: DataType = 
inputExprs.map(_.dataType).headOption.getOrElse(StringType)
+  override def dataType: DataType =
+    
inputExprs.map(_.dataType).headOption.getOrElse(SQLConf.get.defaultStringType)
 
   override def checkInputDataTypes(): TypeCheckResult = {
     if (children.size < 2) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 0dbd4c0ba713..9237c8a25a5d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -32,7 +32,10 @@ class CollationStringExpressionsSuite
     // Supported collations
     case class ConcatWsTestCase[R](s: String, a: Array[String], c: String, 
result: R)
     val testCases = Seq(
-      ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_BINARY", "Spark SQL")
+      ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_BINARY", "Spark SQL"),
+      ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_BINARY_LCASE", "Spark 
SQL"),
+      ConcatWsTestCase(" ", Array("Spark", "SQL"), "UNICODE", "Spark SQL"),
+      ConcatWsTestCase(" ", Array("Spark", "SQL"), "UNICODE_CI", "Spark SQL")
     )
     testCases.foreach(t => {
       val arrCollated = t.a.map(s => s"collate('$s', '${t.c}')").mkString(", ")
@@ -49,22 +52,39 @@ class CollationStringExpressionsSuite
       checkAnswer(sql(query), Row(t.result))
       assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
     })
-    // Unsupported collations
-    case class ConcatWsTestFail(s: String, a: Array[String], c: String)
-    val failCases = Seq(
-      ConcatWsTestFail(" ", Array("ABC", "%b%"), "UTF8_BINARY_LCASE"),
-      ConcatWsTestFail(" ", Array("ABC", "%B%"), "UNICODE"),
-      ConcatWsTestFail(" ", Array("ABC", "%b%"), "UNICODE_CI")
+    // Collation mismatch
+    val collationMismatch = intercept[AnalysisException] {
+      sql("SELECT concat_ws(' ',collate('Spark', 
'UTF8_BINARY_LCASE'),collate('SQL', 'UNICODE'))")
+    }
+    assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+  }
+
+  test("Support Elt string expression with collation") {
+    // Supported collations
+    case class EltTestCase[R](index: Int, inputs: Array[String], c: String, 
result: R)
+    val testCases = Seq(
+      EltTestCase(1, Array("Spark", "SQL"), "UTF8_BINARY", "Spark"),
+      EltTestCase(1, Array("Spark", "SQL"), "UTF8_BINARY_LCASE", "Spark"),
+      EltTestCase(2, Array("Spark", "SQL"), "UNICODE", "SQL"),
+      EltTestCase(2, Array("Spark", "SQL"), "UNICODE_CI", "SQL")
     )
-    failCases.foreach(t => {
-      val arrCollated = t.a.map(s => s"collate('$s', '${t.c}')").mkString(", ")
-      val query = s"SELECT concat_ws(collate('${t.s}', '${t.c}'), 
$arrCollated)"
-      val unsupportedCollation = intercept[AnalysisException] { sql(query) }
-      assert(unsupportedCollation.getErrorClass === 
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+    testCases.foreach(t => {
+      var query = s"SELECT elt(${t.index}, collate('${t.inputs(0)}', 
'${t.c}')," +
+        s" collate('${t.inputs(1)}', '${t.c}'))"
+      // Result & data type
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+      // Implicit casting
+      query = s"SELECT elt(${t.index}, collate('${t.inputs(0)}', '${t.c}'), 
'${t.inputs(1)}')"
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+      query = s"SELECT elt(${t.index}, '${t.inputs(0)}', 
collate('${t.inputs(1)}', '${t.c}'))"
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
     })
     // Collation mismatch
     val collationMismatch = intercept[AnalysisException] {
-      sql("SELECT concat_ws(' ',collate('Spark', 
'UTF8_BINARY_LCASE'),collate('SQL', 'UNICODE'))")
+      sql("SELECT elt(0 ,collate('Spark', 'UTF8_BINARY_LCASE'), collate('SQL', 
'UNICODE'))")
     }
     assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to