This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ec0ee863562e [SPARK-48699][SQL] Refine collation API
ec0ee863562e is described below

commit ec0ee863562e7ad6c27a8632f9de804bd012e9ec
Author: Uros Bojanic <[email protected]>
AuthorDate: Wed Jun 26 15:46:51 2024 +0800

    [SPARK-48699][SQL] Refine collation API
    
    ### What changes were proposed in this pull request?
    - All collation-related public API (i.e. `CollationFactory`, 
`CollationAwareUTF8String`, etc.) should generally use `UTF8String` instead of 
(Java) `String`.
    - Comparator for `UTF8_LCASE` collation should use 
`UTF8String.binaryCompare` instead of `Java.compareTo`.
    - Added ASCII “fast” path for `Lower` & `Upper` expressions.
    
    ### Why are the changes needed?
    - Fix collation API.
    - Fix a bug in the `UTF8_LCASE` comparator.
    - Enhance performance for `Lower` & `Upper` expressions.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    ### How was this patch tested?
    - Added more tests to `CollationSupportSuite` to verify that `UTF8`-based 
comparators indeed work properly.
    - Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #47014 from uros-db/fix-collation-api.
    
    Authored-by: Uros Bojanic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../catalyst/util/CollationAwareUTF8String.java    | 79 ++++++++++++++--------
 .../spark/sql/catalyst/util/CollationFactory.java  | 28 ++++----
 .../org/apache/spark/unsafe/types/UTF8String.java  | 39 ++++++-----
 .../spark/unsafe/types/CollationSupportSuite.java  | 19 +++++-
 4 files changed, 103 insertions(+), 62 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
index 934572cd0d67..7c0ffc73aa9d 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
@@ -228,7 +228,7 @@ public class CollationAwareUTF8String {
    * @return An integer representing the comparison result.
    */
   private static int compareLowerCaseSlow(final UTF8String left, final 
UTF8String right) {
-    return 
lowerCaseCodePoints(left.toString()).compareTo(lowerCaseCodePoints(right.toString()));
+    return lowerCaseCodePoints(left).binaryCompare(lowerCaseCodePoints(right));
   }
 
   public static UTF8String replace(final UTF8String src, final UTF8String 
search,
@@ -339,11 +339,15 @@ public class CollationAwareUTF8String {
    * @return the uppercase string
    */
   public static UTF8String toUpperCase(final UTF8String target) {
-    return UTF8String.fromString(toUpperCase(target.toString()));
+    if (target.isFullAscii()) return target.toUpperCaseAscii();
+    return toUpperCaseSlow(target);
   }
 
-  public static String toUpperCase(final String target) {
-    return UCharacter.toUpperCase(target);
+  private static UTF8String toUpperCaseSlow(final UTF8String target) {
+    // Note: In order to achieve the desired behaviour, we use the ICU 
UCharacter class to
+    // convert the string to uppercase, which only accepts a Java strings as 
input.
+    // TODO(SPARK-48715): All UTF8String -> String conversions should use 
`makeValid`
+    return UTF8String.fromString(UCharacter.toUpperCase(target.toString()));
   }
 
   /**
@@ -353,13 +357,17 @@ public class CollationAwareUTF8String {
    * @return the uppercase string
    */
   public static UTF8String toUpperCase(final UTF8String target, final int 
collationId) {
-    return UTF8String.fromString(toUpperCase(target.toString(), collationId));
+    if (target.isFullAscii()) return target.toUpperCaseAscii();
+    return toUpperCaseSlow(target, collationId);
   }
 
-  public static String toUpperCase(final String target, final int collationId) 
{
+  private static UTF8String toUpperCaseSlow(final UTF8String target, final int 
collationId) {
+    // Note: In order to achieve the desired behaviour, we use the ICU 
UCharacter class to
+    // convert the string to uppercase, which only accepts a Java strings as 
input.
     ULocale locale = CollationFactory.fetchCollation(collationId)
       .collator.getLocale(ULocale.ACTUAL_LOCALE);
-    return UCharacter.toUpperCase(locale, target);
+    // TODO(SPARK-48715): All UTF8String -> String conversions should use 
`makeValid`
+    return UTF8String.fromString(UCharacter.toUpperCase(locale, 
target.toString()));
   }
 
   /**
@@ -369,10 +377,15 @@ public class CollationAwareUTF8String {
    * @return the lowercase string
    */
   public static UTF8String toLowerCase(final UTF8String target) {
-    return UTF8String.fromString(toLowerCase(target.toString()));
+    if (target.isFullAscii()) return target.toLowerCaseAscii();
+    return toLowerCaseSlow(target);
   }
-  public static String toLowerCase(final String target) {
-    return UCharacter.toLowerCase(target);
+
+  private static UTF8String toLowerCaseSlow(final UTF8String target) {
+    // Note: In order to achieve the desired behaviour, we use the ICU 
UCharacter class to
+    // convert the string to lowercase, which only accepts a Java strings as 
input.
+    // TODO(SPARK-48715): All UTF8String -> String conversions should use 
`makeValid`
+    return UTF8String.fromString(UCharacter.toLowerCase(target.toString()));
   }
 
   /**
@@ -382,12 +395,17 @@ public class CollationAwareUTF8String {
    * @return the lowercase string
    */
   public static UTF8String toLowerCase(final UTF8String target, final int 
collationId) {
-    return UTF8String.fromString(toLowerCase(target.toString(), collationId));
+    if (target.isFullAscii()) return target.toLowerCaseAscii();
+    return toLowerCaseSlow(target, collationId);
   }
-  public static String toLowerCase(final String target, final int collationId) 
{
+
+  private static UTF8String toLowerCaseSlow(final UTF8String target, final int 
collationId) {
+    // Note: In order to achieve the desired behaviour, we use the ICU 
UCharacter class to
+    // convert the string to lowercase, which only accepts a Java strings as 
input.
     ULocale locale = CollationFactory.fetchCollation(collationId)
       .collator.getLocale(ULocale.ACTUAL_LOCALE);
-    return UCharacter.toLowerCase(locale, target);
+    // TODO(SPARK-48715): All UTF8String -> String conversions should use 
`makeValid`
+    return UTF8String.fromString(UCharacter.toLowerCase(locale, 
target.toString()));
   }
 
   /**
@@ -424,36 +442,41 @@ public class CollationAwareUTF8String {
    * @param target The target string to convert to lowercase.
    * @return The string converted to lowercase in a context-unaware manner.
    */
-  public static String lowerCaseCodePoints(final String target) {
+  public static UTF8String lowerCaseCodePoints(final UTF8String target) {
+    if (target.isFullAscii()) return target.toLowerCaseAscii();
+    return lowerCaseCodePointsSlow(target);
+  }
+
+  private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) {
+    // TODO(SPARK-48715): All UTF8String -> String conversions should use 
`makeValid`
+    String targetString = target.toString();
     StringBuilder sb = new StringBuilder();
-    for (int i = 0; i < target.length(); ++i) {
-      lowercaseCodePoint(target.codePointAt(i), sb);
+    for (int i = 0; i < targetString.length(); ++i) {
+      lowercaseCodePoint(targetString.codePointAt(i), sb);
     }
-    return sb.toString();
+    return UTF8String.fromString(sb.toString());
   }
 
   /**
    * Convert the input string to titlecase using the ICU root locale rules.
    */
   public static UTF8String toTitleCase(final UTF8String target) {
-    return UTF8String.fromString(toTitleCase(target.toString()));
-  }
-
-  public static String toTitleCase(final String target) {
-    return UCharacter.toTitleCase(target, BreakIterator.getWordInstance());
+    // Note: In order to achieve the desired behaviour, we use the ICU 
UCharacter class to
+    // convert the string to titlecase, which only accepts a Java strings as 
input.
+    // TODO(SPARK-48715): All UTF8String -> String conversions should use 
`makeValid`
+    return UTF8String.fromString(UCharacter.toTitleCase(target.toString(),
+      BreakIterator.getWordInstance()));
   }
 
   /**
    * Convert the input string to titlecase using the specified ICU collation 
rules.
    */
   public static UTF8String toTitleCase(final UTF8String target, final int 
collationId) {
-    return UTF8String.fromString(toTitleCase(target.toString(), collationId));
-  }
-
-  public static String toTitleCase(final String target, final int collationId) 
{
     ULocale locale = CollationFactory.fetchCollation(collationId)
       .collator.getLocale(ULocale.ACTUAL_LOCALE);
-    return UCharacter.toTitleCase(locale, target, 
BreakIterator.getWordInstance(locale));
+    // TODO(SPARK-48715): All UTF8String -> String conversions should use 
`makeValid`
+    return UTF8String.fromString(UCharacter.toTitleCase(locale, 
target.toString(),
+      BreakIterator.getWordInstance(locale)));
   }
 
   public static int findInSet(final UTF8String match, final UTF8String set, 
int collationId) {
@@ -461,6 +484,7 @@ public class CollationAwareUTF8String {
       return 0;
     }
 
+    // TODO(SPARK-48715): All UTF8String -> String conversions should use 
`makeValid`
     String setString = set.toString();
     StringSearch stringSearch = CollationFactory.getStringSearch(setString, 
match.toString(),
       collationId);
@@ -623,6 +647,7 @@ public class CollationAwareUTF8String {
 
   public static Map<String, String> getCollationAwareDict(UTF8String string,
       Map<String, String> dict, int collationId) {
+    // TODO(SPARK-48715): All UTF8String -> String conversions should use 
`makeValid`
     String srcStr = string.toString();
 
     Map<String, String> collationAwareDict = new HashMap<>();
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index 61ec6f7da215..b0f6c5c22991 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -299,7 +299,7 @@ public final class CollationFactory {
           == DefinitionOrigin.PREDEFINED);
         if (collationId == UTF8_BINARY_COLLATION_ID) {
           // Skip cache.
-          return CollationSpecUTF8Binary.UTF8_BINARY_COLLATION;
+          return CollationSpecUTF8.UTF8_BINARY_COLLATION;
         } else if (collationMap.containsKey(collationId)) {
           // Already in cache.
           return collationMap.get(collationId);
@@ -308,7 +308,7 @@ public final class CollationFactory {
           CollationSpec spec;
           ImplementationProvider implementationProvider = 
getImplementationProvider(collationId);
           if (implementationProvider == ImplementationProvider.UTF8_BINARY) {
-            spec = CollationSpecUTF8Binary.fromCollationId(collationId);
+            spec = CollationSpecUTF8.fromCollationId(collationId);
           } else {
             spec = CollationSpecICU.fromCollationId(collationId);
           }
@@ -327,7 +327,7 @@ public final class CollationFactory {
         // Collation names provided by user are treated as case-insensitive.
         String collationNameUpper = collationName.toUpperCase();
         if (collationNameUpper.startsWith("UTF8_")) {
-          return CollationSpecUTF8Binary.collationNameToId(collationName, 
collationNameUpper);
+          return CollationSpecUTF8.collationNameToId(collationName, 
collationNameUpper);
         } else {
           return CollationSpecICU.collationNameToId(collationName, 
collationNameUpper);
         }
@@ -336,7 +336,7 @@ public final class CollationFactory {
       protected abstract Collation buildCollation();
     }
 
-    private static class CollationSpecUTF8Binary extends CollationSpec {
+    private static class CollationSpecUTF8 extends CollationSpec {
 
       /**
        * Bit 0 in collation ID having value 0 for plain UTF8_BINARY and 1 for 
UTF8_LCASE
@@ -357,17 +357,17 @@ public final class CollationFactory {
       private static final int CASE_SENSITIVITY_MASK = 0b1;
 
       private static final int UTF8_BINARY_COLLATION_ID =
-        new CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).collationId;
+        new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).collationId;
       private static final int UTF8_LCASE_COLLATION_ID =
-        new CollationSpecUTF8Binary(CaseSensitivity.LCASE).collationId;
+        new CollationSpecUTF8(CaseSensitivity.LCASE).collationId;
       protected static Collation UTF8_BINARY_COLLATION =
-        new 
CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).buildCollation();
+        new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).buildCollation();
       protected static Collation UTF8_LCASE_COLLATION =
-        new CollationSpecUTF8Binary(CaseSensitivity.LCASE).buildCollation();
+        new CollationSpecUTF8(CaseSensitivity.LCASE).buildCollation();
 
       private final int collationId;
 
-      private CollationSpecUTF8Binary(CaseSensitivity caseSensitivity) {
+      private CollationSpecUTF8(CaseSensitivity caseSensitivity) {
         this.collationId =
           SpecifierUtils.setSpecValue(0, CASE_SENSITIVITY_OFFSET, 
caseSensitivity);
       }
@@ -384,14 +384,14 @@ public final class CollationFactory {
         }
       }
 
-      private static CollationSpecUTF8Binary fromCollationId(int collationId) {
+      private static CollationSpecUTF8 fromCollationId(int collationId) {
         // Extract case sensitivity from collation ID.
         int caseConversionOrdinal = SpecifierUtils.getSpecValue(collationId,
           CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK);
         // Verify only case sensitivity bits were set settable in UTF8_BINARY 
family of collations.
         assert (SpecifierUtils.removeSpec(collationId,
           CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK) == 0);
-        return new 
CollationSpecUTF8Binary(CaseSensitivity.values()[caseConversionOrdinal]);
+        return new 
CollationSpecUTF8(CaseSensitivity.values()[caseConversionOrdinal]);
       }
 
       @Override
@@ -414,7 +414,7 @@ public final class CollationFactory {
             null,
             CollationAwareUTF8String::compareLowerCase,
             "1.0",
-            s -> (long) 
CollationAwareUTF8String.lowerCaseCodePoints(s.toString()).hashCode(),
+            s -> (long) 
CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(),
             /* supportsBinaryEquality = */ false,
             /* supportsBinaryOrdering = */ false,
             /* supportsLowercaseEquality = */ true);
@@ -727,9 +727,9 @@ public final class CollationFactory {
   public static final List<String> SUPPORTED_PROVIDERS = 
List.of(PROVIDER_SPARK, PROVIDER_ICU);
 
   public static final int UTF8_BINARY_COLLATION_ID =
-    Collation.CollationSpecUTF8Binary.UTF8_BINARY_COLLATION_ID;
+    Collation.CollationSpecUTF8.UTF8_BINARY_COLLATION_ID;
   public static final int UTF8_LCASE_COLLATION_ID =
-    Collation.CollationSpecUTF8Binary.UTF8_LCASE_COLLATION_ID;
+    Collation.CollationSpecUTF8.UTF8_LCASE_COLLATION_ID;
   public static final int UNICODE_COLLATION_ID =
     Collation.CollationSpecICU.UNICODE_COLLATION_ID;
   public static final int UNICODE_CI_COLLATION_ID =
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index e7f16988c537..12a7b06232ee 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.function.Function;
 import java.util.Map;
 import java.util.regex.Pattern;
 
@@ -495,6 +496,18 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
     return matchAt(suffix, numBytes - suffix.numBytes);
   }
 
+  /**
+   * Method for ASCII character conversion using a functional interface for 
chars.
+   */
+
+  private UTF8String convertAscii(Function<Character, Character> 
charConverter) {
+    byte[] bytes = new byte[numBytes];
+    for (int i = 0; i < numBytes; i++) {
+        bytes[i] = (byte) charConverter.apply((char) getByte(i)).charValue();
+    }
+    return fromBytes(bytes);
+  }
+
   /**
    * Returns the upper case of this string
    */
@@ -502,18 +515,12 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
     if (numBytes == 0) {
       return EMPTY_UTF8;
     }
-    // Optimization - do char level uppercase conversion in case of chars in 
ASCII range
-    for (int i = 0; i < numBytes; i++) {
-      if (getByte(i) < 0) {
-        // non-ASCII
-        return toUpperCaseSlow();
-      }
-    }
-    byte[] bytes = new byte[numBytes];
-    for (int i = 0; i < numBytes; i++) {
-      bytes[i] = (byte) Character.toUpperCase(getByte(i));
-    }
-    return fromBytes(bytes);
+
+    return isFullAscii() ? toUpperCaseAscii() : toUpperCaseSlow();
+  }
+
+  public UTF8String toUpperCaseAscii() {
+    return convertAscii(Character::toUpperCase);
   }
 
   private UTF8String toUpperCaseSlow() {
@@ -544,12 +551,8 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
     return fromString(toString().toLowerCase());
   }
 
-  private UTF8String toLowerCaseAscii() {
-    final var bytes = new byte[numBytes];
-    for (var i = 0; i < numBytes; i++) {
-      bytes[i] = (byte) Character.toLowerCase(getByte(i));
-    }
-    return fromBytes(bytes);
+  public UTF8String toLowerCaseAscii() {
+    return convertAscii(Character::toLowerCase);
   }
 
   /**
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 436dff1db0e0..9602c83c6c80 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -41,7 +41,7 @@ public class CollationSupportSuite {
    */
 
   private void assertStringCompare(String s1, String s2, String collationName, 
int expected)
-          throws SparkException {
+      throws SparkException {
     UTF8String l = UTF8String.fromString(s1);
     UTF8String r = UTF8String.fromString(s2);
     int compare = 
CollationFactory.fetchCollation(collationName).comparator.compare(l, r);
@@ -129,13 +129,26 @@ public class CollationSupportSuite {
     assertStringCompare("ς", "σ", "UNICODE_CI", 0);
     assertStringCompare("ς", "Σ", "UNICODE_CI", 0);
     assertStringCompare("σ", "Σ", "UNICODE_CI", 0);
+    // Maximum code point.
+    int maxCodePoint = Character.MAX_CODE_POINT;
+    String maxCodePointStr = new String(Character.toChars(maxCodePoint));
+    for (int i = 0; i < maxCodePoint && Character.isValidCodePoint(i); ++i) {
+      assertStringCompare(new String(Character.toChars(i)), maxCodePointStr, 
"UTF8_BINARY", -1);
+      assertStringCompare(new String(Character.toChars(i)), maxCodePointStr, 
"UTF8_LCASE", -1);
+    }
+    // Minimum code point.
+    int minCodePoint = Character.MIN_CODE_POINT;
+    String minCodePointStr = new String(Character.toChars(minCodePoint));
+    for (int i = minCodePoint + 1; i <= maxCodePoint && 
Character.isValidCodePoint(i); ++i) {
+      assertStringCompare(new String(Character.toChars(i)), minCodePointStr, 
"UTF8_BINARY", 1);
+      assertStringCompare(new String(Character.toChars(i)), minCodePointStr, 
"UTF8_LCASE", 1);
+    }
   }
 
   private void assertLowerCaseCodePoints(UTF8String target, UTF8String 
expected,
       Boolean useCodePoints) {
     if (useCodePoints) {
-      assertEquals(expected.toString(),
-        CollationAwareUTF8String.lowerCaseCodePoints(target.toString()));
+      assertEquals(expected, 
CollationAwareUTF8String.lowerCaseCodePoints(target));
     } else {
       assertEquals(expected, target.toLowerCase());
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to