This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7fe1b93884aa [SPARK-46841][SQL] Add collation support for ICU locales
and collation specifiers
7fe1b93884aa is described below
commit 7fe1b93884aa8e9ba20f19351b8537c687b8f59c
Author: Nikola Mandic <[email protected]>
AuthorDate: Tue May 28 09:56:16 2024 -0700
[SPARK-46841][SQL] Add collation support for ICU locales and collation
specifiers
### What changes were proposed in this pull request?
Languages and localization for collations are supported by ICU library.
Collation naming format is as follows:
```
<2-letter language code>[_<4-letter script>][_<3-letter country
code>][_specifier_specifier...]
```
Locale specifier consists of the first part of collation name (language +
script + country). Locale specifiers need to be stable across ICU versions; to
keep existing ids and names invariant we introduce golden file will locale
table which should case CI failure on any silent changes.
Currently supported optional specifiers:
- `CS`/`CI` - case sensitivity, default is case-sensitive; supported by
configuring ICU collation levels
- `AS`/`AI` - accent sensitivity, default is accent-sensitive; supported by
configuring ICU collation levels
User can use collation specifiers in any order except of locale which is
mandatory and must go first. There is a one-to-one mapping between collation
ids and collation names defined in `CollationFactory`.
### Why are the changes needed?
To add languages and localization support for collations.
### Does this PR introduce _any_ user-facing change?
Yes, it adds new predefined collations.
### How was this patch tested?
Added checks to `CollationFactorySuite` and ICU locale map golden file.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #46180 from nikolamand-db/SPARK-46841.
Authored-by: Nikola Mandic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/util/CollationFactory.java | 678 +++++++++++++++++----
.../spark/unsafe/types/CollationFactorySuite.scala | 323 +++++++++-
.../src/main/resources/error/error-conditions.json | 4 +-
.../apache/spark/sql/PlanGenerationTestSuite.scala | 4 +-
.../src/main/protobuf/spark/connect/types.proto | 2 +-
.../connect/common/DataTypeProtoConverter.scala | 9 +-
.../query-tests/queries/csv_from_dataset.json | 2 +-
.../query-tests/queries/csv_from_dataset.proto.bin | Bin 158 -> 169 bytes
.../query-tests/queries/function_lit_array.json | 4 +-
.../queries/function_lit_array.proto.bin | Bin 889 -> 911 bytes
.../query-tests/queries/function_typedLit.json | 32 +-
.../queries/function_typedLit.proto.bin | Bin 1199 -> 1381 bytes
.../query-tests/queries/json_from_dataset.json | 2 +-
.../queries/json_from_dataset.proto.bin | Bin 169 -> 180 bytes
python/pyspark/sql/connect/proto/types_pb2.py | 78 +--
python/pyspark/sql/connect/proto/types_pb2.pyi | 11 +-
python/pyspark/sql/connect/types.py | 5 +-
python/pyspark/sql/types.py | 27 +-
.../org/apache/spark/sql/internal/SQLConf.scala | 15 +-
.../expressions/CollationExpressionSuite.scala | 33 +-
.../resources/collations/ICU-collations-map.md | 143 +++++
.../sql-tests/analyzer-results/collations.sql.out | 77 +++
.../test/resources/sql-tests/inputs/collations.sql | 13 +
.../resources/sql-tests/results/collations.sql.out | 88 +++
.../org/apache/spark/sql/CollationSuite.scala | 2 +-
.../apache/spark/sql/ICUCollationsMapSuite.scala | 69 +++
.../apache/spark/sql/internal/SQLConfSuite.scala | 3 +-
27 files changed, 1388 insertions(+), 236 deletions(-)
diff --git
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index 0133c3feb611..fce12510afaf 100644
---
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.util;
import java.text.CharacterIterator;
import java.text.StringCharacterIterator;
import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiFunction;
import java.util.function.ToLongFunction;
@@ -173,26 +174,546 @@ public final class CollationFactory {
}
/**
- * Constructor with comparators that are inherited from the given collator.
+ * Collation ID is defined as 32-bit integer. We specify binary layouts
for different classes of
+ * collations. Classes of collations are differentiated by most
significant 3 bits (bit 31, 30
+ * and 29), bit 31 being most significant and bit 0 being least
significant.
+ * ---
+ * General collation ID binary layout:
+ * bit 31: 1 for INDETERMINATE (requires all other bits to be 1 as
well), 0 otherwise.
+ * bit 30: 0 for predefined, 1 for user-defined.
+ * Following bits are specified for predefined collations:
+ * bit 29: 0 for UTF8_BINARY, 1 for ICU collations.
+ * bit 28-24: Reserved.
+ * bit 23-22: Reserved for version.
+ * bit 21-18: Reserved for space trimming.
+ * bit 17-0: Depend on collation family.
+ * ---
+ * INDETERMINATE collation ID binary layout:
+ * bit 31-0: 1
+ * INDETERMINATE collation ID is equal to -1.
+ * ---
+ * User-defined collation ID binary layout:
+ * bit 31: 0
+ * bit 30: 1
+ * bit 29-0: Undefined, reserved for future use.
+ * ---
+ * UTF8_BINARY collation ID binary layout:
+ * bit 31-24: Zeroes.
+ * bit 23-22: Zeroes, reserved for version.
+ * bit 21-18: Zeroes, reserved for space trimming.
+ * bit 17-3: Zeroes.
+ * bit 2: 0, reserved for accent sensitivity.
+ * bit 1: 0, reserved for uppercase and case-insensitive.
+ * bit 0: 0 = case-sensitive, 1 = lowercase.
+ * ---
+ * ICU collation ID binary layout:
+ * bit 31-30: Zeroes.
+ * bit 29: 1
+ * bit 28-24: Zeroes.
+ * bit 23-22: Zeroes, reserved for version.
+ * bit 21-18: Zeroes, reserved for space trimming.
+ * bit 17: 0 = case-sensitive, 1 = case-insensitive.
+ * bit 16: 0 = accent-sensitive, 1 = accent-insensitive.
+ * bit 15-14: Zeroes, reserved for punctuation sensitivity.
+ * bit 13-12: Zeroes, reserved for first letter preference.
+ * bit 11-0: Locale ID as specified in `ICULocaleToId` mapping.
+ * ---
+ * Some illustrative examples of collation name to ID mapping:
+ * - UTF8_BINARY -> 0
+ * - UTF8_BINARY_LCASE -> 1
+ * - UNICODE -> 0x20000000
+ * - UNICODE_AI -> 0x20010000
+ * - UNICODE_CI -> 0x20020000
+ * - UNICODE_CI_AI -> 0x20030000
+ * - af -> 0x20000001
+ * - af_CI_AI -> 0x20030001
*/
- public Collation(
- String collationName,
- String provider,
- Collator collator,
- String version,
- boolean supportsBinaryEquality,
- boolean supportsBinaryOrdering,
- boolean supportsLowercaseEquality) {
- this(
- collationName,
- provider,
- collator,
- (s1, s2) -> collator.compare(s1.toString(), s2.toString()),
- version,
- s -> (long)collator.getCollationKey(s.toString()).hashCode(),
- supportsBinaryEquality,
- supportsBinaryOrdering,
- supportsLowercaseEquality);
+ private abstract static class CollationSpec {
+
+ /**
+ * Bit 30 in collation ID having value 0 for predefined and 1 for
user-defined collation.
+ */
+ private enum DefinitionOrigin {
+ PREDEFINED, USER_DEFINED
+ }
+
+ /**
+ * Bit 29 in collation ID having value 0 for UTF8_BINARY family and 1
for ICU family of
+ * collations.
+ */
+ protected enum ImplementationProvider {
+ UTF8_BINARY, ICU
+ }
+
+ /**
+ * Offset in binary collation ID layout.
+ */
+ private static final int DEFINITION_ORIGIN_OFFSET = 30;
+
+ /**
+ * Bitmask corresponding to width in bits in binary collation ID layout.
+ */
+ private static final int DEFINITION_ORIGIN_MASK = 0b1;
+
+ /**
+ * Offset in binary collation ID layout.
+ */
+ protected static final int IMPLEMENTATION_PROVIDER_OFFSET = 29;
+
+ /**
+ * Bitmask corresponding to width in bits in binary collation ID layout.
+ */
+ protected static final int IMPLEMENTATION_PROVIDER_MASK = 0b1;
+
+ private static final int INDETERMINATE_COLLATION_ID = -1;
+
+ /**
+ * Thread-safe cache mapping collation IDs to corresponding `Collation`
instances.
+ * We add entries to this cache lazily as new `Collation` instances are
requested.
+ */
+ private static final Map<Integer, Collation> collationMap = new
ConcurrentHashMap<>();
+
+ /**
+ * Utility function to retrieve `ImplementationProvider` enum instance
from collation ID.
+ */
+ private static ImplementationProvider getImplementationProvider(int
collationId) {
+ return
ImplementationProvider.values()[SpecifierUtils.getSpecValue(collationId,
+ IMPLEMENTATION_PROVIDER_OFFSET, IMPLEMENTATION_PROVIDER_MASK)];
+ }
+
+ /**
+ * Utility function to retrieve `DefinitionOrigin` enum instance from
collation ID.
+ */
+ private static DefinitionOrigin getDefinitionOrigin(int collationId) {
+ return
DefinitionOrigin.values()[SpecifierUtils.getSpecValue(collationId,
+ DEFINITION_ORIGIN_OFFSET, DEFINITION_ORIGIN_MASK)];
+ }
+
+ /**
+ * Main entry point for retrieving `Collation` instance from collation
ID.
+ */
+ private static Collation fetchCollation(int collationId) {
+ // User-defined collations and INDETERMINATE collations cannot produce
a `Collation`
+ // instance.
+ assert (collationId >= 0 && getDefinitionOrigin(collationId)
+ == DefinitionOrigin.PREDEFINED);
+ if (collationId == UTF8_BINARY_COLLATION_ID) {
+ // Skip cache.
+ return CollationSpecUTF8Binary.UTF8_BINARY_COLLATION;
+ } else if (collationMap.containsKey(collationId)) {
+ // Already in cache.
+ return collationMap.get(collationId);
+ } else {
+ // Build `Collation` instance and put into cache.
+ CollationSpec spec;
+ ImplementationProvider implementationProvider =
getImplementationProvider(collationId);
+ if (implementationProvider == ImplementationProvider.UTF8_BINARY) {
+ spec = CollationSpecUTF8Binary.fromCollationId(collationId);
+ } else {
+ spec = CollationSpecICU.fromCollationId(collationId);
+ }
+ Collation collation = spec.buildCollation();
+ collationMap.put(collationId, collation);
+ return collation;
+ }
+ }
+
+ protected static SparkException collationInvalidNameException(String
collationName) {
+ return new SparkException("COLLATION_INVALID_NAME",
+ SparkException.constructMessageParams(Map.of("collationName",
collationName)), null);
+ }
+
+ private static int collationNameToId(String collationName) throws
SparkException {
+ // Collation names provided by user are treated as case-insensitive.
+ String collationNameUpper = collationName.toUpperCase();
+ if (collationNameUpper.startsWith("UTF8_BINARY")) {
+ return CollationSpecUTF8Binary.collationNameToId(collationName,
collationNameUpper);
+ } else {
+ return CollationSpecICU.collationNameToId(collationName,
collationNameUpper);
+ }
+ }
+
+ protected abstract Collation buildCollation();
+ }
+
+ private static class CollationSpecUTF8Binary extends CollationSpec {
+
+ /**
+ * Bit 0 in collation ID having value 0 for plain UTF8_BINARY and 1 for
UTF8_BINARY_LCASE
+ * collation.
+ */
+ private enum CaseSensitivity {
+ UNSPECIFIED, LCASE
+ }
+
+ /**
+ * Offset in binary collation ID layout.
+ */
+ private static final int CASE_SENSITIVITY_OFFSET = 0;
+
+ /**
+ * Bitmask corresponding to width in bits in binary collation ID layout.
+ */
+ private static final int CASE_SENSITIVITY_MASK = 0b1;
+
+ private static final int UTF8_BINARY_COLLATION_ID =
+ new CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).collationId;
+ private static final int UTF8_BINARY_LCASE_COLLATION_ID =
+ new CollationSpecUTF8Binary(CaseSensitivity.LCASE).collationId;
+ protected static Collation UTF8_BINARY_COLLATION =
+ new
CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).buildCollation();
+ protected static Collation UTF8_BINARY_LCASE_COLLATION =
+ new CollationSpecUTF8Binary(CaseSensitivity.LCASE).buildCollation();
+
+ private final int collationId;
+
+ private CollationSpecUTF8Binary(CaseSensitivity caseSensitivity) {
+ this.collationId =
+ SpecifierUtils.setSpecValue(0, CASE_SENSITIVITY_OFFSET,
caseSensitivity);
+ }
+
+ private static int collationNameToId(String originalName, String
collationName)
+ throws SparkException {
+ if (UTF8_BINARY_COLLATION.collationName.equals(collationName)) {
+ return UTF8_BINARY_COLLATION_ID;
+ } else if
(UTF8_BINARY_LCASE_COLLATION.collationName.equals(collationName)) {
+ return UTF8_BINARY_LCASE_COLLATION_ID;
+ } else {
+ // Throw exception with original (before case conversion) collation
name.
+ throw collationInvalidNameException(originalName);
+ }
+ }
+
+ private static CollationSpecUTF8Binary fromCollationId(int collationId) {
+ // Extract case sensitivity from collation ID.
+ int caseConversionOrdinal = SpecifierUtils.getSpecValue(collationId,
+ CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK);
+ // Verify only case sensitivity bits were set settable in UTF8_BINARY
family of collations.
+ assert (SpecifierUtils.removeSpec(collationId,
+ CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK) == 0);
+ return new
CollationSpecUTF8Binary(CaseSensitivity.values()[caseConversionOrdinal]);
+ }
+
+ @Override
+ protected Collation buildCollation() {
+ if (collationId == UTF8_BINARY_COLLATION_ID) {
+ return new Collation(
+ "UTF8_BINARY",
+ PROVIDER_SPARK,
+ null,
+ UTF8String::binaryCompare,
+ "1.0",
+ s -> (long) s.hashCode(),
+ /* supportsBinaryEquality = */ true,
+ /* supportsBinaryOrdering = */ true,
+ /* supportsLowercaseEquality = */ false);
+ } else {
+ return new Collation(
+ "UTF8_BINARY_LCASE",
+ PROVIDER_SPARK,
+ null,
+ UTF8String::compareLowerCase,
+ "1.0",
+ s -> (long) s.toLowerCase().hashCode(),
+ /* supportsBinaryEquality = */ false,
+ /* supportsBinaryOrdering = */ false,
+ /* supportsLowercaseEquality = */ true);
+ }
+ }
+ }
+
+ private static class CollationSpecICU extends CollationSpec {
+
+ /**
+ * Bit 17 in collation ID having value 0 for case-sensitive and 1 for
case-insensitive
+ * collation.
+ */
+ private enum CaseSensitivity {
+ CS, CI
+ }
+
+ /**
+ * Bit 16 in collation ID having value 0 for accent-sensitive and 1 for
accent-insensitive
+ * collation.
+ */
+ private enum AccentSensitivity {
+ AS, AI
+ }
+
+ /**
+ * Offset in binary collation ID layout.
+ */
+ private static final int CASE_SENSITIVITY_OFFSET = 17;
+
+ /**
+ * Bitmask corresponding to width in bits in binary collation ID layout.
+ */
+ private static final int CASE_SENSITIVITY_MASK = 0b1;
+
+ /**
+ * Offset in binary collation ID layout.
+ */
+ private static final int ACCENT_SENSITIVITY_OFFSET = 16;
+
+ /**
+ * Bitmask corresponding to width in bits in binary collation ID layout.
+ */
+ private static final int ACCENT_SENSITIVITY_MASK = 0b1;
+
+ /**
+ * Array of locale names, each locale ID corresponds to the index in
this array.
+ */
+ private static final String[] ICULocaleNames;
+
+ /**
+ * Mapping of locale names to corresponding `ULocale` instance.
+ */
+ private static final Map<String, ULocale> ICULocaleMap = new HashMap<>();
+
+ /**
+ * Used to parse user input collation names which are converted to
uppercase.
+ */
+ private static final Map<String, String> ICULocaleMapUppercase = new
HashMap<>();
+
+ /**
+ * Reverse mapping of `ICULocaleNames`.
+ */
+ private static final Map<String, Integer> ICULocaleToId = new
HashMap<>();
+
+ /**
+ * ICU library Collator version passed to `Collation` instance.
+ */
+ private static final String ICU_COLLATOR_VERSION = "153.120.0.0";
+
+ static {
+ ICULocaleMap.put("UNICODE", ULocale.ROOT);
+ // ICU-implemented `ULocale`s which have corresponding `Collator`
installed.
+ ULocale[] locales = Collator.getAvailableULocales();
+ // Build locale names in format: language["_" optional script]["_"
optional country code].
+ // Examples: en, en_USA, sr_Cyrl_SRB
+ for (ULocale locale : locales) {
+ // Skip variants.
+ if (locale.getVariant().isEmpty()) {
+ String language = locale.getLanguage();
+ // Require non-empty language as first component of locale name.
+ assert (!language.isEmpty());
+ StringBuilder builder = new StringBuilder(language);
+ // Script tag.
+ String script = locale.getScript();
+ if (!script.isEmpty()) {
+ builder.append('_');
+ builder.append(script);
+ }
+ // 3-letter country code.
+ String country = locale.getISO3Country();
+ if (!country.isEmpty()) {
+ builder.append('_');
+ builder.append(country);
+ }
+ String localeName = builder.toString();
+ // Verify locale names are unique.
+ assert (!ICULocaleMap.containsKey(localeName));
+ ICULocaleMap.put(localeName, locale);
+ }
+ }
+ // Construct uppercase-normalized locale name mapping.
+ for (String localeName : ICULocaleMap.keySet()) {
+ String localeUppercase = localeName.toUpperCase();
+ // Locale names are unique case-insensitively.
+ assert (!ICULocaleMapUppercase.containsKey(localeUppercase));
+ ICULocaleMapUppercase.put(localeUppercase, localeName);
+ }
+ // Construct locale name to ID mapping. Locale ID is defined as index
in `ICULocaleNames`.
+ ICULocaleNames = ICULocaleMap.keySet().toArray(new String[0]);
+ Arrays.sort(ICULocaleNames);
+ // Maximum number of locale IDs as defined by binary layout.
+ assert (ICULocaleNames.length <= (1 << 12));
+ for (int i = 0; i < ICULocaleNames.length; ++i) {
+ ICULocaleToId.put(ICULocaleNames[i], i);
+ }
+ }
+
+ private static final int UNICODE_COLLATION_ID =
+ new CollationSpecICU("UNICODE", CaseSensitivity.CS,
AccentSensitivity.AS).collationId;
+ private static final int UNICODE_CI_COLLATION_ID =
+ new CollationSpecICU("UNICODE", CaseSensitivity.CI,
AccentSensitivity.AS).collationId;
+
+ private final CaseSensitivity caseSensitivity;
+ private final AccentSensitivity accentSensitivity;
+ private final String locale;
+ private final int collationId;
+
+ private CollationSpecICU(String locale, CaseSensitivity caseSensitivity,
+ AccentSensitivity accentSensitivity) {
+ this.locale = locale;
+ this.caseSensitivity = caseSensitivity;
+ this.accentSensitivity = accentSensitivity;
+ // Construct collation ID from locale, case-sensitivity and
accent-sensitivity specifiers.
+ int collationId = ICULocaleToId.get(locale);
+ // Mandatory ICU implementation provider.
+ collationId = SpecifierUtils.setSpecValue(collationId,
IMPLEMENTATION_PROVIDER_OFFSET,
+ ImplementationProvider.ICU);
+ collationId = SpecifierUtils.setSpecValue(collationId,
CASE_SENSITIVITY_OFFSET,
+ caseSensitivity);
+ collationId = SpecifierUtils.setSpecValue(collationId,
ACCENT_SENSITIVITY_OFFSET,
+ accentSensitivity);
+ this.collationId = collationId;
+ }
+
+ private static int collationNameToId(
+ String originalName, String collationName) throws SparkException {
+ // Search for the longest locale match because specifiers are designed
to be different from
+ // script tag and country code, meaning the only valid locale name
match can be the longest
+ // one.
+ int lastPos = -1;
+ for (int i = 1; i <= collationName.length(); i++) {
+ String localeName = collationName.substring(0, i);
+ if (ICULocaleMapUppercase.containsKey(localeName)) {
+ lastPos = i;
+ }
+ }
+ if (lastPos == -1) {
+ throw collationInvalidNameException(originalName);
+ } else {
+ String locale = collationName.substring(0, lastPos);
+ int collationId =
ICULocaleToId.get(ICULocaleMapUppercase.get(locale));
+
+ // Try all combinations of AS/AI and CS/CI.
+ CaseSensitivity caseSensitivity;
+ AccentSensitivity accentSensitivity;
+ if (collationName.equals(locale) ||
+ collationName.equals(locale + "_AS") ||
+ collationName.equals(locale + "_CS") ||
+ collationName.equals(locale + "_AS_CS") ||
+ collationName.equals(locale + "_CS_AS")
+ ) {
+ caseSensitivity = CaseSensitivity.CS;
+ accentSensitivity = AccentSensitivity.AS;
+ } else if (collationName.equals(locale + "_CI") ||
+ collationName.equals(locale + "_AS_CI") ||
+ collationName.equals(locale + "_CI_AS")) {
+ caseSensitivity = CaseSensitivity.CI;
+ accentSensitivity = AccentSensitivity.AS;
+ } else if (collationName.equals(locale + "_AI") ||
+ collationName.equals(locale + "_CS_AI") ||
+ collationName.equals(locale + "_AI_CS")) {
+ caseSensitivity = CaseSensitivity.CS;
+ accentSensitivity = AccentSensitivity.AI;
+ } else if (collationName.equals(locale + "_AI_CI") ||
+ collationName.equals(locale + "_CI_AI")) {
+ caseSensitivity = CaseSensitivity.CI;
+ accentSensitivity = AccentSensitivity.AI;
+ } else {
+ throw collationInvalidNameException(originalName);
+ }
+
+ // Build collation ID from computed specifiers.
+ collationId = SpecifierUtils.setSpecValue(collationId,
+ IMPLEMENTATION_PROVIDER_OFFSET, ImplementationProvider.ICU);
+ collationId = SpecifierUtils.setSpecValue(collationId,
+ CASE_SENSITIVITY_OFFSET, caseSensitivity);
+ collationId = SpecifierUtils.setSpecValue(collationId,
+ ACCENT_SENSITIVITY_OFFSET, accentSensitivity);
+ return collationId;
+ }
+ }
+
+ private static CollationSpecICU fromCollationId(int collationId) {
+ // Parse specifiers from collation ID.
+ int caseSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId,
+ CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK);
+ int accentSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId,
+ ACCENT_SENSITIVITY_OFFSET, ACCENT_SENSITIVITY_MASK);
+ collationId = SpecifierUtils.removeSpec(collationId,
+ IMPLEMENTATION_PROVIDER_OFFSET, IMPLEMENTATION_PROVIDER_MASK);
+ collationId = SpecifierUtils.removeSpec(collationId,
+ CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK);
+ collationId = SpecifierUtils.removeSpec(collationId,
+ ACCENT_SENSITIVITY_OFFSET, ACCENT_SENSITIVITY_MASK);
+ // Locale ID remains after removing all other specifiers.
+ int localeId = collationId;
+ // Verify locale ID is valid against `ICULocaleNames` array.
+ assert (localeId < ICULocaleNames.length);
+ CaseSensitivity caseSensitivity =
CaseSensitivity.values()[caseSensitivityOrdinal];
+ AccentSensitivity accentSensitivity =
AccentSensitivity.values()[accentSensitivityOrdinal];
+ String locale = ICULocaleNames[localeId];
+ return new CollationSpecICU(locale, caseSensitivity,
accentSensitivity);
+ }
+
+ @Override
+ protected Collation buildCollation() {
+ ULocale.Builder builder = new ULocale.Builder();
+ builder.setLocale(ICULocaleMap.get(locale));
+ // Compute unicode locale keyword for all combinations of case/accent
sensitivity.
+ if (caseSensitivity == CaseSensitivity.CS &&
+ accentSensitivity == AccentSensitivity.AS) {
+ builder.setUnicodeLocaleKeyword("ks", "level3");
+ } else if (caseSensitivity == CaseSensitivity.CS &&
+ accentSensitivity == AccentSensitivity.AI) {
+ builder
+ .setUnicodeLocaleKeyword("ks", "level1")
+ .setUnicodeLocaleKeyword("kc", "true");
+ } else if (caseSensitivity == CaseSensitivity.CI &&
+ accentSensitivity == AccentSensitivity.AS) {
+ builder.setUnicodeLocaleKeyword("ks", "level2");
+ } else if (caseSensitivity == CaseSensitivity.CI &&
+ accentSensitivity == AccentSensitivity.AI) {
+ builder.setUnicodeLocaleKeyword("ks", "level1");
+ }
+ ULocale resultLocale = builder.build();
+ Collator collator = Collator.getInstance(resultLocale);
+ // Freeze ICU collator to ensure thread safety.
+ collator.freeze();
+ return new Collation(
+ collationName(),
+ PROVIDER_ICU,
+ collator,
+ (s1, s2) -> collator.compare(s1.toString(), s2.toString()),
+ ICU_COLLATOR_VERSION,
+ s -> (long) collator.getCollationKey(s.toString()).hashCode(),
+ /* supportsBinaryEquality = */ collationId == UNICODE_COLLATION_ID,
+ /* supportsBinaryOrdering = */ false,
+ /* supportsLowercaseEquality = */ false);
+ }
+
+ /**
+ * Compute normalized collation name. Components of collation name are
given in order:
+ * - Locale name
+ * - Optional case sensitivity when non-default preceded by underscore
+ * - Optional accent sensitivity when non-default preceded by underscore
+ * Examples: en, en_USA_CI_AI, sr_Cyrl_SRB_AI.
+ */
+ private String collationName() {
+ StringBuilder builder = new StringBuilder();
+ builder.append(locale);
+ if (caseSensitivity != CaseSensitivity.CS) {
+ builder.append('_');
+ builder.append(caseSensitivity.toString());
+ }
+ if (accentSensitivity != AccentSensitivity.AS) {
+ builder.append('_');
+ builder.append(accentSensitivity.toString());
+ }
+ return builder.toString();
+ }
+ }
+
+ /**
+ * Utility class for manipulating conversions between collation IDs and
specifier enums/locale
+ * IDs. Scope bitwise operations here to avoid confusion.
+ */
+ private static class SpecifierUtils {
+ private static int getSpecValue(int collationId, int offset, int mask) {
+ return (collationId >> offset) & mask;
+ }
+
+ private static int removeSpec(int collationId, int offset, int mask) {
+ return collationId & ~(mask << offset);
+ }
+
+ private static int setSpecValue(int collationId, int offset, Enum spec) {
+ return collationId | (spec.ordinal() << offset);
+ }
}
/** Returns the collation identifier. */
@@ -201,75 +722,20 @@ public final class CollationFactory {
}
}
- private static final Collation[] collationTable = new Collation[4];
- private static final HashMap<String, Integer> collationNameToIdMap = new
HashMap<>();
-
- public static final int UTF8_BINARY_COLLATION_ID = 0;
- public static final int UTF8_BINARY_LCASE_COLLATION_ID = 1;
-
public static final String PROVIDER_SPARK = "spark";
public static final String PROVIDER_ICU = "icu";
public static final List<String> SUPPORTED_PROVIDERS =
List.of(PROVIDER_SPARK, PROVIDER_ICU);
- static {
- // Binary comparison. This is the default collation.
- // No custom comparators will be used for this collation.
- // Instead, we rely on byte for byte comparison.
- collationTable[0] = new Collation(
- "UTF8_BINARY",
- PROVIDER_SPARK,
- null,
- UTF8String::binaryCompare,
- "1.0",
- s -> (long)s.hashCode(),
- true,
- true,
- false);
-
- // Case-insensitive UTF8 binary collation.
- // TODO: Do in place comparisons instead of creating new strings.
- collationTable[1] = new Collation(
- "UTF8_BINARY_LCASE",
- PROVIDER_SPARK,
- null,
- UTF8String::compareLowerCase,
- "1.0",
- (s) -> (long)s.toLowerCase().hashCode(),
- false,
- false,
- true);
-
- // UNICODE case sensitive comparison (ROOT locale, in ICU).
- collationTable[2] = new Collation(
- "UNICODE",
- PROVIDER_ICU,
- Collator.getInstance(ULocale.ROOT),
- "153.120.0.0",
- true,
- false,
- false
- );
-
- collationTable[2].collator.setStrength(Collator.TERTIARY);
- collationTable[2].collator.freeze();
-
- // UNICODE case-insensitive comparison (ROOT locale, in ICU + Secondary
strength).
- collationTable[3] = new Collation(
- "UNICODE_CI",
- PROVIDER_ICU,
- Collator.getInstance(ULocale.ROOT),
- "153.120.0.0",
- false,
- false,
- false
- );
- collationTable[3].collator.setStrength(Collator.SECONDARY);
- collationTable[3].collator.freeze();
-
- for (int i = 0; i < collationTable.length; i++) {
- collationNameToIdMap.put(collationTable[i].collationName, i);
- }
- }
+ public static final int UTF8_BINARY_COLLATION_ID =
+ Collation.CollationSpecUTF8Binary.UTF8_BINARY_COLLATION_ID;
+ public static final int UTF8_BINARY_LCASE_COLLATION_ID =
+ Collation.CollationSpecUTF8Binary.UTF8_BINARY_LCASE_COLLATION_ID;
+ public static final int UNICODE_COLLATION_ID =
+ Collation.CollationSpecICU.UNICODE_COLLATION_ID;
+ public static final int UNICODE_CI_COLLATION_ID =
+ Collation.CollationSpecICU.UNICODE_CI_COLLATION_ID;
+ public static final int INDETERMINATE_COLLATION_ID =
+ Collation.CollationSpec.INDETERMINATE_COLLATION_ID;
/**
* Returns a StringSearch object for the given pattern and target strings,
under collation
@@ -297,23 +763,6 @@ public final class CollationFactory {
return new StringSearch(patternString, target, (RuleBasedCollator)
collator);
}
- /**
- * Returns if the given collationName is valid one.
- */
- public static boolean isValidCollation(String collationName) {
- return collationNameToIdMap.containsKey(collationName.toUpperCase());
- }
-
- /**
- * Returns closest valid name to collationName
- */
- public static String getClosestCollation(String collationName) {
- Collation suggestion = Collections.min(List.of(collationTable),
Comparator.comparingInt(
- c -> UTF8String.fromString(c.collationName).levenshteinDistance(
- UTF8String.fromString(collationName.toUpperCase()))));
- return suggestion.collationName;
- }
-
/**
* Returns a collation-unaware StringSearch object for the given pattern and
target strings.
* While this object does not respect collation, it can be used to find
occurrences of the pattern
@@ -326,24 +775,10 @@ public final class CollationFactory {
}
/**
- * Returns the collation id for the given collation name.
+ * Returns the collation ID for the given collation name.
*/
public static int collationNameToId(String collationName) throws
SparkException {
- String normalizedName = collationName.toUpperCase();
- if (collationNameToIdMap.containsKey(normalizedName)) {
- return collationNameToIdMap.get(normalizedName);
- } else {
- Collation suggestion = Collections.min(List.of(collationTable),
Comparator.comparingInt(
- c -> UTF8String.fromString(c.collationName).levenshteinDistance(
- UTF8String.fromString(normalizedName))));
-
- Map<String, String> params = new HashMap<>();
- params.put("collationName", collationName);
- params.put("proposal", suggestion.collationName);
-
- throw new SparkException(
- "COLLATION_INVALID_NAME",
SparkException.constructMessageParams(params), null);
- }
+ return Collation.CollationSpec.collationNameToId(collationName);
}
public static void assertValidProvider(String provider) throws
SparkException {
@@ -359,12 +794,15 @@ public final class CollationFactory {
}
public static Collation fetchCollation(int collationId) {
- return collationTable[collationId];
+ return Collation.CollationSpec.fetchCollation(collationId);
}
public static Collation fetchCollation(String collationName) throws
SparkException {
- int collationId = collationNameToId(collationName);
- return collationTable[collationId];
+ return fetchCollation(collationNameToId(collationName));
+ }
+
+ public static String[] getICULocaleNames() {
+ return Collation.CollationSpecICU.ICULocaleNames;
}
public static UTF8String getCollationKey(UTF8String input, int collationId) {
diff --git
a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
index 768d26bf0e11..69104dea0e99 100644
---
a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
+++
b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
@@ -20,7 +20,10 @@ package org.apache.spark.unsafe.types
import scala.collection.parallel.immutable.ParSeq
import scala.jdk.CollectionConverters.MapHasAsScala
+import com.ibm.icu.util.ULocale
+
import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.util.CollationFactory.fetchCollation
// scalastyle:off
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.must.Matchers
@@ -30,31 +33,95 @@ import org.apache.spark.unsafe.types.UTF8String.{fromString
=> toUTF8}
class CollationFactorySuite extends AnyFunSuite with Matchers { //
scalastyle:ignore funsuite
test("collationId stability") {
- val utf8Binary = fetchCollation(0)
+ assert(INDETERMINATE_COLLATION_ID == -1)
+
+ assert(UTF8_BINARY_COLLATION_ID == 0)
+ val utf8Binary = fetchCollation(UTF8_BINARY_COLLATION_ID)
assert(utf8Binary.collationName == "UTF8_BINARY")
assert(utf8Binary.supportsBinaryEquality)
- val utf8BinaryLcase = fetchCollation(1)
+ assert(UTF8_BINARY_LCASE_COLLATION_ID == 1)
+ val utf8BinaryLcase = fetchCollation(UTF8_BINARY_LCASE_COLLATION_ID)
assert(utf8BinaryLcase.collationName == "UTF8_BINARY_LCASE")
assert(!utf8BinaryLcase.supportsBinaryEquality)
- val unicode = fetchCollation(2)
+ assert(UNICODE_COLLATION_ID == (1 << 29))
+ val unicode = fetchCollation(UNICODE_COLLATION_ID)
assert(unicode.collationName == "UNICODE")
- assert(unicode.supportsBinaryEquality);
+ assert(unicode.supportsBinaryEquality)
- val unicodeCi = fetchCollation(3)
+ assert(UNICODE_CI_COLLATION_ID == ((1 << 29) | (1 << 17)))
+ val unicodeCi = fetchCollation(UNICODE_CI_COLLATION_ID)
assert(unicodeCi.collationName == "UNICODE_CI")
assert(!unicodeCi.supportsBinaryEquality)
}
- test("fetch invalid collation name") {
- val error = intercept[SparkException] {
- fetchCollation("UTF8_BS")
+ test("UTF8_BINARY and ICU root locale collation names") {
+ // Collation name already normalized.
+ Seq(
+ "UTF8_BINARY",
+ "UTF8_BINARY_LCASE",
+ "UNICODE",
+ "UNICODE_CI",
+ "UNICODE_AI",
+ "UNICODE_CI_AI"
+ ).foreach(collationName => {
+ val col = fetchCollation(collationName)
+ assert(col.collationName == collationName)
+ })
+ // Collation name normalization.
+ Seq(
+ // ICU root locale.
+ ("UNICODE_CS", "UNICODE"),
+ ("UNICODE_CS_AS", "UNICODE"),
+ ("UNICODE_CI_AS", "UNICODE_CI"),
+ ("UNICODE_AI_CS", "UNICODE_AI"),
+ ("UNICODE_AI_CI", "UNICODE_CI_AI"),
+ // Randomized case collation names.
+ ("utf8_binary", "UTF8_BINARY"),
+ ("UtF8_binARy_LcasE", "UTF8_BINARY_LCASE"),
+ ("unicode", "UNICODE"),
+ ("UnICoDe_cs_aI", "UNICODE_AI")
+ ).foreach{
+ case (name, normalized) =>
+ val col = fetchCollation(name)
+ assert(col.collationName == normalized)
}
+ }
+
+ test("fetch invalid UTF8_BINARY and ICU root locale collation names") {
+ Seq(
+ "UTF8_BINARY_CS",
+ "UTF8_BINARY_AS",
+ "UTF8_BINARY_CS_AS",
+ "UTF8_BINARY_AS_CS",
+ "UTF8_BINARY_CI",
+ "UTF8_BINARY_AI",
+ "UTF8_BINARY_CI_AI",
+ "UTF8_BINARY_AI_CI",
+ "UTF8_BS",
+ "BINARY_UTF8",
+ "UTF8_BINARY_A",
+ "UNICODE_X",
+ "UNICODE_CI_X",
+ "UNICODE_LCASE_X",
+ "UTF8_UNICODE",
+ "UTF8_BINARY_UNICODE",
+ "CI_UNICODE",
+ "LCASE_UNICODE",
+ "UNICODE_UNSPECIFIED",
+ "UNICODE_CI_UNSPECIFIED",
+ "UNICODE_UNSPECIFIED_CI_UNSPECIFIED",
+ "UNICODE_INDETERMINATE",
+ "UNICODE_CI_INDETERMINATE"
+ ).foreach(collationName => {
+ val error = intercept[SparkException] {
+ fetchCollation(collationName)
+ }
- assert(error.getErrorClass === "COLLATION_INVALID_NAME")
- assert(error.getMessageParameters.asScala ===
- Map("proposal" -> "UTF8_BINARY", "collationName" -> "UTF8_BS"))
+ assert(error.getErrorClass === "COLLATION_INVALID_NAME")
+ assert(error.getMessageParameters.asScala === Map("collationName" ->
collationName))
+ })
}
case class CollationTestCase[R](collationName: String, s1: String, s2:
String, expectedResult: R)
@@ -152,4 +219,238 @@ class CollationFactorySuite extends AnyFunSuite with
Matchers { // scalastyle:ig
}
})
}
+
+ test("test collation caching") {
+ Seq(
+ "UTF8_BINARY",
+ "UTF8_BINARY_LCASE",
+ "UNICODE",
+ "UNICODE_CI",
+ "UNICODE_AI",
+ "UNICODE_CI_AI",
+ "UNICODE_AI_CI"
+ ).foreach(collationId => {
+ val col1 = fetchCollation(collationId)
+ val col2 = fetchCollation(collationId)
+ assert(col1 eq col2) // Check for reference equality.
+ })
+ }
+
+ test("collations with ICU non-root localization") {
+ Seq(
+ // Language only.
+ "en",
+ "en_CS",
+ "en_CI",
+ "en_AS",
+ "en_AI",
+ // Language + 3-letter country code.
+ "en_USA",
+ "en_USA_CS",
+ "en_USA_CI",
+ "en_USA_AS",
+ "en_USA_AI",
+ // Language + script code.
+ "sr_Cyrl",
+ "sr_Cyrl_CS",
+ "sr_Cyrl_CI",
+ "sr_Cyrl_AS",
+ "sr_Cyrl_AI",
+ // Language + script code + 3-letter country code.
+ "sr_Cyrl_SRB",
+ "sr_Cyrl_SRB_CS",
+ "sr_Cyrl_SRB_CI",
+ "sr_Cyrl_SRB_AS",
+ "sr_Cyrl_SRB_AI"
+ ).foreach(collationICU => {
+ val col = fetchCollation(collationICU)
+ assert(col.collator.getLocale(ULocale.VALID_LOCALE) != ULocale.ROOT)
+ })
+ }
+
+ test("invalid names of collations with ICU non-root localization") {
+ Seq(
+ "en_US", // Must use 3-letter country code
+ "enn",
+ "en_AAA",
+ "en_Something",
+ "en_Something_USA",
+ "en_LCASE",
+ "en_UCASE",
+ "en_CI_LCASE",
+ "en_CI_UCASE",
+ "en_CI_UNSPECIFIED",
+ "en_USA_UNSPECIFIED",
+ "en_USA_UNSPECIFIED_CI",
+ "en_INDETERMINATE",
+ "en_USA_INDETERMINATE",
+ "en_Latn_USA", // Use en_USA instead.
+ "en_Cyrl_USA",
+ "en_USA_AAA",
+ "sr_Cyrl_SRB_AAA",
+ // Invalid ordering of language, script and country code.
+ "USA_en",
+ "sr_SRB_Cyrl",
+ "SRB_sr",
+ "SRB_sr_Cyrl",
+ "SRB_Cyrl_sr",
+ "Cyrl_sr",
+ "Cyrl_sr_SRB",
+ "Cyrl_SRB_sr",
+ // Collation specifiers in the middle of locale.
+ "CI_en",
+ "USA_CI_en",
+ "en_CI_USA",
+ "CI_sr_Cyrl_SRB",
+ "sr_CI_Cyrl_SRB",
+ "sr_Cyrl_CI_SRB",
+ "CI_Cyrl_sr",
+ "Cyrl_CI_sr",
+ "Cyrl_CI_sr_SRB",
+ "Cyrl_sr_CI_SRB"
+ ).foreach(collationName => {
+ val error = intercept[SparkException] {
+ fetchCollation(collationName)
+ }
+
+ assert(error.getErrorClass === "COLLATION_INVALID_NAME")
+ assert(error.getMessageParameters.asScala === Map("collationName" ->
collationName))
+ })
+ }
+
+ test("collations name normalization for ICU non-root localization") {
+ Seq(
+ ("en_USA", "en_USA"),
+ ("en_CS", "en"),
+ ("en_AS", "en"),
+ ("en_CS_AS", "en"),
+ ("en_AS_CS", "en"),
+ ("en_CI", "en_CI"),
+ ("en_AI", "en_AI"),
+ ("en_AI_CI", "en_CI_AI"),
+ ("en_CI_AI", "en_CI_AI"),
+ ("en_CS_AI", "en_AI"),
+ ("en_AI_CS", "en_AI"),
+ ("en_CI_AS", "en_CI"),
+ ("en_AS_CI", "en_CI"),
+ ("en_USA_AI_CI", "en_USA_CI_AI"),
+ // Randomized case.
+ ("EN_USA", "en_USA"),
+ ("SR_CYRL", "sr_Cyrl"),
+ ("sr_cyrl_srb", "sr_Cyrl_SRB"),
+ ("sR_cYRl_sRb", "sr_Cyrl_SRB")
+ ).foreach {
+ case (name, normalized) =>
+ val col = fetchCollation(name)
+ assert(col.collationName == normalized)
+ }
+ }
+
+ test("invalid collationId") {
+ val badCollationIds = Seq(
+ INDETERMINATE_COLLATION_ID, // Indeterminate collation.
+ 1 << 30, // User-defined collation range.
+ (1 << 30) | 1, // User-defined collation range.
+ (1 << 30) | (1 << 29), // User-defined collation range.
+ 1 << 1, // UTF8_BINARY mandatory zero bit 1 breach.
+ 1 << 2, // UTF8_BINARY mandatory zero bit 2 breach.
+ 1 << 3, // UTF8_BINARY mandatory zero bit 3 breach.
+ 1 << 4, // UTF8_BINARY mandatory zero bit 4 breach.
+ 1 << 5, // UTF8_BINARY mandatory zero bit 5 breach.
+ 1 << 6, // UTF8_BINARY mandatory zero bit 6 breach.
+ 1 << 7, // UTF8_BINARY mandatory zero bit 7 breach.
+ 1 << 8, // UTF8_BINARY mandatory zero bit 8 breach.
+ 1 << 9, // UTF8_BINARY mandatory zero bit 9 breach.
+ 1 << 10, // UTF8_BINARY mandatory zero bit 10 breach.
+ 1 << 11, // UTF8_BINARY mandatory zero bit 11 breach.
+ 1 << 12, // UTF8_BINARY mandatory zero bit 12 breach.
+ 1 << 13, // UTF8_BINARY mandatory zero bit 13 breach.
+ 1 << 14, // UTF8_BINARY mandatory zero bit 14 breach.
+ 1 << 15, // UTF8_BINARY mandatory zero bit 15 breach.
+ 1 << 16, // UTF8_BINARY mandatory zero bit 16 breach.
+ 1 << 17, // UTF8_BINARY mandatory zero bit 17 breach.
+ 1 << 18, // UTF8_BINARY mandatory zero bit 18 breach.
+ 1 << 19, // UTF8_BINARY mandatory zero bit 19 breach.
+ 1 << 20, // UTF8_BINARY mandatory zero bit 20 breach.
+ 1 << 23, // UTF8_BINARY mandatory zero bit 23 breach.
+ 1 << 24, // UTF8_BINARY mandatory zero bit 24 breach.
+ 1 << 25, // UTF8_BINARY mandatory zero bit 25 breach.
+ 1 << 26, // UTF8_BINARY mandatory zero bit 26 breach.
+ 1 << 27, // UTF8_BINARY mandatory zero bit 27 breach.
+ 1 << 28, // UTF8_BINARY mandatory zero bit 28 breach.
+ (1 << 29) | (1 << 12), // ICU mandatory zero bit 12 breach.
+ (1 << 29) | (1 << 13), // ICU mandatory zero bit 13 breach.
+ (1 << 29) | (1 << 14), // ICU mandatory zero bit 14 breach.
+ (1 << 29) | (1 << 15), // ICU mandatory zero bit 15 breach.
+ (1 << 29) | (1 << 18), // ICU mandatory zero bit 18 breach.
+ (1 << 29) | (1 << 19), // ICU mandatory zero bit 19 breach.
+ (1 << 29) | (1 << 20), // ICU mandatory zero bit 20 breach.
+ (1 << 29) | (1 << 21), // ICU mandatory zero bit 21 breach.
+ (1 << 29) | (1 << 22), // ICU mandatory zero bit 22 breach.
+ (1 << 29) | (1 << 23), // ICU mandatory zero bit 23 breach.
+ (1 << 29) | (1 << 24), // ICU mandatory zero bit 24 breach.
+ (1 << 29) | (1 << 25), // ICU mandatory zero bit 25 breach.
+ (1 << 29) | (1 << 26), // ICU mandatory zero bit 26 breach.
+ (1 << 29) | (1 << 27), // ICU mandatory zero bit 27 breach.
+ (1 << 29) | (1 << 28), // ICU mandatory zero bit 28 breach.
+ (1 << 29) | 0xFFFF // ICU with invalid locale id.
+ )
+ badCollationIds.foreach(collationId => {
+ // Assumptions about collation id will break and assert statement will
fail.
+ intercept[AssertionError](fetchCollation(collationId))
+ })
+ }
+
+ test("repeated and/or incompatible specifiers in collation name") {
+ Seq(
+ "UTF8_BINARY_LCASE_LCASE",
+ "UNICODE_CS_CS",
+ "UNICODE_CI_CI",
+ "UNICODE_CI_CS",
+ "UNICODE_CS_CI",
+ "UNICODE_AS_AS",
+ "UNICODE_AI_AI",
+ "UNICODE_AS_AI",
+ "UNICODE_AI_AS",
+ "UNICODE_AS_CS_AI",
+ "UNICODE_CS_AI_CI",
+ "UNICODE_CS_AS_CI_AI"
+ ).foreach(collationName => {
+ val error = intercept[SparkException] {
+ fetchCollation(collationName)
+ }
+
+ assert(error.getErrorClass === "COLLATION_INVALID_NAME")
+ assert(error.getMessageParameters.asScala === Map("collationName" ->
collationName))
+ })
+ }
+
+ test("basic ICU collator checks") {
+ Seq(
+ CollationTestCase("UNICODE_CI", "a", "A", true),
+ CollationTestCase("UNICODE_CI", "a", "å", false),
+ CollationTestCase("UNICODE_CI", "a", "Å", false),
+ CollationTestCase("UNICODE_AI", "a", "A", false),
+ CollationTestCase("UNICODE_AI", "a", "å", true),
+ CollationTestCase("UNICODE_AI", "a", "Å", false),
+ CollationTestCase("UNICODE_CI_AI", "a", "A", true),
+ CollationTestCase("UNICODE_CI_AI", "a", "å", true),
+ CollationTestCase("UNICODE_CI_AI", "a", "Å", true)
+ ).foreach(testCase => {
+ val collation = fetchCollation(testCase.collationName)
+ assert(collation.equalsFunction(toUTF8(testCase.s1),
toUTF8(testCase.s2)) ==
+ testCase.expectedResult)
+ })
+ Seq(
+ CollationTestCase("en", "a", "A", -1),
+ CollationTestCase("en_CI", "a", "A", 0),
+ CollationTestCase("en_AI", "a", "å", 0),
+ CollationTestCase("sv", "Kypper", "Köpfe", -1),
+ CollationTestCase("de", "Kypper", "Köpfe", 1)
+ ).foreach(testCase => {
+ val collation = fetchCollation(testCase.collationName)
+ val result = collation.comparator.compare(toUTF8(testCase.s1),
toUTF8(testCase.s2))
+ assert(Integer.signum(result) == testCase.expectedResult)
+ })
+ }
}
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 883c51bffade..b19b05859f78 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -469,7 +469,7 @@
},
"COLLATION_INVALID_NAME" : {
"message" : [
- "The value <collationName> does not represent a correct collation name.
Suggested valid collation name: [<proposal>]."
+ "The value <collationName> does not represent a correct collation name."
],
"sqlState" : "42704"
},
@@ -1921,7 +1921,7 @@
"subClass" : {
"DEFAULT_COLLATION" : {
"message" : [
- "Cannot resolve the given default collation. Did you mean
'<proposal>'?"
+ "Cannot resolve the given default collation."
]
},
"TIME_ZONE" : {
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 49b1a5312fda..e0ad8f7078ca 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.{functions => fn}
import org.apache.spark.sql.avro.{functions => avroFn}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
+import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
@@ -699,7 +700,8 @@ class PlanGenerationTestSuite
}
test("select collated string") {
- val schema = StructType(StructField("s", StringType(1)) :: Nil)
+ val schema = StructType(
+ StructField("s",
StringType(CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID)) :: Nil)
createLocalRelation(schema.catalogString).select("s")
}
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/types.proto
b/connector/connect/common/src/main/protobuf/spark/connect/types.proto
index 48f7385330c8..4f768f201575 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/types.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/types.proto
@@ -101,7 +101,7 @@ message DataType {
message String {
uint32 type_variation_reference = 1;
- uint32 collation_id = 2;
+ string collation = 2;
}
message Binary {
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
index 1f580a0ffc0a..f63692717947 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.common
import scala.jdk.CollectionConverters._
import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SparkClassUtils
@@ -80,7 +81,7 @@ object DataTypeProtoConverter {
}
private def toCatalystStringType(t: proto.DataType.String): StringType =
- StringType(t.getCollationId)
+ StringType(if (t.getCollation.nonEmpty) t.getCollation else "UTF8_BINARY")
private def toCatalystYearMonthIntervalType(t:
proto.DataType.YearMonthInterval) = {
(t.hasStartField, t.hasEndField) match {
@@ -177,7 +178,11 @@ object DataTypeProtoConverter {
case s: StringType =>
proto.DataType
.newBuilder()
-
.setString(proto.DataType.String.newBuilder().setCollationId(s.collationId).build())
+ .setString(
+ proto.DataType.String
+ .newBuilder()
+
.setCollation(CollationFactory.fetchCollation(s.collationId).collationName)
+ .build())
.build()
case CharType(length) =>
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json
b/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json
index 33f6007ec68a..e4b31258f984 100644
---
a/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json
+++
b/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json
@@ -18,7 +18,7 @@
"name": "c1",
"dataType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"nullable": true
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.proto.bin
b/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.proto.bin
index da4ad9bf9a4e..c39243a10a8e 100644
Binary files
a/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.proto.bin
and
b/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.proto.bin
differ
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json
b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json
index adf8cabd97b1..2a5a0ddd15f8 100644
---
a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json
+++
b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json
@@ -305,7 +305,7 @@
"array": {
"elementType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"elements": [{
@@ -324,7 +324,7 @@
"array": {
"elementType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"elements": [{
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin
b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin
index d8b4407f6cfa..359ddd61d8b7 100644
Binary files
a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin
and
b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin
differ
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
b/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
index 1e651f0455c7..aaf3a91c4fe1 100644
---
a/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
+++
b/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
@@ -200,7 +200,7 @@
"map": {
"keyType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueType": {
@@ -228,7 +228,7 @@
"name": "_1",
"dataType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"nullable": true
@@ -404,7 +404,7 @@
"map": {
"keyType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueType": {
@@ -417,7 +417,7 @@
"map": {
"keyType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueType": {
@@ -439,7 +439,7 @@
"map": {
"keyType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueType": {
@@ -461,7 +461,7 @@
"map": {
"keyType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueType": {
@@ -493,7 +493,7 @@
"map": {
"keyType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueType": {
@@ -511,7 +511,7 @@
"map": {
"keyType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueType": {
@@ -533,7 +533,7 @@
"map": {
"keyType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueType": {
@@ -576,7 +576,7 @@
"map": {
"keyType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueType": {
@@ -594,7 +594,7 @@
"name": "_1",
"dataType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"nullable": true
@@ -608,7 +608,7 @@
},
"valueType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueContainsNull": true
@@ -640,7 +640,7 @@
"map": {
"keyType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueType": {
@@ -666,7 +666,7 @@
"name": "_1",
"dataType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"nullable": true
@@ -680,7 +680,7 @@
},
"valueType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"valueContainsNull": true
@@ -700,7 +700,7 @@
},
"valueType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"keys": [{
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
b/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
index b3f61830bee0..71640717c12e 100644
Binary files
a/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
and
b/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
differ
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.json
b/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.json
index 537c218952a4..f29245374e6e 100644
---
a/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.json
+++
b/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.json
@@ -18,7 +18,7 @@
"name": "c1",
"dataType": {
"string": {
- "collationId": 0
+ "collation": "UTF8_BINARY"
}
},
"nullable": true
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.proto.bin
b/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.proto.bin
index 297ab2bf0262..1ce2e676ce30 100644
Binary files
a/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.proto.bin
and
b/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.proto.bin
differ
diff --git a/python/pyspark/sql/connect/proto/types_pb2.py
b/python/pyspark/sql/connect/proto/types_pb2.py
index 65e5860b5dc6..1022605fb160 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.py
+++ b/python/pyspark/sql/connect/proto/types_pb2.py
@@ -29,7 +29,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b"\n\x19spark/connect/types.proto\x12\rspark.connect\"\xec!\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01
\x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02
\x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03
\x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04
\x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05
\x01(\x [...]
+
b"\n\x19spark/connect/types.proto\x12\rspark.connect\"\xe7!\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01
\x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02
\x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03
\x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04
\x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05
\x01(\x [...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -42,7 +42,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
)
_DATATYPE._serialized_start = 45
- _DATATYPE._serialized_end = 4377
+ _DATATYPE._serialized_end = 4372
_DATATYPE_BOOLEAN._serialized_start = 1595
_DATATYPE_BOOLEAN._serialized_end = 1662
_DATATYPE_BYTE._serialized_start = 1664
@@ -58,41 +58,41 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_DATATYPE_DOUBLE._serialized_start = 1999
_DATATYPE_DOUBLE._serialized_end = 2065
_DATATYPE_STRING._serialized_start = 2067
- _DATATYPE_STRING._serialized_end = 2168
- _DATATYPE_BINARY._serialized_start = 2170
- _DATATYPE_BINARY._serialized_end = 2236
- _DATATYPE_NULL._serialized_start = 2238
- _DATATYPE_NULL._serialized_end = 2302
- _DATATYPE_TIMESTAMP._serialized_start = 2304
- _DATATYPE_TIMESTAMP._serialized_end = 2373
- _DATATYPE_DATE._serialized_start = 2375
- _DATATYPE_DATE._serialized_end = 2439
- _DATATYPE_TIMESTAMPNTZ._serialized_start = 2441
- _DATATYPE_TIMESTAMPNTZ._serialized_end = 2513
- _DATATYPE_CALENDARINTERVAL._serialized_start = 2515
- _DATATYPE_CALENDARINTERVAL._serialized_end = 2591
- _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2594
- _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2773
- _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2776
- _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2953
- _DATATYPE_CHAR._serialized_start = 2955
- _DATATYPE_CHAR._serialized_end = 3043
- _DATATYPE_VARCHAR._serialized_start = 3045
- _DATATYPE_VARCHAR._serialized_end = 3136
- _DATATYPE_DECIMAL._serialized_start = 3139
- _DATATYPE_DECIMAL._serialized_end = 3292
- _DATATYPE_STRUCTFIELD._serialized_start = 3295
- _DATATYPE_STRUCTFIELD._serialized_end = 3456
- _DATATYPE_STRUCT._serialized_start = 3458
- _DATATYPE_STRUCT._serialized_end = 3585
- _DATATYPE_ARRAY._serialized_start = 3588
- _DATATYPE_ARRAY._serialized_end = 3750
- _DATATYPE_MAP._serialized_start = 3753
- _DATATYPE_MAP._serialized_end = 3972
- _DATATYPE_VARIANT._serialized_start = 3974
- _DATATYPE_VARIANT._serialized_end = 4041
- _DATATYPE_UDT._serialized_start = 4044
- _DATATYPE_UDT._serialized_end = 4315
- _DATATYPE_UNPARSED._serialized_start = 4317
- _DATATYPE_UNPARSED._serialized_end = 4369
+ _DATATYPE_STRING._serialized_end = 2163
+ _DATATYPE_BINARY._serialized_start = 2165
+ _DATATYPE_BINARY._serialized_end = 2231
+ _DATATYPE_NULL._serialized_start = 2233
+ _DATATYPE_NULL._serialized_end = 2297
+ _DATATYPE_TIMESTAMP._serialized_start = 2299
+ _DATATYPE_TIMESTAMP._serialized_end = 2368
+ _DATATYPE_DATE._serialized_start = 2370
+ _DATATYPE_DATE._serialized_end = 2434
+ _DATATYPE_TIMESTAMPNTZ._serialized_start = 2436
+ _DATATYPE_TIMESTAMPNTZ._serialized_end = 2508
+ _DATATYPE_CALENDARINTERVAL._serialized_start = 2510
+ _DATATYPE_CALENDARINTERVAL._serialized_end = 2586
+ _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2589
+ _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2768
+ _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2771
+ _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2948
+ _DATATYPE_CHAR._serialized_start = 2950
+ _DATATYPE_CHAR._serialized_end = 3038
+ _DATATYPE_VARCHAR._serialized_start = 3040
+ _DATATYPE_VARCHAR._serialized_end = 3131
+ _DATATYPE_DECIMAL._serialized_start = 3134
+ _DATATYPE_DECIMAL._serialized_end = 3287
+ _DATATYPE_STRUCTFIELD._serialized_start = 3290
+ _DATATYPE_STRUCTFIELD._serialized_end = 3451
+ _DATATYPE_STRUCT._serialized_start = 3453
+ _DATATYPE_STRUCT._serialized_end = 3580
+ _DATATYPE_ARRAY._serialized_start = 3583
+ _DATATYPE_ARRAY._serialized_end = 3745
+ _DATATYPE_MAP._serialized_start = 3748
+ _DATATYPE_MAP._serialized_end = 3967
+ _DATATYPE_VARIANT._serialized_start = 3969
+ _DATATYPE_VARIANT._serialized_end = 4036
+ _DATATYPE_UDT._serialized_start = 4039
+ _DATATYPE_UDT._serialized_end = 4310
+ _DATATYPE_UNPARSED._serialized_start = 4312
+ _DATATYPE_UNPARSED._serialized_end = 4364
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi
b/python/pyspark/sql/connect/proto/types_pb2.pyi
index e6b34d3485c2..b37621104537 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/types_pb2.pyi
@@ -178,22 +178,19 @@ class DataType(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
- COLLATION_ID_FIELD_NUMBER: builtins.int
+ COLLATION_FIELD_NUMBER: builtins.int
type_variation_reference: builtins.int
- collation_id: builtins.int
+ collation: builtins.str
def __init__(
self,
*,
type_variation_reference: builtins.int = ...,
- collation_id: builtins.int = ...,
+ collation: builtins.str = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "collation_id",
- b"collation_id",
- "type_variation_reference",
- b"type_variation_reference",
+ "collation", b"collation", "type_variation_reference",
b"type_variation_reference"
],
) -> None: ...
diff --git a/python/pyspark/sql/connect/types.py
b/python/pyspark/sql/connect/types.py
index 351fa0165965..885ce62e7db6 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -129,7 +129,7 @@ def pyspark_types_to_proto_types(data_type: DataType) ->
pb2.DataType:
if isinstance(data_type, NullType):
ret.null.CopyFrom(pb2.DataType.NULL())
elif isinstance(data_type, StringType):
- ret.string.collation_id = data_type.collationId
+ ret.string.collation = data_type.collation
elif isinstance(data_type, BooleanType):
ret.boolean.CopyFrom(pb2.DataType.Boolean())
elif isinstance(data_type, BinaryType):
@@ -229,7 +229,8 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType)
-> DataType:
s = schema.decimal.scale if schema.decimal.HasField("scale") else 0
return DecimalType(precision=p, scale=s)
elif schema.HasField("string"):
- return StringType.fromCollationId(schema.string.collation_id)
+ collation = schema.string.collation if schema.string.collation != ""
else "UTF8_BINARY"
+ return StringType(collation)
elif schema.HasField("char"):
return CharType(schema.char.length)
elif schema.HasField("var_char"):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 563c63f5dfb1..c72ff72ce426 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -280,26 +280,13 @@ class StringType(AtomicType):
name of the collation, default is UTF8_BINARY.
"""
- collationNames = ["UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE",
"UNICODE_CI"]
providerSpark = "spark"
providerICU = "icu"
providers = [providerSpark, providerICU]
- def __init__(self, collation: Optional[str] = None):
+ def __init__(self, collation: str = "UTF8_BINARY"):
self.typeName = self._type_name # type: ignore[method-assign]
- self.collationId = 0 if collation is None else
self.collationNameToId(collation)
-
- @classmethod
- def fromCollationId(self, collationId: int) -> "StringType":
- return StringType(StringType.collationNames[collationId])
-
- @classmethod
- def collationIdToName(cls, collationId: int) -> str:
- return StringType.collationNames[collationId]
-
- @classmethod
- def collationNameToId(cls, collationName: str) -> int:
- return StringType.collationNames.index(collationName)
+ self.collation = collation
@classmethod
def collationProvider(cls, collationName: str) -> str:
@@ -312,7 +299,7 @@ class StringType(AtomicType):
if self.isUTF8BinaryCollation():
return "string"
- return f"string collate ${self.collationIdToName(self.collationId)}"
+ return f"string collate ${self.collation}"
# For backwards compatibility and compatibility with other readers all
string types
# are serialized in json as regular strings and the collation info is
written to
@@ -322,13 +309,11 @@ class StringType(AtomicType):
def __repr__(self) -> str:
return (
- "StringType('%s')" % StringType.collationNames[self.collationId]
- if self.collationId != 0
- else "StringType()"
+ "StringType()" if self.isUTF8BinaryCollation() else
"StringType('%s')" % self.collation
)
def isUTF8BinaryCollation(self) -> bool:
- return self.collationId == 0
+ return self.collation == "UTF8_BINARY"
class CharType(AtomicType):
@@ -1046,7 +1031,7 @@ class StructField(DataType):
def schemaCollationValue(self, dt: DataType) -> str:
assert isinstance(dt, StringType)
- collationName = StringType.collationIdToName(dt.collationId)
+ collationName = dt.collation
provider = StringType.collationProvider(collationName)
return f"{provider}.{collationName}"
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 06e0c6eda589..f6f5b23b7f10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -772,12 +772,17 @@ object SQLConf {
" produced by a builtin function such as to_char or CAST")
.version("4.0.0")
.stringConf
- .checkValue(CollationFactory.isValidCollation,
+ .checkValue(
+ collationName => {
+ try {
+ CollationFactory.fetchCollation(collationName)
+ true
+ } catch {
+ case e: SparkException if e.getErrorClass ==
"COLLATION_INVALID_NAME" => false
+ }
+ },
"DEFAULT_COLLATION",
- name =>
- Map(
- "proposal" -> CollationFactory.getClosestCollation(name)
- ))
+ _ => Map())
.createWithDefault("UTF8_BINARY")
val FETCH_SHUFFLE_BLOCKS_IN_BATCH =
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
index 537bac9aae9b..c3495a0c112c 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
@@ -62,7 +62,7 @@ class CollationExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
exception = intercept[SparkException] { Collate(Literal("abc"),
"UTF8_BS") },
errorClass = "COLLATION_INVALID_NAME",
sqlState = "42704",
- parameters = Map("proposal" -> "UTF8_BINARY", "collationName" ->
"UTF8_BS"))
+ parameters = Map("collationName" -> "UTF8_BS"))
}
test("collation on non-explicit default collation") {
@@ -71,7 +71,8 @@ class CollationExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
test("collation on explicitly collated string") {
checkEvaluation(
- Collation(Literal.create("abc", StringType(1))).replacement,
+ Collation(Literal.create("abc",
+
StringType(CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID))).replacement,
"UTF8_BINARY_LCASE")
checkEvaluation(
Collation(Collate(Literal("abc"), "UTF8_BINARY_LCASE")).replacement,
@@ -161,4 +162,32 @@ class CollationExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
checkEvaluation(ArrayExcept(left, right), out)
}
}
+
+ test("collation name normalization in collation expression") {
+ Seq(
+ ("en_USA", "en_USA"),
+ ("en_CS", "en"),
+ ("en_AS", "en"),
+ ("en_CS_AS", "en"),
+ ("en_AS_CS", "en"),
+ ("en_CI", "en_CI"),
+ ("en_AI", "en_AI"),
+ ("en_AI_CI", "en_CI_AI"),
+ ("en_CI_AI", "en_CI_AI"),
+ ("en_CS_AI", "en_AI"),
+ ("en_AI_CS", "en_AI"),
+ ("en_CI_AS", "en_CI"),
+ ("en_AS_CI", "en_CI"),
+ ("en_USA_AI_CI", "en_USA_CI_AI"),
+ // randomized case
+ ("EN_USA", "en_USA"),
+ ("SR_CYRL", "sr_Cyrl"),
+ ("sr_cyrl_srb", "sr_Cyrl_SRB"),
+ ("sR_cYRl_sRb", "sr_Cyrl_SRB")
+ ).foreach {
+ case (collation, normalized) =>
+ checkEvaluation(Collation(Literal.create("abc",
StringType(collation))).replacement,
+ normalized)
+ }
+ }
}
diff --git a/sql/core/src/test/resources/collations/ICU-collations-map.md
b/sql/core/src/test/resources/collations/ICU-collations-map.md
new file mode 100644
index 000000000000..598c3c4b4024
--- /dev/null
+++ b/sql/core/src/test/resources/collations/ICU-collations-map.md
@@ -0,0 +1,143 @@
+<!-- Automatically generated by ICUCollationsMapSuite -->
+## ICU locale ids to name map
+| Locale id | Locale name |
+| --------- | ----------- |
+| 0 | UNICODE |
+| 1 | af |
+| 2 | am |
+| 3 | ar |
+| 4 | ar_SAU |
+| 5 | as |
+| 6 | az |
+| 7 | be |
+| 8 | bg |
+| 9 | bn |
+| 10 | bo |
+| 11 | br |
+| 12 | bs |
+| 13 | bs_Cyrl |
+| 14 | ca |
+| 15 | ceb |
+| 16 | chr |
+| 17 | cs |
+| 18 | cy |
+| 19 | da |
+| 20 | de |
+| 21 | de_AUT |
+| 22 | dsb |
+| 23 | dz |
+| 24 | ee |
+| 25 | el |
+| 26 | en |
+| 27 | en_USA |
+| 28 | eo |
+| 29 | es |
+| 30 | et |
+| 31 | fa |
+| 32 | fa_AFG |
+| 33 | ff |
+| 34 | ff_Adlm |
+| 35 | fi |
+| 36 | fil |
+| 37 | fo |
+| 38 | fr |
+| 39 | fr_CAN |
+| 40 | fy |
+| 41 | ga |
+| 42 | gl |
+| 43 | gu |
+| 44 | ha |
+| 45 | haw |
+| 46 | he |
+| 47 | he_ISR |
+| 48 | hi |
+| 49 | hr |
+| 50 | hsb |
+| 51 | hu |
+| 52 | hy |
+| 53 | id |
+| 54 | id_IDN |
+| 55 | ig |
+| 56 | is |
+| 57 | it |
+| 58 | ja |
+| 59 | ka |
+| 60 | kk |
+| 61 | kl |
+| 62 | km |
+| 63 | kn |
+| 64 | ko |
+| 65 | kok |
+| 66 | ku |
+| 67 | ky |
+| 68 | lb |
+| 69 | lkt |
+| 70 | ln |
+| 71 | lo |
+| 72 | lt |
+| 73 | lv |
+| 74 | mk |
+| 75 | ml |
+| 76 | mn |
+| 77 | mr |
+| 78 | ms |
+| 79 | mt |
+| 80 | my |
+| 81 | nb |
+| 82 | nb_NOR |
+| 83 | ne |
+| 84 | nl |
+| 85 | nn |
+| 86 | no |
+| 87 | om |
+| 88 | or |
+| 89 | pa |
+| 90 | pa_Guru |
+| 91 | pa_Guru_IND |
+| 92 | pl |
+| 93 | ps |
+| 94 | pt |
+| 95 | ro |
+| 96 | ru |
+| 97 | sa |
+| 98 | se |
+| 99 | si |
+| 100 | sk |
+| 101 | sl |
+| 102 | smn |
+| 103 | sq |
+| 104 | sr |
+| 105 | sr_Cyrl |
+| 106 | sr_Cyrl_BIH |
+| 107 | sr_Cyrl_MNE |
+| 108 | sr_Cyrl_SRB |
+| 109 | sr_Latn |
+| 110 | sr_Latn_BIH |
+| 111 | sr_Latn_SRB |
+| 112 | sv |
+| 113 | sw |
+| 114 | ta |
+| 115 | te |
+| 116 | th |
+| 117 | tk |
+| 118 | to |
+| 119 | tr |
+| 120 | ug |
+| 121 | uk |
+| 122 | ur |
+| 123 | uz |
+| 124 | vi |
+| 125 | wae |
+| 126 | wo |
+| 127 | xh |
+| 128 | yi |
+| 129 | yo |
+| 130 | zh |
+| 131 | zh_Hans |
+| 132 | zh_Hans_CHN |
+| 133 | zh_Hans_SGP |
+| 134 | zh_Hant |
+| 135 | zh_Hant_HKG |
+| 136 | zh_Hant_MAC |
+| 137 | zh_Hant_TWN |
+| 138 | zu |
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
index d242a60a17c1..9a1f4ed1f8e5 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
@@ -312,3 +312,80 @@ select array_except(array('aaa' collate
utf8_binary_lcase), array('AAA' collate
-- !query analysis
Project [array_except(array(collate(aaa, utf8_binary_lcase)),
array(collate(AAA, utf8_binary_lcase))) AS array_except(array(collate(aaa)),
array(collate(AAA)))#x]
+- OneRowRelation
+
+
+-- !query
+select 'a' collate unicode < 'A'
+-- !query analysis
+Project [(collate(a, unicode) < cast(A as string collate UNICODE)) AS
(collate(a) < A)#x]
++- OneRowRelation
+
+
+-- !query
+select 'a' collate unicode_ci = 'A'
+-- !query analysis
+Project [(collate(a, unicode_ci) = cast(A as string collate UNICODE_CI)) AS
(collate(a) = A)#x]
++- OneRowRelation
+
+
+-- !query
+select 'a' collate unicode_ai = 'å'
+-- !query analysis
+Project [(collate(a, unicode_ai) = cast(å as string collate UNICODE_AI)) AS
(collate(a) = å)#x]
++- OneRowRelation
+
+
+-- !query
+select 'a' collate unicode_ci_ai = 'Å'
+-- !query analysis
+Project [(collate(a, unicode_ci_ai) = cast(Å as string collate UNICODE_CI_AI))
AS (collate(a) = Å)#x]
++- OneRowRelation
+
+
+-- !query
+select 'a' collate en < 'A'
+-- !query analysis
+Project [(collate(a, en) < cast(A as string collate en)) AS (collate(a) < A)#x]
++- OneRowRelation
+
+
+-- !query
+select 'a' collate en_ci = 'A'
+-- !query analysis
+Project [(collate(a, en_ci) = cast(A as string collate en_CI)) AS (collate(a)
= A)#x]
++- OneRowRelation
+
+
+-- !query
+select 'a' collate en_ai = 'å'
+-- !query analysis
+Project [(collate(a, en_ai) = cast(å as string collate en_AI)) AS (collate(a)
= å)#x]
++- OneRowRelation
+
+
+-- !query
+select 'a' collate en_ci_ai = 'Å'
+-- !query analysis
+Project [(collate(a, en_ci_ai) = cast(Å as string collate en_CI_AI)) AS
(collate(a) = Å)#x]
++- OneRowRelation
+
+
+-- !query
+select 'Kypper' collate sv < 'Köpfe'
+-- !query analysis
+Project [(collate(Kypper, sv) < cast(Köpfe as string collate sv)) AS
(collate(Kypper) < Köpfe)#x]
++- OneRowRelation
+
+
+-- !query
+select 'Kypper' collate de > 'Köpfe'
+-- !query analysis
+Project [(collate(Kypper, de) > cast(Köpfe as string collate de)) AS
(collate(Kypper) > Köpfe)#x]
++- OneRowRelation
+
+
+-- !query
+select 'I' collate tr_ci = 'ı'
+-- !query analysis
+Project [(collate(I, tr_ci) = cast(ı as string collate tr_CI)) AS (collate(I)
= ı)#x]
++- OneRowRelation
diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql
b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
index 619eb4470e9a..6bb0a0163443 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
@@ -77,3 +77,16 @@ select array_distinct(array('aaa' collate utf8_binary_lcase,
'AAA' collate utf8_
select array_union(array('aaa' collate utf8_binary_lcase), array('AAA' collate
utf8_binary_lcase));
select array_intersect(array('aaa' collate utf8_binary_lcase), array('AAA'
collate utf8_binary_lcase));
select array_except(array('aaa' collate utf8_binary_lcase), array('AAA'
collate utf8_binary_lcase));
+
+-- ICU collations (all statements return true)
+select 'a' collate unicode < 'A';
+select 'a' collate unicode_ci = 'A';
+select 'a' collate unicode_ai = 'å';
+select 'a' collate unicode_ci_ai = 'Å';
+select 'a' collate en < 'A';
+select 'a' collate en_ci = 'A';
+select 'a' collate en_ai = 'å';
+select 'a' collate en_ci_ai = 'Å';
+select 'Kypper' collate sv < 'Köpfe';
+select 'Kypper' collate de > 'Köpfe';
+select 'I' collate tr_ci = 'ı';
diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out
b/sql/core/src/test/resources/sql-tests/results/collations.sql.out
index 4485191ba1f3..96c875306d35 100644
--- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out
@@ -339,3 +339,91 @@ select array_except(array('aaa' collate
utf8_binary_lcase), array('AAA' collate
struct<array_except(array(collate(aaa)), array(collate(AAA))):array<string
collate UTF8_BINARY_LCASE>>
-- !query output
[]
+
+
+-- !query
+select 'a' collate unicode < 'A'
+-- !query schema
+struct<(collate(a) < A):boolean>
+-- !query output
+true
+
+
+-- !query
+select 'a' collate unicode_ci = 'A'
+-- !query schema
+struct<(collate(a) = A):boolean>
+-- !query output
+true
+
+
+-- !query
+select 'a' collate unicode_ai = 'å'
+-- !query schema
+struct<(collate(a) = å):boolean>
+-- !query output
+true
+
+
+-- !query
+select 'a' collate unicode_ci_ai = 'Å'
+-- !query schema
+struct<(collate(a) = Å):boolean>
+-- !query output
+true
+
+
+-- !query
+select 'a' collate en < 'A'
+-- !query schema
+struct<(collate(a) < A):boolean>
+-- !query output
+true
+
+
+-- !query
+select 'a' collate en_ci = 'A'
+-- !query schema
+struct<(collate(a) = A):boolean>
+-- !query output
+true
+
+
+-- !query
+select 'a' collate en_ai = 'å'
+-- !query schema
+struct<(collate(a) = å):boolean>
+-- !query output
+true
+
+
+-- !query
+select 'a' collate en_ci_ai = 'Å'
+-- !query schema
+struct<(collate(a) = Å):boolean>
+-- !query output
+true
+
+
+-- !query
+select 'Kypper' collate sv < 'Köpfe'
+-- !query schema
+struct<(collate(Kypper) < Köpfe):boolean>
+-- !query output
+true
+
+
+-- !query
+select 'Kypper' collate de > 'Köpfe'
+-- !query schema
+struct<(collate(Kypper) > Köpfe):boolean>
+-- !query output
+true
+
+
+-- !query
+select 'I' collate tr_ci = 'ı'
+-- !query schema
+struct<(collate(I) = ı):boolean>
+-- !query output
+true
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index 657fd4504cac..4f8587395b3e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -152,7 +152,7 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
exception = intercept[SparkException] { sql("select 'aaa' collate
UTF8_BS") },
errorClass = "COLLATION_INVALID_NAME",
sqlState = "42704",
- parameters = Map("proposal" -> "UTF8_BINARY", "collationName" ->
"UTF8_BS"))
+ parameters = Map("collationName" -> "UTF8_BS"))
}
test("disable bucketing on collated string column") {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/ICUCollationsMapSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/ICUCollationsMapSuite.scala
new file mode 100644
index 000000000000..42d486bd7545
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ICUCollationsMapSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile,
CollationFactory}
+
+// scalastyle:off line.size.limit
+/**
+ * Guard against breaking changes in ICU locale names and codes supported by
Collator class and provider by CollationFactory.
+ * Map is in form of rows of pairs (locale name, locale id); locale name
consists of three parts:
+ * - 2-letter lowercase language code
+ * - 4-letter script code (optional)
+ * - 3-letter uppercase country code
+ *
+ * To re-generate collations map golden file, run:
+ * {{{
+ * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly
org.apache.spark.sql.ICUCollationsMapSuite"
+ * }}}
+ */
+// scalastyle:on line.size.limit
+class ICUCollationsMapSuite extends SparkFunSuite {
+
+ private val collationsMapFile = {
+ getWorkspaceFilePath("sql", "core", "src", "test", "resources",
+ "collations", "ICU-collations-map.md").toFile
+ }
+
+ if (regenerateGoldenFiles) {
+ val map = CollationFactory.getICULocaleNames
+ val mapOutput = map.zipWithIndex.map {
+ case (localeName, idx) => s"| $idx | $localeName |" }.mkString("\n")
+ val goldenOutput = {
+ s"<!-- Automatically generated by ${getClass.getSimpleName} -->\n" +
+ "## ICU locale ids to name map\n" +
+ "| Locale id | Locale name |\n" +
+ "| --------- | ----------- |\n" +
+ mapOutput + "\n"
+ }
+ val parent = collationsMapFile.getParentFile
+ if (!parent.exists()) {
+ assert(parent.mkdirs(), "Could not create directory: " + parent)
+ }
+ stringToFile(collationsMapFile, goldenOutput)
+ }
+
+ test("ICU locales map breaking change") {
+ val goldenLines = fileToString(collationsMapFile).split('\n')
+ val goldenRelevantLines = goldenLines.slice(4, goldenLines.length) // skip
header
+ val input = goldenRelevantLines.map(
+ s => (s.split('|')(2).strip(), s.split('|')(1).strip().toInt))
+ assert(input sameElements CollationFactory.getICULocaleNames.zipWithIndex)
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index 213dfd32c869..8d291591c5f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -518,8 +518,7 @@ class SQLConfSuite extends QueryTest with
SharedSparkSession {
errorClass = "INVALID_CONF_VALUE.DEFAULT_COLLATION",
parameters = Map(
"confValue" -> "UNICODE_C",
- "confName" -> "spark.sql.session.collation.default",
- "proposal" -> "UNICODE_CI"
+ "confName" -> "spark.sql.session.collation.default"
))
withSQLConf(SQLConf.COLLATION_ENABLED.key -> "false") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]