This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new c961e7c [SPARK-27917][SQL][BACKPORT-2.4] canonical form of CaseWhen object is incorrect c961e7c is described below commit c961e7caaebb3cbb593230c450b1518b5807bed6 Author: sandeep katta <sandeep.katta2...@gmail.com> AuthorDate: Tue Jun 11 07:14:24 2019 -0700 [SPARK-27917][SQL][BACKPORT-2.4] canonical form of CaseWhen object is incorrect ## What changes were proposed in this pull request? For caseWhen Object canonicalized is not handled for e.g let's consider below CaseWhen Object val attrRef = AttributeReference("ACCESS_CHECK", StringType)() val caseWhenObj1 = CaseWhen(Seq((attrRef, Literal("A")))) caseWhenObj1.canonicalized **ouput** is as below CASE WHEN ACCESS_CHECK#0 THEN A END (**Before Fix)** **After Fix** : CASE WHEN none#0 THEN A END So when there will be aliasref like below statements, semantic equals will fail. Sematic equals returns true if the canonicalized form of both the expressions are same. val attrRef = AttributeReference("ACCESS_CHECK", StringType)() val aliasAttrRef = attrRef.withName("access_check") val caseWhenObj1 = CaseWhen(Seq((attrRef, Literal("A")))) val caseWhenObj2 = CaseWhen(Seq((aliasAttrRef, Literal("A")))) **assert(caseWhenObj2.semanticEquals(caseWhenObj1.semanticEquals) fails** **caseWhenObj1.canonicalized** Before Fix:CASE WHEN ACCESS_CHECK#0 THEN A END After Fix: CASE WHEN none#0 THEN A END **caseWhenObj2.canonicalized** Before Fix:CASE WHEN access_check#0 THEN A END After Fix: CASE WHEN none#0 THEN A END ## How was this patch tested? Added UT Closes #24836 from sandeep-katta/spark2.4. Authored-by: sandeep katta <sandeep.katta2...@gmail.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 3 +++ .../expressions/ConditionalExpressionSuite.scala | 24 ++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index becfa8d..e3b8b08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -211,6 +211,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } def mapChild(child: Any): Any = child match { case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) + // CaseWhen Case or any tuple type + case (left, right) => (mapChild(left), mapChild(right)) case nonChild: AnyRef => nonChild case null => null } @@ -226,6 +228,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // `mapValues` is lazy and we need to force it to materialize m.mapValues(mapChild).view.force case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) + case Some(child) => Some(mapChild(child)) case nonChild: AnyRef => nonChild case null => null } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index f489d33..2e1858c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -222,4 +222,28 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx) assert(ctx.inlinedMutableStates.size == 1) } + + test("SPARK-27917 test semantic equals of CaseWhen") { + val attrRef = AttributeReference("ACCESS_CHECK", StringType)() + val aliasAttrRef = attrRef.withName("access_check") + // Test for Equality + var caseWhenObj1 = CaseWhen(Seq((attrRef, Literal("A")))) + var caseWhenObj2 = CaseWhen(Seq((aliasAttrRef, Literal("A")))) + assert(caseWhenObj1.semanticEquals(caseWhenObj2)) + assert(caseWhenObj2.semanticEquals(caseWhenObj1)) + // Test for inEquality + caseWhenObj2 = CaseWhen(Seq((attrRef, Literal("a")))) + assert(!caseWhenObj1.semanticEquals(caseWhenObj2)) + assert(!caseWhenObj2.semanticEquals(caseWhenObj1)) + // Test with elseValue with Equality + caseWhenObj1 = CaseWhen(Seq((attrRef, Literal("A"))), attrRef.withName("ELSEVALUE")) + caseWhenObj2 = CaseWhen(Seq((aliasAttrRef, Literal("A"))), aliasAttrRef.withName("elsevalue")) + assert(caseWhenObj1.semanticEquals(caseWhenObj2)) + assert(caseWhenObj2.semanticEquals(caseWhenObj1)) + caseWhenObj1 = CaseWhen(Seq((attrRef, Literal("A"))), Literal("ELSEVALUE")) + caseWhenObj2 = CaseWhen(Seq((aliasAttrRef, Literal("A"))), Literal("elsevalue")) + // Test with elseValue with inEquality + assert(!caseWhenObj1.semanticEquals(caseWhenObj2)) + assert(!caseWhenObj2.semanticEquals(caseWhenObj1)) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org