This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f62724d2d8dc [SPARK-52828][SQL] Make hashing for collated strings collation agnostic f62724d2d8dc is described below commit f62724d2d8dc0f405c9d0dedcd0136bf4a4aa3b7 Author: Uros Bojanic <uros.boja...@databricks.com> AuthorDate: Thu Aug 7 16:37:22 2025 +0800 [SPARK-52828][SQL] Make hashing for collated strings collation agnostic ### What changes were proposed in this pull request? We change the behavior of the `Murmur3Hash` and `XxHash64` catalyst expressions to be collation agnostic (i.e. collation-unaware). Also, we introduce two new internal catalyst expressions: `CollationAwareMurmur3Hash` and `CollationAwareXxHash64`, which are collation aware and take the collation of the string into consideration when hashing collated strings. Furthermore, we replace `Murmur3Hash` and `XxHash64` in expressions where the hash expressions should be collation aware with `CollationAwareMurmur3Hash` and `CollationAwareXxHash64`. This is necessary for example when we do hash partitioning. Moreover, we change the way hashing is done for collated strings for the internal HiveHash expression to be consistent with the rest of the hashing expressions (the HiveHash expression is meant to always be collation-aware). Finally, we add a kill switch (the SQL config is `COLLATION_AGNOSTIC_HASHING_ENABLED`) that allows to recover the previous behavior of `Murmur3Hash` and `XxHash64` as user-facing expressions. The kill switch has no effect on the new collation aware hashing expressions, or the HiveHash expression, which are internal and need to follow the new collation aware behavior. ### Why are the changes needed? The `Murmur3Hash` and `XxHash64` catalyst expressions, when applied to collated strings, currently always take into consideration the collation of the string, that is they are collation aware. This is not the correct behavior, and these expressions should be collation agnostic by default instead. ### Does this PR introduce _any_ user-facing change? Yes, see the detailed explanation above. ### How was this patch tested? Updated existing tests in relevant suites: CollationFactorySuite, DistributionSuite, and HashExpressionsSuite. Also verified that the CollationSuite suite passes. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51521 from uros-db/collation-hashing. Lead-authored-by: Uros Bojanic <uros.boja...@databricks.com> Co-authored-by: Wenchen Fan <cloud0...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/util/CollationFactory.java | 51 +++-- .../spark/unsafe/types/CollationFactorySuite.scala | 8 +- .../spark/sql/catalyst/expressions/hash.scala | 228 ++++++++++++++++++--- .../sql/catalyst/plans/physical/partitioning.scala | 4 +- .../catalyst/util/HyperLogLogPlusPlusHelper.scala | 11 +- .../util/InternalRowComparableWrapper.scala | 8 +- .../org/apache/spark/sql/internal/SQLConf.scala | 9 + .../spark/sql/catalyst/DistributionSuite.scala | 4 +- .../expressions/HashExpressionsSuite.scala | 131 +++++++++++- .../execution/benchmark/CollationBenchmark.scala | 4 +- 10 files changed, 388 insertions(+), 70 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 4bcd75a73105..59c23064858d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -22,7 +22,6 @@ import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import java.util.function.BiFunction; -import java.util.function.ToLongFunction; import java.util.stream.Stream; import com.ibm.icu.text.CollationKey; @@ -125,10 +124,19 @@ public final class CollationFactory { public final String version; /** - * Collation sensitive hash function. Output for two UTF8Strings will be the same if they are - * equal according to the collation. + * Returns the sort key of the input UTF8String. Two UTF8String values are equal iff their + * sort keys are equal (compared as byte arrays). + * The sort key is defined as follows for collations without the RTRIM modifier: + * - UTF8_BINARY: It is the bytes of the string. + * - UTF8_LCASE: It is byte array we get by replacing all invalid UTF8 sequences with the + * Unicode replacement character and then converting all characters of the replaced string + * with their lowercase equivalents (the Greek capital and Greek small sigma both map to + * the Greek final sigma). + * - ICU collations: It is the byte array returned by the ICU library for the collated string. + * For strings with the RTRIM modifier, we right-trim the string and return the collation key + * of the resulting right-trimmed string. */ - public final ToLongFunction<UTF8String> hashFunction; + public final Function<UTF8String, byte[]> sortKeyFunction; /** * Potentially faster way than using comparator to compare two UTF8Strings for equality. @@ -182,7 +190,7 @@ public final class CollationFactory { Collator collator, Comparator<UTF8String> comparator, String version, - ToLongFunction<UTF8String> hashFunction, + Function<UTF8String, byte[]> sortKeyFunction, BiFunction<UTF8String, UTF8String, Boolean> equalsFunction, boolean isUtf8BinaryType, boolean isUtf8LcaseType, @@ -192,7 +200,7 @@ public final class CollationFactory { this.collator = collator; this.comparator = comparator; this.version = version; - this.hashFunction = hashFunction; + this.sortKeyFunction = sortKeyFunction; this.isUtf8BinaryType = isUtf8BinaryType; this.isUtf8LcaseType = isUtf8LcaseType; this.equalsFunction = equalsFunction; @@ -581,18 +589,18 @@ public final class CollationFactory { protected Collation buildCollation() { if (caseSensitivity == CaseSensitivity.UNSPECIFIED) { Comparator<UTF8String> comparator; - ToLongFunction<UTF8String> hashFunction; + Function<UTF8String, byte[]> sortKeyFunction; BiFunction<UTF8String, UTF8String, Boolean> equalsFunction; boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE; if (spaceTrimming == SpaceTrimming.NONE) { comparator = UTF8String::binaryCompare; - hashFunction = s -> (long) s.hashCode(); + sortKeyFunction = s -> s.getBytes(); equalsFunction = UTF8String::equals; } else { comparator = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).binaryCompare( applyTrimmingPolicy(s2, spaceTrimming)); - hashFunction = s -> (long) applyTrimmingPolicy(s, spaceTrimming).hashCode(); + sortKeyFunction = s -> applyTrimmingPolicy(s, spaceTrimming).getBytes(); equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).equals( applyTrimmingPolicy(s2, spaceTrimming)); } @@ -603,25 +611,25 @@ public final class CollationFactory { null, comparator, CollationSpecICU.ICU_VERSION, - hashFunction, + sortKeyFunction, equalsFunction, /* isUtf8BinaryType = */ true, /* isUtf8LcaseType = */ false, spaceTrimming != SpaceTrimming.NONE); } else { Comparator<UTF8String> comparator; - ToLongFunction<UTF8String> hashFunction; + Function<UTF8String, byte[]> sortKeyFunction; if (spaceTrimming == SpaceTrimming.NONE) { comparator = CollationAwareUTF8String::compareLowerCase; - hashFunction = s -> - (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); + sortKeyFunction = s -> + CollationAwareUTF8String.lowerCaseCodePoints(s).getBytes(); } else { comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase( applyTrimmingPolicy(s1, spaceTrimming), applyTrimmingPolicy(s2, spaceTrimming)); - hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints( - applyTrimmingPolicy(s, spaceTrimming)).hashCode(); + sortKeyFunction = s -> CollationAwareUTF8String.lowerCaseCodePoints( + applyTrimmingPolicy(s, spaceTrimming)).getBytes(); } return new Collation( @@ -630,7 +638,7 @@ public final class CollationFactory { null, comparator, CollationSpecICU.ICU_VERSION, - hashFunction, + sortKeyFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, /* isUtf8BinaryType = */ false, /* isUtf8LcaseType = */ true, @@ -1013,19 +1021,18 @@ public final class CollationFactory { collator.freeze(); Comparator<UTF8String> comparator; - ToLongFunction<UTF8String> hashFunction; + Function<UTF8String, byte[]> sortKeyFunction; if (spaceTrimming == SpaceTrimming.NONE) { - hashFunction = s -> (long) collator.getCollationKey( - s.toValidString()).hashCode(); comparator = (s1, s2) -> collator.compare(s1.toValidString(), s2.toValidString()); + sortKeyFunction = s -> collator.getCollationKey(s.toValidString()).toByteArray(); } else { comparator = (s1, s2) -> collator.compare( applyTrimmingPolicy(s1, spaceTrimming).toValidString(), applyTrimmingPolicy(s2, spaceTrimming).toValidString()); - hashFunction = s -> (long) collator.getCollationKey( - applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode(); + sortKeyFunction = s -> collator.getCollationKey( + applyTrimmingPolicy(s, spaceTrimming).toValidString()).toByteArray(); } return new Collation( @@ -1034,7 +1041,7 @@ public final class CollationFactory { collator, comparator, ICU_VERSION, - hashFunction, + sortKeyFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, /* isUtf8BinaryType = */ false, /* isUtf8LcaseType = */ false, diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 8e9d33efe7a6..ef1687f4376d 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -139,7 +139,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig case class CollationTestCase[R](collationName: String, s1: String, s2: String, expectedResult: R) - test("collation aware equality and hash") { + test("collation aware equality and sort key") { val checks = Seq( CollationTestCase("UTF8_BINARY", "aaa", "aaa", true), CollationTestCase("UTF8_BINARY", "aaa", "AAA", false), @@ -194,9 +194,9 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig assert(collation.equalsFunction(toUTF8(testCase.s1), toUTF8(testCase.s2)) == testCase.expectedResult) - val hash1 = collation.hashFunction.applyAsLong(toUTF8(testCase.s1)) - val hash2 = collation.hashFunction.applyAsLong(toUTF8(testCase.s2)) - assert((hash1 == hash2) == testCase.expectedResult) + val sortKey1 = collation.sortKeyFunction.apply(toUTF8(testCase.s1)).asInstanceOf[Array[Byte]] + val sortKey2 = collation.sortKeyFunction.apply(toUTF8(testCase.s2)).asInstanceOf[Array[Byte]] + assert(sortKey1.sameElements(sortKey2) == testCase.expectedResult) }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index c7a693919096..88e22a91a64a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -285,6 +286,11 @@ abstract class HashExpression[E] extends Expression { override def nullable: Boolean = false + protected def isCollationAware: Boolean + + protected lazy val legacyCollationAwareHashing: Boolean = + SQLConf.get.getConf(SQLConf.COLLATION_AWARE_HASHING_ENABLED) + private def hasMapType(dt: DataType): Boolean = { dt.existsRecursively(_.isInstanceOf[MapType]) } @@ -439,14 +445,43 @@ abstract class HashExpression[E] extends Expression { val numBytes = s"$input.numBytes()" s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" } else { - val stringHash = ctx.freshName("stringHash") - s""" - long $stringHash = CollationFactory.fetchCollation(${stringType.collationId}) - .hashFunction.applyAsLong($input); - $result = $hasherClassName.hashLong($stringHash, $result); - """ + if (isCollationAware) { + val key = ctx.freshName("key") + val offset = "Platform.BYTE_ARRAY_OFFSET" + s""" + byte[] $key = (byte[]) CollationFactory.fetchCollation(${stringType.collationId}) + .sortKeyFunction.apply($input); + $result = $hasherClassName.hashUnsafeBytes($key, $offset, $key.length, $result); + """ + } else if (legacyCollationAwareHashing) { + val collation = CollationFactory.fetchCollation(stringType.collationId) + val stringHash = ctx.freshName("stringHash") + if (collation.isUtf8BinaryType || collation.isUtf8LcaseType) { + s""" + long $stringHash = UTF8String.fromBytes((byte[]) CollationFactory + .fetchCollation(${stringType.collationId}).sortKeyFunction.apply($input)).hashCode(); + $result = $hasherClassName.hashLong($stringHash, $result); + """ + } else if (collation.supportsSpaceTrimming) { + s""" + long $stringHash = CollationFactory.fetchCollation(${stringType.collationId}) + .getCollator().getCollationKey($input.trimRight().toValidString()).hashCode(); + $result = $hasherClassName.hashLong($stringHash, $result); + """ + } else { + s""" + long $stringHash = CollationFactory.fetchCollation(${stringType.collationId}) + .getCollator().getCollationKey($input.toValidString()).hashCode(); + $result = $hasherClassName.hashLong($stringHash, $result); + """ + } + } else { + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + } } - } protected def genHashForMap( @@ -556,10 +591,38 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long /** - * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity - * of input `value`. + * This method is intended for callers using the old hash API and preserves compatibility for + * supported data types. It must only be used for data types that do not include collated strings + * or complex types (e.g., structs, arrays, maps) that may contain collated strings. + * + * The caller is responsible for ensuring that `dataType` does not involve collation-aware fields. + * This is validated via an internal assertion. + * + * @throws IllegalArgumentException if `dataType` contains non-UTF8 binary collation. */ def hash(value: Any, dataType: DataType, seed: Long): Long = { + require(!SchemaUtils.hasNonUTF8BinaryCollation(dataType)) + // For UTF8_BINARY, hashing behavior is the same regardless of the isCollationAware flag. + hash( + value = value, + dataType = dataType, + seed = seed, + isCollationAware = false, + legacyCollationAwareHashing = false) + } + + /** + * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity + * of input `value`. The `isCollationAware` boolean flag indicates whether hashing should take + * a string's collation into account. If not, the bytes of the string are hashed, otherwise the + * collation key of the string is hashed. + */ + def hash( + value: Any, + dataType: DataType, + seed: Long, + isCollationAware: Boolean, + legacyCollationAwareHashing: Boolean): Long = { value match { case null => seed case b: Boolean => hashInt(if (b) 1 else 0, seed) @@ -585,12 +648,25 @@ abstract class InterpretedHashFunction { case s: UTF8String => val st = dataType.asInstanceOf[StringType] if (st.supportsBinaryEquality) { - hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes, seed) } else { - val stringHash = CollationFactory - .fetchCollation(st.collationId) - .hashFunction.applyAsLong(s) - hashLong(stringHash, seed) + if (isCollationAware) { + val key = CollationFactory.fetchCollation(st.collationId).sortKeyFunction.apply(s) + .asInstanceOf[Array[Byte]] + hashUnsafeBytes(key, Platform.BYTE_ARRAY_OFFSET, key.length, seed) + } else if (legacyCollationAwareHashing) { + val collation = CollationFactory.fetchCollation(st.collationId) + val stringHash = if (collation.isUtf8BinaryType || collation.isUtf8LcaseType) { + UTF8String.fromBytes(collation.sortKeyFunction.apply(s)).hashCode + } else if (collation.supportsSpaceTrimming) { + collation.getCollator.getCollationKey(s.trimRight.toValidString).hashCode + } else { + collation.getCollator.getCollationKey(s.toValidString).hashCode + } + hashLong(stringHash, seed) + } else { + hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes, seed) + } } case array: ArrayData => @@ -601,7 +677,12 @@ abstract class InterpretedHashFunction { var result = seed var i = 0 while (i < array.numElements()) { - result = hash(array.get(i, elementType), elementType, result) + result = hash( + array.get(i, elementType), + elementType, + result, + isCollationAware, + legacyCollationAwareHashing) i += 1 } result @@ -618,8 +699,18 @@ abstract class InterpretedHashFunction { var result = seed var i = 0 while (i < map.numElements()) { - result = hash(keys.get(i, kt), kt, result) - result = hash(values.get(i, vt), vt, result) + result = hash( + keys.get(i, kt), + kt, + result, + isCollationAware, + legacyCollationAwareHashing) + result = hash( + values.get(i, vt), + vt, + result, + isCollationAware, + legacyCollationAwareHashing) i += 1 } result @@ -634,7 +725,12 @@ abstract class InterpretedHashFunction { var i = 0 val len = struct.numFields while (i < len) { - result = hash(struct.get(i, types(i)), types(i), result) + result = hash( + struct.get(i, types(i)), + types(i), + result, + isCollationAware, + legacyCollationAwareHashing) i += 1 } result @@ -666,8 +762,12 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpress override protected def hasherClassName: String = classOf[Murmur3_x86_32].getName + override protected def isCollationAware: Boolean = false + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { - Murmur3HashFunction.hash(value, dataType, seed).toInt + Murmur3HashFunction.hash( + value, dataType, seed, isCollationAware, legacyCollationAwareHashing + ).toInt } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Murmur3Hash = @@ -688,6 +788,29 @@ object Murmur3HashFunction extends InterpretedHashFunction { } } +case class CollationAwareMurmur3Hash(children: Seq[Expression], seed: Int) + extends HashExpression[Int] +{ + def this(arguments: Seq[Expression]) = this(arguments, 42) + + override def dataType: DataType = IntegerType + + override def prettyName: String = "collation_aware_hash" + + override protected def hasherClassName: String = classOf[Murmur3_x86_32].getName + + override protected def isCollationAware: Boolean = true + + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { + Murmur3HashFunction.hash( + value, dataType, seed, isCollationAware, legacyCollationAwareHashing + ).toInt + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): + CollationAwareMurmur3Hash = copy(children = newChildren) +} + /** * A xxHash64 64-bit hash expression. */ @@ -710,8 +833,10 @@ case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpressio override protected def hasherClassName: String = classOf[XXH64].getName + override protected def isCollationAware: Boolean = false + override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { - XxHash64Function.hash(value, dataType, seed) + XxHash64Function.hash(value, dataType, seed, isCollationAware, legacyCollationAwareHashing) } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): XxHash64 = @@ -728,6 +853,28 @@ object XxHash64Function extends InterpretedHashFunction { } } +case class CollationAwareXxHash64(children: Seq[Expression], seed: Long) + extends HashExpression[Long] +{ + def this(arguments: Seq[Expression]) = this(arguments, 42L) + + override def dataType: DataType = LongType + + override def prettyName: String = "collation_aware_xxhash64" + + override protected def hasherClassName: String = classOf[XXH64].getName + + override protected def isCollationAware: Boolean = true + + override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { + XxHash64Function.hash( + value, dataType, seed, isCollationAware, legacyCollationAwareHashing) + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): + CollationAwareXxHash64 = copy(children = newChildren) +} + /** * Simulates Hive's hashing function from Hive v1.2.1 at * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() @@ -748,8 +895,12 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def hasherClassName: String = classOf[HiveHasher].getName + override protected def isCollationAware: Boolean = true + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { - HiveHashFunction.hash(value, dataType, this.seed).toInt + HiveHashFunction.hash( + value, dataType, this.seed, isCollationAware, legacyCollationAwareHashing + ).toInt } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -835,17 +986,18 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality) { + if (stringType.supportsBinaryEquality || !isCollationAware) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" } else { - val stringHash = ctx.freshName("stringHash") + val key = ctx.freshName("key") + val offset = Platform.BYTE_ARRAY_OFFSET s""" - long $stringHash = CollationFactory.fetchCollation(${stringType.collationId}) - .hashFunction.applyAsLong($input); - $result = $hasherClassName.hashLong($stringHash); + byte[] $key = (byte[]) CollationFactory.fetchCollation(${stringType.collationId}) + .sortKeyFunction.apply($input); + $result = $hasherClassName.hashUnsafeBytes($key, $offset, $key.length, $result); """ } } @@ -1028,7 +1180,12 @@ object HiveHashFunction extends InterpretedHashFunction { (result * 37) + nanoSeconds } - override def hash(value: Any, dataType: DataType, seed: Long): Long = { + override def hash( + value: Any, + dataType: DataType, + seed: Long, + isCollationAware: Boolean, + legacyCollationAwareHashing: Boolean): Long = { value match { case null => 0 case array: ArrayData => @@ -1041,7 +1198,9 @@ object HiveHashFunction extends InterpretedHashFunction { var i = 0 val length = array.numElements() while (i < length) { - result = (31 * result) + hash(array.get(i, elementType), elementType, 0).toInt + result = (31 * result) + hash( + array.get(i, elementType), elementType, 0, isCollationAware, legacyCollationAwareHashing + ).toInt i += 1 } result @@ -1060,7 +1219,11 @@ object HiveHashFunction extends InterpretedHashFunction { var i = 0 val length = map.numElements() while (i < length) { - result += hash(keys.get(i, kt), kt, 0).toInt ^ hash(values.get(i, vt), vt, 0).toInt + result += hash( + keys.get(i, kt), kt, 0, isCollationAware, legacyCollationAwareHashing + ).toInt ^ hash( + values.get(i, vt), vt, 0, isCollationAware, legacyCollationAwareHashing + ).toInt i += 1 } result @@ -1076,7 +1239,10 @@ object HiveHashFunction extends InterpretedHashFunction { var i = 0 val length = struct.numFields while (i < length) { - result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt + result = (31 * result) + + hash( + struct.get(i, types(i)), types(i), 0, isCollationAware, legacyCollationAwareHashing + ).toInt i += 1 } result @@ -1084,7 +1250,7 @@ object HiveHashFunction extends InterpretedHashFunction { case d: Decimal => normalizeDecimal(d.toJavaBigDecimal).hashCode() case timestamp: Long if dataType.isInstanceOf[TimestampType] => hashTimestamp(timestamp) case calendarInterval: CalendarInterval => hashCalendarInterval(calendarInterval) - case _ => super.hash(value, dataType, 0) + case _ => super.hash(value, dataType, 0, isCollationAware, legacyCollationAwareHashing) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 6e19a1d6bbc8..038105f9bfdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -316,7 +316,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less * than numPartitions) based on hashing expressions. */ - def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) + def partitionIdExpression: Expression = Pmod( + new CollationAwareMurmur3Hash(expressions), Literal(numPartitions) + ) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala index fc947386487a..38425f721236 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.XxHash64Function import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers.{DOUBLE_NORMALIZER, FLOAT_NORMALIZER} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String // A helper class for HyperLogLogPlusPlus. class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable { @@ -94,12 +93,16 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable { val value = dataType match { case FloatType => FLOAT_NORMALIZER.apply(_value) case DoubleType => DOUBLE_NORMALIZER.apply(_value) - case st: StringType if !st.supportsBinaryEquality => - CollationFactory.getCollationKeyBytes(_value.asInstanceOf[UTF8String], st.collationId) case _ => _value } // Create the hashed value 'x'. - val x = XxHash64Function.hash(value, dataType, 42L) + val x = XxHash64Function.hash( + value, + dataType, + 42L, + isCollationAware = true, + // legacyCollationAwareHashing only matters when isCollationAware is false. + legacyCollationAwareHashing = false) // Determine the index of the register we are going to use. val idx = (x >>> idxShift).toInt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index d2bdad2d880d..ba3d65fea027 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -39,7 +39,13 @@ class InternalRowComparableWrapper(val row: InternalRow, val dataTypes: Seq[Data private val structType = structTypeCache.get(dataTypes) private val ordering = orderingCache.get(dataTypes) - override def hashCode(): Int = Murmur3HashFunction.hash(row, structType, 42L).toInt + override def hashCode(): Int = Murmur3HashFunction.hash( + row, + structType, + 42L, + isCollationAware = true, + // legacyCollationAwareHashing only matters when isCollationAware is false. + legacyCollationAwareHashing = false).toInt override def equals(other: Any): Boolean = { if (!other.isInstanceOf[InternalRowComparableWrapper]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0138770e3242..0427c84bb8d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1025,6 +1025,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + lazy val COLLATION_AWARE_HASHING_ENABLED = + buildConf("spark.sql.legacy.collationAwareHashFunctions") + .internal() + .doc("Enables collation aware hashing (legacy behavior) for collated strings in " + + "Murmur3Hash and XxHash64 user-facing expressions.") + .version("4.0.1") + .booleanConf + .createWithDefault(false) + val ICU_CASE_MAPPINGS_ENABLED = buildConf("spark.sql.icu.caseMappings.enabled") .doc("When enabled we use the ICU library (instead of the JVM) to implement case mappings" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 7cb4d5f12325..4f3efca4ad0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, Murmur3Hash, Pmod} +import org.apache.spark.sql.catalyst.expressions.{CollationAwareMurmur3Hash, Expression, Literal, Pmod} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType @@ -322,7 +322,7 @@ class DistributionSuite extends SparkFunSuite { val expressions = Seq($"a", $"b", $"c") val hashPartitioning = HashPartitioning(expressions, 10) hashPartitioning.partitionIdExpression match { - case Pmod(Murmur3Hash(es, 42), Literal(10, IntegerType), _) => + case Pmod(CollationAwareMurmur3Hash(es, 42), Literal(10, IntegerType), _) => assert(es.length == expressions.length && es.zip(expressions).forall { case (l, r) => l.semanticEquals(r) }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index c64b94703288..c084b67d4d57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CollationFactory, DateTimeUtils, GenericArrayData, IntervalUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, StructType, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -91,7 +92,14 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { // Note : All expected hashes need to be computed using Hive 1.2.1 - val actual = HiveHashFunction.hash(input, dataType, seed = 0) + val actual = HiveHashFunction.hash( + input, + dataType, + seed = 0, + isCollationAware = true, + // legacyCollationAwareHashing only matters when isCollationAware is false. + legacyCollationAwareHashing = false + ) withClue(s"hash mismatch for input = `$input` of type `$dataType`.") { assert(actual == expected) @@ -621,12 +629,18 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } for (collation <- Seq("UTF8_LCASE", "UNICODE_CI", "UTF8_BINARY")) { - test(s"hash check for collated $collation strings") { + test(s"hash check for collated $collation strings - collation aware") { val s1 = "aaa" val s2 = "AAA" - val murmur3Hash1 = Murmur3Hash(Seq(Collate(Literal(s1), ResolvedCollation(collation))), 42) - val murmur3Hash2 = Murmur3Hash(Seq(Collate(Literal(s2), ResolvedCollation(collation))), 42) + val murmur3Hash1 = CollationAwareMurmur3Hash( + Seq(Collate(Literal(s1), ResolvedCollation(collation))), + 42 + ) + val murmur3Hash2 = CollationAwareMurmur3Hash( + Seq(Collate(Literal(s2), ResolvedCollation(collation))), + 42 + ) // Interpreted hash values for s1 and s2 val interpretedHash1 = murmur3Hash1.eval() @@ -644,6 +658,115 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + for (collation <- Seq("UTF8_LCASE", "UNICODE_CI", "UTF8_BINARY")) { + test(s"hash check for collated $collation strings - collation agnostic") { + val s1 = "aaa" + val s2 = "AAA" + + val murmur3Hash1 = Murmur3Hash(Seq(Collate(Literal(s1), ResolvedCollation(collation))), 42) + val murmur3Hash2 = Murmur3Hash(Seq(Collate(Literal(s2), ResolvedCollation(collation))), 42) + + // Interpreted hash values for s1 and s2 + val interpretedHash1 = murmur3Hash1.eval() + val interpretedHash2 = murmur3Hash2.eval() + + // Check that interpreted and codegen hashes are equal + checkEvaluation(murmur3Hash1, interpretedHash1) + checkEvaluation(murmur3Hash2, interpretedHash2) + + assert(interpretedHash1 != interpretedHash2) + + // Check that the hash computed is the same as the UTF8_BINARY version of it. + if (!CollationFactory.fetchCollation(collation).isUtf8BinaryType) { + Seq[String](s1, s2).foreach { s => + val utf8BinaryStringExpr = Collate(Literal(s), ResolvedCollation("UTF8_BINARY")) + val murmur3HashBinary = Murmur3Hash(Seq(utf8BinaryStringExpr), 42) + val hashBinary = murmur3HashBinary.eval() + val murmur3Hash = Murmur3Hash(Seq(Collate(Literal(s), ResolvedCollation(collation))), 42) + val interpretedHash = murmur3Hash.eval() + assert(interpretedHash == hashBinary) + } + } + } + } + + // Below we test the `Murmur3Hash` and `XxHash64` expressions for the old behavior before the fix. + // The expected values have been computed using the old implementation of the expression. + test("SPARK-52828: always collation aware hash expression") { + withSQLConf(SQLConf.COLLATION_AWARE_HASHING_ENABLED.key -> "true") { + val testCases = Seq[(String, String, Int, Long)]( + // UTF8_BINARY + ("AAA", "UTF8_BINARY", 22125783, 3965631622972380050L), + ("AAA ", "UTF8_BINARY", 399014599, 196039582279068044L), + ("aaa", "UTF8_BINARY", -1689629761, 2465751751477118478L), + ("aaa ", "UTF8_BINARY", -1721438718, -2249763606958050730L), + // UTF8_BINARY_RTRIM + ("AAA", "UTF8_BINARY_RTRIM", -1493064582, 982928955165138586L), + ("AAA ", "UTF8_BINARY_RTRIM", -1493064582, 982928955165138586L), + ("aaa", "UTF8_BINARY_RTRIM", 2132077201, -4940759280126763524L), + ("aaa ", "UTF8_BINARY_RTRIM", 2132077201, -4940759280126763524L), + // UTF8_LCASE + ("AAA", "UTF8_LCASE", 2132077201, -4940759280126763524L), + ("AAA ", "UTF8_LCASE", -619073595, -1146641051608991690L), + ("aaa", "UTF8_LCASE", 2132077201, -4940759280126763524L), + ("aaa ", "UTF8_LCASE", -1498994355, -739345240752106297L), + // UTF8_LCASE_RTRIM + ("AAA", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L), + ("AAA ", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L), + ("aaa", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L), + ("aaa ", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L), + // UNICODE + ("AAA", "UNICODE", 128537619, 49663227161197117L), + ("AAA ", "UNICODE", 82814175, 3618364417906061797L), + ("aaa", "UNICODE", -1822783942, 290910714161494507L), + ("aaa ", "UNICODE", -896289340, 1025563887784400925L), + // UNICODE_RTRIM + ("AAA", "UNICODE_RTRIM", 128537619, 49663227161197117L), + ("AAA ", "UNICODE_RTRIM", 128537619, 49663227161197117L), + ("aaa", "UNICODE_RTRIM", -1822783942, 290910714161494507L), + ("aaa ", "UNICODE_RTRIM", -1822783942, 290910714161494507L), + // UNICODE_CI + ("AAA", "UNICODE_CI", -443043098, -6629915645815515868L), + ("AAA ", "UNICODE_CI", 667473856, -3263604567598338200L), + ("aaa", "UNICODE_CI", -443043098, -6629915645815515868L), + ("aaa ", "UNICODE_CI", -390983808, -5159733933636691741L), + // UNICODE_CI_RTRIM + ("AAA", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L), + ("AAA ", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L), + ("aaa", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L), + ("aaa ", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L) + ) + testCases.foreach { case (str, collationName, expectedMurmur3, expectedXxHash64) => + val stringExpr = Collate(Literal(str), ResolvedCollation(collationName)) + val murmur3Expr = Murmur3Hash(Seq(stringExpr), 42) + checkEvaluation(murmur3Expr, expectedMurmur3) + val xxHash64Expr = XxHash64(Seq(stringExpr), 42L) + checkEvaluation(xxHash64Expr, expectedXxHash64) + } + } + } + + test("SPARK-52828: backward-compatible hash API should reject UTF8_LCASE collation") { + // This test verifies that the legacy hash API throws an exception when used with + // collation-aware strings such as UTF8_LCASE. The assertion ensures we catch unsupported + // usage early via the internal assertion (SchemaUtils.hasNonUTF8BinaryCollation). + val expr_lcase = Collate(Literal("AAA"), ResolvedCollation("UTF8_LCASE")) + intercept[IllegalArgumentException] { + Murmur3HashFunction.hash(expr_lcase.eval(null), expr_lcase.dataType, 42) + } + intercept[IllegalArgumentException] { + XxHash64Function.hash(expr_lcase.eval(null), expr_lcase.dataType, 42) + } + intercept[IllegalArgumentException] { + HiveHashFunction.hash(expr_lcase.eval(null), expr_lcase.dataType, 42) + } + + val expr_utf8bin = Collate(Literal("AAA"), ResolvedCollation("UTF8_BINARY")) + Murmur3HashFunction.hash(expr_utf8bin.eval(null), expr_utf8bin.dataType, 42) + XxHash64Function.hash(expr_utf8bin.eval(null), expr_utf8bin.dataType, 42) + HiveHashFunction.hash(expr_utf8bin.eval(null), expr_utf8bin.dataType, 42) + } + test("SPARK-18207: Compute hash for a lot of expressions") { def checkResult(schema: StructType, input: InternalRow): Unit = { val exprs = schema.fields.zipWithIndex.map { case (f, i) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala index 6069127a0df9..0836823a994a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.benchmark import scala.concurrent.duration._ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.sql.catalyst.expressions.Murmur3HashFunction import org.apache.spark.sql.catalyst.util.{CollationFactory, CollationSupport} +import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String abstract class CollationBenchmarkBase extends BenchmarkBase { @@ -92,7 +94,7 @@ abstract class CollationBenchmarkBase extends BenchmarkBase { sublistStrings.foreach { _ => utf8Strings.foreach { s => (0 to 3).foreach { _ => - collation.hashFunction.applyAsLong(s) + Murmur3HashFunction.hash(s, StringType(collationType), 42L, true, false).toInt } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org