This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 256fc51508e4 [SPARK-47411][SQL] Support StringInstr & FindInSet 
functions to work with collated strings
256fc51508e4 is described below

commit 256fc51508e4eac3efc4746ef0ef92132bc40643
Author: Milan Dankovic <milan.danko...@databricks.com>
AuthorDate: Mon Apr 22 20:45:43 2024 +0800

    [SPARK-47411][SQL] Support StringInstr & FindInSet functions to work with 
collated strings
    
    ### What changes were proposed in this pull request?
    Extend built-in string functions to support non-binary, non-lowercase 
collation for: instr & find_in_set.
    
    ### Why are the changes needed?
    Update collation support for built-in string functions in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users should now be able to use COLLATE within arguments for built-in 
string functions INSTR and FIND_IN_SET in Spark SQL queries, using non-binary 
collations such as UNICODE_CI.
    
    ### How was this patch tested?
    Unit tests for queries using "collate" (CollationSuite).
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45643 from miland-db/miland-db/substr-functions.
    
    Authored-by: Milan Dankovic <milan.danko...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/util/CollationFactory.java  |  17 +++-
 .../spark/sql/catalyst/util/CollationSupport.java  | 112 +++++++++++++++++++++
 .../spark/unsafe/types/CollationSupportSuite.java  |  82 +++++++++++++++
 .../catalyst/expressions/stringExpressions.scala   |  28 ++++--
 .../sql/CollationStringExpressionsSuite.scala      |  62 ++++++++++++
 5 files changed, 288 insertions(+), 13 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index 9786c559da44..93691e28c692 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -196,10 +196,21 @@ public final class CollationFactory {
       final UTF8String targetUTF8String,
       final UTF8String patternUTF8String,
       final int collationId) {
-    String pattern = patternUTF8String.toString();
-    CharacterIterator target = new 
StringCharacterIterator(targetUTF8String.toString());
+    return getStringSearch(targetUTF8String.toString(), 
patternUTF8String.toString(), collationId);
+  }
+
+  /**
+   * Returns a StringSearch object for the given pattern and target strings, 
under collation
+   * rules corresponding to the given collationId. The external ICU library 
StringSearch object can
+   * be used to find occurrences of the pattern in the target string, while 
respecting collation.
+   */
+  public static StringSearch getStringSearch(
+          final String targetString,
+          final String patternString,
+          final int collationId) {
+    CharacterIterator target = new StringCharacterIterator(targetString);
     Collator collator = CollationFactory.fetchCollation(collationId).collator;
-    return new StringSearch(pattern, target, (RuleBasedCollator) collator);
+    return new StringSearch(patternString, target, (RuleBasedCollator) 
collator);
   }
 
   /**
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index f54e6b162a93..d54e297413f4 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -144,6 +144,76 @@ public final class CollationSupport {
     }
   }
 
+  public static class FindInSet {
+    public static int exec(final UTF8String word, final UTF8String set, final 
int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(word, set);
+      } else if (collation.supportsLowercaseEquality) {
+        return execLowercase(word, set);
+      } else {
+        return execICU(word, set, collationId);
+      }
+    }
+    public static String genCode(final String word, final String set, final 
int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.FindInSet.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s, %s)", word, set);
+      } else if (collation.supportsLowercaseEquality) {
+        return String.format(expr + "Lowercase(%s, %s)", word, set);
+      } else {
+        return String.format(expr + "ICU(%s, %s, %d)", word, set, collationId);
+      }
+    }
+    public static int execBinary(final UTF8String word, final UTF8String set) {
+      return set.findInSet(word);
+    }
+    public static int execLowercase(final UTF8String word, final UTF8String 
set) {
+      return set.toLowerCase().findInSet(word.toLowerCase());
+    }
+    public static int execICU(final UTF8String word, final UTF8String set,
+                                  final int collationId) {
+      return CollationAwareUTF8String.findInSet(word, set, collationId);
+    }
+  }
+
+  public static class StringInstr {
+    public static int exec(final UTF8String string, final UTF8String substring,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(string, substring);
+      } else if (collation.supportsLowercaseEquality) {
+        return execLowercase(string, substring);
+      } else {
+        return execICU(string, substring, collationId);
+      }
+    }
+    public static String genCode(final String string, final String substring,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.StringInstr.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s, %s)", string, substring);
+      } else if (collation.supportsLowercaseEquality) {
+        return String.format(expr + "Lowercase(%s, %s)", string, substring);
+      } else {
+        return String.format(expr + "ICU(%s, %s, %d)", string, substring, 
collationId);
+      }
+    }
+    public static int execBinary(final UTF8String string, final UTF8String 
substring) {
+      return string.indexOf(substring, 0);
+    }
+    public static int execLowercase(final UTF8String string, final UTF8String 
substring) {
+      return string.toLowerCase().indexOf(substring.toLowerCase(), 0);
+    }
+    public static int execICU(final UTF8String string, final UTF8String 
substring,
+        final int collationId) {
+      return CollationAwareUTF8String.indexOf(string, substring, 0, 
collationId);
+    }
+  }
+
   // TODO: Add more collation-aware string expressions.
 
   /**
@@ -164,6 +234,48 @@ public final class CollationSupport {
 
   private static class CollationAwareUTF8String {
 
+    private static int findInSet(final UTF8String match, final UTF8String set, 
int collationId) {
+      if (match.contains(UTF8String.fromString(","))) {
+        return 0;
+      }
+
+      String setString = set.toString();
+      StringSearch stringSearch = CollationFactory.getStringSearch(setString, 
match.toString(),
+        collationId);
+
+      int wordStart = 0;
+      while ((wordStart = stringSearch.next()) != StringSearch.DONE) {
+        boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 
1) == ',';
+        boolean isValidEnd = wordStart + stringSearch.getMatchLength() == 
setString.length()
+                || setString.charAt(wordStart + stringSearch.getMatchLength()) 
== ',';
+
+        if (isValidStart && isValidEnd) {
+          int pos = 0;
+          for (int i = 0; i < setString.length() && i < wordStart; i++) {
+            if (setString.charAt(i) == ',') {
+              pos++;
+            }
+          }
+
+          return pos + 1;
+        }
+      }
+
+      return 0;
+    }
+
+    private static int indexOf(final UTF8String target, final UTF8String 
pattern,
+        final int start, final int collationId) {
+      if (pattern.numBytes() == 0) {
+        return 0;
+      }
+
+      StringSearch stringSearch = CollationFactory.getStringSearch(target, 
pattern, collationId);
+      stringSearch.setIndex(start);
+
+      return stringSearch.next();
+    }
+
   }
 
 }
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 3c0d999089e7..36acf1c9b7a6 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -261,6 +261,88 @@ public class CollationSupportSuite {
     assertEndsWith("The i̇o", "İo", "UNICODE_CI", true);
   }
 
+  private void assertStringInstr(String string, String substring, String 
collationName,
+          Integer expected) throws SparkException {
+    UTF8String str = UTF8String.fromString(string);
+    UTF8String substr = UTF8String.fromString(substring);
+    int collationId = CollationFactory.collationNameToId(collationName);
+    assertEquals(expected, CollationSupport.StringInstr.exec(str, substr, 
collationId) + 1);
+  }
+
+  @Test
+  public void testStringInstr() throws SparkException {
+    assertStringInstr("aaads", "Aa", "UTF8_BINARY", 0);
+    assertStringInstr("aaaDs", "de", "UTF8_BINARY", 0);
+    assertStringInstr("aaads", "ds", "UTF8_BINARY", 4);
+    assertStringInstr("xxxx", "", "UTF8_BINARY", 1);
+    assertStringInstr("", "xxxx", "UTF8_BINARY", 0);
+    assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY", 5);
+    assertStringInstr("test大千世界X大千世界", "界X", "UTF8_BINARY", 8);
+    assertStringInstr("aaads", "Aa", "UTF8_BINARY_LCASE", 1);
+    assertStringInstr("aaaDs", "de", "UTF8_BINARY_LCASE", 0);
+    assertStringInstr("aaaDs", "ds", "UTF8_BINARY_LCASE", 4);
+    assertStringInstr("xxxx", "", "UTF8_BINARY_LCASE", 1);
+    assertStringInstr("", "xxxx", "UTF8_BINARY_LCASE", 0);
+    assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY_LCASE", 5);
+    assertStringInstr("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8);
+    assertStringInstr("aaads", "Aa", "UNICODE", 0);
+    assertStringInstr("aaads", "aa", "UNICODE", 1);
+    assertStringInstr("aaads", "de", "UNICODE", 0);
+    assertStringInstr("xxxx", "", "UNICODE", 1);
+    assertStringInstr("", "xxxx", "UNICODE", 0);
+    assertStringInstr("test大千世界X大千世界", "界x", "UNICODE", 0);
+    assertStringInstr("test大千世界X大千世界", "界X", "UNICODE", 8);
+    assertStringInstr("aaads", "AD", "UNICODE_CI", 3);
+    assertStringInstr("aaads", "dS", "UNICODE_CI", 4);
+    assertStringInstr("test大千世界X大千世界", "界y", "UNICODE_CI", 0);
+    assertStringInstr("test大千世界X大千世界", "界x", "UNICODE_CI", 8);
+    assertStringInstr("abİo12", "i̇o", "UNICODE_CI", 3);
+    assertStringInstr("abi̇o12", "İo", "UNICODE_CI", 3);
+  }
+
+  private void assertFindInSet(String word, String set, String collationName,
+        Integer expected) throws SparkException {
+    UTF8String w = UTF8String.fromString(word);
+    UTF8String s = UTF8String.fromString(set);
+    int collationId = CollationFactory.collationNameToId(collationName);
+    assertEquals(expected, CollationSupport.FindInSet.exec(w, s, collationId));
+  }
+
+  @Test
+  public void testFindInSet() throws SparkException {
+    assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0);
+    assertFindInSet("abc", "abc,b,ab,c,def", "UTF8_BINARY", 1);
+    assertFindInSet("def", "abc,b,ab,c,def", "UTF8_BINARY", 5);
+    assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY", 0);
+    assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY", 0);
+    assertFindInSet("a", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
+    assertFindInSet("c", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4);
+    assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 3);
+    assertFindInSet("AbC", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 1);
+    assertFindInSet("abcd", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
+    assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
+    assertFindInSet("XX", "xx", "UTF8_BINARY_LCASE", 1);
+    assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
+    assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UTF8_BINARY_LCASE", 4);
+    assertFindInSet("a", "abc,b,ab,c,def", "UNICODE", 0);
+    assertFindInSet("ab", "abc,b,ab,c,def", "UNICODE", 3);
+    assertFindInSet("Ab", "abc,b,ab,c,def", "UNICODE", 0);
+    assertFindInSet("d,ef", "abc,b,ab,c,def", "UNICODE", 0);
+    assertFindInSet("xx", "xx", "UNICODE", 1);
+    assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE", 0);
+    assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE", 5);
+    assertFindInSet("a", "abc,b,ab,c,def", "UNICODE_CI", 0);
+    assertFindInSet("C", "abc,b,ab,c,def", "UNICODE_CI", 4);
+    assertFindInSet("DeF", "abc,b,ab,c,dEf", "UNICODE_CI", 5);
+    assertFindInSet("DEFG", "abc,b,ab,c,def", "UNICODE_CI", 0);
+    assertFindInSet("XX", "xx", "UNICODE_CI", 1);
+    assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 4);
+    assertFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", "UNICODE_CI", 5);
+    assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 5);
+    assertFindInSet("i̇o", "ab,İo,12", "UNICODE_CI", 2);
+    assertFindInSet("İo", "ab,i̇o,12", "UNICODE_CI", 2);
+  }
+
   // TODO: Test more collation-aware string expressions.
 
   /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index bd2c3baf4fe8..2b7703ed82b3 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -978,15 +978,19 @@ case class StringTranslate(srcExpr: Expression, 
matchingExpr: Expression, replac
 case class FindInSet(left: Expression, right: Expression) extends 
BinaryExpression
     with ImplicitCastInputTypes with NullIntolerant {
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
+  final lazy val collationId: Int = 
left.dataType.asInstanceOf[StringType].collationId
 
-  override protected def nullSafeEval(word: Any, set: Any): Any =
-    set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String])
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringTypeAnyCollation, StringTypeAnyCollation)
+
+  override protected def nullSafeEval(word: Any, set: Any): Any = {
+    CollationSupport.FindInSet.
+      exec(word.asInstanceOf[UTF8String], set.asInstanceOf[UTF8String], 
collationId)
+  }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    nullSafeCodeGen(ctx, ev, (word, set) =>
-      s"${ev.value} = $set.findInSet($word);"
-    )
+    defineCodeGen(ctx, ev, (word, set) => CollationSupport.FindInSet.
+      genCode(word, set, collationId))
   }
 
   override def dataType: DataType = IntegerType
@@ -1350,20 +1354,24 @@ case class StringTrimRight(srcStr: Expression, trimStr: 
Option[Expression] = Non
 case class StringInstr(str: Expression, substr: Expression)
   extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
+  final lazy val collationId: Int = 
left.dataType.asInstanceOf[StringType].collationId
+
   override def left: Expression = str
   override def right: Expression = substr
   override def dataType: DataType = IntegerType
-  override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringTypeAnyCollation, StringTypeAnyCollation)
 
   override def nullSafeEval(string: Any, sub: Any): Any = {
-    string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 
1
+    CollationSupport.StringInstr.
+      exec(string.asInstanceOf[UTF8String], sub.asInstanceOf[UTF8String], 
collationId) + 1
   }
 
   override def prettyName: String = "instr"
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    defineCodeGen(ctx, ev, (l, r) =>
-      s"($l).indexOf($r, 0) + 1")
+      defineCodeGen(ctx, ev, (string, substring) =>
+        CollationSupport.StringInstr.genCode(string, substring, collationId) + 
" + 1")
   }
 
   override protected def withNewChildrenInternal(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 07be8d48e869..35f63ce010a9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -115,6 +115,68 @@ class CollationStringExpressionsSuite
     assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
   }
 
+  test("Support StringInStr string expression with collation") {
+    case class StringInStrTestCase[R](string: String, substring: String, c: 
String, result: R)
+    val testCases = Seq(
+      // scalastyle:off
+      StringInStrTestCase("test大千世界X大千世界", "大千", "UTF8_BINARY", 5),
+      StringInStrTestCase("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8),
+      StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE", 0),
+      StringInStrTestCase("test大千世界X大千世界", "界y", "UNICODE_CI", 0),
+      StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8),
+      StringInStrTestCase("abİo12", "i̇o", "UNICODE_CI", 3)
+      // scalastyle:on
+    )
+    testCases.foreach(t => {
+      val query = s"SELECT instr(collate('${t.string}','${t.c}')," +
+        s"collate('${t.substring}','${t.c}'))"
+      // Result & data type
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(IntegerType))
+      // Implicit casting
+      checkAnswer(sql(s"SELECT instr(collate('${t.string}','${t.c}')," +
+        s"'${t.substring}')"), Row(t.result))
+      checkAnswer(sql(s"SELECT instr('${t.string}'," +
+        s"collate('${t.substring}','${t.c}'))"), Row(t.result))
+    })
+    // Collation mismatch
+    val collationMismatch = intercept[AnalysisException] {
+      sql(s"SELECT instr(collate('aaads','UTF8_BINARY'), 
collate('Aa','UTF8_BINARY_LCASE'))")
+    }
+    assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+  }
+
+  test("Support FindInSet string expression with collation") {
+    case class FindInSetTestCase[R](word: String, set: String, c: String, 
result: R)
+    val testCases = Seq(
+      FindInSetTestCase("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0),
+      FindInSetTestCase("C", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4),
+      FindInSetTestCase("d,ef", "abc,b,ab,c,def", "UNICODE", 0),
+      // scalastyle:off
+      FindInSetTestCase("i̇o", "ab,İo,12", "UNICODE_CI", 2),
+      FindInSetTestCase("İo", "ab,i̇o,12", "UNICODE_CI", 2)
+      // scalastyle:on
+    )
+    testCases.foreach(t => {
+      val query = s"SELECT find_in_set(collate('${t.word}', '${t.c}')," +
+        s"collate('${t.set}', '${t.c}'))"
+      // Result & data type
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(IntegerType))
+      // Implicit casting
+      checkAnswer(sql(s"SELECT find_in_set(collate('${t.word}', '${t.c}')," +
+        s"'${t.set}')"), Row(t.result))
+      checkAnswer(sql(s"SELECT find_in_set('${t.word}'," +
+        s"collate('${t.set}', '${t.c}'))"), Row(t.result))
+    })
+    // Collation mismatch
+    val collationMismatch = intercept[AnalysisException] {
+      sql(s"SELECT find_in_set(collate('AB','UTF8_BINARY')," +
+        s"collate('ab,xyz,fgh','UTF8_BINARY_LCASE'))")
+    }
+    assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+  }
+
   test("Support StartsWith string expression with collation") {
     // Supported collations
     case class StartsWithTestCase[R](l: String, r: String, c: String, result: 
R)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to