This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7b1147a05a6c [SPARK-47567][SQL] Support LOCATE function to work with collated strings 7b1147a05a6c is described below commit 7b1147a05a6ca54276538d766c089980b9ee5d59 Author: Milan Dankovic <milan.danko...@databricks.com> AuthorDate: Mon Apr 29 17:24:36 2024 +0800 [SPARK-47567][SQL] Support LOCATE function to work with collated strings ### What changes were proposed in this pull request? Extend built-in string functions to support non-binary, non-lowercase collation for: locate ### Why are the changes needed? Update collation support for built-in string functions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use COLLATE within arguments for built-in string function LOCATE in Spark SQL queries, using non-binary collations such as UNICODE_CI. ### How was this patch tested? Unit tests for queries using StringLocate (`CollationStringExpressionsSuite.scala`). ### Was this patch authored or co-authored using generative AI tooling? No Closes #45791 from miland-db/miland-db/string-locate. Authored-by: Milan Dankovic <milan.danko...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/util/CollationSupport.java | 38 +++++++++++++ .../spark/unsafe/types/CollationSupportSuite.java | 65 ++++++++++++++++++++++ .../sql/catalyst/analysis/CollationTypeCasts.scala | 4 ++ .../catalyst/expressions/stringExpressions.scala | 14 +++-- .../sql/CollationStringExpressionsSuite.scala | 34 +++++++++++ 5 files changed, 149 insertions(+), 6 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 0fc37c169612..0c81b99de916 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -403,6 +403,44 @@ public final class CollationSupport { } } + public static class StringLocate { + public static int exec(final UTF8String string, final UTF8String substring, final int start, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(string, substring, start); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(string, substring, start); + } else { + return execICU(string, substring, start, collationId); + } + } + public static String genCode(final String string, final String substring, final int start, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringLocate.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s, %d)", string, substring, start); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start); + } else { + return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId); + } + } + public static int execBinary(final UTF8String string, final UTF8String substring, + final int start) { + return string.indexOf(substring, start); + } + public static int execLowercase(final UTF8String string, final UTF8String substring, + final int start) { + return string.toLowerCase().indexOf(substring.toLowerCase(), start); + } + public static int execICU(final UTF8String string, final UTF8String substring, final int start, + final int collationId) { + return CollationAwareUTF8String.indexOf(string, substring, start, collationId); + } + } + // TODO: Add more collation-aware string expressions. /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 6c79fc821317..030c7a7a1e3c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -652,6 +652,71 @@ public class CollationSupportSuite { assertReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"); } + private void assertLocate(String substring, String string, Integer start, String collationName, + Integer expected) throws SparkException { + UTF8String substr = UTF8String.fromString(substring); + UTF8String str = UTF8String.fromString(string); + int collationId = CollationFactory.collationNameToId(collationName); + assertEquals(expected, CollationSupport.StringLocate.exec(str, substr, + start - 1, collationId) + 1); + } + + @Test + public void testLocate() throws SparkException { + // If you add tests with start < 1 be careful to understand the behavior of the indexOf method + // and usage of indexOf in the StringLocate class. + assertLocate("aa", "aaads", 1, "UTF8_BINARY", 1); + assertLocate("aa", "aaads", 2, "UTF8_BINARY", 2); + assertLocate("aa", "aaads", 3, "UTF8_BINARY", 0); + assertLocate("Aa", "aaads", 1, "UTF8_BINARY", 0); + assertLocate("Aa", "aAads", 1, "UTF8_BINARY", 2); + assertLocate("界x", "test大千世界X大千世界", 1, "UTF8_BINARY", 0); + assertLocate("界X", "test大千世界X大千世界", 1, "UTF8_BINARY", 8); + assertLocate("界", "test大千世界X大千世界", 13, "UTF8_BINARY", 13); + assertLocate("AA", "aaads", 1, "UTF8_BINARY_LCASE", 1); + assertLocate("aa", "aAads", 2, "UTF8_BINARY_LCASE", 2); + assertLocate("aa", "aaAds", 3, "UTF8_BINARY_LCASE", 0); + assertLocate("abC", "abcabc", 1, "UTF8_BINARY_LCASE", 1); + assertLocate("abC", "abCabc", 2, "UTF8_BINARY_LCASE", 4); + assertLocate("abc", "abcabc", 4, "UTF8_BINARY_LCASE", 4); + assertLocate("界x", "test大千世界X大千世界", 1, "UTF8_BINARY_LCASE", 8); + assertLocate("界X", "test大千世界Xtest大千世界", 1, "UTF8_BINARY_LCASE", 8); + assertLocate("界", "test大千世界X大千世界", 13, "UTF8_BINARY_LCASE", 13); + assertLocate("大千", "test大千世界大千世界", 1, "UTF8_BINARY_LCASE", 5); + assertLocate("大千", "test大千世界大千世界", 9, "UTF8_BINARY_LCASE", 9); + assertLocate("大千", "大千世界大千世界", 1, "UTF8_BINARY_LCASE", 1); + assertLocate("aa", "Aaads", 1, "UNICODE", 2); + assertLocate("AA", "aaads", 1, "UNICODE", 0); + assertLocate("aa", "aAads", 2, "UNICODE", 0); + assertLocate("aa", "aaAds", 3, "UNICODE", 0); + assertLocate("abC", "abcabc", 1, "UNICODE", 0); + assertLocate("abC", "abCabc", 2, "UNICODE", 0); + assertLocate("abC", "abCabC", 2, "UNICODE", 4); + assertLocate("abc", "abcabc", 1, "UNICODE", 1); + assertLocate("abc", "abcabc", 3, "UNICODE", 4); + assertLocate("界x", "test大千世界X大千世界", 1, "UNICODE", 0); + assertLocate("界X", "test大千世界X大千世界", 1, "UNICODE", 8); + assertLocate("界", "test大千世界X大千世界", 13, "UNICODE", 13); + assertLocate("AA", "aaads", 1, "UNICODE_CI", 1); + assertLocate("aa", "aAads", 2, "UNICODE_CI", 2); + assertLocate("aa", "aaAds", 3, "UNICODE_CI", 0); + assertLocate("abC", "abcabc", 1, "UNICODE_CI", 1); + assertLocate("abC", "abCabc", 2, "UNICODE_CI", 4); + assertLocate("abc", "abcabc", 4, "UNICODE_CI", 4); + assertLocate("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8); + assertLocate("界", "test大千世界X大千世界", 13, "UNICODE_CI", 13); + assertLocate("大千", "test大千世界大千世界", 1, "UNICODE_CI", 5); + assertLocate("大千", "test大千世界大千世界", 9, "UNICODE_CI", 9); + assertLocate("大千", "大千世界大千世界", 1, "UNICODE_CI", 1); + // Case-variable character length + assertLocate("i̇o", "İo世界大千世界", 1, "UNICODE_CI", 1); + assertLocate("i̇o", "大千İo世界大千世界", 1, "UNICODE_CI", 3); + assertLocate("i̇o", "世界İo大千世界大千İo", 4, "UNICODE_CI", 11); + assertLocate("İo", "i̇o世界大千世界", 1, "UNICODE_CI", 1); + assertLocate("İo", "大千i̇o世界大千世界", 1, "UNICODE_CI", 3); + assertLocate("İo", "世界i̇o大千世界大千i̇o", 4, "UNICODE_CI", 12); // 12 instead of 11 + } + // TODO: Test more collation-aware string expressions. /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 3ae251e56772..f69218812d36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -45,6 +45,10 @@ object CollationTypeCasts extends TypeCoercionRule { caseWhenExpr.elseValue.map(e => castStringType(e, outputStringType).getOrElse(e)) CaseWhen(newBranches, newElseValue) + case stringLocate: StringLocate => + stringLocate.withNewChildren(collateToSingleType( + Seq(stringLocate.first, stringLocate.second)) :+ stringLocate.third) + case eltExpr: Elt => eltExpr.withNewChildren(eltExpr.children.head +: collateToSingleType(eltExpr.children.tail)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 91401a3ea3ae..2d7f9652986a 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1457,12 +1457,15 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) this(substr, str, Literal(1)) } + final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId + override def first: Expression = substr override def second: Expression = str override def third: Expression = start override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) override def eval(input: InternalRow): Any = { val s = start.eval(input) @@ -1482,9 +1485,8 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) if (sVal < 1) { 0 } else { - l.asInstanceOf[UTF8String].indexOf( - r.asInstanceOf[UTF8String], - s.asInstanceOf[Int] - 1) + 1 + CollationSupport.StringLocate.exec(l.asInstanceOf[UTF8String], + r.asInstanceOf[UTF8String], s.asInstanceOf[Int] - 1, collationId) + 1; } } } @@ -1505,8 +1507,8 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) ${strGen.code} if (!${strGen.isNull}) { if (${startGen.value} > 0) { - ${ev.value} = ${strGen.value}.indexOf(${substrGen.value}, - ${startGen.value} - 1) + 1; + ${ev.value} = CollationSupport.StringLocate.exec(${strGen.value}, + ${substrGen.value}, ${startGen.value} - 1, $collationId) + 1; } } else { ${ev.isNull} = true; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 305c51c0b703..d88c15fb2325 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -661,6 +661,40 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(0))) } + test("Support Locate string expression with collation") { + case class StringLocateTestCase[R](substring: String, string: String, start: Integer, + c: String, result: R) + val testCases = Seq( + // scalastyle:off + StringLocateTestCase("aa", "aaads", 0, "UTF8_BINARY", 0), + StringLocateTestCase("aa", "Aaads", 0, "UTF8_BINARY_LCASE", 0), + StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UTF8_BINARY_LCASE", 8), + StringLocateTestCase("aBc", "abcabc", 4, "UTF8_BINARY_LCASE", 4), + StringLocateTestCase("aa", "Aaads", 0, "UNICODE", 0), + StringLocateTestCase("abC", "abCabC", 2, "UNICODE", 4), + StringLocateTestCase("aa", "Aaads", 0, "UNICODE_CI", 0), + StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8) + // scalastyle:on + ) + testCases.foreach(t => { + val query = s"SELECT locate(collate('${t.substring}','${t.c}')," + + s"collate('${t.string}','${t.c}'),${t.start})" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) + // Implicit casting + checkAnswer(sql(s"SELECT locate(collate('${t.substring}','${t.c}')," + + s"'${t.string}',${t.start})"), Row(t.result)) + checkAnswer(sql(s"SELECT locate('${t.substring}',collate('${t.string}'," + + s"'${t.c}'),${t.start})"), Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT locate(collate('aBc', 'UTF8_BINARY'),collate('abcabc', 'UTF8_BINARY_LCASE'),4)") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } + // TODO: Add more tests for other string expressions } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org