wangyum commented on a change in pull request #30853:
URL: https://github.com/apache/spark/pull/30853#discussion_r546583196
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
##########
@@ -542,6 +542,18 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan]
with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
+ case Cast(i @ If(_, trueValue, falseValue), dataType, timeZoneId)
Review comment:
Could we split it to 2 PRs, this is because we can not use `makeCopy`
for `Cast` `Expression`:
```
Is otherCopyArgs specified correctly for Cast.
Exception message: wrong number of arguments
ctor: public
org.apache.spark.sql.catalyst.expressions.Cast(org.apache.spark.sql.catalyst.expressions.Expression,org.apache.spark.sql.types.DataType,scala.Option)?
types: class org.apache.spark.sql.catalyst.expressions.Literal
args: 1
, tree: cast(if ((id#1L = 1)) 1 else 2 as int)
at
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$makeCopy$1(TreeNode.scala:515)
at
org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
... 129 more
```
Maybe the complete change is like this:
```
case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
if
u.getClass.getConstructors.headOption.forall(_.getParameterCount == 1) &&
atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(trueValue = u.makeCopy(Array(trueValue)), falseValue =
u.makeCopy(Array(falseValue)))
case u @ UnaryExpression(c @ CaseWhen(branches, elseValue))
if
u.getClass.getConstructors.headOption.forall(_.getParameterCount == 1) &&
atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = u.makeCopy(Array(e._2)))),
elseValue.map(e => u.makeCopy(Array(e))))
case Cast(i @ If(_, trueValue, falseValue), dataType, timeZoneId)
if atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = Cast(trueValue, dataType, timeZoneId),
falseValue = Cast(falseValue, dataType, timeZoneId))
case Cast(c @ CaseWhen(branches, elseValue), dataType, timeZoneId)
if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = Cast(e._2, dataType, timeZoneId))),
elseValue.map(e => Cast(e, dataType, timeZoneId)))
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]