This is an automated email from the ASF dual-hosted git repository.
yamamuro pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new ffda450 [SPARK-32090][SQL] Improve UserDefinedType.equal() to make it
be symmetrical
ffda450 is described below
commit ffda450f7e7e07996de35dec5c6f060ddf74c2b6
Author: yi.wu <[email protected]>
AuthorDate: Sun Jun 28 21:49:10 2020 -0700
[SPARK-32090][SQL] Improve UserDefinedType.equal() to make it be symmetrical
### What changes were proposed in this pull request?
This PR fix `UserDefinedType.equal()` by comparing the UDT class instead of
checking `acceptsType()`.
### Why are the changes needed?
It's weird that equality comparison between two UDT types can have
different result by switching the order:
```scala
// ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass
val udt1 = new ExampleBaseTypeUDT
val udt2 = new ExampleSubTypeUDT
println(udt1 == udt2) // true
println(udt2 == udt1) // false
```
### Does this PR introduce _any_ user-facing change?
Yes.
Before:
```scala
// ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass
val udt1 = new ExampleBaseTypeUDT
val udt2 = new ExampleSubTypeUDT
println(udt1 == udt2) // true
println(udt2 == udt1) // false
```
After:
```scala
// ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass
val udt1 = new ExampleBaseTypeUDT
val udt2 = new ExampleSubTypeUDT
println(udt1 == udt2) // false
println(udt2 == udt1) // false
```
### How was this patch tested?
Added a unit test.
Closes #28923 from Ngone51/fix-udt-equal.
Authored-by: yi.wu <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../org/apache/spark/sql/types/UserDefinedType.scala | 2 +-
.../org/apache/spark/sql/UserDefinedTypeSuite.scala | 18 ++++++++++++++++++
2 files changed, 19 insertions(+), 1 deletion(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 6af16e2..592ce03 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -90,7 +90,7 @@ abstract class UserDefinedType[UserType >: Null] extends
DataType with Serializa
override def hashCode(): Int = getClass.hashCode()
override def equals(other: Any): Boolean = other match {
- case that: UserDefinedType[_] => this.acceptsType(that)
+ case that: UserDefinedType[_] => this.getClass == that.getClass
case _ => false
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 4e74e92..a278841 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -183,6 +183,24 @@ class UserDefinedTypeSuite extends QueryTest with
SharedSQLContext with ParquetT
MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))),
MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.3, 3.0)))).toDF()
+
+ test("SPARK-32090: equal") {
+ val udt1 = new ExampleBaseTypeUDT
+ val udt2 = new ExampleSubTypeUDT
+ val udt3 = new ExampleSubTypeUDT
+ assert(udt1 !== udt2)
+ assert(udt2 !== udt1)
+ assert(udt2 === udt3)
+ assert(udt3 === udt2)
+ }
+
+ test("SPARK-32090: acceptsType") {
+ val udt1 = new ExampleBaseTypeUDT
+ val udt2 = new ExampleSubTypeUDT
+ assert(udt1.acceptsType(udt2))
+ assert(!udt2.acceptsType(udt1))
+ }
+
test("register user type: MyDenseVector for MyLabeledPoint") {
val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v:
Double) => v }
val labelsArrays: Array[Double] = labels.collect()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]