This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ccef4df7abb0 [SPARK-49354][SQL] `split_part` should check whether the
`collation` values of all parameter types are the same
ccef4df7abb0 is described below
commit ccef4df7abb0c794ac134e396844481c9537e65c
Author: panbingkun <[email protected]>
AuthorDate: Mon Aug 26 14:24:52 2024 +0200
[SPARK-49354][SQL] `split_part` should check whether the `collation` values
of all parameter types are the same
### What changes were proposed in this pull request?
The same principle as
https://github.com/apache/spark/pull/47825#pullrequestreview-2250729020, the
parameter `delimiter` in expression `split_part` are treated as
(`collation-dependent`) delimiters, rather than (`collation-unaware`) regular
expressions.
### Why are the changes needed?
Strengthen the parameter data type check of expression `split_part` to
avoid potential issues.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- Add some `test case` to `collations.sql`.
- Pass GA.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47845 from panbingkun/SPARK-49354.
Lead-authored-by: panbingkun <[email protected]>
Co-authored-by: panbingkun <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../sql/catalyst/analysis/CollationTypeCasts.scala | 10 ++++
.../sql-tests/analyzer-results/collations.sql.out | 52 +++++++++++++++++++
.../test/resources/sql-tests/inputs/collations.sql | 11 ++++
.../resources/sql-tests/results/collations.sql.out | 59 ++++++++++++++++++++++
.../sql/CollationStringExpressionsSuite.scala | 46 +++++++++++++++++
5 files changed, 178 insertions(+)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index 5f08e6769e5c..497cd74deefd 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -90,6 +90,16 @@ object CollationTypeCasts extends TypeCoercionRule {
val newValues = collateToSingleType(mapCreate.values)
mapCreate.withNewChildren(newKeys.zip(newValues).flatMap(pair =>
Seq(pair._1, pair._2)))
+ case splitPart: SplitPart =>
+ val Seq(str, delimiter, partNum) = splitPart.children
+ val Seq(newStr, newDelimiter) = collateToSingleType(Seq(str, delimiter))
+ splitPart.withNewChildren(Seq(newStr, newDelimiter, partNum))
+
+ case stringSplitSQL: StringSplitSQL =>
+ val Seq(str, delimiter) = stringSplitSQL.children
+ val Seq(newStr, newDelimiter) = collateToSingleType(Seq(str, delimiter))
+ stringSplitSQL.withNewChildren(Seq(newStr, newDelimiter))
+
case otherExpr @ (
_: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _:
Greatest | _: Least |
_: Coalesce | _: ArrayContains | _: ArrayExcept | _: ConcatWs | _: Mask
| _: StringReplace |
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
index 0ebde354dd7b..66f3931bdf4b 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
@@ -441,3 +441,55 @@ drop table t4
-- !query analysis
DropTable false, false
+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t4
+
+
+-- !query
+create table t5(str string collate utf8_binary, delimiter string collate
utf8_lcase, partNum int) using parquet
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`t5`, false
+
+
+-- !query
+insert into t5 values('11AB12AB13', 'AB', 2)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in
comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in
comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`,
org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included
in comparison]/{warehouse_dir}/t5), [str, delimiter, partNum]
++- Project [cast(col1#x as string) AS str#x, cast(col2#x as string collate
UTF8_LCASE) AS delimiter#x, cast(col3#x as int) AS partNum#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+select split_part(str, delimiter, partNum) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select split_part(str collate utf8_binary, delimiter collate utf8_lcase,
partNum) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select split_part(str collate utf8_binary, delimiter collate utf8_binary,
partNum) from t5
+-- !query analysis
+Project [split_part(collate(str#x, utf8_binary), collate(delimiter#x,
utf8_binary), partNum#x) AS split_part(collate(str, utf8_binary),
collate(delimiter, utf8_binary), partNum)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[str#x,delimiter#x,partNum#x] parquet
+
+
+-- !query
+drop table t5
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t5
diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql
b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
index 4fa1fb9e1cdb..92d5f11659fb 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
@@ -101,3 +101,14 @@ select str_to_map(text collate utf8_binary, pairDelim
collate utf8_lcase, keyVal
select str_to_map(text collate utf8_binary, pairDelim collate utf8_binary,
keyValueDelim collate utf8_binary) from t4;
drop table t4;
+
+-- create table for split_part
+create table t5(str string collate utf8_binary, delimiter string collate
utf8_lcase, partNum int) using parquet;
+
+insert into t5 values('11AB12AB13', 'AB', 2);
+
+select split_part(str, delimiter, partNum) from t5;
+select split_part(str collate utf8_binary, delimiter collate utf8_lcase,
partNum) from t5;
+select split_part(str collate utf8_binary, delimiter collate utf8_binary,
partNum) from t5;
+
+drop table t5;
diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out
b/sql/core/src/test/resources/sql-tests/results/collations.sql.out
index 0bf42abfb885..37e3161e63fa 100644
--- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out
@@ -486,3 +486,62 @@ drop table t4
struct<>
-- !query output
+
+
+-- !query
+create table t5(str string collate utf8_binary, delimiter string collate
utf8_lcase, partNum int) using parquet
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+insert into t5 values('11AB12AB13', 'AB', 2)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+select split_part(str, delimiter, partNum) from t5
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select split_part(str collate utf8_binary, delimiter collate utf8_lcase,
partNum) from t5
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select split_part(str collate utf8_binary, delimiter collate utf8_binary,
partNum) from t5
+-- !query schema
+struct<split_part(collate(str, utf8_binary), collate(delimiter, utf8_binary),
partNum):string>
+-- !query output
+12
+
+
+-- !query
+drop table t5
+-- !query schema
+struct<>
+-- !query output
+
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 21bcafda720f..412c003a0dba 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.analysis.CollationTypeCasts
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -112,6 +113,51 @@ class CollationStringExpressionsSuite
})
}
+ test("Support `StringSplitSQL` string expression with collation") {
+ case class StringSplitSQLTestCase[R](
+ str: String,
+ delimiter: String,
+ collation: String,
+ result: R)
+ val testCases = Seq(
+ StringSplitSQLTestCase("1a2", "a", "UTF8_BINARY", Array("1", "2")),
+ StringSplitSQLTestCase("1a2", "a", "UNICODE", Array("1", "2")),
+ StringSplitSQLTestCase("1a2", "A", "UTF8_LCASE", Array("1", "2")),
+ StringSplitSQLTestCase("1a2", "A", "UNICODE_CI", Array("1", "2"))
+ )
+ testCases.foreach(t => {
+ // Unit test.
+ val str = Literal.create(t.str, StringType(t.collation))
+ val delimiter = Literal.create(t.delimiter, StringType(t.collation))
+ checkEvaluation(StringSplitSQL(str, delimiter), t.result)
+ })
+
+ // Because `StringSplitSQL` is an internal expression,
+ // E2E SQL test cannot be performed in `collations.sql`.
+ checkError(
+ exception = intercept[AnalysisException] {
+ val expr = StringSplitSQL(
+ Cast(Literal.create("1a2"), StringType("UTF8_BINARY")),
+ Cast(Literal.create("a"), StringType("UTF8_LCASE")))
+ CollationTypeCasts.transform(expr)
+ },
+ errorClass = "COLLATION_MISMATCH.IMPLICIT",
+ sqlState = "42P21",
+ parameters = Map.empty
+ )
+ checkError(
+ exception = intercept[AnalysisException] {
+ val expr = StringSplitSQL(
+ Collate(Literal.create("1a2"), "UTF8_BINARY"),
+ Collate(Literal.create("a"), "UTF8_LCASE"))
+ CollationTypeCasts.transform(expr)
+ },
+ errorClass = "COLLATION_MISMATCH.EXPLICIT",
+ sqlState = "42P21",
+ parameters = Map("explicitTypes" -> "`string`, `string collate
UTF8_LCASE`")
+ )
+ }
+
test("Support `Contains` string expression with collation") {
case class ContainsTestCase[R](left: String, right: String, collation:
String, result: R)
val testCases = Seq(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]