This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 74c82642941 [SPARK-40812][CONNECT][PYTHON][FOLLOW-UP] Improve 
Deduplicate in Python client
74c82642941 is described below

commit 74c826429416493a6d1d0efdf83b0e561dc33591
Author: Rui Wang <[email protected]>
AuthorDate: Mon Oct 24 10:50:55 2022 +0800

    [SPARK-40812][CONNECT][PYTHON][FOLLOW-UP] Improve Deduplicate in Python 
client
    
    ### What changes were proposed in this pull request?
    
    Following up on https://github.com/apache/spark/pull/38276, this PR improve 
both `distinct()` and `dropDuplicates` DataFrame API in Python client, which 
both depends on `Deduplicate` plan in the Connect proto.
    
    ### Why are the changes needed?
    
    Improve API coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT
    
    Closes #38327 from amaliujia/python_deduplicate.
    
    Authored-by: Rui Wang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py            | 41 +++++++++++++++++++---
 python/pyspark/sql/connect/plan.py                 | 39 ++++++++++++++++++++
 .../sql/tests/connect/test_connect_plan_only.py    | 19 ++++++++++
 3 files changed, 95 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index eabcf433ae9..2b7e3d52039 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -157,11 +157,44 @@ class DataFrame(object):
     def describe(self, cols: List[ColumnRef]) -> Any:
         ...
 
+    def dropDuplicates(self, subset: Optional[List[str]] = None) -> 
"DataFrame":
+        """Return a new :class:`DataFrame` with duplicate rows removed,
+        optionally only deduplicating based on certain columns.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        subset : List of column names, optional
+            List of columns to use for duplicate comparison (default All 
columns).
+
+        Returns
+        -------
+        :class:`DataFrame`
+            DataFrame without duplicated rows.
+        """
+        if subset is None:
+            return DataFrame.withPlan(
+                plan.Deduplicate(child=self._plan, all_columns_as_keys=True), 
session=self._session
+            )
+        else:
+            return DataFrame.withPlan(
+                plan.Deduplicate(child=self._plan, column_names=subset), 
session=self._session
+            )
+
     def distinct(self) -> "DataFrame":
-        """Returns all distinct rows."""
-        all_cols = self.columns
-        gf = self.groupBy(*all_cols)
-        return gf.agg()
+        """Returns a new :class:`DataFrame` containing the distinct rows in 
this :class:`DataFrame`.
+
+        .. versionadded:: 3.4.0
+
+        Returns
+        -------
+        :class:`DataFrame`
+            DataFrame with distinct rows.
+        """
+        return DataFrame.withPlan(
+            plan.Deduplicate(child=self._plan, all_columns_as_keys=True), 
session=self._session
+        )
 
     def drop(self, *cols: "ColumnOrString") -> "DataFrame":
         all_cols = self.columns
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 297b15994d3..d6b6f9e3b67 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -327,6 +327,45 @@ class Offset(LogicalPlan):
         """
 
 
+class Deduplicate(LogicalPlan):
+    def __init__(
+        self,
+        child: Optional["LogicalPlan"],
+        all_columns_as_keys: bool = False,
+        column_names: Optional[List[str]] = None,
+    ) -> None:
+        super().__init__(child)
+        self.all_columns_as_keys = all_columns_as_keys
+        self.column_names = column_names
+
+    def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
+        assert self._child is not None
+        plan = proto.Relation()
+        plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys
+        if self.column_names is not None:
+            plan.deduplicate.column_names.extend(self.column_names)
+        return plan
+
+    def print(self, indent: int = 0) -> str:
+        c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child 
else ""
+        return (
+            f"{' ' * indent}<all_columns_as_keys={self.all_columns_as_keys} "
+            f"column_names={self.column_names}>\n{c_buf}"
+        )
+
+    def _repr_html_(self) -> str:
+        return f"""
+        <ul>
+            <li>
+                <b></b>Deduplicate<br />
+                all_columns_as_keys: {self.all_columns_as_keys} <br />
+                column_names: {self.column_names} <br />
+                {self._child_repr_()}
+            </li>
+        </uL>
+        """
+
+
 class Sort(LogicalPlan):
     def __init__(
         self, child: Optional["LogicalPlan"], *columns: Union[SortOrder, 
ColumnRef, str]
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 3b609db7a02..450f5c70fab 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -72,6 +72,25 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.sample.with_replacement, True)
         self.assertEqual(plan.root.sample.seed.seed, -1)
 
+    def test_deduplicate(self):
+        df = self.connect.readTable(table_name=self.tbl_name)
+
+        distinct_plan = df.distinct()._plan.to_proto(self.connect)
+        self.assertEqual(distinct_plan.root.deduplicate.all_columns_as_keys, 
True)
+        self.assertEqual(len(distinct_plan.root.deduplicate.column_names), 0)
+
+        deduplicate_on_all_columns_plan = 
df.dropDuplicates()._plan.to_proto(self.connect)
+        
self.assertEqual(deduplicate_on_all_columns_plan.root.deduplicate.all_columns_as_keys,
 True)
+        
self.assertEqual(len(deduplicate_on_all_columns_plan.root.deduplicate.column_names),
 0)
+
+        deduplicate_on_subset_columns_plan = df.dropDuplicates(["name", 
"height"])._plan.to_proto(
+            self.connect
+        )
+        self.assertEqual(
+            
deduplicate_on_subset_columns_plan.root.deduplicate.all_columns_as_keys, False
+        )
+        
self.assertEqual(len(deduplicate_on_subset_columns_plan.root.deduplicate.column_names),
 2)
+
     def test_relation_alias(self):
         df = self.connect.readTable(table_name=self.tbl_name)
         plan = df.alias("table_alias")._plan.to_proto(self.connect)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to