[
https://issues.apache.org/jira/browse/SPARK-24760?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16541209#comment-16541209
]
Mortada Mehyar commented on SPARK-24760:
----------------------------------------
[~bryanc] thanks for the example. It looks like reasonable behavior of
pandas/pyarrow to me though, because the pd.Series you have is of dtype 'O',
and it is getting converted into a float dtype. And None becomes NaN whether
the conversion happens in pandas or pyarrow.
So, I actually found another example in pyspark which highlights the NaN
handling issue, without creating a pandas_udf:
{code}
mortada_mehyar ~ $ pyspark --conf "spark.sql.execution.arrow.enabled=true"
In [1]: import pandas as pd
In [2]: df = pd.DataFrame({'a': [1, 2], 'b': [float('nan'), 1.2]})
In [3]: spark.createDataFrame(df).show()
+---+----+
| a| b|
+---+----+
| 1|null|
| 2| 1.2|
+---+----+
mortada_mehyar ~ $ pyspark --conf "spark.sql.execution.arrow.enabled=false"
In [1]: import pandas as pd
In [2]: df = pd.DataFrame({'a': [1, 2], 'b': [float('nan'), 1.2]})
In [3]: spark.createDataFrame(df).show()
+---+---+
| a| b|
+---+---+
| 1|NaN|
| 2|1.2|
+---+---+
{code}
Note the code is identical and the only difference is the config value for
`spark.sql.execution.arrow.enabled`, which defaults to false in spark 2.3.1. I
think this would be quite a surprising behavior for users.
cc [~wesmckinn] would appreciate your input, thanks!
> Pandas UDF does not handle NaN correctly
> ----------------------------------------
>
> Key: SPARK-24760
> URL: https://issues.apache.org/jira/browse/SPARK-24760
> Project: Spark
> Issue Type: Bug
> Components: PySpark
> Affects Versions: 2.3.0, 2.3.1
> Environment: Spark 2.3.1
> Pandas 0.23.1
> Reporter: Mortada Mehyar
> Priority: Minor
>
> I noticed that having `NaN` values when using the new Pandas UDF feature
> triggers a JVM exception. Not sure if this is an issue with PySpark or
> PyArrow. Here is a somewhat contrived example to showcase the problem.
> {code}
> In [1]: import pandas as pd
> ...: from pyspark.sql.functions import lit, pandas_udf, PandasUDFType
> In [2]: d = [{'key': 'a', 'value': 1},
> {'key': 'a', 'value': 2},
> {'key': 'b', 'value': 3},
> {'key': 'b', 'value': -2}]
> df = spark.createDataFrame(d, "key: string, value: int")
> df.show()
> +---+-----+
> |key|value|
> +---+-----+
> | a| 1|
> | a| 2|
> | b| 3|
> | b| -2|
> +---+-----+
> In [3]: df_tmp = df.withColumn('new', lit(1.0)) # add a DoubleType column
> df_tmp.printSchema()
> root
> |-- key: string (nullable = true)
> |-- value: integer (nullable = true)
> |-- new: double (nullable = false)
> {code}
> And the Pandas UDF is simply creating a new column where negative values
> would be set to a particular float, in this case INF and it works fine
> {code}
> In [4]: @pandas_udf(df_tmp.schema, PandasUDFType.GROUPED_MAP)
> ...: def func(pdf):
> ...: pdf['new'] = pdf['value'].where(pdf['value'] > 0, float('inf'))
> ...: return pdf
> In [5]: df.groupby('key').apply(func).show()
> +---+-----+----------+
> |key|value|new|
> +---+-----+----------+
> | b| 3| 3.0|
> | b| -2| Infinity|
> | a| 1| 1.0|
> | a| 2| 2.0|
> +---+-----+----------+
> {code}
> However if we set this value to NaN then it triggers an exception:
> {code}
> In [6]: @pandas_udf(df_tmp.schema, PandasUDFType.GROUPED_MAP)
> ...: def func(pdf):
> ...: pdf['new'] = pdf['value'].where(pdf['value'] > 0, float('nan'))
> ...: return pdf
> ...:
> ...: df.groupby('key').apply(func).show()
> [Stage 23:======================================================> (73 + 2) /
> 75]2018-07-07 16:26:27 ERROR Executor:91 - Exception in task 36.0 in stage
> 23.0 (TID 414)
> java.lang.IllegalStateException: Value at index is null
> at org.apache.arrow.vector.Float8Vector.get(Float8Vector.java:98)
> at
> org.apache.spark.sql.vectorized.ArrowColumnVector$DoubleAccessor.getDouble(ArrowColumnVector.java:344)
> at
> org.apache.spark.sql.vectorized.ArrowColumnVector.getDouble(ArrowColumnVector.java:99)
> at
> org.apache.spark.sql.execution.vectorized.MutableColumnarRow.getDouble(MutableColumnarRow.java:126)
> at
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
> Source)
> at
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
> Source)
> at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
> at
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown
> Source)
> at
> org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
> at
> org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
> at
> org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
> at
> org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
> at
> org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
> at
> org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
> at
> org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
> at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
> at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
> at
> org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
> at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
> at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
> at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
> at org.apache.spark.scheduler.Task.run(Task.scala:109)
> at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
> at
> java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
> at
> java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
> at java.lang.Thread.run(Thread.java:745)
> 2018-07-07 16:26:27 WARN TaskSetManager:66 - Lost task 36.0 in stage 23.0
> (TID 414, localhost, executor driver): java.lang.IllegalStateException: Value
> at index is null
> at org.apache.arrow.vector.Float8Vector.get(Float8Vector.java:98)
> at
> org.apache.spark.sql.vectorized.ArrowColumnVector$DoubleAccessor.getDouble(ArrowColumnVector.java:344)
> at
> org.apache.spark.sql.vectorized.ArrowColumnVector.getDouble(ArrowColumnVector.java:99)
> at
> org.apache.spark.sql.execution.vectorized.MutableColumnarRow.getDouble(MutableColumnarRow.java:126)
> at
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
> Source)
> at
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
> Source)
> at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
> at
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown
> Source)
> at
> org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
> at
> org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
> at
> org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
> at
> org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
> at
> org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
> at
> org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
> at
> org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
> at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
> at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
> at
> org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
> at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
> at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
> at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
> at org.apache.spark.scheduler.Task.run(Task.scala:109)
> at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
> at
> java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
> at
> java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
> at java.lang.Thread.run(Thread.java:745)
> {code}
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]