mkaravel commented on code in PR #46180:
URL: https://github.com/apache/spark/pull/46180#discussion_r1607446081


##########
common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java:
##########
@@ -288,13 +338,24 @@ private static int collationNameToId(String 
collationName) throws SparkException
 
     private static class CollationSpecUTF8Binary extends CollationSpec {
 
-      private static final int CASE_SENSITIVITY_OFFSET = 0;
-      private static final int CASE_SENSITIVITY_MASK = 0b1;
-
+      /**
+       * Bit 0 in collation id having value
+       * 0 for plain UTF8_BINARY and 1 for UTF8_BINARY_LCASE collation.

Review Comment:
   Please use the full extent of the 100 characters per line that you are 
allowed for comments.
   I am not suggesting to pack everything together all the time, but in cases 
like the one above, I would have expected to have the first line "filled" 
before going to the second line.



##########
common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java:
##########
@@ -221,41 +231,80 @@ public Collation(
      */
     private abstract static class CollationSpec {
 
+      /**
+       * Bit 30 in collation id having value 0 for predefined and 1 for 
user-defined collation.
+       */
       private enum DefinitionOrigin {
         PREDEFINED, USER_DEFINED
       }
 
+      /**
+       * Bit 29 in collation id having value
+       * 0 for UTF8_BINARY family and 1 for ICU family of collations.
+       */
       protected enum ImplementationProvider {
         UTF8_BINARY, ICU
       }
 
+      /**
+       * Offset in binary collation id layout.
+       */
       private static final int DEFINITION_ORIGIN_OFFSET = 30;
+
+      /**
+       * Bitmask corresponding to width in bits in binary collation id layout.
+       */
       private static final int DEFINITION_ORIGIN_MASK = 0b1;
+
+      /**
+       * Offset in binary collation id layout.
+       */
       protected static final int IMPLEMENTATION_PROVIDER_OFFSET = 29;
+
+      /**
+       * Bitmask corresponding to width in bits in binary collation id layout.
+       */
       protected static final int IMPLEMENTATION_PROVIDER_MASK = 0b1;
 
       private static final int INDETERMINATE_COLLATION_ID = -1;
 
+      /**
+       * Thread-safe cache mapping collation ids to corresponding `Collation` 
instances.
+       * We add entries to this cache lazily as new `Collation` instances are 
requested.
+       */
       private static final Map<Integer, Collation> collationMap = new 
ConcurrentHashMap<>();
 
+      /**
+       * Utility function to retrieve `ImplementationProvider` enum instance 
from collation id.
+       */
       private static ImplementationProvider getImplementationProvider(int 
collationId) {
         return 
ImplementationProvider.values()[SpecifierUtils.getSpecValue(collationId,
           IMPLEMENTATION_PROVIDER_OFFSET, IMPLEMENTATION_PROVIDER_MASK)];
       }
 
+      /**
+       * Utility function to retrieve `DefinitionOrigin` enum instance from 
collation id.
+       */
       private static DefinitionOrigin getDefinitionOrigin(int collationId) {
         return 
DefinitionOrigin.values()[SpecifierUtils.getSpecValue(collationId,
           DEFINITION_ORIGIN_OFFSET, DEFINITION_ORIGIN_MASK)];
       }
 
+      /**
+       * Main entry point for retrieving `Collation` instance from collation 
id.
+       */
       private static Collation fetchCollation(int collationId) {
+        // User-defined collations and INDETERMINATE collations cannot produce 
`Collation` instance

Review Comment:
   ```suggestion
           // User-defined collations and INDETERMINATE collations cannot 
produce a `Collation` instance.
   ```



##########
common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java:
##########
@@ -333,77 +397,136 @@ private static CollationSpecUTF8Binary 
fromCollationId(int collationId) {
       @Override
       protected Collation buildCollation() {
         if (collationId == UTF8_BINARY_COLLATION_ID) {
-          return new Collation("UTF8_BINARY", null, UTF8String::binaryCompare, 
"1.0",
-            s -> (long) s.hashCode(), true, true, false);
+          return new Collation(
+            "UTF8_BINARY",
+            PROVIDER_SPARK,
+            null,
+            UTF8String::binaryCompare,
+            "1.0",
+            s -> (long) s.hashCode(),
+            true,
+            true,
+            false);

Review Comment:
   Way better! Thank you!
   Since we are in Java and not Scala, I suggest using comments for the boolean 
constants. Something like:
   ```java
   /*isBinaryCollation=*/ true,
   ```



##########
common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala:
##########
@@ -152,4 +219,238 @@ class CollationFactorySuite extends AnyFunSuite with 
Matchers { // scalastyle:ig
       }
     })
   }
+
+  test("test collation caching") {
+    Seq(
+      "UTF8_BINARY",
+      "UTF8_BINARY_LCASE",
+      "UNICODE",
+      "UNICODE_CI",
+      "UNICODE_AI",
+      "UNICODE_CI_AI",
+      "UNICODE_AI_CI"
+    ).foreach(collationId => {
+      val col1 = fetchCollation(collationId)
+      val col2 = fetchCollation(collationId)
+      assert(col1 eq col2) // Check for reference equality.
+    })
+  }
+
+  test("collations with ICU non-root localization") {
+    Seq(
+      // Language only.
+      "en",
+      "en_CS",
+      "en_CI",
+      "en_AS",
+      "en_AI",
+      // Language + 3-letter country code.
+      "en_USA",
+      "en_USA_CS",
+      "en_USA_CI",
+      "en_USA_AS",
+      "en_USA_AI",
+      // Language + script code.
+      "sr_Cyrl",
+      "sr_Cyrl_CS",
+      "sr_Cyrl_CI",
+      "sr_Cyrl_AS",
+      "sr_Cyrl_AI",
+      // Language + script code + 3-letter country code.
+      "sr_Cyrl_SRB",
+      "sr_Cyrl_SRB_CS",
+      "sr_Cyrl_SRB_CI",
+      "sr_Cyrl_SRB_AS",
+      "sr_Cyrl_SRB_AI"
+    ).foreach(collationICU => {
+      val col = fetchCollation(collationICU)
+      assert(col.collator.getLocale(ULocale.VALID_LOCALE) != ULocale.ROOT)
+    })
+  }
+
+  test("invalid names of collations with ICU non-root localization") {
+    Seq(
+      "en_US", // Must use 3-letter country code
+      "enn",
+      "en_AAA",
+      "en_Something",
+      "en_Something_USA",
+      "en_LCASE",
+      "en_UCASE",
+      "en_CI_LCASE",
+      "en_CI_UCASE",
+      "en_CI_UNSPECIFIED",
+      "en_USA_UNSPECIFIED",
+      "en_USA_UNSPECIFIED_CI",
+      "en_INDETERMINATE",
+      "en_USA_INDETERMINATE",
+      "en_Latn_USA", // Use en_USA instead.
+      "en_Cyrl_USA",
+      "en_USA_AAA",
+      "sr_Cyrl_SRB_AAA",
+      // Invalid ordering of language, script and country code.
+      "USA_en",
+      "sr_SRB_Cyrl",
+      "SRB_sr",
+      "SRB_sr_Cyrl",
+      "SRB_Cyrl_sr",
+      "Cyrl_sr",
+      "Cyrl_sr_SRB",
+      "Cyrl_SRB_sr",
+      // Collation specifiers in the middle of locale.
+      "CI_en",
+      "USA_CI_en",
+      "en_CI_USA",
+      "CI_sr_Cyrl_SRB",
+      "sr_CI_Cyrl_SRB",
+      "sr_Cyrl_CI_SRB",
+      "CI_Cyrl_sr",
+      "Cyrl_CI_sr",
+      "Cyrl_CI_sr_SRB",
+      "Cyrl_sr_CI_SRB"
+    ).foreach(collationName => {
+      val error = intercept[SparkException] {
+        fetchCollation(collationName)
+      }
+
+      assert(error.getErrorClass === "COLLATION_INVALID_NAME")
+      assert(error.getMessageParameters.asScala === Map("collationName" -> 
collationName))
+    })
+  }
+
+  test("collations name normalization for ICU non-root localization") {
+    Seq(
+      ("en_USA", "en_USA"),
+      ("en_CS", "en"),
+      ("en_AS", "en"),
+      ("en_CS_AS", "en"),
+      ("en_AS_CS", "en"),
+      ("en_CI", "en_CI"),
+      ("en_AI", "en_AI"),
+      ("en_AI_CI", "en_CI_AI"),
+      ("en_CI_AI", "en_CI_AI"),
+      ("en_CS_AI", "en_AI"),
+      ("en_AI_CS", "en_AI"),
+      ("en_CI_AS", "en_CI"),
+      ("en_AS_CI", "en_CI"),
+      ("en_USA_AI_CI", "en_USA_CI_AI"),
+      // Randomized case.
+      ("EN_USA", "en_USA"),
+      ("SR_CYRL", "sr_Cyrl"),
+      ("sr_cyrl_srb", "sr_Cyrl_SRB"),
+      ("sR_cYRl_sRb", "sr_Cyrl_SRB")
+    ).foreach {
+      case (name, normalized) =>
+        val col = fetchCollation(name)
+        assert(col.collationName == normalized)
+    }
+  }
+
+  test("invalid collationId") {
+    val badCollationIds = Seq(
+      INDETERMINATE_COLLATION_ID, // Indeterminate collation.
+      1 << 30, // User-defined collation range.
+      (1 << 30) | 1, // User-defined collation range.
+      (1 << 30) | (1 << 29), // User-defined collation range.
+      1 << 1, // UTF8_BINARY mandatory zero bit 1 breach.
+      1 << 2, // UTF8_BINARY mandatory zero bit 2 breach.
+      1 << 3, // UTF8_BINARY mandatory zero bit 3 breach.
+      1 << 4, // UTF8_BINARY mandatory zero bit 4 breach.
+      1 << 5, // UTF8_BINARY mandatory zero bit 5 breach.
+      1 << 6, // UTF8_BINARY mandatory zero bit 6 breach.
+      1 << 7, // UTF8_BINARY mandatory zero bit 7 breach.
+      1 << 8, // UTF8_BINARY mandatory zero bit 8 breach.
+      1 << 9, // UTF8_BINARY mandatory zero bit 9 breach.
+      1 << 10, // UTF8_BINARY mandatory zero bit 10 breach.
+      1 << 11, // UTF8_BINARY mandatory zero bit 11 breach.
+      1 << 12, // UTF8_BINARY mandatory zero bit 12 breach.
+      1 << 13, // UTF8_BINARY mandatory zero bit 13 breach.
+      1 << 14, // UTF8_BINARY mandatory zero bit 14 breach.
+      1 << 15, // UTF8_BINARY mandatory zero bit 15 breach.
+      1 << 16, // UTF8_BINARY mandatory zero bit 16 breach.
+      1 << 17, // UTF8_BINARY mandatory zero bit 17 breach.
+      1 << 18, // UTF8_BINARY mandatory zero bit 18 breach.
+      1 << 19, // UTF8_BINARY mandatory zero bit 19 breach.
+      1 << 20, // UTF8_BINARY mandatory zero bit 20 breach.
+      1 << 23, // UTF8_BINARY mandatory zero bit 23 breach.
+      1 << 24, // UTF8_BINARY mandatory zero bit 24 breach.
+      1 << 25, // UTF8_BINARY mandatory zero bit 25 breach.
+      1 << 26, // UTF8_BINARY mandatory zero bit 26 breach.
+      1 << 27, // UTF8_BINARY mandatory zero bit 27 breach.
+      1 << 28, // UTF8_BINARY mandatory zero bit 28 breach.
+      (1 << 29) | (1 << 12), // ICU mandatory zero bit 12 breach.
+      (1 << 29) | (1 << 13), // ICU mandatory zero bit 13 breach.
+      (1 << 29) | (1 << 14), // ICU mandatory zero bit 14 breach.
+      (1 << 29) | (1 << 15), // ICU mandatory zero bit 15 breach.
+      (1 << 29) | (1 << 18), // ICU mandatory zero bit 18 breach.
+      (1 << 29) | (1 << 19), // ICU mandatory zero bit 19 breach.
+      (1 << 29) | (1 << 20), // ICU mandatory zero bit 20 breach.
+      (1 << 29) | (1 << 21), // ICU mandatory zero bit 21 breach.
+      (1 << 29) | (1 << 22), // ICU mandatory zero bit 22 breach.
+      (1 << 29) | (1 << 23), // ICU mandatory zero bit 23 breach.
+      (1 << 29) | (1 << 24), // ICU mandatory zero bit 24 breach.
+      (1 << 29) | (1 << 25), // ICU mandatory zero bit 25 breach.
+      (1 << 29) | (1 << 26), // ICU mandatory zero bit 26 breach.
+      (1 << 29) | (1 << 27), // ICU mandatory zero bit 27 breach.
+      (1 << 29) | (1 << 28), // ICU mandatory zero bit 28 breach.
+      (1 << 29) | 0xFFFF // ICU with invalid locale id.
+    )
+    badCollationIds.foreach(collationId => {
+      // Assumptions about collation id will break and assert statement will 
fail.
+      intercept[AssertionError](fetchCollation(collationId))
+    })
+  }
+
+  test("repeated specifiers in collation name") {

Review Comment:
   ```suggestion
     test("repeated and/or incompatible specifiers in collation name") {
   ```



##########
common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java:
##########
@@ -179,35 +179,45 @@ public Collation(
      * Classes of collations are differentiated by most significant 3 bits 
(bit 31, 30 and 29),
      * bit 31 being most significant and bit 0 being least significant.
      * ---
+     * General collation id binary layout:

Review Comment:
   ```suggestion
        * General collation ID binary layout:
   ```
   This is a nit. Would be good though to consistently use "ID" in the comments 
(code is different).



##########
common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java:
##########
@@ -117,76 +119,445 @@ public Collation(
     }
 
     /**
-     * Constructor with comparators that are inherited from the given collator.
+     * collation id (32-bit integer) layout:
+     * bit 31:    0 = predefined collation, 1 = user-defined collation
+     * bit 30-29: 00 = utf8-binary, 01 = ICU, 10 = indeterminate (without spec 
implementation)
+     * bit 28:    0 for utf8-binary / 0 = case-sensitive, 1 = case-insensitive 
for ICU
+     * bit 27:    0 for utf8-binary / 0 = accent-sensitive, 1 = 
accent-insensitive for ICU
+     * bit 26-25: zeroes, reserved for punctuation sensitivity
+     * bit 24-23: zeroes, reserved for first letter preference
+     * bit 22-21: 00 = unspecified, 01 = to-lower, 10 = to-upper
+     * bit 20-19: zeroes, reserved for space trimming
+     * bit 18-17: zeroes, reserved for version
+     * bit 16-12: zeroes
+     * bit 11-0:  zeroes for utf8-binary / locale id for ICU
      */
-    public Collation(
-        String collationName,
-        Collator collator,
-        String version,
-        boolean supportsBinaryEquality,
-        boolean supportsBinaryOrdering,
-        boolean supportsLowercaseEquality) {
-      this(
-        collationName,
-        collator,
-        (s1, s2) -> collator.compare(s1.toString(), s2.toString()),
-        version,
-        s -> (long)collator.getCollationKey(s.toString()).hashCode(),
-        supportsBinaryEquality,
-        supportsBinaryOrdering,
-        supportsLowercaseEquality);
+    private abstract static class CollationSpec {
+      protected enum ImplementationProvider {
+        UTF8_BINARY, ICU, INDETERMINATE
+      }
+
+      protected enum CaseSensitivity {
+        CS, CI
+      }
+
+      protected enum AccentSensitivity {
+        AS, AI
+      }
+
+      protected enum CaseConversion {
+        UNSPECIFIED, LCASE, UCASE
+      }
+
+      protected static final int IMPLEMENTATION_PROVIDER_OFFSET = 29;
+      protected static final int IMPLEMENTATION_PROVIDER_MASK = 0b11;
+      protected static final int CASE_SENSITIVITY_OFFSET = 28;
+      protected static final int CASE_SENSITIVITY_MASK = 0b1;
+      protected static final int ACCENT_SENSITIVITY_OFFSET = 27;
+      protected static final int ACCENT_SENSITIVITY_MASK = 0b1;
+      protected static final int CASE_CONVERSION_OFFSET = 21;
+      protected static final int CASE_CONVERSION_MASK = 0b11;
+      protected static final int LOCALE_OFFSET = 0;
+      protected static final int LOCALE_MASK = 0x0FFF;
+
+      protected static final int INDETERMINATE_COLLATION_ID =
+        ImplementationProvider.INDETERMINATE.ordinal() << 
IMPLEMENTATION_PROVIDER_OFFSET;
+
+      protected final CaseSensitivity caseSensitivity;
+      protected final AccentSensitivity accentSensitivity;
+      protected final CaseConversion caseConversion;
+      protected final String locale;
+      protected final int collationId;
+
+      protected CollationSpec(
+          String locale,
+          CaseSensitivity caseSensitivity,
+          AccentSensitivity accentSensitivity,
+          CaseConversion caseConversion) {
+        this.locale = locale;
+        this.caseSensitivity = caseSensitivity;
+        this.accentSensitivity = accentSensitivity;
+        this.caseConversion = caseConversion;
+        this.collationId = getCollationId();
+      }
+
+      private static final Map<Integer, Collation> collationMap = new 
ConcurrentHashMap<>();
+
+      public static Collation fetchCollation(int collationId) throws 
SparkException {
+        if (collationId == UTF8_BINARY_COLLATION_ID) {
+          return CollationSpecUTF8Binary.UTF8_BINARY_COLLATION;
+        } else if (collationMap.containsKey(collationId)) {
+          return collationMap.get(collationId);
+        } else {
+          CollationSpec spec;
+          int implementationProviderOrdinal =
+            (collationId >> IMPLEMENTATION_PROVIDER_OFFSET) & 
IMPLEMENTATION_PROVIDER_MASK;
+          if (implementationProviderOrdinal >= 
ImplementationProvider.values().length) {
+            throw SparkException.internalError("Invalid collation 
implementation provider");
+          } else {
+            ImplementationProvider implementationProvider = 
ImplementationProvider.values()[
+              implementationProviderOrdinal];
+            if (implementationProvider == ImplementationProvider.UTF8_BINARY) {
+              spec = CollationSpecUTF8Binary.fromCollationId(collationId);
+            } else if (implementationProvider == ImplementationProvider.ICU) {
+              spec = CollationSpecICU.fromCollationId(collationId);
+            } else {
+              throw SparkException.internalError("Cannot instantiate 
indeterminate collation");
+            }
+            Collation collation = spec.buildCollation();
+            collationMap.put(collationId, collation);
+            return collation;
+          }
+        }
+      }
+
+      protected static SparkException collationInvalidNameException(String 
collationName) {
+        return new SparkException("COLLATION_INVALID_NAME",
+          SparkException.constructMessageParams(Map.of("collationName", 
collationName)), null);
+      }
+
+      public static int collationNameToId(String collationName) throws 
SparkException {
+        String collationNameUpper = collationName.toUpperCase();

Review Comment:
   SG. Let's have this in mind for user-defined collations.



##########
common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala:
##########
@@ -152,4 +219,218 @@ class CollationFactorySuite extends AnyFunSuite with 
Matchers { // scalastyle:ig
       }
     })
   }
+
+  test("test collation caching") {
+    Seq(
+      "UTF8_BINARY",
+      "UTF8_BINARY_LCASE",
+      "UNICODE",
+      "UNICODE_CI",
+      "UNICODE_AI",
+      "UNICODE_CI_AI",
+      "UNICODE_AI_CI"
+    ).foreach(collationId => {
+      val col1 = fetchCollation(collationId)
+      val col2 = fetchCollation(collationId)
+      assert(col1 eq col2) // reference equality
+    })
+  }
+
+  test("collations with ICU non-root localization") {
+    Seq(
+      // language only
+      "en",
+      "en_CS",
+      "en_CI",
+      "en_AS",
+      "en_AI",
+      // language + 3-letter country code
+      "en_USA",
+      "en_USA_CS",
+      "en_USA_CI",
+      "en_USA_AS",
+      "en_USA_AI",
+      // language + script code
+      "sr_Cyrl",
+      "sr_Cyrl_CS",
+      "sr_Cyrl_CI",
+      "sr_Cyrl_AS",
+      "sr_Cyrl_AI",
+      // language + script code + 3-letter country code
+      "sr_Cyrl_SRB",
+      "sr_Cyrl_SRB_CS",
+      "sr_Cyrl_SRB_CI",
+      "sr_Cyrl_SRB_AS",
+      "sr_Cyrl_SRB_AI"
+    ).foreach(collationICU => {
+      val col = fetchCollation(collationICU)
+      assert(col.collator.getLocale(ULocale.VALID_LOCALE) != ULocale.ROOT)
+    })
+  }
+
+  test("invalid names of collations with ICU non-root localization") {

Review Comment:
   I missed these test cases. I made a suggestion for the test name that makes 
it more clear that we do not only test repeated specifiers, but also 
incompatible ones.



##########
common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java:
##########
@@ -179,35 +179,45 @@ public Collation(
      * Classes of collations are differentiated by most significant 3 bits 
(bit 31, 30 and 29),
      * bit 31 being most significant and bit 0 being least significant.
      * ---
+     * General collation id binary layout:

Review Comment:
   Also below of course.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to