stefankandic commented on code in PR #48585:
URL: https://github.com/apache/spark/pull/48585#discussion_r1810661923


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala:
##########
@@ -148,6 +149,10 @@ object AnsiTypeCoercion extends TypeCoercionBase {
       // interval type the string should be promoted as. There are many 
possible interval
       // types, such as year interval, month interval, day interval, hour 
interval, etc.
       case (_: StringType, _: AnsiIntervalType) => None
+      // [SPARK-50060] If a binary operation contains at least one collated 
string types, we can't
+      // decide which collation the result should have.
+      case (d1: StringType, d2: StringType) if 
!UnsafeRowUtils.isBinaryStable(d1) ||
+        !UnsafeRowUtils.isBinaryStable(d2) => None

Review Comment:
   we should probably not use `UnsafeRowUtils` here, 
`SchemaUtils.hasNonUTF8BinaryCollation` would be a better choice



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala:
##########
@@ -162,7 +162,7 @@ object CollationTypeCasts extends TypeCoercionRule {
    */
   def getOutputCollation(expr: Seq[Expression]): StringType = {
     val explicitTypes = expr.filter {
-        case _: Collate => true
+        case _: Collate | Alias(_: Collate, _) => true

Review Comment:
   Someone else will probably want to see if an expression is explicit as well; 
can you maybe add this as a separate function?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala:
##########
@@ -310,6 +310,14 @@ abstract class TypeCoercionBase {
       }
     }
 
+    /** Checks whether the output of a logical plan has a collated string 
output type at
+     * the corresponding index */
+    private def hasOutputCollatedStringTypes(
+      children: Seq[LogicalPlan],
+      attrIndex: Int): Boolean = {
+      children.exists(p => 
!UnsafeRowUtils.isBinaryStable(p.output(attrIndex).dataType))

Review Comment:
   ditto



##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala:
##########
@@ -3260,6 +3260,46 @@ class CollationSQLExpressionsSuite
     }
   }
 
+  test("SPARK-50060: set operators with conflicting and non-conflicting 
collations") {
+    val setOperators = Seq[(String, Seq[Row])](
+      ("UNION", Seq(Row("a"))),
+      ("INTERSECT", Seq(Row("a"))),
+      ("EXCEPT", Seq()),
+      ("UNION ALL", Seq(Row("A"), Row("a"))),
+      ("INTERSECT ALL", Seq(Row("a"))),
+      ("EXCEPT ALL", Seq()),
+      ("UNION DISTINCT", Seq(Row("a"))),
+      ("INTERSECT DISTINCT", Seq(Row("a"))),
+      ("EXCEPT DISTINCT", Seq()))
+
+    Seq[Boolean](true, false).foreach{ ansi_enabled =>

Review Comment:
   we can merge this like:
   ```code
   Seq(true, false).foreach { ansiEnabled =>
     withSQLConf(
       SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
       SqlApiConf.DEFAULT_COLLATION -> "UNICODE"
     ) {
       // Your code here
     }
   }
   ```



##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala:
##########
@@ -3260,6 +3260,46 @@ class CollationSQLExpressionsSuite
     }
   }
 
+  test("SPARK-50060: set operators with conflicting and non-conflicting 
collations") {
+    val setOperators = Seq[(String, Seq[Row])](
+      ("UNION", Seq(Row("a"))),
+      ("INTERSECT", Seq(Row("a"))),
+      ("EXCEPT", Seq()),
+      ("UNION ALL", Seq(Row("A"), Row("a"))),
+      ("INTERSECT ALL", Seq(Row("a"))),
+      ("EXCEPT ALL", Seq()),
+      ("UNION DISTINCT", Seq(Row("a"))),
+      ("INTERSECT DISTINCT", Seq(Row("a"))),
+      ("EXCEPT DISTINCT", Seq()))
+
+    Seq[Boolean](true, false).foreach{ ansi_enabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi_enabled.toString) {
+        withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
+          setOperators.foreach { case (operator, result) =>
+            // Check whether conflict between explicit collations is detected
+            val conflictingQuery =
+              s"SELECT 'a' COLLATE UTF8_LCASE $operator SELECT 'A' COLLATE 
UNICODE_CI"
+            checkError(
+              exception = intercept[AnalysisException] {
+                sql(conflictingQuery)
+              },
+              condition = "COLLATION_MISMATCH.EXPLICIT",
+              parameters = Map(
+                "explicitTypes" -> "\"STRING COLLATE UTF8_LCASE\", \"STRING 
COLLATE UNICODE_CI\"")
+            )
+
+            // Check whether conflict between default and explicit collation 
is resolved in
+            // favor of the explicit collation
+            val nonConflictingQuery = s"SELECT 'a' COLLATE UTF8_LCASE AS val 
$operator " +
+              s"SELECT 'A' AS val ORDER BY val COLLATE UTF8_BINARY"
+            checkAnswer(sql(nonConflictingQuery), result)

Review Comment:
   can we add some cases with implicit mismatch? We have a table and select its 
column, or we use some string function on the literal?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala:
##########
@@ -905,6 +924,10 @@ object TypeCoercion extends TypeCoercionBase {
 
   /** Promotes all the way to StringType. */
   private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] 
= (dt1, dt2) match {
+    // [SPARK-50060] If a binary operation contains at least one collated 
string type,
+    // we can't decide which collation the result should have.
+    case (d1 : StringType, d2: StringType) if 
!UnsafeRowUtils.isBinaryStable(d1) ||
+      !UnsafeRowUtils.isBinaryStable(d2) => None

Review Comment:
   ditto



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala:
##########
@@ -310,6 +310,14 @@ abstract class TypeCoercionBase {
       }
     }
 
+    /** Checks whether the output of a logical plan has a collated string 
output type at
+     * the corresponding index */
+    private def hasOutputCollatedStringTypes(

Review Comment:
   ```suggestion
       private def outputHasNonUTF8BinaryCollation(
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala:
##########
@@ -335,9 +343,20 @@ abstract class TypeCoercionBase {
       // Return the result after the widen data types have been found for all 
the children
       if (attrIndex >= children.head.output.length) return castedTypes.toSeq
 
-      // For the attrIndex-th attribute, find the widest type
-      val widenTypeOpt = 
findWiderCommonType(children.map(_.output(attrIndex).dataType))
-      castedTypes.enqueue(widenTypeOpt)
+      // If the `children` output has at least one collated string type as the 
output type at the
+      // `attrIndex`-th attribute, `CollationTypeCasts` is called to determine 
the output collation.
+      if (hasOutputCollatedStringTypes(children, attrIndex)) {
+        // Since this function is called for set operators, we can be sure 
that their children
+        // are all `Project` nodes.
+        val collatedOutputType = CollationTypeCasts.getOutputCollation(

Review Comment:
   Can we maybe move this into a separate function which returns which type 
should be added to the queue? Also we could add an assert that all of those are 
actually `Project` nodes



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to