This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0eb96ae6eb86 [SPARK-47620][PYTHON][CONNECT] Add a helper function to 
sort columns
0eb96ae6eb86 is described below

commit 0eb96ae6eb8680155d4c6974dadaeebd7475a1fc
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Mar 28 12:22:27 2024 +0800

    [SPARK-47620][PYTHON][CONNECT] Add a helper function to sort columns
    
    ### What changes were proposed in this pull request?
    Add a helper function `_sort_col` to sort columns
    
    ### Why are the changes needed?
    simple code refactoring
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #45743 from zhengruifeng/connect_sort_col.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py         | 16 +++-------------
 python/pyspark/sql/connect/functions/builtin.py | 12 ++++++++++++
 python/pyspark/sql/connect/plan.py              | 15 ++-------------
 3 files changed, 17 insertions(+), 26 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 82929ccfbc4f..672ac8b9c25c 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -72,7 +72,6 @@ from pyspark.sql.connect.readwriter import DataFrameWriter, 
DataFrameWriterV2
 from pyspark.sql.connect.streaming.readwriter import DataStreamWriter
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import (
-    SortOrder,
     ColumnReference,
     UnresolvedRegex,
     UnresolvedStar,
@@ -349,15 +348,6 @@ class DataFrame:
     def repartitionByRange(  # type: ignore[misc]
         self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
     ) -> "DataFrame":
-        def _convert_col(col: "ColumnOrName") -> Column:
-            if isinstance(col, Column):
-                if isinstance(col._expr, SortOrder):
-                    return col
-                else:
-                    return col.asc()
-            else:
-                return F.col(col).asc()
-
         if isinstance(numPartitions, int):
             if not numPartitions > 0:
                 raise PySparkValueError(
@@ -375,14 +365,14 @@ class DataFrame:
             else:
                 return DataFrame(
                     plan.RepartitionByExpression(
-                        self._plan, numPartitions, [_convert_col(c) for c in 
cols]
+                        self._plan, numPartitions, [F._sort_col(c) for c in 
cols]
                     ),
                     self.sparkSession,
                 )
         elif isinstance(numPartitions, (str, Column)):
             return DataFrame(
                 plan.RepartitionByExpression(
-                    self._plan, None, [_convert_col(c) for c in 
[numPartitions] + list(cols)]
+                    self._plan, None, [F._sort_col(c) for c in [numPartitions] 
+ list(cols)]
                 ),
                 self.sparkSession,
             )
@@ -729,7 +719,7 @@ class DataFrame:
                 message_parameters={"arg_name": "ascending", "arg_type": 
type(ascending).__name__},
             )
 
-        return _cols
+        return [F._sort_col(c) for c in _cols]
 
     def sort(
         self,
diff --git a/python/pyspark/sql/connect/functions/builtin.py 
b/python/pyspark/sql/connect/functions/builtin.py
index c423c5f188ef..57c1f881eebf 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -45,6 +45,7 @@ from pyspark.errors import PySparkTypeError, PySparkValueError
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import (
     CaseWhen,
+    SortOrder,
     Expression,
     LiteralExpression,
     ColumnReference,
@@ -88,6 +89,17 @@ def _to_col(col: "ColumnOrName") -> Column:
     return col if isinstance(col, Column) else column(col)
 
 
+def _sort_col(col: "ColumnOrName") -> Column:
+    assert isinstance(col, (Column, str))
+    if isinstance(col, Column):
+        if isinstance(col._expr, SortOrder):
+            return col
+        else:
+            return col.asc()
+    else:
+        return column(col).asc()
+
+
 def _invoke_function(name: str, *args: Union[Column, Expression]) -> Column:
     """
     Simple wrapper function that converts the arguments into the appropriate 
types.
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index d9dd8874398c..863c27fabf6b 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -46,10 +46,7 @@ from pyspark.sql.types import DataType
 import pyspark.sql.connect.proto as proto
 from pyspark.sql.connect.conversion import storage_level_to_proto
 from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.expressions import (
-    Expression,
-    SortOrder,
-)
+from pyspark.sql.connect.expressions import Expression
 from pyspark.sql.connect.types import pyspark_types_to_proto_types, 
UnparsedDataType
 from pyspark.errors import (
     PySparkValueError,
@@ -674,19 +671,11 @@ class Sort(LogicalPlan):
         self.columns = columns
         self.is_global = is_global
 
-    def _convert_col(
-        self, col: Column, session: "SparkConnectClient"
-    ) -> proto.Expression.SortOrder:
-        if isinstance(col._expr, SortOrder):
-            return col._expr.to_plan(session).sort_order
-        else:
-            return SortOrder(col._expr).to_plan(session).sort_order
-
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = self._create_proto_relation()
         plan.sort.input.CopyFrom(self._child.plan(session))
-        plan.sort.order.extend([self._convert_col(c, session) for c in 
self.columns])
+        plan.sort.order.extend([c.to_plan(session).sort_order for c in 
self.columns])
         plan.sort.is_global = self.is_global
         return plan
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to