Github user eric-maynard commented on a diff in the pull request: https://github.com/apache/spark/pull/21470#discussion_r192819128 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala --- @@ -803,18 +803,60 @@ object TypeCoercion { e.copy(left = Cast(e.left, TimestampType)) } - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonType(left.dataType, right.dataType).map { commonType => - if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tightest common type, cast to that. - val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) - val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.withNewChildren(Seq(newLeft, newRight)) - } else { - // Otherwise, don't do anything with the expression. - b - } - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + case b @ BinaryOperator(left, right) + if !BinaryOperator.sameType(left.dataType, right.dataType) => + (left.dataType, right.dataType) match { + case (StructType(fields1), StructType(fields2)) => + val commonTypes = scala.collection.mutable.ArrayBuffer.empty[DataType] + val len = fields1.length + var i = 0 + var continue = fields1.length == fields2.length + while (i < len && continue) { + val commonType = findTightestCommonType(fields1(i).dataType, fields2(i).dataType) + if (commonType.isDefined) { + commonTypes += commonType.get + } else { + continue = false + } + i += 1 + } + + if (continue) { + val newLeftST = new StructType(fields1.zip(commonTypes).map { + case (f, commonType) => f.copy(dataType = commonType) + }) + val newLeft = if (left.dataType == newLeftST) left else Cast(left, newLeftST) + + val newRightST = new StructType(fields2.zip(commonTypes).map { + case (f, commonType) => f.copy(dataType = commonType) + }) + val newRight = if (right.dataType == newRightST) right else Cast(right, newRightST) + + if (b.inputType.acceptsType(newLeftST) && b.inputType.acceptsType(newRightST)) { + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // type not acceptable, don't do anything with the expression. + b + } + } else { + // left struct type and right struct type have different number of fields, or some + // fields don't have a common type, don't do anything with the expression. + b + } + + case _ => + findTightestCommonType(left.dataType, right.dataType).map { commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) --- End diff -- This ternary operation seems to crop up a few times in this PR. Maybe we can push it out into a method? ``` private def castIfNeeded(e: Expression, possibleType: DataType): Expression = { if (e.dataType == possibleType) data else Cast(e, possibleType) } ```
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org