This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e8227a8e2442 [SPARK-49018][SQL] Fix approx_count_distinct not working correctly with collation e8227a8e2442 is described below commit e8227a8e24422f9bcabd1b18c87cc4f3b78a72b3 Author: viktorluc-db <viktor.lu...@databricks.com> AuthorDate: Mon Aug 5 20:55:49 2024 +0800 [SPARK-49018][SQL] Fix approx_count_distinct not working correctly with collation ### What changes were proposed in this pull request? Fix for approx_count_distinct not working correctly with collated strings. ### Why are the changes needed? approx_count_distinct was not working with any collation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test added to CollationSQLExpressionSuite. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47503 from viktorluc-db/bugfix. Authored-by: viktorluc-db <viktor.lu...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../catalyst/util/HyperLogLogPlusPlusHelper.scala | 3 ++ .../spark/sql/CollationSQLExpressionsSuite.scala | 39 ++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala index 6471a746f2ed..fc947386487a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.XxHash64Function import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers.{DOUBLE_NORMALIZER, FLOAT_NORMALIZER} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String // A helper class for HyperLogLogPlusPlus. class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable { @@ -93,6 +94,8 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable { val value = dataType match { case FloatType => FLOAT_NORMALIZER.apply(_value) case DoubleType => DOUBLE_NORMALIZER.apply(_value) + case st: StringType if !st.supportsBinaryEquality => + CollationFactory.getCollationKeyBytes(_value.asInstanceOf[UTF8String], st.collationId) case _ => _value } // Create the hashed value 'x'. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index e31411ea212f..2473a9228194 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -2319,6 +2319,45 @@ class CollationSQLExpressionsSuite ) } + test("Support HyperLogLogPlusPlus expression with collation") { + case class HyperLogLogPlusPlusTestCase( + collation: String, + input: Seq[String], + output: Seq[Row] + ) + + val testCases = Seq( + HyperLogLogPlusPlusTestCase("utf8_binary", Seq("a", "a", "A", "z", "zz", "ZZ", "w", "AA", + "aA", "Aa", "aa"), Seq(Row(10))), + HyperLogLogPlusPlusTestCase("utf8_lcase", Seq("a", "a", "A", "z", "zz", "ZZ", "w", "AA", + "aA", "Aa", "aa"), Seq(Row(5))), + HyperLogLogPlusPlusTestCase("UNICODE", Seq("a", "a", "A", "z", "zz", "ZZ", "w", "AA", + "aA", "Aa", "aa"), Seq(Row(10))), + HyperLogLogPlusPlusTestCase("UNICODE_CI", Seq("a", "a", "A", "z", "zz", "ZZ", "w", "AA", + "aA", "Aa", "aa"), Seq(Row(5))) + ) + + testCases.foreach( t => { + // Using explicit collate clause + val query = + s""" + |SELECT approx_count_distinct(col) FROM VALUES + |${t.input.map(s => s"('${s}' collate ${t.collation})").mkString(", ") } tab(col) + |""".stripMargin + checkAnswer(sql(query), t.output) + + // Using default collation + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collation) { + val query = + s""" + |SELECT approx_count_distinct(col) FROM VALUES + |${t.input.map(s => s"('${s}')").mkString(", ") } tab(col) + |""".stripMargin + checkAnswer(sql(query), t.output) + } + }) + } + // TODO: Add more tests for other SQL expressions } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org