This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new afe310d617e5 [SPARK-47351][SQL] Add collation support for StringToMap 
& Mask string expressions
afe310d617e5 is described below

commit afe310d617e5d5e1fd79e7d42e2bbafe93c6d3a8
Author: Uros Bojanic <157381213+uros...@users.noreply.github.com>
AuthorDate: Fri Apr 26 20:33:29 2024 +0800

    [SPARK-47351][SQL] Add collation support for StringToMap & Mask string 
expressions
    
    ### What changes were proposed in this pull request?
    Introduce collation awareness for string expressions: str_to_map & mask.
    
    ### Why are the changes needed?
    Add collation support for built-in string functions in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users should now be able to use collated strings within arguments for 
built-in string functions: str_to_map & mask.
    
    ### How was this patch tested?
    E2e sql tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #46165 from uros-db/SPARK-47351.
    
    Authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/analysis/CollationTypeCasts.scala |  2 +-
 .../catalyst/expressions/complexTypeCreator.scala  |  8 +-
 .../sql/catalyst/expressions/maskExpressions.scala | 44 +++++-----
 .../spark/sql/CollationSQLExpressionsSuite.scala   | 98 ++++++++++++++++++++++
 4 files changed, 129 insertions(+), 23 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index 473d552b3d94..c7ca5607481d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -64,7 +64,7 @@ object CollationTypeCasts extends TypeCoercionRule {
 
     case otherExpr @ (
       _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: 
Greatest | _: Least |
-      _: Coalesce | _: BinaryExpression | _: ConcatWs) =>
+      _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask) =>
       val newChildren = collateToSingleType(otherExpr.children)
       otherExpr.withNewChildren(newChildren)
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 3eb6225b5426..c38b6cea9a0a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.ArrayImplicits._
@@ -570,11 +571,12 @@ case class StringToMap(text: Expression, pairDelim: 
Expression, keyValueDelim: E
   override def second: Expression = pairDelim
   override def third: Expression = keyValueDelim
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, 
StringType)
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation)
 
-  override def dataType: DataType = MapType(StringType, StringType)
+  override def dataType: DataType = MapType(first.dataType, first.dataType)
 
-  private lazy val mapBuilder = new ArrayBasedMapBuilder(StringType, 
StringType)
+  private lazy val mapBuilder = new ArrayBasedMapBuilder(first.dataType, 
first.dataType)
 
   override def nullSafeEval(
       inputString: Any,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
index e5157685a9a6..c11357352c79 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
@@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, 
InputParameter}
 import org.apache.spark.sql.errors.QueryErrorsBase
-import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
+import org.apache.spark.sql.types.{AbstractDataType, DataType}
 import org.apache.spark.unsafe.types.UTF8String
 
 // scalastyle:off line.size.limit
@@ -79,12 +81,14 @@ import org.apache.spark.unsafe.types.UTF8String
 object MaskExpressionBuilder extends ExpressionBuilder {
   override def functionSignature: Option[FunctionSignature] = {
     val strArg = InputParameter("str")
-    val upperCharArg = InputParameter("upperChar", 
Some(Literal(Mask.MASKED_UPPERCASE)))
-    val lowerCharArg = InputParameter("lowerChar", 
Some(Literal(Mask.MASKED_LOWERCASE)))
-    val digitCharArg = InputParameter("digitChar", 
Some(Literal(Mask.MASKED_DIGIT)))
-    val otherCharArg = InputParameter(
-      "otherChar",
-      Some(Literal(Mask.MASKED_IGNORE, StringType)))
+    val upperCharArg = InputParameter("upperChar",
+      Some(Literal.create(Mask.MASKED_UPPERCASE, 
SQLConf.get.defaultStringType)))
+    val lowerCharArg = InputParameter("lowerChar",
+      Some(Literal.create(Mask.MASKED_LOWERCASE, 
SQLConf.get.defaultStringType)))
+    val digitCharArg = InputParameter("digitChar",
+      Some(Literal.create(Mask.MASKED_DIGIT, SQLConf.get.defaultStringType)))
+    val otherCharArg = InputParameter("otherChar",
+      Some(Literal.create(Mask.MASKED_IGNORE, SQLConf.get.defaultStringType)))
     val functionSignature: FunctionSignature = FunctionSignature(Seq(
       strArg, upperCharArg, lowerCharArg, digitCharArg, otherCharArg))
     Some(functionSignature)
@@ -109,33 +113,34 @@ case class Mask(
   def this(input: Expression) =
     this(
       input,
-      Literal(Mask.MASKED_UPPERCASE),
-      Literal(Mask.MASKED_LOWERCASE),
-      Literal(Mask.MASKED_DIGIT),
-      Literal(Mask.MASKED_IGNORE, StringType))
+      Literal.create(Mask.MASKED_UPPERCASE, SQLConf.get.defaultStringType),
+      Literal.create(Mask.MASKED_LOWERCASE, SQLConf.get.defaultStringType),
+      Literal.create(Mask.MASKED_DIGIT, SQLConf.get.defaultStringType),
+      Literal.create(Mask.MASKED_IGNORE, input.dataType))
 
   def this(input: Expression, upperChar: Expression) =
     this(
       input,
       upperChar,
-      Literal(Mask.MASKED_LOWERCASE),
-      Literal(Mask.MASKED_DIGIT),
-      Literal(Mask.MASKED_IGNORE, StringType))
+      Literal.create(Mask.MASKED_LOWERCASE, SQLConf.get.defaultStringType),
+      Literal.create(Mask.MASKED_DIGIT, SQLConf.get.defaultStringType),
+      Literal.create(Mask.MASKED_IGNORE, input.dataType))
 
   def this(input: Expression, upperChar: Expression, lowerChar: Expression) =
     this(
       input,
       upperChar,
       lowerChar,
-      Literal(Mask.MASKED_DIGIT),
-      Literal(Mask.MASKED_IGNORE, StringType))
+      Literal.create(Mask.MASKED_DIGIT, SQLConf.get.defaultStringType),
+      Literal.create(Mask.MASKED_IGNORE, input.dataType))
 
   def this(
       input: Expression,
       upperChar: Expression,
       lowerChar: Expression,
       digitChar: Expression) =
-    this(input, upperChar, lowerChar, digitChar, Literal(Mask.MASKED_IGNORE, 
StringType))
+    this(input, upperChar, lowerChar, digitChar,
+      Literal.create(Mask.MASKED_IGNORE, input.dataType))
 
   override def checkInputDataTypes(): TypeCheckResult = {
 
@@ -187,7 +192,8 @@ case class Mask(
    *      NumericType, IntegralType, FractionalType.
    */
   override def inputTypes: Seq[AbstractDataType] =
-    Seq(StringType, StringType, StringType, StringType, StringType)
+    Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation,
+      StringTypeAnyCollation, StringTypeAnyCollation)
 
   override def nullable: Boolean = true
 
@@ -276,7 +282,7 @@ case class Mask(
    * Returns the [[DataType]] of the result of evaluating this expression. It 
is invalid to query
    * the dataType of an unresolved expression (i.e., when `resolved` == false).
    */
-  override def dataType: DataType = StringType
+  override def dataType: DataType = input.dataType
 
   /**
    * Returns a Seq of the children of this node. Children should not change. 
Immutability required
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
new file mode 100644
index 000000000000..5cc0f568db77
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.immutable.Seq
+
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{MapType, StringType}
+
+// scalastyle:off nonascii
+class CollationSQLExpressionsSuite
+  extends QueryTest
+  with SharedSparkSession {
+
+  test("Support StringToMap expression with collation") {
+    // Supported collations
+    case class StringToMapTestCase[R](t: String, p: String, k: String, c: 
String, result: R)
+    val testCases = Seq(
+      StringToMapTestCase("a:1,b:2,c:3", ",", ":", "UTF8_BINARY",
+        Map("a" -> "1", "b" -> "2", "c" -> "3")),
+      StringToMapTestCase("A-1;B-2;C-3", ";", "-", "UTF8_BINARY_LCASE",
+        Map("A" -> "1", "B" -> "2", "C" -> "3")),
+      StringToMapTestCase("1:a,2:b,3:c", ",", ":", "UNICODE",
+        Map("1" -> "a", "2" -> "b", "3" -> "c")),
+      StringToMapTestCase("1/A!2/B!3/C", "!", "/", "UNICODE_CI",
+        Map("1" -> "A", "2" -> "B", "3" -> "C"))
+    )
+    testCases.foreach(t => {
+      val query = s"SELECT str_to_map(collate('${t.t}', '${t.c}'), '${t.p}', 
'${t.k}');"
+      // Result & data type
+      checkAnswer(sql(query), Row(t.result))
+      val dataType = MapType(StringType(t.c), StringType(t.c), true)
+      assert(sql(query).schema.fields.head.dataType.sameType(dataType))
+    })
+  }
+
+  test("Support Mask expression with collation") {
+    // Supported collations
+    case class MaskTestCase[R](i: String, u: String, l: String, d: String, o: 
String, c: String,
+      result: R)
+    val testCases = Seq(
+      MaskTestCase("ab-CD-12-@$", null, null, null, null, "UTF8_BINARY", 
"ab-CD-12-@$"),
+      MaskTestCase("ab-CD-12-@$", "X", null, null, null, "UTF8_BINARY_LCASE", 
"ab-XX-12-@$"),
+      MaskTestCase("ab-CD-12-@$", "X", "x", null, null, "UNICODE", 
"xx-XX-12-@$"),
+      MaskTestCase("ab-CD-12-@$", "X", "x", "0", "#", "UNICODE_CI", 
"xx#XX#00###")
+    )
+    testCases.foreach(t => {
+      def col(s: String): String = if (s == null) "null" else s"collate('$s', 
'${t.c}')"
+      val query = s"SELECT mask(${col(t.i)}, ${col(t.u)}, ${col(t.l)}, 
${col(t.d)}, ${col(t.o)})"
+      // Result & data type
+      var result = sql(query)
+      checkAnswer(result, Row(t.result))
+      assert(result.schema.fields.head.dataType.sameType(StringType(t.c)))
+    })
+    // Implicit casting
+    val testCasting = Seq(
+      MaskTestCase("ab-CD-12-@$", "X", "x", "0", "#", "UNICODE_CI", 
"xx#XX#00###")
+    )
+    testCasting.foreach(t => {
+      def col(s: String): String = if (s == null) "null" else s"collate('$s', 
'${t.c}')"
+      def str(s: String): String = if (s == null) "null" else s"'$s'"
+      val query1 = s"SELECT mask(${col(t.i)}, ${str(t.u)}, ${str(t.l)}, 
${str(t.d)}, ${str(t.o)})"
+      val query2 = s"SELECT mask(${str(t.i)}, ${col(t.u)}, ${str(t.l)}, 
${str(t.d)}, ${str(t.o)})"
+      val query3 = s"SELECT mask(${str(t.i)}, ${str(t.u)}, ${col(t.l)}, 
${str(t.d)}, ${str(t.o)})"
+      val query4 = s"SELECT mask(${str(t.i)}, ${str(t.u)}, ${str(t.l)}, 
${col(t.d)}, ${str(t.o)})"
+      val query5 = s"SELECT mask(${str(t.i)}, ${str(t.u)}, ${str(t.l)}, 
${str(t.d)}, ${col(t.o)})"
+      for (q <- Seq(query1, query2, query3, query4, query5)) {
+        val result = sql(q)
+        checkAnswer(result, Row(t.result))
+        assert(result.schema.fields.head.dataType.sameType(StringType(t.c)))
+      }
+    })
+    // Collation mismatch
+    val collationMismatch = intercept[AnalysisException] {
+      sql("SELECT 
mask(collate('ab-CD-12-@$','UNICODE'),collate('X','UNICODE_CI'),'x','0','#')")
+    }
+    assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+  }
+
+  // TODO: Add more tests for other SQL expressions
+
+}
+// scalastyle:on nonascii


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to