This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 461ffa1a681b [SPARK-47692][SQL] Fix default StringType meaning in 
implicit casting
461ffa1a681b is described below

commit 461ffa1a681b3d2fd2b0e32f22a45e30b45ba707
Author: Mihailo Milosevic <mihailo.milose...@databricks.com>
AuthorDate: Wed Apr 24 16:12:20 2024 +0800

    [SPARK-47692][SQL] Fix default StringType meaning in implicit casting
    
    ### What changes were proposed in this pull request?
    Addition of priority flag to StringType.
    
    ### Why are the changes needed?
    In order to follow casting rules for collations, we need to know whether 
StringType is considered default, implicit or explicit.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    ### How was this patch tested?
    Implicit tests in CollationSuite.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #45819 from mihailom-db/SPARK-47692.
    
    Authored-by: Mihailo Milosevic <mihailo.milose...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/internal/types/AbstractStringType.scala    |  3 +-
 .../sql/catalyst/analysis/CollationTypeCasts.scala | 30 ++++++----
 .../spark/sql/catalyst/analysis/TypeCoercion.scala |  5 +-
 .../org/apache/spark/sql/CollationSuite.scala      | 66 +++++++++++++++++++++-
 4 files changed, 88 insertions(+), 16 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
index 6403295fe20c..0828c2d6fc10 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
@@ -17,13 +17,14 @@
 
 package org.apache.spark.sql.internal.types
 
+import org.apache.spark.sql.internal.SqlApiConf
 import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
 
 /**
  * StringTypeCollated is an abstract class for StringType with collation 
support.
  */
 abstract class AbstractStringType extends AbstractDataType {
-  override private[sql] def defaultConcreteType: DataType = StringType
+  override private[sql] def defaultConcreteType: DataType = 
SqlApiConf.get.defaultStringType
   override private[sql] def simpleString: String = "string"
 }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index 3affd91dd3b8..c6232a870dff 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -22,7 +22,7 @@ import javax.annotation.Nullable
 import scala.annotation.tailrec
 
 import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, 
haveSameType}
-import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, 
CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, 
Expression, Greatest, If, In, InSubquery, Least, Overlay, StringLPad, 
StringRPad}
+import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, 
CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, 
Expression, Greatest, If, In, InSubquery, Least, Literal, Overlay, StringLPad, 
StringRPad}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
@@ -48,9 +48,9 @@ object CollationTypeCasts extends TypeCoercionRule {
     case eltExpr: Elt =>
       eltExpr.withNewChildren(eltExpr.children.head +: 
collateToSingleType(eltExpr.children.tail))
 
-    case overlay: Overlay =>
-      overlay.withNewChildren(collateToSingleType(Seq(overlay.input, 
overlay.replace))
-        ++ Seq(overlay.pos, overlay.len))
+    case overlayExpr: Overlay =>
+      overlayExpr.withNewChildren(collateToSingleType(Seq(overlayExpr.input, 
overlayExpr.replace))
+        ++ Seq(overlayExpr.pos, overlayExpr.len))
 
     case stringPadExpr @ (_: StringRPad | _: StringLPad) =>
       val Seq(str, len, pad) = stringPadExpr.children
@@ -108,7 +108,12 @@ object CollationTypeCasts extends TypeCoercionRule {
    * complex DataTypes with collated StringTypes (e.g. ArrayType)
    */
   def getOutputCollation(expr: Seq[Expression]): StringType = {
-    val explicitTypes = expr.filter(_.isInstanceOf[Collate])
+    val explicitTypes = expr.filter {
+        case _: Collate => true
+        case cast: Cast if 
cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined =>
+          cast.dataType.isInstanceOf[StringType]
+        case _ => false
+      }
       .map(_.dataType.asInstanceOf[StringType].collationId)
       .distinct
 
@@ -123,17 +128,22 @@ object CollationTypeCasts extends TypeCoercionRule {
           )
       // Only implicit or default collations present
       case 0 =>
-        val implicitTypes = expr.map(_.dataType)
+        val implicitTypes = expr.filter {
+            case Literal(_, _: StringType) => false
+            case cast: Cast if 
cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty =>
+              cast.child.dataType.isInstanceOf[StringType]
+            case _ => true
+          }
+          .map(_.dataType)
           .filter(hasStringType)
-          .map(extractStringType)
-          .filter(dt => dt.collationId != 
SQLConf.get.defaultStringType.collationId)
-          .distinctBy(_.collationId)
+          .map(extractStringType(_).collationId)
+          .distinct
 
         if (implicitTypes.length > 1) {
           throw QueryCompilationErrors.implicitCollationMismatchError()
         }
         else {
-          implicitTypes.headOption.getOrElse(SQLConf.get.defaultStringType)
+          
implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType)
         }
     }
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 259e28b62bca..506314effde3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -998,9 +998,10 @@ object TypeCoercion extends TypeCoercionBase {
       case (_: StringType, AnyTimestampType) => 
AnyTimestampType.defaultConcreteType
       case (_: StringType, BinaryType) => BinaryType
       // Cast any atomic type to string.
-      case (any: AtomicType, _: StringType) if !any.isInstanceOf[StringType] 
=> StringType
+      case (any: AtomicType, st: StringType) if !any.isInstanceOf[StringType] 
=> st
       case (any: AtomicType, st: AbstractStringType)
-        if !any.isInstanceOf[StringType] => st.defaultConcreteType
+        if !any.isInstanceOf[StringType] =>
+        st.defaultConcreteType
 
       // When we reach here, input type is not acceptable for any types in 
this type collection,
       // try to find the first one we can implicitly cast.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index 9aad96c696ea..26f7726c3964 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.ExtendedAnalysisException
+import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.util.CollationFactory
 import org.apache.spark.sql.connector.{DatasourceV2SQLBase, 
FakeV2ProviderWithCustomSchema}
 import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
@@ -412,7 +413,7 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
-  test("implicit casting of collated strings") {
+  test("SPARK-47210: Implicit casting of collated strings") {
     val tableName = "parquet_dummy_implicit_cast_t22"
     withTable(tableName) {
       spark.sql(
@@ -566,7 +567,66 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
-  test("cast of default collated strings in IN expression") {
+  test("SPARK-47692: Parameter marker with EXECUTE IMMEDIATE implicit 
casting") {
+    sql(s"DECLARE stmtStr1 = 'SELECT collation(:var1 || :var2)';")
+    sql(s"DECLARE stmtStr2 = 'SELECT collation(:var1 || (\\\'a\\\' COLLATE 
UNICODE))';")
+
+    checkAnswer(
+      sql(
+        """EXECUTE IMMEDIATE stmtStr1 USING
+          | 'a' AS var1,
+          | 'b' AS var2;""".stripMargin),
+      Seq(Row("UTF8_BINARY"))
+    )
+
+    withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
+      checkAnswer(
+        sql(
+          """EXECUTE IMMEDIATE stmtStr1 USING
+            | 'a' AS var1,
+            | 'b' AS var2;""".stripMargin),
+        Seq(Row("UNICODE"))
+      )
+    }
+
+    checkAnswer(
+      sql(
+        """EXECUTE IMMEDIATE stmtStr2 USING
+          | 'a' AS var1;""".stripMargin),
+      Seq(Row("UNICODE"))
+    )
+
+    withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
+      checkAnswer(
+        sql(
+          """EXECUTE IMMEDIATE stmtStr2 USING
+            | 'a' AS var1;""".stripMargin),
+        Seq(Row("UNICODE"))
+      )
+    }
+  }
+
+  test("SPARK-47692: Parameter markers with variable mapping") {
+    checkAnswer(
+      spark.sql(
+        "SELECT collation(:var1 || :var2)",
+        Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")),
+            "var2" -> Literal.create('b', StringType("UNICODE")))),
+      Seq(Row("UTF8_BINARY"))
+    )
+
+    withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
+      checkAnswer(
+        spark.sql(
+          "SELECT collation(:var1 || :var2)",
+          Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")),
+              "var2" -> Literal.create('b', StringType("UNICODE")))),
+        Seq(Row("UNICODE"))
+      )
+    }
+  }
+
+  test("SPARK-47210: Cast of default collated strings in IN expression") {
     val tableName = "t1"
     withTable(tableName) {
       spark.sql(
@@ -591,7 +651,7 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
   }
 
   // TODO(SPARK-47210): Add indeterminate support
-  test("indeterminate collation checks") {
+  test("SPARK-47210: Indeterminate collation checks") {
     val tableName = "t1"
     val newTableName = "t2"
     withTable(tableName) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to