Repository: spark
Updated Branches:
  refs/heads/master 803e7f087 -> adcb7d335


[SPARK-3855][SQL] Preserve the result attribute of python UDFs though 
transformations

In the current implementation it was possible for the reference to change after 
analysis.

Author: Michael Armbrust <[email protected]>

Closes #2717 from marmbrus/pythonUdfResults and squashes the following commits:

da14879 [Michael Armbrust] Fix test
6343bcb [Michael Armbrust] add test
9533286 [Michael Armbrust] Correctly preserve the result attribute of python 
UDFs though transformations


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/adcb7d33
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/adcb7d33
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/adcb7d33

Branch: refs/heads/master
Commit: adcb7d3350032dda69a43de724c8bdff5fef2c67
Parents: 803e7f0
Author: Michael Armbrust <[email protected]>
Authored: Fri Oct 17 14:12:07 2014 -0700
Committer: Patrick Wendell <[email protected]>
Committed: Fri Oct 17 14:12:07 2014 -0700

----------------------------------------------------------------------
 python/pyspark/tests.py                                 |  6 ++++++
 .../apache/spark/sql/execution/SparkStrategies.scala    |  2 +-
 .../org/apache/spark/sql/execution/pythonUdfs.scala     | 12 ++++++++++--
 3 files changed, 17 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/adcb7d33/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index ceab574..f5ccf31 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -683,6 +683,12 @@ class SQLTests(ReusedPySparkTestCase):
         [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
         self.assertEqual(row[0], 5)
 
+    def test_udf2(self):
+        self.sqlCtx.registerFunction("strlen", lambda string: len(string))
+        
self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+        [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 
1").collect()
+        self.assertEqual(u"4", res[0])
+
     def test_broadcast_in_udf(self):
         bar = {"a": "aa", "b": "bb", "c": "abc"}
         foo = self.sc.broadcast(bar)

http://git-wip-us.apache.org/repos/asf/spark/blob/adcb7d33/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 4f1af72..79e4ddb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -295,7 +295,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
       case logical.Repartition(expressions, child) =>
         execution.Exchange(HashPartitioning(expressions, numPartitions), 
planLater(child)) :: Nil
-      case e @ EvaluatePython(udf, child) =>
+      case e @ EvaluatePython(udf, child, _) =>
         BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
       case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
       case _ => Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/adcb7d33/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 0977da3..be729e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -105,13 +105,21 @@ private[spark] object ExtractPythonUdfs extends 
Rule[LogicalPlan] {
   }
 }
 
+object EvaluatePython {
+  def apply(udf: PythonUDF, child: LogicalPlan) =
+    new EvaluatePython(udf, child, AttributeReference("pythonUDF", 
udf.dataType)())
+}
+
 /**
  * :: DeveloperApi ::
  * Evaluates a [[PythonUDF]], appending the result to the end of the input 
tuple.
  */
 @DeveloperApi
-case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends 
logical.UnaryNode {
-  val resultAttribute = AttributeReference("pythonUDF", udf.dataType, 
nullable=true)()
+case class EvaluatePython(
+    udf: PythonUDF,
+    child: LogicalPlan,
+    resultAttribute: AttributeReference)
+  extends logical.UnaryNode {
 
   def output = child.output :+ resultAttribute
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to