This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 94f5f4ff659 [SPARK-43704][CONNECT][PS] Support `MultiIndex` for 
`to_series()`
94f5f4ff659 is described below

commit 94f5f4ff659180b20700f7cc61cd195d52da56e6
Author: Haejoon Lee <[email protected]>
AuthorDate: Sun Oct 8 09:54:37 2023 -0700

    [SPARK-43704][CONNECT][PS] Support `MultiIndex` for `to_series()`
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to support `MultiIndex` for `to_series()`.
    
    ### Why are the changes needed?
    
    So far, `to_series()` for `MultiIndex` is not working properly since the 
underlying data structure is different from Pandas and Spark. See the below 
examples in the next section for more detail.
    
    ### Does this PR introduce _any_ user-facing change?
    
    **Before**
    ```python
    >>> psmidx = ps.MultiIndex.from_tuples([("A", "B")])
    >>> psmidx.to_series()
    A  B    {'__index_level_0__': 'A', '__index_level_1__'...
       C    {'__index_level_0__': 'A', '__index_level_1__'...
    B  C    {'__index_level_0__': 'B', '__index_level_1__'...
    dtype: object
    ```
    
    **After**
    ```python
    >>> psmidx = ps.MultiIndex.from_tuples([("A", "B")])
    >>> psmidx.to_series()
    A  B    [A, B]
       C    [A, C]
    B  C    [B, C]
    dtype: object
    ```
    
    ### How was this patch tested?
    
    Enabling the existing UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43228 from itholic/SPARK-43704.
    
    Authored-by: Haejoon Lee <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/pandas/indexes/base.py                      | 14 +++++++++++++-
 python/pyspark/pandas/indexing.py                          | 12 ++++++++----
 .../pandas/tests/connect/indexes/test_parity_base.py       |  4 ----
 3 files changed, 21 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/pandas/indexes/base.py 
b/python/pyspark/pandas/indexes/base.py
index c020e918d37..5652c6a8a85 100644
--- a/python/pyspark/pandas/indexes/base.py
+++ b/python/pyspark/pandas/indexes/base.py
@@ -916,7 +916,19 @@ class Index(IndexOpsMixin):
             data_fields=[field],
             column_label_names=None,
         )
-        return first_series(DataFrame(internal))
+
+        result = first_series(DataFrame(internal))
+        if self._internal.index_level == 1:
+            return result
+        else:
+            # MultiIndex
+            def struct_to_array(scol: Column) -> Column:
+                field_names = result._internal.spark_type_for(
+                    scol
+                ).fieldNames()  # type: ignore[attr-defined]
+                return F.array([scol[field] for field in field_names])
+
+            return result.spark.transform(struct_to_array)
 
     def to_frame(self, index: bool = True, name: Optional[Name] = None) -> 
DataFrame:
         """
diff --git a/python/pyspark/pandas/indexing.py 
b/python/pyspark/pandas/indexing.py
index 460eb37af78..c725d01d673 100644
--- a/python/pyspark/pandas/indexing.py
+++ b/python/pyspark/pandas/indexing.py
@@ -1077,12 +1077,16 @@ class LocIndexer(LocIndexerLike):
 
             return reduce(lambda x, y: x & y, conds), None, None
         else:
-            from pyspark.sql.types import StructType
+            from pyspark.sql.types import ArrayType, StructType
 
             index = self._psdf_or_psser.index
-            index_data_type = [  # type: ignore[assignment]
-                f.dataType for f in cast(StructType, 
index.to_series().spark.data_type)
-            ]
+            data_type = index.to_series().spark.data_type
+            if isinstance(data_type, StructType):
+                index_data_type = [f.dataType for f in data_type]  # type: 
ignore[assignment]
+            elif isinstance(data_type, ArrayType):
+                index_data_type = [  # type: ignore[assignment]
+                    data_type.elementType for _ in 
range(index._internal.index_level)
+                ]
 
             start = rows_sel.start
             if start is not None:
diff --git a/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py 
b/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
index 8f1f2d2221c..83ce92eb34b 100644
--- a/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
+++ b/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
@@ -29,10 +29,6 @@ class IndexesParityTests(
     def psdf(self):
         return ps.from_pandas(self.pdf)
 
-    @unittest.skip("TODO(SPARK-43704): Enable 
IndexesParityTests.test_to_series.")
-    def test_to_series(self):
-        super().test_to_series()
-
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.indexes.test_parity_base import *  # 
noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to