This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 2a5f0ab [SPARK-30921][PYSPARK] Predicates on python udf should not be
pushdown through Aggregate
2a5f0ab is described below
commit 2a5f0aba73fa7a933b605a4108001dca51a91eb5
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Apr 6 09:36:20 2020 +0900
[SPARK-30921][PYSPARK] Predicates on python udf should not be pushdown
through Aggregate
### What changes were proposed in this pull request?
This patch proposed to skip predicates on PythonUDFs to be pushdown through
Aggregate.
### Why are the changes needed?
The predicates on PythonUDFs cannot be pushdown through Aggregate. Pushed
down predicates cannot be evaluate because PythonUDFs cannot be evaluated on
Filter and cause error like:
```
Caused by: java.lang.UnsupportedOperationException: Cannot generate code
for expression: mean(input[1, struct<bar:bigint>, true].bar)
at
org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode(Expression.scala:304)
at
org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode$(Expression.scala:303)
at
org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:52)
at
org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:146)
at scala.Option.getOrElse(Option.scala:189)
at
org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:141)
at
org.apache.spark.sql.catalyst.expressions.CastBase.doGenCode(Cast.scala:821)
at
org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:146)
at scala.Option.getOrElse(Option.scala:189)
```
### Does this PR introduce any user-facing change?
Yes. Previously the predicates on PythonUDFs will be pushdown through
Aggregate can cause error. After this change, the query can work.
### How was this patch tested?
Unit test.
Closes #28089 from viirya/SPARK-30921.
Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
(cherry picked from commit 1f0287148977adb416001cb0988e919a2698c8e0)
Signed-off-by: HyukjinKwon <[email protected]>
---
python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py | 17 +++++++++++++++++
.../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 5 +++--
2 files changed, 20 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
index 2167978..224c8ce 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
@@ -491,6 +491,23 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
agg2 = self.spark.sql("select max_udf(id) from table")
assert_frame_equal(agg1.toPandas(), agg2.toPandas())
+ def test_no_predicate_pushdown_through(self):
+ # SPARK-30921: We should not pushdown predicates of PythonUDFs through
Aggregate.
+ import numpy as np
+
+ @pandas_udf('float', PandasUDFType.GROUPED_AGG)
+ def mean(x):
+ return np.mean(x)
+
+ df = self.spark.createDataFrame([
+ Row(id=1, foo=42), Row(id=2, foo=1), Row(id=2, foo=2)
+ ])
+
+ agg = df.groupBy('id').agg(mean('foo').alias("mean"))
+ filtered = agg.filter(agg['mean'] > 40.0)
+
+ assert(filtered.collect()[0]["mean"] == 42.0)
+
if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_udf_grouped_agg import *
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 30ad6bfe..d93c4a5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1204,9 +1204,10 @@ object PushPredicateThroughNonJoin extends
Rule[LogicalPlan] with PredicateHelpe
def getAliasMap(plan: Aggregate): AttributeMap[Expression] = {
// Find all the aliased expressions in the aggregate list that don't
include any actual
- // AggregateExpression, and create a map from the alias to the expression
+ // AggregateExpression or PythonUDF, and create a map from the alias to
the expression
val aliasMap = plan.aggregateExpressions.collect {
- case a: Alias if
a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
+ case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
+ PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
(a.toAttribute, a.child)
}
AttributeMap(aliasMap)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]