This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a1608bbc3f1 [SPARK-48670][SQL] Providing suggestion as part of error 
message when invalid collation name is given
7a1608bbc3f1 is described below

commit 7a1608bbc3f1dfd7ffd1f9dc762cb369f47a8d43
Author: Aleksandar Tomic <[email protected]>
AuthorDate: Wed Jun 26 16:02:03 2024 +0800

    [SPARK-48670][SQL] Providing suggestion as part of error message when 
invalid collation name is given
    
    ### What changes were proposed in this pull request?
    
    This PR improves error reporting in collation space. Currently, when 
invalid collation name is provided, caller just gets information that collation 
name can't be accepted. This PR will also return a suggestion on valid 
collation name that is similar to invalid one that was provided.
    
    We propose following rules on generating the suggestion:
    1) Find locale that is the closest valid locale measured by Levenshtein 
distance.
    2) Remove duplicate modifiers (e.g. CS_AI_AI_CS becomes CS_AI).
    3) Remove invalid combinations (e.g. CS_CI becomes CS).
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Existing tests for invalid collation names are extended to also cover 
suggestion checks.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #47040 from dbatomic/collation_suggestion_on_error.
    
    Authored-by: Aleksandar Tomic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/util/CollationFactory.java  |  92 ++++++++++-
 .../spark/unsafe/types/CollationFactorySuite.scala | 184 +++++++++++----------
 .../src/main/resources/error/error-conditions.json |   4 +-
 .../org/apache/spark/sql/internal/SQLConf.scala    |   3 +-
 .../expressions/CollationExpressionSuite.scala     |   2 +-
 .../org/apache/spark/sql/CollationSuite.scala      |   2 +-
 .../apache/spark/sql/internal/SQLConfSuite.scala   |   3 +-
 7 files changed, 194 insertions(+), 96 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index b0f6c5c22991..50b218431c0c 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -20,6 +20,7 @@ import java.text.CharacterIterator;
 import java.text.StringCharacterIterator;
 import java.util.*;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Function;
 import java.util.function.BiFunction;
 import java.util.function.ToLongFunction;
 
@@ -318,9 +319,16 @@ public final class CollationFactory {
         }
       }
 
+      /**
+       * Method for constructing errors thrown on providing invalid collation 
name.
+       */
       protected static SparkException collationInvalidNameException(String 
collationName) {
+        Map<String, String> params = new HashMap<>();
+        final int maxSuggestions = 3;
+        params.put("collationName", collationName);
+        params.put("proposals", 
getClosestSuggestionsOnInvalidName(collationName, maxSuggestions));
         return new SparkException("COLLATION_INVALID_NAME",
-          SparkException.constructMessageParams(Map.of("collationName", 
collationName)), null);
+          SparkException.constructMessageParams(params), null);
       }
 
       private static int collationNameToId(String collationName) throws 
SparkException {
@@ -828,4 +836,86 @@ public final class CollationFactory {
     }
   }
 
+  /**
+   * Returns same string if collation name is valid or the closest suggestion 
if it is invalid.
+   */
+  public static String getClosestSuggestionsOnInvalidName(
+      String collationName, int maxSuggestions) {
+    String[] validRootNames;
+    String[] validModifiers;
+    if (collationName.startsWith("UTF8_")) {
+      validRootNames = new String[]{
+        Collation.CollationSpecUTF8Binary.UTF8_BINARY_COLLATION.collationName,
+        Collation.CollationSpecUTF8Binary.UTF8_LCASE_COLLATION.collationName
+      };
+      validModifiers = new String[0];
+    } else {
+      validRootNames = getICULocaleNames();
+      validModifiers = new String[]{"_CI", "_AI", "_CS", "_AS"};
+    }
+
+    // Split modifiers and locale name.
+    final int MODIFIER_LENGTH = 3;
+    String localeName = collationName.toUpperCase();
+    List<String> modifiers = new ArrayList<>();
+    while (Arrays.stream(validModifiers).anyMatch(localeName::endsWith)) {
+      modifiers.add(localeName.substring(localeName.length() - 
MODIFIER_LENGTH));
+      localeName = localeName.substring(0, localeName.length() - 
MODIFIER_LENGTH);
+    }
+
+    // Suggest version with unique modifiers.
+    Collections.reverse(modifiers);
+    modifiers = modifiers.stream().distinct().toList();
+
+    // Remove conflicting settings.
+    if (modifiers.contains("_CI") && modifiers.contains(("_CS"))) {
+      modifiers = modifiers.stream().filter(m -> !m.equals("_CI")).toList();
+    }
+
+    if (modifiers.contains("_AI") && modifiers.contains(("_AS"))) {
+      modifiers = modifiers.stream().filter(m -> !m.equals("_AI")).toList();
+    }
+
+    final String finalLocaleName = localeName;
+    Comparator<String> distanceComparator = (c1, c2) -> {
+      int distance1 = UTF8String.fromString(c1.toUpperCase())
+              .levenshteinDistance(UTF8String.fromString(finalLocaleName));
+      int distance2 = UTF8String.fromString(c2.toUpperCase())
+              .levenshteinDistance(UTF8String.fromString(finalLocaleName));
+      return Integer.compare(distance1, distance2);
+    };
+
+    String[] rootNamesByDistance = Arrays.copyOf(validRootNames, 
validRootNames.length);
+    Arrays.sort(rootNamesByDistance, distanceComparator);
+    Function<String, Boolean> isCollationNameValid = name -> {
+      try {
+        collationNameToId(name);
+        return true;
+      } catch (SparkException e) {
+        return false;
+      }
+    };
+
+    final int suggestionThreshold = 3;
+    final ArrayList<String> suggestions = new ArrayList<>(maxSuggestions);
+    for (int i = 0; i < maxSuggestions; i++) {
+      // Add at least one suggestion.
+      // Add others if distance from the original is lower than threshold.
+      String suggestion = rootNamesByDistance[i] + String.join("", modifiers);
+      assert(isCollationNameValid.apply(suggestion));
+      if (suggestions.isEmpty()) {
+        suggestions.add(suggestion);
+      } else {
+        int distance = UTF8String.fromString(suggestion.toUpperCase())
+          
.levenshteinDistance(UTF8String.fromString(collationName.toUpperCase()));
+        if (distance < suggestionThreshold) {
+          suggestions.add(suggestion);
+        } else {
+          break;
+        }
+      }
+    }
+
+    return String.join(", ", suggestions);
+  }
 }
diff --git 
a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
 
b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
index c539e859b550..3c29daeff168 100644
--- 
a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
+++ 
b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala
@@ -91,37 +91,35 @@ class CollationFactorySuite extends AnyFunSuite with 
Matchers { // scalastyle:ig
 
   test("fetch invalid UTF8_BINARY and ICU root locale collation names") {
     Seq(
-      "UTF8_BINARY_CS",
-      "UTF8_BINARY_AS",
-      "UTF8_BINARY_CS_AS",
-      "UTF8_BINARY_AS_CS",
-      "UTF8_BINARY_CI",
-      "UTF8_BINARY_AI",
-      "UTF8_BINARY_CI_AI",
-      "UTF8_BINARY_AI_CI",
-      "UTF8_BS",
-      "BINARY_UTF8",
-      "UTF8_BINARY_A",
-      "UNICODE_X",
-      "UNICODE_CI_X",
-      "UNICODE_LCASE_X",
-      "UTF8_UNICODE",
-      "UTF8_BINARY_UNICODE",
-      "CI_UNICODE",
-      "LCASE_UNICODE",
-      "UNICODE_UNSPECIFIED",
-      "UNICODE_CI_UNSPECIFIED",
-      "UNICODE_UNSPECIFIED_CI_UNSPECIFIED",
-      "UNICODE_INDETERMINATE",
-      "UNICODE_CI_INDETERMINATE"
-    ).foreach(collationName => {
-      val error = intercept[SparkException] {
-        fetchCollation(collationName)
-      }
-
+      ("UTF8_BINARY_CS", "UTF8_BINARY"),
+      ("UTF8_BINARY_AS", "UTF8_BINARY"), // this should be UNICODE_AS
+      ("UTF8_BINARY_CS_AS","UTF8_BINARY"), // this should be UNICODE_CS_AS
+      ("UTF8_BINARY_AS_CS","UTF8_BINARY"),
+      ("UTF8_BINARY_CI","UTF8_BINARY"),
+      ("UTF8_BINARY_AI","UTF8_BINARY"),
+      ("UTF8_BINARY_CI_AI","UTF8_BINARY"),
+      ("UTF8_BINARY_AI_CI","UTF8_BINARY"),
+      ("UTF8_BS","UTF8_LCASE"),
+      ("BINARY_UTF8","ar_SAU"),
+      ("UTF8_BINARY_A","UTF8_BINARY"),
+      ("UNICODE_X","UNICODE"),
+      ("UNICODE_CI_X","UNICODE"),
+      ("UNICODE_LCASE_X","UNICODE"),
+      ("UTF8_UNICODE","UTF8_LCASE"),
+      ("UTF8_BINARY_UNICODE","UTF8_BINARY"),
+      ("CI_UNICODE", "UNICODE"),
+      ("LCASE_UNICODE", "UNICODE"),
+      ("UNICODE_UNSPECIFIED", "UNICODE"),
+      ("UNICODE_CI_UNSPECIFIED", "UNICODE"),
+      ("UNICODE_UNSPECIFIED_CI_UNSPECIFIED", "UNICODE"),
+      ("UNICODE_INDETERMINATE", "UNICODE"),
+      ("UNICODE_CI_INDETERMINATE", "UNICODE")
+    ).foreach{case (collationName, proposals) =>
+      val error = intercept[SparkException] { fetchCollation(collationName) }
       assert(error.getErrorClass === "COLLATION_INVALID_NAME")
-      assert(error.getMessageParameters.asScala === Map("collationName" -> 
collationName))
-    })
+      assert(error.getMessageParameters.asScala === Map(
+        "collationName" -> collationName, "proposals" -> proposals))
+    }
   }
 
   case class CollationTestCase[R](collationName: String, s1: String, s2: 
String, expectedResult: R)
@@ -276,52 +274,55 @@ class CollationFactorySuite extends AnyFunSuite with 
Matchers { // scalastyle:ig
 
   test("invalid names of collations with ICU non-root localization") {
     Seq(
-      "en_US", // Must use 3-letter country code
-      "enn",
-      "en_AAA",
-      "en_Something",
-      "en_Something_USA",
-      "en_LCASE",
-      "en_UCASE",
-      "en_CI_LCASE",
-      "en_CI_UCASE",
-      "en_CI_UNSPECIFIED",
-      "en_USA_UNSPECIFIED",
-      "en_USA_UNSPECIFIED_CI",
-      "en_INDETERMINATE",
-      "en_USA_INDETERMINATE",
-      "en_Latn_USA", // Use en_USA instead.
-      "en_Cyrl_USA",
-      "en_USA_AAA",
-      "sr_Cyrl_SRB_AAA",
+      ("en_US", "en_USA"), // Must use 3-letter country code
+      ("eN_US", "en_USA"), // verify that proper casing is captured in error.
+      ("enn", "en, nn, bn"),
+      ("en_AAA", "en_USA"),
+      ("en_Something", "UNICODE"),
+      ("en_Something_USA", "en_USA"),
+      ("en_LCASE", "en_USA"),
+      ("en_UCASE", "en_USA"),
+      ("en_CI_LCASE", "UNICODE"),
+      ("en_CI_UCASE", "en_USA"),
+      ("en_CI_UNSPECIFIED", "en_USA"),
+      ("en_USA_UNSPECIFIED", "en_USA"),
+      ("en_USA_UNSPECIFIED_CI", "en_USA_CI"),
+      ("en_INDETERMINATE", "en_USA"),
+      ("en_USA_INDETERMINATE", "en_USA"),
+      ("en_Latn_USA", "en_USA"),
+      ("en_Cyrl_USA", "en_USA"),
+      ("en_USA_AAA", "en_USA"),
+      ("sr_Cyrl_SRB_AAA", "sr_Cyrl_SRB"),
       // Invalid ordering of language, script and country code.
-      "USA_en",
-      "sr_SRB_Cyrl",
-      "SRB_sr",
-      "SRB_sr_Cyrl",
-      "SRB_Cyrl_sr",
-      "Cyrl_sr",
-      "Cyrl_sr_SRB",
-      "Cyrl_SRB_sr",
+      ("USA_en", "en"),
+      ("sr_SRB_Cyrl", "sr_Cyrl"),
+      ("SRB_sr", "ar_SAU"),
+      ("SRB_sr_Cyrl", "bs_Cyrl"),
+      ("SRB_Cyrl_sr", "sr_Cyrl_SRB"),
+      ("Cyrl_sr", "sr_Cyrl_SRB"),
+      ("Cyrl_sr_SRB", "sr_Cyrl_SRB"),
+      ("Cyrl_SRB_sr", "sr_Cyrl_SRB"),
       // Collation specifiers in the middle of locale.
-      "CI_en",
-      "USA_CI_en",
-      "en_CI_USA",
-      "CI_sr_Cyrl_SRB",
-      "sr_CI_Cyrl_SRB",
-      "sr_Cyrl_CI_SRB",
-      "CI_Cyrl_sr",
-      "Cyrl_CI_sr",
-      "Cyrl_CI_sr_SRB",
-      "Cyrl_sr_CI_SRB"
-    ).foreach(collationName => {
-      val error = intercept[SparkException] {
-        fetchCollation(collationName)
-      }
-
+      ("CI_en", "ceb"),
+      ("USA_CI_en", "UNICODE"),
+      ("en_CI_USA", "en_USA"),
+      ("CI_sr_Cyrl_SRB", "sr_Cyrl_SRB"),
+      ("sr_CI_Cyrl_SRB", "sr_Cyrl_SRB"),
+      ("sr_Cyrl_CI_SRB", "sr_Cyrl_SRB"),
+      ("CI_Cyrl_sr", "sr_Cyrl_SRB"),
+      ("Cyrl_CI_sr", "he_ISR"),
+      ("Cyrl_CI_sr_SRB", "sr_Cyrl_SRB"),
+      ("Cyrl_sr_CI_SRB", "sr_Cyrl_SRB"),
+      // no locale specified
+      ("_CI_AI", "af_CI_AI, am_CI_AI, ar_CI_AI"),
+      ("", "af, am, ar")
+    ).foreach { case (collationName, proposals) => {
+      val error = intercept[SparkException] { fetchCollation(collationName) }
       assert(error.getErrorClass === "COLLATION_INVALID_NAME")
-      assert(error.getMessageParameters.asScala === Map("collationName" -> 
collationName))
-    })
+
+      assert(error.getMessageParameters.asScala === Map(
+        "collationName" -> collationName, "proposals" -> proposals))
+    }}
   }
 
   test("collations name normalization for ICU non-root localization") {
@@ -407,28 +408,33 @@ class CollationFactorySuite extends AnyFunSuite with 
Matchers { // scalastyle:ig
     })
   }
 
-  test("repeated and/or incompatible specifiers in collation name") {
+  test("repeated and/or incompatible and/or misplaced specifiers in collation 
name") {
     Seq(
-      "UTF8_LCASE_LCASE",
-      "UNICODE_CS_CS",
-      "UNICODE_CI_CI",
-      "UNICODE_CI_CS",
-      "UNICODE_CS_CI",
-      "UNICODE_AS_AS",
-      "UNICODE_AI_AI",
-      "UNICODE_AS_AI",
-      "UNICODE_AI_AS",
-      "UNICODE_AS_CS_AI",
-      "UNICODE_CS_AI_CI",
-      "UNICODE_CS_AS_CI_AI"
-    ).foreach(collationName => {
+      ("UTF8_LCASE_LCASE", "UTF8_LCASE"),
+      ("UNICODE_CS_CS", "UNICODE_CS"),
+      ("UNICODE_CI_CI", "UNICODE_CI"),
+      ("UNICODE_CI_CS", "UNICODE_CS"),
+      ("UNICODE_CS_CI", "UNICODE_CS"),
+      ("UNICODE_AS_AS", "UNICODE_AS"),
+      ("UNICODE_AI_AI", "UNICODE_AI"),
+      ("UNICODE_AS_AI", "UNICODE_AS"),
+      ("UNICODE_AI_AS", "UNICODE_AS"),
+      ("UNICODE_AS_CS_AI", "UNICODE_AS_CS"),
+      ("UNICODE_CS_AI_CI", "UNICODE_CS_AI"),
+      ("UNICODE_CS_AS_CI_AI", "UNICODE_CS_AS"),
+      ("UNICODE__CS__AS", "UNICODE_AS"),
+      ("UNICODE-CS-AS", "UNICODE"),
+      ("UNICODECSAS", "UNICODE"),
+      ("_CS_AS_UNICODE", "UNICODE")
+    ).foreach { case (collationName, proposals) =>
       val error = intercept[SparkException] {
         fetchCollation(collationName)
       }
 
       assert(error.getErrorClass === "COLLATION_INVALID_NAME")
-      assert(error.getMessageParameters.asScala === Map("collationName" -> 
collationName))
-    })
+      assert(error.getMessageParameters.asScala === Map(
+        "collationName" -> collationName, "proposals" -> proposals))
+    }
   }
 
   test("basic ICU collator checks") {
diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 72f358f87d62..2d9ae7a89b81 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -487,7 +487,7 @@
   },
   "COLLATION_INVALID_NAME" : {
     "message" : [
-      "The value <collationName> does not represent a correct collation name."
+      "The value <collationName> does not represent a correct collation name. 
Suggested valid collation names: [<proposals>]."
     ],
     "sqlState" : "42704"
   },
@@ -1944,7 +1944,7 @@
     "subClass" : {
       "DEFAULT_COLLATION" : {
         "message" : [
-          "Cannot resolve the given default collation."
+          "Cannot resolve the given default collation. Suggested valid 
collation names: ['<proposals>']?"
         ]
       },
       "TIME_ZONE" : {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d0dc75017fa6..3b6374e712c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -782,7 +782,8 @@ object SQLConf {
           }
         },
         "DEFAULT_COLLATION",
-        _ => Map())
+        collationName => Map(
+          "proposals" -> 
CollationFactory.getClosestSuggestionsOnInvalidName(collationName, 3)))
       .createWithDefault("UTF8_BINARY")
 
   val ICU_CASE_MAPPINGS_ENABLED =
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
index d20dab8eb8ac..a4651c6c4c7e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
@@ -63,7 +63,7 @@ class CollationExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       exception = intercept[SparkException] { Collate(Literal("abc"), 
"UTF8_BS") },
       errorClass = "COLLATION_INVALID_NAME",
       sqlState = "42704",
-      parameters = Map("collationName" -> "UTF8_BS"))
+      parameters = Map("collationName" -> "UTF8_BS", "proposals" -> 
"UTF8_LCASE"))
   }
 
   test("collation on non-explicit default collation") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index c4eaedfb215e..f662b86eaf81 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -153,7 +153,7 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
       exception = intercept[SparkException] { sql("select 'aaa' collate 
UTF8_BS") },
       errorClass = "COLLATION_INVALID_NAME",
       sqlState = "42704",
-      parameters = Map("collationName" -> "UTF8_BS"))
+      parameters = Map("collationName" -> "UTF8_BS", "proposals" -> 
"UTF8_LCASE"))
   }
 
   test("disable bucketing on collated string column") {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index 0639913c8f81..404ec865c1b0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -514,7 +514,8 @@ class SQLConfSuite extends QueryTest with 
SharedSparkSession {
       errorClass = "INVALID_CONF_VALUE.DEFAULT_COLLATION",
       parameters = Map(
         "confValue" -> "UNICODE_C",
-        "confName" -> "spark.sql.session.collation.default"
+        "confName" -> "spark.sql.session.collation.default",
+        "proposals" -> "UNICODE"
       ))
 
     withSQLConf(SQLConf.COLLATION_ENABLED.key -> "false") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to