This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 2a5f0ab  [SPARK-30921][PYSPARK] Predicates on python udf should not be 
pushdown through Aggregate
2a5f0ab is described below

commit 2a5f0aba73fa7a933b605a4108001dca51a91eb5
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Apr 6 09:36:20 2020 +0900

    [SPARK-30921][PYSPARK] Predicates on python udf should not be pushdown 
through Aggregate
    
    ### What changes were proposed in this pull request?
    
    This patch proposed to skip predicates on PythonUDFs to be pushdown through 
Aggregate.
    
    ### Why are the changes needed?
    
    The predicates on PythonUDFs cannot be pushdown through Aggregate. Pushed 
down predicates cannot be evaluate because PythonUDFs cannot be evaluated on 
Filter and cause error like:
    
    ```
    Caused by: java.lang.UnsupportedOperationException: Cannot generate code 
for expression: mean(input[1, struct<bar:bigint>, true].bar)
            at 
org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode(Expression.scala:304)
            at 
org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode$(Expression.scala:303)
            at 
org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:52)
            at 
org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:146)
            at scala.Option.getOrElse(Option.scala:189)
            at 
org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:141)
            at 
org.apache.spark.sql.catalyst.expressions.CastBase.doGenCode(Cast.scala:821)
            at 
org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:146)
            at scala.Option.getOrElse(Option.scala:189)
    ```
    
    ### Does this PR introduce any user-facing change?
    
    Yes. Previously the predicates on PythonUDFs will be pushdown through 
Aggregate can cause error. After this change, the query can work.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #28089 from viirya/SPARK-30921.
    
    Authored-by: Liang-Chi Hsieh <[email protected]>
    Signed-off-by: HyukjinKwon <[email protected]>
    (cherry picked from commit 1f0287148977adb416001cb0988e919a2698c8e0)
    Signed-off-by: HyukjinKwon <[email protected]>
---
 python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py | 17 +++++++++++++++++
 .../apache/spark/sql/catalyst/optimizer/Optimizer.scala |  5 +++--
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py 
b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
index 2167978..224c8ce 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
@@ -491,6 +491,23 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
             agg2 = self.spark.sql("select max_udf(id) from table")
             assert_frame_equal(agg1.toPandas(), agg2.toPandas())
 
+    def test_no_predicate_pushdown_through(self):
+        # SPARK-30921: We should not pushdown predicates of PythonUDFs through 
Aggregate.
+        import numpy as np
+
+        @pandas_udf('float', PandasUDFType.GROUPED_AGG)
+        def mean(x):
+            return np.mean(x)
+
+        df = self.spark.createDataFrame([
+            Row(id=1, foo=42), Row(id=2, foo=1), Row(id=2, foo=2)
+        ])
+
+        agg = df.groupBy('id').agg(mean('foo').alias("mean"))
+        filtered = agg.filter(agg['mean'] > 40.0)
+
+        assert(filtered.collect()[0]["mean"] == 42.0)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.test_pandas_udf_grouped_agg import *
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 30ad6bfe..d93c4a5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1204,9 +1204,10 @@ object PushPredicateThroughNonJoin extends 
Rule[LogicalPlan] with PredicateHelpe
 
   def getAliasMap(plan: Aggregate): AttributeMap[Expression] = {
     // Find all the aliased expressions in the aggregate list that don't 
include any actual
-    // AggregateExpression, and create a map from the alias to the expression
+    // AggregateExpression or PythonUDF, and create a map from the alias to 
the expression
     val aliasMap = plan.aggregateExpressions.collect {
-      case a: Alias if 
a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
+      case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
+          PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
         (a.toAttribute, a.child)
     }
     AttributeMap(aliasMap)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to