This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d86208d043a [SPARK-43620][CONNECT][PS] Fix Pandas APIs depends on
unsupported features
d86208d043a is described below
commit d86208d043a27a02d3c4ccf6a929c7e6b8ad0292
Author: Haejoon Lee <[email protected]>
AuthorDate: Wed Oct 4 12:22:34 2023 +0900
[SPARK-43620][CONNECT][PS] Fix Pandas APIs depends on unsupported features
### What changes were proposed in this pull request?
This PR proposes to fix the Pandas APIs that have dependency on unsupported
PySpark features.
### Why are the changes needed?
To increate the API coverage for Pandas API on Spark with Spark Connect.
### Does this PR introduce _any_ user-facing change?
Pandas data type APIs such as `astype` and `factorize` are supported on
Spark Connect.
### How was this patch tested?
Enabling the existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43120 from itholic/SPARK-43620.
Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/base.py | 25 +++++++--------
python/pyspark/pandas/data_type_ops/base.py | 26 ++++++++++++----
.../pandas/data_type_ops/categorical_ops.py | 7 ++---
.../data_type_ops/test_parity_binary_ops.py | 4 +--
.../data_type_ops/test_parity_boolean_ops.py | 4 ---
.../data_type_ops/test_parity_categorical_ops.py | 12 --------
.../connect/data_type_ops/test_parity_date_ops.py | 4 ---
.../data_type_ops/test_parity_datetime_ops.py | 4 ---
.../connect/data_type_ops/test_parity_null_ops.py | 4 +--
.../connect/data_type_ops/test_parity_num_ops.py | 4 ---
.../data_type_ops/test_parity_string_ops.py | 4 ---
.../data_type_ops/test_parity_timedelta_ops.py | 4 ---
.../tests/connect/indexes/test_parity_base.py | 4 ---
.../tests/connect/indexes/test_parity_category.py | 36 +---------------------
.../tests/connect/series/test_parity_compute.py | 4 ---
.../tests/connect/test_parity_categorical.py | 24 ---------------
16 files changed, 39 insertions(+), 131 deletions(-)
diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py
index ed6e983fdc8..fa513e8b9b6 100644
--- a/python/pyspark/pandas/base.py
+++ b/python/pyspark/pandas/base.py
@@ -1704,16 +1704,10 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
if len(categories) == 0:
scol = F.lit(None)
else:
- kvs = list(
- chain(
- *[
- (F.lit(code), F.lit(category))
- for code, category in enumerate(categories)
- ]
- )
- )
- map_scol = F.create_map(*kvs)
- scol = map_scol[self.spark.column]
+ scol = F.lit(None)
+ for code, category in reversed(list(enumerate(categories))):
+ scol = F.when(self.spark.column == F.lit(code),
F.lit(category)).otherwise(scol)
+
codes, uniques = self._with_new_scol(
scol.alias(self._internal.data_spark_column_names[0])
).factorize(use_na_sentinel=use_na_sentinel)
@@ -1761,9 +1755,16 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
if len(kvs) == 0: # uniques are all missing values
new_scol = F.lit(na_sentinel_code)
else:
- map_scol = F.create_map(*kvs)
null_scol = F.when(self.isnull().spark.column,
F.lit(na_sentinel_code))
- new_scol = null_scol.otherwise(map_scol[self.spark.column])
+ mapped_scol = None
+ for i in range(0, len(kvs), 2):
+ key = kvs[i]
+ value = kvs[i + 1]
+ if mapped_scol is None:
+ mapped_scol = F.when(self.spark.column == key, value)
+ else:
+ mapped_scol = mapped_scol.when(self.spark.column == key,
value)
+ new_scol = null_scol.otherwise(mapped_scol)
codes =
self._with_new_scol(new_scol.alias(self._internal.data_spark_column_names[0]))
diff --git a/python/pyspark/pandas/data_type_ops/base.py
b/python/pyspark/pandas/data_type_ops/base.py
index 5d497a55a5f..4f57aa65be7 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -17,7 +17,6 @@
import numbers
from abc import ABCMeta
-from itertools import chain
from typing import Any, Optional, Union
import numpy as np
@@ -130,12 +129,27 @@ def _as_categorical_type(
if len(categories) == 0:
scol = F.lit(-1)
else:
- kvs = chain(
- *[(F.lit(category), F.lit(code)) for code, category in
enumerate(categories)]
- )
- map_scol = F.create_map(*kvs)
+ scol = F.lit(-1)
+ if isinstance(
+
index_ops._internal.spark_type_for(index_ops._internal.column_labels[0]),
BinaryType
+ ):
+ from pyspark.sql.functions import base64
+
+ stringified_column = base64(index_ops.spark.column)
+ for code, category in enumerate(categories):
+ # Convert each category to base64 before comparison
+ base64_category = F.base64(F.lit(category))
+ scol = F.when(stringified_column == base64_category,
F.lit(code)).otherwise(
+ scol
+ )
+ else:
+ stringified_column = F.format_string("%s",
index_ops.spark.column)
+
+ for code, category in enumerate(categories):
+ scol = F.when(stringified_column == F.lit(category),
F.lit(code)).otherwise(
+ scol
+ )
- scol = F.coalesce(map_scol[index_ops.spark.column], F.lit(-1))
return index_ops._with_new_scol(
scol.cast(spark_type),
field=index_ops._internal.data_fields[0].copy(
diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py
b/python/pyspark/pandas/data_type_ops/categorical_ops.py
index 824666b5819..bbaded42be9 100644
--- a/python/pyspark/pandas/data_type_ops/categorical_ops.py
+++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
-from itertools import chain
from typing import cast, Any, Union
import pandas as pd
@@ -135,7 +134,7 @@ def _to_cat(index_ops: IndexOpsLike) -> IndexOpsLike:
if len(categories) == 0:
scol = F.lit(None)
else:
- kvs = chain(*[(F.lit(code), F.lit(category)) for code, category in
enumerate(categories)])
- map_scol = F.create_map(*kvs)
- scol = map_scol[index_ops.spark.column]
+ scol = F.lit(None)
+ for code, category in reversed(list(enumerate(categories))):
+ scol = F.when(index_ops.spark.column == F.lit(code),
F.lit(category)).otherwise(scol)
return index_ops._with_new_scol(scol)
diff --git
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py
index 663c0007389..29b13868e03 100644
---
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py
+++
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py
@@ -25,9 +25,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase
class BinaryOpsParityTests(
BinaryOpsTestsMixin, PandasOnSparkTestUtils, OpsTestBase,
ReusedConnectTestCase
):
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_astype(self):
- super().test_astype()
+ pass
if __name__ == "__main__":
diff --git
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py
index 52d517967eb..9ad2aa0ad17 100644
---
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py
+++
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py
@@ -30,10 +30,6 @@ class BooleanOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_astype(self):
- super().test_astype()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_boolean_ops
import * # noqa: F401
diff --git
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
index b680e5b3d79..1b4dabdb045 100644
---
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
+++
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
@@ -30,18 +30,6 @@ class CategoricalOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_astype(self):
- super().test_astype()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_eq(self):
- super().test_eq()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_ne(self):
- super().test_ne()
-
if __name__ == "__main__":
from
pyspark.pandas.tests.connect.data_type_ops.test_parity_categorical_ops import *
diff --git
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py
index e7b1c7de70d..baa3180baaa 100644
--- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py
+++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py
@@ -30,10 +30,6 @@ class DateOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_astype(self):
- super().test_astype()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_date_ops
import * # noqa: F401
diff --git
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py
index 6d081b10aba..2641e3a32dc 100644
---
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py
+++
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py
@@ -30,10 +30,6 @@ class DatetimeOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_astype(self):
- super().test_astype()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_datetime_ops
import * # noqa: F401
diff --git
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py
index 63b53c02fd7..5df4c791c98 100644
--- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py
+++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py
@@ -25,9 +25,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase
class NullOpsParityTests(
NullOpsTestsMixin, PandasOnSparkTestUtils, OpsTestBase,
ReusedConnectTestCase
):
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_astype(self):
- super().test_astype()
+ pass
if __name__ == "__main__":
diff --git
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py
index 04aa24c4045..56eba708c94 100644
--- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py
+++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py
@@ -30,10 +30,6 @@ class NumOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_astype(self):
- super().test_astype()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_num_ops import
* # noqa: F401
diff --git
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py
index ecbf94a6bde..f507756a7a4 100644
---
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py
+++
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py
@@ -30,10 +30,6 @@ class StringOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_astype(self):
- super().test_astype()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_string_ops
import * # noqa: F401
diff --git
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py
index 058dd2bfd3f..edd29fa1ed2 100644
---
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py
+++
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py
@@ -30,10 +30,6 @@ class TimedeltaOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_astype(self):
- super().test_astype()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_timedelta_ops
import * # noqa: F401
diff --git a/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
b/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
index 3cf4dc9b3d2..8f1f2d2221c 100644
--- a/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
+++ b/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
@@ -29,10 +29,6 @@ class IndexesParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_factorize(self):
- super().test_factorize()
-
@unittest.skip("TODO(SPARK-43704): Enable
IndexesParityTests.test_to_series.")
def test_to_series(self):
super().test_to_series()
diff --git
a/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py
b/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py
index d99d013306f..aed7df26202 100644
--- a/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py
+++ b/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py
@@ -24,41 +24,7 @@ from pyspark.testing.pandasutils import
PandasOnSparkTestUtils, TestUtils
class CategoricalIndexParityTests(
CategoricalIndexTestsMixin, PandasOnSparkTestUtils, TestUtils,
ReusedConnectTestCase
):
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_append(self):
- super().test_append()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_astype(self):
- super().test_astype()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_factorize(self):
- super().test_factorize()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_intersection(self):
- super().test_intersection()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_remove_categories(self):
- super().test_remove_categories()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_remove_unused_categories(self):
- super().test_remove_unused_categories()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_reorder_categories(self):
- super().test_reorder_categories()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_set_categories(self):
- super().test_set_categories()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_union(self):
- super().test_union()
+ pass
if __name__ == "__main__":
diff --git a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py
b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py
index 31916f12b4e..8876fcb1398 100644
--- a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py
+++ b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py
@@ -24,10 +24,6 @@ from pyspark.testing.pandasutils import
PandasOnSparkTestUtils
class SeriesParityComputeTests(SeriesComputeMixin, PandasOnSparkTestUtils,
ReusedConnectTestCase):
pass
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_factorize(self):
- super().test_factorize()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.series.test_parity_compute import * #
noqa: F401
diff --git a/python/pyspark/pandas/tests/connect/test_parity_categorical.py
b/python/pyspark/pandas/tests/connect/test_parity_categorical.py
index 210cfce8ddb..ca880aef572 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_categorical.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_categorical.py
@@ -29,30 +29,6 @@ class CategoricalParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_astype(self):
- super().test_astype()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_factorize(self):
- super().test_factorize()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_remove_categories(self):
- super().test_remove_categories()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_remove_unused_categories(self):
- super().test_remove_unused_categories()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_reorder_categories(self):
- super().test_reorder_categories()
-
- @unittest.skip("TODO(SPARK-43620): Support `Column` for
SparkConnectColumn.__getitem__.")
- def test_set_categories(self):
- super().test_set_categories()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.test_parity_categorical import * #
noqa: F401
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]