[ https://issues.apache.org/jira/browse/SPARK-24760?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16541209#comment-16541209 ]
Mortada Mehyar commented on SPARK-24760: ---------------------------------------- [~bryanc] thanks for the example. It looks like reasonable behavior of pandas/pyarrow to me though, because the pd.Series you have is of dtype 'O', and it is getting converted into a float dtype. And None becomes NaN whether the conversion happens in pandas or pyarrow. So, I actually found another example in pyspark which highlights the NaN handling issue, without creating a pandas_udf: {code} mortada_mehyar ~ $ pyspark --conf "spark.sql.execution.arrow.enabled=true" In [1]: import pandas as pd In [2]: df = pd.DataFrame({'a': [1, 2], 'b': [float('nan'), 1.2]}) In [3]: spark.createDataFrame(df).show() +---+----+ | a| b| +---+----+ | 1|null| | 2| 1.2| +---+----+ mortada_mehyar ~ $ pyspark --conf "spark.sql.execution.arrow.enabled=false" In [1]: import pandas as pd In [2]: df = pd.DataFrame({'a': [1, 2], 'b': [float('nan'), 1.2]}) In [3]: spark.createDataFrame(df).show() +---+---+ | a| b| +---+---+ | 1|NaN| | 2|1.2| +---+---+ {code} Note the code is identical and the only difference is the config value for `spark.sql.execution.arrow.enabled`, which defaults to false in spark 2.3.1. I think this would be quite a surprising behavior for users. cc [~wesmckinn] would appreciate your input, thanks! > Pandas UDF does not handle NaN correctly > ---------------------------------------- > > Key: SPARK-24760 > URL: https://issues.apache.org/jira/browse/SPARK-24760 > Project: Spark > Issue Type: Bug > Components: PySpark > Affects Versions: 2.3.0, 2.3.1 > Environment: Spark 2.3.1 > Pandas 0.23.1 > Reporter: Mortada Mehyar > Priority: Minor > > I noticed that having `NaN` values when using the new Pandas UDF feature > triggers a JVM exception. Not sure if this is an issue with PySpark or > PyArrow. Here is a somewhat contrived example to showcase the problem. > {code} > In [1]: import pandas as pd > ...: from pyspark.sql.functions import lit, pandas_udf, PandasUDFType > In [2]: d = [{'key': 'a', 'value': 1}, > {'key': 'a', 'value': 2}, > {'key': 'b', 'value': 3}, > {'key': 'b', 'value': -2}] > df = spark.createDataFrame(d, "key: string, value: int") > df.show() > +---+-----+ > |key|value| > +---+-----+ > | a| 1| > | a| 2| > | b| 3| > | b| -2| > +---+-----+ > In [3]: df_tmp = df.withColumn('new', lit(1.0)) # add a DoubleType column > df_tmp.printSchema() > root > |-- key: string (nullable = true) > |-- value: integer (nullable = true) > |-- new: double (nullable = false) > {code} > And the Pandas UDF is simply creating a new column where negative values > would be set to a particular float, in this case INF and it works fine > {code} > In [4]: @pandas_udf(df_tmp.schema, PandasUDFType.GROUPED_MAP) > ...: def func(pdf): > ...: pdf['new'] = pdf['value'].where(pdf['value'] > 0, float('inf')) > ...: return pdf > In [5]: df.groupby('key').apply(func).show() > +---+-----+----------+ > |key|value|new| > +---+-----+----------+ > | b| 3| 3.0| > | b| -2| Infinity| > | a| 1| 1.0| > | a| 2| 2.0| > +---+-----+----------+ > {code} > However if we set this value to NaN then it triggers an exception: > {code} > In [6]: @pandas_udf(df_tmp.schema, PandasUDFType.GROUPED_MAP) > ...: def func(pdf): > ...: pdf['new'] = pdf['value'].where(pdf['value'] > 0, float('nan')) > ...: return pdf > ...: > ...: df.groupby('key').apply(func).show() > [Stage 23:======================================================> (73 + 2) / > 75]2018-07-07 16:26:27 ERROR Executor:91 - Exception in task 36.0 in stage > 23.0 (TID 414) > java.lang.IllegalStateException: Value at index is null > at org.apache.arrow.vector.Float8Vector.get(Float8Vector.java:98) > at > org.apache.spark.sql.vectorized.ArrowColumnVector$DoubleAccessor.getDouble(ArrowColumnVector.java:344) > at > org.apache.spark.sql.vectorized.ArrowColumnVector.getDouble(ArrowColumnVector.java:99) > at > org.apache.spark.sql.execution.vectorized.MutableColumnarRow.getDouble(MutableColumnarRow.java:126) > at > org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown > Source) > at > org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown > Source) > at scala.collection.Iterator$$anon$11.next(Iterator.scala:409) > at > org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown > Source) > at > org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) > at > org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247) > at > org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830) > at > org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830) > at > org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) > at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) > at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) > at > org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) > at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) > at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) > at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) > at org.apache.spark.scheduler.Task.run(Task.scala:109) > at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) > at > java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) > at > java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) > at java.lang.Thread.run(Thread.java:745) > 2018-07-07 16:26:27 WARN TaskSetManager:66 - Lost task 36.0 in stage 23.0 > (TID 414, localhost, executor driver): java.lang.IllegalStateException: Value > at index is null > at org.apache.arrow.vector.Float8Vector.get(Float8Vector.java:98) > at > org.apache.spark.sql.vectorized.ArrowColumnVector$DoubleAccessor.getDouble(ArrowColumnVector.java:344) > at > org.apache.spark.sql.vectorized.ArrowColumnVector.getDouble(ArrowColumnVector.java:99) > at > org.apache.spark.sql.execution.vectorized.MutableColumnarRow.getDouble(MutableColumnarRow.java:126) > at > org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown > Source) > at > org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown > Source) > at scala.collection.Iterator$$anon$11.next(Iterator.scala:409) > at > org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown > Source) > at > org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) > at > org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247) > at > org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830) > at > org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830) > at > org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) > at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) > at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) > at > org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) > at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) > at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) > at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) > at org.apache.spark.scheduler.Task.run(Task.scala:109) > at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) > at > java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) > at > java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) > at java.lang.Thread.run(Thread.java:745) > {code} -- This message was sent by Atlassian JIRA (v7.6.3#76005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org