This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 079a9c5  [SPARK-36771][PYTHON] Fix `pop` of Categorical Series
079a9c5 is described below

commit 079a9c52925818532b57c9cec1ddd31be723885e
Author: Xinrong Meng <xinrong.m...@databricks.com>
AuthorDate: Tue Sep 21 14:11:21 2021 -0700

    [SPARK-36771][PYTHON] Fix `pop` of Categorical Series
    
    ### What changes were proposed in this pull request?
    Fix `pop` of Categorical Series to be consistent with the latest pandas 
(1.3.2) behavior.
    
    ### Why are the changes needed?
    As https://github.com/databricks/koalas/issues/2198, pandas API on Spark 
behaves differently from pandas on `pop` of Categorical Series.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, results of `pop` of Categorical Series change.
    
    #### From
    ```py
    >>> psser = ps.Series(["a", "b", "c", "a"], dtype="category")
    >>> psser
    0    a
    1    b
    2    c
    3    a
    dtype: category
    Categories (3, object): ['a', 'b', 'c']
    >>> psser.pop(0)
    0
    >>> psser
    1    b
    2    c
    3    a
    dtype: category
    Categories (3, object): ['a', 'b', 'c']
    >>> psser.pop(3)
    0
    >>> psser
    1    b
    2    c
    dtype: category
    Categories (3, object): ['a', 'b', 'c']
    ```
    
    #### To
    ```py
    >>> psser = ps.Series(["a", "b", "c", "a"], dtype="category")
    >>> psser
    0    a
    1    b
    2    c
    3    a
    dtype: category
    Categories (3, object): ['a', 'b', 'c']
    >>> psser.pop(0)
    'a'
    >>> psser
    1    b
    2    c
    3    a
    dtype: category
    Categories (3, object): ['a', 'b', 'c']
    >>> psser.pop(3)
    'a'
    >>> psser
    1    b
    2    c
    dtype: category
    Categories (3, object): ['a', 'b', 'c']
    
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #34052 from xinrong-databricks/cat_pop.
    
    Authored-by: Xinrong Meng <xinrong.m...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/pandas/series.py            |  8 ++++++--
 python/pyspark/pandas/tests/test_series.py | 25 +++++++++++++++++++++++++
 2 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index d72c08d..da0d2fb 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -47,7 +47,7 @@ import numpy as np
 import pandas as pd
 from pandas.core.accessor import CachedAccessor
 from pandas.io.formats.printing import pprint_thing
-from pandas.api.types import is_list_like, is_hashable
+from pandas.api.types import is_list_like, is_hashable, CategoricalDtype
 from pandas.tseries.frequencies import DateOffset
 from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame
 from pyspark.sql.types import (
@@ -4098,7 +4098,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
             pdf = sdf.limit(2).toPandas()
             length = len(pdf)
             if length == 1:
-                return pdf[internal.data_spark_column_names[0]].iloc[0]
+                val = pdf[internal.data_spark_column_names[0]].iloc[0]
+                if isinstance(self.dtype, CategoricalDtype):
+                    return self.dtype.categories[val]
+                else:
+                    return val
 
             item_string = name_like_string(item)
             sdf = sdf.withColumn(SPARK_DEFAULT_INDEX_NAME, 
SF.lit(str(item_string)))
diff --git a/python/pyspark/pandas/tests/test_series.py 
b/python/pyspark/pandas/tests/test_series.py
index 09e5d30..b7bb121 100644
--- a/python/pyspark/pandas/tests/test_series.py
+++ b/python/pyspark/pandas/tests/test_series.py
@@ -1669,6 +1669,31 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
         with self.assertRaisesRegex(KeyError, msg):
             psser.pop(("lama", "speed", "x"))
 
+        pser = pd.Series(["a", "b", "c", "a"], dtype="category")
+        psser = ps.from_pandas(pser)
+
+        if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"):
+            self.assert_eq(psser.pop(0), pser.pop(0))
+            self.assert_eq(psser, pser)
+
+            self.assert_eq(psser.pop(3), pser.pop(3))
+            self.assert_eq(psser, pser)
+        else:
+            # Before pandas 1.3.0, `pop` modifies the dtype of categorical 
series wrongly.
+            self.assert_eq(psser.pop(0), "a")
+            self.assert_eq(
+                psser,
+                pd.Series(
+                    pd.Categorical(["b", "c", "a"], categories=["a", "b", 
"c"]), index=[1, 2, 3]
+                ),
+            )
+
+            self.assert_eq(psser.pop(3), "a")
+            self.assert_eq(
+                psser,
+                pd.Series(pd.Categorical(["b", "c"], categories=["a", "b", 
"c"]), index=[1, 2]),
+            )
+
     def test_replace(self):
         pser = pd.Series([10, 20, 15, 30, np.nan], name="x")
         psser = ps.Series(pser)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to