Michael Tong created SPARK-27691:
------------------------------------
Summary: Issue when running queries using filter predicates on
pandas GROUPED_AGG udfs
Key: SPARK-27691
URL: https://issues.apache.org/jira/browse/SPARK-27691
Project: Spark
Issue Type: Bug
Components: Input/Output
Affects Versions: 2.4.2
Reporter: Michael Tong
Am currently running pyspark 2.4.2 and I am currently unable to run the
following code.
{code:java}
from pyspark.sql import functions, types
import pandas as pd
import random
# initialize test data
test_data = [[False, int(random.random() * 2)] for i in range(10000)]
test_data = pd.DataFrame(test_data, columns=['bool_value', 'int_value'])
# pandas udf
pandas_any_udf = functions.pandas_udf(lambda x: x.any(), types.BooleanType(),
functions.PandasUDFType.GROUPED_AGG)
# create spark DataFrame and build the query
test_df = spark.createDataFrame(test_data)
test_df =
test_df.groupby('int_value').agg(pandas_any_udf('bool_value').alias('bool_any_result'))
test_df = test_df.filter(functions.col('bool_any_result') == True)
# write to output
test_df.write.parquet('/tmp/mtong/write_test', mode='overwrite')
{code}
Below is a truncated error message.
{code:java}
Py4JJavaError: An error occurred while calling o1125.parquet. :
org.apache.spark.SparkException: Job aborted.
...
Exchange hashpartitioning(int_value#123L, 2000)
+- *(1) Filter (<lambda>(bool_value#122) = true)
+- Scan ExistingRDD arrow[bool_value#122,int_value#123L]
...
Caused by: java.lang.UnsupportedOperationException: Cannot evaluate expression:
<lambda>(input[0, boolean, true]){code}
What appears to be happening is that the query optimizer incorrectly pushes up
the filter predicate on bool_any_result before the group by operation. This
causes the query to error out before spark attempts to execute the query. I
have also tried running a variant of this query with functions.count() as the
aggregation function and the query ran fine, so I believe that this is an error
that only affects pandas udfs.
Variant of query with standard aggregation function
{code:java}
test_df = spark.createDataFrame(test_data)
test_df =
test_df.groupby('int_value').agg(functions.count('bool_value').alias('bool_counts'))
test_df = test_df.filter(functions.col('bool_counts') > 0)
{code}
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]