This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6be3560f3c89 [SPARK-48364][SQL] Add AbstractMapType type casting and 
fix RaiseError parameter map to work with collated strings
6be3560f3c89 is described below

commit 6be3560f3c89e212e850a0788d24a7c0755ea35b
Author: Uros Bojanic <157381213+uros...@users.noreply.github.com>
AuthorDate: Wed May 22 05:21:23 2024 -0700

    [SPARK-48364][SQL] Add AbstractMapType type casting and fix RaiseError 
parameter map to work with collated strings
    
    ### What changes were proposed in this pull request?
    Following up on the introduction of AbstractMapType 
(https://github.com/apache/spark/pull/46458) and changes that introduce 
collation awareness for RaiseError expression 
(https://github.com/apache/spark/pull/46461), this PR should add the 
appropriate type casting rules for AbstractMapType.
    
    ### Why are the changes needed?
    Fix the CI failure for the `Support RaiseError misc expression with 
collation` test when ANSI is off.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, type casting is now allowed for map types with collated strings.
    
    ### How was this patch tested?
    Extended suite `CollationSQLExpressionsANSIOffSuite` with ANSI disabled.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #46661 from uros-db/fix-abstract-map.
    
    Authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/analysis/CollationTypeCasts.scala | 15 ++++++++++++-
 .../spark/sql/catalyst/analysis/TypeCoercion.scala | 13 +++++++++--
 .../spark/sql/catalyst/expressions/misc.scala      |  4 ++--
 .../spark/sql/CollationSQLExpressionsSuite.scala   | 10 +++++++--
 .../org/apache/spark/sql/CollationSuite.scala      | 25 ++--------------------
 5 files changed, 37 insertions(+), 30 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index a50dad7c8cdb..00abdf4ee19d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -25,7 +25,7 @@ import 
org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveS
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType}
 
 object CollationTypeCasts extends TypeCoercionRule {
   override val transform: PartialFunction[Expression, Expression] = {
@@ -85,6 +85,11 @@ object CollationTypeCasts extends TypeCoercionRule {
   private def extractStringType(dt: DataType): StringType = dt match {
     case st: StringType => st
     case ArrayType(et, _) => extractStringType(et)
+    case MapType(kt, vt, _) => if (hasStringType(kt)) {
+        extractStringType(kt)
+      } else {
+        extractStringType(vt)
+      }
   }
 
   /**
@@ -102,6 +107,14 @@ object CollationTypeCasts extends TypeCoercionRule {
       case st: StringType if st.collationId != castType.collationId => castType
       case ArrayType(arrType, nullable) =>
         castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull
+      case MapType(keyType, valueType, nullable) =>
+        val newKeyType = castStringType(keyType, castType).getOrElse(keyType)
+        val newValueType = castStringType(valueType, 
castType).getOrElse(valueType)
+        if (newKeyType != keyType || newValueType != valueType) {
+          MapType(newKeyType, newValueType, nullable)
+        } else {
+          null
+        }
       case _ => null
     }
     Option(ret)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 936bb22baa46..7866f47c28b1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.types.{AbstractArrayType, 
AbstractStringType, StringTypeAnyCollation}
+import org.apache.spark.sql.internal.types.{AbstractArrayType, 
AbstractMapType, AbstractStringType, StringTypeAnyCollation}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.types.UpCastRule.numericPrecedence
 
@@ -1048,6 +1048,15 @@ object TypeCoercion extends TypeCoercionBase {
           }
         }
 
+      case (MapType(fromKeyType, fromValueType, fn), 
AbstractMapType(toKeyType, toValueType)) =>
+        val newKeyType = implicitCast(fromKeyType, toKeyType).orNull
+        val newValueType = implicitCast(fromValueType, toValueType).orNull
+        if (newKeyType != null && newValueType != null) {
+          MapType(newKeyType, newValueType, fn)
+        } else {
+          null
+        }
+
       case _ => null
     }
     Option(ret)
@@ -1080,10 +1089,10 @@ object TypeCoercion extends TypeCoercionBase {
   /**
    * Whether the data type contains StringType.
    */
-  @tailrec
   def hasStringType(dt: DataType): Boolean = dt match {
     case _: StringType => true
     case ArrayType(et, _) => hasStringType(et)
+    case MapType(kt, vt, _) => hasStringType(kt) || hasStringType(vt)
     // Add StructType if we support string promotion for struct fields in the 
future.
     case _ => false
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index e9fa362de14c..d9d7cd2cd0c1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, 
RandomUUIDGenerator}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.types.StringTypeAnyCollation
+import org.apache.spark.sql.internal.types.{AbstractMapType, 
StringTypeAnyCollation}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: 
Expression, dataType:
   override def foldable: Boolean = false
   override def nullable: Boolean = true
   override def inputTypes: Seq[AbstractDataType] =
-    Seq(StringTypeAnyCollation, MapType(StringType, StringType))
+    Seq(StringTypeAnyCollation, AbstractMapType(StringTypeAnyCollation, 
StringTypeAnyCollation))
 
   override def left: Expression = errorClass
   override def right: Expression = errorParms
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 828245bb3fdd..f3d07ba47b71 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -21,8 +21,8 @@ import java.text.SimpleDateFormat
 
 import scala.collection.immutable.Seq
 
-import org.apache.spark.{SparkException, SparkIllegalArgumentException, 
SparkRuntimeException}
-import org.apache.spark.sql.internal.SqlApiConf
+import org.apache.spark.{SparkConf, SparkException, 
SparkIllegalArgumentException, SparkRuntimeException}
+import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
 
@@ -1636,3 +1636,9 @@ class CollationSQLExpressionsSuite
 
 }
 // scalastyle:on nonascii
+
+class CollationSQLExpressionsANSIOffSuite extends CollationSQLExpressionsSuite 
{
+  override protected def sparkConf: SparkConf =
+    super.sparkConf.set(SQLConf.ANSI_ENABLED, false)
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index b22a762a2954..657fd4504cac 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -32,6 +32,7 @@ import 
org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
SortMergeJoinExec}
 import org.apache.spark.sql.internal.SqlApiConf
+import org.apache.spark.sql.internal.types.{AbstractMapType, 
StringTypeAnyCollation}
 import org.apache.spark.sql.types.{MapType, StringType, StructField, 
StructType}
 
 class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
@@ -954,10 +955,7 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
         errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE",
         parameters = Map(
           "functionName" -> "`=`",
-          "dataType" -> toSQLType(MapType(
-            
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")),
-            StringType
-          )),
+          "dataType" -> toSQLType(AbstractMapType(StringTypeAnyCollation, 
StringTypeAnyCollation)),
           "sqlExpr" -> "\"(m = m)\""),
         context = ExpectedContext(ctx, query.length - ctx.length, query.length 
- 1))
     }
@@ -1010,25 +1008,6 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
         |select map('a' collate utf8_binary_lcase, 1, 'b' collate 
utf8_binary_lcase, 2)
         |['A' collate utf8_binary_lcase]
         |""".stripMargin), Seq(Row(1)))
-    val ctx = "map('aaa' collate utf8_binary_lcase, 1, 'AAA' collate 
utf8_binary_lcase, 2)['AaA']"
-    val query = s"select $ctx"
-    checkError(
-      exception = intercept[AnalysisException](sql(query)),
-      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
-      parameters = Map(
-        "sqlExpr" -> "\"map(collate(aaa), 1, collate(AAA), 2)[AaA]\"",
-        "paramIndex" -> "second",
-        "inputSql" -> "\"AaA\"",
-        "inputType" -> toSQLType(StringType),
-        "requiredType" -> toSQLType(StringType(
-          CollationFactory.collationNameToId("UTF8_BINARY_LCASE")))
-      ),
-      context = ExpectedContext(
-        fragment = ctx,
-        start = query.length - ctx.length,
-        stop = query.length - 1
-      )
-    )
   }
 
   test("window aggregates should respect collation") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to