This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0329479acb67 [SPARK-47359][SQL] Support TRANSLATE function to work with collated strings 0329479acb67 is described below commit 0329479acb6758c4d3e53d514ea832a181d31065 Author: Milan Dankovic <milan.danko...@databricks.com> AuthorDate: Tue Apr 30 22:28:56 2024 +0800 [SPARK-47359][SQL] Support TRANSLATE function to work with collated strings ### What changes were proposed in this pull request? Extend built-in string functions to support non-binary, non-lowercase collation for: `translate` ### Why are the changes needed? Update collation support for built-in string functions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use COLLATE within arguments for built-in string function TRANSLATE in Spark SQL queries, using non-binary collations such as UNICODE_CI. ### How was this patch tested? Unit tests for queries using StringTranslate (CollationStringExpressionsSuite.scala). ### Was this patch authored or co-authored using generative AI tooling? No Closes #45820 from miland-db/miland-db/string-translate. Authored-by: Milan Dankovic <milan.danko...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/util/CollationSupport.java | 85 ++++++++++++++++++++++ .../sql/catalyst/analysis/CollationTypeCasts.scala | 3 +- .../catalyst/expressions/stringExpressions.scala | 28 ++++--- .../sql/CollationStringExpressionsSuite.scala | 74 +++++++++++++++++++ 4 files changed, 180 insertions(+), 10 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 9778ca31209e..b77671cee90b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -25,7 +25,9 @@ import org.apache.spark.unsafe.UTF8StringBuilder; import org.apache.spark.unsafe.types.UTF8String; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.regex.Pattern; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; @@ -483,6 +485,56 @@ public final class CollationSupport { } } + public static class StringTranslate { + public static UTF8String exec(final UTF8String source, Map<String, String> dict, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(source, dict); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(source, dict); + } else { + return execICU(source, dict, collationId); + } + } + public static String genCode(final String source, final String dict, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.EndsWith.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", source, dict); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s)", source, dict); + } else { + return String.format(expr + "ICU(%s, %s, %d)", source, dict, collationId); + } + } + public static UTF8String execBinary(final UTF8String source, Map<String, String> dict) { + return source.translate(dict); + } + public static UTF8String execLowercase(final UTF8String source, Map<String, String> dict) { + String srcStr = source.toString(); + StringBuilder sb = new StringBuilder(); + int charCount = 0; + for (int k = 0; k < srcStr.length(); k += charCount) { + int codePoint = srcStr.codePointAt(k); + charCount = Character.charCount(codePoint); + String subStr = srcStr.substring(k, k + charCount); + String translated = dict.get(subStr.toLowerCase()); + if (null == translated) { + sb.append(subStr); + } else if (!"\0".equals(translated)) { + sb.append(translated); + } + } + return UTF8String.fromString(sb.toString()); + } + public static UTF8String execICU(final UTF8String source, Map<String, String> dict, + final int collationId) { + return source.translate(CollationAwareUTF8String.getCollationAwareDict( + source, dict, collationId)); + } + } + // TODO: Add more collation-aware string expressions. /** @@ -808,6 +860,39 @@ public final class CollationSupport { } } + private static Map<String, String> getCollationAwareDict(UTF8String string, + Map<String, String> dict, int collationId) { + String srcStr = string.toString(); + + Map<String, String> collationAwareDict = new HashMap<>(); + for (String key : dict.keySet()) { + StringSearch stringSearch = + CollationFactory.getStringSearch(string, UTF8String.fromString(key), collationId); + + int pos = 0; + while ((pos = stringSearch.next()) != StringSearch.DONE) { + int codePoint = srcStr.codePointAt(pos); + int charCount = Character.charCount(codePoint); + String newKey = srcStr.substring(pos, pos + charCount); + + boolean exists = false; + for (String existingKey : collationAwareDict.keySet()) { + if (stringSearch.getCollator().compare(existingKey, newKey) == 0) { + collationAwareDict.put(newKey, collationAwareDict.get(existingKey)); + exists = true; + break; + } + } + + if (!exists) { + collationAwareDict.put(newKey, dict.get(key)); + } + } + } + + return collationAwareDict; + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 1130677d5f1b..44349384187e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -73,7 +73,8 @@ object CollationTypeCasts extends TypeCoercionRule { case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | - _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: StringReplace) => + _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: StringReplace | + _: StringTranslate) => val newChildren = collateToSingleType(otherExpr.children) otherExpr.withNewChildren(newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0352046b920..0769c8e609ec 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} -import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} @@ -859,9 +859,14 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: object StringTranslate { - def buildDict(matchingString: UTF8String, replaceString: UTF8String) + def buildDict(matchingString: UTF8String, replaceString: UTF8String, collationId: Int) : JMap[String, String] = { - val matching = matchingString.toString() + val matching = if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { + matchingString.toString().toLowerCase() + } else { + matchingString.toString() + } + val replace = replaceString.toString() val dict = new HashMap[String, String]() var i = 0 @@ -912,13 +917,16 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac @transient private var lastReplace: UTF8String = _ @transient private var dict: JMap[String, String] = _ + final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId + override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = { if (matchingEval != lastMatching || replaceEval != lastReplace) { lastMatching = matchingEval.asInstanceOf[UTF8String].clone() lastReplace = replaceEval.asInstanceOf[UTF8String].clone() - dict = StringTranslate.buildDict(lastMatching, lastReplace) + dict = StringTranslate.buildDict(lastMatching, lastReplace, collationId) } - srcEval.asInstanceOf[UTF8String].translate(dict) + + CollationSupport.StringTranslate.exec(srcEval.asInstanceOf[UTF8String], dict, collationId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -939,15 +947,17 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac $termLastMatching = $matching.clone(); $termLastReplace = $replace.clone(); $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict($termLastMatching, $termLastReplace); + .buildDict($termLastMatching, $termLastReplace, $collationId); } - ${ev.value} = $src.translate($termDict); + ${ev.value} = CollationSupport.StringTranslate. + exec($src, $termDict, $collationId); """ }) } - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) + override def dataType: DataType = srcExpr.dataType + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) override def first: Expression = srcExpr override def second: Expression = matchingExpr override def third: Expression = replaceExpr diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 989e418b7477..b9a4fecd0465 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -248,6 +248,80 @@ class CollationStringExpressionsSuite } assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } + test("TRANSLATE check result on explicitly collated string") { + // Supported collations + case class TranslateTestCase[R](input: String, matchExpression: String, + replaceExpression: String, collation: String, result: R) + val testCases = Seq( + TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_BINARY_LCASE", "41a2s3a4e"), + TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_BINARY_LCASE", "41a2s3a4e"), + TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_BINARY_LCASE", "xXaxsXaxe"), + TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY_LCASE", "xxaxsXaxex"), + TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY_LCASE", "xXaxsXaxeX"), + // scalastyle:off + TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY_LCASE", "test大千世AB大千世A"), + TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY_LCASE", "大千世界abca大千世界"), + TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY_LCASE", "oeso大千世界大千世界"), + TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY_LCASE", "大千世界大千世界OesO"), + TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY_LCASE", "世世世界世世世界tesT"), + // scalastyle:on + TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"), + TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"), + TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"), + TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE", "TXaxsXaxex"), + // scalastyle:off + TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"), + TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"), + TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"), + // scalastyle:on + TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE_CI", "41a2s3a4e"), + TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE_CI", "xXaxsXaxe"), + TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", "xxaxsXaxex"), + TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", "xXaxsXaxeX"), + // scalastyle:off + TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", "test大千世AB大千世A"), + TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", "大千世界abca大千世界"), + TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", "oeso大千世界大千世界"), + TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", "大千世界大千世界OesO"), + TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", "世世世界世世世界tesT"), + // scalastyle:on + TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY_LCASE", "14234e"), + TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", "14234e"), + TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"), + TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", "Tr4234e"), + TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_BINARY_LCASE", "41a2s3a4e"), + TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE", "Tra2s3a4e"), + TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE_CI", "41a2s3a4e"), + TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_BINARY", "Tra2s3a4e"), + TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY", "123f"), + TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY_LCASE", "123f"), + TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"), + TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f") + ) + + testCases.foreach(t => { + val query = s"SELECT translate(collate('${t.input}', '${t.collation}')," + + s"collate('${t.matchExpression}', '${t.collation}')," + + s"collate('${t.replaceExpression}', '${t.collation}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType( + StringType(CollationFactory.collationNameToId(t.collation)))) + // Implicit casting + checkAnswer(sql(s"SELECT translate(collate('${t.input}', '${t.collation}')," + + s"'${t.matchExpression}', '${t.replaceExpression}')"), Row(t.result)) + checkAnswer(sql(s"SELECT translate('${t.input}', collate('${t.matchExpression}'," + + s"'${t.collation}'), '${t.replaceExpression}')"), Row(t.result)) + checkAnswer(sql(s"SELECT translate('${t.input}', '${t.matchExpression}'," + + s"collate('${t.replaceExpression}', '${t.collation}'))"), Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql(s"SELECT translate(collate('Translate', 'UTF8_BINARY_LCASE')," + + s"collate('Rnlt', 'UNICODE'), '1234')") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } test("Support Replace string expression with collation") { case class ReplaceTestCase[R](source: String, search: String, replace: String, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org