Repository: spark
Updated Branches:
  refs/heads/master 0cbadf28c -> 0f90f4e6a


[SPARK-13410][SQL] Support unionAll for DataFrames with UDT columns.

## What changes were proposed in this pull request?

This PR adds equality operators to UDT classes so that they can be correctly 
tested for dataType equality during union operations.

This was previously causing `"AnalysisException: u"unresolved operator 
'Union;""` when trying to unionAll two dataframes with UDT columns as below.

```
from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
from pyspark.sql import types

schema = types.StructType([types.StructField("point", PythonOnlyUDT(), True)])

a = sqlCtx.createDataFrame([[PythonOnlyPoint(1.0, 2.0)]], schema)
b = sqlCtx.createDataFrame([[PythonOnlyPoint(3.0, 4.0)]], schema)

c = a.unionAll(b)
```

## How was the this patch tested?

Tested using two unit tests in sql/test.py and the DataFrameSuite.

Additional information here : https://issues.apache.org/jira/browse/SPARK-13410

Author: Franklyn D'souza <[email protected]>

Closes #11279 from damnMeddlingKid/udt-union-all.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0f90f4e6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0f90f4e6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0f90f4e6

Branch: refs/heads/master
Commit: 0f90f4e6ac9e9ca694e3622b866f33d3fdf1a459
Parents: 0cbadf2
Author: Franklyn D'souza <[email protected]>
Authored: Sun Feb 21 16:58:17 2016 -0800
Committer: Reynold Xin <[email protected]>
Committed: Sun Feb 21 16:58:17 2016 -0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                       | 18 ++++++++++++++++++
 .../apache/spark/sql/types/UserDefinedType.scala  | 10 ++++++++++
 .../apache/spark/sql/test/ExamplePointUDT.scala   |  7 ++++++-
 .../org/apache/spark/sql/DataFrameSuite.scala     | 16 ++++++++++++++++
 4 files changed, 50 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0f90f4e6/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index e30aa0a..cc11c0f 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -601,6 +601,24 @@ class SQLTests(ReusedPySparkTestCase):
         point = df1.head().point
         self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
 
+    def test_unionAll_with_udt(self):
+        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
+        row1 = (1.0, ExamplePoint(1.0, 2.0))
+        row2 = (2.0, ExamplePoint(3.0, 4.0))
+        schema = StructType([StructField("label", DoubleType(), False),
+                             StructField("point", ExamplePointUDT(), False)])
+        df1 = self.sqlCtx.createDataFrame([row1], schema)
+        df2 = self.sqlCtx.createDataFrame([row2], schema)
+
+        result = df1.unionAll(df2).orderBy("label").collect()
+        self.assertEqual(
+            result,
+            [
+                Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
+                Row(label=2.0, point=ExamplePoint(3.0, 4.0))
+            ]
+        )
+
     def test_column_operators(self):
         ci = self.df.key
         cs = self.df.value

http://git-wip-us.apache.org/repos/asf/spark/blob/0f90f4e6/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index d7a2c23..7664c30 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -86,6 +86,11 @@ abstract class UserDefinedType[UserType] extends DataType 
with Serializable {
     this.getClass == dataType.getClass
 
   override def sql: String = sqlType.sql
+
+  override def equals(other: Any): Boolean = other match {
+    case that: UserDefinedType[_] => this.acceptsType(that)
+    case _ => false
+  }
 }
 
 /**
@@ -112,4 +117,9 @@ private[sql] class PythonUserDefinedType(
       ("serializedClass" -> serializedPyClass) ~
       ("sqlType" -> sqlType.jsonValue)
   }
+
+  override def equals(other: Any): Boolean = other match {
+    case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT)
+    case _ => false
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0f90f4e6/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index 20a17ba..e2c9fc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -26,7 +26,12 @@ import org.apache.spark.sql.types._
  * @param y y coordinate
  */
 @SQLUserDefinedType(udt = classOf[ExamplePointUDT])
-private[sql] class ExamplePoint(val x: Double, val y: Double) extends 
Serializable
+private[sql] class ExamplePoint(val x: Double, val y: Double) extends 
Serializable {
+  override def equals(other: Any): Boolean = other match {
+    case that: ExamplePoint => this.x == that.x && this.y == that.y
+    case _ => false
+  }
+}
 
 /**
  * User-defined type for [[ExamplePoint]].

http://git-wip-us.apache.org/repos/asf/spark/blob/0f90f4e6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 50a2464..4930c48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -112,6 +112,22 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     )
   }
 
+  test("unionAll should union DataFrames with UDTs (SPARK-13410)") {
+    val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 
2.0))))
+    val schema1 = StructType(Array(StructField("label", IntegerType, false),
+                    StructField("point", new ExamplePointUDT(), false)))
+    val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 
4.0))))
+    val schema2 = StructType(Array(StructField("label", IntegerType, false),
+                    StructField("point", new ExamplePointUDT(), false)))
+    val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
+    val df2 = sqlContext.createDataFrame(rowRDD2, schema2)
+
+    checkAnswer(
+      df1.unionAll(df2).orderBy("label"),
+      Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 
4.0)))
+    )
+  }
+
   test("empty data frame") {
     assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String])
     assert(sqlContext.emptyDataFrame.count() === 0)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to