uros-db commented on code in PR #46761:
URL: https://github.com/apache/spark/pull/46761#discussion_r1670942962
##########
common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java:
##########
@@ -655,38 +691,130 @@ public static UTF8String lowercaseSubStringIndex(final
UTF8String string,
}
}
- public static Map<String, String> getCollationAwareDict(UTF8String string,
- Map<String, String> dict, int collationId) {
- // TODO(SPARK-48715): All UTF8String -> String conversions should use
`makeValid`
- String srcStr = string.toString();
+ /**
+ * Converts the original translation dictionary (`dict`) to a dictionary
with lowercased keys.
+ * This method is used to create a dictionary that can be used for the
UTF8_LCASE collation.
+ * Note that `StringTranslate.buildDict` will ensure that all strings are
validated properly.
+ *
+ * The method returns a map with lowercased code points as keys, while the
values remain
+ * unchanged. Note that `dict` is constructed on a character by character
basis, and the
+ * original keys are stored as strings. Keys in the resulting lowercase
dictionary are stored
+ * as integers, which correspond only to single characters from the original
`dict`. Also,
+ * there is special handling for the Turkish dotted uppercase letter I
(U+0130).
+ */
+ private static Map<Integer, String> getLowercaseDict(final Map<String,
String> dict) {
+ // Replace all the keys in the dict with lowercased code points.
+ Map<Integer, String> lowercaseDict = new HashMap<>();
+ for (Map.Entry<String, String> entry : dict.entrySet()) {
+ int codePoint = entry.getKey().codePointAt(0);
+ lowercaseDict.putIfAbsent(getLowercaseCodePoint(codePoint),
entry.getValue());
+ }
+ return lowercaseDict;
+ }
+
+ /**
+ * Translates the `input` string using the translation map `dict`, for
UTF8_LCASE collation.
+ * String translation is performed by iterating over the input string, from
left to right, and
+ * repeatedly translating the longest possible substring that matches a key
in the dictionary.
+ * For UTF8_LCASE, the method uses the lowercased substring to perform the
lookup in the
+ * lowercased version of the translation map.
+ *
+ * @param input the string to be translated
+ * @param dict the lowercase translation dictionary
+ * @return the translated string
+ */
+ public static UTF8String lowercaseTranslate(final UTF8String input,
+ final Map<String, String> dict) {
+ // Iterator for the input string.
+ Iterator<Integer> inputIter = input.codePointIterator(
+ CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
+ // Lowercased translation dictionary.
+ Map<Integer, String> lowercaseDict = getLowercaseDict(dict);
+ // StringBuilder to store the translated string.
+ StringBuilder sb = new StringBuilder();
- Map<String, String> collationAwareDict = new HashMap<>();
- for (String key : dict.keySet()) {
- StringSearch stringSearch =
- CollationFactory.getStringSearch(string, UTF8String.fromString(key),
collationId);
+ // Buffered code point iteration to handle one-to-many case mappings.
+ int codePointBuffer = -1, codePoint;
+ while (inputIter.hasNext()) {
+ if (codePointBuffer != -1) {
+ codePoint = codePointBuffer;
+ codePointBuffer = -1;
+ } else {
+ codePoint = inputIter.next();
+ }
+ // Special handling for letter i (U+0069) followed by a combining dot
(U+0307).
+ if (lowercaseDict.containsKey(CODE_POINT_COMBINED_LOWERCASE_I_DOT) &&
+ codePoint == CODE_POINT_LOWERCASE_I && inputIter.hasNext()) {
+ int nextCodePoint = inputIter.next();
+ if (nextCodePoint == CODE_POINT_COMBINING_DOT) {
+ codePoint = CODE_POINT_COMBINED_LOWERCASE_I_DOT;
+ } else {
+ codePointBuffer = nextCodePoint;
+ }
+ }
+ // Translate the code point using the lowercased dictionary.
+ String translated = lowercaseDict.get(getLowercaseCodePoint(codePoint));
+ if (translated == null) {
+ // Append the original code point if no translation is found.
+ sb.appendCodePoint(codePoint);
+ } else if (!"\0".equals(translated)) {
+ // Append the translated code point if the translation is not the null
character.
+ sb.append(translated);
+ }
+ // Skip the code point if it maps to the null character.
+ }
+ // Append the last code point if it was buffered.
+ if (codePointBuffer != -1) sb.appendCodePoint(codePointBuffer);
- int pos = 0;
- while ((pos = stringSearch.next()) != StringSearch.DONE) {
- int codePoint = srcStr.codePointAt(pos);
- int charCount = Character.charCount(codePoint);
- String newKey = srcStr.substring(pos, pos + charCount);
+ // Return the translated string.
+ return UTF8String.fromString(sb.toString());
+ }
- boolean exists = false;
- for (String existingKey : collationAwareDict.keySet()) {
- if (stringSearch.getCollator().compare(existingKey, newKey) == 0) {
- collationAwareDict.put(newKey,
collationAwareDict.get(existingKey));
- exists = true;
- break;
+ /**
+ * Translates the `input` string using the translation map `dict`, for all
ICU collations.
+ * String translation is performed by iterating over the input string, from
left to right, and
+ * repeatedly translating the longest possible substring that matches a key
in the dictionary.
+ * For ICU collations, the method uses the collation key of the substring to
perform the lookup
+ * in the collation aware version of the translation map.
+ *
+ * @param input the string to be translated
+ * @param dict the collation aware translation dictionary
+ * @param collationId the collation ID to use for string translation
+ * @return the translated string
+ */
+ public static UTF8String translate(final UTF8String input,
+ final Map<String, String> dict, final int collationId) {
+ String inputString = input.toValidString();
+ CharacterIterator target = new StringCharacterIterator(inputString);
+ Collator collator = CollationFactory.fetchCollation(collationId).collator;
+ StringBuilder sb = new StringBuilder();
+ int charIndex = 0;
+ while (charIndex < inputString.length()) {
+ int longestMatchLen = 0;
+ String longestMatch = "";
+ for (String key : dict.keySet()) {
+ StringSearch stringSearch = new StringSearch(key, target,
(RuleBasedCollator) collator);
+ stringSearch.setIndex(charIndex);
+ int matchIndex = stringSearch.next();
+ if (matchIndex == charIndex) {
+ int matchLen = stringSearch.getMatchLength();
Review Comment:
yup, `getMatchLength` returns the number of chars
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]