This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b9f2270f5b0b [SPARK-47352][SQL] Fix Upper, Lower, InitCap collation 
awareness
b9f2270f5b0b is described below

commit b9f2270f5b0ba6ea1fb1cdf3225fa626ab91540b
Author: Mihailo Milosevic <mihailo.milose...@databricks.com>
AuthorDate: Tue Apr 23 16:28:33 2024 +0800

    [SPARK-47352][SQL] Fix Upper, Lower, InitCap collation awareness
    
    ### What changes were proposed in this pull request?
    Add support for Locale aware expressions.
    
    ### Why are the changes needed?
    This is needed as some future collations might use different Locales then 
default.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, we follow ICU implementations for collations that are non native.
    
    ### How was this patch tested?
    Tests for Upper, Lower and InitCap already exist.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #46104 from mihailom-db/SPARK-47352.
    
    Authored-by: Mihailo Milosevic <mihailo.milose...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/util/CollationSupport.java  | 108 +++++++++++++++
 .../spark/unsafe/types/CollationSupportSuite.java  | 151 +++++++++++++++++++++
 .../catalyst/expressions/stringExpressions.scala   |  24 ++--
 3 files changed, 271 insertions(+), 12 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index d54e297413f4..b28321230840 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -16,7 +16,10 @@
  */
 package org.apache.spark.sql.catalyst.util;
 
+import com.ibm.icu.lang.UCharacter;
+import com.ibm.icu.text.BreakIterator;
 import com.ibm.icu.text.StringSearch;
+import com.ibm.icu.util.ULocale;
 
 import org.apache.spark.unsafe.types.UTF8String;
 
@@ -144,6 +147,93 @@ public final class CollationSupport {
     }
   }
 
+  public static class Upper {
+    public static UTF8String exec(final UTF8String v, final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality || 
collation.supportsLowercaseEquality) {
+        return execUTF8(v);
+      } else {
+        return execICU(v, collationId);
+      }
+    }
+    public static String genCode(final String v, final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.Upper.exec";
+      if (collation.supportsBinaryEquality || 
collation.supportsLowercaseEquality) {
+        return String.format(expr + "UTF8(%s)", v);
+      } else {
+        return String.format(expr + "ICU(%s, %d)", v, collationId);
+      }
+    }
+    public static UTF8String execUTF8(final UTF8String v) {
+      return v.toUpperCase();
+    }
+    public static UTF8String execICU(final UTF8String v, final int 
collationId) {
+      return 
UTF8String.fromString(CollationAwareUTF8String.toUpperCase(v.toString(), 
collationId));
+    }
+  }
+
+  public static class Lower {
+    public static UTF8String exec(final UTF8String v, final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality || 
collation.supportsLowercaseEquality) {
+        return execUTF8(v);
+      } else {
+        return execICU(v, collationId);
+      }
+    }
+    public static String genCode(final String v, final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+        String expr = "CollationSupport.Lower.exec";
+      if (collation.supportsBinaryEquality || 
collation.supportsLowercaseEquality) {
+        return String.format(expr + "UTF8(%s)", v);
+      } else {
+        return String.format(expr + "ICU(%s, %d)", v, collationId);
+      }
+    }
+    public static UTF8String execUTF8(final UTF8String v) {
+      return v.toLowerCase();
+    }
+    public static UTF8String execICU(final UTF8String v, final int 
collationId) {
+      return 
UTF8String.fromString(CollationAwareUTF8String.toLowerCase(v.toString(), 
collationId));
+    }
+  }
+
+  public static class InitCap {
+    public static UTF8String exec(final UTF8String v, final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality || 
collation.supportsLowercaseEquality) {
+        return execUTF8(v);
+      } else {
+        return execICU(v, collationId);
+      }
+    }
+
+    public static String genCode(final String v, final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.InitCap.exec";
+      if (collation.supportsBinaryEquality || 
collation.supportsLowercaseEquality) {
+        return String.format(expr + "UTF8(%s)", v);
+      } else {
+        return String.format(expr + "ICU(%s, %d)", v, collationId);
+      }
+    }
+
+    public static UTF8String execUTF8(final UTF8String v) {
+      return v.toLowerCase().toTitleCase();
+    }
+
+    public static UTF8String execICU(final UTF8String v, final int 
collationId) {
+      return UTF8String.fromString(
+              CollationAwareUTF8String.toTitleCase(
+                      CollationAwareUTF8String.toLowerCase(
+                              v.toString(),
+                              collationId
+                      ),
+                      collationId));
+    }
+  }
+
   public static class FindInSet {
     public static int exec(final UTF8String word, final UTF8String set, final 
int collationId) {
       CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
@@ -234,6 +324,24 @@ public final class CollationSupport {
 
   private static class CollationAwareUTF8String {
 
+    private static String toUpperCase(final String target, final int 
collationId) {
+      ULocale locale = CollationFactory.fetchCollation(collationId)
+              .collator.getLocale(ULocale.ACTUAL_LOCALE);
+      return UCharacter.toUpperCase(locale, target);
+    }
+
+    private static String toLowerCase(final String target, final int 
collationId) {
+      ULocale locale = CollationFactory.fetchCollation(collationId)
+              .collator.getLocale(ULocale.ACTUAL_LOCALE);
+      return UCharacter.toLowerCase(locale, target);
+    }
+
+    private static String toTitleCase(final String target, final int 
collationId) {
+      ULocale locale = CollationFactory.fetchCollation(collationId)
+              .collator.getLocale(ULocale.ACTUAL_LOCALE);
+      return UCharacter.toTitleCase(locale, target, 
BreakIterator.getWordInstance(locale));
+    }
+
     private static int findInSet(final UTF8String match, final UTF8String set, 
int collationId) {
       if (match.contains(UTF8String.fromString(","))) {
         return 0;
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 36acf1c9b7a6..3fca7296b832 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -261,6 +261,157 @@ public class CollationSupportSuite {
     assertEndsWith("The i̇o", "İo", "UNICODE_CI", true);
   }
 
+
+  private void assertUpper(String target, String collationName, String 
expected)
+          throws SparkException {
+    UTF8String target_utf8 = UTF8String.fromString(target);
+    UTF8String expected_utf8 = UTF8String.fromString(expected);
+    int collationId = CollationFactory.collationNameToId(collationName);
+    assertEquals(expected_utf8, CollationSupport.Upper.exec(target_utf8, 
collationId));
+  }
+
+  @Test
+  public void testUpper() throws SparkException {
+    // Edge cases
+    assertUpper("", "UTF8_BINARY", "");
+    assertUpper("", "UTF8_BINARY_LCASE", "");
+    assertUpper("", "UNICODE", "");
+    assertUpper("", "UNICODE_CI", "");
+    // Basic tests
+    assertUpper("abcde", "UTF8_BINARY", "ABCDE");
+    assertUpper("abcde", "UTF8_BINARY_LCASE", "ABCDE");
+    assertUpper("abcde", "UNICODE", "ABCDE");
+    assertUpper("abcde", "UNICODE_CI", "ABCDE");
+    // Uppercase present
+    assertUpper("AbCdE", "UTF8_BINARY", "ABCDE");
+    assertUpper("aBcDe", "UTF8_BINARY", "ABCDE");
+    assertUpper("AbCdE", "UTF8_BINARY_LCASE", "ABCDE");
+    assertUpper("aBcDe", "UTF8_BINARY_LCASE", "ABCDE");
+    assertUpper("AbCdE", "UNICODE", "ABCDE");
+    assertUpper("aBcDe", "UNICODE", "ABCDE");
+    assertUpper("AbCdE", "UNICODE_CI", "ABCDE");
+    assertUpper("aBcDe", "UNICODE_CI", "ABCDE");
+    // Accent letters
+    assertUpper("aBćDe","UTF8_BINARY", "ABĆDE");
+    assertUpper("aBćDe","UTF8_BINARY_LCASE", "ABĆDE");
+    assertUpper("aBćDe","UNICODE", "ABĆDE");
+    assertUpper("aBćDe","UNICODE_CI", "ABĆDE");
+    // Variable byte length characters
+    assertUpper("ab世De", "UTF8_BINARY", "AB世DE");
+    assertUpper("äbćδe", "UTF8_BINARY", "ÄBĆΔE");
+    assertUpper("ab世De", "UTF8_BINARY_LCASE", "AB世DE");
+    assertUpper("äbćδe", "UTF8_BINARY_LCASE", "ÄBĆΔE");
+    assertUpper("ab世De", "UNICODE", "AB世DE");
+    assertUpper("äbćδe", "UNICODE", "ÄBĆΔE");
+    assertUpper("ab世De", "UNICODE_CI", "AB世DE");
+    assertUpper("äbćδe", "UNICODE_CI", "ÄBĆΔE");
+    // Case-variable character length
+    assertUpper("i̇o", "UTF8_BINARY","İO");
+    assertUpper("i̇o", "UTF8_BINARY_LCASE","İO");
+    assertUpper("i̇o", "UNICODE","İO");
+    assertUpper("i̇o", "UNICODE_CI","İO");
+  }
+
+  private void assertLower(String target, String collationName, String 
expected)
+          throws SparkException {
+    UTF8String target_utf8 = UTF8String.fromString(target);
+    UTF8String expected_utf8 = UTF8String.fromString(expected);
+    int collationId = CollationFactory.collationNameToId(collationName);
+    assertEquals(expected_utf8, CollationSupport.Lower.exec(target_utf8, 
collationId));
+  }
+
+  @Test
+  public void testLower() throws SparkException {
+    // Edge cases
+    assertLower("", "UTF8_BINARY", "");
+    assertLower("", "UTF8_BINARY_LCASE", "");
+    assertLower("", "UNICODE", "");
+    assertLower("", "UNICODE_CI", "");
+    // Basic tests
+    assertLower("ABCDE", "UTF8_BINARY", "abcde");
+    assertLower("ABCDE", "UTF8_BINARY_LCASE", "abcde");
+    assertLower("ABCDE", "UNICODE", "abcde");
+    assertLower("ABCDE", "UNICODE_CI", "abcde");
+    // Uppercase present
+    assertLower("AbCdE", "UTF8_BINARY", "abcde");
+    assertLower("aBcDe", "UTF8_BINARY", "abcde");
+    assertLower("AbCdE", "UTF8_BINARY_LCASE", "abcde");
+    assertLower("aBcDe", "UTF8_BINARY_LCASE", "abcde");
+    assertLower("AbCdE", "UNICODE", "abcde");
+    assertLower("aBcDe", "UNICODE", "abcde");
+    assertLower("AbCdE", "UNICODE_CI", "abcde");
+    assertLower("aBcDe", "UNICODE_CI", "abcde");
+    // Accent letters
+    assertLower("AbĆdE","UTF8_BINARY", "abćde");
+    assertLower("AbĆdE","UTF8_BINARY_LCASE", "abćde");
+    assertLower("AbĆdE","UNICODE", "abćde");
+    assertLower("AbĆdE","UNICODE_CI", "abćde");
+    // Variable byte length characters
+    assertLower("aB世De", "UTF8_BINARY", "ab世de");
+    assertLower("ÄBĆΔE", "UTF8_BINARY", "äbćδe");
+    assertLower("aB世De", "UTF8_BINARY_LCASE", "ab世de");
+    assertLower("ÄBĆΔE", "UTF8_BINARY_LCASE", "äbćδe");
+    assertLower("aB世De", "UNICODE", "ab世de");
+    assertLower("ÄBĆΔE", "UNICODE", "äbćδe");
+    assertLower("aB世De", "UNICODE_CI", "ab世de");
+    assertLower("ÄBĆΔE", "UNICODE_CI", "äbćδe");
+    // Case-variable character length
+    assertLower("İo", "UTF8_BINARY","i̇o");
+    assertLower("İo", "UTF8_BINARY_LCASE","i̇o");
+    assertLower("İo", "UNICODE","i̇o");
+    assertLower("İo", "UNICODE_CI","i̇o");
+  }
+
+  private void assertInitCap(String target, String collationName, String 
expected)
+          throws SparkException {
+    UTF8String target_utf8 = UTF8String.fromString(target);
+    UTF8String expected_utf8 = UTF8String.fromString(expected);
+    int collationId = CollationFactory.collationNameToId(collationName);
+    assertEquals(expected_utf8, CollationSupport.InitCap.exec(target_utf8, 
collationId));
+  }
+
+  @Test
+  public void testInitCap() throws SparkException {
+    // Edge cases
+    assertInitCap("", "UTF8_BINARY", "");
+    assertInitCap("", "UTF8_BINARY_LCASE", "");
+    assertInitCap("", "UNICODE", "");
+    assertInitCap("", "UNICODE_CI", "");
+    // Basic tests
+    assertInitCap("ABCDE", "UTF8_BINARY", "Abcde");
+    assertInitCap("ABCDE", "UTF8_BINARY_LCASE", "Abcde");
+    assertInitCap("ABCDE", "UNICODE", "Abcde");
+    assertInitCap("ABCDE", "UNICODE_CI", "Abcde");
+    // Uppercase present
+    assertInitCap("AbCdE", "UTF8_BINARY", "Abcde");
+    assertInitCap("aBcDe", "UTF8_BINARY", "Abcde");
+    assertInitCap("AbCdE", "UTF8_BINARY_LCASE", "Abcde");
+    assertInitCap("aBcDe", "UTF8_BINARY_LCASE", "Abcde");
+    assertInitCap("AbCdE", "UNICODE", "Abcde");
+    assertInitCap("aBcDe", "UNICODE", "Abcde");
+    assertInitCap("AbCdE", "UNICODE_CI", "Abcde");
+    assertInitCap("aBcDe", "UNICODE_CI", "Abcde");
+    // Accent letters
+    assertInitCap("AbĆdE", "UTF8_BINARY", "Abćde");
+    assertInitCap("AbĆdE", "UTF8_BINARY_LCASE", "Abćde");
+    assertInitCap("AbĆdE", "UNICODE", "Abćde");
+    assertInitCap("AbĆdE", "UNICODE_CI", "Abćde");
+    // Variable byte length characters
+    assertInitCap("aB 世 De", "UTF8_BINARY", "Ab 世 De");
+    assertInitCap("ÄBĆΔE", "UTF8_BINARY", "Äbćδe");
+    assertInitCap("aB 世 De", "UTF8_BINARY_LCASE", "Ab 世 De");
+    assertInitCap("ÄBĆΔE", "UTF8_BINARY_LCASE", "Äbćδe");
+    assertInitCap("aB 世 De", "UNICODE", "Ab 世 De");
+    assertInitCap("ÄBĆΔE", "UNICODE", "Äbćδe");
+    assertInitCap("aB 世 de", "UNICODE_CI", "Ab 世 De");
+    assertInitCap("ÄBĆΔE", "UNICODE_CI", "Äbćδe");
+    // Case-variable character length
+    assertInitCap("İo", "UTF8_BINARY", "İo");
+    assertInitCap("İo", "UTF8_BINARY_LCASE", "İo");
+    assertInitCap("İo", "UNICODE", "İo");
+    assertInitCap("İo", "UNICODE_CI", "İo");
+  }
+
   private void assertStringInstr(String string, String substring, String 
collationName,
           Integer expected) throws SparkException {
     UTF8String str = UTF8String.fromString(string);
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index cd21a6f5fdc2..fd4fc7a54229 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -453,14 +453,14 @@ trait String2StringExpression extends 
ImplicitCastInputTypes {
 case class Upper(child: Expression)
   extends UnaryExpression with String2StringExpression with NullIntolerant {
 
-  // scalastyle:off caselocale
-  override def convert(v: UTF8String): UTF8String = v.toUpperCase
-  // scalastyle:on caselocale
+  final lazy val collationId: Int = 
child.dataType.asInstanceOf[StringType].collationId
+
+  override def convert(v: UTF8String): UTF8String = 
CollationSupport.Upper.exec(v, collationId)
 
   final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    defineCodeGen(ctx, ev, c => s"($c).toUpperCase()")
+    defineCodeGen(ctx, ev, c => CollationSupport.Upper.genCode(c, collationId))
   }
 
   override protected def withNewChildInternal(newChild: Expression): Upper = 
copy(child = newChild)
@@ -481,14 +481,14 @@ case class Upper(child: Expression)
 case class Lower(child: Expression)
   extends UnaryExpression with String2StringExpression with NullIntolerant {
 
-  // scalastyle:off caselocale
-  override def convert(v: UTF8String): UTF8String = v.toLowerCase
-  // scalastyle:on caselocale
+  final lazy val collationId: Int = 
child.dataType.asInstanceOf[StringType].collationId
+
+  override def convert(v: UTF8String): UTF8String = 
CollationSupport.Lower.exec(v, collationId)
 
   final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    defineCodeGen(ctx, ev, c => s"($c).toLowerCase()")
+    defineCodeGen(ctx, ev, c => CollationSupport.Lower.genCode(c, collationId))
   }
 
   override def prettyName: String =
@@ -1824,16 +1824,16 @@ case class FormatString(children: Expression*) extends 
Expression with ImplicitC
 case class InitCap(child: Expression)
   extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
+  final lazy val collationId: Int = 
child.dataType.asInstanceOf[StringType].collationId
+
   override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)
   override def dataType: DataType = child.dataType
 
   override def nullSafeEval(string: Any): Any = {
-    // scalastyle:off caselocale
-    string.asInstanceOf[UTF8String].toLowerCase.toTitleCase
-    // scalastyle:on caselocale
+    CollationSupport.InitCap.exec(string.asInstanceOf[UTF8String], collationId)
   }
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()")
+    defineCodeGen(ctx, ev, str => CollationSupport.InitCap.genCode(str, 
collationId))
   }
 
   override protected def withNewChildInternal(newChild: Expression): InitCap =


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to