This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0329479acb67 [SPARK-47359][SQL] Support TRANSLATE function to work 
with collated strings
0329479acb67 is described below

commit 0329479acb6758c4d3e53d514ea832a181d31065
Author: Milan Dankovic <milan.danko...@databricks.com>
AuthorDate: Tue Apr 30 22:28:56 2024 +0800

    [SPARK-47359][SQL] Support TRANSLATE function to work with collated strings
    
    ### What changes were proposed in this pull request?
    Extend built-in string functions to support non-binary, non-lowercase 
collation for: `translate`
    
    ### Why are the changes needed?
    Update collation support for built-in string functions in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users should now be able to use COLLATE within arguments for built-in 
string function TRANSLATE in Spark SQL queries, using non-binary collations 
such as UNICODE_CI.
    
    ### How was this patch tested?
    Unit tests for queries using StringTranslate 
(CollationStringExpressionsSuite.scala).
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45820 from miland-db/miland-db/string-translate.
    
    Authored-by: Milan Dankovic <milan.danko...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/util/CollationSupport.java  | 85 ++++++++++++++++++++++
 .../sql/catalyst/analysis/CollationTypeCasts.scala |  3 +-
 .../catalyst/expressions/stringExpressions.scala   | 28 ++++---
 .../sql/CollationStringExpressionsSuite.scala      | 74 +++++++++++++++++++
 4 files changed, 180 insertions(+), 10 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index 9778ca31209e..b77671cee90b 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -25,7 +25,9 @@ import org.apache.spark.unsafe.UTF8StringBuilder;
 import org.apache.spark.unsafe.types.UTF8String;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.regex.Pattern;
 
 import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
@@ -483,6 +485,56 @@ public final class CollationSupport {
     }
   }
 
+  public static class StringTranslate {
+    public static UTF8String exec(final UTF8String source, Map<String, String> 
dict,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(source, dict);
+      } else if (collation.supportsLowercaseEquality) {
+        return execLowercase(source, dict);
+      } else {
+        return execICU(source, dict, collationId);
+      }
+    }
+    public static String genCode(final String source, final String dict, final 
int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.EndsWith.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s, %s)", source, dict);
+      } else if (collation.supportsLowercaseEquality) {
+        return String.format(expr + "Lowercase(%s, %s)", source, dict);
+      } else {
+        return String.format(expr + "ICU(%s, %s, %d)", source, dict, 
collationId);
+      }
+    }
+    public static UTF8String execBinary(final UTF8String source, Map<String, 
String> dict) {
+      return source.translate(dict);
+    }
+    public static UTF8String execLowercase(final UTF8String source, 
Map<String, String> dict) {
+      String srcStr = source.toString();
+      StringBuilder sb = new StringBuilder();
+      int charCount = 0;
+      for (int k = 0; k < srcStr.length(); k += charCount) {
+        int codePoint = srcStr.codePointAt(k);
+        charCount = Character.charCount(codePoint);
+        String subStr = srcStr.substring(k, k + charCount);
+        String translated = dict.get(subStr.toLowerCase());
+        if (null == translated) {
+          sb.append(subStr);
+        } else if (!"\0".equals(translated)) {
+          sb.append(translated);
+        }
+      }
+      return UTF8String.fromString(sb.toString());
+    }
+    public static UTF8String execICU(final UTF8String source, Map<String, 
String> dict,
+        final int collationId) {
+      return source.translate(CollationAwareUTF8String.getCollationAwareDict(
+        source, dict, collationId));
+    }
+  }
+
   // TODO: Add more collation-aware string expressions.
 
   /**
@@ -808,6 +860,39 @@ public final class CollationSupport {
       }
     }
 
+    private static Map<String, String> getCollationAwareDict(UTF8String string,
+        Map<String, String> dict, int collationId) {
+      String srcStr = string.toString();
+
+      Map<String, String> collationAwareDict = new HashMap<>();
+      for (String key : dict.keySet()) {
+        StringSearch stringSearch =
+          CollationFactory.getStringSearch(string, UTF8String.fromString(key), 
collationId);
+
+        int pos = 0;
+        while ((pos = stringSearch.next()) != StringSearch.DONE) {
+          int codePoint = srcStr.codePointAt(pos);
+          int charCount = Character.charCount(codePoint);
+          String newKey = srcStr.substring(pos, pos + charCount);
+
+          boolean exists = false;
+          for (String existingKey : collationAwareDict.keySet()) {
+            if (stringSearch.getCollator().compare(existingKey, newKey) == 0) {
+              collationAwareDict.put(newKey, 
collationAwareDict.get(existingKey));
+              exists = true;
+              break;
+            }
+          }
+
+          if (!exists) {
+            collationAwareDict.put(newKey, dict.get(key));
+          }
+        }
+      }
+
+      return collationAwareDict;
+    }
+
   }
 
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index 1130677d5f1b..44349384187e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -73,7 +73,8 @@ object CollationTypeCasts extends TypeCoercionRule {
 
     case otherExpr @ (
       _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: 
Greatest | _: Least |
-      _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: 
StringReplace) =>
+      _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: 
StringReplace |
+      _: StringTranslate) =>
       val newChildren = collateToSingleType(otherExpr.children)
       otherExpr.withNewChildren(newChildren)
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index b0352046b920..0769c8e609ec 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -34,7 +34,7 @@ import 
org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.trees.BinaryLike
 import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, 
UPPER_OR_LOWER}
-import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport, 
GenericArrayData, TypeUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, 
CollationSupport, GenericArrayData, TypeUtils}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.types.{AbstractArrayType, 
StringTypeAnyCollation}
@@ -859,9 +859,14 @@ case class Overlay(input: Expression, replace: Expression, 
pos: Expression, len:
 
 object StringTranslate {
 
-  def buildDict(matchingString: UTF8String, replaceString: UTF8String)
+  def buildDict(matchingString: UTF8String, replaceString: UTF8String, 
collationId: Int)
     : JMap[String, String] = {
-    val matching = matchingString.toString()
+    val matching = if 
(CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) {
+      matchingString.toString().toLowerCase()
+    } else {
+      matchingString.toString()
+    }
+
     val replace = replaceString.toString()
     val dict = new HashMap[String, String]()
     var i = 0
@@ -912,13 +917,16 @@ case class StringTranslate(srcExpr: Expression, 
matchingExpr: Expression, replac
   @transient private var lastReplace: UTF8String = _
   @transient private var dict: JMap[String, String] = _
 
+  final lazy val collationId: Int = 
first.dataType.asInstanceOf[StringType].collationId
+
   override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: 
Any): Any = {
     if (matchingEval != lastMatching || replaceEval != lastReplace) {
       lastMatching = matchingEval.asInstanceOf[UTF8String].clone()
       lastReplace = replaceEval.asInstanceOf[UTF8String].clone()
-      dict = StringTranslate.buildDict(lastMatching, lastReplace)
+      dict = StringTranslate.buildDict(lastMatching, lastReplace, collationId)
     }
-    srcEval.asInstanceOf[UTF8String].translate(dict)
+
+    CollationSupport.StringTranslate.exec(srcEval.asInstanceOf[UTF8String], 
dict, collationId)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -939,15 +947,17 @@ case class StringTranslate(srcExpr: Expression, 
matchingExpr: Expression, replac
         $termLastMatching = $matching.clone();
         $termLastReplace = $replace.clone();
         $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate
-          .buildDict($termLastMatching, $termLastReplace);
+          .buildDict($termLastMatching, $termLastReplace, $collationId);
       }
-      ${ev.value} = $src.translate($termDict);
+      ${ev.value} = CollationSupport.StringTranslate.
+      exec($src, $termDict, $collationId);
       """
     })
   }
 
-  override def dataType: DataType = StringType
-  override def inputTypes: Seq[DataType] = Seq(StringType, StringType, 
StringType)
+  override def dataType: DataType = srcExpr.dataType
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation)
   override def first: Expression = srcExpr
   override def second: Expression = matchingExpr
   override def third: Expression = replaceExpr
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 989e418b7477..b9a4fecd0465 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -248,6 +248,80 @@ class CollationStringExpressionsSuite
     }
     assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
   }
+  test("TRANSLATE check result on explicitly collated string") {
+    // Supported collations
+    case class TranslateTestCase[R](input: String, matchExpression: String,
+        replaceExpression: String, collation: String, result: R)
+    val testCases = Seq(
+      TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_BINARY_LCASE", 
"41a2s3a4e"),
+      TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_BINARY_LCASE", 
"41a2s3a4e"),
+      TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_BINARY_LCASE", 
"xXaxsXaxe"),
+      TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY_LCASE", 
"xxaxsXaxex"),
+      TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY_LCASE", 
"xXaxsXaxeX"),
+      // scalastyle:off
+      TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY_LCASE", 
"test大千世AB大千世A"),
+      TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY_LCASE", 
"大千世界abca大千世界"),
+      TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY_LCASE", 
"oeso大千世界大千世界"),
+      TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY_LCASE", 
"大千世界大千世界OesO"),
+      TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY_LCASE", 
"世世世界世世世界tesT"),
+      // scalastyle:on
+      TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"),
+      TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"),
+      TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE", 
"TxaxsXaxeX"),
+      TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE", 
"TXaxsXaxex"),
+      // scalastyle:off
+      TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE", 
"test大千世AX大千世A"),
+      TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"),
+      TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"),
+      // scalastyle:on
+      TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE_CI", 
"41a2s3a4e"),
+      TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE_CI", 
"xXaxsXaxe"),
+      TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", 
"xxaxsXaxex"),
+      TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", 
"xXaxsXaxeX"),
+      // scalastyle:off
+      TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", 
"test大千世AB大千世A"),
+      TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", 
"大千世界abca大千世界"),
+      TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", 
"oeso大千世界大千世界"),
+      TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", 
"大千世界大千世界OesO"),
+      TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", 
"世世世界世世世界tesT"),
+      // scalastyle:on
+      TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", 
"UTF8_BINARY_LCASE", "14234e"),
+      TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", 
"14234e"),
+      TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", 
"Tr4234e"),
+      TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", 
"Tr4234e"),
+      TranslateTestCase("Translate", "Rnlt", "123495834634", 
"UTF8_BINARY_LCASE", "41a2s3a4e"),
+      TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE", 
"Tra2s3a4e"),
+      TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE_CI", 
"41a2s3a4e"),
+      TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_BINARY", 
"Tra2s3a4e"),
+      TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY", "123f"),
+      TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY_LCASE", "123f"),
+      TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"),
+      TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f")
+    )
+
+    testCases.foreach(t => {
+      val query = s"SELECT translate(collate('${t.input}', '${t.collation}')," 
+
+        s"collate('${t.matchExpression}', '${t.collation}')," +
+        s"collate('${t.replaceExpression}', '${t.collation}'))"
+      // Result & data type
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(
+        StringType(CollationFactory.collationNameToId(t.collation))))
+      // Implicit casting
+      checkAnswer(sql(s"SELECT translate(collate('${t.input}', 
'${t.collation}')," +
+        s"'${t.matchExpression}', '${t.replaceExpression}')"), Row(t.result))
+      checkAnswer(sql(s"SELECT translate('${t.input}', 
collate('${t.matchExpression}'," +
+        s"'${t.collation}'), '${t.replaceExpression}')"), Row(t.result))
+      checkAnswer(sql(s"SELECT translate('${t.input}', 
'${t.matchExpression}'," +
+        s"collate('${t.replaceExpression}', '${t.collation}'))"), 
Row(t.result))
+    })
+    // Collation mismatch
+    val collationMismatch = intercept[AnalysisException] {
+      sql(s"SELECT translate(collate('Translate', 'UTF8_BINARY_LCASE')," +
+        s"collate('Rnlt', 'UNICODE'), '1234')")
+    }
+    assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+  }
 
   test("Support Replace string expression with collation") {
     case class ReplaceTestCase[R](source: String, search: String, replace: 
String,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to