Github user eric-maynard commented on a diff in the pull request:
https://github.com/apache/spark/pull/21470#discussion_r192819128
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
---
@@ -803,18 +803,60 @@ object TypeCoercion {
e.copy(left = Cast(e.left, TimestampType))
}
- case b @ BinaryOperator(left, right) if left.dataType !=
right.dataType =>
- findTightestCommonType(left.dataType, right.dataType).map {
commonType =>
- if (b.inputType.acceptsType(commonType)) {
- // If the expression accepts the tightest common type, cast to
that.
- val newLeft = if (left.dataType == commonType) left else
Cast(left, commonType)
- val newRight = if (right.dataType == commonType) right else
Cast(right, commonType)
- b.withNewChildren(Seq(newLeft, newRight))
- } else {
- // Otherwise, don't do anything with the expression.
- b
- }
- }.getOrElse(b) // If there is no applicable conversion, leave
expression unchanged.
+ case b @ BinaryOperator(left, right)
+ if !BinaryOperator.sameType(left.dataType, right.dataType) =>
+ (left.dataType, right.dataType) match {
+ case (StructType(fields1), StructType(fields2)) =>
+ val commonTypes =
scala.collection.mutable.ArrayBuffer.empty[DataType]
+ val len = fields1.length
+ var i = 0
+ var continue = fields1.length == fields2.length
+ while (i < len && continue) {
+ val commonType = findTightestCommonType(fields1(i).dataType,
fields2(i).dataType)
+ if (commonType.isDefined) {
+ commonTypes += commonType.get
+ } else {
+ continue = false
+ }
+ i += 1
+ }
+
+ if (continue) {
+ val newLeftST = new StructType(fields1.zip(commonTypes).map {
+ case (f, commonType) => f.copy(dataType = commonType)
+ })
+ val newLeft = if (left.dataType == newLeftST) left else
Cast(left, newLeftST)
+
+ val newRightST = new StructType(fields2.zip(commonTypes).map
{
+ case (f, commonType) => f.copy(dataType = commonType)
+ })
+ val newRight = if (right.dataType == newRightST) right else
Cast(right, newRightST)
+
+ if (b.inputType.acceptsType(newLeftST) &&
b.inputType.acceptsType(newRightST)) {
+ b.withNewChildren(Seq(newLeft, newRight))
+ } else {
+ // type not acceptable, don't do anything with the
expression.
+ b
+ }
+ } else {
+ // left struct type and right struct type have different
number of fields, or some
+ // fields don't have a common type, don't do anything with
the expression.
+ b
+ }
+
+ case _ =>
+ findTightestCommonType(left.dataType, right.dataType).map {
commonType =>
+ if (b.inputType.acceptsType(commonType)) {
+ // If the expression accepts the tightest common type,
cast to that.
+ val newLeft = if (left.dataType == commonType) left else
Cast(left, commonType)
--- End diff --
This ternary operation seems to crop up a few times in this PR. Maybe we
can push it out into a method?
```
private def castIfNeeded(e: Expression, possibleType: DataType): Expression
= {
if (e.dataType == possibleType) data else Cast(e, possibleType)
}
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]