This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new ffda450  [SPARK-32090][SQL] Improve UserDefinedType.equal() to make it 
be symmetrical
ffda450 is described below

commit ffda450f7e7e07996de35dec5c6f060ddf74c2b6
Author: yi.wu <yi...@databricks.com>
AuthorDate: Sun Jun 28 21:49:10 2020 -0700

    [SPARK-32090][SQL] Improve UserDefinedType.equal() to make it be symmetrical
    
    ### What changes were proposed in this pull request?
    
    This PR fix `UserDefinedType.equal()` by comparing the UDT class instead of 
checking `acceptsType()`.
    
    ### Why are the changes needed?
    
    It's weird that equality comparison between two UDT types can have 
different result by switching the order:
    
    ```scala
    // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass
    val udt1 = new ExampleBaseTypeUDT
    val udt2 = new ExampleSubTypeUDT
    println(udt1 == udt2) // true
    println(udt2 == udt1) // false
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    Before:
    ```scala
    // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass
    val udt1 = new ExampleBaseTypeUDT
    val udt2 = new ExampleSubTypeUDT
    println(udt1 == udt2) // true
    println(udt2 == udt1) // false
    ```
    
    After:
    ```scala
    // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass
    val udt1 = new ExampleBaseTypeUDT
    val udt2 = new ExampleSubTypeUDT
    println(udt1 == udt2) // false
    println(udt2 == udt1) // false
    ```
    
    ### How was this patch tested?
    
    Added a unit test.
    
    Closes #28923 from Ngone51/fix-udt-equal.
    
    Authored-by: yi.wu <yi...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../org/apache/spark/sql/types/UserDefinedType.scala   |  2 +-
 .../org/apache/spark/sql/UserDefinedTypeSuite.scala    | 18 ++++++++++++++++++
 2 files changed, 19 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 6af16e2..592ce03 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -90,7 +90,7 @@ abstract class UserDefinedType[UserType >: Null] extends 
DataType with Serializa
   override def hashCode(): Int = getClass.hashCode()
 
   override def equals(other: Any): Boolean = other match {
-    case that: UserDefinedType[_] => this.acceptsType(that)
+    case that: UserDefinedType[_] => this.getClass == that.getClass
     case _ => false
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 4e74e92..a278841 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -183,6 +183,24 @@ class UserDefinedTypeSuite extends QueryTest with 
SharedSQLContext with ParquetT
     MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))),
     MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.3, 3.0)))).toDF()
 
+
+  test("SPARK-32090: equal") {
+    val udt1 = new ExampleBaseTypeUDT
+    val udt2 = new ExampleSubTypeUDT
+    val udt3 = new ExampleSubTypeUDT
+    assert(udt1 !== udt2)
+    assert(udt2 !== udt1)
+    assert(udt2 === udt3)
+    assert(udt3 === udt2)
+  }
+
+  test("SPARK-32090: acceptsType") {
+    val udt1 = new ExampleBaseTypeUDT
+    val udt2 = new ExampleSubTypeUDT
+    assert(udt1.acceptsType(udt2))
+    assert(!udt2.acceptsType(udt1))
+  }
+
   test("register user type: MyDenseVector for MyLabeledPoint") {
     val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: 
Double) => v }
     val labelsArrays: Array[Double] = labels.collect()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to