This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7b1147a05a6c [SPARK-47567][SQL] Support LOCATE function to work with 
collated strings
7b1147a05a6c is described below

commit 7b1147a05a6ca54276538d766c089980b9ee5d59
Author: Milan Dankovic <milan.danko...@databricks.com>
AuthorDate: Mon Apr 29 17:24:36 2024 +0800

    [SPARK-47567][SQL] Support LOCATE function to work with collated strings
    
    ### What changes were proposed in this pull request?
    Extend built-in string functions to support non-binary, non-lowercase 
collation for: locate
    
    ### Why are the changes needed?
    Update collation support for built-in string functions in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users should now be able to use COLLATE within arguments for built-in 
string function LOCATE in Spark SQL queries, using non-binary collations such 
as UNICODE_CI.
    
    ### How was this patch tested?
    Unit tests for queries using StringLocate 
(`CollationStringExpressionsSuite.scala`).
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45791 from miland-db/miland-db/string-locate.
    
    Authored-by: Milan Dankovic <milan.danko...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/util/CollationSupport.java  | 38 +++++++++++++
 .../spark/unsafe/types/CollationSupportSuite.java  | 65 ++++++++++++++++++++++
 .../sql/catalyst/analysis/CollationTypeCasts.scala |  4 ++
 .../catalyst/expressions/stringExpressions.scala   | 14 +++--
 .../sql/CollationStringExpressionsSuite.scala      | 34 +++++++++++
 5 files changed, 149 insertions(+), 6 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index 0fc37c169612..0c81b99de916 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -403,6 +403,44 @@ public final class CollationSupport {
     }
   }
 
+  public static class StringLocate {
+    public static int exec(final UTF8String string, final UTF8String 
substring, final int start,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(string, substring, start);
+      } else if (collation.supportsLowercaseEquality) {
+        return execLowercase(string, substring, start);
+      } else {
+        return execICU(string, substring, start, collationId);
+      }
+    }
+    public static String genCode(final String string, final String substring, 
final int start,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.StringLocate.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s, %s, %d)", string, substring, 
start);
+      } else if (collation.supportsLowercaseEquality) {
+        return String.format(expr + "Lowercase(%s, %s, %d)", string, 
substring, start);
+      } else {
+        return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, 
start, collationId);
+      }
+    }
+    public static int execBinary(final UTF8String string, final UTF8String 
substring,
+        final int start) {
+      return string.indexOf(substring, start);
+    }
+    public static int execLowercase(final UTF8String string, final UTF8String 
substring,
+        final int start) {
+      return string.toLowerCase().indexOf(substring.toLowerCase(), start);
+    }
+    public static int execICU(final UTF8String string, final UTF8String 
substring, final int start,
+                              final int collationId) {
+      return CollationAwareUTF8String.indexOf(string, substring, start, 
collationId);
+    }
+  }
+
   // TODO: Add more collation-aware string expressions.
 
   /**
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 6c79fc821317..030c7a7a1e3c 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -652,6 +652,71 @@ public class CollationSupportSuite {
     assertReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy");
   }
 
+  private void assertLocate(String substring, String string, Integer start, 
String collationName,
+        Integer expected) throws SparkException {
+    UTF8String substr = UTF8String.fromString(substring);
+    UTF8String str = UTF8String.fromString(string);
+    int collationId = CollationFactory.collationNameToId(collationName);
+    assertEquals(expected, CollationSupport.StringLocate.exec(str, substr,
+      start - 1, collationId) + 1);
+  }
+
+  @Test
+  public void testLocate() throws SparkException {
+    // If you add tests with start < 1 be careful to understand the behavior 
of the indexOf method
+    // and usage of indexOf in the StringLocate class.
+    assertLocate("aa", "aaads", 1, "UTF8_BINARY", 1);
+    assertLocate("aa", "aaads", 2, "UTF8_BINARY", 2);
+    assertLocate("aa", "aaads", 3, "UTF8_BINARY", 0);
+    assertLocate("Aa", "aaads", 1, "UTF8_BINARY", 0);
+    assertLocate("Aa", "aAads", 1, "UTF8_BINARY", 2);
+    assertLocate("界x", "test大千世界X大千世界", 1, "UTF8_BINARY", 0);
+    assertLocate("界X", "test大千世界X大千世界", 1, "UTF8_BINARY", 8);
+    assertLocate("界", "test大千世界X大千世界", 13, "UTF8_BINARY", 13);
+    assertLocate("AA", "aaads", 1, "UTF8_BINARY_LCASE", 1);
+    assertLocate("aa", "aAads", 2, "UTF8_BINARY_LCASE", 2);
+    assertLocate("aa", "aaAds", 3, "UTF8_BINARY_LCASE", 0);
+    assertLocate("abC", "abcabc", 1, "UTF8_BINARY_LCASE", 1);
+    assertLocate("abC", "abCabc", 2, "UTF8_BINARY_LCASE", 4);
+    assertLocate("abc", "abcabc", 4, "UTF8_BINARY_LCASE", 4);
+    assertLocate("界x", "test大千世界X大千世界", 1, "UTF8_BINARY_LCASE", 8);
+    assertLocate("界X", "test大千世界Xtest大千世界", 1, "UTF8_BINARY_LCASE", 8);
+    assertLocate("界", "test大千世界X大千世界", 13, "UTF8_BINARY_LCASE", 13);
+    assertLocate("大千", "test大千世界大千世界", 1, "UTF8_BINARY_LCASE", 5);
+    assertLocate("大千", "test大千世界大千世界", 9, "UTF8_BINARY_LCASE", 9);
+    assertLocate("大千", "大千世界大千世界", 1, "UTF8_BINARY_LCASE", 1);
+    assertLocate("aa", "Aaads", 1, "UNICODE", 2);
+    assertLocate("AA", "aaads", 1, "UNICODE", 0);
+    assertLocate("aa", "aAads", 2, "UNICODE", 0);
+    assertLocate("aa", "aaAds", 3, "UNICODE", 0);
+    assertLocate("abC", "abcabc", 1, "UNICODE", 0);
+    assertLocate("abC", "abCabc", 2, "UNICODE", 0);
+    assertLocate("abC", "abCabC", 2, "UNICODE", 4);
+    assertLocate("abc", "abcabc", 1, "UNICODE", 1);
+    assertLocate("abc", "abcabc", 3, "UNICODE", 4);
+    assertLocate("界x", "test大千世界X大千世界", 1, "UNICODE", 0);
+    assertLocate("界X", "test大千世界X大千世界", 1, "UNICODE", 8);
+    assertLocate("界", "test大千世界X大千世界", 13, "UNICODE", 13);
+    assertLocate("AA", "aaads", 1, "UNICODE_CI", 1);
+    assertLocate("aa", "aAads", 2, "UNICODE_CI", 2);
+    assertLocate("aa", "aaAds", 3, "UNICODE_CI", 0);
+    assertLocate("abC", "abcabc", 1, "UNICODE_CI", 1);
+    assertLocate("abC", "abCabc", 2, "UNICODE_CI", 4);
+    assertLocate("abc", "abcabc", 4, "UNICODE_CI", 4);
+    assertLocate("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8);
+    assertLocate("界", "test大千世界X大千世界", 13, "UNICODE_CI", 13);
+    assertLocate("大千", "test大千世界大千世界", 1, "UNICODE_CI", 5);
+    assertLocate("大千", "test大千世界大千世界", 9, "UNICODE_CI", 9);
+    assertLocate("大千", "大千世界大千世界", 1, "UNICODE_CI", 1);
+    // Case-variable character length
+    assertLocate("i̇o", "İo世界大千世界", 1, "UNICODE_CI", 1);
+    assertLocate("i̇o", "大千İo世界大千世界", 1, "UNICODE_CI", 3);
+    assertLocate("i̇o", "世界İo大千世界大千İo", 4, "UNICODE_CI", 11);
+    assertLocate("İo", "i̇o世界大千世界", 1, "UNICODE_CI", 1);
+    assertLocate("İo", "大千i̇o世界大千世界", 1, "UNICODE_CI", 3);
+    assertLocate("İo", "世界i̇o大千世界大千i̇o", 4, "UNICODE_CI", 12); // 12 instead 
of 11
+  }
+
   // TODO: Test more collation-aware string expressions.
 
   /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index 3ae251e56772..f69218812d36 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -45,6 +45,10 @@ object CollationTypeCasts extends TypeCoercionRule {
           caseWhenExpr.elseValue.map(e => castStringType(e, 
outputStringType).getOrElse(e))
         CaseWhen(newBranches, newElseValue)
 
+    case stringLocate: StringLocate =>
+      stringLocate.withNewChildren(collateToSingleType(
+        Seq(stringLocate.first, stringLocate.second)) :+ stringLocate.third)
+
     case eltExpr: Elt =>
       eltExpr.withNewChildren(eltExpr.children.head +: 
collateToSingleType(eltExpr.children.tail))
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 91401a3ea3ae..2d7f9652986a 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1457,12 +1457,15 @@ case class StringLocate(substr: Expression, str: 
Expression, start: Expression)
     this(substr, str, Literal(1))
   }
 
+  final lazy val collationId: Int = 
first.dataType.asInstanceOf[StringType].collationId
+
   override def first: Expression = substr
   override def second: Expression = str
   override def third: Expression = start
   override def nullable: Boolean = substr.nullable || str.nullable
   override def dataType: DataType = IntegerType
-  override def inputTypes: Seq[DataType] = Seq(StringType, StringType, 
IntegerType)
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType)
 
   override def eval(input: InternalRow): Any = {
     val s = start.eval(input)
@@ -1482,9 +1485,8 @@ case class StringLocate(substr: Expression, str: 
Expression, start: Expression)
           if (sVal < 1) {
             0
           } else {
-            l.asInstanceOf[UTF8String].indexOf(
-              r.asInstanceOf[UTF8String],
-              s.asInstanceOf[Int] - 1) + 1
+            CollationSupport.StringLocate.exec(l.asInstanceOf[UTF8String],
+              r.asInstanceOf[UTF8String], s.asInstanceOf[Int] - 1, 
collationId) + 1;
           }
         }
       }
@@ -1505,8 +1507,8 @@ case class StringLocate(substr: Expression, str: 
Expression, start: Expression)
           ${strGen.code}
           if (!${strGen.isNull}) {
             if (${startGen.value} > 0) {
-              ${ev.value} = ${strGen.value}.indexOf(${substrGen.value},
-                ${startGen.value} - 1) + 1;
+              ${ev.value} = CollationSupport.StringLocate.exec(${strGen.value},
+              ${substrGen.value}, ${startGen.value} - 1, $collationId) + 1;
             }
           } else {
             ${ev.isNull} = true;
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 305c51c0b703..d88c15fb2325 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -661,6 +661,40 @@ class CollationStringExpressionsSuite
     assert(sql(query).schema.fields.head.dataType.sameType(StringType(0)))
   }
 
+  test("Support Locate string expression with collation") {
+    case class StringLocateTestCase[R](substring: String, string: String, 
start: Integer,
+        c: String, result: R)
+    val testCases = Seq(
+      // scalastyle:off
+      StringLocateTestCase("aa", "aaads", 0, "UTF8_BINARY", 0),
+      StringLocateTestCase("aa", "Aaads", 0, "UTF8_BINARY_LCASE", 0),
+      StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UTF8_BINARY_LCASE", 8),
+      StringLocateTestCase("aBc", "abcabc", 4, "UTF8_BINARY_LCASE", 4),
+      StringLocateTestCase("aa", "Aaads", 0, "UNICODE", 0),
+      StringLocateTestCase("abC", "abCabC", 2, "UNICODE", 4),
+      StringLocateTestCase("aa", "Aaads", 0, "UNICODE_CI", 0),
+      StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8)
+      // scalastyle:on
+    )
+    testCases.foreach(t => {
+      val query = s"SELECT locate(collate('${t.substring}','${t.c}')," +
+        s"collate('${t.string}','${t.c}'),${t.start})"
+      // Result & data type
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(IntegerType))
+      // Implicit casting
+      checkAnswer(sql(s"SELECT locate(collate('${t.substring}','${t.c}')," +
+        s"'${t.string}',${t.start})"), Row(t.result))
+      checkAnswer(sql(s"SELECT locate('${t.substring}',collate('${t.string}'," 
+
+        s"'${t.c}'),${t.start})"), Row(t.result))
+    })
+    // Collation mismatch
+    val collationMismatch = intercept[AnalysisException] {
+      sql("SELECT locate(collate('aBc', 'UTF8_BINARY'),collate('abcabc', 
'UTF8_BINARY_LCASE'),4)")
+    }
+    assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+  }
+
   // TODO: Add more tests for other string expressions
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to