This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 94f5f4ff659 [SPARK-43704][CONNECT][PS] Support `MultiIndex` for
`to_series()`
94f5f4ff659 is described below
commit 94f5f4ff659180b20700f7cc61cd195d52da56e6
Author: Haejoon Lee <[email protected]>
AuthorDate: Sun Oct 8 09:54:37 2023 -0700
[SPARK-43704][CONNECT][PS] Support `MultiIndex` for `to_series()`
### What changes were proposed in this pull request?
This PR proposes to support `MultiIndex` for `to_series()`.
### Why are the changes needed?
So far, `to_series()` for `MultiIndex` is not working properly since the
underlying data structure is different from Pandas and Spark. See the below
examples in the next section for more detail.
### Does this PR introduce _any_ user-facing change?
**Before**
```python
>>> psmidx = ps.MultiIndex.from_tuples([("A", "B")])
>>> psmidx.to_series()
A B {'__index_level_0__': 'A', '__index_level_1__'...
C {'__index_level_0__': 'A', '__index_level_1__'...
B C {'__index_level_0__': 'B', '__index_level_1__'...
dtype: object
```
**After**
```python
>>> psmidx = ps.MultiIndex.from_tuples([("A", "B")])
>>> psmidx.to_series()
A B [A, B]
C [A, C]
B C [B, C]
dtype: object
```
### How was this patch tested?
Enabling the existing UT.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43228 from itholic/SPARK-43704.
Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/pandas/indexes/base.py | 14 +++++++++++++-
python/pyspark/pandas/indexing.py | 12 ++++++++----
.../pandas/tests/connect/indexes/test_parity_base.py | 4 ----
3 files changed, 21 insertions(+), 9 deletions(-)
diff --git a/python/pyspark/pandas/indexes/base.py
b/python/pyspark/pandas/indexes/base.py
index c020e918d37..5652c6a8a85 100644
--- a/python/pyspark/pandas/indexes/base.py
+++ b/python/pyspark/pandas/indexes/base.py
@@ -916,7 +916,19 @@ class Index(IndexOpsMixin):
data_fields=[field],
column_label_names=None,
)
- return first_series(DataFrame(internal))
+
+ result = first_series(DataFrame(internal))
+ if self._internal.index_level == 1:
+ return result
+ else:
+ # MultiIndex
+ def struct_to_array(scol: Column) -> Column:
+ field_names = result._internal.spark_type_for(
+ scol
+ ).fieldNames() # type: ignore[attr-defined]
+ return F.array([scol[field] for field in field_names])
+
+ return result.spark.transform(struct_to_array)
def to_frame(self, index: bool = True, name: Optional[Name] = None) ->
DataFrame:
"""
diff --git a/python/pyspark/pandas/indexing.py
b/python/pyspark/pandas/indexing.py
index 460eb37af78..c725d01d673 100644
--- a/python/pyspark/pandas/indexing.py
+++ b/python/pyspark/pandas/indexing.py
@@ -1077,12 +1077,16 @@ class LocIndexer(LocIndexerLike):
return reduce(lambda x, y: x & y, conds), None, None
else:
- from pyspark.sql.types import StructType
+ from pyspark.sql.types import ArrayType, StructType
index = self._psdf_or_psser.index
- index_data_type = [ # type: ignore[assignment]
- f.dataType for f in cast(StructType,
index.to_series().spark.data_type)
- ]
+ data_type = index.to_series().spark.data_type
+ if isinstance(data_type, StructType):
+ index_data_type = [f.dataType for f in data_type] # type:
ignore[assignment]
+ elif isinstance(data_type, ArrayType):
+ index_data_type = [ # type: ignore[assignment]
+ data_type.elementType for _ in
range(index._internal.index_level)
+ ]
start = rows_sel.start
if start is not None:
diff --git a/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
b/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
index 8f1f2d2221c..83ce92eb34b 100644
--- a/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
+++ b/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
@@ -29,10 +29,6 @@ class IndexesParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)
- @unittest.skip("TODO(SPARK-43704): Enable
IndexesParityTests.test_to_series.")
- def test_to_series(self):
- super().test_to_series()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.indexes.test_parity_base import * #
noqa: F401
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]