This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 97e9bb3ac4b6 [SPARK-48700][SQL] Mode expression for complex types (all 
collations)
97e9bb3ac4b6 is described below

commit 97e9bb3ac4b66711ced640ea466eeea5da6d1fd2
Author: Gideon P <[email protected]>
AuthorDate: Tue Oct 1 15:09:35 2024 +0200

    [SPARK-48700][SQL] Mode expression for complex types (all collations)
    
    ### What changes were proposed in this pull request?
    
    Add support for complex types with subfields that are collated strings, for 
the mode operator.
    
    ### Why are the changes needed?
    
    Full support for collations as per SPARK-48700
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    ### How was this patch tested?
    
    Unit tests only, so far.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #47154 from GideonPotok/collationmodecomplex.
    
    Lead-authored-by: Gideon P <[email protected]>
    Co-authored-by: Gideon Potok <[email protected]>
    Signed-off-by: Max Gekk <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |  10 +
 .../sql/catalyst/expressions/aggregate/Mode.scala  |  85 +++++--
 .../spark/sql/CollationSQLExpressionsSuite.scala   | 257 ++++++++++++++-------
 3 files changed, 250 insertions(+), 102 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index fcaf2b1d9d30..3786643125a9 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -631,6 +631,11 @@
       "Cannot process input data types for the expression: <expression>."
     ],
     "subClass" : {
+      "BAD_INPUTS" : {
+        "message" : [
+          "The input data types to <functionName> must be valid, but found the 
input types <dataType>."
+        ]
+      },
       "MISMATCHED_TYPES" : {
         "message" : [
           "All input types must be the same except nullable, containsNull, 
valueContainsNull flags, but found the input types <inputTypes>."
@@ -1011,6 +1016,11 @@
           "The input of <functionName> can't be <dataType> type data."
         ]
       },
+      "UNSUPPORTED_MODE_DATA_TYPE" : {
+        "message" : [
+          "The <mode> does not support the <child> data type, because there is 
a \"MAP\" type with keys and/or values that have collated sub-fields."
+        ]
+      },
       "UNSUPPORTED_UDF_INPUT_TYPE" : {
         "message" : [
           "UDFs do not support '<dataType>' as an input data type."
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
index e254a670991a..8998348f0571 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
@@ -17,14 +17,17 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import org.apache.spark.SparkIllegalArgumentException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, 
TypeCheckResult, UnresolvedWithinGroup}
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, 
Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr
 import org.apache.spark.sql.catalyst.trees.UnaryLike
 import org.apache.spark.sql.catalyst.types.PhysicalDataType
-import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData, 
UnsafeRowUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, 
GenericArrayData, UnsafeRowUtils}
+import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, 
BooleanType, DataType, StringType}
+import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, 
BooleanType, DataType, MapType, StringType, StructField, StructType}
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.collection.OpenHashMap
 
@@ -50,17 +53,20 @@ case class Mode(
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    if (UnsafeRowUtils.isBinaryStable(child.dataType) || 
child.dataType.isInstanceOf[StringType]) {
+    // TODO: SPARK-49358: Mode expression for map type with collated fields
+    if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
+      !child.dataType.existsRecursively(f => f.isInstanceOf[MapType] &&
+        !UnsafeRowUtils.isBinaryStable(f))) {
       /*
         * The Mode class uses collation awareness logic to handle string data.
-        * Complex types with collated fields are not yet supported.
+        * All complex types except MapType with collated fields are supported.
        */
-      // TODO: SPARK-48700: Mode expression for complex types (all collations)
       super.checkInputDataTypes()
     } else {
-      TypeCheckResult.TypeCheckFailure("The input to the function 'mode' was" +
-        " a type of binary-unstable type that is " +
-        s"not currently supported by ${prettyName}.")
+      TypeCheckResult.DataTypeMismatch("UNSUPPORTED_MODE_DATA_TYPE",
+        messageParameters =
+          Map("child" -> toSQLType(child.dataType),
+            "mode" -> toSQLId(prettyName)))
     }
   }
 
@@ -86,6 +92,54 @@ case class Mode(
     buffer
   }
 
+  private def getCollationAwareBuffer(
+      childDataType: DataType,
+      buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = {
+    def groupAndReduceBuffer(groupingFunction: AnyRef => _): Iterable[(AnyRef, 
Long)] = {
+      buffer.groupMapReduce(t =>
+        groupingFunction(t._1))(x => x)((x, y) => (x._1, x._2 + y._2)).values
+    }
+    def determineBufferingFunction(
+        childDataType: DataType): Option[AnyRef => _] = {
+      childDataType match {
+        case _ if UnsafeRowUtils.isBinaryStable(child.dataType) => None
+        case _ => Some(collationAwareTransform(_, childDataType))
+      }
+    }
+    
determineBufferingFunction(childDataType).map(groupAndReduceBuffer).getOrElse(buffer)
+  }
+
+  protected[sql] def collationAwareTransform(data: AnyRef, dataType: 
DataType): AnyRef = {
+    dataType match {
+      case _ if UnsafeRowUtils.isBinaryStable(dataType) => data
+      case st: StructType =>
+        
processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields))
+      case at: ArrayType => processArrayTypeWithBuffer(at, 
data.asInstanceOf[ArrayData])
+      case st: StringType =>
+        CollationFactory.getCollationKey(data.asInstanceOf[UTF8String], 
st.collationId)
+      case _ =>
+        throw new SparkIllegalArgumentException(
+          errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS",
+          messageParameters = Map(
+            "expression" -> toSQLExpr(this),
+            "functionName" -> toSQLType(prettyName),
+            "dataType" -> toSQLType(child.dataType))
+        )
+    }
+  }
+
+  private def processStructTypeWithBuffer(
+      tuples: Seq[(Any, StructField)]): Seq[Any] = {
+    tuples.map(t => collationAwareTransform(t._1.asInstanceOf[AnyRef], 
t._2.dataType))
+  }
+
+  private def processArrayTypeWithBuffer(
+      a: ArrayType,
+      data: ArrayData): Seq[Any] = {
+    (0 until data.numElements()).map(i =>
+      collationAwareTransform(data.get(i, a.elementType), a.elementType))
+  }
+
   override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
     if (buffer.isEmpty) {
       return null
@@ -102,17 +156,12 @@ case class Mode(
       *  to a single value (the sum of the counts), and finally reduces the 
groups to a single map.
       *
       * The new map is then used in the rest of the Mode evaluation logic.
+      *
+      * It is expected to work for all simple and complex types with
+      *  collated fields, except for MapType (temporarily).
       */
-    val collationAwareBuffer = child.dataType match {
-      case c: StringType if
-        !CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality 
=>
-        val collationId = c.collationId
-        val modeMap = buffer.toSeq.groupMapReduce {
-         case (k, _) => 
CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
-        }(x => x)((x, y) => (x._1, x._2 + y._2)).values
-        modeMap
-      case _ => buffer
-    }
+    val collationAwareBuffer = getCollationAwareBuffer(child.dataType, buffer)
+
     reverseOpt.map { reverse =>
       val defaultKeyOrdering = if (reverse) {
         
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 941d5cd31db4..9930709cd8bf 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -19,11 +19,12 @@ package org.apache.spark.sql
 
 import java.sql.{Date, Timestamp}
 import java.text.SimpleDateFormat
+import java.util.Locale
 
 import scala.collection.immutable.Seq
 
-import org.apache.spark.{SparkConf, SparkException, 
SparkIllegalArgumentException, SparkRuntimeException}
-import org.apache.spark.sql.catalyst.ExtendedAnalysisException
+import org.apache.spark.{SparkConf, SparkException, 
SparkIllegalArgumentException, SparkRuntimeException, SparkThrowable}
+import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
 import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
@@ -1752,7 +1753,7 @@ class CollationSQLExpressionsSuite
       UTF8StringModeTestCase("unicode_ci", bufferValuesUTF8String, "b"),
       UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a"))
 
-    testCasesUTF8String.foreach(t => {
+    testCasesUTF8String.foreach ( t => {
       val buffer = new OpenHashMap[AnyRef, Long](5)
       val myMode = Mode(child = Literal.create("some_column_name", 
StringType(t.collationId)))
       t.bufferValues.foreach { case (k, v) => buffer.update(k, v) }
@@ -1760,6 +1761,40 @@ class CollationSQLExpressionsSuite
     })
   }
 
+  test("Support Mode.eval(buffer) with complex types") {
+    case class UTF8StringModeTestCase[R](
+        collationId: String,
+        bufferValues: Map[InternalRow, Long],
+        result: R)
+
+    val bufferValuesUTF8String: Map[Any, Long] = Map(
+      UTF8String.fromString("a") -> 5L,
+      UTF8String.fromString("b") -> 4L,
+      UTF8String.fromString("B") -> 3L,
+      UTF8String.fromString("d") -> 2L,
+      UTF8String.fromString("e") -> 1L)
+
+    val bufferValuesComplex = bufferValuesUTF8String.map{
+      case (k, v) => (InternalRow.fromSeq(Seq(k, k, k)), v)
+    }
+    val testCasesUTF8String = Seq(
+      UTF8StringModeTestCase("utf8_binary", bufferValuesComplex, "[a,a,a]"),
+      UTF8StringModeTestCase("UTF8_LCASE", bufferValuesComplex, "[b,b,b]"),
+      UTF8StringModeTestCase("unicode_ci", bufferValuesComplex, "[b,b,b]"),
+      UTF8StringModeTestCase("unicode", bufferValuesComplex, "[a,a,a]"))
+
+    testCasesUTF8String.foreach { t =>
+      val buffer = new OpenHashMap[AnyRef, Long](5)
+      val myMode = Mode(child = Literal.create(null, StructType(Seq(
+        StructField("f1", StringType(t.collationId), true),
+        StructField("f2", StringType(t.collationId), true),
+        StructField("f3", StringType(t.collationId), true)
+      ))))
+      t.bufferValues.foreach { case (k, v) => buffer.update(k, v) }
+      assert(myMode.eval(buffer).toString.toLowerCase() == 
t.result.toLowerCase())
+    }
+  }
+
   test("Support mode for string expression with collated strings in struct") {
     case class ModeTestCase[R](collationId: String, bufferValues: Map[String, 
Long], result: R)
     val testCases = Seq(
@@ -1780,33 +1815,7 @@ class CollationSQLExpressionsSuite
           t.collationId + ", f2: INT>) USING parquet")
         sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
         val query = s"SELECT lower(mode(i).f1) FROM ${tableName}"
-        if(t.collationId == "UTF8_LCASE" ||
-          t.collationId == "unicode_ci" ||
-          t.collationId == "unicode") {
-          // Cannot resolve "mode(i)" due to data type mismatch:
-          // Input to function mode was a complex type with strings collated 
on non-binary
-          // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 
pos 13;
-          val params = Seq(("sqlExpr", "\"mode(i)\""),
-            ("msg", "The input to the function 'mode'" +
-              " was a type of binary-unstable type that is not currently 
supported by mode."),
-            ("hint", "")).toMap
-          checkError(
-            exception = intercept[AnalysisException] {
-              sql(query)
-            },
-            condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
-            parameters = params,
-            queryContext = Array(
-              ExpectedContext(objectType = "",
-                objectName = "",
-                startIndex = 13,
-                stopIndex = 19,
-                fragment = "mode(i)")
-            )
-          )
-        } else {
-          checkAnswer(sql(query), Row(t.result))
-        }
+        checkAnswer(sql(query), Row(t.result))
       }
     })
   }
@@ -1819,47 +1828,21 @@ class CollationSQLExpressionsSuite
       ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
       ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
     )
-    testCases.foreach(t => {
+    testCases.foreach { t =>
       val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
         (0L to numRepeats).map(_ => s"named_struct('f1', " +
           s"named_struct('f2', collate('$elt', '${t.collationId}')), 'f3', 
1)").mkString(",")
       }.mkString(",")
 
-      val tableName = s"t_${t.collationId}_mode_nested_struct"
+      val tableName = s"t_${t.collationId}_mode_nested_struct1"
       withTable(tableName) {
         sql(s"CREATE TABLE ${tableName}(i STRUCT<f1: STRUCT<f2: STRING COLLATE 
" +
           t.collationId + ">, f3: INT>) USING parquet")
         sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
         val query = s"SELECT lower(mode(i).f1.f2) FROM ${tableName}"
-        if(t.collationId == "UTF8_LCASE" ||
-          t.collationId == "unicode_ci" ||
-          t.collationId == "unicode") {
-          // Cannot resolve "mode(i)" due to data type mismatch:
-          // Input to function mode was a complex type with strings collated 
on non-binary
-          // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 
pos 13;
-          val params = Seq(("sqlExpr", "\"mode(i)\""),
-            ("msg", "The input to the function 'mode' " +
-              "was a type of binary-unstable type that is not currently 
supported by mode."),
-            ("hint", "")).toMap
-          checkError(
-            exception = intercept[AnalysisException] {
-              sql(query)
-            },
-            condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
-            parameters = params,
-            queryContext = Array(
-              ExpectedContext(objectType = "",
-                objectName = "",
-                startIndex = 13,
-                stopIndex = 19,
-                fragment = "mode(i)")
-            )
-          )
-        } else {
-          checkAnswer(sql(query), Row(t.result))
-        }
+        checkAnswer(sql(query), Row(t.result))
       }
-    })
+    }
   }
 
   test("Support mode for string expression with collated strings in array 
complex type") {
@@ -1870,44 +1853,150 @@ class CollationSQLExpressionsSuite
       ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
       ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
     )
-    testCases.foreach(t => {
+    testCases.foreach { t =>
+      val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+        (0L to numRepeats).map(_ => s"array(named_struct('f2', " +
+          s"collate('$elt', '${t.collationId}'), 'f3', 1))").mkString(",")
+      }.mkString(",")
+
+      val tableName = s"t_${t.collationId}_mode_nested_struct2"
+      withTable(tableName) {
+        sql(s"CREATE TABLE ${tableName}(" +
+          s"i ARRAY< STRUCT<f2: STRING COLLATE ${t.collationId}, f3: INT>>)" +
+          s" USING parquet")
+        sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+        val query = s"SELECT lower(element_at(mode(i).f2, 1)) FROM 
${tableName}"
+        checkAnswer(sql(query), Row(t.result))
+      }
+    }
+  }
+
+  test("Support mode for string expression with collated strings in 3D array 
type") {
+    case class ModeTestCase[R](collationId: String, bufferValues: Map[String, 
Long], result: R)
+    val testCases = Seq(
+      ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+      ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+    )
+    testCases.foreach { t =>
+      val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+        (0L to numRepeats).map(_ =>
+          s"array(array(array(collate('$elt', 
'${t.collationId}'))))").mkString(",")
+      }.mkString(",")
+
+      val tableName = s"t_${t.collationId}_mode_nested_3d_array"
+      withTable(tableName) {
+        sql(s"CREATE TABLE ${tableName}(i ARRAY<ARRAY<ARRAY" +
+          s"<STRING COLLATE ${t.collationId}>>>) USING parquet")
+        sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+        val query = s"SELECT lower(" +
+          s"element_at(element_at(element_at(mode(i),1),1),1)) FROM 
${tableName}"
+        checkAnswer(sql(query), Row(t.result))
+      }
+    }
+  }
+
+  test("Support mode for string expression with collated complex type - Highly 
nested") {
+    case class ModeTestCase[R](collationId: String, bufferValues: Map[String, 
Long], result: R)
+    val testCases = Seq(
+      ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+      ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+    )
+    testCases.foreach { t =>
       val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
         (0L to numRepeats).map(_ => s"array(named_struct('s1', 
named_struct('a2', " +
           s"array(collate('$elt', '${t.collationId}'))), 'f3', 
1))").mkString(",")
       }.mkString(",")
 
-      val tableName = s"t_${t.collationId}_mode_nested_struct"
+      val tableName = s"t_${t.collationId}_mode_highly_nested_struct"
       withTable(tableName) {
         sql(s"CREATE TABLE ${tableName}(" +
           s"i ARRAY<STRUCT<s1: STRUCT<a2: ARRAY<STRING COLLATE 
${t.collationId}>>, f3: INT>>)" +
           s" USING parquet")
         sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
         val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2, 
1)) FROM ${tableName}"
-        if(t.collationId == "UTF8_LCASE" ||
-          t.collationId == "unicode_ci" || t.collationId == "unicode") {
-          val params = Seq(("sqlExpr", "\"mode(i)\""),
-            ("msg", "The input to the function 'mode' was a type" +
-              " of binary-unstable type that is not currently supported by 
mode."),
-            ("hint", "")).toMap
-          checkError(
-            exception = intercept[AnalysisException] {
-              sql(query)
-            },
-            condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
-            parameters = params,
-            queryContext = Array(
-              ExpectedContext(objectType = "",
-                objectName = "",
-                startIndex = 35,
-                stopIndex = 41,
-                fragment = "mode(i)")
-            )
-          )
-        } else {
+
           checkAnswer(sql(query), Row(t.result))
+      }
+    }
+  }
+
+  test("Support mode expression with collated in recursively nested struct 
with map with keys") {
+    case class ModeTestCase(collationId: String, bufferValues: Map[String, 
Long], result: String)
+    Seq(
+      ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 
1}"),
+      ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 
1}"),
+      ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 
1}"),
+      ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 
1}")
+    ).foreach { t1 =>
+      def checkThisError(t: ModeTestCase, query: String): Any = {
+        val c = s"STRUCT<m1: MAP<STRING COLLATE 
${t.collationId.toUpperCase(Locale.ROOT)}, INT>>"
+        val c1 = s"\"${c}\""
+        checkError(
+          exception = intercept[SparkThrowable] {
+            sql(query).collect()
+          },
+          condition = "DATATYPE_MISMATCH.UNSUPPORTED_MODE_DATA_TYPE",
+          parameters = Map(
+            ("sqlExpr", "\"mode(i)\""),
+            ("child", c1),
+            ("mode", "`mode`")),
+          queryContext = Seq(ExpectedContext("mode(i)", 18, 24)).toArray
+        )
+      }
+
+      def getValuesToAdd(t: ModeTestCase): String = {
+        val valuesToAdd = t.bufferValues.map {
+          case (elt, numRepeats) =>
+            (0L to numRepeats).map(i =>
+              s"named_struct('m1', map(collate('$elt', '${t.collationId}'), 
1))"
+            ).mkString(",")
+        }.mkString(",")
+        valuesToAdd
+      }
+      val tableName = s"t_${t1.collationId}_mode_nested_map_struct1"
+      withTable(tableName) {
+        sql(s"CREATE TABLE ${tableName}(" +
+          s"i STRUCT<m1: MAP<STRING COLLATE ${t1.collationId}, INT>>) USING 
parquet")
+        sql(s"INSERT INTO ${tableName} VALUES ${getValuesToAdd(t1)}")
+        val query = "SELECT lower(cast(mode(i).m1 as string))" +
+          s" FROM ${tableName}"
+        if (t1.collationId == "utf8_binary") {
+          checkAnswer(sql(query), Row(t1.result))
+        } else {
+          checkThisError(t1, query)
         }
       }
-    })
+    }
+  }
+
+  test("UDT with collation  - Mode (throw exception)") {
+    case class ModeTestCase(collationId: String, bufferValues: Map[String, 
Long], result: String)
+    Seq(
+      ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+      ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+    ).foreach { t1 =>
+        checkError(
+          exception = intercept[SparkIllegalArgumentException] {
+          Mode(
+              child = Literal.create(null,
+                MapType(StringType(t1.collationId), IntegerType))
+            ).collationAwareTransform(
+              data = Map.empty[String, Any],
+              dataType = MapType(StringType(t1.collationId), IntegerType)
+            )
+          },
+          condition = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS",
+          parameters = Map(
+            "expression" -> "\"mode(NULL)\"",
+            "functionName" -> "\"MODE\"",
+            "dataType" -> s"\"MAP<STRING COLLATE 
${t1.collationId.toUpperCase()}, INT>\"")
+         )
+      }
   }
 
   test("SPARK-48430: Map value extraction with collations") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to