Repository: spark
Updated Branches:
refs/heads/branch-1.6 0784e02fd -> 573a2c97e
[SPARK-13410][SQL] Support unionAll for DataFrames with UDT columns.
## What changes were proposed in this pull request?
This PR adds equality operators to UDT classes so that they can be correctly
tested for dataType equality during union operations.
This was previously causing `"AnalysisException: u"unresolved operator
'Union;""` when trying to unionAll two dataframes with UDT columns as below.
```
from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
from pyspark.sql import types
schema = types.StructType([types.StructField("point", PythonOnlyUDT(), True)])
a = sqlCtx.createDataFrame([[PythonOnlyPoint(1.0, 2.0)]], schema)
b = sqlCtx.createDataFrame([[PythonOnlyPoint(3.0, 4.0)]], schema)
c = a.unionAll(b)
```
## How was the this patch tested?
Tested using two unit tests in sql/test.py and the DataFrameSuite.
Additional information here : https://issues.apache.org/jira/browse/SPARK-13410
rxin
Author: Franklyn D'souza <[email protected]>
Closes #11333 from damnMeddlingKid/udt-union-patch.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/573a2c97
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/573a2c97
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/573a2c97
Branch: refs/heads/branch-1.6
Commit: 573a2c97e9a9b8feae22f8af173fb158d59e5332
Parents: 0784e02
Author: Franklyn D'souza <[email protected]>
Authored: Tue Feb 23 15:34:04 2016 -0800
Committer: Reynold Xin <[email protected]>
Committed: Tue Feb 23 15:34:04 2016 -0800
----------------------------------------------------------------------
python/pyspark/sql/tests.py | 18 ++++++++++++++++++
.../apache/spark/sql/types/UserDefinedType.scala | 10 ++++++++++
.../apache/spark/sql/test/ExamplePointUDT.scala | 7 ++++++-
.../org/apache/spark/sql/DataFrameSuite.scala | 16 ++++++++++++++++
4 files changed, 50 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/573a2c97/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 6b06984..0dc4274 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -601,6 +601,24 @@ class SQLTests(ReusedPySparkTestCase):
point = df1.head().point
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+ def test_unionAll_with_udt(self):
+ from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
+ row1 = (1.0, ExamplePoint(1.0, 2.0))
+ row2 = (2.0, ExamplePoint(3.0, 4.0))
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ df1 = self.sqlCtx.createDataFrame([row1], schema)
+ df2 = self.sqlCtx.createDataFrame([row2], schema)
+
+ result = df1.unionAll(df2).orderBy("label").collect()
+ self.assertEqual(
+ result,
+ [
+ Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
+ Row(label=2.0, point=ExamplePoint(3.0, 4.0))
+ ]
+ )
+
def test_column_operators(self):
ci = self.df.key
cs = self.df.value
http://git-wip-us.apache.org/repos/asf/spark/blob/573a2c97/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 4305903..6f092c9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -84,6 +84,11 @@ abstract class UserDefinedType[UserType] extends DataType
with Serializable {
override private[sql] def acceptsType(dataType: DataType) =
this.getClass == dataType.getClass
+
+ override def equals(other: Any): Boolean = other match {
+ case that: UserDefinedType[_] => this.acceptsType(that)
+ case _ => false
+ }
}
/**
@@ -110,4 +115,9 @@ private[sql] class PythonUserDefinedType(
("serializedClass" -> serializedPyClass) ~
("sqlType" -> sqlType.jsonValue)
}
+
+ override def equals(other: Any): Boolean = other match {
+ case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT)
+ case _ => false
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/573a2c97/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index 8d4854b..4aaf321 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -26,7 +26,12 @@ import org.apache.spark.sql.types._
* @param y y coordinate
*/
@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
-private[sql] class ExamplePoint(val x: Double, val y: Double) extends
Serializable
+private[sql] class ExamplePoint(val x: Double, val y: Double) extends
Serializable {
+ override def equals(other: Any): Boolean = other match {
+ case that: ExamplePoint => this.x == that.x && this.y == that.y
+ case _ => false
+ }
+}
/**
* User-defined type for [[ExamplePoint]].
http://git-wip-us.apache.org/repos/asf/spark/blob/573a2c97/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f2ac560..8994556 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -516,6 +516,22 @@ class DataFrameSuite extends QueryTest with
SharedSQLContext {
}
}
+ test("unionAll should union DataFrames with UDTs (SPARK-13410)") {
+ val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0,
2.0))))
+ val schema1 = StructType(Array(StructField("label", IntegerType, false),
+ StructField("point", new ExamplePointUDT(), false)))
+ val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0,
4.0))))
+ val schema2 = StructType(Array(StructField("label", IntegerType, false),
+ StructField("point", new ExamplePointUDT(), false)))
+ val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
+ val df2 = sqlContext.createDataFrame(rowRDD2, schema2)
+
+ checkAnswer(
+ df1.unionAll(df2).orderBy("label"),
+ Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0,
4.0)))
+ )
+ }
+
ignore("show") {
// This test case is intended ignored, but to make sure it compiles
correctly
testData.select($"*").show()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]