This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ccef4df7abb0 [SPARK-49354][SQL] `split_part` should check whether the 
`collation` values of all parameter types are the same
ccef4df7abb0 is described below

commit ccef4df7abb0c794ac134e396844481c9537e65c
Author: panbingkun <[email protected]>
AuthorDate: Mon Aug 26 14:24:52 2024 +0200

    [SPARK-49354][SQL] `split_part` should check whether the `collation` values 
of all parameter types are the same
    
    ### What changes were proposed in this pull request?
    The same principle as 
https://github.com/apache/spark/pull/47825#pullrequestreview-2250729020, the 
parameter `delimiter` in expression `split_part` are treated as 
(`collation-dependent`) delimiters, rather than (`collation-unaware`) regular 
expressions.
    
    ### Why are the changes needed?
    Strengthen the parameter data type check of expression `split_part`  to 
avoid potential issues.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    - Add some `test case` to `collations.sql`.
    - Pass GA.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #47845 from panbingkun/SPARK-49354.
    
    Lead-authored-by: panbingkun <[email protected]>
    Co-authored-by: panbingkun <[email protected]>
    Signed-off-by: Max Gekk <[email protected]>
---
 .../sql/catalyst/analysis/CollationTypeCasts.scala | 10 ++++
 .../sql-tests/analyzer-results/collations.sql.out  | 52 +++++++++++++++++++
 .../test/resources/sql-tests/inputs/collations.sql | 11 ++++
 .../resources/sql-tests/results/collations.sql.out | 59 ++++++++++++++++++++++
 .../sql/CollationStringExpressionsSuite.scala      | 46 +++++++++++++++++
 5 files changed, 178 insertions(+)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index 5f08e6769e5c..497cd74deefd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -90,6 +90,16 @@ object CollationTypeCasts extends TypeCoercionRule {
       val newValues = collateToSingleType(mapCreate.values)
       mapCreate.withNewChildren(newKeys.zip(newValues).flatMap(pair => 
Seq(pair._1, pair._2)))
 
+    case splitPart: SplitPart =>
+      val Seq(str, delimiter, partNum) = splitPart.children
+      val Seq(newStr, newDelimiter) = collateToSingleType(Seq(str, delimiter))
+      splitPart.withNewChildren(Seq(newStr, newDelimiter, partNum))
+
+    case stringSplitSQL: StringSplitSQL =>
+      val Seq(str, delimiter) = stringSplitSQL.children
+      val Seq(newStr, newDelimiter) = collateToSingleType(Seq(str, delimiter))
+      stringSplitSQL.withNewChildren(Seq(newStr, newDelimiter))
+
     case otherExpr @ (
       _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: 
Greatest | _: Least |
       _: Coalesce | _: ArrayContains | _: ArrayExcept | _: ConcatWs | _: Mask 
| _: StringReplace |
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
index 0ebde354dd7b..66f3931bdf4b 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
@@ -441,3 +441,55 @@ drop table t4
 -- !query analysis
 DropTable false, false
 +- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t4
+
+
+-- !query
+create table t5(str string collate utf8_binary, delimiter string collate 
utf8_lcase, partNum int) using parquet
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`t5`, false
+
+
+-- !query
+insert into t5 values('11AB12AB13', 'AB', 2)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in 
comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in 
comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, 
org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included 
in comparison]/{warehouse_dir}/t5), [str, delimiter, partNum]
++- Project [cast(col1#x as string) AS str#x, cast(col2#x as string collate 
UTF8_LCASE) AS delimiter#x, cast(col3#x as int) AS partNum#x]
+   +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+select split_part(str, delimiter, partNum) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+  "sqlState" : "42P21"
+}
+
+
+-- !query
+select split_part(str collate utf8_binary, delimiter collate utf8_lcase, 
partNum) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+  "sqlState" : "42P21",
+  "messageParameters" : {
+    "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+  }
+}
+
+
+-- !query
+select split_part(str collate utf8_binary, delimiter collate utf8_binary, 
partNum) from t5
+-- !query analysis
+Project [split_part(collate(str#x, utf8_binary), collate(delimiter#x, 
utf8_binary), partNum#x) AS split_part(collate(str, utf8_binary), 
collate(delimiter, utf8_binary), partNum)#x]
++- SubqueryAlias spark_catalog.default.t5
+   +- Relation spark_catalog.default.t5[str#x,delimiter#x,partNum#x] parquet
+
+
+-- !query
+drop table t5
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t5
diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql 
b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
index 4fa1fb9e1cdb..92d5f11659fb 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
@@ -101,3 +101,14 @@ select str_to_map(text collate utf8_binary, pairDelim 
collate utf8_lcase, keyVal
 select str_to_map(text collate utf8_binary, pairDelim collate utf8_binary, 
keyValueDelim collate utf8_binary) from t4;
 
 drop table t4;
+
+-- create table for split_part
+create table t5(str string collate utf8_binary, delimiter string collate 
utf8_lcase, partNum int) using parquet;
+
+insert into t5 values('11AB12AB13', 'AB', 2);
+
+select split_part(str, delimiter, partNum) from t5;
+select split_part(str collate utf8_binary, delimiter collate utf8_lcase, 
partNum) from t5;
+select split_part(str collate utf8_binary, delimiter collate utf8_binary, 
partNum) from t5;
+
+drop table t5;
diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out 
b/sql/core/src/test/resources/sql-tests/results/collations.sql.out
index 0bf42abfb885..37e3161e63fa 100644
--- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out
@@ -486,3 +486,62 @@ drop table t4
 struct<>
 -- !query output
 
+
+
+-- !query
+create table t5(str string collate utf8_binary, delimiter string collate 
utf8_lcase, partNum int) using parquet
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+insert into t5 values('11AB12AB13', 'AB', 2)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+select split_part(str, delimiter, partNum) from t5
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+  "sqlState" : "42P21"
+}
+
+
+-- !query
+select split_part(str collate utf8_binary, delimiter collate utf8_lcase, 
partNum) from t5
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+  "sqlState" : "42P21",
+  "messageParameters" : {
+    "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+  }
+}
+
+
+-- !query
+select split_part(str collate utf8_binary, delimiter collate utf8_binary, 
partNum) from t5
+-- !query schema
+struct<split_part(collate(str, utf8_binary), collate(delimiter, utf8_binary), 
partNum):string>
+-- !query output
+12
+
+
+-- !query
+drop table t5
+-- !query schema
+struct<>
+-- !query output
+
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 21bcafda720f..412c003a0dba 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.analysis.CollationTypeCasts
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
@@ -112,6 +113,51 @@ class CollationStringExpressionsSuite
     })
   }
 
+  test("Support `StringSplitSQL` string expression with collation") {
+    case class StringSplitSQLTestCase[R](
+        str: String,
+        delimiter: String,
+        collation: String,
+        result: R)
+    val testCases = Seq(
+      StringSplitSQLTestCase("1a2", "a", "UTF8_BINARY", Array("1", "2")),
+      StringSplitSQLTestCase("1a2", "a", "UNICODE", Array("1", "2")),
+      StringSplitSQLTestCase("1a2", "A", "UTF8_LCASE", Array("1", "2")),
+      StringSplitSQLTestCase("1a2", "A", "UNICODE_CI", Array("1", "2"))
+    )
+    testCases.foreach(t => {
+      // Unit test.
+      val str = Literal.create(t.str, StringType(t.collation))
+      val delimiter = Literal.create(t.delimiter, StringType(t.collation))
+      checkEvaluation(StringSplitSQL(str, delimiter), t.result)
+    })
+
+    // Because `StringSplitSQL` is an internal expression,
+    // E2E SQL test cannot be performed in `collations.sql`.
+    checkError(
+      exception = intercept[AnalysisException] {
+        val expr = StringSplitSQL(
+          Cast(Literal.create("1a2"), StringType("UTF8_BINARY")),
+          Cast(Literal.create("a"), StringType("UTF8_LCASE")))
+        CollationTypeCasts.transform(expr)
+      },
+      errorClass = "COLLATION_MISMATCH.IMPLICIT",
+      sqlState = "42P21",
+      parameters = Map.empty
+    )
+    checkError(
+      exception = intercept[AnalysisException] {
+        val expr = StringSplitSQL(
+          Collate(Literal.create("1a2"), "UTF8_BINARY"),
+          Collate(Literal.create("a"), "UTF8_LCASE"))
+        CollationTypeCasts.transform(expr)
+      },
+      errorClass = "COLLATION_MISMATCH.EXPLICIT",
+      sqlState = "42P21",
+      parameters = Map("explicitTypes" -> "`string`, `string collate 
UTF8_LCASE`")
+    )
+  }
+
   test("Support `Contains` string expression with collation") {
     case class ContainsTestCase[R](left: String, right: String, collation: 
String, result: R)
     val testCases = Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to