stefankandic commented on code in PR #48663:
URL: https://github.com/apache/spark/pull/48663#discussion_r1848606677


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -193,88 +144,208 @@ object CollationTypeCoercion {
     case other => other
   }
 
+  /**
+   * If childType is collated and target is UTF8_BINARY, the collation of the 
output
+   * should be that of the childType.
+   */
+  private def shouldRemoveCast(cast: Cast): Boolean = {
+    val isUserDefined = cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined
+    val isChildTypeCollatedString = cast.child.dataType match {
+      case st: StringType => !st.isUTF8BinaryCollation
+      case _ => false
+    }
+    val targetType = cast.dataType
+
+    isUserDefined && isChildTypeCollatedString && targetType == StringType
+  }
+
   /**
    * Extracts StringTypes from filtered hasStringType
    */
   @tailrec
-  private def extractStringType(dt: DataType): StringType = dt match {
-    case st: StringType => st
+  private def extractStringType(dt: DataType): Option[StringType] = dt match {
+    case st: StringType => Some(st)
     case ArrayType(et, _) => extractStringType(et)
+    case _ => None
   }
 
   /**
    * Casts given expression to collated StringType with id equal to 
collationId only
    * if expression has StringType in the first place.
-   * @param expr
-   * @param collationId
-   * @return
    */
-  def castStringType(expr: Expression, st: StringType): Option[Expression] =
-    castStringType(expr.dataType, st).map { dt => Cast(expr, dt)}
+  def castStringType(expr: Expression, st: StringType): Expression = {
+    castStringType(expr.dataType, st)
+      .map(dt => Cast(expr, dt))
+      .getOrElse(expr)
+  }
 
   private def castStringType(inType: DataType, castType: StringType): 
Option[DataType] = {
-    @Nullable val ret: DataType = inType match {
-      case st: StringType if st.collationId != castType.collationId => castType
+    inType match {
+      case st: StringType if st.collationId != castType.collationId =>
+        Some(castType)
       case ArrayType(arrType, nullable) =>
-        castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull
-      case _ => null
+        castStringType(arrType, castType).map(ArrayType(_, nullable))
+      case _ => None
     }
-    Option(ret)
   }
 
   /**
    * Collates input expressions to a single collation.
    */
-  def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = {
-    val st = getOutputCollation(exprs)
+  def collateToSingleType(expressions: Seq[Expression]): Seq[Expression] = {
+    val lctOpt = findLeastCommonStringType(expressions)
 
-    exprs.map(e => castStringType(e, st).getOrElse(e))
+    lctOpt match {
+      case Some(lct) =>
+        expressions.map(e => castStringType(e, lct))
+      case _ =>
+        expressions
+    }
   }
 
   /**
-   * Based on the data types of the input expressions this method determines
-   * a collation type which the output will have. This function accepts Seq of
-   * any expressions, but will only be affected by collated StringTypes or
-   * complex DataTypes with collated StringTypes (e.g. ArrayType)
+   * Tries to find the least common StringType among the given expressions.
    */
-  def getOutputCollation(expr: Seq[Expression]): StringType = {
-    val explicitTypes = expr.filter {
-        case _: Collate => true
-        case _ => false
-      }
-      .map(_.dataType.asInstanceOf[StringType].collationId)
-      .distinct
-
-    explicitTypes.size match {
-      // We have 1 explicit collation
-      case 1 => StringType(explicitTypes.head)
-      // Multiple explicit collations occurred
-      case size if size > 1 =>
-        throw QueryCompilationErrors
-          .explicitCollationMismatchError(
-            explicitTypes.map(t => StringType(t))
-          )
-      // Only implicit or default collations present
-      case 0 =>
-        val implicitTypes = expr.filter {
-            case Literal(_, _: StringType) => false
-            case cast: Cast if 
cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty =>
-              cast.child.dataType.isInstanceOf[StringType]
-            case _ => true
-          }
-          .map(_.dataType)
-          .filter(hasStringType)
-          .map(extractStringType(_).collationId)
-          .distinct
-
-        if (implicitTypes.length > 1) {
-          throw QueryCompilationErrors.implicitCollationMismatchError(
-            implicitTypes.map(t => StringType(t))
-          )
+  private def findLeastCommonStringType(expressions: Seq[Expression]): 
Option[StringType] = {
+    if (!expressions.exists(e => 
SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) {
+      return None
+    }
+
+    val collationContextWinner = 
expressions.foldLeft(findCollationContext(expressions.head)) {
+      case (Some(left), right) =>
+        findCollationContext(right).flatMap { ctx =>
+          collationPrecedenceWinner(left, ctx)
         }
-        else {
-          
implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType)
+      case (None, _) => return None
+    }
+
+    collationContextWinner.flatMap { cc =>
+      extractStringType(cc.dataType)
+    }
+  }
+
+  /**
+   * Tries to find the collation context for the given expression.
+   * If found, it will also set the [[COLLATION_CONTEXT_TAG]] on the 
expression,
+   * so that the context can be reused later.
+   */
+  private def findCollationContext(expr: Expression): Option[CollationContext] 
= {
+
+    def getChildren: Seq[Expression] = expr match {
+      // we don't need to consider the struct name expressions
+      case struct: CreateNamedStruct => struct.valExprs
+      case _ => expr.children
+    }
+
+    val contextOpt = expr match {
+      case _ if hasCollationContextTag(expr) =>
+        Some(expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+
+      // if `expr` doesn't have a string in its dataType then it doesn't
+      // have the collation context either
+      case _ if !expr.dataType.existsRecursively(_.isInstanceOf[StringType]) =>
+        None
+
+      case collate: Collate =>
+        Some(CollationContext(collate.dataType, Explicit))
+
+      case _: Alias | _: SubqueryExpression | _: AttributeReference | _: 
VariableReference =>
+        Some(CollationContext(expr.dataType, Implicit))
+
+      case _: Literal =>
+        Some(CollationContext(expr.dataType, Default))
+
+      // if it does have a string type but none of its children do
+      // then the collation context strength is default
+      case _ if 
!expr.children.exists(_.dataType.existsRecursively(_.isInstanceOf[StringType])) 
=>
+        Some(CollationContext(expr.dataType, Default))
+
+      case extract: ExtractValue =>

Review Comment:
   No, let's say we have a map column and we do `mapCol['key' COLLATE 
UNICODE]`, since we are just retrieving the value, it shouldn't have explicit 
collation just because the key is explicit. That's why should only look at the 
child and not at all children.
   
   There is a test for this called `access collated map via literal`.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -193,88 +144,208 @@ object CollationTypeCoercion {
     case other => other
   }
 
+  /**
+   * If childType is collated and target is UTF8_BINARY, the collation of the 
output
+   * should be that of the childType.
+   */
+  private def shouldRemoveCast(cast: Cast): Boolean = {
+    val isUserDefined = cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined
+    val isChildTypeCollatedString = cast.child.dataType match {
+      case st: StringType => !st.isUTF8BinaryCollation
+      case _ => false
+    }
+    val targetType = cast.dataType
+
+    isUserDefined && isChildTypeCollatedString && targetType == StringType
+  }
+
   /**
    * Extracts StringTypes from filtered hasStringType
    */
   @tailrec
-  private def extractStringType(dt: DataType): StringType = dt match {
-    case st: StringType => st
+  private def extractStringType(dt: DataType): Option[StringType] = dt match {
+    case st: StringType => Some(st)
     case ArrayType(et, _) => extractStringType(et)
+    case _ => None
   }
 
   /**
    * Casts given expression to collated StringType with id equal to 
collationId only
    * if expression has StringType in the first place.
-   * @param expr
-   * @param collationId
-   * @return
    */
-  def castStringType(expr: Expression, st: StringType): Option[Expression] =
-    castStringType(expr.dataType, st).map { dt => Cast(expr, dt)}
+  def castStringType(expr: Expression, st: StringType): Expression = {
+    castStringType(expr.dataType, st)
+      .map(dt => Cast(expr, dt))
+      .getOrElse(expr)
+  }
 
   private def castStringType(inType: DataType, castType: StringType): 
Option[DataType] = {
-    @Nullable val ret: DataType = inType match {
-      case st: StringType if st.collationId != castType.collationId => castType
+    inType match {
+      case st: StringType if st.collationId != castType.collationId =>
+        Some(castType)
       case ArrayType(arrType, nullable) =>
-        castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull
-      case _ => null
+        castStringType(arrType, castType).map(ArrayType(_, nullable))
+      case _ => None
     }
-    Option(ret)
   }
 
   /**
    * Collates input expressions to a single collation.
    */
-  def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = {
-    val st = getOutputCollation(exprs)
+  def collateToSingleType(expressions: Seq[Expression]): Seq[Expression] = {
+    val lctOpt = findLeastCommonStringType(expressions)
 
-    exprs.map(e => castStringType(e, st).getOrElse(e))
+    lctOpt match {
+      case Some(lct) =>
+        expressions.map(e => castStringType(e, lct))
+      case _ =>
+        expressions
+    }
   }
 
   /**
-   * Based on the data types of the input expressions this method determines
-   * a collation type which the output will have. This function accepts Seq of
-   * any expressions, but will only be affected by collated StringTypes or
-   * complex DataTypes with collated StringTypes (e.g. ArrayType)
+   * Tries to find the least common StringType among the given expressions.
    */
-  def getOutputCollation(expr: Seq[Expression]): StringType = {
-    val explicitTypes = expr.filter {
-        case _: Collate => true
-        case _ => false
-      }
-      .map(_.dataType.asInstanceOf[StringType].collationId)
-      .distinct
-
-    explicitTypes.size match {
-      // We have 1 explicit collation
-      case 1 => StringType(explicitTypes.head)
-      // Multiple explicit collations occurred
-      case size if size > 1 =>
-        throw QueryCompilationErrors
-          .explicitCollationMismatchError(
-            explicitTypes.map(t => StringType(t))
-          )
-      // Only implicit or default collations present
-      case 0 =>
-        val implicitTypes = expr.filter {
-            case Literal(_, _: StringType) => false
-            case cast: Cast if 
cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty =>
-              cast.child.dataType.isInstanceOf[StringType]
-            case _ => true
-          }
-          .map(_.dataType)
-          .filter(hasStringType)
-          .map(extractStringType(_).collationId)
-          .distinct
-
-        if (implicitTypes.length > 1) {
-          throw QueryCompilationErrors.implicitCollationMismatchError(
-            implicitTypes.map(t => StringType(t))
-          )
+  private def findLeastCommonStringType(expressions: Seq[Expression]): 
Option[StringType] = {
+    if (!expressions.exists(e => 
SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) {
+      return None
+    }
+
+    val collationContextWinner = 
expressions.foldLeft(findCollationContext(expressions.head)) {
+      case (Some(left), right) =>
+        findCollationContext(right).flatMap { ctx =>
+          collationPrecedenceWinner(left, ctx)
         }
-        else {
-          
implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType)
+      case (None, _) => return None
+    }
+
+    collationContextWinner.flatMap { cc =>
+      extractStringType(cc.dataType)
+    }
+  }
+
+  /**
+   * Tries to find the collation context for the given expression.
+   * If found, it will also set the [[COLLATION_CONTEXT_TAG]] on the 
expression,
+   * so that the context can be reused later.
+   */
+  private def findCollationContext(expr: Expression): Option[CollationContext] 
= {
+
+    def getChildren: Seq[Expression] = expr match {
+      // we don't need to consider the struct name expressions
+      case struct: CreateNamedStruct => struct.valExprs
+      case _ => expr.children
+    }
+
+    val contextOpt = expr match {
+      case _ if hasCollationContextTag(expr) =>
+        Some(expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+
+      // if `expr` doesn't have a string in its dataType then it doesn't
+      // have the collation context either
+      case _ if !expr.dataType.existsRecursively(_.isInstanceOf[StringType]) =>
+        None
+
+      case collate: Collate =>
+        Some(CollationContext(collate.dataType, Explicit))
+
+      case _: Alias | _: SubqueryExpression | _: AttributeReference | _: 
VariableReference =>
+        Some(CollationContext(expr.dataType, Implicit))
+
+      case _: Literal =>
+        Some(CollationContext(expr.dataType, Default))
+
+      // if it does have a string type but none of its children do
+      // then the collation context strength is default
+      case _ if 
!expr.children.exists(_.dataType.existsRecursively(_.isInstanceOf[StringType])) 
=>
+        Some(CollationContext(expr.dataType, Default))
+
+      case extract: ExtractValue =>
+        findCollationContext(extract.child)
+          .map(cc => CollationContext(extract.dataType, cc.strength))
+
+      case _ =>
+        val contextWinnerOpt = getChildren
+          .flatMap(findCollationContext)
+          .foldLeft(Option.empty[CollationContext]) {
+            case (Some(left), right) =>
+              collationPrecedenceWinner(left, right)
+            case (None, right) =>
+              Some(right)
+          }
+
+        contextWinnerOpt.map { context =>
+          if (hasStringType(expr.dataType)) {
+            CollationContext(expr.dataType, context.strength)
+          } else {
+            context
+          }
         }
     }
+
+    contextOpt.foreach(expr.setTagValue(COLLATION_CONTEXT_TAG, _))
+    contextOpt
   }
+
+  /**
+   * Returns the collation context that wins in precedence between left and 
right.
+   */
+  private def collationPrecedenceWinner(
+      left: CollationContext,
+      right: CollationContext): Option[CollationContext] = {
+
+    val (leftStringType, rightStringType) =
+      (extractStringType(left.dataType), extractStringType(right.dataType)) 
match {
+        case (Some(l), Some(r)) =>
+          (l, r)
+        case (None, None) =>
+          return None
+        case (Some(_), None) =>

Review Comment:
   This can happen if the data type of the context is a map or a struct since 
the `extractStringType` only works for strings and arrays like you mentioned 
below. 
   
   However I'm not 100% sure we can hit this case as usually there is something 
above the struct or a map that only extracts the primitive field (like 
`ExtractValue`).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to