Repository: spark Updated Branches: refs/heads/master 803e7f087 -> adcb7d335
[SPARK-3855][SQL] Preserve the result attribute of python UDFs though transformations In the current implementation it was possible for the reference to change after analysis. Author: Michael Armbrust <[email protected]> Closes #2717 from marmbrus/pythonUdfResults and squashes the following commits: da14879 [Michael Armbrust] Fix test 6343bcb [Michael Armbrust] add test 9533286 [Michael Armbrust] Correctly preserve the result attribute of python UDFs though transformations Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/adcb7d33 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/adcb7d33 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/adcb7d33 Branch: refs/heads/master Commit: adcb7d3350032dda69a43de724c8bdff5fef2c67 Parents: 803e7f0 Author: Michael Armbrust <[email protected]> Authored: Fri Oct 17 14:12:07 2014 -0700 Committer: Patrick Wendell <[email protected]> Committed: Fri Oct 17 14:12:07 2014 -0700 ---------------------------------------------------------------------- python/pyspark/tests.py | 6 ++++++ .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../org/apache/spark/sql/execution/pythonUdfs.scala | 12 ++++++++++-- 3 files changed, 17 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/adcb7d33/python/pyspark/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index ceab574..f5ccf31 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -683,6 +683,12 @@ class SQLTests(ReusedPySparkTestCase): [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string)) + self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(u"4", res[0]) + def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) http://git-wip-us.apache.org/repos/asf/spark/blob/adcb7d33/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4f1af72..79e4ddb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -295,7 +295,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil - case e @ EvaluatePython(udf, child) => + case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil case _ => Nil http://git-wip-us.apache.org/repos/asf/spark/blob/adcb7d33/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 0977da3..be729e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -105,13 +105,21 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { } } +object EvaluatePython { + def apply(udf: PythonUDF, child: LogicalPlan) = + new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) +} + /** * :: DeveloperApi :: * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. */ @DeveloperApi -case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode { - val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)() +case class EvaluatePython( + udf: PythonUDF, + child: LogicalPlan, + resultAttribute: AttributeReference) + extends logical.UnaryNode { def output = child.output :+ resultAttribute } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
