This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2bb36fb271b [SPARK-45936][PS] Optimize `Index.symmetric_difference`
2bb36fb271b is described below
commit 2bb36fb271b60dda68567b92613a3664a7aae2b8
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Nov 16 10:40:05 2023 +0900
[SPARK-45936][PS] Optimize `Index.symmetric_difference`
### What changes were proposed in this pull request?
Add a helper function for `XOR`, and use it to optimize
`Index.symmetric_difference`
### Why are the changes needed?
the old plan is too complex
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #43816 from zhengruifeng/ps_base_diff.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/indexes/base.py | 4 ++--
python/pyspark/pandas/indexes/multi.py | 15 ++-------------
python/pyspark/pandas/utils.py | 17 +++++++++++++++++
3 files changed, 21 insertions(+), 15 deletions(-)
diff --git a/python/pyspark/pandas/indexes/base.py
b/python/pyspark/pandas/indexes/base.py
index 6c6ee9ae0d7..a515a79dcd7 100644
--- a/python/pyspark/pandas/indexes/base.py
+++ b/python/pyspark/pandas/indexes/base.py
@@ -72,6 +72,7 @@ from pyspark.pandas.utils import (
validate_index_loc,
ERROR_MESSAGE_CANNOT_COMBINE,
log_advice,
+ xor,
)
from pyspark.pandas.internal import (
InternalField,
@@ -1468,8 +1469,7 @@ class Index(IndexOpsMixin):
sdf_self =
self._psdf._internal.spark_frame.select(self._internal.index_spark_columns)
sdf_other =
other._psdf._internal.spark_frame.select(other._internal.index_spark_columns)
-
- sdf_symdiff =
sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other))
+ sdf_symdiff = xor(sdf_self, sdf_other)
if sort:
sdf_symdiff =
sdf_symdiff.sort(*self._internal.index_spark_column_names)
diff --git a/python/pyspark/pandas/indexes/multi.py
b/python/pyspark/pandas/indexes/multi.py
index 62b42c1fcd0..7d2712cbb53 100644
--- a/python/pyspark/pandas/indexes/multi.py
+++ b/python/pyspark/pandas/indexes/multi.py
@@ -38,6 +38,7 @@ from pyspark.pandas.utils import (
scol_for,
verify_temp_column_name,
validate_index_loc,
+ xor,
)
from pyspark.pandas.internal import (
InternalField,
@@ -809,19 +810,7 @@ class MultiIndex(Index):
sdf_self =
self._psdf._internal.spark_frame.select(self._internal.index_spark_columns)
sdf_other =
other._psdf._internal.spark_frame.select(other._internal.index_spark_columns)
-
- tmp_tag_col = verify_temp_column_name(sdf_self, "__multi_index_tag__")
- tmp_max_col = verify_temp_column_name(sdf_self,
"__multi_index_max_tag__")
- tmp_min_col = verify_temp_column_name(sdf_self,
"__multi_index_min_tag__")
-
- sdf_symdiff = (
- sdf_self.withColumn(tmp_tag_col, F.lit(0))
- .union(sdf_other.withColumn(tmp_tag_col, F.lit(1)))
- .groupBy(*self._internal.index_spark_column_names)
- .agg(F.min(tmp_tag_col).alias(tmp_min_col),
F.max(tmp_tag_col).alias(tmp_max_col))
- .where(F.col(tmp_min_col) == F.col(tmp_max_col))
- .select(*self._internal.index_spark_column_names)
- )
+ sdf_symdiff = xor(sdf_self, sdf_other)
if sort:
sdf_symdiff =
sdf_symdiff.sort(*self._internal.index_spark_column_names)
diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py
index 9f372a53079..57c1ddbe6ae 100644
--- a/python/pyspark/pandas/utils.py
+++ b/python/pyspark/pandas/utils.py
@@ -1033,6 +1033,23 @@ def validate_index_loc(index: "Index", loc: int) -> None:
)
+def xor(df1: PySparkDataFrame, df2: PySparkDataFrame) -> PySparkDataFrame:
+ colNames = df1.columns
+
+ tmp_tag_col = verify_temp_column_name(df1, "__temporary_tag__")
+ tmp_max_col = verify_temp_column_name(df1, "__temporary_max_tag__")
+ tmp_min_col = verify_temp_column_name(df1, "__temporary_min_tag__")
+
+ return (
+ df1.withColumn(tmp_tag_col, F.lit(0))
+ .union(df2.withColumn(tmp_tag_col, F.lit(1)))
+ .groupBy(*colNames)
+ .agg(F.min(tmp_tag_col).alias(tmp_min_col),
F.max(tmp_tag_col).alias(tmp_max_col))
+ .where(F.col(tmp_min_col) == F.col(tmp_max_col))
+ .select(*colNames)
+ )
+
+
def _test() -> None:
import os
import doctest
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]