This is an automated email from the ASF dual-hosted git repository.

cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 816aba3  [SPARK-34521][PYTHON][SQL] Fix spark.createDataFrame when 
using pandas with StringDtype
816aba3 is described below

commit 816aba355cf8601688ba765b9c852b7feb64d3c2
Author: Nicolas Azrak <[email protected]>
AuthorDate: Wed Dec 15 22:03:19 2021 -0800

    [SPARK-34521][PYTHON][SQL] Fix spark.createDataFrame when using pandas with 
StringDtype
    
    ### What changes were proposed in this pull request?
    
    This change fixes `SPARK-34521`. It allows creating a spark DataFrame from 
a pandas DataFrame that is using a `StringDtype` column and arrow pyspark 
enabled.
    
    ### Why are the changes needed?
    
    Pandas stores string columns in two different ways: using a numpy `ndarray` 
or using a custom `StringArray`. The `StringArray` version is used when 
specifing the `dtype=string`. When that happens, spark cannot serialize the 
column to arrow. Converting the `Series` before fixes this problem.
    
    However, due to the different ways to handle string columns, doing 
`spark.createDataFrame(pandas_dataframe).toPandas()` might not equal to 
`pandas_dataframe`. The column dtype could be different.
    
    More info: https://pandas.pydata.org/docs/user_guide/text.html
    
    ### Does this PR introduce _any_ user-facing change?
    
    Trying to create a spark `DataFrame` from a pandas `DataFrame` using a 
string dtype and `"spark.sql.execution.arrow.pyspark.enabled"` now doesn't 
throw an exception and returns the expected dataframe.
    
    Before:
    
    `spark.createDataFrame(pd.DataFrame({"A": ['a', 'b', 'c']}, 
dtype="string"))`
    ```
    Error
    Traceback (most recent call last):
      File 
"/home/nico/projects/playground/spark/python/pyspark/sql/tests/test_arrow.py", 
line 415, in test_createDataFrame_with_string_dtype
        print(self.spark.createDataFrame(pd.DataFrame({"A": ['a', 'b', 'c']}, 
dtype="string")))
      File 
"/home/nico/projects/playground/spark/python/pyspark/sql/session.py", line 823, 
in createDataFrame
        return super(SparkSession, self).createDataFrame(  # type: 
ignore[call-overload]
      File 
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/conversion.py", 
line 358, in createDataFrame
        return self._create_from_pandas_with_arrow(data, schema, timezone)
      File 
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/conversion.py", 
line 550, in _create_from_pandas_with_arrow
        self._sc  # type: ignore[attr-defined]
      File "/home/nico/projects/playground/spark/python/pyspark/context.py", 
line 611, in _serialize_to_jvm
        serializer.dump_stream(data, tempFile)
      File 
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
 line 221, in dump_stream
        super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)
      File 
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
 line 81, in dump_stream
        for batch in iterator:
      File 
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
 line 220, in <genexpr>
        batches = (self._create_batch(series) for series in iterator)
      File 
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
 line 211, in _create_batch
        arrs.append(create_array(s, t))
      File 
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
 line 185, in create_array
        raise e
      File 
"/home/nico/projects/playground/spark/python/pyspark/sql/pandas/serializers.py",
 line 175, in create_array
        array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
      File "pyarrow/array.pxi", line 904, in pyarrow.lib.Array.from_pandas
      File "pyarrow/array.pxi", line 252, in pyarrow.lib.array
      File "pyarrow/array.pxi", line 107, in 
pyarrow.lib._handle_arrow_array_protocol
    ValueError: Cannot specify a mask or a size when passing an object that is 
converted with the __arrow_array__ protocol.
    
    ```
    
    After:
    
    `spark.createDataFrame(pd.DataFrame({"A": ['a', 'b', 'c']}, 
dtype="string"))`
    > `DataFrame[A: string]`
    
    ### How was this patch tested?
    
    Using the `test_createDataFrame_with_string_dtype` test.
    
    Closes #34509 from nicolasazrak/SPARK-34521.
    
    Authored-by: Nicolas Azrak <[email protected]>
    Signed-off-by: Bryan Cutler <[email protected]>
---
 python/pyspark/sql/pandas/serializers.py |  5 ++++-
 python/pyspark/sql/tests/test_arrow.py   | 21 +++++++++++++++++++++
 2 files changed, 25 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 47c98c8..992e82b 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -215,7 +215,10 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         series = ((s, None) if not isinstance(s, (list, tuple)) else s for s 
in series)
 
         def create_array(s, t):
-            mask = s.isnull()
+            if hasattr(s.array, "__arrow_array__"):
+                mask = None
+            else:
+                mask = s.isnull()
             # Ensure timestamp series are in expected form for Spark internal 
representation
             if t is not None and pa.types.is_timestamp(t) and t.tz is not None:
                 s = _check_series_convert_timestamps_internal(s, 
self._timezone)
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index 99705fb..ff42ade 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -548,6 +548,27 @@ class ArrowTests(ReusedSQLTestCase):
                 self.assertEqual(m, map_data[i])
                 self.assertEqual(m_arrow, map_data[i])
 
+    def test_createDataFrame_with_string_dtype(self):
+        # SPARK-34521: spark.createDataFrame does not support Pandas 
StringDtype extension type
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
True}):
+            data = [["abc"], ["def"], [None], ["ghi"], [None]]
+            pandas_df = pd.DataFrame(data, columns=["col"], dtype="string")
+            schema = StructType([StructField("col", StringType(), True)])
+            df = self.spark.createDataFrame(pandas_df, schema=schema)
+
+            # dtypes won't match. Pandas has two different ways to store 
string columns:
+            # using ndarray (when dtype isn't specified) or using a 
StringArray when dtype="string".
+            # When calling dataframe#toPandas() it will use the ndarray 
version.
+            # Changing that to use a StringArray would be backwards 
incompatible.
+            assert_frame_equal(pandas_df, df.toPandas(), check_dtype=False)
+
+    def test_createDataFrame_with_int64(self):
+        # SPARK-34521: spark.createDataFrame does not support Pandas 
StringDtype extension type
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
True}):
+            pandas_df = pd.DataFrame({"col": [1, 2, 3, None]}, dtype="Int64")
+            df = self.spark.createDataFrame(pandas_df)
+            assert_frame_equal(pandas_df, df.toPandas(), check_dtype=False)
+
     def test_toPandas_with_map_type(self):
         pdf = pd.DataFrame(
             {"id": [0, 1, 2, 3], "m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 
1, "b": 2, "c": 3}]}

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to