This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e8227a8e2442 [SPARK-49018][SQL] Fix approx_count_distinct not working 
correctly with collation
e8227a8e2442 is described below

commit e8227a8e24422f9bcabd1b18c87cc4f3b78a72b3
Author: viktorluc-db <viktor.lu...@databricks.com>
AuthorDate: Mon Aug 5 20:55:49 2024 +0800

    [SPARK-49018][SQL] Fix approx_count_distinct not working correctly with 
collation
    
    ### What changes were proposed in this pull request?
    Fix for approx_count_distinct not working correctly with collated strings.
    
    ### Why are the changes needed?
    approx_count_distinct was not working with any collation.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    New test added to CollationSQLExpressionSuite.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #47503 from viktorluc-db/bugfix.
    
    Authored-by: viktorluc-db <viktor.lu...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/util/HyperLogLogPlusPlusHelper.scala  |  3 ++
 .../spark/sql/CollationSQLExpressionsSuite.scala   | 39 ++++++++++++++++++++++
 2 files changed, 42 insertions(+)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
index 6471a746f2ed..fc947386487a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.XxHash64Function
 import 
org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers.{DOUBLE_NORMALIZER,
 FLOAT_NORMALIZER}
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 // A helper class for HyperLogLogPlusPlus.
 class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable {
@@ -93,6 +94,8 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends 
Serializable {
     val value = dataType match {
       case FloatType => FLOAT_NORMALIZER.apply(_value)
       case DoubleType => DOUBLE_NORMALIZER.apply(_value)
+      case st: StringType if !st.supportsBinaryEquality =>
+        CollationFactory.getCollationKeyBytes(_value.asInstanceOf[UTF8String], 
st.collationId)
       case _ => _value
     }
     // Create the hashed value 'x'.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index e31411ea212f..2473a9228194 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -2319,6 +2319,45 @@ class CollationSQLExpressionsSuite
     )
   }
 
+  test("Support HyperLogLogPlusPlus expression with collation") {
+    case class HyperLogLogPlusPlusTestCase(
+      collation: String,
+      input: Seq[String],
+      output: Seq[Row]
+    )
+
+    val testCases = Seq(
+      HyperLogLogPlusPlusTestCase("utf8_binary", Seq("a", "a", "A", "z", "zz", 
"ZZ", "w", "AA",
+        "aA", "Aa", "aa"), Seq(Row(10))),
+      HyperLogLogPlusPlusTestCase("utf8_lcase", Seq("a", "a", "A", "z", "zz", 
"ZZ", "w", "AA",
+        "aA", "Aa", "aa"), Seq(Row(5))),
+      HyperLogLogPlusPlusTestCase("UNICODE", Seq("a", "a", "A", "z", "zz", 
"ZZ", "w", "AA",
+        "aA", "Aa", "aa"), Seq(Row(10))),
+      HyperLogLogPlusPlusTestCase("UNICODE_CI", Seq("a", "a", "A", "z", "zz", 
"ZZ", "w", "AA",
+        "aA", "Aa", "aa"), Seq(Row(5)))
+    )
+
+    testCases.foreach( t => {
+      // Using explicit collate clause
+      val query =
+        s"""
+           |SELECT approx_count_distinct(col) FROM VALUES
+           |${t.input.map(s => s"('${s}' collate ${t.collation})").mkString(", 
") } tab(col)
+           |""".stripMargin
+      checkAnswer(sql(query), t.output)
+
+      // Using default collation
+      withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collation) {
+        val query =
+          s"""
+             |SELECT approx_count_distinct(col) FROM VALUES
+             |${t.input.map(s => s"('${s}')").mkString(", ") } tab(col)
+             |""".stripMargin
+        checkAnswer(sql(query), t.output)
+      }
+    })
+  }
+
   // TODO: Add more tests for other SQL expressions
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to