This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0fc5b0b9383b [SPARK-47353][SQL] Enable collation support for the Mode
expression
0fc5b0b9383b is described below
commit 0fc5b0b9383b7963835eeaafdd1c5f02417eedde
Author: GideonPotok <[email protected]>
AuthorDate: Wed Jun 26 16:08:31 2024 +0800
[SPARK-47353][SQL] Enable collation support for the Mode expression
### What changes were proposed in this pull request?
[SPARK-47353](https://issues.apache.org/jira/browse/SPARK-47353)
#### Pull requests
[Scala TreeMap (RB Tree)](https://github.com/apache/spark/pull/46404)
[GroupMapReduce](https://github.com/apache/spark/pull/46526) <- Most
performant
[GroupMapReduce (Cleaned up) (This
PR)](https://github.com/apache/spark/pull/46597) <- Most performant
[Comparing Experimental Approaches
](https://github.com/apache/spark/pull/46488)
https://github.com/apache/spark/pull/46597/files#r1626058908 ->
https://github.com/apache/spark/pull/46917#discussion_r1640081873
#### Central Change to Mode `eval` Algorithm:
- Update to `eval` Method: The `eval` method now checks if the column
being looked at is string with non-default collation and if so, uses a grouping
```
buff.toSeq.groupMapReduce {
case (key: String, _) =>
CollationFactory.getCollationKey(UTF8String.fromString(key),
collationId)
case (key: UTF8String, _) =>
CollationFactory.getCollationKey(key, collationId)
case (key, _) => key
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
```
#### Minor Change to Mode:
- Introduction of `collationId`: A new lazy value `collationId` is computed
from the `dataType` of the `child` expression, used to fetch the appropriate
collation comparator when `collationEnabled` is true.
This PR will fail for complex types containing collated strings
Follow up PR will implement that
#### Unit Test Enhancements: Significant additions to
`CollationStringExpressionsSuite` to test new functionality including:
- Tests for the `Mode` function when handling strings with different
collation settings.
#### Benchmark Updates:
- Enhanced the `CollationBenchmark` classes to include benchmarks for the
new mode functionality with and without collation settings, as well as
numerical types.
### Why are the changes needed?
1. Ensures consistency in handling string comparisons under various
collation settings.
2. Improves global usability by enabling compatibility with different
collation standards.
### Does this PR introduce _any_ user-facing change?
Yes, this PR introduces the following user-facing changes:
1. Adds a new `collationEnabled` property to the `Mode` expression.
2. Users can now specify collation settings for the `Mode` expression to
customize its behavior.
### How was this patch tested?
This patch was tested through a combination of new and existing unit and
end-to-end SQL tests.
1. **Unit Tests:**
- **CollationStringExpressionsSuite:**
- Make the newly added tests more in the same design pattern as the
existing tests
- Added multiple test cases to verify that the `Mode` function
correctly handles strings with different collation settings.
Out of scope: Special Unicode Cases higher planes
Tests do not need to include Null Handling.
3. **Benchmark Tests:**
4. **Manual Testing:**
```
./build/mvn -DskipTests clean package
export SPARK_HOME=/Users/gideon/repos/spark
$SPARK_HOME/bin/spark-shell
spark.sqlContext.setConf("spark.sql.collation.enabled", "true")
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.functions
import spark.implicits._
val data = Seq(("Def"), ("def"), ("DEF"), ("abc"), ("abc"))
val df = data.toDF("word")
val dfLC = df.withColumn("word",
col("word").cast(StringType("UTF8_BINARY_LCASE")))
val dfLCA =
dfLC.agg(org.apache.spark.sql.functions.mode(functions.col("word")).as("count"))
dfLCA.show()
/*
BEFORE:
-----+
|count|
+-----+
| abc|
+-----+
AFTER:
+-----+
|count|
+-----+
| Def|
+-----+
*/
```
6. **Continuous Integration (CI):**
- The patch passed all relevant Continuous Integration (CI) checks,
including:
- Unit test suite
- Benchmark suite
- Consider moving the new benchmark to the catalyst module
### Was this patch authored or co-authored using generative AI tooling?
Nope!
Closes #46597 from GideonPotok/spark_47353_3_clean.
Lead-authored-by: GideonPotok <[email protected]>
Co-authored-by: Gideon Potok <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/expressions/aggregate/Mode.scala | 51 ++++-
.../spark/sql/CollationSQLExpressionsSuite.scala | 214 +++++++++++++++++++++
2 files changed, 259 insertions(+), 6 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
index d1a9cafdf61f..5977eff4526d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
@@ -18,13 +18,14 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder,
UnresolvedWithinGroup}
+import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder,
TypeCheckResult, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending,
Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData,
UnsafeRowUtils}
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType,
BooleanType, DataType}
+import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType,
BooleanType, DataType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.OpenHashMap
case class Mode(
@@ -48,6 +49,21 @@ case class Mode(
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
child.dataType.isInstanceOf[StringType]) {
+ /*
+ * The Mode class uses collation awareness logic to handle string data.
+ * Complex types with collated fields are not yet supported.
+ */
+ // TODO: SPARK-48700: Mode expression for complex types (all collations)
+ super.checkInputDataTypes()
+ } else {
+ TypeCheckResult.TypeCheckFailure("The input to the function 'mode' was" +
+ " a type of binary-unstable type that is " +
+ s"not currently supported by ${prettyName}.")
+ }
+ }
+
override def prettyName: String = "mode"
override def update(
@@ -74,7 +90,29 @@ case class Mode(
if (buffer.isEmpty) {
return null
}
-
+ /*
+ * The Mode class uses special collation awareness logic
+ * to handle string data types with various collations.
+ *
+ * For string types that don't support binary equality,
+ * we create a new map where the keys are the collation keys of the
original strings.
+ *
+ * Keys from the original map are aggregated based on the corresponding
collation keys.
+ * The groupMapReduce method groups the entries by collation key and
maps each group
+ * to a single value (the sum of the counts), and finally reduces the
groups to a single map.
+ *
+ * The new map is then used in the rest of the Mode evaluation logic.
+ */
+ val collationAwareBuffer = child.dataType match {
+ case c: StringType if
+ !CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality
=>
+ val collationId = c.collationId
+ val modeMap = buffer.toSeq.groupMapReduce {
+ case (k, _) =>
CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
+ }(x => x)((x, y) => (x._1, x._2 + y._2)).values
+ modeMap
+ case _ => buffer
+ }
reverseOpt.map { reverse =>
val defaultKeyOrdering = if (reverse) {
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse
@@ -82,8 +120,8 @@ case class Mode(
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]]
}
val ordering = Ordering.Tuple2(Ordering.Long, defaultKeyOrdering)
- buffer.maxBy { case (key, count) => (count, key) }(ordering)
- }.getOrElse(buffer.maxBy(_._2))._1
+ collationAwareBuffer.maxBy { case (key, count) => (count, key)
}(ordering)
+ }.getOrElse(collationAwareBuffer.maxBy(_._2))._1
}
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
Mode =
@@ -128,6 +166,7 @@ case class Mode(
copy(child = newChild)
}
+// TODO: SPARK-48701: PandasMode (all collations)
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 993e6cc35a79..7994c496cb65 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -24,9 +24,13 @@ import scala.collection.immutable.Seq
import org.apache.spark.{SparkConf, SparkException,
SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.collection.OpenHashMap
// scalastyle:off nonascii
class CollationSQLExpressionsSuite
@@ -1646,6 +1650,216 @@ class CollationSQLExpressionsSuite
}
}
+ test("Support mode for string expression with collation - Basic Test") {
+ Seq("utf8_binary", "UTF8_LCASE", "unicode_ci", "unicode").foreach {
collationId =>
+ val query = s"SELECT mode(collate('abc', '${collationId}'))"
+ checkAnswer(sql(query), Row("abc"))
+
assert(sql(query).schema.fields.head.dataType.sameType(StringType(collationId)))
+ }
+ }
+
+ test("Support mode for string expression with collation - Advanced Test") {
+ case class ModeTestCase[R](collationId: String, bufferValues: Map[String,
Long], result: R)
+ val testCases = Seq(
+ ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+ ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+ ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a")
+ )
+ testCases.foreach(t => {
+ val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+ (0L to numRepeats).map(_ => s"('$elt')").mkString(",")
+ }.mkString(",")
+
+ val tableName = s"t_${t.collationId}_mode"
+ withTable(s"${tableName}") {
+ sql(s"CREATE TABLE ${tableName}(i STRING) USING parquet")
+ sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+ val query = s"SELECT mode(collate(i, '${t.collationId}')) FROM
${tableName}"
+ checkAnswer(sql(query), Row(t.result))
+
assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collationId)))
+
+ }
+ })
+ }
+
+ test("Support Mode.eval(buffer)") {
+ case class UTF8StringModeTestCase[R](
+ collationId: String,
+ bufferValues: Map[UTF8String, Long],
+ result: R)
+
+ val bufferValuesUTF8String = Map(
+ UTF8String.fromString("a") -> 5L,
+ UTF8String.fromString("b") -> 4L,
+ UTF8String.fromString("B") -> 3L,
+ UTF8String.fromString("d") -> 2L,
+ UTF8String.fromString("e") -> 1L)
+
+ val testCasesUTF8String = Seq(
+ UTF8StringModeTestCase("utf8_binary", bufferValuesUTF8String, "a"),
+ UTF8StringModeTestCase("UTF8_LCASE", bufferValuesUTF8String, "b"),
+ UTF8StringModeTestCase("unicode_ci", bufferValuesUTF8String, "b"),
+ UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a"))
+
+ testCasesUTF8String.foreach(t => {
+ val buffer = new OpenHashMap[AnyRef, Long](5)
+ val myMode = Mode(child = Literal.create("some_column_name",
StringType(t.collationId)))
+ t.bufferValues.foreach { case (k, v) => buffer.update(k, v) }
+ assert(myMode.eval(buffer).toString.toLowerCase() ==
t.result.toLowerCase())
+ })
+ }
+
+ test("Support mode for string expression with collated strings in struct") {
+ case class ModeTestCase[R](collationId: String, bufferValues: Map[String,
Long], result: R)
+ val testCases = Seq(
+ ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+ ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+ )
+ testCases.foreach(t => {
+ val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+ (0L to numRepeats).map(_ => s"named_struct('f1'," +
+ s" collate('$elt', '${t.collationId}'), 'f2', 1)").mkString(",")
+ }.mkString(",")
+
+ val tableName = s"t_${t.collationId}_mode_struct"
+ withTable(tableName) {
+ sql(s"CREATE TABLE ${tableName}(i STRUCT<f1: STRING COLLATE " +
+ t.collationId + ", f2: INT>) USING parquet")
+ sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+ val query = s"SELECT lower(mode(i).f1) FROM ${tableName}"
+ if(t.collationId == "UTF8_LCASE" ||
+ t.collationId == "unicode_ci" ||
+ t.collationId == "unicode") {
+ // Cannot resolve "mode(i)" due to data type mismatch:
+ // Input to function mode was a complex type with strings collated
on non-binary
+ // collations, which is not yet supported.. SQLSTATE: 42K09; line 1
pos 13;
+ val params = Seq(("sqlExpr", "\"mode(i)\""),
+ ("msg", "The input to the function 'mode'" +
+ " was a type of binary-unstable type that is not currently
supported by mode."),
+ ("hint", "")).toMap
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(query)
+ },
+ errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ parameters = params,
+ queryContext = Array(
+ ExpectedContext(objectType = "",
+ objectName = "",
+ startIndex = 13,
+ stopIndex = 19,
+ fragment = "mode(i)")
+ )
+ )
+ } else {
+ checkAnswer(sql(query), Row(t.result))
+ }
+ }
+ })
+ }
+
+ test("Support mode for string expression with collated strings in
recursively nested struct") {
+ case class ModeTestCase[R](collationId: String, bufferValues: Map[String,
Long], result: R)
+ val testCases = Seq(
+ ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+ ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+ )
+ testCases.foreach(t => {
+ val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+ (0L to numRepeats).map(_ => s"named_struct('f1', " +
+ s"named_struct('f2', collate('$elt', '${t.collationId}')), 'f3',
1)").mkString(",")
+ }.mkString(",")
+
+ val tableName = s"t_${t.collationId}_mode_nested_struct"
+ withTable(tableName) {
+ sql(s"CREATE TABLE ${tableName}(i STRUCT<f1: STRUCT<f2: STRING COLLATE
" +
+ t.collationId + ">, f3: INT>) USING parquet")
+ sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+ val query = s"SELECT lower(mode(i).f1.f2) FROM ${tableName}"
+ if(t.collationId == "UTF8_LCASE" ||
+ t.collationId == "unicode_ci" ||
+ t.collationId == "unicode") {
+ // Cannot resolve "mode(i)" due to data type mismatch:
+ // Input to function mode was a complex type with strings collated
on non-binary
+ // collations, which is not yet supported.. SQLSTATE: 42K09; line 1
pos 13;
+ val params = Seq(("sqlExpr", "\"mode(i)\""),
+ ("msg", "The input to the function 'mode' " +
+ "was a type of binary-unstable type that is not currently
supported by mode."),
+ ("hint", "")).toMap
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(query)
+ },
+ errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ parameters = params,
+ queryContext = Array(
+ ExpectedContext(objectType = "",
+ objectName = "",
+ startIndex = 13,
+ stopIndex = 19,
+ fragment = "mode(i)")
+ )
+ )
+ } else {
+ checkAnswer(sql(query), Row(t.result))
+ }
+ }
+ })
+ }
+
+ test("Support mode for string expression with collated strings in array
complex type") {
+ case class ModeTestCase[R](collationId: String, bufferValues: Map[String,
Long], result: R)
+ val testCases = Seq(
+ ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+ ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+ )
+ testCases.foreach(t => {
+ val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+ (0L to numRepeats).map(_ => s"array(named_struct('s1',
named_struct('a2', " +
+ s"array(collate('$elt', '${t.collationId}'))), 'f3',
1))").mkString(",")
+ }.mkString(",")
+
+ val tableName = s"t_${t.collationId}_mode_nested_struct"
+ withTable(tableName) {
+ sql(s"CREATE TABLE ${tableName}(" +
+ s"i ARRAY<STRUCT<s1: STRUCT<a2: ARRAY<STRING COLLATE
${t.collationId}>>, f3: INT>>)" +
+ s" USING parquet")
+ sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+ val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2,
1)) FROM ${tableName}"
+ if(t.collationId == "UTF8_LCASE" ||
+ t.collationId == "unicode_ci" || t.collationId == "unicode") {
+ val params = Seq(("sqlExpr", "\"mode(i)\""),
+ ("msg", "The input to the function 'mode' was a type" +
+ " of binary-unstable type that is not currently supported by
mode."),
+ ("hint", "")).toMap
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(query)
+ },
+ errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ parameters = params,
+ queryContext = Array(
+ ExpectedContext(objectType = "",
+ objectName = "",
+ startIndex = 35,
+ stopIndex = 41,
+ fragment = "mode(i)")
+ )
+ )
+ } else {
+ checkAnswer(sql(query), Row(t.result))
+ }
+ }
+ })
+ }
+
test("SPARK-48430: Map value extraction with collations") {
for {
collateKey <- Seq(true, false)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]