This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d0dbc6c5e5c4 [SPARK-50470][SQL] Block usage of collations for map keys
d0dbc6c5e5c4 is described below
commit d0dbc6c5e5c44f64c3f13e676e0fb468a3ae7f57
Author: Aleksei Shishkin <[email protected]>
AuthorDate: Tue Dec 3 23:33:08 2024 +0100
[SPARK-50470][SQL] Block usage of collations for map keys
### What changes were proposed in this pull request?
According to the issue description collation usage on keys of map may lead
for unexpected effects (key duplication). This PR blocks this usage, but it is
still available with Spark conf flag.
Old behavior is enabled by `spark.sql.collation.allowInMapKeys`.
### Why are the changes needed?
Because usage may create unclear situation. For example if we have `map('A'
-> 1, 'a' -> 2)` then changing collation from UTF8_BINARY to UTF8_LCASE break
key uniqueness.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49030 from
Alexvsalexvsalex/SPARK-50470_block_collations_on_maps_key.
Authored-by: Aleksei Shishkin <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 5 ++
.../sql/catalyst/analysis/CheckAnalysis.scala | 8 ++
.../spark/sql/errors/QueryCompilationErrors.scala | 6 ++
.../org/apache/spark/sql/internal/SQLConf.scala | 9 +++
.../org/apache/spark/sql/util/SchemaUtils.scala | 11 +++
.../spark/sql/execution/command/tables.scala | 3 +
.../InsertIntoHadoopFsRelationCommand.scala | 3 +
.../spark/sql/execution/datasources/rules.scala | 3 +
.../spark/sql/CollationExpressionWalkerSuite.scala | 75 ++++++++++---------
.../spark/sql/CollationSQLExpressionsSuite.scala | 8 +-
.../org/apache/spark/sql/CollationSuite.scala | 87 +++++++++++++++++-----
.../CollatedFilterPushDownToParquetSuite.scala | 8 +-
.../collation/CollationTypePrecedenceSuite.scala | 46 ++++++------
13 files changed, 187 insertions(+), 85 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 157989c09d09..63c4d18c99de 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5193,6 +5193,11 @@
"The SQL pipe operator syntax using |> does not support <clauses>."
]
},
+ "COLLATIONS_IN_MAP_KEYS" : {
+ "message" : [
+ "Collated strings for keys of maps"
+ ]
+ },
"COMBINATION_QUERY_RESULT_CLAUSES" : {
"message" : [
"Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY."
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index c13da35334ba..d4b97ff037f3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -1560,15 +1560,23 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
alter.conf.resolver)
}
+ def checkNoCollationsInMapKeys(colsToAdd: Seq[QualifiedColType]): Unit = {
+ if (!alter.conf.allowCollationsInMapKeys) {
+ colsToAdd.foreach(col =>
SchemaUtils.checkNoCollationsInMapKeys(col.dataType))
+ }
+ }
+
alter match {
case AddColumns(table: ResolvedTable, colsToAdd) =>
colsToAdd.foreach { colToAdd =>
checkColumnNotExists("add", colToAdd.name, table.schema)
}
checkColumnNameDuplication(colsToAdd)
+ checkNoCollationsInMapKeys(colsToAdd)
case ReplaceColumns(_: ResolvedTable, colsToAdd) =>
checkColumnNameDuplication(colsToAdd)
+ checkNoCollationsInMapKeys(colsToAdd)
case RenameColumn(table: ResolvedTable, col: ResolvedFieldName, newName)
=>
checkColumnNotExists("rename", col.path :+ newName, table.schema)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 7d7a490c9790..b673d5a04315 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -351,6 +351,12 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase with Compilat
)
}
+ def collatedStringsInMapKeysNotSupportedError(): Throwable = {
+ new AnalysisException(
+ errorClass = "UNSUPPORTED_FEATURE.COLLATIONS_IN_MAP_KEYS",
+ messageParameters = Map.empty)
+ }
+
def trimCollationNotEnabledError(): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_FEATURE.TRIM_COLLATION",
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 2a05508a1754..c0d35fa0ce2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -767,6 +767,13 @@ object SQLConf {
.checkValue(_ > 0, "The initial number of partitions must be positive.")
.createOptional
+ lazy val ALLOW_COLLATIONS_IN_MAP_KEYS =
+ buildConf("spark.sql.collation.allowInMapKeys")
+ .doc("Allow for non-UTF8_BINARY collated strings inside of map's keys")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
lazy val TRIM_COLLATION_ENABLED =
buildConf("spark.sql.collation.trim.enabled")
.internal()
@@ -5585,6 +5592,8 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
}
}
+ def allowCollationsInMapKeys: Boolean = getConf(ALLOW_COLLATIONS_IN_MAP_KEYS)
+
def trimCollationEnabled: Boolean = getConf(TRIM_COLLATION_ENABLED)
override def defaultStringType: StringType = {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
index 1e0bac331dc7..0aadd3cd3a44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
@@ -304,6 +304,17 @@ private[spark] object SchemaUtils {
}
}
+ def checkNoCollationsInMapKeys(schema: DataType): Unit = schema match {
+ case m: MapType =>
+ if (hasNonUTF8BinaryCollation(m.keyType)) {
+ throw
QueryCompilationErrors.collatedStringsInMapKeysNotSupportedError()
+ }
+ checkNoCollationsInMapKeys(m.valueType)
+ case s: StructType => s.fields.foreach(field =>
checkNoCollationsInMapKeys(field.dataType))
+ case a: ArrayType => checkNoCollationsInMapKeys(a.elementType)
+ case _ =>
+ }
+
/**
* Replaces any collated string type with non collated StringType
* recursively in the given data type.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 9ecd3fd19aa6..84b73a74f3ab 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -241,6 +241,9 @@ case class AlterTableAddColumnsCommand(
SchemaUtils.checkColumnNameDuplication(
(colsWithProcessedDefaults ++ catalogTable.schema).map(_.name),
conf.caseSensitiveAnalysis)
+ if (!conf.allowCollationsInMapKeys) {
+ colsToAdd.foreach(col =>
SchemaUtils.checkNoCollationsInMapKeys(col.dataType))
+ }
DDLUtils.checkTableColumns(catalogTable,
StructType(colsWithProcessedDefaults))
val existingSchema = CharVarcharUtils.getRawSchema(catalogTable.dataSchema)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index aed129c7dccc..8a795f074881 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -84,6 +84,9 @@ case class InsertIntoHadoopFsRelationCommand(
outputColumnNames,
sparkSession.sessionState.conf.caseSensitiveAnalysis)
}
+ if (!conf.allowCollationsInMapKeys) {
+ SchemaUtils.checkNoCollationsInMapKeys(query.schema)
+ }
val hadoopConf =
sparkSession.sessionState.newHadoopConfWithOptions(options)
val fs = outputPath.getFileSystem(hadoopConf)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 713161cc49ce..23596861a647 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -343,6 +343,9 @@ case class PreprocessTableCreation(catalog: SessionCatalog)
extends Rule[Logical
SchemaUtils.checkSchemaColumnNameDuplication(
schema,
conf.caseSensitiveAnalysis)
+ if (!conf.allowCollationsInMapKeys) {
+ SchemaUtils.checkNoCollationsInMapKeys(schema)
+ }
val normalizedPartCols = normalizePartitionColumns(schema, table)
val normalizedBucketSpec = normalizeBucketSpec(schema, table)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
index bc62fa5fdd33..e3622c310185 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.variant.ParseJson
-import org.apache.spark.sql.internal.SqlApiConf
+import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.internal.types._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
@@ -636,48 +636,49 @@ class CollationExpressionWalkerSuite extends
SparkFunSuite with SharedSparkSessi
val expr = headConstructor.newInstance(args:
_*).asInstanceOf[ExpectsInputTypes]
withTable("tbl", "tbl_lcase") {
+ withSQLConf(SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true") {
+ val utf8_df = generateTableData(expr.inputTypes.take(2), Utf8Binary)
+ val utf8_lcase_df = generateTableData(expr.inputTypes.take(2),
Utf8Lcase)
+
+ val utf8BinaryResult = try {
+ val df = utf8_df.selectExpr(transformExpressionToString(expr,
Utf8Binary))
+ df.getRows(1, 0)
+ scala.util.Right(df)
+ } catch {
+ case e: Throwable => scala.util.Left(e)
+ }
+ val utf8LcaseResult = try {
+ val df =
utf8_lcase_df.selectExpr(transformExpressionToString(expr, Utf8Lcase))
+ df.getRows(1, 0)
+ scala.util.Right(df)
+ } catch {
+ case e: Throwable => scala.util.Left(e)
+ }
- val utf8_df = generateTableData(expr.inputTypes.take(2), Utf8Binary)
- val utf8_lcase_df = generateTableData(expr.inputTypes.take(2),
Utf8Lcase)
-
- val utf8BinaryResult = try {
- val df = utf8_df.selectExpr(transformExpressionToString(expr,
Utf8Binary))
- df.getRows(1, 0)
- scala.util.Right(df)
- } catch {
- case e: Throwable => scala.util.Left(e)
- }
- val utf8LcaseResult = try {
- val df = utf8_lcase_df.selectExpr(transformExpressionToString(expr,
Utf8Lcase))
- df.getRows(1, 0)
- scala.util.Right(df)
- } catch {
- case e: Throwable => scala.util.Left(e)
- }
-
- assert(utf8BinaryResult.isLeft === utf8LcaseResult.isLeft)
+ assert(utf8BinaryResult.isLeft === utf8LcaseResult.isLeft)
- if (utf8BinaryResult.isRight) {
- val utf8BinaryResultChecked = utf8BinaryResult.getOrElse(null)
- val utf8LcaseResultChecked = utf8LcaseResult.getOrElse(null)
+ if (utf8BinaryResult.isRight) {
+ val utf8BinaryResultChecked = utf8BinaryResult.getOrElse(null)
+ val utf8LcaseResultChecked = utf8LcaseResult.getOrElse(null)
- val dt = utf8BinaryResultChecked.schema.fields.head.dataType
+ val dt = utf8BinaryResultChecked.schema.fields.head.dataType
- dt match {
- case st if utf8BinaryResultChecked != null &&
utf8LcaseResultChecked != null &&
- hasStringType(st) =>
- // scalastyle:off caselocale
- assert(utf8BinaryResultChecked.getRows(1,
0).map(_.map(_.toLowerCase))(1) ===
- utf8LcaseResultChecked.getRows(1,
0).map(_.map(_.toLowerCase))(1))
+ dt match {
+ case st if utf8BinaryResultChecked != null &&
utf8LcaseResultChecked != null &&
+ hasStringType(st) =>
+ // scalastyle:off caselocale
+ assert(utf8BinaryResultChecked.getRows(1,
0).map(_.map(_.toLowerCase))(1) ===
+ utf8LcaseResultChecked.getRows(1,
0).map(_.map(_.toLowerCase))(1))
// scalastyle:on caselocale
- case _ =>
- assert(utf8BinaryResultChecked.getRows(1, 0)(1) ===
- utf8LcaseResultChecked.getRows(1, 0)(1))
+ case _ =>
+ assert(utf8BinaryResultChecked.getRows(1, 0)(1) ===
+ utf8LcaseResultChecked.getRows(1, 0)(1))
+ }
+ }
+ else {
+ assert(utf8BinaryResult.getOrElse(new Exception()).getClass
+ == utf8LcaseResult.getOrElse(new Exception()).getClass)
}
- }
- else {
- assert(utf8BinaryResult.getOrElse(new Exception()).getClass
- == utf8LcaseResult.getOrElse(new Exception()).getClass)
}
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index b2bb3eaffd41..4e91fd721a07 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -2006,9 +2006,11 @@ class CollationSQLExpressionsSuite
}
val tableName = s"t_${t1.collationId}_mode_nested_map_struct1"
withTable(tableName) {
- sql(s"CREATE TABLE ${tableName}(" +
- s"i STRUCT<m1: MAP<STRING COLLATE ${t1.collationId}, INT>>) USING
parquet")
- sql(s"INSERT INTO ${tableName} VALUES ${getValuesToAdd(t1)}")
+ withSQLConf(SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true") {
+ sql(s"CREATE TABLE ${tableName}(" +
+ s"i STRUCT<m1: MAP<STRING COLLATE ${t1.collationId}, INT>>) USING
parquet")
+ sql(s"INSERT INTO ${tableName} VALUES ${getValuesToAdd(t1)}")
+ }
val query = "SELECT lower(cast(mode(i).m1 as string))" +
s" FROM ${tableName}"
val queryResult = sql(query)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index dc7d14b21bec..11f2c4b997a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -221,23 +221,54 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
parameters = Map("collationName" -> "UTF8_BS", "proposals" ->
"UTF8_LCASE"))
}
+ test("fail on table creation with collated strings as map key") {
+ withTable("table_1", "table_2") {
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("CREATE TABLE table_1 (col MAP<STRING COLLATE UNICODE, STRING>)
USING parquet")
+ },
+ condition = "UNSUPPORTED_FEATURE.COLLATIONS_IN_MAP_KEYS"
+ )
+ withSQLConf(SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true") {
+ sql("CREATE TABLE table_2 (col MAP<STRING COLLATE UNICODE, STRING>)
USING parquet")
+ }
+ }
+ }
+
+ test("fail on adding column with collated map key") {
+ withTable("table_1") {
+ sql("CREATE TABLE table_1 (id INTEGER) USING parquet")
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("ALTER TABLE table_1 ADD COLUMN col1 MAP<ARRAY<STRING COLLATE
UNICODE>, INTEGER>")
+ },
+ condition = "UNSUPPORTED_FEATURE.COLLATIONS_IN_MAP_KEYS"
+ )
+ withSQLConf(SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true") {
+ sql("ALTER TABLE table_1 ADD COLUMN col1 MAP<ARRAY<STRING COLLATE
UNICODE>, INTEGER>")
+ }
+ }
+ }
+
test("disable bucketing on collated string column") {
def createTable(bucketColumns: String*): Unit = {
val tableName = "test_partition_tbl"
withTable(tableName) {
- sql(
- s"""
- |CREATE TABLE $tableName (
- | id INT,
- | c1 STRING COLLATE UNICODE,
- | c2 STRING,
- | struct_col STRUCT<col1: STRING COLLATE UNICODE, col2: STRING>,
- | array_col ARRAY<STRING COLLATE UNICODE>,
- | map_col MAP<STRING COLLATE UNICODE, STRING>
- |) USING parquet
- |CLUSTERED BY (${bucketColumns.mkString(",")})
- |INTO 4 BUCKETS""".stripMargin
- )
+ withSQLConf(SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true") {
+ sql(
+ s"""
+ |CREATE TABLE $tableName (
+ | id INT,
+ | c1 STRING COLLATE UNICODE,
+ | c2 STRING,
+ | struct_col STRUCT<col1: STRING COLLATE UNICODE, col2:
STRING>,
+ | array_col ARRAY<STRING COLLATE UNICODE>,
+ | map_col MAP<STRING COLLATE UNICODE, STRING>
+ |) USING parquet
+ |CLUSTERED BY (${bucketColumns.mkString(",")})
+ |INTO 4 BUCKETS""".stripMargin
+ )
+ }
}
}
// should work fine on default collated columns
@@ -1124,7 +1155,9 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
}
// map doesn't support aggregation
withTable(table) {
- sql(s"create table $table (m map<string collate utf8_lcase, string>)
using parquet")
+ withSQLConf(SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true") {
+ sql(s"create table $table (m map<string collate utf8_lcase, string>)
using parquet")
+ }
val query = s"select distinct m from $table"
checkError(
exception = intercept[ExtendedAnalysisException](sql(query)),
@@ -1166,8 +1199,10 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
}
// map doesn't support joins
withTable(tableLeft, tableRight) {
- Seq(tableLeft, tableRight).map(tab =>
- sql(s"create table $tab (m map<string collate utf8_lcase, string>)
using parquet"))
+ withSQLConf(SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true") {
+ Seq(tableLeft, tableRight).map(tab =>
+ sql(s"create table $tab (m map<string collate utf8_lcase, string>)
using parquet"))
+ }
val query =
s"select $tableLeft.m from $tableLeft join $tableRight on $tableLeft.m
= $tableRight.m"
val ctx = s"$tableLeft.m = $tableRight.m"
@@ -1418,7 +1453,10 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
val tableName = "t"
withTable(tableName) {
- withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) {
+ withSQLConf(
+ SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen,
+ SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true"
+ ) {
sql(s"create table $tableName" +
s" (m map<string$collationSetup, string$collationSetup>)")
sql(s"insert into $tableName values (map('aaa', 'AAA'))")
@@ -1443,7 +1481,10 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
val tableName = "t"
withTable(tableName) {
- withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) {
+ withSQLConf(
+ SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen,
+ SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true"
+ ) {
sql(s"create table $tableName" +
s" (m map<struct<fld1: string$collationSetup, fld2:
string$collationSetup>, " +
s"struct<fld1: string$collationSetup, fld2:
string$collationSetup>>)")
@@ -1470,7 +1511,10 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
val tableName = "t"
withTable(tableName) {
- withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) {
+ withSQLConf(
+ SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen,
+ SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true"
+ ) {
sql(s"create table $tableName " +
s"(m map<array<string$collationSetup>,
array<string$collationSetup>>)")
sql(s"insert into $tableName values (map(array('aaa', 'bbb'),
array('ccc', 'ddd')))")
@@ -1493,7 +1537,10 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
test(s"Check that order by on map with$collationSetup strings fails
($codeGen)") {
val tableName = "t"
withTable(tableName) {
- withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) {
+ withSQLConf(
+ SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen,
+ SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true"
+ ) {
sql(s"create table $tableName" +
s" (m map<string$collationSetup, string$collationSetup>, " +
s" c integer)")
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala
index 9b54fe4bb052..8bb4a1c803e8 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala
@@ -43,7 +43,7 @@ abstract class CollatedFilterPushDownToParquetSuite extends
QueryTest
val collatedStructNestedCol = "f1"
val collatedStructFieldAccess =
s"$collatedStructCol.$collatedStructNestedCol"
val collatedArrayCol = "c3"
- val collatedMapCol = "c4"
+ val nonCollatedMapCol = "c4"
val lcaseCollation = "'UTF8_LCASE'"
@@ -69,7 +69,7 @@ abstract class CollatedFilterPushDownToParquetSuite extends
QueryTest
| named_struct('$collatedStructNestedCol',
| COLLATE(c, $lcaseCollation)) as $collatedStructCol,
| array(COLLATE(c, $lcaseCollation)) as $collatedArrayCol,
- | map(COLLATE(c, $lcaseCollation), 1) as $collatedMapCol
+ | map(c, 1) as $nonCollatedMapCol
|FROM VALUES ('aaa'), ('AAA'), ('bbb')
|as data(c)
|""".stripMargin)
@@ -215,9 +215,9 @@ abstract class CollatedFilterPushDownToParquetSuite extends
QueryTest
test("map - parquet does not support null check on complex types") {
testPushDown(
- filterString = s"map_keys($collatedMapCol) != array(collate('aaa',
$lcaseCollation))",
+ filterString = s"map_keys($nonCollatedMapCol) != array('aaa')",
expectedPushedFilters = Seq.empty,
- expectedRowCount = 1)
+ expectedRowCount = 2)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationTypePrecedenceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationTypePrecedenceSuite.scala
index 93e36afae242..bb6fce1fb1b6 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationTypePrecedenceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationTypePrecedenceSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.collation
import org.apache.spark.SparkThrowable
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
class CollationTypePrecedenceSuite extends QueryTest with SharedSparkSession {
@@ -397,27 +398,30 @@ class CollationTypePrecedenceSuite extends QueryTest with
SharedSparkSession {
sql(s"SELECT c1 FROM $tableName WHERE $condition = 'B'")
withTable(tableName) {
- sql(s"""
- |CREATE TABLE $tableName (
- | c1 MAP<STRING COLLATE UNICODE_CI, STRING COLLATE UNICODE_CI>,
- | c2 STRING
- |) USING $dataSource
- |""".stripMargin)
-
- sql(s"INSERT INTO $tableName VALUES (map('a', 'b'), 'a')")
-
- Seq("c1['A']",
- "c1['A' COLLATE UNICODE_CI]",
- "c1[c2 COLLATE UNICODE_CI]").foreach { condition =>
- checkAnswer(selectQuery(condition), Seq(Row(Map("a" -> "b"))))
- }
-
- Seq(
- // different explicit collation
- "c1['A' COLLATE UNICODE]",
- // different implicit collation
- "c1[c2]").foreach { condition =>
- assertThrowsError(selectQuery(condition),
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ withSQLConf(SQLConf.ALLOW_COLLATIONS_IN_MAP_KEYS.key -> "true") {
+ sql(
+ s"""
+ |CREATE TABLE $tableName (
+ | c1 MAP<STRING COLLATE UNICODE_CI, STRING COLLATE UNICODE_CI>,
+ | c2 STRING
+ |) USING $dataSource
+ |""".stripMargin)
+
+ sql(s"INSERT INTO $tableName VALUES (map('a', 'b'), 'a')")
+
+ Seq("c1['A']",
+ "c1['A' COLLATE UNICODE_CI]",
+ "c1[c2 COLLATE UNICODE_CI]").foreach { condition =>
+ checkAnswer(selectQuery(condition), Seq(Row(Map("a" -> "b"))))
+ }
+
+ Seq(
+ // different explicit collation
+ "c1['A' COLLATE UNICODE]",
+ // different implicit collation
+ "c1[c2]").foreach { condition =>
+ assertThrowsError(selectQuery(condition),
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ }
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]