This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8e4bbdff80a1 [SPARK-48440][SQL] Fix StringTranslate behaviour for
non-UTF8_BINARY collations
8e4bbdff80a1 is described below
commit 8e4bbdff80a1c069ccce71060751987e9e6c0b6b
Author: Uros Bojanic <[email protected]>
AuthorDate: Fri Jul 12 22:30:18 2024 +0800
[SPARK-48440][SQL] Fix StringTranslate behaviour for non-UTF8_BINARY
collations
### What changes were proposed in this pull request?
String searching in UTF8_LCASE now works on character-level, rather than on
byte-level. For example: `translate("İ", "i")` now returns `"İ"`, because there
exists no **single character** in `"İ"` such that lowercased version of that
character equals to `"i"`. Note, however, that there _is_ a byte subsequence of
`"İ"` such that lowercased version of that UTF-8 byte sequence equals to `"i"`
(so the new behaviour is different than the old behaviour).
Also, translation for ICU collations works by repeatedly translating the
longest possible substring that matches a key in the dictionary (under the
specified collation), starting from the left side of the input string, until
the entire string is translated.
### Why are the changes needed?
Fix functions that give unusable results due to one-to-many case mapping
when performing string search under UTF8_BINARY_LCASE (see example above).
### Does this PR introduce _any_ user-facing change?
Yes, behaviour of `translate` expression is changed for edge cases with
one-to-many case mapping.
### How was this patch tested?
New unit tests in `CollationStringExpressionsSuite`.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #46761 from uros-db/alter-translate.
Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/util/CollationAwareUTF8String.java | 218 +++++++++++++++++----
.../spark/sql/catalyst/util/CollationSupport.java | 25 +--
.../spark/unsafe/types/CollationSupportSuite.java | 192 ++++++++++++++++--
.../catalyst/expressions/stringExpressions.scala | 30 ++-
.../sql/CollationStringExpressionsSuite.scala | 51 +----
5 files changed, 402 insertions(+), 114 deletions(-)
diff --git
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
index 23adc772b7f3..af152c87f88c 100644
---
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
+++
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
@@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.util;
import com.ibm.icu.lang.UCharacter;
import com.ibm.icu.text.BreakIterator;
+import com.ibm.icu.text.Collator;
+import com.ibm.icu.text.RuleBasedCollator;
import com.ibm.icu.text.StringSearch;
import com.ibm.icu.util.ULocale;
@@ -26,8 +28,12 @@ import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
import static org.apache.spark.unsafe.Platform.copyMemory;
+import static org.apache.spark.unsafe.types.UTF8String.CodePointIteratorType;
+import java.text.CharacterIterator;
+import java.text.StringCharacterIterator;
import java.util.HashMap;
+import java.util.Iterator;
import java.util.Map;
/**
@@ -424,19 +430,50 @@ public class CollationAwareUTF8String {
* @param codePoint The code point to convert to lowercase.
* @param sb The StringBuilder to append the lowercase character to.
*/
- private static void lowercaseCodePoint(final int codePoint, final
StringBuilder sb) {
- if (codePoint == 0x0130) {
+ private static void appendLowercaseCodePoint(final int codePoint, final
StringBuilder sb) {
+ int lowercaseCodePoint = getLowercaseCodePoint(codePoint);
+ if (lowercaseCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) {
// Latin capital letter I with dot above is mapped to 2 lowercase
characters.
sb.appendCodePoint(0x0069);
sb.appendCodePoint(0x0307);
+ } else {
+ // All other characters should follow context-unaware ICU single-code
point case mapping.
+ sb.appendCodePoint(lowercaseCodePoint);
+ }
+ }
+
+ /**
+ * `CODE_POINT_COMBINED_LOWERCASE_I_DOT` is an internal representation of
the combined lowercase
+ * code point for ASCII lowercase letter i with an additional combining dot
character (U+0307).
+ * This integer value is not a valid code point itself, but rather an
artificial code point
+ * marker used to represent the two lowercase characters that are the result
of converting the
+ * uppercase Turkish dotted letter I with a combining dot character (U+0130)
to lowercase.
+ */
+ private static final int CODE_POINT_LOWERCASE_I = 0x69;
+ private static final int CODE_POINT_COMBINING_DOT = 0x307;
+ private static final int CODE_POINT_COMBINED_LOWERCASE_I_DOT =
+ CODE_POINT_LOWERCASE_I << 16 | CODE_POINT_COMBINING_DOT;
+
+ /**
+ * Returns the lowercase version of the provided code point, with special
handling for
+ * one-to-many case mappings (i.e. characters that map to multiple
characters in lowercase) and
+ * context-insensitive case mappings (i.e. characters that map to different
characters based on
+ * the position in the string relative to other characters in lowercase).
+ */
+ private static int getLowercaseCodePoint(final int codePoint) {
+ if (codePoint == 0x0130) {
+ // Latin capital letter I with dot above is mapped to 2 lowercase
characters.
+ return CODE_POINT_COMBINED_LOWERCASE_I_DOT;
}
else if (codePoint == 0x03C2) {
- // Greek final and non-final capital letter sigma should be mapped the
same.
- sb.appendCodePoint(0x03C3);
+ // Greek final and non-final letter sigma should be mapped the same.
This is achieved by
+ // mapping Greek small final sigma (U+03C2) to Greek small non-final
sigma (U+03C3). Capital
+ // letter sigma (U+03A3) is mapped to small non-final sigma (U+03C3) in
the `else` branch.
+ return 0x03C3;
}
else {
// All other characters should follow context-unaware ICU single-code
point case mapping.
- sb.appendCodePoint(UCharacter.toLowerCase(codePoint));
+ return UCharacter.toLowerCase(codePoint);
}
}
@@ -444,7 +481,7 @@ public class CollationAwareUTF8String {
* Converts an entire string to lowercase using ICU rules, code point by
code point, with
* special handling for one-to-many case mappings (i.e. characters that map
to multiple
* characters in lowercase). Also, this method omits information about
context-sensitive case
- * mappings using special handling in the `lowercaseCodePoint` method.
+ * mappings using special handling in the `appendLowercaseCodePoint` method.
*
* @param target The target string to convert to lowercase.
* @return The string converted to lowercase in a context-unaware manner.
@@ -455,10 +492,11 @@ public class CollationAwareUTF8String {
}
private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) {
- String targetString = target.toValidString();
+ Iterator<Integer> targetIter = target.codePointIterator(
+ CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
StringBuilder sb = new StringBuilder();
- for (int i = 0; i < targetString.length(); ++i) {
- lowercaseCodePoint(targetString.codePointAt(i), sb);
+ while (targetIter.hasNext()) {
+ appendLowercaseCodePoint(targetIter.next(), sb);
}
return UTF8String.fromString(sb.toString());
}
@@ -655,38 +693,152 @@ public class CollationAwareUTF8String {
}
}
- public static Map<String, String> getCollationAwareDict(UTF8String string,
- Map<String, String> dict, int collationId) {
- // TODO(SPARK-48715): All UTF8String -> String conversions should use
`makeValid`
- String srcStr = string.toString();
+ /**
+ * Converts the original translation dictionary (`dict`) to a dictionary
with lowercased keys.
+ * This method is used to create a dictionary that can be used for the
UTF8_LCASE collation.
+ * Note that `StringTranslate.buildDict` will ensure that all strings are
validated properly.
+ *
+ * The method returns a map with lowercased code points as keys, while the
values remain
+ * unchanged. Note that `dict` is constructed on a character by character
basis, and the
+ * original keys are stored as strings. Keys in the resulting lowercase
dictionary are stored
+ * as integers, which correspond only to single characters from the original
`dict`. Also,
+ * there is special handling for the Turkish dotted uppercase letter I
(U+0130).
+ */
+ private static Map<Integer, String> getLowercaseDict(final Map<String,
String> dict) {
+ // Replace all the keys in the dict with lowercased code points.
+ Map<Integer, String> lowercaseDict = new HashMap<>();
+ for (Map.Entry<String, String> entry : dict.entrySet()) {
+ int codePoint = entry.getKey().codePointAt(0);
+ lowercaseDict.putIfAbsent(getLowercaseCodePoint(codePoint),
entry.getValue());
+ }
+ return lowercaseDict;
+ }
+
+ /**
+ * Translates the `input` string using the translation map `dict`, for
UTF8_LCASE collation.
+ * String translation is performed by iterating over the input string, from
left to right, and
+ * repeatedly translating the longest possible substring that matches a key
in the dictionary.
+ * For UTF8_LCASE, the method uses the lowercased substring to perform the
lookup in the
+ * lowercased version of the translation map.
+ *
+ * @param input the string to be translated
+ * @param dict the lowercase translation dictionary
+ * @return the translated string
+ */
+ public static UTF8String lowercaseTranslate(final UTF8String input,
+ final Map<String, String> dict) {
+ // Iterator for the input string.
+ Iterator<Integer> inputIter = input.codePointIterator(
+ CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
+ // Lowercased translation dictionary.
+ Map<Integer, String> lowercaseDict = getLowercaseDict(dict);
+ // StringBuilder to store the translated string.
+ StringBuilder sb = new StringBuilder();
- Map<String, String> collationAwareDict = new HashMap<>();
- for (String key : dict.keySet()) {
- StringSearch stringSearch =
- CollationFactory.getStringSearch(string, UTF8String.fromString(key),
collationId);
+ // We use buffered code point iteration to handle one-to-many case
mappings. We need to handle
+ // at most two code points at a time (for
`CODE_POINT_COMBINED_LOWERCASE_I_DOT`), a buffer of
+ // size 1 enables us to match two codepoints in the input string with a
single codepoint in
+ // the lowercase translation dictionary.
+ int codePointBuffer = -1, codePoint;
+ while (inputIter.hasNext()) {
+ if (codePointBuffer != -1) {
+ codePoint = codePointBuffer;
+ codePointBuffer = -1;
+ } else {
+ codePoint = inputIter.next();
+ }
+ // Special handling for letter i (U+0069) followed by a combining dot
(U+0307). By ensuring
+ // that `CODE_POINT_LOWERCASE_I` is buffered, we guarantee finding a
max-length match.
+ if (lowercaseDict.containsKey(CODE_POINT_COMBINED_LOWERCASE_I_DOT) &&
+ codePoint == CODE_POINT_LOWERCASE_I && inputIter.hasNext()) {
+ int nextCodePoint = inputIter.next();
+ if (nextCodePoint == CODE_POINT_COMBINING_DOT) {
+ codePoint = CODE_POINT_COMBINED_LOWERCASE_I_DOT;
+ } else {
+ codePointBuffer = nextCodePoint;
+ }
+ }
+ // Translate the code point using the lowercased dictionary.
+ String translated = lowercaseDict.get(getLowercaseCodePoint(codePoint));
+ if (translated == null) {
+ // Append the original code point if no translation is found.
+ sb.appendCodePoint(codePoint);
+ } else if (!"\0".equals(translated)) {
+ // Append the translated code point if the translation is not the null
character.
+ sb.append(translated);
+ }
+ // Skip the code point if it maps to the null character.
+ }
+ // Append the last code point if it was buffered.
+ if (codePointBuffer != -1) sb.appendCodePoint(codePointBuffer);
- int pos = 0;
- while ((pos = stringSearch.next()) != StringSearch.DONE) {
- int codePoint = srcStr.codePointAt(pos);
- int charCount = Character.charCount(codePoint);
- String newKey = srcStr.substring(pos, pos + charCount);
+ // Return the translated string.
+ return UTF8String.fromString(sb.toString());
+ }
- boolean exists = false;
- for (String existingKey : collationAwareDict.keySet()) {
- if (stringSearch.getCollator().compare(existingKey, newKey) == 0) {
- collationAwareDict.put(newKey,
collationAwareDict.get(existingKey));
- exists = true;
- break;
+ /**
+ * Translates the `input` string using the translation map `dict`, for all
ICU collations.
+ * String translation is performed by iterating over the input string, from
left to right, and
+ * repeatedly translating the longest possible substring that matches a key
in the dictionary.
+ * For ICU collations, the method uses the ICU `StringSearch` class to
perform the lookup in
+ * the translation map, while respecting the rules of the specified ICU
collation.
+ *
+ * @param input the string to be translated
+ * @param dict the collation aware translation dictionary
+ * @param collationId the collation ID to use for string translation
+ * @return the translated string
+ */
+ public static UTF8String translate(final UTF8String input,
+ final Map<String, String> dict, final int collationId) {
+ // Replace invalid UTF-8 sequences with the Unicode replacement character
U+FFFD.
+ String inputString = input.toValidString();
+ // Create a character iterator for the validated input string. This will
be used for searching
+ // inside the string using ICU `StringSearch` class. We only need to do it
once before the
+ // main loop of the translate algorithm.
+ CharacterIterator target = new StringCharacterIterator(inputString);
+ Collator collator = CollationFactory.fetchCollation(collationId).collator;
+ StringBuilder sb = new StringBuilder();
+ // Index for the current character in the (validated) input string. This
is the character we
+ // want to determine if we need to replace or not.
+ int charIndex = 0;
+ while (charIndex < inputString.length()) {
+ // We search the replacement dictionary to find a match. If there are
more than one matches
+ // (which is possible for collated strings), we want to choose the match
of largest length.
+ int longestMatchLen = 0;
+ String longestMatch = "";
+ for (String key : dict.keySet()) {
+ StringSearch stringSearch = new StringSearch(key, target,
(RuleBasedCollator) collator);
+ // Point `stringSearch` to start at the current character.
+ stringSearch.setIndex(charIndex);
+ int matchIndex = stringSearch.next();
+ if (matchIndex == charIndex) {
+ // We have found a match (that is the current position matches with
one of the characters
+ // in the dictionary). However, there might be other matches of
larger length, so we need
+ // to continue searching against the characters in the dictionary
and keep track of the
+ // match of largest length.
+ int matchLen = stringSearch.getMatchLength();
+ if (matchLen > longestMatchLen) {
+ longestMatchLen = matchLen;
+ longestMatch = key;
}
}
-
- if (!exists) {
- collationAwareDict.put(newKey, dict.get(key));
+ }
+ if (longestMatchLen == 0) {
+ // No match was found, so output the current character.
+ sb.append(inputString.charAt(charIndex));
+ // Move on to the next character in the input string.
+ ++charIndex;
+ } else {
+ // We have found at least one match. Append the match of longest match
length to the output.
+ if (!"\0".equals(dict.get(longestMatch))) {
+ sb.append(dict.get(longestMatch));
}
+ // Skip as many characters as the longest match.
+ charIndex += longestMatchLen;
}
}
-
- return collationAwareDict;
+ // Return the translated string.
+ return UTF8String.fromString(sb.toString());
}
public static UTF8String lowercaseTrim(
diff --git
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index 450a3eea1a3a..f9ccd22f3f5c 100644
---
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -212,7 +212,7 @@ public final class CollationSupport {
return useICU ? execBinaryICU(v) : execBinary(v);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(v);
- } else {
+ } else {
return execICU(v, collationId);
}
}
@@ -224,7 +224,7 @@ public final class CollationSupport {
return String.format(expr + "%s(%s)", funcName, v);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s)", v);
- } else {
+ } else {
return String.format(expr + "ICU(%s, %d)", v, collationId);
}
}
@@ -261,7 +261,7 @@ public final class CollationSupport {
return String.format(expr + "%s(%s)", funcName, v);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s)", v);
- } else {
+ } else {
return String.format(expr + "ICU(%s, %d)", v, collationId);
}
}
@@ -522,26 +522,11 @@ public final class CollationSupport {
return source.translate(dict);
}
public static UTF8String execLowercase(final UTF8String source,
Map<String, String> dict) {
- String srcStr = source.toString();
- StringBuilder sb = new StringBuilder();
- int charCount = 0;
- for (int k = 0; k < srcStr.length(); k += charCount) {
- int codePoint = srcStr.codePointAt(k);
- charCount = Character.charCount(codePoint);
- String subStr = srcStr.substring(k, k + charCount);
- String translated = dict.get(subStr.toLowerCase());
- if (null == translated) {
- sb.append(subStr);
- } else if (!"\0".equals(translated)) {
- sb.append(translated);
- }
- }
- return UTF8String.fromString(sb.toString());
+ return CollationAwareUTF8String.lowercaseTranslate(source, dict);
}
public static UTF8String execICU(final UTF8String source, Map<String,
String> dict,
final int collationId) {
- return source.translate(CollationAwareUTF8String.getCollationAwareDict(
- source, dict, collationId));
+ return CollationAwareUTF8String.translate(source, dict, collationId);
}
}
diff --git
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 9438484344d6..ce0cef3fef30 100644
---
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -22,6 +22,9 @@ import org.apache.spark.sql.catalyst.util.CollationFactory;
import org.apache.spark.sql.catalyst.util.CollationSupport;
import org.junit.jupiter.api.Test;
+import java.util.HashMap;
+import java.util.Map;
+
import static org.junit.jupiter.api.Assertions.*;
// checkstyle.off: AvoidEscapedUnicodeCharacters
@@ -1378,19 +1381,186 @@ public class CollationSupportSuite {
assertStringTrimRight("UTF8_LCASE", "Ëaaaẞ", "Ëẞ", "Ëaaa");
}
- // TODO: Test more collation-aware string expressions.
-
- /**
- * Collation-aware regexp expressions.
- */
-
- // TODO: Test more collation-aware regexp expressions.
+ private void assertStringTranslate(
+ String inputString,
+ String matchingString,
+ String replaceString,
+ String collationName,
+ String expectedResultString) throws SparkException {
+ int collationId = CollationFactory.collationNameToId(collationName);
+ Map<String, String> dict = buildDict(matchingString, replaceString);
+ UTF8String source = UTF8String.fromString(inputString);
+ UTF8String result = CollationSupport.StringTranslate.exec(source, dict,
collationId);
+ assertEquals(expectedResultString, result.toString());
+ }
- /**
- * Other collation-aware expressions.
- */
+ @Test
+ public void testStringTranslate() throws SparkException {
+ // Basic tests - UTF8_BINARY.
+ assertStringTranslate("Translate", "Rnlt", "12", "UTF8_BINARY", "Tra2sae");
+ assertStringTranslate("Translate", "Rn", "1234", "UTF8_BINARY",
"Tra2slate");
+ assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_BINARY",
"Tra2s3a4e");
+ assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_BINARY",
"TRaxsXaxe");
+ assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY",
"TxaxsXaxeX");
+ assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY",
"TXaxsXaxex");
+ assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY",
"test大千世AX大千世A");
+ assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY",
"大千世界test大千世界");
+ assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY",
"Oeso大千世界大千世界");
+ assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY",
"大千世界大千世界oesO");
+ assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY",
"世世世界世世世界tesT");
+ assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY",
"Tr4234e");
+ assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_BINARY",
"Tra2s3a4e");
+ assertStringTranslate("abcdef", "abcde", "123", "UTF8_BINARY", "123f");
+ // Basic tests - UTF8_LCASE.
+ assertStringTranslate("Translate", "Rnlt", "12", "UTF8_LCASE", "1a2sae");
+ assertStringTranslate("Translate", "Rn", "1234", "UTF8_LCASE",
"T1a2slate");
+ assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_LCASE",
"41a2s3a4e");
+ assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_LCASE",
"xXaxsXaxe");
+ assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE",
"xxaxsXaxex");
+ assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_LCASE",
"xXaxsXaxeX");
+ assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_LCASE",
"test大千世AB大千世A");
+ assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_LCASE",
"大千世界abca大千世界");
+ assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE",
"oeso大千世界大千世界");
+ assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE",
"大千世界大千世界OesO");
+ assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE",
"世世世界世世世界tesT");
+ assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE",
"14234e");
+ assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_LCASE",
"41a2s3a4e");
+ assertStringTranslate("abcdef", "abcde", "123", "UTF8_LCASE", "123f");
+ // Basic tests - UNICODE.
+ assertStringTranslate("Translate", "Rnlt", "12", "UNICODE", "Tra2sae");
+ assertStringTranslate("Translate", "Rn", "1234", "UNICODE", "Tra2slate");
+ assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e");
+ assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe");
+ assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE",
"TxaxsXaxeX");
+ assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE",
"TXaxsXaxex");
+ assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE",
"test大千世AX大千世A");
+ assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UNICODE",
"大千世界test大千世界");
+ assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE",
"Oeso大千世界大千世界");
+ assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE",
"大千世界大千世界oesO");
+ assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE",
"世世世界世世世界tesT");
+ assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE",
"Tr4234e");
+ assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE",
"Tra2s3a4e");
+ assertStringTranslate("abcdef", "abcde", "123", "UNICODE", "123f");
+ // Basic tests - UNICODE_CI.
+ assertStringTranslate("Translate", "Rnlt", "12", "UNICODE_CI", "1a2sae");
+ assertStringTranslate("Translate", "Rn", "1234", "UNICODE_CI",
"T1a2slate");
+ assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE_CI",
"41a2s3a4e");
+ assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE_CI",
"xXaxsXaxe");
+ assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI",
"xxaxsXaxex");
+ assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI",
"xXaxsXaxeX");
+ assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE_CI",
"test大千世AB大千世A");
+ assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI",
"大千世界abca大千世界");
+ assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE_CI",
"oeso大千世界大千世界");
+ assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI",
"大千世界大千世界OesO");
+ assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI",
"世世世界世世世界tesT");
+ assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI",
"14234e");
+ assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE_CI",
"41a2s3a4e");
+ assertStringTranslate("abcdef", "abcde", "123", "UNICODE_CI", "123f");
+
+ // One-to-many case mapping - UTF8_BINARY.
+ assertStringTranslate("İ", "i\u0307", "xy", "UTF8_BINARY", "İ");
+ assertStringTranslate("i\u0307", "İ", "xy", "UTF8_BINARY", "i\u0307");
+ assertStringTranslate("i\u030A", "İ", "x", "UTF8_BINARY", "i\u030A");
+ assertStringTranslate("i\u030A", "İi", "xy", "UTF8_BINARY", "y\u030A");
+ assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123");
+ assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307");
+ assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23");
+ assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_BINARY",
"12bc3");
+ assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_BINARY",
"a2bcå");
+ assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UTF8_BINARY",
"3\u030Aβφδ1\u0307");
+ // One-to-many case mapping - UTF8_LCASE.
+ assertStringTranslate("İ", "i\u0307", "xy", "UTF8_LCASE", "İ");
+ assertStringTranslate("i\u0307", "İ", "xy", "UTF8_LCASE", "x");
+ assertStringTranslate("i\u030A", "İ", "x", "UTF8_LCASE", "i\u030A");
+ assertStringTranslate("i\u030A", "İi", "xy", "UTF8_LCASE", "y\u030A");
+ assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_LCASE", "11");
+ assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_LCASE", "11");
+ assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_LCASE", "İ23");
+ assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_LCASE",
"12bc3");
+ assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_LCASE",
"12bc3");
+ assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UTF8_LCASE",
"3\u030Aβφδ2");
+ // One-to-many case mapping - UNICODE.
+ assertStringTranslate("İ", "i\u0307", "xy", "UNICODE", "İ");
+ assertStringTranslate("i\u0307", "İ", "xy", "UNICODE", "i\u0307");
+ assertStringTranslate("i\u030A", "İ", "x", "UNICODE", "i\u030A");
+ assertStringTranslate("i\u030A", "İi", "xy", "UNICODE", "i\u030A");
+ assertStringTranslate("İi\u0307", "İi\u0307", "123", "UNICODE",
"1i\u0307");
+ assertStringTranslate("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307");
+ assertStringTranslate("İi\u0307", "xi\u0307", "123", "UNICODE",
"İi\u0307");
+ assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE", "3bc3");
+ assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE",
"a\u030Abcå");
+ assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UNICODE",
"4βφδ2");
+ // One-to-many case mapping - UNICODE_CI.
+ assertStringTranslate("İ", "i\u0307", "xy", "UNICODE_CI", "İ");
+ assertStringTranslate("i\u0307", "İ", "xy", "UNICODE_CI", "x");
+ assertStringTranslate("i\u030A", "İ", "x", "UNICODE_CI", "i\u030A");
+ assertStringTranslate("i\u030A", "İi", "xy", "UNICODE_CI", "i\u030A");
+ assertStringTranslate("İi\u0307", "İi\u0307", "123", "UNICODE_CI", "11");
+ assertStringTranslate("İi\u0307", "İyz", "123", "UNICODE_CI", "11");
+ assertStringTranslate("İi\u0307", "xi\u0307", "123", "UNICODE_CI",
"İi\u0307");
+ assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE_CI",
"3bc3");
+ assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE_CI",
"3bc3");
+ assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UNICODE_CI",
"4βφδ2");
+
+ // Greek sigmas - UTF8_BINARY.
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_BINARY",
"σΥσΤΗΜΑΤΙΚΟσ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_BINARY",
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_BINARY",
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_BINARY",
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_BINARY",
"ςΥςΤΗΜΑΤΙΚΟς");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_BINARY",
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+ assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_BINARY",
"σιστιματικος");
+ assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_BINARY",
"σιστιματικος");
+ assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_BINARY",
"σιστιματικοσ");
+ // Greek sigmas - UTF8_LCASE.
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_LCASE",
"σισΤιΜΑΤΙΚΟσ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_LCASE",
"σισΤιΜΑΤΙΚΟσ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_LCASE",
"σισΤιΜΑΤΙΚΟσ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_LCASE",
"ςιςΤιΜΑΤΙΚΟς");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_LCASE",
"ςιςΤιΜΑΤΙΚΟς");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_LCASE",
"ςιςΤιΜΑΤΙΚΟς");
+ assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_LCASE",
"σιστιματικοσ");
+ assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_LCASE",
"σιστιματικοσ");
+ assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_LCASE",
"σιστιματικοσ");
+ // Greek sigmas - UNICODE.
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE",
"σΥσΤΗΜΑΤΙΚΟσ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE",
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE",
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UNICODE",
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UNICODE",
"ςΥςΤΗΜΑΤΙΚΟς");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UNICODE",
"ΣΥΣΤΗΜΑΤΙΚΟΣ");
+ assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE",
"σιστιματικος");
+ assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE",
"σιστιματικος");
+ assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE",
"σιστιματικοσ");
+ // Greek sigmas - UNICODE_CI.
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE_CI",
"σισΤιΜΑΤΙΚΟσ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE_CI",
"σισΤιΜΑΤΙΚΟσ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE_CI",
"σισΤιΜΑΤΙΚΟσ");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UNICODE_CI",
"ςιςΤιΜΑΤΙΚΟς");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UNICODE_CI",
"ςιςΤιΜΑΤΙΚΟς");
+ assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UNICODE_CI",
"ςιςΤιΜΑΤΙΚΟς");
+ assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE_CI",
"σιστιματικοσ");
+ assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE_CI",
"σιστιματικοσ");
+ assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE_CI",
"σιστιματικοσ");
+ }
- // TODO: Test other collation-aware expressions.
+ private Map<String, String> buildDict(String matching, String replace) {
+ Map<String, String> dict = new HashMap<>();
+ int i = 0, j = 0;
+ while (i < matching.length()) {
+ String rep = "\u0000";
+ if (j < replace.length()) {
+ int repCharCount = Character.charCount(replace.codePointAt(j));
+ rep = replace.substring(j, j + repCharCount);
+ j += repCharCount;
+ }
+ int matchCharCount = Character.charCount(matching.codePointAt(i));
+ String matchStr = matching.substring(i, i + matchCharCount);
+ dict.putIfAbsent(matchStr, rep);
+ i += matchCharCount;
+ }
+ return dict;
+ }
}
// checkstyle.on: AvoidEscapedUnicodeCharacters
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index b188b9c2630f..1302ca80e51a 100755
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1050,15 +1050,35 @@ case class Overlay(input: Expression, replace:
Expression, pos: Expression, len:
object StringTranslate {
- def buildDict(matchingString: UTF8String, replaceString: UTF8String,
collationId: Int)
+ /**
+ * Build a translation dictionary from UTF8Strings. First, this method
converts the input strings
+ * to valid Java Strings. However, we avoid any behavior changes for the
UTF8_BINARY collation,
+ * but ensure that all other collations use `UTF8String.toValidString` to
achieve this step.
+ */
+ def buildDict(matchingString: UTF8String, replaceString: UTF8String,
collationId: Integer)
: JMap[String, String] = {
- val matching = if
(CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) {
- matchingString.toString().toLowerCase()
+ val isCollationAware = collationId ==
CollationFactory.UTF8_BINARY_COLLATION_ID
+ val matching: String = if (isCollationAware) {
+ matchingString.toString
+ } else {
+ matchingString.toValidString
+ }
+ val replace: String = if (isCollationAware) {
+ replaceString.toString
} else {
- matchingString.toString()
+ replaceString.toValidString
}
+ buildDict(matching, replace)
+ }
- val replace = replaceString.toString()
+ /**
+ * Build a translation dictionary from Strings. This method assumes that the
input strings are
+ * already valid. The result dictionary maps each character in `matching` to
the corresponding
+ * character in `replace`. If `replace` is shorter than `matching`, the
extra characters in
+ * `matching` will be mapped to null terminator, which causes characters to
get deleted during
+ * translation. If `replace` is longer than `matching`, the extra characters
will be ignored.
+ */
+ private def buildDict(matching: String, replace: String): JMap[String,
String] = {
val dict = new HashMap[String, String]()
var i = 0
var j = 0
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 78aee5b80e54..5f722b2f01fb 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -252,55 +252,16 @@ class CollationStringExpressionsSuite
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}
- test("TRANSLATE check result on explicitly collated string") {
+
+ test("Support StringTranslate string expression with collation") {
// Supported collations
case class TranslateTestCase[R](input: String, matchExpression: String,
- replaceExpression: String, collation: String, result: R)
+ replaceExpression: String, collation: String, result: R)
val testCases = Seq(
+ TranslateTestCase("Translate", "Rnlt", "12", "UTF8_BINARY", "Tra2sae"),
TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE",
"41a2s3a4e"),
- TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE",
"41a2s3a4e"),
- TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_LCASE",
"xXaxsXaxe"),
- TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE",
"xxaxsXaxex"),
- TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_LCASE",
"xXaxsXaxeX"),
- // scalastyle:off
- TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_LCASE",
"test大千世AB大千世A"),
- TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_LCASE",
"大千世界abca大千世界"),
- TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE",
"oeso大千世界大千世界"),
- TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE",
"大千世界大千世界OesO"),
- TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE",
"世世世界世世世界tesT"),
- // scalastyle:on
- TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"),
- TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"),
- TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE",
"TxaxsXaxeX"),
- TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE",
"TXaxsXaxex"),
- // scalastyle:off
- TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE",
"test大千世AX大千世A"),
- TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"),
- TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"),
- // scalastyle:on
- TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE_CI",
"41a2s3a4e"),
- TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE_CI",
"xXaxsXaxe"),
- TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI",
"xxaxsXaxex"),
- TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI",
"xXaxsXaxeX"),
- // scalastyle:off
- TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE_CI",
"test大千世AB大千世A"),
- TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI",
"大千世界abca大千世界"),
- TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE_CI",
"oeso大千世界大千世界"),
- TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI",
"大千世界大千世界OesO"),
- TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI",
"世世世界世世世界tesT"),
- // scalastyle:on
- TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE",
"14234e"),
- TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI",
"14234e"),
- TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE",
"Tr4234e"),
- TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY",
"Tr4234e"),
- TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_LCASE",
"41a2s3a4e"),
- TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE",
"Tra2s3a4e"),
- TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE_CI",
"41a2s3a4e"),
- TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_BINARY",
"Tra2s3a4e"),
- TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY", "123f"),
- TranslateTestCase("abcdef", "abcde", "123", "UTF8_LCASE", "123f"),
- TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"),
- TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f")
+ TranslateTestCase("Translate", "Rn", "\u0000\u0000", "UNICODE",
"Traslate"),
+ TranslateTestCase("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate")
)
testCases.foreach(t => {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]