This is an automated email from the ASF dual-hosted git repository.
cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 816aba3 [SPARK-34521][PYTHON][SQL] Fix spark.createDataFrame when
using pandas with StringDtype
816aba3 is described below
commit 816aba355cf8601688ba765b9c852b7feb64d3c2
Author: Nicolas Azrak <[email protected]>
AuthorDate: Wed Dec 15 22:03:19 2021 -0800
[SPARK-34521][PYTHON][SQL] Fix spark.createDataFrame when using pandas with
StringDtype
### What changes were proposed in this pull request?
This change fixes `SPARK-34521`. It allows creating a spark DataFrame from
a pandas DataFrame that is using a `StringDtype` column and arrow pyspark
enabled.
### Why are the changes needed?
Pandas stores string columns in two different ways: using a numpy `ndarray`
or using a custom `StringArray`. The `StringArray` version is used when
specifing the `dtype=string`. When that happens, spark cannot serialize the
column to arrow. Converting the `Series` before fixes this problem.
However, due to the different ways to handle string columns, doing
`spark.createDataFrame(pandas_dataframe).toPandas()` might not equal to
`pandas_dataframe`. The column dtype could be different.
More info: https://pandas.pydata.org/docs/user_guide/text.html
### Does this PR introduce _any_ user-facing change?
Trying to create a spark `DataFrame` from a pandas `DataFrame` using a
string dtype and `"spark.sql.execution.arrow.pyspark.enabled"` now doesn't
throw an exception and returns the expected dataframe.
Before:
`spark.createDataFrame(pd.DataFrame({"A": ['a', 'b', 'c']},
dtype="string"))`
```
Error
Traceback (most recent call last):
File
"/home/nico/projects/playground/spark/python/pyspark/sql/tests/test_arrow.py",
line 415, in test_createDataFrame_with_string_dtype
print(self.spark.createDataFrame(pd.DataFrame({"A": ['a', 'b', 'c']},
dtype="string")))
File
"/home/nico/projects/playground/spark/python/pyspark/sql/session.py", line 823,
in createDataFrame
return super(SparkSession, self).createDataFrame( # type:
ignore[call-overload]
File
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/conversion.py",
line 358, in createDataFrame
return self._create_from_pandas_with_arrow(data, schema, timezone)
File
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/conversion.py",
line 550, in _create_from_pandas_with_arrow
self._sc # type: ignore[attr-defined]
File "/home/nico/projects/playground/spark/python/pyspark/context.py",
line 611, in _serialize_to_jvm
serializer.dump_stream(data, tempFile)
File
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
line 221, in dump_stream
super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)
File
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
line 81, in dump_stream
for batch in iterator:
File
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
line 220, in <genexpr>
batches = (self._create_batch(series) for series in iterator)
File
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
line 211, in _create_batch
arrs.append(create_array(s, t))
File
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
line 185, in create_array
raise e
File
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
line 175, in create_array
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
File "pyarrow/array.pxi", line 904, in pyarrow.lib.Array.from_pandas
File "pyarrow/array.pxi", line 252, in pyarrow.lib.array
File "pyarrow/array.pxi", line 107, in
pyarrow.lib._handle_arrow_array_protocol
ValueError: Cannot specify a mask or a size when passing an object that is
converted with the __arrow_array__ protocol.
```
After:
`spark.createDataFrame(pd.DataFrame({"A": ['a', 'b', 'c']},
dtype="string"))`
> `DataFrame[A: string]`
### How was this patch tested?
Using the `test_createDataFrame_with_string_dtype` test.
Closes #34509 from nicolasazrak/SPARK-34521.
Authored-by: Nicolas Azrak <[email protected]>
Signed-off-by: Bryan Cutler <[email protected]>
---
python/pyspark/sql/pandas/serializers.py | 5 ++++-
python/pyspark/sql/tests/test_arrow.py | 21 +++++++++++++++++++++
2 files changed, 25 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 47c98c8..992e82b 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -215,7 +215,10 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s
in series)
def create_array(s, t):
- mask = s.isnull()
+ if hasattr(s.array, "__arrow_array__"):
+ mask = None
+ else:
+ mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal
representation
if t is not None and pa.types.is_timestamp(t) and t.tz is not None:
s = _check_series_convert_timestamps_internal(s,
self._timezone)
diff --git a/python/pyspark/sql/tests/test_arrow.py
b/python/pyspark/sql/tests/test_arrow.py
index 99705fb..ff42ade 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -548,6 +548,27 @@ class ArrowTests(ReusedSQLTestCase):
self.assertEqual(m, map_data[i])
self.assertEqual(m_arrow, map_data[i])
+ def test_createDataFrame_with_string_dtype(self):
+ # SPARK-34521: spark.createDataFrame does not support Pandas
StringDtype extension type
+ with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled":
True}):
+ data = [["abc"], ["def"], [None], ["ghi"], [None]]
+ pandas_df = pd.DataFrame(data, columns=["col"], dtype="string")
+ schema = StructType([StructField("col", StringType(), True)])
+ df = self.spark.createDataFrame(pandas_df, schema=schema)
+
+ # dtypes won't match. Pandas has two different ways to store
string columns:
+ # using ndarray (when dtype isn't specified) or using a
StringArray when dtype="string".
+ # When calling dataframe#toPandas() it will use the ndarray
version.
+ # Changing that to use a StringArray would be backwards
incompatible.
+ assert_frame_equal(pandas_df, df.toPandas(), check_dtype=False)
+
+ def test_createDataFrame_with_int64(self):
+ # SPARK-34521: spark.createDataFrame does not support Pandas
StringDtype extension type
+ with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled":
True}):
+ pandas_df = pd.DataFrame({"col": [1, 2, 3, None]}, dtype="Int64")
+ df = self.spark.createDataFrame(pandas_df)
+ assert_frame_equal(pandas_df, df.toPandas(), check_dtype=False)
+
def test_toPandas_with_map_type(self):
pdf = pd.DataFrame(
{"id": [0, 1, 2, 3], "m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a":
1, "b": 2, "c": 3}]}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]