This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 079a9c5 [SPARK-36771][PYTHON] Fix `pop` of Categorical Series 079a9c5 is described below commit 079a9c52925818532b57c9cec1ddd31be723885e Author: Xinrong Meng <xinrong.m...@databricks.com> AuthorDate: Tue Sep 21 14:11:21 2021 -0700 [SPARK-36771][PYTHON] Fix `pop` of Categorical Series ### What changes were proposed in this pull request? Fix `pop` of Categorical Series to be consistent with the latest pandas (1.3.2) behavior. ### Why are the changes needed? As https://github.com/databricks/koalas/issues/2198, pandas API on Spark behaves differently from pandas on `pop` of Categorical Series. ### Does this PR introduce _any_ user-facing change? Yes, results of `pop` of Categorical Series change. #### From ```py >>> psser = ps.Series(["a", "b", "c", "a"], dtype="category") >>> psser 0 a 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(0) 0 >>> psser 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(3) 0 >>> psser 1 b 2 c dtype: category Categories (3, object): ['a', 'b', 'c'] ``` #### To ```py >>> psser = ps.Series(["a", "b", "c", "a"], dtype="category") >>> psser 0 a 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(0) 'a' >>> psser 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(3) 'a' >>> psser 1 b 2 c dtype: category Categories (3, object): ['a', 'b', 'c'] ``` ### How was this patch tested? Unit tests. Closes #34052 from xinrong-databricks/cat_pop. Authored-by: Xinrong Meng <xinrong.m...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/pandas/series.py | 8 ++++++-- python/pyspark/pandas/tests/test_series.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index d72c08d..da0d2fb 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -47,7 +47,7 @@ import numpy as np import pandas as pd from pandas.core.accessor import CachedAccessor from pandas.io.formats.printing import pprint_thing -from pandas.api.types import is_list_like, is_hashable +from pandas.api.types import is_list_like, is_hashable, CategoricalDtype from pandas.tseries.frequencies import DateOffset from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame from pyspark.sql.types import ( @@ -4098,7 +4098,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]): pdf = sdf.limit(2).toPandas() length = len(pdf) if length == 1: - return pdf[internal.data_spark_column_names[0]].iloc[0] + val = pdf[internal.data_spark_column_names[0]].iloc[0] + if isinstance(self.dtype, CategoricalDtype): + return self.dtype.categories[val] + else: + return val item_string = name_like_string(item) sdf = sdf.withColumn(SPARK_DEFAULT_INDEX_NAME, SF.lit(str(item_string))) diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index 09e5d30..b7bb121 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -1669,6 +1669,31 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils): with self.assertRaisesRegex(KeyError, msg): psser.pop(("lama", "speed", "x")) + pser = pd.Series(["a", "b", "c", "a"], dtype="category") + psser = ps.from_pandas(pser) + + if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"): + self.assert_eq(psser.pop(0), pser.pop(0)) + self.assert_eq(psser, pser) + + self.assert_eq(psser.pop(3), pser.pop(3)) + self.assert_eq(psser, pser) + else: + # Before pandas 1.3.0, `pop` modifies the dtype of categorical series wrongly. + self.assert_eq(psser.pop(0), "a") + self.assert_eq( + psser, + pd.Series( + pd.Categorical(["b", "c", "a"], categories=["a", "b", "c"]), index=[1, 2, 3] + ), + ) + + self.assert_eq(psser.pop(3), "a") + self.assert_eq( + psser, + pd.Series(pd.Categorical(["b", "c"], categories=["a", "b", "c"]), index=[1, 2]), + ) + def test_replace(self): pser = pd.Series([10, 20, 15, 30, np.nan], name="x") psser = ps.Series(pser) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org