This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3aea6c258bf3 [SPARK-48003][SQL] Add collation support for hll sketch 
aggregate
3aea6c258bf3 is described below

commit 3aea6c258bf3541d7f53cd3914244f817ed36ff6
Author: Uros Bojanic <157381213+uros...@users.noreply.github.com>
AuthorDate: Tue Apr 30 20:58:43 2024 +0800

    [SPARK-48003][SQL] Add collation support for hll sketch aggregate
    
    ### What changes were proposed in this pull request?
    Introduce collation awareness for hll sketch aggregate.
    
    ### Why are the changes needed?
    Add collation support for hyperloglog expressions in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users should now be able to use collated strings within arguments for 
hyperloglog function: hll_sketch_agg.
    
    ### How was this patch tested?
    E2e sql tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #46241 from uros-db/hll-agg.
    
    Authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/util/CollationFactory.java     | 14 ++++++++++++++
 .../aggregate/datasketchesAggregates.scala            |  8 ++++++--
 .../scala/org/apache/spark/sql/CollationSuite.scala   | 19 +++++++++++++++++++
 3 files changed, 39 insertions(+), 2 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index 93691e28c692..8ffff63445b6 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -25,6 +25,7 @@ import java.util.function.ToLongFunction;
 import com.ibm.icu.text.RuleBasedCollator;
 import com.ibm.icu.text.StringSearch;
 import com.ibm.icu.util.ULocale;
+import com.ibm.icu.text.CollationKey;
 import com.ibm.icu.text.Collator;
 
 import org.apache.spark.SparkException;
@@ -270,4 +271,17 @@ public final class CollationFactory {
     int collationId = collationNameToId(collationName);
     return collationTable[collationId];
   }
+
+  public static UTF8String getCollationKey(UTF8String input, int collationId) {
+    Collation collation = fetchCollation(collationId);
+    if (collation.supportsBinaryEquality) {
+      return input;
+    } else if (collation.supportsLowercaseEquality) {
+      return input.toLowerCase();
+    } else {
+      CollationKey collationKey = 
collation.collator.getCollationKey(input.toString());
+      return UTF8String.fromBytes(collationKey.toByteArray());
+    }
+  }
+
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala
index 02925f3625d2..2102428131f6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala
@@ -25,7 +25,9 @@ import org.apache.spark.SparkUnsupportedOperationException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionDescription, Literal}
 import org.apache.spark.sql.catalyst.trees.BinaryLike
+import org.apache.spark.sql.catalyst.util.CollationFactory
 import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
 import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, 
DataType, IntegerType, LongType, StringType, TypeCollection}
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -103,7 +105,7 @@ case class HllSketchAgg(
   override def prettyName: String = "hll_sketch_agg"
 
   override def inputTypes: Seq[AbstractDataType] =
-    Seq(TypeCollection(IntegerType, LongType, StringType, BinaryType), 
IntegerType)
+    Seq(TypeCollection(IntegerType, LongType, StringTypeAnyCollation, 
BinaryType), IntegerType)
 
   override def dataType: DataType = BinaryType
 
@@ -137,7 +139,9 @@ case class HllSketchAgg(
         // TODO: implement support for decimal/datetime/interval types
         case IntegerType => sketch.update(v.asInstanceOf[Int])
         case LongType => sketch.update(v.asInstanceOf[Long])
-        case StringType => sketch.update(v.asInstanceOf[UTF8String].toString)
+        case st: StringType =>
+          val cKey = 
CollationFactory.getCollationKey(v.asInstanceOf[UTF8String], st.collationId)
+          sketch.update(cKey.toString)
         case BinaryType => sketch.update(v.asInstanceOf[Array[Byte]])
         case dataType => throw new SparkUnsupportedOperationException(
           errorClass = "_LEGACY_ERROR_TEMP_3121",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index 26f7726c3964..fce9ad3cc184 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -1040,4 +1040,23 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
       checkAnswer(dfNonBinary, dfBinary)
     }
   }
+
+  test("hll sketch aggregate should respect collation") {
+    case class HllSketchAggTestCase[R](c: String, result: R)
+    val testCases = Seq(
+      HllSketchAggTestCase("UTF8_BINARY", 4),
+      HllSketchAggTestCase("UTF8_BINARY_LCASE", 3),
+      HllSketchAggTestCase("UNICODE", 4),
+      HllSketchAggTestCase("UNICODE_CI", 3)
+    )
+    testCases.foreach(t => {
+      withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.c) {
+        val q = "SELECT hll_sketch_estimate(hll_sketch_agg(col)) FROM " +
+          "VALUES ('a'), ('A'), ('b'), ('b'), ('c') tab(col)"
+        val df = sql(q)
+        checkAnswer(df, Seq(Row(t.result)))
+      }
+    })
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to