Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/22857#discussion_r236098841
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
---
@@ -736,3 +736,60 @@ object CombineConcats extends Rule[LogicalPlan] {
flattenConcats(concat)
}
}
+
+/**
+ * A rule that replaces `Literal(null, _)` with `FalseLiteral` for further
optimizations.
+ *
+ * This rule applies to conditions in [[Filter]] and [[Join]]. Moreover,
it transforms predicates
+ * in all [[If]] expressions as well as branch conditions in all
[[CaseWhen]] expressions.
+ *
+ * For example, `Filter(Literal(null, _))` is equal to
`Filter(FalseLiteral)`.
+ *
+ * Another example containing branches is `Filter(If(cond, FalseLiteral,
Literal(null, _)))`;
+ * this can be optimized to `Filter(If(cond, FalseLiteral,
FalseLiteral))`, and eventually
+ * `Filter(FalseLiteral)`.
+ *
+ * As this rule is not limited to conditions in [[Filter]] and [[Join]],
arbitrary plans can
+ * benefit from it. For example, `Project(If(And(cond, Literal(null)),
Literal(1), Literal(2)))`
+ * can be simplified into `Project(Literal(2))`.
+ *
+ * As a result, many unnecessary computations can be removed in the query
optimization phase.
+ */
+object ReplaceNullWithFalse extends Rule[LogicalPlan] {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case f @ Filter(cond, _) => f.copy(condition =
replaceNullWithFalse(cond))
+ case j @ Join(_, _, _, Some(cond)) => j.copy(condition =
Some(replaceNullWithFalse(cond)))
+ case p: LogicalPlan => p transformExpressions {
+ case i @ If(pred, _, _) => i.copy(predicate =
replaceNullWithFalse(pred))
+ case cw @ CaseWhen(branches, _) =>
+ val newBranches = branches.map { case (cond, value) =>
+ replaceNullWithFalse(cond) -> value
+ }
+ cw.copy(branches = newBranches)
+ }
+ }
+
+ /**
+ * Recursively replaces `Literal(null, _)` with `FalseLiteral`.
+ *
+ * Note that `transformExpressionsDown` can not be used here as we must
stop as soon as we hit
+ * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or
`Literal(null, _)`.
+ */
+ private def replaceNullWithFalse(e: Expression): Expression = e match {
+ case cw: CaseWhen if cw.dataType == BooleanType =>
+ val newBranches = cw.branches.map { case (cond, value) =>
+ replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
+ }
+ val newElseValue = cw.elseValue.map(replaceNullWithFalse)
+ CaseWhen(newBranches, newElseValue)
+ case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
+ If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal),
replaceNullWithFalse(falseVal))
+ case And(left, right) =>
+ And(replaceNullWithFalse(left), replaceNullWithFalse(right))
+ case Or(left, right) =>
+ Or(replaceNullWithFalse(left), replaceNullWithFalse(right))
+ case Literal(null, _) => FalseLiteral
--- End diff --
Here, for safety, we should check the data types.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]