This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new c961e7c  [SPARK-27917][SQL][BACKPORT-2.4] canonical form of CaseWhen 
object is incorrect
c961e7c is described below

commit c961e7caaebb3cbb593230c450b1518b5807bed6
Author: sandeep katta <sandeep.katta2...@gmail.com>
AuthorDate: Tue Jun 11 07:14:24 2019 -0700

    [SPARK-27917][SQL][BACKPORT-2.4] canonical form of CaseWhen object is 
incorrect
    
    ## What changes were proposed in this pull request?
    For caseWhen Object canonicalized is not handled
    
    for e.g let's consider below CaseWhen Object
        val attrRef = AttributeReference("ACCESS_CHECK", StringType)()
        val caseWhenObj1 = CaseWhen(Seq((attrRef, Literal("A"))))
    
    caseWhenObj1.canonicalized **ouput** is as below
    
    CASE WHEN ACCESS_CHECK#0 THEN A END (**Before Fix)**
    
    **After Fix** : CASE WHEN none#0 THEN A END
    
    So when there will be aliasref like below statements, semantic equals will 
fail. Sematic equals returns true if the canonicalized form of both the 
expressions are same.
    
    val attrRef = AttributeReference("ACCESS_CHECK", StringType)()
    val aliasAttrRef = attrRef.withName("access_check")
    val caseWhenObj1 = CaseWhen(Seq((attrRef, Literal("A"))))
    val caseWhenObj2 = CaseWhen(Seq((aliasAttrRef, Literal("A"))))
    
    **assert(caseWhenObj2.semanticEquals(caseWhenObj1.semanticEquals) fails**
    
    **caseWhenObj1.canonicalized**
    
    Before Fix:CASE WHEN ACCESS_CHECK#0 THEN A END
    After Fix: CASE WHEN none#0 THEN A END
    **caseWhenObj2.canonicalized**
    
    Before Fix:CASE WHEN access_check#0 THEN A END
    After Fix: CASE WHEN none#0 THEN A END
    
    ## How was this patch tested?
    Added UT
    
    Closes #24836 from sandeep-katta/spark2.4.
    
    Authored-by: sandeep katta <sandeep.katta2...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../apache/spark/sql/catalyst/trees/TreeNode.scala |  3 +++
 .../expressions/ConditionalExpressionSuite.scala   | 24 ++++++++++++++++++++++
 2 files changed, 27 insertions(+)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index becfa8d..e3b8b08 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -211,6 +211,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
     }
     def mapChild(child: Any): Any = child match {
       case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
+      // CaseWhen Case or any tuple type
+      case (left, right) => (mapChild(left), mapChild(right))
       case nonChild: AnyRef => nonChild
       case null => null
     }
@@ -226,6 +228,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
         // `mapValues` is lazy and we need to force it to materialize
         m.mapValues(mapChild).view.force
       case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
+      case Some(child) => Some(mapChild(child))
       case nonChild: AnyRef => nonChild
       case null => null
     }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index f489d33..2e1858c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -222,4 +222,28 @@ class ConditionalExpressionSuite extends SparkFunSuite 
with ExpressionEvalHelper
     CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), 
Literal(-1)).genCode(ctx)
     assert(ctx.inlinedMutableStates.size == 1)
   }
+
+  test("SPARK-27917 test semantic equals of CaseWhen") {
+    val attrRef = AttributeReference("ACCESS_CHECK", StringType)()
+    val aliasAttrRef = attrRef.withName("access_check")
+    // Test for Equality
+    var caseWhenObj1 = CaseWhen(Seq((attrRef, Literal("A"))))
+    var caseWhenObj2 = CaseWhen(Seq((aliasAttrRef, Literal("A"))))
+    assert(caseWhenObj1.semanticEquals(caseWhenObj2))
+    assert(caseWhenObj2.semanticEquals(caseWhenObj1))
+    // Test for inEquality
+    caseWhenObj2 = CaseWhen(Seq((attrRef, Literal("a"))))
+    assert(!caseWhenObj1.semanticEquals(caseWhenObj2))
+    assert(!caseWhenObj2.semanticEquals(caseWhenObj1))
+    // Test with elseValue with Equality
+    caseWhenObj1 = CaseWhen(Seq((attrRef, Literal("A"))), 
attrRef.withName("ELSEVALUE"))
+    caseWhenObj2 = CaseWhen(Seq((aliasAttrRef, Literal("A"))), 
aliasAttrRef.withName("elsevalue"))
+    assert(caseWhenObj1.semanticEquals(caseWhenObj2))
+    assert(caseWhenObj2.semanticEquals(caseWhenObj1))
+    caseWhenObj1 = CaseWhen(Seq((attrRef, Literal("A"))), Literal("ELSEVALUE"))
+    caseWhenObj2 = CaseWhen(Seq((aliasAttrRef, Literal("A"))), 
Literal("elsevalue"))
+    // Test with elseValue with inEquality
+    assert(!caseWhenObj1.semanticEquals(caseWhenObj2))
+    assert(!caseWhenObj2.semanticEquals(caseWhenObj1))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to