This is an automated email from the ASF dual-hosted git repository. yamamuro pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 737a850 [SPARK-32090][SQL] Improve UserDefinedType.equal() to make it be symmetrical 737a850 is described below commit 737a85033a543613fbdc6ac5737a0d23f7e9f108 Author: yi.wu <yi...@databricks.com> AuthorDate: Sun Jun 28 21:49:10 2020 -0700 [SPARK-32090][SQL] Improve UserDefinedType.equal() to make it be symmetrical ### What changes were proposed in this pull request? This PR fix `UserDefinedType.equal()` by comparing the UDT class instead of checking `acceptsType()`. ### Why are the changes needed? It's weird that equality comparison between two UDT types can have different result by switching the order: ```scala // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass val udt1 = new ExampleBaseTypeUDT val udt2 = new ExampleSubTypeUDT println(udt1 == udt2) // true println(udt2 == udt1) // false ``` ### Does this PR introduce _any_ user-facing change? Yes. Before: ```scala // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass val udt1 = new ExampleBaseTypeUDT val udt2 = new ExampleSubTypeUDT println(udt1 == udt2) // true println(udt2 == udt1) // false ``` After: ```scala // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass val udt1 = new ExampleBaseTypeUDT val udt2 = new ExampleSubTypeUDT println(udt1 == udt2) // false println(udt2 == udt1) // false ``` ### How was this patch tested? Added a unit test. Closes #28923 from Ngone51/fix-udt-equal. Authored-by: yi.wu <yi...@databricks.com> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- .../org/apache/spark/sql/types/UserDefinedType.scala | 2 +- .../org/apache/spark/sql/UserDefinedTypeSuite.scala | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 6af16e2..592ce03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -90,7 +90,7 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other match { - case that: UserDefinedType[_] => this.acceptsType(that) + case that: UserDefinedType[_] => this.getClass == that.getClass case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index ed8ab1c..7c126f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -134,6 +134,24 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque MyLabeledPoint(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), MyLabeledPoint(0.0, new TestUDT.MyDenseVector(Array(0.3, 3.0)))).toDF() + + test("SPARK-32090: equal") { + val udt1 = new ExampleBaseTypeUDT + val udt2 = new ExampleSubTypeUDT + val udt3 = new ExampleSubTypeUDT + assert(udt1 !== udt2) + assert(udt2 !== udt1) + assert(udt2 === udt3) + assert(udt3 === udt2) + } + + test("SPARK-32090: acceptsType") { + val udt1 = new ExampleBaseTypeUDT + val udt2 = new ExampleSubTypeUDT + assert(udt1.acceptsType(udt2)) + assert(!udt2.acceptsType(udt1)) + } + test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org