HyukjinKwon closed pull request #22533: [SPARK-18818][PYTHON] Add 'ascending'
parameter to Window.orderBy()
URL: https://github.com/apache/spark/pull/22533
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index e7dec11c69b57..c9203b0d26ba7 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -54,6 +54,22 @@ def _to_java_column(col):
return jcol
+def _to_sorted_java_columns(cols, ascending=True):
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0]
+ jcols = [_to_java_column(c) for c in cols]
+ if isinstance(ascending, (bool, int)):
+ if not ascending:
+ jcols = [jc.desc() for jc in jcols]
+ elif isinstance(ascending, list):
+ jcols = [jc if asc else jc.desc()
+ for asc, jc in zip(ascending, jcols)]
+ else:
+ raise TypeError("Ascending can only be boolean or list, but got %s" %
type(ascending))
+
+ return jcols
+
+
def _to_seq(sc, cols, converter=None):
"""
Convert a list of Column (or names) into a JVM Seq of Column.
@@ -712,3 +728,4 @@ def _test():
if __name__ == "__main__":
_test()
+
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 21bc69b8236fd..f4b8b22012d6c 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -34,7 +34,8 @@
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import _parse_datatype_json_string
-from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
+from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column, \
+ _to_sorted_java_columns
from pyspark.sql.readwriter import DataFrameWriter
from pyspark.sql.streaming import DataStreamWriter
from pyspark.sql.types import IntegralType
@@ -1125,21 +1126,8 @@ def _jcols(self, *cols):
def _sort_cols(self, cols, kwargs):
""" Return a JVM Seq of Columns that describes the sort order
"""
- if not cols:
- raise ValueError("should sort by at least one column")
- if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
- jcols = [_to_java_column(c) for c in cols]
ascending = kwargs.get('ascending', True)
- if isinstance(ascending, (bool, int)):
- if not ascending:
- jcols = [jc.desc() for jc in jcols]
- elif isinstance(ascending, list):
- jcols = [jc if asc else jc.desc()
- for asc, jc in zip(ascending, jcols)]
- else:
- raise TypeError("ascending can only be boolean or list, but got
%s" % type(ascending))
- return self._jseq(jcols)
+ return self._jseq(_to_sorted_java_columns(cols, ascending))
@since("1.3.1")
def describe(self, *cols):
@@ -2332,3 +2320,4 @@ def _test():
if __name__ == "__main__":
_test()
+
diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
index d19ced954f04e..418f96f16b5cb 100644
--- a/python/pyspark/sql/window.py
+++ b/python/pyspark/sql/window.py
@@ -20,7 +20,7 @@
long = int
from pyspark import since, SparkContext
-from pyspark.sql.column import Column, _to_seq, _to_java_column
+from pyspark.sql.column import Column, _to_seq, _to_java_column,
_to_sorted_java_columns
__all__ = ["Window", "WindowSpec"]
@@ -44,6 +44,9 @@ class Window(object):
>>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3
FOLLOWING
>>> window =
Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3)
+ >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3
FOLLOWING
+ >>> window = Window.orderBy("date",
ascending=False).partitionBy("country").rangeBetween(-3, 3)
+
.. note:: When ordering is not defined, an unbounded window frame
(rowFrame,
unboundedPreceding, unboundedFollowing) is used by default. When
ordering is defined,
a growing window frame (rangeFrame, unboundedPreceding, currentRow)
is used by default.
@@ -76,12 +79,40 @@ def partitionBy(*cols):
@staticmethod
@since(1.4)
- def orderBy(*cols):
+ def orderBy(*cols, **kwargs):
"""
Creates a :class:`WindowSpec` with the ordering defined.
+
+ :param cols: names of columns or expressions.
+ :param ascending: boolean or list of boolean (default True).
+ Sort ascending vs. descending. Specify list for multiple sort
orders.
+ If a list is specified, length of the list must equal length of
the `cols`.
+
+ .. versionchanged:: 2.5
+ Added optional ``ascending`` argument.
+
+ >>> from pyspark.sql import functions as F, SparkSession, Window
+ >>> spark = SparkSession.builder.getOrCreate()
+ >>> df = spark.createDataFrame(
+ ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")],
["id", "category"])
+ >>> window = Window.orderBy("id",
ascending=False).partitionBy("category").rowsBetween(
+ ... Window.unboundedPreceding, Window.currentRow)
+ >>> df.withColumn("sum", F.sum("id").over(window)).show()
+ +---+--------+---+
+ | id|category|sum|
+ +---+--------+---+
+ | 3| b| 3|
+ | 2| b| 5|
+ | 1| b| 6|
+ | 2| a| 2|
+ | 1| a| 3|
+ | 1| a| 4|
+ +---+--------+---+
"""
sc = SparkContext._active_spark_context
- jspec =
sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols))
+ ascending = kwargs.get('ascending', True)
+ jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(
+ _to_seq(sc, _to_sorted_java_columns(cols, ascending)))
return WindowSpec(jspec)
@staticmethod
@@ -195,13 +226,20 @@ def partitionBy(self, *cols):
return WindowSpec(self._jspec.partitionBy(_to_java_cols(cols)))
@since(1.4)
- def orderBy(self, *cols):
+ def orderBy(self, *cols, **kwargs):
"""
Defines the ordering columns in a :class:`WindowSpec`.
- :param cols: names of columns or expressions
+ :param cols: names of columns or expressions.
+ :param ascending: boolean or list of boolean (default True).
+ Sort ascending vs. descending. Specify list for multiple sort
orders.
+ If a list is specified, length of the list must equal length of
the `cols`.
"""
- return WindowSpec(self._jspec.orderBy(_to_java_cols(cols)))
+ ascending = kwargs.get('ascending', True)
+ return WindowSpec(self._jspec.orderBy(_to_seq(
+ SparkContext._active_spark_context,
+ _to_sorted_java_columns(cols, ascending)
+ )))
@since(1.4)
def rowsBetween(self, start, end):
@@ -273,3 +311,4 @@ def _test():
if __name__ == "__main__":
_test()
+
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]