This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6c885a7cf57d [SPARK-45074][PYTHON][CONNECT] `DataFrame.{sort, 
sortWithinPartitions}` support column ordinals
6c885a7cf57d is described below

commit 6c885a7cf57df328b03308cff2eed814bda156e4
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Sep 4 23:31:23 2023 -0700

    [SPARK-45074][PYTHON][CONNECT] `DataFrame.{sort, sortWithinPartitions}` 
support column ordinals
    
    ### What changes were proposed in this pull request?
    `DataFrame.{sort, sortWithinPartitions}` support column ordinals
    
    ### Why are the changes needed?
    for feature parity:
    
    SQL:
    ```
    select a, 1, sum(b) from v group by 1, 2 order by 3, 1;
    ```
    
    DataFrame:
    ```
    df.select("a", sf.lit(1), "b").groupBy(1, 2).agg(sf.sum("b")).sort(3, 1)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new feature
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    NO
    
    Closes #42809 from zhengruifeng/py_oderby_ordinal.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 python/pyspark/sql/connect/dataframe.py |  33 ++++++--
 python/pyspark/sql/dataframe.py         | 134 +++++++++++++++++++++++++++++---
 python/pyspark/sql/tests/test_group.py  |  53 +++++++++++++
 3 files changed, 202 insertions(+), 18 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index b22fdc1383cf..c443023ce02a 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -593,7 +593,9 @@ class DataFrame:
     tail.__doc__ = PySparkDataFrame.tail.__doc__
 
     def _sort_cols(
-        self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]], 
kwargs: Dict[str, Any]
+        self,
+        cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
+        kwargs: Dict[str, Any],
     ) -> List[Column]:
         """Return a JVM Seq of Columns that describes the sort order"""
         if cols is None:
@@ -602,11 +604,24 @@ class DataFrame:
                 message_parameters={"item": "cols"},
             )
 
-        _cols: List[Column] = []
         if len(cols) == 1 and isinstance(cols[0], list):
-            _cols = [_to_col(c) for c in cols[0]]
-        else:
-            _cols = [_to_col(cast("ColumnOrName", c)) for c in cols]
+            cols = cols[0]
+
+        _cols: List[Column] = []
+        for c in cols:
+            if isinstance(c, int) and not isinstance(c, bool):
+                # TODO: should introduce dedicated error class
+                # ordinal is 1-based
+                if c > 0:
+                    _c = self[c - 1]
+                # negative ordinal means sort by desc
+                elif c < 0:
+                    _c = self[-c - 1].desc()
+                else:
+                    raise IndexError("Column ordinal must not be zero!")
+            else:
+                _c = c  # type: ignore[assignment]
+            _cols.append(_to_col(cast("ColumnOrName", _c)))
 
         ascending = kwargs.get("ascending", True)
         if isinstance(ascending, (bool, int)):
@@ -623,7 +638,9 @@ class DataFrame:
         return _cols
 
     def sort(
-        self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: 
Any
+        self,
+        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+        **kwargs: Any,
     ) -> "DataFrame":
         return DataFrame.withPlan(
             plan.Sort(
@@ -639,7 +656,9 @@ class DataFrame:
     orderBy = sort
 
     def sortWithinPartitions(
-        self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: 
Any
+        self,
+        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+        **kwargs: Any,
     ) -> "DataFrame":
         return DataFrame.withPlan(
             plan.Sort(
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3d7bdd7a0b2b..f59ae40542b9 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2853,7 +2853,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         return DataFrame(jdf, self.sparkSession)
 
     def sortWithinPartitions(
-        self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: 
Any
+        self,
+        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+        **kwargs: Any,
     ) -> "DataFrame":
         """Returns a new :class:`DataFrame` with each partition sorted by the 
specified column(s).
 
@@ -2862,10 +2864,13 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         .. versionchanged:: 3.4.0
             Supports Spark Connect.
 
+        .. versionchanged:: 4.0.0
+            Supports column ordinal.
+
         Parameters
         ----------
-        cols : str, list or :class:`Column`, optional
-            list of :class:`Column` or column names to sort by.
+        cols : int, str, list or :class:`Column`, optional
+            list of :class:`Column` or column names or column ordinals to sort 
by.
 
         Other Parameters
         ----------------
@@ -2879,17 +2884,42 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         :class:`DataFrame`
             DataFrame sorted by partitions.
 
+        Notes
+        -----
+        A column ordinal starts from 1, which is different from the
+        0-based :meth:`__getitem__`.
+        If a column ordinal is negative, it means sort descending.
+
         Examples
         --------
+        >>> from pyspark.sql import functions as sf
         >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], 
schema=["age", "name"])
         >>> df.sortWithinPartitions("age", ascending=False)
         DataFrame[age: bigint, name: string]
+
+        >>> df.coalesce(1).sortWithinPartitions(1).show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  2|Alice|
+        |  5|  Bob|
+        +---+-----+
+
+        >>> df.coalesce(1).sortWithinPartitions(-1).show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  5|  Bob|
+        |  2|Alice|
+        +---+-----+
         """
         jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
         return DataFrame(jdf, self.sparkSession)
 
     def sort(
-        self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: 
Any
+        self,
+        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
+        **kwargs: Any,
     ) -> "DataFrame":
         """Returns a new :class:`DataFrame` sorted by the specified column(s).
 
@@ -2898,10 +2928,13 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         .. versionchanged:: 3.4.0
             Supports Spark Connect.
 
+        .. versionchanged:: 4.0.0
+            Supports column ordinal.
+
         Parameters
         ----------
-        cols : str, list, or :class:`Column`, optional
-             list of :class:`Column` or column names to sort by.
+        cols : int, str, list, or :class:`Column`, optional
+             list of :class:`Column` or column names or column ordinals to 
sort by.
 
         Other Parameters
         ----------------
@@ -2915,15 +2948,29 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         :class:`DataFrame`
             Sorted DataFrame.
 
+        Notes
+        -----
+        A column ordinal starts from 1, which is different from the
+        0-based :meth:`__getitem__`.
+        If a column ordinal is negative, it means sort descending.
+
         Examples
         --------
-        >>> from pyspark.sql.functions import desc, asc
+        >>> from pyspark.sql import functions as sf
         >>> df = spark.createDataFrame([
         ...     (2, "Alice"), (5, "Bob")], schema=["age", "name"])
 
         Sort the DataFrame in ascending order.
 
-        >>> df.sort(asc("age")).show()
+        >>> df.sort(sf.asc("age")).show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  2|Alice|
+        |  5|  Bob|
+        +---+-----+
+
+        >>> df.sort(1).show()
         +---+-----+
         |age| name|
         +---+-----+
@@ -2940,6 +2987,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         |  5|  Bob|
         |  2|Alice|
         +---+-----+
+
         >>> df.orderBy(df.age.desc()).show()
         +---+-----+
         |age| name|
@@ -2947,6 +2995,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         |  5|  Bob|
         |  2|Alice|
         +---+-----+
+
         >>> df.sort("age", ascending=False).show()
         +---+-----+
         |age| name|
@@ -2955,11 +3004,38 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         |  2|Alice|
         +---+-----+
 
+        >>> df.sort(-1).show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  5|  Bob|
+        |  2|Alice|
+        +---+-----+
+
         Specify multiple columns
 
+        >>> from pyspark.sql import functions as sf
         >>> df = spark.createDataFrame([
         ...     (2, "Alice"), (2, "Bob"), (5, "Bob")], schema=["age", "name"])
-        >>> df.orderBy(desc("age"), "name").show()
+        >>> df.orderBy(sf.desc("age"), "name").show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  5|  Bob|
+        |  2|Alice|
+        |  2|  Bob|
+        +---+-----+
+
+        >>> df.orderBy(-1, "name").show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  5|  Bob|
+        |  2|Alice|
+        |  2|  Bob|
+        +---+-----+
+
+        >>> df.orderBy(-1, 2).show()
         +---+-----+
         |age| name|
         +---+-----+
@@ -2978,6 +3054,24 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         |  2|  Bob|
         |  2|Alice|
         +---+-----+
+
+        >>> df.orderBy([1, "name"], ascending=[False, False]).show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  5|  Bob|
+        |  2|  Bob|
+        |  2|Alice|
+        +---+-----+
+
+        >>> df.orderBy([1, 2], ascending=[False, False]).show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  5|  Bob|
+        |  2|  Bob|
+        |  2|Alice|
+        +---+-----+
         """
         jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
         return DataFrame(jdf, self.sparkSession)
@@ -3026,7 +3120,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         return self._jseq(_cols, _to_java_column)
 
     def _sort_cols(
-        self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]], 
kwargs: Dict[str, Any]
+        self,
+        cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
+        kwargs: Dict[str, Any],
     ) -> JavaObject:
         """Return a JVM Seq of Columns that describes the sort order"""
         if not cols:
@@ -3036,7 +3132,23 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
             )
         if len(cols) == 1 and isinstance(cols[0], list):
             cols = cols[0]
-        jcols = [_to_java_column(cast("ColumnOrName", c)) for c in cols]
+
+        jcols = []
+        for c in cols:
+            if isinstance(c, int) and not isinstance(c, bool):
+                # TODO: should introduce dedicated error class
+                # ordinal is 1-based
+                if c > 0:
+                    _c = self[c - 1]
+                # negative ordinal means sort by desc
+                elif c < 0:
+                    _c = self[-c - 1].desc()
+                else:
+                    raise IndexError("Column ordinal must not be zero!")
+            else:
+                _c = c  # type: ignore[assignment]
+            jcols.append(_to_java_column(cast("ColumnOrName", _c)))
+
         ascending = kwargs.get("ascending", True)
         if isinstance(ascending, (bool, int)):
             if not ascending:
diff --git a/python/pyspark/sql/tests/test_group.py 
b/python/pyspark/sql/tests/test_group.py
index d481d725ebfb..6981601cb129 100644
--- a/python/pyspark/sql/tests/test_group.py
+++ b/python/pyspark/sql/tests/test_group.py
@@ -96,6 +96,59 @@ class GroupTestsMixin:
             with self.assertRaises(IndexError):
                 df.groupBy(10).agg(sf.sum("b"))
 
+    def test_order_by_ordinal(self):
+        spark = self.spark
+        df = spark.createDataFrame(
+            [
+                (1, 1),
+                (1, 2),
+                (2, 1),
+                (2, 2),
+                (3, 1),
+                (3, 2),
+            ],
+            ["a", "b"],
+        )
+
+        with self.tempView("v"):
+            df.createOrReplaceTempView("v")
+
+            df1 = spark.sql("select * from v order by 1 desc;")
+            df2 = df.orderBy(-1)
+            assertSchemaEqual(df1.schema, df2.schema)
+            assertDataFrameEqual(df1, df2)
+
+            df1 = spark.sql("select * from v order by 1 desc, b desc;")
+            df2 = df.orderBy(-1, df.b.desc())
+            assertSchemaEqual(df1.schema, df2.schema)
+            assertDataFrameEqual(df1, df2)
+
+            df1 = spark.sql("select * from v order by 1 desc, 2 desc;")
+            df2 = df.orderBy(-1, -2)
+            assertSchemaEqual(df1.schema, df2.schema)
+            assertDataFrameEqual(df1, df2)
+
+            # groupby ordinal with orderby ordinal
+            df1 = spark.sql("select a, 1, sum(b) from v group by 1, 2 order by 
1;")
+            df2 = df.select("a", sf.lit(1), "b").groupBy(1, 
2).agg(sf.sum("b")).sort(1)
+            assertSchemaEqual(df1.schema, df2.schema)
+            assertDataFrameEqual(df1, df2)
+
+            df1 = spark.sql("select a, 1, sum(b) from v group by 1, 2 order by 
3, 1;")
+            df2 = df.select("a", sf.lit(1), "b").groupBy(1, 
2).agg(sf.sum("b")).sort(3, 1)
+            assertSchemaEqual(df1.schema, df2.schema)
+            assertDataFrameEqual(df1, df2)
+
+            # negative cases: ordinal out of range
+            with self.assertRaises(IndexError):
+                df.sort(0)
+
+            with self.assertRaises(IndexError):
+                df.orderBy(3)
+
+            with self.assertRaises(IndexError):
+                df.orderBy(-3)
+
 
 class GroupTests(GroupTestsMixin, ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to