Dat Nguyen created ARROW-5016:
---------------------------------
Summary: Failed to convert 'float' to 'double' with using
pandas_udf and pyspark
Key: ARROW-5016
URL: https://issues.apache.org/jira/browse/ARROW-5016
Project: Apache Arrow
Issue Type: Bug
Components: Python
Affects Versions: 0.12.1
Environment: Linux 68b0517ddf1c 3.10.0-862.11.6.el7.x86_64 #1 SMP
GNU/Linux
Reporter: Dat Nguyen
Hi everyone,
I would like to report a (potential) bug. I followed an official guide on
[Usage Guide for Pandas with Apache
Arrow]([https://spark.apache.org/docs/2.4.0/sql-pyspark-pandas-with-arrow.html)].
However, `libarrrow` throws me error for type conversion from float -> double.
Here is the example and its output.
pyarrow==0.12.1
{code:title=reproduce_bug.py}
from pyspark.sql import SparkSession, SQLContext
from pyspark.sql.functions import pandas_udf, PandasUDFType, col
spark = SparkSession.builder.appName('ReproduceBug') .getOrCreate()
df = spark.createDataFrame(
[(1, "a"), (1, "a"), (1, "b")],
("id", "value"))
df.show()
# Spark DataFrame
# +---+-----+
# | id|value|
# +---+-----+
# | 1| a|
# | 1| a|
# | 1| b|
# +---+-----+
# Potential Bug #
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def compute_frequencies(sha256):
total = sha256.count()
per_groups = sha256.groupby(sha256).transform('count')
score = per_groups / total
return score
df.groupBy("id")\
.agg(compute_frequencies(col('value')))\
.show()
spark.stop()
{code}
{code:title=output}
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-3-d4f781f64db1> in <module>
32
33 df.groupBy("id")\
---> 34 .agg(compute_frequencies(col('value')))\
35 .show()
36
/usr/local/spark/python/pyspark/sql/dataframe.py in show(self, n, truncate,
vertical)
376 """
377 if isinstance(truncate, bool) and truncate:
--> 378 print(self._jdf.showString(n, 20, vertical))
379 else:
380 print(self._jdf.showString(n, int(truncate), vertical))
/usr/local/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in
__call__(self, *args)
1255 answer = self.gateway_client.send_command(command)
1256 return_value = get_return_value(
-> 1257 answer, self.gateway_client, self.target_id, self.name)
1258
1259 for temp_arg in temp_args:
/usr/local/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
61 def deco(*a, **kw):
62 try:
---> 63 return f(*a, **kw)
64 except py4j.protocol.Py4JJavaError as e:
65 s = e.java_exception.toString()
/usr/local/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in
get_return_value(answer, gateway_client, target_id, name)
326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
--> 328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
Py4JJavaError: An error occurred while calling o186.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 44 in
stage 23.0 failed 1 times, most recent failure: Lost task 44.0 in stage 23.0
(TID 601, localhost, executor driver):
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/worker.py", line 372,
in main
process()
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/worker.py", line 367,
in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line
284, in dump_stream
batch = _create_batch(series, self._timezone)
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line
253, in _create_batch
arrs = [create_array(s, t) for s, t in series]
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line
253, in <listcomp>
arrs = [create_array(s, t) for s, t in series]
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line
251, in create_array
return pa.Array.from_pandas(s, mask=mask, type=t)
File "pyarrow/array.pxi", line 536, in pyarrow.lib.Array.from_pandas
File "pyarrow/array.pxi", line 176, in pyarrow.lib.array
File "pyarrow/array.pxi", line 85, in pyarrow.lib._ndarray_to_array
File "pyarrow/error.pxi", line 81, in pyarrow.lib.check_status
pyarrow.lib.ArrowInvalid: Could not convert 0 0.666667
1 0.666667
2 0.333333
Name: _0, dtype: float64 with type Series: tried to convert to double
{code}
Please let me know if you would like to know more any further information.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)