This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f9a2077fd32f [SPARK-49810][PYTHON] Extract the preparation of 
`DataFrame.sort` to parent class
f9a2077fd32f is described below

commit f9a2077fd32faf63796a68cbb3483b486f220b1c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Sep 28 16:21:30 2024 +0900

    [SPARK-49810][PYTHON] Extract the preparation of `DataFrame.sort` to parent 
class
    
    ### What changes were proposed in this pull request?
    Extract the preparation of df.sort to parent class
    
    ### Why are the changes needed?
    deduplicate code, the logics in two classes are similar
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #48282 from zhengruifeng/py_sql_sort.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/classic/dataframe.py | 52 +++---------------------------
 python/pyspark/sql/connect/dataframe.py | 53 +++----------------------------
 python/pyspark/sql/dataframe.py         | 56 +++++++++++++++++++++++++++++++++
 3 files changed, 65 insertions(+), 96 deletions(-)

diff --git a/python/pyspark/sql/classic/dataframe.py 
b/python/pyspark/sql/classic/dataframe.py
index 0dd66a9d8654..9f9dedbd3820 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -55,6 +55,7 @@ from pyspark.serializers import BatchedSerializer, 
CPickleSerializer, UTF8Deseri
 from pyspark.storagelevel import StorageLevel
 from pyspark.traceback_utils import SCCallSiteSync
 from pyspark.sql.column import Column
+from pyspark.sql.functions import builtin as F
 from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column
 from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
 from pyspark.sql.merge import MergeIntoWriter
@@ -873,7 +874,8 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
         *cols: Union[int, str, Column, List[Union[int, str, Column]]],
         **kwargs: Any,
     ) -> ParentDataFrame:
-        jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
+        _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
+        jdf = self._jdf.sortWithinPartitions(self._jseq(_cols, 
_to_java_column))
         return DataFrame(jdf, self.sparkSession)
 
     def sort(
@@ -881,7 +883,8 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
         *cols: Union[int, str, Column, List[Union[int, str, Column]]],
         **kwargs: Any,
     ) -> ParentDataFrame:
-        jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
+        _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
+        jdf = self._jdf.sort(self._jseq(_cols, _to_java_column))
         return DataFrame(jdf, self.sparkSession)
 
     orderBy = sort
@@ -928,51 +931,6 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
                 _cols.append(c)  # type: ignore[arg-type]
         return self._jseq(_cols, _to_java_column)
 
-    def _sort_cols(
-        self,
-        cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
-        kwargs: Dict[str, Any],
-    ) -> "JavaObject":
-        """Return a JVM Seq of Columns that describes the sort order"""
-        if not cols:
-            raise PySparkValueError(
-                errorClass="CANNOT_BE_EMPTY",
-                messageParameters={"item": "column"},
-            )
-        if len(cols) == 1 and isinstance(cols[0], list):
-            cols = cols[0]
-
-        jcols = []
-        for c in cols:
-            if isinstance(c, int) and not isinstance(c, bool):
-                # ordinal is 1-based
-                if c > 0:
-                    _c = self[c - 1]
-                # negative ordinal means sort by desc
-                elif c < 0:
-                    _c = self[-c - 1].desc()
-                else:
-                    raise PySparkIndexError(
-                        errorClass="ZERO_INDEX",
-                        messageParameters={},
-                    )
-            else:
-                _c = c  # type: ignore[assignment]
-            jcols.append(_to_java_column(cast("ColumnOrName", _c)))
-
-        ascending = kwargs.get("ascending", True)
-        if isinstance(ascending, (bool, int)):
-            if not ascending:
-                jcols = [jc.desc() for jc in jcols]
-        elif isinstance(ascending, list):
-            jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending, 
jcols)]
-        else:
-            raise PySparkTypeError(
-                errorClass="NOT_BOOL_OR_LIST",
-                messageParameters={"arg_name": "ascending", "arg_type": 
type(ascending).__name__},
-            )
-        return self._jseq(jcols)
-
     def describe(self, *cols: Union[str, List[str]]) -> ParentDataFrame:
         if len(cols) == 1 and isinstance(cols[0], list):
             cols = cols[0]  # type: ignore[assignment]
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 146cfe11bc50..136fe60532df 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -739,62 +739,16 @@ class DataFrame(ParentDataFrame):
     def tail(self, num: int) -> List[Row]:
         return DataFrame(plan.Tail(child=self._plan, limit=num), 
session=self._session).collect()
 
-    def _sort_cols(
-        self,
-        cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
-        kwargs: Dict[str, Any],
-    ) -> List[Column]:
-        """Return a JVM Seq of Columns that describes the sort order"""
-        if cols is None:
-            raise PySparkValueError(
-                errorClass="CANNOT_BE_EMPTY",
-                messageParameters={"item": "cols"},
-            )
-
-        if len(cols) == 1 and isinstance(cols[0], list):
-            cols = cols[0]
-
-        _cols: List[Column] = []
-        for c in cols:
-            if isinstance(c, int) and not isinstance(c, bool):
-                # ordinal is 1-based
-                if c > 0:
-                    _c = self[c - 1]
-                # negative ordinal means sort by desc
-                elif c < 0:
-                    _c = self[-c - 1].desc()
-                else:
-                    raise PySparkIndexError(
-                        errorClass="ZERO_INDEX",
-                        messageParameters={},
-                    )
-            else:
-                _c = c  # type: ignore[assignment]
-            _cols.append(F._to_col(cast("ColumnOrName", _c)))
-
-        ascending = kwargs.get("ascending", True)
-        if isinstance(ascending, (bool, int)):
-            if not ascending:
-                _cols = [c.desc() for c in _cols]
-        elif isinstance(ascending, list):
-            _cols = [c if asc else c.desc() for asc, c in zip(ascending, 
_cols)]
-        else:
-            raise PySparkTypeError(
-                errorClass="NOT_BOOL_OR_LIST",
-                messageParameters={"arg_name": "ascending", "arg_type": 
type(ascending).__name__},
-            )
-
-        return [F._sort_col(c) for c in _cols]
-
     def sort(
         self,
         *cols: Union[int, str, Column, List[Union[int, str, Column]]],
         **kwargs: Any,
     ) -> ParentDataFrame:
+        _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
         res = DataFrame(
             plan.Sort(
                 self._plan,
-                columns=self._sort_cols(cols, kwargs),
+                columns=[F._sort_col(c) for c in _cols],
                 is_global=True,
             ),
             session=self._session,
@@ -809,10 +763,11 @@ class DataFrame(ParentDataFrame):
         *cols: Union[int, str, Column, List[Union[int, str, Column]]],
         **kwargs: Any,
     ) -> ParentDataFrame:
+        _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
         res = DataFrame(
             plan.Sort(
                 self._plan,
-                columns=self._sort_cols(cols, kwargs),
+                columns=[F._sort_col(c) for c in _cols],
                 is_global=False,
             ),
             session=self._session,
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 142034583dbd..5906108163b4 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2891,6 +2891,62 @@ class DataFrame:
         """
         ...
 
+    def _preapare_cols_for_sort(
+        self,
+        _to_col: Callable[[str], Column],
+        cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
+        kwargs: Dict[str, Any],
+    ) -> Sequence[Column]:
+        from pyspark.errors import PySparkTypeError, PySparkValueError, 
PySparkIndexError
+
+        if not cols:
+            raise PySparkValueError(
+                errorClass="CANNOT_BE_EMPTY", messageParameters={"item": 
"cols"}
+            )
+
+        if len(cols) == 1 and isinstance(cols[0], list):
+            cols = cols[0]
+
+        _cols: List[Column] = []
+        for c in cols:
+            if isinstance(c, int) and not isinstance(c, bool):
+                # ordinal is 1-based
+                if c > 0:
+                    _cols.append(self[c - 1])
+                # negative ordinal means sort by desc
+                elif c < 0:
+                    _cols.append(self[-c - 1].desc())
+                else:
+                    raise PySparkIndexError(
+                        errorClass="ZERO_INDEX",
+                        messageParameters={},
+                    )
+            elif isinstance(c, Column):
+                _cols.append(c)
+            elif isinstance(c, str):
+                _cols.append(_to_col(c))
+            else:
+                raise PySparkTypeError(
+                    errorClass="NOT_COLUMN_OR_INT_OR_STR",
+                    messageParameters={
+                        "arg_name": "col",
+                        "arg_type": type(c).__name__,
+                    },
+                )
+
+        ascending = kwargs.get("ascending", True)
+        if isinstance(ascending, (bool, int)):
+            if not ascending:
+                _cols = [c.desc() for c in _cols]
+        elif isinstance(ascending, list):
+            _cols = [c if asc else c.desc() for asc, c in zip(ascending, 
_cols)]
+        else:
+            raise PySparkTypeError(
+                errorClass="NOT_COLUMN_OR_INT_OR_STR",
+                messageParameters={"arg_name": "ascending", "arg_type": 
type(ascending).__name__},
+            )
+        return _cols
+
     orderBy = sort
 
     @dispatch_df_method


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to