tedyu commented on code in PR #48936:
URL: https://github.com/apache/spark/pull/48936#discussion_r1869917985


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -164,31 +160,89 @@ object CollationTypeCoercion {
   }
 
   /**
-   * Extracts StringTypes from filtered hasStringType
+   * Changes the data type of the expression to the given `newType`.
    */
-  @tailrec
-  private def extractStringType(dt: DataType): Option[StringType] = dt match {
-    case st: StringType => Some(st)
-    case ArrayType(et, _) => extractStringType(et)
-    case _ => None
+  private def changeType(expr: Expression, newType: DataType): Expression = {
+    mergeTypes(expr.dataType, newType) match {
+      case Some(newDataType) if newDataType != expr.dataType =>
+        
assert(!newDataType.existsRecursively(_.isInstanceOf[StringTypeWithContext]))
+
+        val exprWithNewType = expr match {
+          case lit: Literal => lit.copy(dataType = newDataType)
+          case cast: Cast => cast.copy(dataType = newDataType)
+          case _ => Cast(expr, newDataType)
+        }
+
+        // also copy the collation context tag
+        if (hasCollationContextTag(expr)) {
+          exprWithNewType.setTagValue(
+            COLLATION_CONTEXT_TAG, expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+        }
+        exprWithNewType
+
+      case _ =>
+        expr
+    }
   }
 
   /**
-   * Casts given expression to collated StringType with id equal to 
collationId only
-   * if expression has StringType in the first place.
+   * If possible, returns the new data type from `inType` by applying
+   * the collation of `castType`.
    */
-  def castStringType(expr: Expression, st: StringType): Expression = {
-    castStringType(expr.dataType, st)
-      .map(dt => Cast(expr, dt))
-      .getOrElse(expr)
+  private def mergeTypes(inType: DataType, castType: DataType): 
Option[DataType] = {
+    val outType = mergeStructurally(inType, castType) {
+      case (_: StringType, right: StringTypeWithContext) =>
+        right.stringType
+    }
+
+    outType
   }
 
-  private def castStringType(inType: DataType, castType: StringType): 
Option[DataType] = {
-    inType match {
-      case st: StringType if st.collationId != castType.collationId =>
-        Some(castType)
-      case ArrayType(arrType, nullable) =>
-        castStringType(arrType, castType).map(ArrayType(_, nullable))
+  /**
+   * Merges two data types structurally according to the given base case.
+   */
+  private def mergeStructurally(
+      leftType: DataType,
+      rightType: DataType)
+      (baseCase: PartialFunction[(DataType, DataType), DataType]): 
Option[DataType] = {
+    (leftType, rightType) match {
+
+      // handle the base cases first
+      case _ if baseCase.isDefinedAt((leftType, rightType)) =>
+        Option(baseCase(leftType, rightType))
+
+      case _ if leftType == rightType =>
+        Some(leftType)
+
+      case (ArrayType(leftElemType, nullable), ArrayType(rightElemType, _)) =>
+        mergeStructurally(leftElemType, rightElemType)( 
baseCase).map(ArrayType(_, nullable))

Review Comment:
   Makes sense.
   Maybe name the parameters `original` and `update`: Implies right contains 
updates applied to left.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to