stefankandic commented on code in PR #48936:
URL: https://github.com/apache/spark/pull/48936#discussion_r1868261563
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -164,31 +160,89 @@ object CollationTypeCoercion {
}
/**
- * Extracts StringTypes from filtered hasStringType
+ * Changes the data type of the expression to the given `newType`.
*/
- @tailrec
- private def extractStringType(dt: DataType): Option[StringType] = dt match {
- case st: StringType => Some(st)
- case ArrayType(et, _) => extractStringType(et)
- case _ => None
+ private def changeType(expr: Expression, newType: DataType): Expression = {
+ mergeTypes(expr.dataType, newType) match {
+ case Some(newDataType) if newDataType != expr.dataType =>
+
assert(!newDataType.existsRecursively(_.isInstanceOf[StringTypeWithContext]))
+
+ val exprWithNewType = expr match {
+ case lit: Literal => lit.copy(dataType = newDataType)
+ case cast: Cast => cast.copy(dataType = newDataType)
+ case _ => Cast(expr, newDataType)
+ }
+
+ // also copy the collation context tag
+ if (hasCollationContextTag(expr)) {
+ exprWithNewType.setTagValue(
+ COLLATION_CONTEXT_TAG, expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+ }
+ exprWithNewType
+
+ case _ =>
+ expr
+ }
}
/**
- * Casts given expression to collated StringType with id equal to
collationId only
- * if expression has StringType in the first place.
+ * If possible, returns the new data type from `inType` by applying
+ * the collation of `castType`.
*/
- def castStringType(expr: Expression, st: StringType): Expression = {
- castStringType(expr.dataType, st)
- .map(dt => Cast(expr, dt))
- .getOrElse(expr)
+ private def mergeTypes(inType: DataType, castType: DataType):
Option[DataType] = {
+ val outType = mergeStructurally(inType, castType) {
+ case (_: StringType, right: StringTypeWithContext) =>
+ right.stringType
+ }
+
+ outType
}
- private def castStringType(inType: DataType, castType: StringType):
Option[DataType] = {
- inType match {
- case st: StringType if st.collationId != castType.collationId =>
- Some(castType)
- case ArrayType(arrType, nullable) =>
- castStringType(arrType, castType).map(ArrayType(_, nullable))
+ /**
+ * Merges two data types structurally according to the given base case.
+ */
+ private def mergeStructurally(
+ leftType: DataType,
+ rightType: DataType)
+ (baseCase: PartialFunction[(DataType, DataType), DataType]):
Option[DataType] = {
+ (leftType, rightType) match {
+
+ // handle the base cases first
+ case _ if baseCase.isDefinedAt((leftType, rightType)) =>
+ Option(baseCase(leftType, rightType))
+
+ case _ if leftType == rightType =>
+ Some(leftType)
+
+ case (ArrayType(leftElemType, nullable), ArrayType(rightElemType, _)) =>
+ mergeStructurally(leftElemType, rightElemType)(
baseCase).map(ArrayType(_, nullable))
Review Comment:
Let's say you have an expression `e` with children `c1, c2, c3`. You will
find the least common type for them `t = LCT(c1, c2, c3)`. Then you should
apply that type to all children (bear in mind that some of these children can
be complex while others can be primitive types).
When applying the new type `t` we should not change the nullability of the
type that the children had originally, and that's why I am doing it this way.
Not sure if I could have named the parameters better, but the `left` and
`right` type are not equal, we are trying to merge the collations of the right
one into the left and keep everything else the same.
It wouldn't make sense for one of the children to suddenly become nullable
after coercion or the other way around.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -164,31 +160,89 @@ object CollationTypeCoercion {
}
/**
- * Extracts StringTypes from filtered hasStringType
+ * Changes the data type of the expression to the given `newType`.
*/
- @tailrec
- private def extractStringType(dt: DataType): Option[StringType] = dt match {
- case st: StringType => Some(st)
- case ArrayType(et, _) => extractStringType(et)
- case _ => None
+ private def changeType(expr: Expression, newType: DataType): Expression = {
+ mergeTypes(expr.dataType, newType) match {
+ case Some(newDataType) if newDataType != expr.dataType =>
+
assert(!newDataType.existsRecursively(_.isInstanceOf[StringTypeWithContext]))
+
+ val exprWithNewType = expr match {
+ case lit: Literal => lit.copy(dataType = newDataType)
+ case cast: Cast => cast.copy(dataType = newDataType)
+ case _ => Cast(expr, newDataType)
+ }
+
+ // also copy the collation context tag
+ if (hasCollationContextTag(expr)) {
+ exprWithNewType.setTagValue(
+ COLLATION_CONTEXT_TAG, expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+ }
+ exprWithNewType
+
+ case _ =>
+ expr
+ }
}
/**
- * Casts given expression to collated StringType with id equal to
collationId only
- * if expression has StringType in the first place.
+ * If possible, returns the new data type from `inType` by applying
+ * the collation of `castType`.
*/
- def castStringType(expr: Expression, st: StringType): Expression = {
- castStringType(expr.dataType, st)
- .map(dt => Cast(expr, dt))
- .getOrElse(expr)
+ private def mergeTypes(inType: DataType, castType: DataType):
Option[DataType] = {
+ val outType = mergeStructurally(inType, castType) {
+ case (_: StringType, right: StringTypeWithContext) =>
+ right.stringType
+ }
+
+ outType
}
- private def castStringType(inType: DataType, castType: StringType):
Option[DataType] = {
- inType match {
- case st: StringType if st.collationId != castType.collationId =>
- Some(castType)
- case ArrayType(arrType, nullable) =>
- castStringType(arrType, castType).map(ArrayType(_, nullable))
+ /**
+ * Merges two data types structurally according to the given base case.
+ */
+ private def mergeStructurally(
+ leftType: DataType,
+ rightType: DataType)
+ (baseCase: PartialFunction[(DataType, DataType), DataType]):
Option[DataType] = {
+ (leftType, rightType) match {
+
+ // handle the base cases first
+ case _ if baseCase.isDefinedAt((leftType, rightType)) =>
+ Option(baseCase(leftType, rightType))
+
+ case _ if leftType == rightType =>
+ Some(leftType)
+
+ case (ArrayType(leftElemType, nullable), ArrayType(rightElemType, _)) =>
+ mergeStructurally(leftElemType, rightElemType)(
baseCase).map(ArrayType(_, nullable))
+
+ case (MapType(leftKey, leftValue, nullable), MapType(rightKey,
rightValue, _)) =>
Review Comment:
answered above
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]