This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 44cfce8  [SPARK-36274][PYTHON] Fix equality comparison of unordered 
Categoricals
44cfce8 is described below

commit 44cfce8548f87aebd725442e1dc8a635318c5267
Author: Xinrong Meng <[email protected]>
AuthorDate: Fri Jul 23 18:30:59 2021 -0700

    [SPARK-36274][PYTHON] Fix equality comparison of unordered Categoricals
    
    ### What changes were proposed in this pull request?
    Fix equality comparison of unordered Categoricals.
    
    ### Why are the changes needed?
    Codes of a Categorical Series are used for Series equality comparison. 
However, that doesn't apply to unordered Categoricals, where the same value can 
have different codes in two same categories in a different order.
    
    So we should map codes to value respectively and then compare the equality 
of value.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    From:
    ```py
    >>> psser1 = ps.Series(pd.Categorical(list("abca")))
    >>> psser2 = ps.Series(pd.Categorical(list("bcaa"), categories=list("bca")))
    >>> with ps.option_context("compute.ops_on_diff_frames", True):
    ...     (psser1 == psser2).sort_index()
    ...
    0     True
    1     True
    2     True
    3    False
    dtype: bool
    ```
    
    To:
    ```py
    >>> psser1 = ps.Series(pd.Categorical(list("abca")))
    >>> psser2 = ps.Series(pd.Categorical(list("bcaa"), categories=list("bca")))
    >>> with ps.option_context("compute.ops_on_diff_frames", True):
    ...     (psser1 == psser2).sort_index()
    ...
    0    False
    1    False
    2    False
    3     True
    dtype: bool
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #33497 from xinrong-databricks/cat_bug.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Takuya UESHIN <[email protected]>
    (cherry picked from commit 85adc2ff60812f4af7befe0e8791d868a23359ae)
    Signed-off-by: Takuya UESHIN <[email protected]>
---
 .../pandas/data_type_ops/categorical_ops.py        | 27 +++++++++++++---------
 .../tests/data_type_ops/test_categorical_ops.py    | 26 +++++++++++++++++++++
 2 files changed, 42 insertions(+), 11 deletions(-)

diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py 
b/python/pyspark/pandas/data_type_ops/categorical_ops.py
index 932b9ed..36d5181 100644
--- a/python/pyspark/pandas/data_type_ops/categorical_ops.py
+++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py
@@ -61,16 +61,7 @@ class CategoricalOps(DataTypeOps):
         if isinstance(dtype, CategoricalDtype) and cast(CategoricalDtype, 
dtype).categories is None:
             return index_ops.copy()
 
-        categories = cast(CategoricalDtype, index_ops.dtype).categories
-        if len(categories) == 0:
-            scol = SF.lit(None)
-        else:
-            kvs = chain(
-                *[(SF.lit(code), SF.lit(category)) for code, category in 
enumerate(categories)]
-            )
-            map_scol = F.create_map(*kvs)
-            scol = map_scol[index_ops.spark.column]
-        return index_ops._with_new_scol(scol).astype(dtype)
+        return _to_cat(index_ops).astype(dtype)
 
     def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         return _compare(left, right, Column.__eq__, 
is_equality_comparison=True)
@@ -119,7 +110,10 @@ def _compare(
         # Check if categoricals have the same dtype, same categories, and same 
ordered
         if hash(left.dtype) != hash(right.dtype):
             raise TypeError("Categoricals can only be compared if 'categories' 
are the same.")
-        return column_op(f)(left, right)
+        if cast(CategoricalDtype, left.dtype).ordered:
+            return column_op(f)(left, right)
+        else:
+            return column_op(f)(_to_cat(left), _to_cat(right))
     elif not is_list_like(right):
         categories = cast(CategoricalDtype, left.dtype).categories
         if right not in categories:
@@ -128,3 +122,14 @@ def _compare(
         return column_op(f)(left, right_code)
     else:
         raise TypeError("Cannot compare a Categorical with the given type.")
+
+
+def _to_cat(index_ops: IndexOpsLike) -> IndexOpsLike:
+    categories = cast(CategoricalDtype, index_ops.dtype).categories
+    if len(categories) == 0:
+        scol = SF.lit(None)
+    else:
+        kvs = chain(*[(SF.lit(code), SF.lit(category)) for code, category in 
enumerate(categories)])
+        map_scol = F.create_map(*kvs)
+        scol = map_scol[index_ops.spark.column]
+    return index_ops._with_new_scol(scol)
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py 
b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py
index c9d150c..1dc9c39 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py
@@ -48,6 +48,9 @@ class CategoricalOpsTest(PandasOnSparkTestCase, 
TestCasesUtils):
                 "that_ordered_string_cat": pd.Categorical(
                     ["z", "y", "x"], categories=["x", "z", "y"], ordered=True
                 ),
+                "this_given_cat_string_cat": pd.Series(
+                    pd.Categorical(["x", "y", "z"], categories=list("zyx"))
+                ),
             }
         )
 
@@ -253,6 +256,18 @@ class CategoricalOpsTest(PandasOnSparkTestCase, 
TestCasesUtils):
             psdf["this_string_cat"] == psdf["that_string_cat"],
         )
 
+        self.assert_eq(
+            pdf["this_string_cat"] == pdf["this_given_cat_string_cat"],
+            psdf["this_string_cat"] == psdf["this_given_cat_string_cat"],
+        )
+
+        pser1 = pd.Series(pd.Categorical(list("abca")))
+        pser2 = pd.Series(pd.Categorical(list("bcaa"), categories=list("bca")))
+        psser1 = ps.from_pandas(pser1)
+        psser2 = ps.from_pandas(pser2)
+        with option_context("compute.ops_on_diff_frames", True):
+            self.assert_eq(pser1 == pser2, (psser1 == psser2).sort_index())
+
     def test_ne(self):
         pdf, psdf = self.pdf, self.psdf
 
@@ -302,6 +317,17 @@ class CategoricalOpsTest(PandasOnSparkTestCase, 
TestCasesUtils):
             pdf["this_string_cat"] != pdf["that_string_cat"],
             psdf["this_string_cat"] != psdf["that_string_cat"],
         )
+        self.assert_eq(
+            pdf["this_string_cat"] != pdf["this_given_cat_string_cat"],
+            psdf["this_string_cat"] != psdf["this_given_cat_string_cat"],
+        )
+
+        pser1 = pd.Series(pd.Categorical(list("abca")))
+        pser2 = pd.Series(pd.Categorical(list("bcaa"), categories=list("bca")))
+        psser1 = ps.from_pandas(pser1)
+        psser2 = ps.from_pandas(pser2)
+        with option_context("compute.ops_on_diff_frames", True):
+            self.assert_eq(pser1 != pser2, (psser1 != psser2).sort_index())
 
     def test_lt(self):
         pdf, psdf = self.pdf, self.psdf

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to