This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ef9c8e0d045 [SPARK-41439][CONNECT][PYTHON][FOLLOWUP] Make unpivot of 
`connect/dataframe.py` consistent with `pyspark/dataframe.py`
ef9c8e0d045 is described below

commit ef9c8e0d045576fb325ef337319fe6d59b7ce858
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Mon Dec 12 08:34:21 2022 +0800

    [SPARK-41439][CONNECT][PYTHON][FOLLOWUP] Make unpivot of 
`connect/dataframe.py` consistent with `pyspark/dataframe.py`
    
    ### What changes were proposed in this pull request?
    This PR lets `unpivot` of `connect/dataframe.py` consistent with 
`pyspark/dataframe.py` and adds test cases for connect's `unpivot`.
    
    This PR follows up https://github.com/apache/spark/pull/38973
    
    ### Why are the changes needed?
    1. Lets `unpivot` of `connect/dataframe.py` consistent with 
`pyspark/dataframe.py`
    2. Add test cases for connect's `unpivot`.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'. New API
    
    ### How was this patch tested?
    New test cases.
    
    Closes #39019 from beliefer/SPARK-41439_followup.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py            | 22 ++++++++++++++++++---
 .../sql/tests/connect/test_connect_basic.py        | 23 ++++++++++++++++++++++
 .../sql/tests/connect/test_connect_plan_only.py    |  2 +-
 3 files changed, 43 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 4c1956cc577..08d48bb11f2 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -826,8 +826,8 @@ class DataFrame(object):
 
     def unpivot(
         self,
-        ids: List["ColumnOrName"],
-        values: List["ColumnOrName"],
+        ids: Optional[Union["ColumnOrName", List["ColumnOrName"], 
Tuple["ColumnOrName", ...]]],
+        values: Optional[Union["ColumnOrName", List["ColumnOrName"], 
Tuple["ColumnOrName", ...]]],
         variableColumnName: str,
         valueColumnName: str,
     ) -> "DataFrame":
@@ -852,8 +852,24 @@ class DataFrame(object):
         -------
         :class:`DataFrame`
         """
+
+        def to_jcols(
+            cols: Optional[Union["ColumnOrName", List["ColumnOrName"], 
Tuple["ColumnOrName", ...]]]
+        ) -> List["ColumnOrName"]:
+            if cols is None:
+                lst = []
+            elif isinstance(cols, tuple):
+                lst = list(cols)
+            elif isinstance(cols, list):
+                lst = cols
+            else:
+                lst = [cols]
+            return lst
+
         return DataFrame.withPlan(
-            plan.Unpivot(self._plan, ids, values, variableColumnName, 
valueColumnName),
+            plan.Unpivot(
+                self._plan, to_jcols(ids), to_jcols(values), 
variableColumnName, valueColumnName
+            ),
             self._session,
         )
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 9d49cfd321c..6dabbaedffe 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -783,6 +783,29 @@ class SparkConnectTests(SparkConnectSQLTestCase):
                 """Cannot resolve column name "x" among (a, b, c)""", 
str(context.exception)
             )
 
+    def test_unpivot(self):
+        self.assert_eq(
+            self.connect.read.table(self.tbl_name)
+            .filter("id > 3")
+            .unpivot(["id"], ["name"], "variable", "value")
+            .toPandas(),
+            self.spark.read.table(self.tbl_name)
+            .filter("id > 3")
+            .unpivot(["id"], ["name"], "variable", "value")
+            .toPandas(),
+        )
+
+        self.assert_eq(
+            self.connect.read.table(self.tbl_name)
+            .filter("id > 3")
+            .unpivot("id", None, "variable", "value")
+            .toPandas(),
+            self.spark.read.table(self.tbl_name)
+            .filter("id > 3")
+            .unpivot("id", None, "variable", "value")
+            .toPandas(),
+        )
+
     def test_with_columns(self):
         # SPARK-41256: test withColumn(s).
         self.assert_eq(
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 83e21e42bad..e0cd54195f3 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -189,7 +189,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
 
         plan = (
             df.filter(df.col_name > 3)
-            .unpivot(["id"], [], "variable", "value")
+            .unpivot(["id"], None, "variable", "value")
             ._plan.to_proto(self.connect)
         )
         self.assertTrue(len(plan.root.unpivot.ids) == 1)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to