stefankandic commented on code in PR #48585:
URL: https://github.com/apache/spark/pull/48585#discussion_r1810661923
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala:
##########
@@ -148,6 +149,10 @@ object AnsiTypeCoercion extends TypeCoercionBase {
// interval type the string should be promoted as. There are many
possible interval
// types, such as year interval, month interval, day interval, hour
interval, etc.
case (_: StringType, _: AnsiIntervalType) => None
+ // [SPARK-50060] If a binary operation contains at least one collated
string types, we can't
+ // decide which collation the result should have.
+ case (d1: StringType, d2: StringType) if
!UnsafeRowUtils.isBinaryStable(d1) ||
+ !UnsafeRowUtils.isBinaryStable(d2) => None
Review Comment:
we should probably not use `UnsafeRowUtils` here,
`SchemaUtils.hasNonUTF8BinaryCollation` would be a better choice
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala:
##########
@@ -162,7 +162,7 @@ object CollationTypeCasts extends TypeCoercionRule {
*/
def getOutputCollation(expr: Seq[Expression]): StringType = {
val explicitTypes = expr.filter {
- case _: Collate => true
+ case _: Collate | Alias(_: Collate, _) => true
Review Comment:
Someone else will probably want to see if an expression is explicit as well;
can you maybe add this as a separate function?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala:
##########
@@ -310,6 +310,14 @@ abstract class TypeCoercionBase {
}
}
+ /** Checks whether the output of a logical plan has a collated string
output type at
+ * the corresponding index */
+ private def hasOutputCollatedStringTypes(
+ children: Seq[LogicalPlan],
+ attrIndex: Int): Boolean = {
+ children.exists(p =>
!UnsafeRowUtils.isBinaryStable(p.output(attrIndex).dataType))
Review Comment:
ditto
##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala:
##########
@@ -3260,6 +3260,46 @@ class CollationSQLExpressionsSuite
}
}
+ test("SPARK-50060: set operators with conflicting and non-conflicting
collations") {
+ val setOperators = Seq[(String, Seq[Row])](
+ ("UNION", Seq(Row("a"))),
+ ("INTERSECT", Seq(Row("a"))),
+ ("EXCEPT", Seq()),
+ ("UNION ALL", Seq(Row("A"), Row("a"))),
+ ("INTERSECT ALL", Seq(Row("a"))),
+ ("EXCEPT ALL", Seq()),
+ ("UNION DISTINCT", Seq(Row("a"))),
+ ("INTERSECT DISTINCT", Seq(Row("a"))),
+ ("EXCEPT DISTINCT", Seq()))
+
+ Seq[Boolean](true, false).foreach{ ansi_enabled =>
Review Comment:
we can merge this like:
```code
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(
SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
SqlApiConf.DEFAULT_COLLATION -> "UNICODE"
) {
// Your code here
}
}
```
##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala:
##########
@@ -3260,6 +3260,46 @@ class CollationSQLExpressionsSuite
}
}
+ test("SPARK-50060: set operators with conflicting and non-conflicting
collations") {
+ val setOperators = Seq[(String, Seq[Row])](
+ ("UNION", Seq(Row("a"))),
+ ("INTERSECT", Seq(Row("a"))),
+ ("EXCEPT", Seq()),
+ ("UNION ALL", Seq(Row("A"), Row("a"))),
+ ("INTERSECT ALL", Seq(Row("a"))),
+ ("EXCEPT ALL", Seq()),
+ ("UNION DISTINCT", Seq(Row("a"))),
+ ("INTERSECT DISTINCT", Seq(Row("a"))),
+ ("EXCEPT DISTINCT", Seq()))
+
+ Seq[Boolean](true, false).foreach{ ansi_enabled =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi_enabled.toString) {
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
+ setOperators.foreach { case (operator, result) =>
+ // Check whether conflict between explicit collations is detected
+ val conflictingQuery =
+ s"SELECT 'a' COLLATE UTF8_LCASE $operator SELECT 'A' COLLATE
UNICODE_CI"
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(conflictingQuery)
+ },
+ condition = "COLLATION_MISMATCH.EXPLICIT",
+ parameters = Map(
+ "explicitTypes" -> "\"STRING COLLATE UTF8_LCASE\", \"STRING
COLLATE UNICODE_CI\"")
+ )
+
+ // Check whether conflict between default and explicit collation
is resolved in
+ // favor of the explicit collation
+ val nonConflictingQuery = s"SELECT 'a' COLLATE UTF8_LCASE AS val
$operator " +
+ s"SELECT 'A' AS val ORDER BY val COLLATE UTF8_BINARY"
+ checkAnswer(sql(nonConflictingQuery), result)
Review Comment:
can we add some cases with implicit mismatch? We have a table and select its
column, or we use some string function on the literal?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala:
##########
@@ -905,6 +924,10 @@ object TypeCoercion extends TypeCoercionBase {
/** Promotes all the way to StringType. */
private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType]
= (dt1, dt2) match {
+ // [SPARK-50060] If a binary operation contains at least one collated
string type,
+ // we can't decide which collation the result should have.
+ case (d1 : StringType, d2: StringType) if
!UnsafeRowUtils.isBinaryStable(d1) ||
+ !UnsafeRowUtils.isBinaryStable(d2) => None
Review Comment:
ditto
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala:
##########
@@ -310,6 +310,14 @@ abstract class TypeCoercionBase {
}
}
+ /** Checks whether the output of a logical plan has a collated string
output type at
+ * the corresponding index */
+ private def hasOutputCollatedStringTypes(
Review Comment:
```suggestion
private def outputHasNonUTF8BinaryCollation(
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala:
##########
@@ -335,9 +343,20 @@ abstract class TypeCoercionBase {
// Return the result after the widen data types have been found for all
the children
if (attrIndex >= children.head.output.length) return castedTypes.toSeq
- // For the attrIndex-th attribute, find the widest type
- val widenTypeOpt =
findWiderCommonType(children.map(_.output(attrIndex).dataType))
- castedTypes.enqueue(widenTypeOpt)
+ // If the `children` output has at least one collated string type as the
output type at the
+ // `attrIndex`-th attribute, `CollationTypeCasts` is called to determine
the output collation.
+ if (hasOutputCollatedStringTypes(children, attrIndex)) {
+ // Since this function is called for set operators, we can be sure
that their children
+ // are all `Project` nodes.
+ val collatedOutputType = CollationTypeCasts.getOutputCollation(
Review Comment:
Can we maybe move this into a separate function which returns which type
should be added to the queue? Also we could add an assert that all of those are
actually `Project` nodes
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]