This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 542faa70cefe [SPARK-52828][SQL] Make hashing for collated strings 
collation agnostic
542faa70cefe is described below

commit 542faa70cefe6526387f2bfaebde546e4e562ace
Author: Uros Bojanic <uros.boja...@databricks.com>
AuthorDate: Thu Aug 7 16:37:22 2025 +0800

    [SPARK-52828][SQL] Make hashing for collated strings collation agnostic
    
    ### What changes were proposed in this pull request?
    We change the behavior of the `Murmur3Hash` and `XxHash64` catalyst 
expressions to be collation agnostic (i.e. collation-unaware). Also, we 
introduce two new internal catalyst expressions: `CollationAwareMurmur3Hash` 
and `CollationAwareXxHash64`, which are collation aware and take the collation 
of the string into consideration when hashing collated strings.
    
    Furthermore, we replace `Murmur3Hash` and `XxHash64` in expressions where 
the hash expressions should be collation aware with `CollationAwareMurmur3Hash` 
and `CollationAwareXxHash64`. This is necessary for example when we do hash 
partitioning. Moreover, we change the way hashing is done for collated strings 
for the internal HiveHash expression to be consistent with the rest of the 
hashing expressions (the HiveHash expression is meant to always be 
collation-aware).
    
    Finally, we add a kill switch (the SQL config is 
`COLLATION_AGNOSTIC_HASHING_ENABLED`) that allows to recover the previous 
behavior of `Murmur3Hash` and `XxHash64` as user-facing expressions. The kill 
switch has no effect on the new collation aware hashing expressions, or the 
HiveHash expression, which are internal and need to follow the new collation 
aware behavior.
    
    ### Why are the changes needed?
    The `Murmur3Hash` and `XxHash64` catalyst expressions, when applied to 
collated strings, currently always take into consideration the collation of the 
string, that is they are collation aware. This is not the correct behavior, and 
these expressions should be collation agnostic by default instead.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, see the detailed explanation above.
    
    ### How was this patch tested?
    Updated existing tests in relevant suites: CollationFactorySuite, 
DistributionSuite, and HashExpressionsSuite. Also verified that the 
CollationSuite suite passes.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #51521 from uros-db/collation-hashing.
    
    Lead-authored-by: Uros Bojanic <uros.boja...@databricks.com>
    Co-authored-by: Wenchen Fan <cloud0...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/util/CollationFactory.java  |  51 +++--
 .../spark/unsafe/types/CollationFactorySuite.scala |   8 +-
 .../spark/sql/catalyst/expressions/hash.scala      | 228 ++++++++++++++++++---
 .../sql/catalyst/plans/physical/partitioning.scala |   4 +-
 .../catalyst/util/HyperLogLogPlusPlusHelper.scala  |  11 +-
 .../util/InternalRowComparableWrapper.scala        |   8 +-
 .../org/apache/spark/sql/internal/SQLConf.scala    |   9 +
 .../spark/sql/catalyst/DistributionSuite.scala     |   4 +-
 .../expressions/HashExpressionsSuite.scala         | 131 +++++++++++-
 .../execution/benchmark/CollationBenchmark.scala   |   4 +-
 10 files changed, 388 insertions(+), 70 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index 4bcd75a73105..59c23064858d 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -22,7 +22,6 @@ import java.util.*;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Function;
 import java.util.function.BiFunction;
-import java.util.function.ToLongFunction;
 import java.util.stream.Stream;
 
 import com.ibm.icu.text.CollationKey;
@@ -125,10 +124,19 @@ public final class CollationFactory {
     public final String version;
 
     /**
-     * Collation sensitive hash function. Output for two UTF8Strings will be 
the same if they are
-     * equal according to the collation.
+     * Returns the sort key of the input UTF8String. Two UTF8String values are 
equal iff their
+     * sort keys are equal (compared as byte arrays).
+     * The sort key is defined as follows for collations without the RTRIM 
modifier:
+     * - UTF8_BINARY: It is the bytes of the string.
+     * - UTF8_LCASE: It is byte array we get by replacing all invalid UTF8 
sequences with the
+     *   Unicode replacement character and then converting all characters of 
the replaced string
+     *   with their lowercase equivalents (the Greek capital and Greek small 
sigma both map to
+     *   the Greek final sigma).
+     * - ICU collations: It is the byte array returned by the ICU library for 
the collated string.
+     *   For strings with the RTRIM modifier, we right-trim the string and 
return the collation key
+     *   of the resulting right-trimmed string.
      */
-    public final ToLongFunction<UTF8String> hashFunction;
+    public final Function<UTF8String, byte[]> sortKeyFunction;
 
     /**
      * Potentially faster way than using comparator to compare two UTF8Strings 
for equality.
@@ -182,7 +190,7 @@ public final class CollationFactory {
         Collator collator,
         Comparator<UTF8String> comparator,
         String version,
-        ToLongFunction<UTF8String> hashFunction,
+        Function<UTF8String, byte[]> sortKeyFunction,
         BiFunction<UTF8String, UTF8String, Boolean> equalsFunction,
         boolean isUtf8BinaryType,
         boolean isUtf8LcaseType,
@@ -192,7 +200,7 @@ public final class CollationFactory {
       this.collator = collator;
       this.comparator = comparator;
       this.version = version;
-      this.hashFunction = hashFunction;
+      this.sortKeyFunction = sortKeyFunction;
       this.isUtf8BinaryType = isUtf8BinaryType;
       this.isUtf8LcaseType = isUtf8LcaseType;
       this.equalsFunction = equalsFunction;
@@ -581,18 +589,18 @@ public final class CollationFactory {
       protected Collation buildCollation() {
         if (caseSensitivity == CaseSensitivity.UNSPECIFIED) {
           Comparator<UTF8String> comparator;
-          ToLongFunction<UTF8String> hashFunction;
+          Function<UTF8String, byte[]> sortKeyFunction;
           BiFunction<UTF8String, UTF8String, Boolean> equalsFunction;
           boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE;
 
           if (spaceTrimming == SpaceTrimming.NONE) {
             comparator = UTF8String::binaryCompare;
-            hashFunction = s -> (long) s.hashCode();
+            sortKeyFunction = s -> s.getBytes();
             equalsFunction = UTF8String::equals;
           } else {
             comparator = (s1, s2) -> applyTrimmingPolicy(s1, 
spaceTrimming).binaryCompare(
               applyTrimmingPolicy(s2, spaceTrimming));
-            hashFunction = s -> (long) applyTrimmingPolicy(s, 
spaceTrimming).hashCode();
+            sortKeyFunction = s -> applyTrimmingPolicy(s, 
spaceTrimming).getBytes();
             equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, 
spaceTrimming).equals(
               applyTrimmingPolicy(s2, spaceTrimming));
           }
@@ -603,25 +611,25 @@ public final class CollationFactory {
             null,
             comparator,
             CollationSpecICU.ICU_VERSION,
-            hashFunction,
+            sortKeyFunction,
             equalsFunction,
             /* isUtf8BinaryType = */ true,
             /* isUtf8LcaseType = */ false,
             spaceTrimming != SpaceTrimming.NONE);
         } else {
           Comparator<UTF8String> comparator;
-          ToLongFunction<UTF8String> hashFunction;
+          Function<UTF8String, byte[]> sortKeyFunction;
 
           if (spaceTrimming == SpaceTrimming.NONE) {
             comparator = CollationAwareUTF8String::compareLowerCase;
-            hashFunction = s ->
-              (long) 
CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode();
+            sortKeyFunction = s ->
+              CollationAwareUTF8String.lowerCaseCodePoints(s).getBytes();
           } else {
             comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase(
               applyTrimmingPolicy(s1, spaceTrimming),
               applyTrimmingPolicy(s2, spaceTrimming));
-            hashFunction = s -> (long) 
CollationAwareUTF8String.lowerCaseCodePoints(
-              applyTrimmingPolicy(s, spaceTrimming)).hashCode();
+            sortKeyFunction = s -> 
CollationAwareUTF8String.lowerCaseCodePoints(
+              applyTrimmingPolicy(s, spaceTrimming)).getBytes();
           }
 
           return new Collation(
@@ -630,7 +638,7 @@ public final class CollationFactory {
             null,
             comparator,
             CollationSpecICU.ICU_VERSION,
-            hashFunction,
+            sortKeyFunction,
             (s1, s2) -> comparator.compare(s1, s2) == 0,
             /* isUtf8BinaryType = */ false,
             /* isUtf8LcaseType = */ true,
@@ -1013,19 +1021,18 @@ public final class CollationFactory {
         collator.freeze();
 
         Comparator<UTF8String> comparator;
-        ToLongFunction<UTF8String> hashFunction;
+        Function<UTF8String, byte[]> sortKeyFunction;
 
         if (spaceTrimming == SpaceTrimming.NONE) {
-          hashFunction = s -> (long) collator.getCollationKey(
-            s.toValidString()).hashCode();
           comparator = (s1, s2) ->
             collator.compare(s1.toValidString(), s2.toValidString());
+          sortKeyFunction = s -> 
collator.getCollationKey(s.toValidString()).toByteArray();
         } else {
           comparator = (s1, s2) -> collator.compare(
             applyTrimmingPolicy(s1, spaceTrimming).toValidString(),
             applyTrimmingPolicy(s2, spaceTrimming).toValidString());
-          hashFunction = s -> (long) collator.getCollationKey(
-            applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode();
+          sortKeyFunction = s -> collator.getCollationKey(
+            applyTrimmingPolicy(s, 
spaceTrimming).toValidString()).toByteArray();
         }
 
         return new Collation(
@@ -1034,7 +1041,7 @@ public final class CollationFactory {
           collator,
           comparator,
           ICU_VERSION,
-          hashFunction,
+          sortKeyFunction,
           (s1, s2) -> comparator.compare(s1, s2) == 0,
           /* isUtf8BinaryType = */ false,
           /* isUtf8LcaseType = */ false,
diff --git 
a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
 
b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
index b20baa6050f4..e56c6002a88e 100644
--- 
a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
+++ 
b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
@@ -138,7 +138,7 @@ class CollationFactorySuite extends AnyFunSuite with 
Matchers { // scalastyle:ig
 
   case class CollationTestCase[R](collationName: String, s1: String, s2: 
String, expectedResult: R)
 
-  test("collation aware equality and hash") {
+  test("collation aware equality and sort key") {
     val checks = Seq(
       CollationTestCase("UTF8_BINARY", "aaa", "aaa", true),
       CollationTestCase("UTF8_BINARY", "aaa", "AAA", false),
@@ -193,9 +193,9 @@ class CollationFactorySuite extends AnyFunSuite with 
Matchers { // scalastyle:ig
       assert(collation.equalsFunction(toUTF8(testCase.s1), 
toUTF8(testCase.s2)) ==
         testCase.expectedResult)
 
-      val hash1 = collation.hashFunction.applyAsLong(toUTF8(testCase.s1))
-      val hash2 = collation.hashFunction.applyAsLong(toUTF8(testCase.s2))
-      assert((hash1 == hash2) == testCase.expectedResult)
+      val sortKey1 = 
collation.sortKeyFunction.apply(toUTF8(testCase.s1)).asInstanceOf[Array[Byte]]
+      val sortKey2 = 
collation.sortKeyFunction.apply(toUTF8(testCase.s2)).asInstanceOf[Array[Byte]]
+      assert(sortKey1.sameElements(sortKey2) == testCase.expectedResult)
     })
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index ac493d19df1b..d71effa45463 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.SchemaUtils
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.hash.Murmur3_x86_32
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -273,6 +274,11 @@ abstract class HashExpression[E] extends Expression {
 
   override def nullable: Boolean = false
 
+  protected def isCollationAware: Boolean
+
+  protected lazy val legacyCollationAwareHashing: Boolean =
+    SQLConf.get.getConf(SQLConf.COLLATION_AWARE_HASHING_ENABLED)
+
   private def hasMapType(dt: DataType): Boolean = {
     dt.existsRecursively(_.isInstanceOf[MapType])
   }
@@ -427,14 +433,43 @@ abstract class HashExpression[E] extends Expression {
       val numBytes = s"$input.numBytes()"
       s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, 
$numBytes, $result);"
     } else {
-      val stringHash = ctx.freshName("stringHash")
-      s"""
-        long $stringHash = 
CollationFactory.fetchCollation(${stringType.collationId})
-          .hashFunction.applyAsLong($input);
-        $result = $hasherClassName.hashLong($stringHash, $result);
-      """
+      if (isCollationAware) {
+        val key = ctx.freshName("key")
+        val offset = "Platform.BYTE_ARRAY_OFFSET"
+        s"""
+          byte[] $key = (byte[]) 
CollationFactory.fetchCollation(${stringType.collationId})
+            .sortKeyFunction.apply($input);
+          $result = $hasherClassName.hashUnsafeBytes($key, $offset, 
$key.length, $result);
+        """
+      } else if (legacyCollationAwareHashing) {
+        val collation = CollationFactory.fetchCollation(stringType.collationId)
+        val stringHash = ctx.freshName("stringHash")
+        if (collation.isUtf8BinaryType || collation.isUtf8LcaseType) {
+          s"""
+            long $stringHash = UTF8String.fromBytes((byte[]) CollationFactory
+              
.fetchCollation(${stringType.collationId}).sortKeyFunction.apply($input)).hashCode();
+            $result = $hasherClassName.hashLong($stringHash, $result);
+          """
+        } else if (collation.supportsSpaceTrimming) {
+          s"""
+            long $stringHash = 
CollationFactory.fetchCollation(${stringType.collationId})
+              
.getCollator().getCollationKey($input.trimRight().toValidString()).hashCode();
+            $result = $hasherClassName.hashLong($stringHash, $result);
+          """
+        } else {
+          s"""
+            long $stringHash = 
CollationFactory.fetchCollation(${stringType.collationId})
+              
.getCollator().getCollationKey($input.toValidString()).hashCode();
+            $result = $hasherClassName.hashLong($stringHash, $result);
+          """
+        }
+      } else {
+        val baseObject = s"$input.getBaseObject()"
+        val baseOffset = s"$input.getBaseOffset()"
+        val numBytes = s"$input.numBytes()"
+        s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, 
$numBytes, $result);"
+      }
     }
-
   }
 
   protected def genHashForMap(
@@ -544,10 +579,38 @@ abstract class InterpretedHashFunction {
   protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: 
Long): Long
 
   /**
-   * Computes hash of a given `value` of type `dataType`. The caller needs to 
check the validity
-   * of input `value`.
+   * This method is intended for callers using the old hash API and preserves 
compatibility for
+   * supported data types. It must only be used for data types that do not 
include collated strings
+   * or complex types (e.g., structs, arrays, maps) that may contain collated 
strings.
+   *
+   * The caller is responsible for ensuring that `dataType` does not involve 
collation-aware fields.
+   * This is validated via an internal assertion.
+   *
+   * @throws IllegalArgumentException if `dataType` contains non-UTF8 binary 
collation.
    */
   def hash(value: Any, dataType: DataType, seed: Long): Long = {
+    require(!SchemaUtils.hasNonUTF8BinaryCollation(dataType))
+    // For UTF8_BINARY, hashing behavior is the same regardless of the 
isCollationAware flag.
+    hash(
+      value = value,
+      dataType = dataType,
+      seed = seed,
+      isCollationAware = false,
+      legacyCollationAwareHashing = false)
+  }
+
+  /**
+   * Computes hash of a given `value` of type `dataType`. The caller needs to 
check the validity
+   * of input `value`. The `isCollationAware` boolean flag indicates whether 
hashing should take
+   * a string's collation into account. If not, the bytes of the string are 
hashed, otherwise the
+   * collation key of the string is hashed.
+   */
+  def hash(
+      value: Any,
+      dataType: DataType,
+      seed: Long,
+      isCollationAware: Boolean,
+      legacyCollationAwareHashing: Boolean): Long = {
     value match {
       case null => seed
       case b: Boolean => hashInt(if (b) 1 else 0, seed)
@@ -573,12 +636,25 @@ abstract class InterpretedHashFunction {
       case s: UTF8String =>
         val st = dataType.asInstanceOf[StringType]
         if (st.supportsBinaryEquality) {
-          hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed)
+          hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes, seed)
         } else {
-          val stringHash = CollationFactory
-            .fetchCollation(st.collationId)
-            .hashFunction.applyAsLong(s)
-          hashLong(stringHash, seed)
+          if (isCollationAware) {
+            val key = 
CollationFactory.fetchCollation(st.collationId).sortKeyFunction.apply(s)
+              .asInstanceOf[Array[Byte]]
+            hashUnsafeBytes(key, Platform.BYTE_ARRAY_OFFSET, key.length, seed)
+          } else if (legacyCollationAwareHashing) {
+            val collation = CollationFactory.fetchCollation(st.collationId)
+            val stringHash = if (collation.isUtf8BinaryType || 
collation.isUtf8LcaseType) {
+              UTF8String.fromBytes(collation.sortKeyFunction.apply(s)).hashCode
+            } else if (collation.supportsSpaceTrimming) {
+              
collation.getCollator.getCollationKey(s.trimRight.toValidString).hashCode
+            } else {
+              collation.getCollator.getCollationKey(s.toValidString).hashCode
+            }
+            hashLong(stringHash, seed)
+          } else {
+            hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes, seed)
+          }
         }
 
       case array: ArrayData =>
@@ -589,7 +665,12 @@ abstract class InterpretedHashFunction {
         var result = seed
         var i = 0
         while (i < array.numElements()) {
-          result = hash(array.get(i, elementType), elementType, result)
+          result = hash(
+            array.get(i, elementType),
+            elementType,
+            result,
+            isCollationAware,
+            legacyCollationAwareHashing)
           i += 1
         }
         result
@@ -606,8 +687,18 @@ abstract class InterpretedHashFunction {
         var result = seed
         var i = 0
         while (i < map.numElements()) {
-          result = hash(keys.get(i, kt), kt, result)
-          result = hash(values.get(i, vt), vt, result)
+          result = hash(
+            keys.get(i, kt),
+            kt,
+            result,
+            isCollationAware,
+            legacyCollationAwareHashing)
+          result = hash(
+            values.get(i, vt),
+            vt,
+            result,
+            isCollationAware,
+            legacyCollationAwareHashing)
           i += 1
         }
         result
@@ -622,7 +713,12 @@ abstract class InterpretedHashFunction {
         var i = 0
         val len = struct.numFields
         while (i < len) {
-          result = hash(struct.get(i, types(i)), types(i), result)
+          result = hash(
+            struct.get(i, types(i)),
+            types(i),
+            result,
+            isCollationAware,
+            legacyCollationAwareHashing)
           i += 1
         }
         result
@@ -654,8 +750,12 @@ case class Murmur3Hash(children: Seq[Expression], seed: 
Int) extends HashExpress
 
   override protected def hasherClassName: String = 
classOf[Murmur3_x86_32].getName
 
+  override protected def isCollationAware: Boolean = false
+
   override protected def computeHash(value: Any, dataType: DataType, seed: 
Int): Int = {
-    Murmur3HashFunction.hash(value, dataType, seed).toInt
+    Murmur3HashFunction.hash(
+      value, dataType, seed, isCollationAware, legacyCollationAwareHashing
+    ).toInt
   }
 
   override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): Murmur3Hash =
@@ -676,6 +776,29 @@ object Murmur3HashFunction extends InterpretedHashFunction 
{
   }
 }
 
+case class CollationAwareMurmur3Hash(children: Seq[Expression], seed: Int)
+  extends HashExpression[Int]
+{
+  def this(arguments: Seq[Expression]) = this(arguments, 42)
+
+  override def dataType: DataType = IntegerType
+
+  override def prettyName: String = "collation_aware_hash"
+
+  override protected def hasherClassName: String = 
classOf[Murmur3_x86_32].getName
+
+  override protected def isCollationAware: Boolean = true
+
+  override protected def computeHash(value: Any, dataType: DataType, seed: 
Int): Int = {
+    Murmur3HashFunction.hash(
+      value, dataType, seed, isCollationAware, legacyCollationAwareHashing
+    ).toInt
+  }
+
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]):
+    CollationAwareMurmur3Hash = copy(children = newChildren)
+}
+
 /**
  * A xxHash64 64-bit hash expression.
  */
@@ -698,8 +821,10 @@ case class XxHash64(children: Seq[Expression], seed: Long) 
extends HashExpressio
 
   override protected def hasherClassName: String = classOf[XXH64].getName
 
+  override protected def isCollationAware: Boolean = false
+
   override protected def computeHash(value: Any, dataType: DataType, seed: 
Long): Long = {
-    XxHash64Function.hash(value, dataType, seed)
+    XxHash64Function.hash(value, dataType, seed, isCollationAware, 
legacyCollationAwareHashing)
   }
 
   override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): XxHash64 =
@@ -716,6 +841,28 @@ object XxHash64Function extends InterpretedHashFunction {
   }
 }
 
+case class CollationAwareXxHash64(children: Seq[Expression], seed: Long)
+  extends HashExpression[Long]
+{
+  def this(arguments: Seq[Expression]) = this(arguments, 42L)
+
+  override def dataType: DataType = LongType
+
+  override def prettyName: String = "collation_aware_xxhash64"
+
+  override protected def hasherClassName: String = classOf[XXH64].getName
+
+  override protected def isCollationAware: Boolean = true
+
+  override protected def computeHash(value: Any, dataType: DataType, seed: 
Long): Long = {
+    XxHash64Function.hash(
+      value, dataType, seed, isCollationAware, legacyCollationAwareHashing)
+  }
+
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]):
+    CollationAwareXxHash64 = copy(children = newChildren)
+}
+
 /**
  * Simulates Hive's hashing function from Hive v1.2.1 at
  * 
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
@@ -736,8 +883,12 @@ case class HiveHash(children: Seq[Expression]) extends 
HashExpression[Int] {
 
   override protected def hasherClassName: String = classOf[HiveHasher].getName
 
+  override protected def isCollationAware: Boolean = true
+
   override protected def computeHash(value: Any, dataType: DataType, seed: 
Int): Int = {
-    HiveHashFunction.hash(value, dataType, this.seed).toInt
+    HiveHashFunction.hash(
+      value, dataType, this.seed, isCollationAware, legacyCollationAwareHashing
+    ).toInt
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -823,17 +974,18 @@ case class HiveHash(children: Seq[Expression]) extends 
HashExpression[Int] {
 
   override protected def genHashString(
       ctx: CodegenContext, stringType: StringType, input: String, result: 
String): String = {
-    if (stringType.supportsBinaryEquality) {
+    if (stringType.supportsBinaryEquality || !isCollationAware) {
       val baseObject = s"$input.getBaseObject()"
       val baseOffset = s"$input.getBaseOffset()"
       val numBytes = s"$input.numBytes()"
       s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, 
$numBytes);"
     } else {
-      val stringHash = ctx.freshName("stringHash")
+      val key = ctx.freshName("key")
+      val offset = Platform.BYTE_ARRAY_OFFSET
       s"""
-        long $stringHash = 
CollationFactory.fetchCollation(${stringType.collationId})
-          .hashFunction.applyAsLong($input);
-        $result = $hasherClassName.hashLong($stringHash);
+        byte[] $key = (byte[]) 
CollationFactory.fetchCollation(${stringType.collationId})
+          .sortKeyFunction.apply($input);
+        $result = $hasherClassName.hashUnsafeBytes($key, $offset, $key.length, 
$result);
       """
     }
   }
@@ -1016,7 +1168,12 @@ object HiveHashFunction extends InterpretedHashFunction {
      (result * 37) + nanoSeconds
   }
 
-  override def hash(value: Any, dataType: DataType, seed: Long): Long = {
+  override def hash(
+      value: Any,
+      dataType: DataType,
+      seed: Long,
+      isCollationAware: Boolean,
+      legacyCollationAwareHashing: Boolean): Long = {
     value match {
       case null => 0
       case array: ArrayData =>
@@ -1029,7 +1186,9 @@ object HiveHashFunction extends InterpretedHashFunction {
         var i = 0
         val length = array.numElements()
         while (i < length) {
-          result = (31 * result) + hash(array.get(i, elementType), 
elementType, 0).toInt
+          result = (31 * result) + hash(
+            array.get(i, elementType), elementType, 0, isCollationAware, 
legacyCollationAwareHashing
+          ).toInt
           i += 1
         }
         result
@@ -1048,7 +1207,11 @@ object HiveHashFunction extends InterpretedHashFunction {
         var i = 0
         val length = map.numElements()
         while (i < length) {
-          result += hash(keys.get(i, kt), kt, 0).toInt ^ hash(values.get(i, 
vt), vt, 0).toInt
+          result += hash(
+            keys.get(i, kt), kt, 0, isCollationAware, 
legacyCollationAwareHashing
+          ).toInt ^ hash(
+            values.get(i, vt), vt, 0, isCollationAware, 
legacyCollationAwareHashing
+          ).toInt
           i += 1
         }
         result
@@ -1064,7 +1227,10 @@ object HiveHashFunction extends InterpretedHashFunction {
         var i = 0
         val length = struct.numFields
         while (i < length) {
-          result = (31 * result) + hash(struct.get(i, types(i)), types(i), 
0).toInt
+          result = (31 * result) +
+            hash(
+              struct.get(i, types(i)), types(i), 0, isCollationAware, 
legacyCollationAwareHashing
+            ).toInt
           i += 1
         }
         result
@@ -1072,7 +1238,7 @@ object HiveHashFunction extends InterpretedHashFunction {
       case d: Decimal => normalizeDecimal(d.toJavaBigDecimal).hashCode()
       case timestamp: Long if dataType.isInstanceOf[TimestampType] => 
hashTimestamp(timestamp)
       case calendarInterval: CalendarInterval => 
hashCalendarInterval(calendarInterval)
-      case _ => super.hash(value, dataType, 0)
+      case _ => super.hash(value, dataType, 0, isCollationAware, 
legacyCollationAwareHashing)
     }
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 6e19a1d6bbc8..038105f9bfdf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -316,7 +316,9 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
    * Returns an expression that will produce a valid partition ID(i.e. 
non-negative and is less
    * than numPartitions) based on hashing expressions.
    */
-  def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), 
Literal(numPartitions))
+  def partitionIdExpression: Expression = Pmod(
+    new CollationAwareMurmur3Hash(expressions), Literal(numPartitions)
+  )
 
   override protected def withNewChildrenInternal(
     newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions 
= newChildren)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
index fc947386487a..38425f721236 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.XxHash64Function
 import 
org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers.{DOUBLE_NORMALIZER,
 FLOAT_NORMALIZER}
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
 
 // A helper class for HyperLogLogPlusPlus.
 class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable {
@@ -94,12 +93,16 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends 
Serializable {
     val value = dataType match {
       case FloatType => FLOAT_NORMALIZER.apply(_value)
       case DoubleType => DOUBLE_NORMALIZER.apply(_value)
-      case st: StringType if !st.supportsBinaryEquality =>
-        CollationFactory.getCollationKeyBytes(_value.asInstanceOf[UTF8String], 
st.collationId)
       case _ => _value
     }
     // Create the hashed value 'x'.
-    val x = XxHash64Function.hash(value, dataType, 42L)
+    val x = XxHash64Function.hash(
+      value,
+      dataType,
+      42L,
+      isCollationAware = true,
+      // legacyCollationAwareHashing only matters when isCollationAware is 
false.
+      legacyCollationAwareHashing = false)
 
     // Determine the index of the register we are going to use.
     val idx = (x >>> idxShift).toInt
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
index d2bdad2d880d..ba3d65fea027 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
@@ -39,7 +39,13 @@ class InternalRowComparableWrapper(val row: InternalRow, val 
dataTypes: Seq[Data
   private val structType = structTypeCache.get(dataTypes)
   private val ordering = orderingCache.get(dataTypes)
 
-  override def hashCode(): Int = Murmur3HashFunction.hash(row, structType, 
42L).toInt
+  override def hashCode(): Int = Murmur3HashFunction.hash(
+    row,
+    structType,
+    42L,
+    isCollationAware = true,
+    // legacyCollationAwareHashing only matters when isCollationAware is false.
+    legacyCollationAwareHashing = false).toInt
 
   override def equals(other: Any): Boolean = {
     if (!other.isInstanceOf[InternalRowComparableWrapper]) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 436b81309d93..da39b24813f9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -881,6 +881,15 @@ object SQLConf {
       .booleanConf
       .createWithDefault(Utils.isTesting)
 
+  lazy val COLLATION_AWARE_HASHING_ENABLED =
+    buildConf("spark.sql.legacy.collationAwareHashFunctions")
+      .internal()
+      .doc("Enables collation aware hashing (legacy behavior) for collated 
strings in " +
+        "Murmur3Hash and XxHash64 user-facing expressions.")
+      .version("4.0.1")
+      .booleanConf
+      .createWithDefault(false)
+
   val ICU_CASE_MAPPINGS_ENABLED =
     buildConf("spark.sql.icu.caseMappings.enabled")
       .doc("When enabled we use the ICU library (instead of the JVM) to 
implement case mappings" +
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
index 7cb4d5f12325..4f3efca4ad0f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
 import org.apache.spark.SparkFunSuite
 /* Implicit conversions */
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, 
Murmur3Hash, Pmod}
+import org.apache.spark.sql.catalyst.expressions.{CollationAwareMurmur3Hash, 
Expression, Literal, Pmod}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.types.IntegerType
 
@@ -322,7 +322,7 @@ class DistributionSuite extends SparkFunSuite {
     val expressions = Seq($"a", $"b", $"c")
     val hashPartitioning = HashPartitioning(expressions, 10)
     hashPartitioning.partitionIdExpression match {
-      case Pmod(Murmur3Hash(es, 42), Literal(10, IntegerType), _) =>
+      case Pmod(CollationAwareMurmur3Hash(es, 42), Literal(10, IntegerType), 
_) =>
         assert(es.length == expressions.length && es.zip(expressions).forall {
           case (l, r) => l.semanticEquals(r)
         })
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 019c953a3b0a..3104094e8543 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, 
ExpressionEncoder}
 import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, 
CollationFactory, DateTimeUtils, GenericArrayData, IntervalUtils}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{ArrayType, StructType, _}
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.ArrayImplicits._
@@ -91,7 +92,14 @@ class HashExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
 
   def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = {
     // Note : All expected hashes need to be computed using Hive 1.2.1
-    val actual = HiveHashFunction.hash(input, dataType, seed = 0)
+    val actual = HiveHashFunction.hash(
+      input,
+      dataType,
+      seed = 0,
+      isCollationAware = true,
+      // legacyCollationAwareHashing only matters when isCollationAware is 
false.
+      legacyCollationAwareHashing = false
+    )
 
     withClue(s"hash mismatch for input = `$input` of type `$dataType`.") {
       assert(actual == expected)
@@ -621,12 +629,18 @@ class HashExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
   }
 
   for (collation <- Seq("UTF8_LCASE", "UNICODE_CI", "UTF8_BINARY")) {
-    test(s"hash check for collated $collation strings") {
+    test(s"hash check for collated $collation strings - collation aware") {
       val s1 = "aaa"
       val s2 = "AAA"
 
-      val murmur3Hash1 = Murmur3Hash(Seq(Collate(Literal(s1), 
ResolvedCollation(collation))), 42)
-      val murmur3Hash2 = Murmur3Hash(Seq(Collate(Literal(s2), 
ResolvedCollation(collation))), 42)
+      val murmur3Hash1 = CollationAwareMurmur3Hash(
+        Seq(Collate(Literal(s1), ResolvedCollation(collation))),
+        42
+      )
+      val murmur3Hash2 = CollationAwareMurmur3Hash(
+        Seq(Collate(Literal(s2), ResolvedCollation(collation))),
+        42
+      )
 
       // Interpreted hash values for s1 and s2
       val interpretedHash1 = murmur3Hash1.eval()
@@ -644,6 +658,115 @@ class HashExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     }
   }
 
+  for (collation <- Seq("UTF8_LCASE", "UNICODE_CI", "UTF8_BINARY")) {
+    test(s"hash check for collated $collation strings - collation agnostic") {
+      val s1 = "aaa"
+      val s2 = "AAA"
+
+      val murmur3Hash1 = Murmur3Hash(Seq(Collate(Literal(s1), 
ResolvedCollation(collation))), 42)
+      val murmur3Hash2 = Murmur3Hash(Seq(Collate(Literal(s2), 
ResolvedCollation(collation))), 42)
+
+      // Interpreted hash values for s1 and s2
+      val interpretedHash1 = murmur3Hash1.eval()
+      val interpretedHash2 = murmur3Hash2.eval()
+
+      // Check that interpreted and codegen hashes are equal
+      checkEvaluation(murmur3Hash1, interpretedHash1)
+      checkEvaluation(murmur3Hash2, interpretedHash2)
+
+      assert(interpretedHash1 != interpretedHash2)
+
+      // Check that the hash computed is the same as the UTF8_BINARY version 
of it.
+      if (!CollationFactory.fetchCollation(collation).isUtf8BinaryType) {
+        Seq[String](s1, s2).foreach { s =>
+          val utf8BinaryStringExpr = Collate(Literal(s), 
ResolvedCollation("UTF8_BINARY"))
+          val murmur3HashBinary = Murmur3Hash(Seq(utf8BinaryStringExpr), 42)
+          val hashBinary = murmur3HashBinary.eval()
+          val murmur3Hash = Murmur3Hash(Seq(Collate(Literal(s), 
ResolvedCollation(collation))), 42)
+          val interpretedHash = murmur3Hash.eval()
+          assert(interpretedHash == hashBinary)
+        }
+      }
+    }
+  }
+
+  // Below we test the `Murmur3Hash` and `XxHash64` expressions for the old 
behavior before the fix.
+  // The expected values have been computed using the old implementation of 
the expression.
+  test("SPARK-52828: always collation aware hash expression") {
+    withSQLConf(SQLConf.COLLATION_AWARE_HASHING_ENABLED.key -> "true") {
+      val testCases = Seq[(String, String, Int, Long)](
+        // UTF8_BINARY
+        ("AAA", "UTF8_BINARY", 22125783, 3965631622972380050L),
+        ("AAA  ", "UTF8_BINARY", 399014599, 196039582279068044L),
+        ("aaa", "UTF8_BINARY", -1689629761, 2465751751477118478L),
+        ("aaa   ", "UTF8_BINARY", -1721438718, -2249763606958050730L),
+        // UTF8_BINARY_RTRIM
+        ("AAA", "UTF8_BINARY_RTRIM", -1493064582, 982928955165138586L),
+        ("AAA  ", "UTF8_BINARY_RTRIM", -1493064582, 982928955165138586L),
+        ("aaa", "UTF8_BINARY_RTRIM", 2132077201, -4940759280126763524L),
+        ("aaa   ", "UTF8_BINARY_RTRIM", 2132077201, -4940759280126763524L),
+        // UTF8_LCASE
+        ("AAA", "UTF8_LCASE", 2132077201, -4940759280126763524L),
+        ("AAA  ", "UTF8_LCASE", -619073595, -1146641051608991690L),
+        ("aaa", "UTF8_LCASE", 2132077201, -4940759280126763524L),
+        ("aaa   ", "UTF8_LCASE", -1498994355, -739345240752106297L),
+        // UTF8_LCASE_RTRIM
+        ("AAA", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L),
+        ("AAA  ", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L),
+        ("aaa", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L),
+        ("aaa   ", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L),
+        // UNICODE
+        ("AAA", "UNICODE", 128537619, 49663227161197117L),
+        ("AAA  ", "UNICODE", 82814175, 3618364417906061797L),
+        ("aaa", "UNICODE", -1822783942, 290910714161494507L),
+        ("aaa   ", "UNICODE", -896289340, 1025563887784400925L),
+        // UNICODE_RTRIM
+        ("AAA", "UNICODE_RTRIM", 128537619, 49663227161197117L),
+        ("AAA  ", "UNICODE_RTRIM", 128537619, 49663227161197117L),
+        ("aaa", "UNICODE_RTRIM", -1822783942, 290910714161494507L),
+        ("aaa   ", "UNICODE_RTRIM", -1822783942, 290910714161494507L),
+        // UNICODE_CI
+        ("AAA", "UNICODE_CI", -443043098, -6629915645815515868L),
+        ("AAA  ", "UNICODE_CI", 667473856, -3263604567598338200L),
+        ("aaa", "UNICODE_CI", -443043098, -6629915645815515868L),
+        ("aaa   ", "UNICODE_CI", -390983808, -5159733933636691741L),
+        // UNICODE_CI_RTRIM
+        ("AAA", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L),
+        ("AAA  ", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L),
+        ("aaa", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L),
+        ("aaa   ", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L)
+      )
+      testCases.foreach { case (str, collationName, expectedMurmur3, 
expectedXxHash64) =>
+        val stringExpr = Collate(Literal(str), 
ResolvedCollation(collationName))
+        val murmur3Expr = Murmur3Hash(Seq(stringExpr), 42)
+        checkEvaluation(murmur3Expr, expectedMurmur3)
+        val xxHash64Expr = XxHash64(Seq(stringExpr), 42L)
+        checkEvaluation(xxHash64Expr, expectedXxHash64)
+      }
+    }
+  }
+
+  test("SPARK-52828: backward-compatible hash API should reject UTF8_LCASE 
collation") {
+    // This test verifies that the legacy hash API throws an exception when 
used with
+    // collation-aware strings such as UTF8_LCASE. The assertion ensures we 
catch unsupported
+    // usage early via the internal assertion 
(SchemaUtils.hasNonUTF8BinaryCollation).
+    val expr_lcase = Collate(Literal("AAA"), ResolvedCollation("UTF8_LCASE"))
+    intercept[IllegalArgumentException] {
+      Murmur3HashFunction.hash(expr_lcase.eval(null), expr_lcase.dataType, 42)
+    }
+    intercept[IllegalArgumentException] {
+      XxHash64Function.hash(expr_lcase.eval(null), expr_lcase.dataType, 42)
+    }
+    intercept[IllegalArgumentException] {
+      HiveHashFunction.hash(expr_lcase.eval(null), expr_lcase.dataType, 42)
+    }
+
+    val expr_utf8bin = Collate(Literal("AAA"), 
ResolvedCollation("UTF8_BINARY"))
+    Murmur3HashFunction.hash(expr_utf8bin.eval(null), expr_utf8bin.dataType, 
42)
+    XxHash64Function.hash(expr_utf8bin.eval(null), expr_utf8bin.dataType, 42)
+    HiveHashFunction.hash(expr_utf8bin.eval(null), expr_utf8bin.dataType, 42)
+  }
+
   test("SPARK-18207: Compute hash for a lot of expressions") {
     def checkResult(schema: StructType, input: InternalRow): Unit = {
       val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala
index 6069127a0df9..0836823a994a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala
@@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.benchmark
 import scala.concurrent.duration._
 
 import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
+import org.apache.spark.sql.catalyst.expressions.Murmur3HashFunction
 import org.apache.spark.sql.catalyst.util.{CollationFactory, CollationSupport}
+import org.apache.spark.sql.types.StringType
 import org.apache.spark.unsafe.types.UTF8String
 
 abstract class CollationBenchmarkBase extends BenchmarkBase {
@@ -92,7 +94,7 @@ abstract class CollationBenchmarkBase extends BenchmarkBase {
         sublistStrings.foreach { _ =>
           utf8Strings.foreach { s =>
             (0 to 3).foreach { _ =>
-              collation.hashFunction.applyAsLong(s)
+              Murmur3HashFunction.hash(s, StringType(collationType), 42L, 
true, false).toInt
             }
           }
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to