This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 079a9c5 [SPARK-36771][PYTHON] Fix `pop` of Categorical Series
079a9c5 is described below
commit 079a9c52925818532b57c9cec1ddd31be723885e
Author: Xinrong Meng <[email protected]>
AuthorDate: Tue Sep 21 14:11:21 2021 -0700
[SPARK-36771][PYTHON] Fix `pop` of Categorical Series
### What changes were proposed in this pull request?
Fix `pop` of Categorical Series to be consistent with the latest pandas
(1.3.2) behavior.
### Why are the changes needed?
As https://github.com/databricks/koalas/issues/2198, pandas API on Spark
behaves differently from pandas on `pop` of Categorical Series.
### Does this PR introduce _any_ user-facing change?
Yes, results of `pop` of Categorical Series change.
#### From
```py
>>> psser = ps.Series(["a", "b", "c", "a"], dtype="category")
>>> psser
0 a
1 b
2 c
3 a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(0)
0
>>> psser
1 b
2 c
3 a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(3)
0
>>> psser
1 b
2 c
dtype: category
Categories (3, object): ['a', 'b', 'c']
```
#### To
```py
>>> psser = ps.Series(["a", "b", "c", "a"], dtype="category")
>>> psser
0 a
1 b
2 c
3 a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(0)
'a'
>>> psser
1 b
2 c
3 a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(3)
'a'
>>> psser
1 b
2 c
dtype: category
Categories (3, object): ['a', 'b', 'c']
```
### How was this patch tested?
Unit tests.
Closes #34052 from xinrong-databricks/cat_pop.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
---
python/pyspark/pandas/series.py | 8 ++++++--
python/pyspark/pandas/tests/test_series.py | 25 +++++++++++++++++++++++++
2 files changed, 31 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index d72c08d..da0d2fb 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -47,7 +47,7 @@ import numpy as np
import pandas as pd
from pandas.core.accessor import CachedAccessor
from pandas.io.formats.printing import pprint_thing
-from pandas.api.types import is_list_like, is_hashable
+from pandas.api.types import is_list_like, is_hashable, CategoricalDtype
from pandas.tseries.frequencies import DateOffset
from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame
from pyspark.sql.types import (
@@ -4098,7 +4098,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
pdf = sdf.limit(2).toPandas()
length = len(pdf)
if length == 1:
- return pdf[internal.data_spark_column_names[0]].iloc[0]
+ val = pdf[internal.data_spark_column_names[0]].iloc[0]
+ if isinstance(self.dtype, CategoricalDtype):
+ return self.dtype.categories[val]
+ else:
+ return val
item_string = name_like_string(item)
sdf = sdf.withColumn(SPARK_DEFAULT_INDEX_NAME,
SF.lit(str(item_string)))
diff --git a/python/pyspark/pandas/tests/test_series.py
b/python/pyspark/pandas/tests/test_series.py
index 09e5d30..b7bb121 100644
--- a/python/pyspark/pandas/tests/test_series.py
+++ b/python/pyspark/pandas/tests/test_series.py
@@ -1669,6 +1669,31 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
with self.assertRaisesRegex(KeyError, msg):
psser.pop(("lama", "speed", "x"))
+ pser = pd.Series(["a", "b", "c", "a"], dtype="category")
+ psser = ps.from_pandas(pser)
+
+ if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"):
+ self.assert_eq(psser.pop(0), pser.pop(0))
+ self.assert_eq(psser, pser)
+
+ self.assert_eq(psser.pop(3), pser.pop(3))
+ self.assert_eq(psser, pser)
+ else:
+ # Before pandas 1.3.0, `pop` modifies the dtype of categorical
series wrongly.
+ self.assert_eq(psser.pop(0), "a")
+ self.assert_eq(
+ psser,
+ pd.Series(
+ pd.Categorical(["b", "c", "a"], categories=["a", "b",
"c"]), index=[1, 2, 3]
+ ),
+ )
+
+ self.assert_eq(psser.pop(3), "a")
+ self.assert_eq(
+ psser,
+ pd.Series(pd.Categorical(["b", "c"], categories=["a", "b",
"c"]), index=[1, 2]),
+ )
+
def test_replace(self):
pser = pd.Series([10, 20, 15, 30, np.nan], name="x")
psser = ps.Series(pser)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]