Repository: spark
Updated Branches:
refs/heads/branch-2.0 8ef31fbd7 -> b8e1b7c8a
[SPARK-15888] [SQL] fix Python UDF with aggregate
## What changes were proposed in this pull request?
After we move the ExtractPythonUDF rule into physical plan, Python UDF can't
work on top of aggregate anymore, because they can't be evaluated before
aggregate, should be evaluated after aggregate. This PR add another rule to
extract these kind of Python UDF from logical aggregate, create a Project on
top of Aggregate.
## How was this patch tested?
Added regression tests. The plan of added test query looks like this:
```
== Parsed Logical Plan ==
'Project [<lambda>('k, 's) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17,
sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
+- LogicalRDD [key#5L, value#6]
== Analyzed Logical Plan ==
t: int
Project [<lambda>(k#17, s#22L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17,
sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
+- LogicalRDD [key#5L, value#6]
== Optimized Logical Plan ==
Project [<lambda>(agg#29, agg#30L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS agg#29,
sum(cast(<lambda>(value#6) as bigint)) AS agg#30L]
+- LogicalRDD [key#5L, value#6]
== Physical Plan ==
*Project [pythonUDF0#37 AS t#26]
+- BatchEvalPython [<lambda>(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37]
+- *HashAggregate(key=[<lambda>(key#5L)#31],
functions=[sum(cast(<lambda>(value#6) as bigint))], output=[agg#29,agg#30L])
+- Exchange hashpartitioning(<lambda>(key#5L)#31, 200)
+- *HashAggregate(key=[pythonUDF0#34 AS <lambda>(key#5L)#31],
functions=[partial_sum(cast(pythonUDF1#35 as bigint))],
output=[<lambda>(key#5L)#31,sum#33L])
+- BatchEvalPython [<lambda>(key#5L), <lambda>(value#6)], [key#5L,
value#6, pythonUDF0#34, pythonUDF1#35]
+- Scan ExistingRDD[key#5L,value#6]
```
Author: Davies Liu <[email protected]>
Closes #13682 from davies/fix_py_udf.
(cherry picked from commit 5389013acc99367729dfc6deeb2cecc9edd1e24c)
Signed-off-by: Davies Liu <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b8e1b7c8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b8e1b7c8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b8e1b7c8
Branch: refs/heads/branch-2.0
Commit: b8e1b7c8ac7cec71ede10bfd44b0b86d8ba80af7
Parents: 8ef31fb
Author: Davies Liu <[email protected]>
Authored: Wed Jun 15 13:38:04 2016 -0700
Committer: Davies Liu <[email protected]>
Committed: Wed Jun 15 13:38:17 2016 -0700
----------------------------------------------------------------------
python/pyspark/sql/tests.py | 10 ++-
.../spark/sql/execution/SparkOptimizer.scala | 6 +-
.../execution/python/BatchEvalPythonExec.scala | 2 +
.../execution/python/ExtractPythonUDFs.scala | 70 +++++++++++++++++---
4 files changed, 77 insertions(+), 11 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/b8e1b7c8/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1d5d691..c631ad8 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -339,13 +339,21 @@ class SQLTests(ReusedPySparkTestCase):
def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1,
"2")], ["key", "value"])
- from pyspark.sql.functions import udf, col
+ from pyspark.sql.functions import udf, col, sum
from pyspark.sql.types import BooleanType
my_filter = udf(lambda a: a == 1, BooleanType())
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
self.assertEqual(sel.collect(), [Row(key=1)])
+ my_copy = udf(lambda x: x, IntegerType())
+ my_add = udf(lambda a, b: int(a + b), IntegerType())
+ my_strlen = udf(lambda x: len(x), IntegerType())
+ sel = df.groupBy(my_copy(col("key")).alias("k"))\
+ .agg(sum(my_strlen(col("value"))).alias("s"))\
+ .select(my_add(col("k"), col("s")).alias("t"))
+ self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
http://git-wip-us.apache.org/repos/asf/spark/blob/b8e1b7c8/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 08b2d7f..12a10cb 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
+import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.internal.SQLConf
class SparkOptimizer(
@@ -28,6 +29,7 @@ class SparkOptimizer(
experimentalMethods: ExperimentalMethods)
extends Optimizer(catalog, conf) {
- override def batches: Seq[Batch] = super.batches :+ Batch(
- "User Provided Optimizers", fixedPoint,
experimentalMethods.extraOptimizations: _*)
+ override def batches: Seq[Batch] = super.batches :+
+ Batch("Extract Python UDF from Aggregate", Once,
ExtractPythonUDFFromAggregate) :+
+ Batch("User Provided Optimizers", fixedPoint,
experimentalMethods.extraOptimizations: _*)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/b8e1b7c8/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 061d7c7..d9bf4d3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output:
Seq[Attribute], chi
def children: Seq[SparkPlan] = child :: Nil
+ override def producedAttributes: AttributeSet =
AttributeSet(output.drop(child.output.length))
+
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions,
Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
http://git-wip-us.apache.org/repos/asf/spark/blob/b8e1b7c8/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index ab19236..668470e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -18,12 +18,68 @@
package org.apache.spark.sql.execution.python
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan,
Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * Extracts all the Python UDFs in logical aggregate, which depends on
aggregate expression or
+ * grouping key, evaluate them after aggregate.
+ */
+private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
+
+ /**
+ * Returns whether the expression could only be evaluated within aggregate.
+ */
+ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
+ e.isInstanceOf[AggregateExpression] ||
+ agg.groupingExpressions.exists(_.semanticEquals(e))
+ }
+
+ private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate):
Boolean = {
+ expr.find {
+ e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_,
agg)).isDefined
+ }.isDefined
+ }
+
+ private def extract(agg: Aggregate): LogicalPlan = {
+ val projList = new ArrayBuffer[NamedExpression]()
+ val aggExpr = new ArrayBuffer[NamedExpression]()
+ agg.aggregateExpressions.foreach { expr =>
+ if (hasPythonUdfOverAggregate(expr, agg)) {
+ // Python UDF can only be evaluated after aggregate
+ val newE = expr transformDown {
+ case e: Expression if belongAggregate(e, agg) =>
+ val alias = e match {
+ case a: NamedExpression => a
+ case o => Alias(e, "agg")()
+ }
+ aggExpr += alias
+ alias.toAttribute
+ }
+ projList += newE.asInstanceOf[NamedExpression]
+ } else {
+ aggExpr += expr
+ projList += expr.toAttribute
+ }
+ }
+ // There is no Python UDF over aggregate expression
+ Project(projList, agg.copy(aggregateExpressions = aggExpr))
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case agg: Aggregate if
agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
+ extract(agg)
+ }
+}
+
+
/**
* Extracts PythonUDFs from operators, rewriting the query plan so that the
UDF can be evaluated
* alone in a batch.
@@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends
Rule[SparkPlan] {
}
/**
- * Extract all the PythonUDFs from the current operator.
+ * Extract all the PythonUDFs from the current operator and evaluate them
before the operator.
*/
- def extract(plan: SparkPlan): SparkPlan = {
+ private def extract(plan: SparkPlan): SparkPlan = {
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
+ // ignore the PythonUDF that come from second/third aggregate, which is
not used
+ .filter(udf => udf.references.subsetOf(plan.inputSet))
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
@@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends
Rule[SparkPlan] {
// Other cases are disallowed as they are ambiguous or would require a
cartesian
// product.
udfs.filterNot(attributeMap.contains).foreach { udf =>
- if (udf.references.subsetOf(plan.inputSet)) {
- sys.error(s"Invalid PythonUDF $udf, requires attributes from more
than one child.")
- } else {
- sys.error(s"Unable to evaluate PythonUDF $udf. Missing input
attributes.")
- }
+ sys.error(s"Invalid PythonUDF $udf, requires attributes from more than
one child.")
}
val rewritten = plan.transformExpressions {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]