This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 38f067dfcef9 [SPARK-49358][SQL] Mode expression for map types with 
collated strings
38f067dfcef9 is described below

commit 38f067dfcef9ae53330fdd73ea89ebba614c965b
Author: Uros Bojanic <[email protected]>
AuthorDate: Thu Oct 3 11:36:26 2024 +0200

    [SPARK-49358][SQL] Mode expression for map types with collated strings
    
    ### What changes were proposed in this pull request?
    Introduce support for collated string in map types for `mode` expression.
    
    ### Why are the changes needed?
    Complete complex type handling for `mode` expression.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, `mode` expression can now handle map types with collated strings.
    
    ### How was this patch tested?
    New tests in `CollationSQLExpressionsSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48326 from uros-db/mode-map.
    
    Authored-by: Uros Bojanic <[email protected]>
    Signed-off-by: Max Gekk <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |  5 --
 .../sql/catalyst/expressions/aggregate/Mode.scala  | 38 ++++++---------
 .../spark/sql/CollationSQLExpressionsSuite.scala   | 54 ++--------------------
 3 files changed, 19 insertions(+), 78 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 3786643125a9..12666fe4ff62 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -1016,11 +1016,6 @@
           "The input of <functionName> can't be <dataType> type data."
         ]
       },
-      "UNSUPPORTED_MODE_DATA_TYPE" : {
-        "message" : [
-          "The <mode> does not support the <child> data type, because there is 
a \"MAP\" type with keys and/or values that have collated sub-fields."
-        ]
-      },
       "UNSUPPORTED_UDF_INPUT_TYPE" : {
         "message" : [
           "UDFs do not support '<dataType>' as an input data type."
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
index 8998348f0571..97add0b8e45b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
@@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.SparkIllegalArgumentException
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, 
TypeCheckResult, UnresolvedWithinGroup}
+import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, 
UnresolvedWithinGroup}
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, 
Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
 import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr
 import org.apache.spark.sql.catalyst.trees.UnaryLike
 import org.apache.spark.sql.catalyst.types.PhysicalDataType
-import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, 
GenericArrayData, UnsafeRowUtils}
-import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, 
GenericArrayData, MapData, UnsafeRowUtils}
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, 
BooleanType, DataType, MapType, StringType, StructField, StructType}
 import org.apache.spark.unsafe.types.UTF8String
@@ -52,24 +52,6 @@ case class Mode(
 
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    // TODO: SPARK-49358: Mode expression for map type with collated fields
-    if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
-      !child.dataType.existsRecursively(f => f.isInstanceOf[MapType] &&
-        !UnsafeRowUtils.isBinaryStable(f))) {
-      /*
-        * The Mode class uses collation awareness logic to handle string data.
-        * All complex types except MapType with collated fields are supported.
-       */
-      super.checkInputDataTypes()
-    } else {
-      TypeCheckResult.DataTypeMismatch("UNSUPPORTED_MODE_DATA_TYPE",
-        messageParameters =
-          Map("child" -> toSQLType(child.dataType),
-            "mode" -> toSQLId(prettyName)))
-    }
-  }
-
   override def prettyName: String = "mode"
 
   override def update(
@@ -115,6 +97,7 @@ case class Mode(
       case st: StructType =>
         
processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields))
       case at: ArrayType => processArrayTypeWithBuffer(at, 
data.asInstanceOf[ArrayData])
+      case mt: MapType => processMapTypeWithBuffer(mt, 
data.asInstanceOf[MapData])
       case st: StringType =>
         CollationFactory.getCollationKey(data.asInstanceOf[UTF8String], 
st.collationId)
       case _ =>
@@ -140,6 +123,16 @@ case class Mode(
       collationAwareTransform(data.get(i, a.elementType), a.elementType))
   }
 
+  private def processMapTypeWithBuffer(mt: MapType, data: MapData): Map[Any, 
Any] = {
+    val transformedKeys = (0 until data.numElements()).map { i =>
+      collationAwareTransform(data.keyArray().get(i, mt.keyType), mt.keyType)
+    }
+    val transformedValues = (0 until data.numElements()).map { i =>
+      collationAwareTransform(data.valueArray().get(i, mt.valueType), 
mt.valueType)
+    }
+    transformedKeys.zip(transformedValues).toMap
+  }
+
   override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
     if (buffer.isEmpty) {
       return null
@@ -157,8 +150,7 @@ case class Mode(
       *
       * The new map is then used in the rest of the Mode evaluation logic.
       *
-      * It is expected to work for all simple and complex types with
-      *  collated fields, except for MapType (temporarily).
+      * It is expected to work for all simple and complex types with collated 
fields.
       */
     val collationAwareBuffer = getCollationAwareBuffer(child.dataType, buffer)
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 9930709cd8bf..851160d2fbb9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -19,11 +19,10 @@ package org.apache.spark.sql
 
 import java.sql.{Date, Timestamp}
 import java.text.SimpleDateFormat
-import java.util.Locale
 
 import scala.collection.immutable.Seq
 
-import org.apache.spark.{SparkConf, SparkException, 
SparkIllegalArgumentException, SparkRuntimeException, SparkThrowable}
+import org.apache.spark.{SparkConf, SparkException, 
SparkIllegalArgumentException, SparkRuntimeException}
 import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
@@ -1924,7 +1923,7 @@ class CollationSQLExpressionsSuite
     }
   }
 
-  test("Support mode expression with collated in recursively nested struct 
with map with keys") {
+  test("Support mode for string expression with collated complex type - nested 
map") {
     case class ModeTestCase(collationId: String, bufferValues: Map[String, 
Long], result: String)
     Seq(
       ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 
1}"),
@@ -1932,22 +1931,6 @@ class CollationSQLExpressionsSuite
       ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 
1}"),
       ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 
1}")
     ).foreach { t1 =>
-      def checkThisError(t: ModeTestCase, query: String): Any = {
-        val c = s"STRUCT<m1: MAP<STRING COLLATE 
${t.collationId.toUpperCase(Locale.ROOT)}, INT>>"
-        val c1 = s"\"${c}\""
-        checkError(
-          exception = intercept[SparkThrowable] {
-            sql(query).collect()
-          },
-          condition = "DATATYPE_MISMATCH.UNSUPPORTED_MODE_DATA_TYPE",
-          parameters = Map(
-            ("sqlExpr", "\"mode(i)\""),
-            ("child", c1),
-            ("mode", "`mode`")),
-          queryContext = Seq(ExpectedContext("mode(i)", 18, 24)).toArray
-        )
-      }
-
       def getValuesToAdd(t: ModeTestCase): String = {
         val valuesToAdd = t.bufferValues.map {
           case (elt, numRepeats) =>
@@ -1964,41 +1947,12 @@ class CollationSQLExpressionsSuite
         sql(s"INSERT INTO ${tableName} VALUES ${getValuesToAdd(t1)}")
         val query = "SELECT lower(cast(mode(i).m1 as string))" +
           s" FROM ${tableName}"
-        if (t1.collationId == "utf8_binary") {
-          checkAnswer(sql(query), Row(t1.result))
-        } else {
-          checkThisError(t1, query)
-        }
+        val queryResult = sql(query)
+        checkAnswer(queryResult, Row(t1.result))
       }
     }
   }
 
-  test("UDT with collation  - Mode (throw exception)") {
-    case class ModeTestCase(collationId: String, bufferValues: Map[String, 
Long], result: String)
-    Seq(
-      ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
-      ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
-      ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
-    ).foreach { t1 =>
-        checkError(
-          exception = intercept[SparkIllegalArgumentException] {
-          Mode(
-              child = Literal.create(null,
-                MapType(StringType(t1.collationId), IntegerType))
-            ).collationAwareTransform(
-              data = Map.empty[String, Any],
-              dataType = MapType(StringType(t1.collationId), IntegerType)
-            )
-          },
-          condition = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS",
-          parameters = Map(
-            "expression" -> "\"mode(NULL)\"",
-            "functionName" -> "\"MODE\"",
-            "dataType" -> s"\"MAP<STRING COLLATE 
${t1.collationId.toUpperCase()}, INT>\"")
-         )
-      }
-  }
-
   test("SPARK-48430: Map value extraction with collations") {
     for {
       collateKey <- Seq(true, false)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to