This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0fc5b0b9383b [SPARK-47353][SQL] Enable collation support for the Mode 
expression
0fc5b0b9383b is described below

commit 0fc5b0b9383b7963835eeaafdd1c5f02417eedde
Author: GideonPotok <[email protected]>
AuthorDate: Wed Jun 26 16:08:31 2024 +0800

    [SPARK-47353][SQL] Enable collation support for the Mode expression
    
    ### What changes were proposed in this pull request?
    
    [SPARK-47353](https://issues.apache.org/jira/browse/SPARK-47353)
    
    #### Pull requests
    
    [Scala TreeMap (RB Tree)](https://github.com/apache/spark/pull/46404)
    [GroupMapReduce](https://github.com/apache/spark/pull/46526) <- Most 
performant
    [GroupMapReduce (Cleaned up) (This 
PR)](https://github.com/apache/spark/pull/46597) <- Most performant
     [Comparing Experimental Approaches 
](https://github.com/apache/spark/pull/46488)
    
    https://github.com/apache/spark/pull/46597/files#r1626058908 -> 
https://github.com/apache/spark/pull/46917#discussion_r1640081873
    
    #### Central Change to Mode `eval` Algorithm:
     - Update to `eval` Method: The `eval` method now checks if the column 
being looked at is string with non-default collation and if so, uses a grouping
     ```
    buff.toSeq.groupMapReduce {
             case (key: String, _) =>
               CollationFactory.getCollationKey(UTF8String.fromString(key), 
collationId)
             case (key: UTF8String, _) =>
               CollationFactory.getCollationKey(key, collationId)
             case (key, _) => key
           }(x => x)((x, y) => (x._1, x._2 + y._2)).values
    ```
    
    #### Minor Change to Mode:
    - Introduction of `collationId`: A new lazy value `collationId` is computed 
from the `dataType` of the `child` expression, used to fetch the appropriate 
collation comparator when `collationEnabled` is true.
    
    This PR will fail for complex types containing collated strings
    Follow up PR will implement that
    
     #### Unit Test Enhancements: Significant additions to 
`CollationStringExpressionsSuite` to test new functionality including:
    - Tests for the `Mode` function when handling strings with different 
collation settings.
    
     #### Benchmark Updates:
    - Enhanced the `CollationBenchmark` classes to include benchmarks for the 
new mode functionality with and without collation settings, as well as 
numerical types.
    
    ### Why are the changes needed?
    
    1. Ensures consistency in handling string comparisons under various 
collation settings.
    2. Improves global usability by enabling compatibility with different 
collation standards.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, this PR introduces the following user-facing changes:
    
    1. Adds a new `collationEnabled` property to the `Mode` expression.
    2. Users can now specify collation settings for the `Mode` expression to 
customize its behavior.
    
    ### How was this patch tested?
    
    This patch was tested through a combination of new and existing unit and 
end-to-end SQL tests.
    
    1. **Unit Tests:**
       - **CollationStringExpressionsSuite:**
         - Make the newly added tests more in the same design pattern as the 
existing tests
        - Added multiple test cases to verify that the `Mode` function 
correctly handles strings with different collation settings.
    
    Out of scope: Special Unicode Cases higher planes
    
    Tests do not need to include Null Handling.
    
    3. **Benchmark Tests:**
    
    4. **Manual Testing:**
    
    ```
     ./build/mvn -DskipTests clean package
    export SPARK_HOME=/Users/gideon/repos/spark
    $SPARK_HOME/bin/spark-shell
       spark.sqlContext.setConf("spark.sql.collation.enabled", "true")
        import org.apache.spark.sql.types.StringType
        import org.apache.spark.sql.functions
        import spark.implicits._
        val data = Seq(("Def"), ("def"), ("DEF"), ("abc"), ("abc"))
        val df = data.toDF("word")
        val dfLC = df.withColumn("word",
          col("word").cast(StringType("UTF8_BINARY_LCASE")))
        val dfLCA = 
dfLC.agg(org.apache.spark.sql.functions.mode(functions.col("word")).as("count"))
        dfLCA.show()
    /*
    BEFORE:
    -----+
    |count|
    +-----+
    |  abc|
    +-----+
    
    AFTER:
    +-----+
    |count|
    +-----+
    |  Def|
    +-----+
    
    */
    ```
    
    6. **Continuous Integration (CI):**
       - The patch passed all relevant Continuous Integration (CI) checks, 
including:
         - Unit test suite
         - Benchmark suite
         - Consider moving the new benchmark to the catalyst module
    
    ### Was this patch authored or co-authored using generative AI tooling?
    Nope!
    
    Closes #46597 from GideonPotok/spark_47353_3_clean.
    
    Lead-authored-by: GideonPotok <[email protected]>
    Co-authored-by: Gideon Potok <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/expressions/aggregate/Mode.scala  |  51 ++++-
 .../spark/sql/CollationSQLExpressionsSuite.scala   | 214 +++++++++++++++++++++
 2 files changed, 259 insertions(+), 6 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
index d1a9cafdf61f..5977eff4526d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
@@ -18,13 +18,14 @@
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, 
UnresolvedWithinGroup}
+import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, 
TypeCheckResult, UnresolvedWithinGroup}
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, 
Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
 import org.apache.spark.sql.catalyst.trees.UnaryLike
 import org.apache.spark.sql.catalyst.types.PhysicalDataType
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData, 
UnsafeRowUtils}
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, 
BooleanType, DataType}
+import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, 
BooleanType, DataType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.collection.OpenHashMap
 
 case class Mode(
@@ -48,6 +49,21 @@ case class Mode(
 
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (UnsafeRowUtils.isBinaryStable(child.dataType) || 
child.dataType.isInstanceOf[StringType]) {
+      /*
+        * The Mode class uses collation awareness logic to handle string data.
+        * Complex types with collated fields are not yet supported.
+       */
+      // TODO: SPARK-48700: Mode expression for complex types (all collations)
+      super.checkInputDataTypes()
+    } else {
+      TypeCheckResult.TypeCheckFailure("The input to the function 'mode' was" +
+        " a type of binary-unstable type that is " +
+        s"not currently supported by ${prettyName}.")
+    }
+  }
+
   override def prettyName: String = "mode"
 
   override def update(
@@ -74,7 +90,29 @@ case class Mode(
     if (buffer.isEmpty) {
       return null
     }
-
+    /*
+      * The Mode class uses special collation awareness logic
+      *  to handle string data types with various collations.
+      *
+      * For string types that don't support binary equality,
+      * we create a new map where the keys are the collation keys of the 
original strings.
+      *
+      * Keys from the original map are aggregated based on the corresponding 
collation keys.
+      *  The groupMapReduce method groups the entries by collation key and 
maps each group
+      *  to a single value (the sum of the counts), and finally reduces the 
groups to a single map.
+      *
+      * The new map is then used in the rest of the Mode evaluation logic.
+      */
+    val collationAwareBuffer = child.dataType match {
+      case c: StringType if
+        !CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality 
=>
+        val collationId = c.collationId
+        val modeMap = buffer.toSeq.groupMapReduce {
+         case (k, _) => 
CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
+        }(x => x)((x, y) => (x._1, x._2 + y._2)).values
+        modeMap
+      case _ => buffer
+    }
     reverseOpt.map { reverse =>
       val defaultKeyOrdering = if (reverse) {
         
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse
@@ -82,8 +120,8 @@ case class Mode(
         
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]]
       }
       val ordering = Ordering.Tuple2(Ordering.Long, defaultKeyOrdering)
-      buffer.maxBy { case (key, count) => (count, key) }(ordering)
-    }.getOrElse(buffer.maxBy(_._2))._1
+      collationAwareBuffer.maxBy { case (key, count) => (count, key) 
}(ordering)
+    }.getOrElse(collationAwareBuffer.maxBy(_._2))._1
   }
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
Mode =
@@ -128,6 +166,7 @@ case class Mode(
     copy(child = newChild)
 }
 
+// TODO: SPARK-48701: PandasMode (all collations)
 // scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = """
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 993e6cc35a79..7994c496cb65 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -24,9 +24,13 @@ import scala.collection.immutable.Seq
 
 import org.apache.spark.{SparkConf, SparkException, 
SparkIllegalArgumentException, SparkRuntimeException}
 import org.apache.spark.sql.catalyst.ExtendedAnalysisException
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
 import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.collection.OpenHashMap
 
 // scalastyle:off nonascii
 class CollationSQLExpressionsSuite
@@ -1646,6 +1650,216 @@ class CollationSQLExpressionsSuite
     }
   }
 
+  test("Support mode for string expression with collation - Basic Test") {
+    Seq("utf8_binary", "UTF8_LCASE", "unicode_ci", "unicode").foreach { 
collationId =>
+      val query = s"SELECT mode(collate('abc', '${collationId}'))"
+      checkAnswer(sql(query), Row("abc"))
+      
assert(sql(query).schema.fields.head.dataType.sameType(StringType(collationId)))
+    }
+  }
+
+  test("Support mode for string expression with collation - Advanced Test") {
+    case class ModeTestCase[R](collationId: String, bufferValues: Map[String, 
Long], result: R)
+    val testCases = Seq(
+      ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+      ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+      ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a")
+    )
+    testCases.foreach(t => {
+      val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+        (0L to numRepeats).map(_ => s"('$elt')").mkString(",")
+      }.mkString(",")
+
+      val tableName = s"t_${t.collationId}_mode"
+      withTable(s"${tableName}") {
+        sql(s"CREATE TABLE ${tableName}(i STRING) USING parquet")
+        sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+        val query = s"SELECT mode(collate(i, '${t.collationId}')) FROM 
${tableName}"
+        checkAnswer(sql(query), Row(t.result))
+        
assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collationId)))
+
+      }
+    })
+  }
+
+  test("Support Mode.eval(buffer)") {
+    case class UTF8StringModeTestCase[R](
+        collationId: String,
+        bufferValues: Map[UTF8String, Long],
+        result: R)
+
+    val bufferValuesUTF8String = Map(
+      UTF8String.fromString("a") -> 5L,
+      UTF8String.fromString("b") -> 4L,
+      UTF8String.fromString("B") -> 3L,
+      UTF8String.fromString("d") -> 2L,
+      UTF8String.fromString("e") -> 1L)
+
+    val testCasesUTF8String = Seq(
+      UTF8StringModeTestCase("utf8_binary", bufferValuesUTF8String, "a"),
+      UTF8StringModeTestCase("UTF8_LCASE", bufferValuesUTF8String, "b"),
+      UTF8StringModeTestCase("unicode_ci", bufferValuesUTF8String, "b"),
+      UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a"))
+
+    testCasesUTF8String.foreach(t => {
+      val buffer = new OpenHashMap[AnyRef, Long](5)
+      val myMode = Mode(child = Literal.create("some_column_name", 
StringType(t.collationId)))
+      t.bufferValues.foreach { case (k, v) => buffer.update(k, v) }
+      assert(myMode.eval(buffer).toString.toLowerCase() == 
t.result.toLowerCase())
+    })
+  }
+
+  test("Support mode for string expression with collated strings in struct") {
+    case class ModeTestCase[R](collationId: String, bufferValues: Map[String, 
Long], result: R)
+    val testCases = Seq(
+      ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+      ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+    )
+    testCases.foreach(t => {
+      val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+        (0L to numRepeats).map(_ => s"named_struct('f1'," +
+          s" collate('$elt', '${t.collationId}'), 'f2', 1)").mkString(",")
+      }.mkString(",")
+
+      val tableName = s"t_${t.collationId}_mode_struct"
+      withTable(tableName) {
+        sql(s"CREATE TABLE ${tableName}(i STRUCT<f1: STRING COLLATE " +
+          t.collationId + ", f2: INT>) USING parquet")
+        sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+        val query = s"SELECT lower(mode(i).f1) FROM ${tableName}"
+        if(t.collationId == "UTF8_LCASE" ||
+          t.collationId == "unicode_ci" ||
+          t.collationId == "unicode") {
+          // Cannot resolve "mode(i)" due to data type mismatch:
+          // Input to function mode was a complex type with strings collated 
on non-binary
+          // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 
pos 13;
+          val params = Seq(("sqlExpr", "\"mode(i)\""),
+            ("msg", "The input to the function 'mode'" +
+              " was a type of binary-unstable type that is not currently 
supported by mode."),
+            ("hint", "")).toMap
+          checkError(
+            exception = intercept[AnalysisException] {
+              sql(query)
+            },
+            errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+            parameters = params,
+            queryContext = Array(
+              ExpectedContext(objectType = "",
+                objectName = "",
+                startIndex = 13,
+                stopIndex = 19,
+                fragment = "mode(i)")
+            )
+          )
+        } else {
+          checkAnswer(sql(query), Row(t.result))
+        }
+      }
+    })
+  }
+
+  test("Support mode for string expression with collated strings in 
recursively nested struct") {
+    case class ModeTestCase[R](collationId: String, bufferValues: Map[String, 
Long], result: R)
+    val testCases = Seq(
+      ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+      ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+    )
+    testCases.foreach(t => {
+      val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+        (0L to numRepeats).map(_ => s"named_struct('f1', " +
+          s"named_struct('f2', collate('$elt', '${t.collationId}')), 'f3', 
1)").mkString(",")
+      }.mkString(",")
+
+      val tableName = s"t_${t.collationId}_mode_nested_struct"
+      withTable(tableName) {
+        sql(s"CREATE TABLE ${tableName}(i STRUCT<f1: STRUCT<f2: STRING COLLATE 
" +
+          t.collationId + ">, f3: INT>) USING parquet")
+        sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+        val query = s"SELECT lower(mode(i).f1.f2) FROM ${tableName}"
+        if(t.collationId == "UTF8_LCASE" ||
+          t.collationId == "unicode_ci" ||
+          t.collationId == "unicode") {
+          // Cannot resolve "mode(i)" due to data type mismatch:
+          // Input to function mode was a complex type with strings collated 
on non-binary
+          // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 
pos 13;
+          val params = Seq(("sqlExpr", "\"mode(i)\""),
+            ("msg", "The input to the function 'mode' " +
+              "was a type of binary-unstable type that is not currently 
supported by mode."),
+            ("hint", "")).toMap
+          checkError(
+            exception = intercept[AnalysisException] {
+              sql(query)
+            },
+            errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+            parameters = params,
+            queryContext = Array(
+              ExpectedContext(objectType = "",
+                objectName = "",
+                startIndex = 13,
+                stopIndex = 19,
+                fragment = "mode(i)")
+            )
+          )
+        } else {
+          checkAnswer(sql(query), Row(t.result))
+        }
+      }
+    })
+  }
+
+  test("Support mode for string expression with collated strings in array 
complex type") {
+    case class ModeTestCase[R](collationId: String, bufferValues: Map[String, 
Long], result: R)
+    val testCases = Seq(
+      ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+      ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+      ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+    )
+    testCases.foreach(t => {
+      val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+        (0L to numRepeats).map(_ => s"array(named_struct('s1', 
named_struct('a2', " +
+          s"array(collate('$elt', '${t.collationId}'))), 'f3', 
1))").mkString(",")
+      }.mkString(",")
+
+      val tableName = s"t_${t.collationId}_mode_nested_struct"
+      withTable(tableName) {
+        sql(s"CREATE TABLE ${tableName}(" +
+          s"i ARRAY<STRUCT<s1: STRUCT<a2: ARRAY<STRING COLLATE 
${t.collationId}>>, f3: INT>>)" +
+          s" USING parquet")
+        sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+        val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2, 
1)) FROM ${tableName}"
+        if(t.collationId == "UTF8_LCASE" ||
+          t.collationId == "unicode_ci" || t.collationId == "unicode") {
+          val params = Seq(("sqlExpr", "\"mode(i)\""),
+            ("msg", "The input to the function 'mode' was a type" +
+              " of binary-unstable type that is not currently supported by 
mode."),
+            ("hint", "")).toMap
+          checkError(
+            exception = intercept[AnalysisException] {
+              sql(query)
+            },
+            errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+            parameters = params,
+            queryContext = Array(
+              ExpectedContext(objectType = "",
+                objectName = "",
+                startIndex = 35,
+                stopIndex = 41,
+                fragment = "mode(i)")
+            )
+          )
+        } else {
+          checkAnswer(sql(query), Row(t.result))
+        }
+      }
+    })
+  }
+
   test("SPARK-48430: Map value extraction with collations") {
     for {
       collateKey <- Seq(true, false)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to