This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8e4bbdff80a1 [SPARK-48440][SQL] Fix StringTranslate behaviour for 
non-UTF8_BINARY collations
8e4bbdff80a1 is described below

commit 8e4bbdff80a1c069ccce71060751987e9e6c0b6b
Author: Uros Bojanic <[email protected]>
AuthorDate: Fri Jul 12 22:30:18 2024 +0800

    [SPARK-48440][SQL] Fix StringTranslate behaviour for non-UTF8_BINARY 
collations
    
    ### What changes were proposed in this pull request?
    String searching in UTF8_LCASE now works on character-level, rather than on 
byte-level. For example: `translate("İ", "i")` now returns `"İ"`, because there 
exists no **single character** in `"İ"` such that lowercased version of that 
character equals to `"i"`. Note, however, that there _is_ a byte subsequence of 
`"İ"` such that lowercased version of that UTF-8 byte sequence equals to `"i"` 
(so the new behaviour is different than the old behaviour).
    
    Also, translation for ICU collations works by repeatedly translating the 
longest possible substring that matches a key in the dictionary (under the 
specified collation), starting from the left side of the input string, until 
the entire string is translated.
    
    ### Why are the changes needed?
    Fix functions that give unusable results due to one-to-many case mapping 
when performing string search under UTF8_BINARY_LCASE (see example above).
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, behaviour of `translate` expression is changed for edge cases with 
one-to-many case mapping.
    
    ### How was this patch tested?
    New unit tests in `CollationStringExpressionsSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #46761 from uros-db/alter-translate.
    
    Authored-by: Uros Bojanic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../catalyst/util/CollationAwareUTF8String.java    | 218 +++++++++++++++++----
 .../spark/sql/catalyst/util/CollationSupport.java  |  25 +--
 .../spark/unsafe/types/CollationSupportSuite.java  | 192 ++++++++++++++++--
 .../catalyst/expressions/stringExpressions.scala   |  30 ++-
 .../sql/CollationStringExpressionsSuite.scala      |  51 +----
 5 files changed, 402 insertions(+), 114 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
index 23adc772b7f3..af152c87f88c 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
@@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.util;
 
 import com.ibm.icu.lang.UCharacter;
 import com.ibm.icu.text.BreakIterator;
+import com.ibm.icu.text.Collator;
+import com.ibm.icu.text.RuleBasedCollator;
 import com.ibm.icu.text.StringSearch;
 import com.ibm.icu.util.ULocale;
 
@@ -26,8 +28,12 @@ import org.apache.spark.unsafe.types.UTF8String;
 
 import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
 import static org.apache.spark.unsafe.Platform.copyMemory;
+import static org.apache.spark.unsafe.types.UTF8String.CodePointIteratorType;
 
+import java.text.CharacterIterator;
+import java.text.StringCharacterIterator;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.Map;
 
 /**
@@ -424,19 +430,50 @@ public class CollationAwareUTF8String {
    * @param codePoint The code point to convert to lowercase.
    * @param sb The StringBuilder to append the lowercase character to.
    */
-  private static void lowercaseCodePoint(final int codePoint, final 
StringBuilder sb) {
-    if (codePoint == 0x0130) {
+  private static void appendLowercaseCodePoint(final int codePoint, final 
StringBuilder sb) {
+    int lowercaseCodePoint = getLowercaseCodePoint(codePoint);
+    if (lowercaseCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) {
       // Latin capital letter I with dot above is mapped to 2 lowercase 
characters.
       sb.appendCodePoint(0x0069);
       sb.appendCodePoint(0x0307);
+    } else {
+      // All other characters should follow context-unaware ICU single-code 
point case mapping.
+      sb.appendCodePoint(lowercaseCodePoint);
+    }
+  }
+
+  /**
+   * `CODE_POINT_COMBINED_LOWERCASE_I_DOT` is an internal representation of 
the combined lowercase
+   * code point for ASCII lowercase letter i with an additional combining dot 
character (U+0307).
+   * This integer value is not a valid code point itself, but rather an 
artificial code point
+   * marker used to represent the two lowercase characters that are the result 
of converting the
+   * uppercase Turkish dotted letter I with a combining dot character (U+0130) 
to lowercase.
+   */
+  private static final int CODE_POINT_LOWERCASE_I = 0x69;
+  private static final int CODE_POINT_COMBINING_DOT = 0x307;
+  private static final int CODE_POINT_COMBINED_LOWERCASE_I_DOT =
+    CODE_POINT_LOWERCASE_I << 16 | CODE_POINT_COMBINING_DOT;
+
+  /**
+   * Returns the lowercase version of the provided code point, with special 
handling for
+   * one-to-many case mappings (i.e. characters that map to multiple 
characters in lowercase) and
+   * context-insensitive case mappings (i.e. characters that map to different 
characters based on
+   * the position in the string relative to other characters in lowercase).
+   */
+  private static int getLowercaseCodePoint(final int codePoint) {
+    if (codePoint == 0x0130) {
+      // Latin capital letter I with dot above is mapped to 2 lowercase 
characters.
+      return CODE_POINT_COMBINED_LOWERCASE_I_DOT;
     }
     else if (codePoint == 0x03C2) {
-      // Greek final and non-final capital letter sigma should be mapped the 
same.
-      sb.appendCodePoint(0x03C3);
+      // Greek final and non-final letter sigma should be mapped the same. 
This is achieved by
+      // mapping Greek small final sigma (U+03C2) to Greek small non-final 
sigma (U+03C3). Capital
+      // letter sigma (U+03A3) is mapped to small non-final sigma (U+03C3) in 
the `else` branch.
+      return 0x03C3;
     }
     else {
       // All other characters should follow context-unaware ICU single-code 
point case mapping.
-      sb.appendCodePoint(UCharacter.toLowerCase(codePoint));
+      return UCharacter.toLowerCase(codePoint);
     }
   }
 
@@ -444,7 +481,7 @@ public class CollationAwareUTF8String {
    * Converts an entire string to lowercase using ICU rules, code point by 
code point, with
    * special handling for one-to-many case mappings (i.e. characters that map 
to multiple
    * characters in lowercase). Also, this method omits information about 
context-sensitive case
-   * mappings using special handling in the `lowercaseCodePoint` method.
+   * mappings using special handling in the `appendLowercaseCodePoint` method.
    *
    * @param target The target string to convert to lowercase.
    * @return The string converted to lowercase in a context-unaware manner.
@@ -455,10 +492,11 @@ public class CollationAwareUTF8String {
   }
 
   private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) {
-    String targetString = target.toValidString();
+    Iterator<Integer> targetIter = target.codePointIterator(
+      CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
     StringBuilder sb = new StringBuilder();
-    for (int i = 0; i < targetString.length(); ++i) {
-      lowercaseCodePoint(targetString.codePointAt(i), sb);
+    while (targetIter.hasNext()) {
+      appendLowercaseCodePoint(targetIter.next(), sb);
     }
     return UTF8String.fromString(sb.toString());
   }
@@ -655,38 +693,152 @@ public class CollationAwareUTF8String {
     }
   }
 
-  public static Map<String, String> getCollationAwareDict(UTF8String string,
-      Map<String, String> dict, int collationId) {
-    // TODO(SPARK-48715): All UTF8String -> String conversions should use 
`makeValid`
-    String srcStr = string.toString();
+  /**
+   * Converts the original translation dictionary (`dict`) to a dictionary 
with lowercased keys.
+   * This method is used to create a dictionary that can be used for the 
UTF8_LCASE collation.
+   * Note that `StringTranslate.buildDict` will ensure that all strings are 
validated properly.
+   *
+   * The method returns a map with lowercased code points as keys, while the 
values remain
+   * unchanged. Note that `dict` is constructed on a character by character 
basis, and the
+   * original keys are stored as strings. Keys in the resulting lowercase 
dictionary are stored
+   * as integers, which correspond only to single characters from the original 
`dict`. Also,
+   * there is special handling for the Turkish dotted uppercase letter I 
(U+0130).
+   */
+  private static Map<Integer, String> getLowercaseDict(final Map<String, 
String> dict) {
+    // Replace all the keys in the dict with lowercased code points.
+    Map<Integer, String> lowercaseDict = new HashMap<>();
+    for (Map.Entry<String, String> entry : dict.entrySet()) {
+      int codePoint = entry.getKey().codePointAt(0);
+      lowercaseDict.putIfAbsent(getLowercaseCodePoint(codePoint), 
entry.getValue());
+    }
+    return lowercaseDict;
+  }
+
+  /**
+   * Translates the `input` string using the translation map `dict`, for 
UTF8_LCASE collation.
+   * String translation is performed by iterating over the input string, from 
left to right, and
+   * repeatedly translating the longest possible substring that matches a key 
in the dictionary.
+   * For UTF8_LCASE, the method uses the lowercased substring to perform the 
lookup in the
+   * lowercased version of the translation map.
+   *
+   * @param input the string to be translated
+   * @param dict the lowercase translation dictionary
+   * @return the translated string
+   */
+  public static UTF8String lowercaseTranslate(final UTF8String input,
+      final Map<String, String> dict) {
+    // Iterator for the input string.
+    Iterator<Integer> inputIter = input.codePointIterator(
+      CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
+    // Lowercased translation dictionary.
+    Map<Integer, String> lowercaseDict = getLowercaseDict(dict);
+    // StringBuilder to store the translated string.
+    StringBuilder sb = new StringBuilder();
 
-    Map<String, String> collationAwareDict = new HashMap<>();
-    for (String key : dict.keySet()) {
-      StringSearch stringSearch =
-        CollationFactory.getStringSearch(string, UTF8String.fromString(key), 
collationId);
+    // We use buffered code point iteration to handle one-to-many case 
mappings. We need to handle
+    // at most two code points at a time (for 
`CODE_POINT_COMBINED_LOWERCASE_I_DOT`), a buffer of
+    // size 1 enables us to match two codepoints in the input string with a 
single codepoint in
+    // the lowercase translation dictionary.
+    int codePointBuffer = -1, codePoint;
+    while (inputIter.hasNext()) {
+      if (codePointBuffer != -1) {
+        codePoint = codePointBuffer;
+        codePointBuffer = -1;
+      } else {
+        codePoint = inputIter.next();
+      }
+      // Special handling for letter i (U+0069) followed by a combining dot 
(U+0307). By ensuring
+      // that `CODE_POINT_LOWERCASE_I` is buffered, we guarantee finding a 
max-length match.
+      if (lowercaseDict.containsKey(CODE_POINT_COMBINED_LOWERCASE_I_DOT) &&
+          codePoint == CODE_POINT_LOWERCASE_I && inputIter.hasNext()) {
+        int nextCodePoint = inputIter.next();
+        if (nextCodePoint == CODE_POINT_COMBINING_DOT) {
+          codePoint = CODE_POINT_COMBINED_LOWERCASE_I_DOT;
+        } else {
+          codePointBuffer = nextCodePoint;
+        }
+      }
+      // Translate the code point using the lowercased dictionary.
+      String translated = lowercaseDict.get(getLowercaseCodePoint(codePoint));
+      if (translated == null) {
+        // Append the original code point if no translation is found.
+        sb.appendCodePoint(codePoint);
+      } else if (!"\0".equals(translated)) {
+        // Append the translated code point if the translation is not the null 
character.
+        sb.append(translated);
+      }
+      // Skip the code point if it maps to the null character.
+    }
+    // Append the last code point if it was buffered.
+    if (codePointBuffer != -1) sb.appendCodePoint(codePointBuffer);
 
-      int pos = 0;
-      while ((pos = stringSearch.next()) != StringSearch.DONE) {
-        int codePoint = srcStr.codePointAt(pos);
-        int charCount = Character.charCount(codePoint);
-        String newKey = srcStr.substring(pos, pos + charCount);
+    // Return the translated string.
+    return UTF8String.fromString(sb.toString());
+  }
 
-        boolean exists = false;
-        for (String existingKey : collationAwareDict.keySet()) {
-          if (stringSearch.getCollator().compare(existingKey, newKey) == 0) {
-            collationAwareDict.put(newKey, 
collationAwareDict.get(existingKey));
-            exists = true;
-            break;
+  /**
+   * Translates the `input` string using the translation map `dict`, for all 
ICU collations.
+   * String translation is performed by iterating over the input string, from 
left to right, and
+   * repeatedly translating the longest possible substring that matches a key 
in the dictionary.
+   * For ICU collations, the method uses the ICU `StringSearch` class to 
perform the lookup in
+   * the translation map, while respecting the rules of the specified ICU 
collation.
+   *
+   * @param input the string to be translated
+   * @param dict the collation aware translation dictionary
+   * @param collationId the collation ID to use for string translation
+   * @return the translated string
+   */
+  public static UTF8String translate(final UTF8String input,
+      final Map<String, String> dict, final int collationId) {
+    // Replace invalid UTF-8 sequences with the Unicode replacement character 
U+FFFD.
+    String inputString = input.toValidString();
+    // Create a character iterator for the validated input string. This will 
be used for searching
+    // inside the string using ICU `StringSearch` class. We only need to do it 
once before the
+    // main loop of the translate algorithm.
+    CharacterIterator target = new StringCharacterIterator(inputString);
+    Collator collator = CollationFactory.fetchCollation(collationId).collator;
+    StringBuilder sb = new StringBuilder();
+    // Index for the current character in the (validated) input string. This 
is the character we
+    // want to determine if we need to replace or not.
+    int charIndex = 0;
+    while (charIndex < inputString.length()) {
+      // We search the replacement dictionary to find a match. If there are 
more than one matches
+      // (which is possible for collated strings), we want to choose the match 
of largest length.
+      int longestMatchLen = 0;
+      String longestMatch = "";
+      for (String key : dict.keySet()) {
+        StringSearch stringSearch = new StringSearch(key, target, 
(RuleBasedCollator) collator);
+        // Point `stringSearch` to start at the current character.
+        stringSearch.setIndex(charIndex);
+        int matchIndex = stringSearch.next();
+        if (matchIndex == charIndex) {
+          // We have found a match (that is the current position matches with 
one of the characters
+          // in the dictionary). However, there might be other matches of 
larger length, so we need
+          // to continue searching against the characters in the dictionary 
and keep track of the
+          // match of largest length.
+          int matchLen = stringSearch.getMatchLength();
+          if (matchLen > longestMatchLen) {
+            longestMatchLen = matchLen;
+            longestMatch = key;
           }
         }
-
-        if (!exists) {
-          collationAwareDict.put(newKey, dict.get(key));
+      }
+      if (longestMatchLen == 0) {
+        // No match was found, so output the current character.
+        sb.append(inputString.charAt(charIndex));
+        // Move on to the next character in the input string.
+        ++charIndex;
+      } else {
+        // We have found at least one match. Append the match of longest match 
length to the output.
+        if (!"\0".equals(dict.get(longestMatch))) {
+          sb.append(dict.get(longestMatch));
         }
+        // Skip as many characters as the longest match.
+        charIndex += longestMatchLen;
       }
     }
-
-    return collationAwareDict;
+    // Return the translated string.
+    return UTF8String.fromString(sb.toString());
   }
 
   public static UTF8String lowercaseTrim(
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index 450a3eea1a3a..f9ccd22f3f5c 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -212,7 +212,7 @@ public final class CollationSupport {
         return useICU ? execBinaryICU(v) : execBinary(v);
       } else if (collation.supportsLowercaseEquality) {
         return execLowercase(v);
-      }  else {
+      } else {
         return execICU(v, collationId);
       }
     }
@@ -224,7 +224,7 @@ public final class CollationSupport {
         return String.format(expr + "%s(%s)", funcName, v);
       } else if (collation.supportsLowercaseEquality) {
         return String.format(expr + "Lowercase(%s)", v);
-      }  else {
+      } else {
         return String.format(expr + "ICU(%s, %d)", v, collationId);
       }
     }
@@ -261,7 +261,7 @@ public final class CollationSupport {
         return String.format(expr + "%s(%s)", funcName, v);
       } else if (collation.supportsLowercaseEquality) {
         return String.format(expr + "Lowercase(%s)", v);
-      }  else {
+      } else {
         return String.format(expr + "ICU(%s, %d)", v, collationId);
       }
     }
@@ -522,26 +522,11 @@ public final class CollationSupport {
       return source.translate(dict);
     }
     public static UTF8String execLowercase(final UTF8String source, 
Map<String, String> dict) {
-      String srcStr = source.toString();
-      StringBuilder sb = new StringBuilder();
-      int charCount = 0;
-      for (int k = 0; k < srcStr.length(); k += charCount) {
-        int codePoint = srcStr.codePointAt(k);
-        charCount = Character.charCount(codePoint);
-        String subStr = srcStr.substring(k, k + charCount);
-        String translated = dict.get(subStr.toLowerCase());
-        if (null == translated) {
-          sb.append(subStr);
-        } else if (!"\0".equals(translated)) {
-          sb.append(translated);
-        }
-      }
-      return UTF8String.fromString(sb.toString());
+      return CollationAwareUTF8String.lowercaseTranslate(source, dict);
     }
     public static UTF8String execICU(final UTF8String source, Map<String, 
String> dict,
         final int collationId) {
-      return source.translate(CollationAwareUTF8String.getCollationAwareDict(
-        source, dict, collationId));
+      return CollationAwareUTF8String.translate(source, dict, collationId);
     }
   }
 
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 9438484344d6..ce0cef3fef30 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -22,6 +22,9 @@ import org.apache.spark.sql.catalyst.util.CollationFactory;
 import org.apache.spark.sql.catalyst.util.CollationSupport;
 import org.junit.jupiter.api.Test;
 
+import java.util.HashMap;
+import java.util.Map;
+
 import static org.junit.jupiter.api.Assertions.*;
 
 // checkstyle.off: AvoidEscapedUnicodeCharacters
@@ -1378,19 +1381,186 @@ public class CollationSupportSuite {
     assertStringTrimRight("UTF8_LCASE", "Ëaaaẞ", "Ëẞ", "Ëaaa");
   }
 
-  // TODO: Test more collation-aware string expressions.
-
-  /**
-   * Collation-aware regexp expressions.
-   */
-
-  // TODO: Test more collation-aware regexp expressions.
+  private void assertStringTranslate(
+      String inputString,
+      String matchingString,
+      String replaceString,
+      String collationName,
+      String expectedResultString) throws SparkException {
+    int collationId = CollationFactory.collationNameToId(collationName);
+    Map<String, String> dict = buildDict(matchingString, replaceString);
+    UTF8String source = UTF8String.fromString(inputString);
+    UTF8String result = CollationSupport.StringTranslate.exec(source, dict, 
collationId);
+    assertEquals(expectedResultString, result.toString());
+  }
 
-  /**
-   * Other collation-aware expressions.
-   */
+  @Test
+  public void testStringTranslate() throws SparkException {
+    // Basic tests - UTF8_BINARY.
+    assertStringTranslate("Translate", "Rnlt", "12", "UTF8_BINARY", "Tra2sae");
+    assertStringTranslate("Translate", "Rn", "1234", "UTF8_BINARY", 
"Tra2slate");
+    assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_BINARY", 
"Tra2s3a4e");
+    assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_BINARY", 
"TRaxsXaxe");
+    assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY", 
"TxaxsXaxeX");
+    assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY", 
"TXaxsXaxex");
+    assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY", 
"test大千世AX大千世A");
+    assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY", 
"大千世界test大千世界");
+    assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY", 
"Oeso大千世界大千世界");
+    assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY", 
"大千世界大千世界oesO");
+    assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY", 
"世世世界世世世界tesT");
+    assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", 
"Tr4234e");
+    assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_BINARY", 
"Tra2s3a4e");
+    assertStringTranslate("abcdef", "abcde", "123", "UTF8_BINARY", "123f");
+    // Basic tests - UTF8_LCASE.
+    assertStringTranslate("Translate", "Rnlt", "12", "UTF8_LCASE", "1a2sae");
+    assertStringTranslate("Translate", "Rn", "1234", "UTF8_LCASE", 
"T1a2slate");
+    assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_LCASE", 
"41a2s3a4e");
+    assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_LCASE", 
"xXaxsXaxe");
+    assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE", 
"xxaxsXaxex");
+    assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_LCASE", 
"xXaxsXaxeX");
+    assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_LCASE", 
"test大千世AB大千世A");
+    assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_LCASE", 
"大千世界abca大千世界");
+    assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE", 
"oeso大千世界大千世界");
+    assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE", 
"大千世界大千世界OesO");
+    assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE", 
"世世世界世世世界tesT");
+    assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE", 
"14234e");
+    assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_LCASE", 
"41a2s3a4e");
+    assertStringTranslate("abcdef", "abcde", "123", "UTF8_LCASE", "123f");
+    // Basic tests - UNICODE.
+    assertStringTranslate("Translate", "Rnlt", "12", "UNICODE", "Tra2sae");
+    assertStringTranslate("Translate", "Rn", "1234", "UNICODE", "Tra2slate");
+    assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e");
+    assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe");
+    assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE", 
"TxaxsXaxeX");
+    assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE", 
"TXaxsXaxex");
+    assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE", 
"test大千世AX大千世A");
+    assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UNICODE", 
"大千世界test大千世界");
+    assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE", 
"Oeso大千世界大千世界");
+    assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", 
"大千世界大千世界oesO");
+    assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE", 
"世世世界世世世界tesT");
+    assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", 
"Tr4234e");
+    assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE", 
"Tra2s3a4e");
+    assertStringTranslate("abcdef", "abcde", "123", "UNICODE", "123f");
+    // Basic tests - UNICODE_CI.
+    assertStringTranslate("Translate", "Rnlt", "12", "UNICODE_CI", "1a2sae");
+    assertStringTranslate("Translate", "Rn", "1234", "UNICODE_CI", 
"T1a2slate");
+    assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE_CI", 
"41a2s3a4e");
+    assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE_CI", 
"xXaxsXaxe");
+    assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", 
"xxaxsXaxex");
+    assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", 
"xXaxsXaxeX");
+    assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", 
"test大千世AB大千世A");
+    assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", 
"大千世界abca大千世界");
+    assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", 
"oeso大千世界大千世界");
+    assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", 
"大千世界大千世界OesO");
+    assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", 
"世世世界世世世界tesT");
+    assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", 
"14234e");
+    assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE_CI", 
"41a2s3a4e");
+    assertStringTranslate("abcdef", "abcde", "123", "UNICODE_CI", "123f");
+
+    // One-to-many case mapping - UTF8_BINARY.
+    assertStringTranslate("İ", "i\u0307", "xy", "UTF8_BINARY", "İ");
+    assertStringTranslate("i\u0307", "İ", "xy", "UTF8_BINARY", "i\u0307");
+    assertStringTranslate("i\u030A", "İ", "x", "UTF8_BINARY", "i\u030A");
+    assertStringTranslate("i\u030A", "İi", "xy", "UTF8_BINARY", "y\u030A");
+    assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123");
+    assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307");
+    assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23");
+    assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_BINARY", 
"12bc3");
+    assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_BINARY", 
"a2bcå");
+    assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UTF8_BINARY", 
"3\u030Aβφδ1\u0307");
+    // One-to-many case mapping - UTF8_LCASE.
+    assertStringTranslate("İ", "i\u0307", "xy", "UTF8_LCASE", "İ");
+    assertStringTranslate("i\u0307", "İ", "xy", "UTF8_LCASE", "x");
+    assertStringTranslate("i\u030A", "İ", "x", "UTF8_LCASE", "i\u030A");
+    assertStringTranslate("i\u030A", "İi", "xy", "UTF8_LCASE", "y\u030A");
+    assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_LCASE", "11");
+    assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_LCASE", "11");
+    assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_LCASE", "İ23");
+    assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_LCASE", 
"12bc3");
+    assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_LCASE", 
"12bc3");
+    assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UTF8_LCASE", 
"3\u030Aβφδ2");
+    // One-to-many case mapping - UNICODE.
+    assertStringTranslate("İ", "i\u0307", "xy", "UNICODE", "İ");
+    assertStringTranslate("i\u0307", "İ", "xy", "UNICODE", "i\u0307");
+    assertStringTranslate("i\u030A", "İ", "x", "UNICODE", "i\u030A");
+    assertStringTranslate("i\u030A", "İi", "xy", "UNICODE", "i\u030A");
+    assertStringTranslate("İi\u0307", "İi\u0307", "123", "UNICODE", 
"1i\u0307");
+    assertStringTranslate("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307");
+    assertStringTranslate("İi\u0307", "xi\u0307", "123", "UNICODE", 
"İi\u0307");
+    assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE", "3bc3");
+    assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE", 
"a\u030Abcå");
+    assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UNICODE", 
"4βφδ2");
+    // One-to-many case mapping - UNICODE_CI.
+    assertStringTranslate("İ", "i\u0307", "xy", "UNICODE_CI", "İ");
+    assertStringTranslate("i\u0307", "İ", "xy", "UNICODE_CI", "x");
+    assertStringTranslate("i\u030A", "İ", "x", "UNICODE_CI", "i\u030A");
+    assertStringTranslate("i\u030A", "İi", "xy", "UNICODE_CI", "i\u030A");
+    assertStringTranslate("İi\u0307", "İi\u0307", "123", "UNICODE_CI", "11");
+    assertStringTranslate("İi\u0307", "İyz", "123", "UNICODE_CI", "11");
+    assertStringTranslate("İi\u0307", "xi\u0307", "123", "UNICODE_CI", 
"İi\u0307");
+    assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE_CI", 
"3bc3");
+    assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE_CI", 
"3bc3");
+    assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UNICODE_CI", 
"4βφδ2");
+
+    // Greek sigmas - UTF8_BINARY.
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_BINARY", 
"σΥσΤΗΜΑΤΙΚΟσ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_BINARY", 
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_BINARY", 
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_BINARY", 
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_BINARY", 
"ςΥςΤΗΜΑΤΙΚΟς");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_BINARY", 
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+    assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_BINARY", 
"σιστιματικος");
+    assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_BINARY", 
"σιστιματικος");
+    assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_BINARY", 
"σιστιματικοσ");
+    // Greek sigmas - UTF8_LCASE.
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_LCASE", 
"σισΤιΜΑΤΙΚΟσ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_LCASE", 
"σισΤιΜΑΤΙΚΟσ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_LCASE", 
"σισΤιΜΑΤΙΚΟσ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_LCASE", 
"ςιςΤιΜΑΤΙΚΟς");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_LCASE", 
"ςιςΤιΜΑΤΙΚΟς");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_LCASE", 
"ςιςΤιΜΑΤΙΚΟς");
+    assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_LCASE", 
"σιστιματικοσ");
+    assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_LCASE", 
"σιστιματικοσ");
+    assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_LCASE", 
"σιστιματικοσ");
+    // Greek sigmas - UNICODE.
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE", 
"σΥσΤΗΜΑΤΙΚΟσ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE", 
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE", 
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UNICODE", 
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UNICODE", 
"ςΥςΤΗΜΑΤΙΚΟς");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UNICODE", 
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+    assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE", 
"σιστιματικος");
+    assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE", 
"σιστιματικος");
+    assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE", 
"σιστιματικοσ");
+    // Greek sigmas - UNICODE_CI.
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE_CI", 
"σισΤιΜΑΤΙΚΟσ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE_CI", 
"σισΤιΜΑΤΙΚΟσ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE_CI", 
"σισΤιΜΑΤΙΚΟσ");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UNICODE_CI", 
"ςιςΤιΜΑΤΙΚΟς");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UNICODE_CI", 
"ςιςΤιΜΑΤΙΚΟς");
+    assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UNICODE_CI", 
"ςιςΤιΜΑΤΙΚΟς");
+    assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE_CI", 
"σιστιματικοσ");
+    assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE_CI", 
"σιστιματικοσ");
+    assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE_CI", 
"σιστιματικοσ");
+  }
 
-  // TODO: Test other collation-aware expressions.
+  private Map<String, String> buildDict(String matching, String replace) {
+    Map<String, String> dict = new HashMap<>();
+    int i = 0, j = 0;
+    while (i < matching.length()) {
+      String rep = "\u0000";
+      if (j < replace.length()) {
+        int repCharCount = Character.charCount(replace.codePointAt(j));
+        rep = replace.substring(j, j + repCharCount);
+        j += repCharCount;
+      }
+      int matchCharCount = Character.charCount(matching.codePointAt(i));
+      String matchStr = matching.substring(i, i + matchCharCount);
+      dict.putIfAbsent(matchStr, rep);
+      i += matchCharCount;
+    }
+    return dict;
+  }
 
 }
 // checkstyle.on: AvoidEscapedUnicodeCharacters
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index b188b9c2630f..1302ca80e51a 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1050,15 +1050,35 @@ case class Overlay(input: Expression, replace: 
Expression, pos: Expression, len:
 
 object StringTranslate {
 
-  def buildDict(matchingString: UTF8String, replaceString: UTF8String, 
collationId: Int)
+  /**
+   * Build a translation dictionary from UTF8Strings. First, this method 
converts the input strings
+   * to valid Java Strings. However, we avoid any behavior changes for the 
UTF8_BINARY collation,
+   * but ensure that all other collations use `UTF8String.toValidString` to 
achieve this step.
+   */
+  def buildDict(matchingString: UTF8String, replaceString: UTF8String, 
collationId: Integer)
     : JMap[String, String] = {
-    val matching = if 
(CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) {
-      matchingString.toString().toLowerCase()
+    val isCollationAware = collationId == 
CollationFactory.UTF8_BINARY_COLLATION_ID
+    val matching: String = if (isCollationAware) {
+      matchingString.toString
+    } else {
+      matchingString.toValidString
+    }
+    val replace: String = if (isCollationAware) {
+      replaceString.toString
     } else {
-      matchingString.toString()
+      replaceString.toValidString
     }
+    buildDict(matching, replace)
+  }
 
-    val replace = replaceString.toString()
+  /**
+   * Build a translation dictionary from Strings. This method assumes that the 
input strings are
+   * already valid. The result dictionary maps each character in `matching` to 
the corresponding
+   * character in `replace`. If `replace` is shorter than `matching`, the 
extra characters in
+   * `matching` will be mapped to null terminator, which causes characters to 
get deleted during
+   * translation. If `replace` is longer than `matching`, the extra characters 
will be ignored.
+   */
+  private def buildDict(matching: String, replace: String): JMap[String, 
String] = {
     val dict = new HashMap[String, String]()
     var i = 0
     var j = 0
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 78aee5b80e54..5f722b2f01fb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -252,55 +252,16 @@ class CollationStringExpressionsSuite
     }
     assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
   }
-  test("TRANSLATE check result on explicitly collated string") {
+
+  test("Support StringTranslate string expression with collation") {
     // Supported collations
     case class TranslateTestCase[R](input: String, matchExpression: String,
-        replaceExpression: String, collation: String, result: R)
+      replaceExpression: String, collation: String, result: R)
     val testCases = Seq(
+      TranslateTestCase("Translate", "Rnlt", "12", "UTF8_BINARY", "Tra2sae"),
       TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", 
"41a2s3a4e"),
-      TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", 
"41a2s3a4e"),
-      TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_LCASE", 
"xXaxsXaxe"),
-      TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE", 
"xxaxsXaxex"),
-      TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_LCASE", 
"xXaxsXaxeX"),
-      // scalastyle:off
-      TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_LCASE", 
"test大千世AB大千世A"),
-      TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_LCASE", 
"大千世界abca大千世界"),
-      TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE", 
"oeso大千世界大千世界"),
-      TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE", 
"大千世界大千世界OesO"),
-      TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE", 
"世世世界世世世界tesT"),
-      // scalastyle:on
-      TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"),
-      TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"),
-      TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE", 
"TxaxsXaxeX"),
-      TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE", 
"TXaxsXaxex"),
-      // scalastyle:off
-      TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE", 
"test大千世AX大千世A"),
-      TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"),
-      TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"),
-      // scalastyle:on
-      TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE_CI", 
"41a2s3a4e"),
-      TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE_CI", 
"xXaxsXaxe"),
-      TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", 
"xxaxsXaxex"),
-      TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", 
"xXaxsXaxeX"),
-      // scalastyle:off
-      TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", 
"test大千世AB大千世A"),
-      TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", 
"大千世界abca大千世界"),
-      TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", 
"oeso大千世界大千世界"),
-      TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", 
"大千世界大千世界OesO"),
-      TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", 
"世世世界世世世界tesT"),
-      // scalastyle:on
-      TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE", 
"14234e"),
-      TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", 
"14234e"),
-      TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", 
"Tr4234e"),
-      TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", 
"Tr4234e"),
-      TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_LCASE", 
"41a2s3a4e"),
-      TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE", 
"Tra2s3a4e"),
-      TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE_CI", 
"41a2s3a4e"),
-      TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_BINARY", 
"Tra2s3a4e"),
-      TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY", "123f"),
-      TranslateTestCase("abcdef", "abcde", "123", "UTF8_LCASE", "123f"),
-      TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"),
-      TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f")
+      TranslateTestCase("Translate", "Rn", "\u0000\u0000", "UNICODE", 
"Traslate"),
+      TranslateTestCase("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate")
     )
 
     testCases.foreach(t => {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to