This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7290000d51d7 [SPARK-48035][SQL] Fix try_add/try_multiply being semantic equal to add/multiply 7290000d51d7 is described below commit 7290000d51d72ad3a3fb395a7d1975c84b8f8df4 Author: Supun Nakandala <supun.nakand...@databricks.com> AuthorDate: Tue May 7 10:02:27 2024 +0900 [SPARK-48035][SQL] Fix try_add/try_multiply being semantic equal to add/multiply ### What changes were proposed in this pull request? - This PR fixes a correctness bug in commutative operator canonicalization where we currently do not take into account the evaluation mode during operand reordering. - As a result, the following condition will be incorrectly true: ``` val l1 = Literal(1) val l2 = Literal(2) val l3 = Literal(3) val expr1 = Add(Add(l1, l2), l3) val expr2 = Add(Add(l2, l1, EvalMode.TRY), l3) expr1.semanticEquals(expr2) ``` - To fix the issue, we now reorder commutative operands only if all operators have the same evaluation mode. ### Why are the changes needed? - To fix a correctness bug. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Added unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #46307 from db-scnakandala/db-scnakandala/master. Authored-by: Supun Nakandala <supun.nakand...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../sql/catalyst/expressions/Expression.scala | 14 ++++++++++++ .../sql/catalyst/expressions/arithmetic.scala | 23 ++++++++++++-------- .../catalyst/expressions/CanonicalizeSuite.scala | 25 ++++++++++++++++++++++ 3 files changed, 53 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index de15ec43c4f3..2759f5a29c79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -1378,6 +1378,20 @@ trait CommutativeExpression extends Expression { } reorderResult } + + /** + * Helper method to collect the evaluation mode of the commutative expressions. This is + * used by the canonicalized methods of [[Add]] and [[Multiply]] operators to ensure that + * all operands have the same evaluation mode before reordering the operands. + */ + protected def collectEvalModes( + e: Expression, + f: PartialFunction[CommutativeExpression, Seq[EvalMode.Value]] + ): Seq[EvalMode.Value] = e match { + case c: CommutativeExpression if f.isDefinedAt(c) => + f(c) ++ c.children.flatMap(collectEvalModes(_, f)) + case _ => Nil + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9eecf81684ce..91c10a53af8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -452,13 +452,14 @@ case class Add( copy(left = newLeft, right = newRight) override lazy val canonicalized: Expression = { - // TODO: do not reorder consecutive `Add`s with different `evalMode` - val reorderResult = buildCanonicalizedPlan( + val evalModes = collectEvalModes(this, {case Add(_, _, evalMode) => Seq(evalMode)}) + lazy val reorderResult = buildCanonicalizedPlan( { case Add(l, r, _) => Seq(l, r) }, { case (l: Expression, r: Expression) => Add(l, r, evalMode)}, Some(evalMode) ) - if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) { + if (resolved && evalModes.forall(_ == evalMode) && reorderResult.resolved && + reorderResult.dataType == dataType) { reorderResult } else { // SPARK-40903: Avoid reordering decimal Add for canonicalization if the result data type is @@ -608,12 +609,16 @@ case class Multiply( newLeft: Expression, newRight: Expression): Multiply = copy(left = newLeft, right = newRight) override lazy val canonicalized: Expression = { - // TODO: do not reorder consecutive `Multiply`s with different `evalMode` - buildCanonicalizedPlan( - { case Multiply(l, r, _) => Seq(l, r) }, - { case (l: Expression, r: Expression) => Multiply(l, r, evalMode)}, - Some(evalMode) - ) + val evalModes = collectEvalModes(this, {case Multiply(_, _, evalMode) => Seq(evalMode)}) + if (evalModes.forall(_ == evalMode)) { + buildCanonicalizedPlan( + { case Multiply(l, r, _) => Seq(l, r) }, + { case (l: Expression, r: Expression) => Multiply(l, r, evalMode)}, + Some(evalMode) + ) + } else { + withCanonicalizedChildren + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 3366d99dd75e..7e545d332105 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -454,4 +454,29 @@ class CanonicalizeSuite extends SparkFunSuite { // different. assert(common3.canonicalized != common4.canonicalized) } + + test("SPARK-48035: Add/Multiply operator canonicalization should take into account the" + + "evaluation mode of the operands before operand reordering") { + Seq(1, 10) map { multiCommutativeOpOptThreshold => + val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD) + SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, + multiCommutativeOpOptThreshold.toString) + try { + val l1 = Literal(1) + val l2 = Literal(2) + val l3 = Literal(3) + + val expr1 = Add(Add(l1, l2), l3) + val expr2 = Add(Add(l2, l1, EvalMode.TRY), l3) + assert(!expr1.semanticEquals(expr2)) + + val expr3 = Multiply(Multiply(l1, l2), l3) + val expr4 = Multiply(Multiply(l2, l1, EvalMode.TRY), l3) + assert(!expr3.semanticEquals(expr4)) + } finally { + SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, + default.toString) + } + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org