MaxGekk commented on code in PR #48936:
URL: https://github.com/apache/spark/pull/48936#discussion_r1868194153


##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala:
##########
@@ -450,7 +450,10 @@ class CollationSQLRegexpSuite
         },
         condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
         parameters = Map(
-          "sqlExpr" -> "\"regexp_replace(collate(ABCDE, UNICODE_CI), .c., FFF, 
1)\"",
+          "sqlExpr" ->
+            """
+              |"regexp_replace(collate(ABCDE, UNICODE_CI), .c., 'FFF' collate 
UNICODE_CI, 1)"
+              |""".stripMargin.trim,

Review Comment:
   ```suggestion
             "sqlExpr" ->
               """"regexp_replace(collate(ABCDE, UNICODE_CI), .c., 'FFF' 
collate UNICODE_CI, 1)"""",
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -164,31 +160,89 @@ object CollationTypeCoercion {
   }
 
   /**
-   * Extracts StringTypes from filtered hasStringType
+   * Changes the data type of the expression to the given `newType`.
    */
-  @tailrec
-  private def extractStringType(dt: DataType): Option[StringType] = dt match {
-    case st: StringType => Some(st)
-    case ArrayType(et, _) => extractStringType(et)
-    case _ => None
+  private def changeType(expr: Expression, newType: DataType): Expression = {
+    mergeTypes(expr.dataType, newType) match {
+      case Some(newDataType) if newDataType != expr.dataType =>
+        
assert(!newDataType.existsRecursively(_.isInstanceOf[StringTypeWithContext]))
+
+        val exprWithNewType = expr match {
+          case lit: Literal => lit.copy(dataType = newDataType)
+          case cast: Cast => cast.copy(dataType = newDataType)
+          case _ => Cast(expr, newDataType)
+        }
+
+        // also copy the collation context tag
+        if (hasCollationContextTag(expr)) {
+          exprWithNewType.setTagValue(
+            COLLATION_CONTEXT_TAG, expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+        }
+        exprWithNewType
+
+      case _ =>
+        expr
+    }
   }
 
   /**
-   * Casts given expression to collated StringType with id equal to 
collationId only
-   * if expression has StringType in the first place.
+   * If possible, returns the new data type from `inType` by applying
+   * the collation of `castType`.
    */
-  def castStringType(expr: Expression, st: StringType): Expression = {
-    castStringType(expr.dataType, st)
-      .map(dt => Cast(expr, dt))
-      .getOrElse(expr)
+  private def mergeTypes(inType: DataType, castType: DataType): 
Option[DataType] = {
+    val outType = mergeStructurally(inType, castType) {
+      case (_: StringType, right: StringTypeWithContext) =>
+        right.stringType
+    }
+
+    outType
   }
 
-  private def castStringType(inType: DataType, castType: StringType): 
Option[DataType] = {
-    inType match {
-      case st: StringType if st.collationId != castType.collationId =>
-        Some(castType)
-      case ArrayType(arrType, nullable) =>
-        castStringType(arrType, castType).map(ArrayType(_, nullable))
+  /**
+   * Merges two data types structurally according to the given base case.
+   */
+  private def mergeStructurally(
+      leftType: DataType,
+      rightType: DataType)
+      (baseCase: PartialFunction[(DataType, DataType), DataType]): 
Option[DataType] = {
+    (leftType, rightType) match {
+
+      // handle the base cases first
+      case _ if baseCase.isDefinedAt((leftType, rightType)) =>
+        Option(baseCase(leftType, rightType))
+
+      case _ if leftType == rightType =>
+        Some(leftType)
+
+      case (ArrayType(leftElemType, nullable), ArrayType(rightElemType, _)) =>
+        mergeStructurally(leftElemType, rightElemType)( 
baseCase).map(ArrayType(_, nullable))

Review Comment:
   ```suggestion
           mergeStructurally(leftElemType, 
rightElemType)(baseCase).map(ArrayType(_, nullable))
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -164,31 +160,89 @@ object CollationTypeCoercion {
   }
 
   /**
-   * Extracts StringTypes from filtered hasStringType
+   * Changes the data type of the expression to the given `newType`.
    */
-  @tailrec
-  private def extractStringType(dt: DataType): Option[StringType] = dt match {
-    case st: StringType => Some(st)
-    case ArrayType(et, _) => extractStringType(et)
-    case _ => None
+  private def changeType(expr: Expression, newType: DataType): Expression = {
+    mergeTypes(expr.dataType, newType) match {
+      case Some(newDataType) if newDataType != expr.dataType =>
+        
assert(!newDataType.existsRecursively(_.isInstanceOf[StringTypeWithContext]))
+
+        val exprWithNewType = expr match {
+          case lit: Literal => lit.copy(dataType = newDataType)
+          case cast: Cast => cast.copy(dataType = newDataType)
+          case _ => Cast(expr, newDataType)
+        }
+
+        // also copy the collation context tag
+        if (hasCollationContextTag(expr)) {
+          exprWithNewType.setTagValue(
+            COLLATION_CONTEXT_TAG, expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+        }
+        exprWithNewType
+
+      case _ =>
+        expr
+    }
   }
 
   /**
-   * Casts given expression to collated StringType with id equal to 
collationId only
-   * if expression has StringType in the first place.
+   * If possible, returns the new data type from `inType` by applying
+   * the collation of `castType`.
    */
-  def castStringType(expr: Expression, st: StringType): Expression = {
-    castStringType(expr.dataType, st)
-      .map(dt => Cast(expr, dt))
-      .getOrElse(expr)
+  private def mergeTypes(inType: DataType, castType: DataType): 
Option[DataType] = {
+    val outType = mergeStructurally(inType, castType) {
+      case (_: StringType, right: StringTypeWithContext) =>
+        right.stringType
+    }
+
+    outType
   }
 
-  private def castStringType(inType: DataType, castType: StringType): 
Option[DataType] = {
-    inType match {
-      case st: StringType if st.collationId != castType.collationId =>
-        Some(castType)
-      case ArrayType(arrType, nullable) =>
-        castStringType(arrType, castType).map(ArrayType(_, nullable))
+  /**
+   * Merges two data types structurally according to the given base case.
+   */
+  private def mergeStructurally(
+      leftType: DataType,
+      rightType: DataType)
+      (baseCase: PartialFunction[(DataType, DataType), DataType]): 
Option[DataType] = {
+    (leftType, rightType) match {
+
+      // handle the base cases first
+      case _ if baseCase.isDefinedAt((leftType, rightType)) =>
+        Option(baseCase(leftType, rightType))
+
+      case _ if leftType == rightType =>
+        Some(leftType)
+
+      case (ArrayType(leftElemType, nullable), ArrayType(rightElemType, _)) =>
+        mergeStructurally(leftElemType, rightElemType)( 
baseCase).map(ArrayType(_, nullable))

Review Comment:
   Could you clarify merging `nullable`, so, if the left array type has 
`nullable` is `false`, but the right one has `true`, you propose to set it to 
`false`? I would expect something like `leftNullable || rightNullable`.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -223,151 +277,195 @@ object CollationTypeCoercion {
     val collationContextWinner = 
expressions.foldLeft(findCollationContext(expressions.head)) {
       case (Some(left), right) =>
         findCollationContext(right).flatMap { ctx =>
-          collationPrecedenceWinner(left, ctx)
+          mergeWinner(left, ctx)
         }
       case (None, _) => return None

Review Comment:
   Not related to the PR, but `return` is not necessary, it seems.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -164,31 +160,89 @@ object CollationTypeCoercion {
   }
 
   /**
-   * Extracts StringTypes from filtered hasStringType
+   * Changes the data type of the expression to the given `newType`.
    */
-  @tailrec
-  private def extractStringType(dt: DataType): Option[StringType] = dt match {
-    case st: StringType => Some(st)
-    case ArrayType(et, _) => extractStringType(et)
-    case _ => None
+  private def changeType(expr: Expression, newType: DataType): Expression = {
+    mergeTypes(expr.dataType, newType) match {
+      case Some(newDataType) if newDataType != expr.dataType =>
+        
assert(!newDataType.existsRecursively(_.isInstanceOf[StringTypeWithContext]))
+
+        val exprWithNewType = expr match {
+          case lit: Literal => lit.copy(dataType = newDataType)
+          case cast: Cast => cast.copy(dataType = newDataType)
+          case _ => Cast(expr, newDataType)
+        }
+
+        // also copy the collation context tag
+        if (hasCollationContextTag(expr)) {
+          exprWithNewType.setTagValue(
+            COLLATION_CONTEXT_TAG, expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+        }
+        exprWithNewType
+
+      case _ =>
+        expr
+    }
   }
 
   /**
-   * Casts given expression to collated StringType with id equal to 
collationId only
-   * if expression has StringType in the first place.
+   * If possible, returns the new data type from `inType` by applying
+   * the collation of `castType`.
    */
-  def castStringType(expr: Expression, st: StringType): Expression = {
-    castStringType(expr.dataType, st)
-      .map(dt => Cast(expr, dt))
-      .getOrElse(expr)
+  private def mergeTypes(inType: DataType, castType: DataType): 
Option[DataType] = {
+    val outType = mergeStructurally(inType, castType) {
+      case (_: StringType, right: StringTypeWithContext) =>
+        right.stringType
+    }
+
+    outType
   }
 
-  private def castStringType(inType: DataType, castType: StringType): 
Option[DataType] = {
-    inType match {
-      case st: StringType if st.collationId != castType.collationId =>
-        Some(castType)
-      case ArrayType(arrType, nullable) =>
-        castStringType(arrType, castType).map(ArrayType(_, nullable))
+  /**
+   * Merges two data types structurally according to the given base case.
+   */
+  private def mergeStructurally(
+      leftType: DataType,
+      rightType: DataType)
+      (baseCase: PartialFunction[(DataType, DataType), DataType]): 
Option[DataType] = {
+    (leftType, rightType) match {
+
+      // handle the base cases first
+      case _ if baseCase.isDefinedAt((leftType, rightType)) =>
+        Option(baseCase(leftType, rightType))
+
+      case _ if leftType == rightType =>
+        Some(leftType)
+
+      case (ArrayType(leftElemType, nullable), ArrayType(rightElemType, _)) =>
+        mergeStructurally(leftElemType, rightElemType)( 
baseCase).map(ArrayType(_, nullable))
+
+      case (MapType(leftKey, leftValue, nullable), MapType(rightKey, 
rightValue, _)) =>

Review Comment:
   the same question regarding `nullable`.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -223,151 +277,195 @@ object CollationTypeCoercion {
     val collationContextWinner = 
expressions.foldLeft(findCollationContext(expressions.head)) {
       case (Some(left), right) =>
         findCollationContext(right).flatMap { ctx =>
-          collationPrecedenceWinner(left, ctx)
+          mergeWinner(left, ctx)
         }
       case (None, _) => return None
     }
-
-    collationContextWinner.flatMap { cc =>
-      extractStringType(cc.dataType)
-    }
+    collationContextWinner
   }
 
   /**
    * Tries to find the collation context for the given expression.
    * If found, it will also set the [[COLLATION_CONTEXT_TAG]] on the 
expression,
    * so that the context can be reused later.
    */
-  private def findCollationContext(expr: Expression): Option[CollationContext] 
= {
+  private def findCollationContext(expr: Expression): Option[DataType] = {

Review Comment:
   Should we adjust the comment above regarding the return type?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to