This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ec0ee863562e [SPARK-48699][SQL] Refine collation API
ec0ee863562e is described below
commit ec0ee863562e7ad6c27a8632f9de804bd012e9ec
Author: Uros Bojanic <[email protected]>
AuthorDate: Wed Jun 26 15:46:51 2024 +0800
[SPARK-48699][SQL] Refine collation API
### What changes were proposed in this pull request?
- All collation-related public API (i.e. `CollationFactory`,
`CollationAwareUTF8String`, etc.) should generally use `UTF8String` instead of
(Java) `String`.
- Comparator for `UTF8_LCASE` collation should use
`UTF8String.binaryCompare` instead of `Java.compareTo`.
- Added ASCII “fast” path for `Lower` & `Upper` expressions.
### Why are the changes needed?
- Fix collation API.
- Fix a bug in the `UTF8_LCASE` comparator.
- Enhance performance for `Lower` & `Upper` expressions.
### Does this PR introduce _any_ user-facing change?
Yes.
### How was this patch tested?
- Added more tests to `CollationSupportSuite` to verify that `UTF8`-based
comparators indeed work properly.
- Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47014 from uros-db/fix-collation-api.
Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/util/CollationAwareUTF8String.java | 79 ++++++++++++++--------
.../spark/sql/catalyst/util/CollationFactory.java | 28 ++++----
.../org/apache/spark/unsafe/types/UTF8String.java | 39 ++++++-----
.../spark/unsafe/types/CollationSupportSuite.java | 19 +++++-
4 files changed, 103 insertions(+), 62 deletions(-)
diff --git
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
index 934572cd0d67..7c0ffc73aa9d 100644
---
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
+++
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
@@ -228,7 +228,7 @@ public class CollationAwareUTF8String {
* @return An integer representing the comparison result.
*/
private static int compareLowerCaseSlow(final UTF8String left, final
UTF8String right) {
- return
lowerCaseCodePoints(left.toString()).compareTo(lowerCaseCodePoints(right.toString()));
+ return lowerCaseCodePoints(left).binaryCompare(lowerCaseCodePoints(right));
}
public static UTF8String replace(final UTF8String src, final UTF8String
search,
@@ -339,11 +339,15 @@ public class CollationAwareUTF8String {
* @return the uppercase string
*/
public static UTF8String toUpperCase(final UTF8String target) {
- return UTF8String.fromString(toUpperCase(target.toString()));
+ if (target.isFullAscii()) return target.toUpperCaseAscii();
+ return toUpperCaseSlow(target);
}
- public static String toUpperCase(final String target) {
- return UCharacter.toUpperCase(target);
+ private static UTF8String toUpperCaseSlow(final UTF8String target) {
+ // Note: In order to achieve the desired behaviour, we use the ICU
UCharacter class to
+ // convert the string to uppercase, which only accepts a Java strings as
input.
+ // TODO(SPARK-48715): All UTF8String -> String conversions should use
`makeValid`
+ return UTF8String.fromString(UCharacter.toUpperCase(target.toString()));
}
/**
@@ -353,13 +357,17 @@ public class CollationAwareUTF8String {
* @return the uppercase string
*/
public static UTF8String toUpperCase(final UTF8String target, final int
collationId) {
- return UTF8String.fromString(toUpperCase(target.toString(), collationId));
+ if (target.isFullAscii()) return target.toUpperCaseAscii();
+ return toUpperCaseSlow(target, collationId);
}
- public static String toUpperCase(final String target, final int collationId)
{
+ private static UTF8String toUpperCaseSlow(final UTF8String target, final int
collationId) {
+ // Note: In order to achieve the desired behaviour, we use the ICU
UCharacter class to
+ // convert the string to uppercase, which only accepts a Java strings as
input.
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
- return UCharacter.toUpperCase(locale, target);
+ // TODO(SPARK-48715): All UTF8String -> String conversions should use
`makeValid`
+ return UTF8String.fromString(UCharacter.toUpperCase(locale,
target.toString()));
}
/**
@@ -369,10 +377,15 @@ public class CollationAwareUTF8String {
* @return the lowercase string
*/
public static UTF8String toLowerCase(final UTF8String target) {
- return UTF8String.fromString(toLowerCase(target.toString()));
+ if (target.isFullAscii()) return target.toLowerCaseAscii();
+ return toLowerCaseSlow(target);
}
- public static String toLowerCase(final String target) {
- return UCharacter.toLowerCase(target);
+
+ private static UTF8String toLowerCaseSlow(final UTF8String target) {
+ // Note: In order to achieve the desired behaviour, we use the ICU
UCharacter class to
+ // convert the string to lowercase, which only accepts a Java strings as
input.
+ // TODO(SPARK-48715): All UTF8String -> String conversions should use
`makeValid`
+ return UTF8String.fromString(UCharacter.toLowerCase(target.toString()));
}
/**
@@ -382,12 +395,17 @@ public class CollationAwareUTF8String {
* @return the lowercase string
*/
public static UTF8String toLowerCase(final UTF8String target, final int
collationId) {
- return UTF8String.fromString(toLowerCase(target.toString(), collationId));
+ if (target.isFullAscii()) return target.toLowerCaseAscii();
+ return toLowerCaseSlow(target, collationId);
}
- public static String toLowerCase(final String target, final int collationId)
{
+
+ private static UTF8String toLowerCaseSlow(final UTF8String target, final int
collationId) {
+ // Note: In order to achieve the desired behaviour, we use the ICU
UCharacter class to
+ // convert the string to lowercase, which only accepts a Java strings as
input.
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
- return UCharacter.toLowerCase(locale, target);
+ // TODO(SPARK-48715): All UTF8String -> String conversions should use
`makeValid`
+ return UTF8String.fromString(UCharacter.toLowerCase(locale,
target.toString()));
}
/**
@@ -424,36 +442,41 @@ public class CollationAwareUTF8String {
* @param target The target string to convert to lowercase.
* @return The string converted to lowercase in a context-unaware manner.
*/
- public static String lowerCaseCodePoints(final String target) {
+ public static UTF8String lowerCaseCodePoints(final UTF8String target) {
+ if (target.isFullAscii()) return target.toLowerCaseAscii();
+ return lowerCaseCodePointsSlow(target);
+ }
+
+ private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) {
+ // TODO(SPARK-48715): All UTF8String -> String conversions should use
`makeValid`
+ String targetString = target.toString();
StringBuilder sb = new StringBuilder();
- for (int i = 0; i < target.length(); ++i) {
- lowercaseCodePoint(target.codePointAt(i), sb);
+ for (int i = 0; i < targetString.length(); ++i) {
+ lowercaseCodePoint(targetString.codePointAt(i), sb);
}
- return sb.toString();
+ return UTF8String.fromString(sb.toString());
}
/**
* Convert the input string to titlecase using the ICU root locale rules.
*/
public static UTF8String toTitleCase(final UTF8String target) {
- return UTF8String.fromString(toTitleCase(target.toString()));
- }
-
- public static String toTitleCase(final String target) {
- return UCharacter.toTitleCase(target, BreakIterator.getWordInstance());
+ // Note: In order to achieve the desired behaviour, we use the ICU
UCharacter class to
+ // convert the string to titlecase, which only accepts a Java strings as
input.
+ // TODO(SPARK-48715): All UTF8String -> String conversions should use
`makeValid`
+ return UTF8String.fromString(UCharacter.toTitleCase(target.toString(),
+ BreakIterator.getWordInstance()));
}
/**
* Convert the input string to titlecase using the specified ICU collation
rules.
*/
public static UTF8String toTitleCase(final UTF8String target, final int
collationId) {
- return UTF8String.fromString(toTitleCase(target.toString(), collationId));
- }
-
- public static String toTitleCase(final String target, final int collationId)
{
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
- return UCharacter.toTitleCase(locale, target,
BreakIterator.getWordInstance(locale));
+ // TODO(SPARK-48715): All UTF8String -> String conversions should use
`makeValid`
+ return UTF8String.fromString(UCharacter.toTitleCase(locale,
target.toString(),
+ BreakIterator.getWordInstance(locale)));
}
public static int findInSet(final UTF8String match, final UTF8String set,
int collationId) {
@@ -461,6 +484,7 @@ public class CollationAwareUTF8String {
return 0;
}
+ // TODO(SPARK-48715): All UTF8String -> String conversions should use
`makeValid`
String setString = set.toString();
StringSearch stringSearch = CollationFactory.getStringSearch(setString,
match.toString(),
collationId);
@@ -623,6 +647,7 @@ public class CollationAwareUTF8String {
public static Map<String, String> getCollationAwareDict(UTF8String string,
Map<String, String> dict, int collationId) {
+ // TODO(SPARK-48715): All UTF8String -> String conversions should use
`makeValid`
String srcStr = string.toString();
Map<String, String> collationAwareDict = new HashMap<>();
diff --git
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index 61ec6f7da215..b0f6c5c22991 100644
---
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -299,7 +299,7 @@ public final class CollationFactory {
== DefinitionOrigin.PREDEFINED);
if (collationId == UTF8_BINARY_COLLATION_ID) {
// Skip cache.
- return CollationSpecUTF8Binary.UTF8_BINARY_COLLATION;
+ return CollationSpecUTF8.UTF8_BINARY_COLLATION;
} else if (collationMap.containsKey(collationId)) {
// Already in cache.
return collationMap.get(collationId);
@@ -308,7 +308,7 @@ public final class CollationFactory {
CollationSpec spec;
ImplementationProvider implementationProvider =
getImplementationProvider(collationId);
if (implementationProvider == ImplementationProvider.UTF8_BINARY) {
- spec = CollationSpecUTF8Binary.fromCollationId(collationId);
+ spec = CollationSpecUTF8.fromCollationId(collationId);
} else {
spec = CollationSpecICU.fromCollationId(collationId);
}
@@ -327,7 +327,7 @@ public final class CollationFactory {
// Collation names provided by user are treated as case-insensitive.
String collationNameUpper = collationName.toUpperCase();
if (collationNameUpper.startsWith("UTF8_")) {
- return CollationSpecUTF8Binary.collationNameToId(collationName,
collationNameUpper);
+ return CollationSpecUTF8.collationNameToId(collationName,
collationNameUpper);
} else {
return CollationSpecICU.collationNameToId(collationName,
collationNameUpper);
}
@@ -336,7 +336,7 @@ public final class CollationFactory {
protected abstract Collation buildCollation();
}
- private static class CollationSpecUTF8Binary extends CollationSpec {
+ private static class CollationSpecUTF8 extends CollationSpec {
/**
* Bit 0 in collation ID having value 0 for plain UTF8_BINARY and 1 for
UTF8_LCASE
@@ -357,17 +357,17 @@ public final class CollationFactory {
private static final int CASE_SENSITIVITY_MASK = 0b1;
private static final int UTF8_BINARY_COLLATION_ID =
- new CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).collationId;
+ new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).collationId;
private static final int UTF8_LCASE_COLLATION_ID =
- new CollationSpecUTF8Binary(CaseSensitivity.LCASE).collationId;
+ new CollationSpecUTF8(CaseSensitivity.LCASE).collationId;
protected static Collation UTF8_BINARY_COLLATION =
- new
CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).buildCollation();
+ new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).buildCollation();
protected static Collation UTF8_LCASE_COLLATION =
- new CollationSpecUTF8Binary(CaseSensitivity.LCASE).buildCollation();
+ new CollationSpecUTF8(CaseSensitivity.LCASE).buildCollation();
private final int collationId;
- private CollationSpecUTF8Binary(CaseSensitivity caseSensitivity) {
+ private CollationSpecUTF8(CaseSensitivity caseSensitivity) {
this.collationId =
SpecifierUtils.setSpecValue(0, CASE_SENSITIVITY_OFFSET,
caseSensitivity);
}
@@ -384,14 +384,14 @@ public final class CollationFactory {
}
}
- private static CollationSpecUTF8Binary fromCollationId(int collationId) {
+ private static CollationSpecUTF8 fromCollationId(int collationId) {
// Extract case sensitivity from collation ID.
int caseConversionOrdinal = SpecifierUtils.getSpecValue(collationId,
CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK);
// Verify only case sensitivity bits were set settable in UTF8_BINARY
family of collations.
assert (SpecifierUtils.removeSpec(collationId,
CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK) == 0);
- return new
CollationSpecUTF8Binary(CaseSensitivity.values()[caseConversionOrdinal]);
+ return new
CollationSpecUTF8(CaseSensitivity.values()[caseConversionOrdinal]);
}
@Override
@@ -414,7 +414,7 @@ public final class CollationFactory {
null,
CollationAwareUTF8String::compareLowerCase,
"1.0",
- s -> (long)
CollationAwareUTF8String.lowerCaseCodePoints(s.toString()).hashCode(),
+ s -> (long)
CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(),
/* supportsBinaryEquality = */ false,
/* supportsBinaryOrdering = */ false,
/* supportsLowercaseEquality = */ true);
@@ -727,9 +727,9 @@ public final class CollationFactory {
public static final List<String> SUPPORTED_PROVIDERS =
List.of(PROVIDER_SPARK, PROVIDER_ICU);
public static final int UTF8_BINARY_COLLATION_ID =
- Collation.CollationSpecUTF8Binary.UTF8_BINARY_COLLATION_ID;
+ Collation.CollationSpecUTF8.UTF8_BINARY_COLLATION_ID;
public static final int UTF8_LCASE_COLLATION_ID =
- Collation.CollationSpecUTF8Binary.UTF8_LCASE_COLLATION_ID;
+ Collation.CollationSpecUTF8.UTF8_LCASE_COLLATION_ID;
public static final int UNICODE_COLLATION_ID =
Collation.CollationSpecICU.UNICODE_COLLATION_ID;
public static final int UNICODE_CI_COLLATION_ID =
diff --git
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index e7f16988c537..12a7b06232ee 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.function.Function;
import java.util.Map;
import java.util.regex.Pattern;
@@ -495,6 +496,18 @@ public final class UTF8String implements
Comparable<UTF8String>, Externalizable,
return matchAt(suffix, numBytes - suffix.numBytes);
}
+ /**
+ * Method for ASCII character conversion using a functional interface for
chars.
+ */
+
+ private UTF8String convertAscii(Function<Character, Character>
charConverter) {
+ byte[] bytes = new byte[numBytes];
+ for (int i = 0; i < numBytes; i++) {
+ bytes[i] = (byte) charConverter.apply((char) getByte(i)).charValue();
+ }
+ return fromBytes(bytes);
+ }
+
/**
* Returns the upper case of this string
*/
@@ -502,18 +515,12 @@ public final class UTF8String implements
Comparable<UTF8String>, Externalizable,
if (numBytes == 0) {
return EMPTY_UTF8;
}
- // Optimization - do char level uppercase conversion in case of chars in
ASCII range
- for (int i = 0; i < numBytes; i++) {
- if (getByte(i) < 0) {
- // non-ASCII
- return toUpperCaseSlow();
- }
- }
- byte[] bytes = new byte[numBytes];
- for (int i = 0; i < numBytes; i++) {
- bytes[i] = (byte) Character.toUpperCase(getByte(i));
- }
- return fromBytes(bytes);
+
+ return isFullAscii() ? toUpperCaseAscii() : toUpperCaseSlow();
+ }
+
+ public UTF8String toUpperCaseAscii() {
+ return convertAscii(Character::toUpperCase);
}
private UTF8String toUpperCaseSlow() {
@@ -544,12 +551,8 @@ public final class UTF8String implements
Comparable<UTF8String>, Externalizable,
return fromString(toString().toLowerCase());
}
- private UTF8String toLowerCaseAscii() {
- final var bytes = new byte[numBytes];
- for (var i = 0; i < numBytes; i++) {
- bytes[i] = (byte) Character.toLowerCase(getByte(i));
- }
- return fromBytes(bytes);
+ public UTF8String toLowerCaseAscii() {
+ return convertAscii(Character::toLowerCase);
}
/**
diff --git
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 436dff1db0e0..9602c83c6c80 100644
---
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -41,7 +41,7 @@ public class CollationSupportSuite {
*/
private void assertStringCompare(String s1, String s2, String collationName,
int expected)
- throws SparkException {
+ throws SparkException {
UTF8String l = UTF8String.fromString(s1);
UTF8String r = UTF8String.fromString(s2);
int compare =
CollationFactory.fetchCollation(collationName).comparator.compare(l, r);
@@ -129,13 +129,26 @@ public class CollationSupportSuite {
assertStringCompare("ς", "σ", "UNICODE_CI", 0);
assertStringCompare("ς", "Σ", "UNICODE_CI", 0);
assertStringCompare("σ", "Σ", "UNICODE_CI", 0);
+ // Maximum code point.
+ int maxCodePoint = Character.MAX_CODE_POINT;
+ String maxCodePointStr = new String(Character.toChars(maxCodePoint));
+ for (int i = 0; i < maxCodePoint && Character.isValidCodePoint(i); ++i) {
+ assertStringCompare(new String(Character.toChars(i)), maxCodePointStr,
"UTF8_BINARY", -1);
+ assertStringCompare(new String(Character.toChars(i)), maxCodePointStr,
"UTF8_LCASE", -1);
+ }
+ // Minimum code point.
+ int minCodePoint = Character.MIN_CODE_POINT;
+ String minCodePointStr = new String(Character.toChars(minCodePoint));
+ for (int i = minCodePoint + 1; i <= maxCodePoint &&
Character.isValidCodePoint(i); ++i) {
+ assertStringCompare(new String(Character.toChars(i)), minCodePointStr,
"UTF8_BINARY", 1);
+ assertStringCompare(new String(Character.toChars(i)), minCodePointStr,
"UTF8_LCASE", 1);
+ }
}
private void assertLowerCaseCodePoints(UTF8String target, UTF8String
expected,
Boolean useCodePoints) {
if (useCodePoints) {
- assertEquals(expected.toString(),
- CollationAwareUTF8String.lowerCaseCodePoints(target.toString()));
+ assertEquals(expected,
CollationAwareUTF8String.lowerCaseCodePoints(target));
} else {
assertEquals(expected, target.toLowerCase());
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]