stefankandic commented on code in PR #48663:
URL: https://github.com/apache/spark/pull/48663#discussion_r1848612667


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -193,88 +144,208 @@ object CollationTypeCoercion {
     case other => other
   }
 
+  /**
+   * If childType is collated and target is UTF8_BINARY, the collation of the 
output
+   * should be that of the childType.
+   */
+  private def shouldRemoveCast(cast: Cast): Boolean = {
+    val isUserDefined = cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined
+    val isChildTypeCollatedString = cast.child.dataType match {
+      case st: StringType => !st.isUTF8BinaryCollation
+      case _ => false
+    }
+    val targetType = cast.dataType
+
+    isUserDefined && isChildTypeCollatedString && targetType == StringType
+  }
+
   /**
    * Extracts StringTypes from filtered hasStringType
    */
   @tailrec
-  private def extractStringType(dt: DataType): StringType = dt match {
-    case st: StringType => st
+  private def extractStringType(dt: DataType): Option[StringType] = dt match {
+    case st: StringType => Some(st)
     case ArrayType(et, _) => extractStringType(et)
+    case _ => None
   }
 
   /**
    * Casts given expression to collated StringType with id equal to 
collationId only
    * if expression has StringType in the first place.
-   * @param expr
-   * @param collationId
-   * @return
    */
-  def castStringType(expr: Expression, st: StringType): Option[Expression] =
-    castStringType(expr.dataType, st).map { dt => Cast(expr, dt)}
+  def castStringType(expr: Expression, st: StringType): Expression = {
+    castStringType(expr.dataType, st)
+      .map(dt => Cast(expr, dt))
+      .getOrElse(expr)
+  }
 
   private def castStringType(inType: DataType, castType: StringType): 
Option[DataType] = {
-    @Nullable val ret: DataType = inType match {
-      case st: StringType if st.collationId != castType.collationId => castType
+    inType match {
+      case st: StringType if st.collationId != castType.collationId =>
+        Some(castType)
       case ArrayType(arrType, nullable) =>
-        castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull
-      case _ => null
+        castStringType(arrType, castType).map(ArrayType(_, nullable))
+      case _ => None
     }
-    Option(ret)
   }
 
   /**
    * Collates input expressions to a single collation.
    */
-  def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = {
-    val st = getOutputCollation(exprs)
+  def collateToSingleType(expressions: Seq[Expression]): Seq[Expression] = {
+    val lctOpt = findLeastCommonStringType(expressions)
 
-    exprs.map(e => castStringType(e, st).getOrElse(e))
+    lctOpt match {
+      case Some(lct) =>
+        expressions.map(e => castStringType(e, lct))
+      case _ =>
+        expressions
+    }
   }
 
   /**
-   * Based on the data types of the input expressions this method determines
-   * a collation type which the output will have. This function accepts Seq of
-   * any expressions, but will only be affected by collated StringTypes or
-   * complex DataTypes with collated StringTypes (e.g. ArrayType)
+   * Tries to find the least common StringType among the given expressions.
    */
-  def getOutputCollation(expr: Seq[Expression]): StringType = {
-    val explicitTypes = expr.filter {
-        case _: Collate => true
-        case _ => false
-      }
-      .map(_.dataType.asInstanceOf[StringType].collationId)
-      .distinct
-
-    explicitTypes.size match {
-      // We have 1 explicit collation
-      case 1 => StringType(explicitTypes.head)
-      // Multiple explicit collations occurred
-      case size if size > 1 =>
-        throw QueryCompilationErrors
-          .explicitCollationMismatchError(
-            explicitTypes.map(t => StringType(t))
-          )
-      // Only implicit or default collations present
-      case 0 =>
-        val implicitTypes = expr.filter {
-            case Literal(_, _: StringType) => false
-            case cast: Cast if 
cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty =>
-              cast.child.dataType.isInstanceOf[StringType]
-            case _ => true
-          }
-          .map(_.dataType)
-          .filter(hasStringType)
-          .map(extractStringType(_).collationId)
-          .distinct
-
-        if (implicitTypes.length > 1) {
-          throw QueryCompilationErrors.implicitCollationMismatchError(
-            implicitTypes.map(t => StringType(t))
-          )
+  private def findLeastCommonStringType(expressions: Seq[Expression]): 
Option[StringType] = {
+    if (!expressions.exists(e => 
SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) {
+      return None
+    }
+
+    val collationContextWinner = 
expressions.foldLeft(findCollationContext(expressions.head)) {
+      case (Some(left), right) =>
+        findCollationContext(right).flatMap { ctx =>
+          collationPrecedenceWinner(left, ctx)
         }
-        else {
-          
implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType)
+      case (None, _) => return None
+    }
+
+    collationContextWinner.flatMap { cc =>
+      extractStringType(cc.dataType)
+    }
+  }
+
+  /**
+   * Tries to find the collation context for the given expression.
+   * If found, it will also set the [[COLLATION_CONTEXT_TAG]] on the 
expression,
+   * so that the context can be reused later.
+   */
+  private def findCollationContext(expr: Expression): Option[CollationContext] 
= {
+
+    def getChildren: Seq[Expression] = expr match {
+      // we don't need to consider the struct name expressions
+      case struct: CreateNamedStruct => struct.valExprs
+      case _ => expr.children
+    }
+
+    val contextOpt = expr match {
+      case _ if hasCollationContextTag(expr) =>
+        Some(expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+
+      // if `expr` doesn't have a string in its dataType then it doesn't
+      // have the collation context either
+      case _ if !expr.dataType.existsRecursively(_.isInstanceOf[StringType]) =>
+        None
+
+      case collate: Collate =>
+        Some(CollationContext(collate.dataType, Explicit))
+
+      case _: Alias | _: SubqueryExpression | _: AttributeReference | _: 
VariableReference =>
+        Some(CollationContext(expr.dataType, Implicit))
+
+      case _: Literal =>

Review Comment:
   I don't think that cast should be special cased here, it should just take 
the context of its child



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to