BryanCutler commented on a change in pull request #26585:
URL: https://github.com/apache/spark/pull/26585#discussion_r429422097



##########
File path: python/pyspark/sql/pandas/serializers.py
##########
@@ -154,6 +155,9 @@ def create_array(s, t):
             # Ensure timestamp series are in expected form for Spark internal 
representation
             if t is not None and pa.types.is_timestamp(t):
                 s = _check_series_convert_timestamps_internal(s, 
self._timezone)
+            elif type(s.dtype) == pd.CategoricalDtype:
+                # FIXME: This can be removed once minimum pyarrow version is 
>= 0.16.1

Review comment:
       please change `FIXME` -> `NOTE`. It sounds like we are adding broken 
code, which isn't the case. It's just not needed after a certain version.

##########
File path: python/pyspark/sql/pandas/serializers.py
##########
@@ -142,6 +142,7 @@ def _create_batch(self, series):
         """
         import pandas as pd
         import pyarrow as pa
+

Review comment:
       nit: remove newline

##########
File path: python/pyspark/sql/tests/test_arrow.py
##########
@@ -415,6 +415,20 @@ def run_test(num_records, num_parts, max_records, 
use_delay=False):
         for case in cases:
             run_test(*case)
 
+    def test_createDateFrame_with_category_type(self):
+        pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
+        pdf["B"] = pdf["A"].astype('category')
+
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
True}):
+            arrow_df = self.spark.createDataFrame(pdf)
+            result_arrow = arrow_df.toPandas()
+
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
False}):
+            df = self.spark.createDataFrame(pdf)
+            result_spark = df.toPandas()
+
+        assert_frame_equal(result_spark, result_arrow)
+

Review comment:
       could you add an assert that the Spark DataFrame has column "B" as a 
string type?

##########
File path: python/pyspark/sql/tests/test_pandas_udf_scalar.py
##########
@@ -897,6 +897,30 @@ def test_timestamp_dst(self):
             result = df.withColumn('time', foo_udf(df.time))
             self.assertEquals(df.collect(), result.collect())
 
+    def test_createDateFrame_with_category_type(self):

Review comment:
       This test module is for `pandas_udf`s, not for `createDataFrame`. We do 
need to add a `pandas_udf` that tests this. The user would specify a return 
type of `string` and then return a categorical pandas.Series that has string 
categories. For example:
   
   ```python
   @pandas_udf('string')
   def f(x):
       return x.astype('category')
   
   pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
   df = spark.createDataFrame(pdf).withColumn("B", f(col("A")))
   result = df.toPandas()
   # Check result "B" is equal to "A"
   ```
   

##########
File path: python/pyspark/sql/tests/test_pandas_udf_scalar.py
##########
@@ -897,6 +897,30 @@ def test_timestamp_dst(self):
             result = df.withColumn('time', foo_udf(df.time))
             self.assertEquals(df.collect(), result.collect())
 
+    def test_createDateFrame_with_category_type(self):
+        pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
+        pdf["B"] = pdf["A"].astype('category')
+        category_first_element = dict(enumerate(pdf['B'].cat.categories))[0]
+
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
True}):
+            arrow_df = self.spark.createDataFrame(pdf)
+            arrow_type = arrow_df.dtypes[1][1]
+            result_arrow = arrow_df.collect()
+            arrow_first_category_element = result_arrow[0][1]
+
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
False}):
+            df = self.spark.createDataFrame(pdf)
+            spark_type = df.dtypes[1][1]
+            result_spark = df.collect()
+            spark_first_category_element = result_spark[0][1]
+
+        # ensure original category elements are string
+        assert isinstance(category_first_element, str)
+        # spark dataframe and arrow execution mode enabled dataframe type must 
match padnads
+        assert spark_type == arrow_type == 'string'
+        assert isinstance(arrow_first_category_element, str)
+        assert isinstance(spark_first_category_element, str)

Review comment:
       Oh yeah, move these to the other test please.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to