This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 56393b90723 [SPARK-45091][PYTHON][CONNECT][SQL] Function
`floor/round/bround` accept Column type `scale`
56393b90723 is described below
commit 56393b90723257a757b7b87fb623847ef03d4bf3
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Sep 7 10:44:25 2023 +0800
[SPARK-45091][PYTHON][CONNECT][SQL] Function `floor/round/bround` accept
Column type `scale`
### What changes were proposed in this pull request?
1, `floor`: add missing parameter `scale` in Python, which already existed
in Scala for a long time;
2, `round/bround`: parameter `scale` support Column type, to be consistent
with `floor/ceil/ceiling`
### Why are the changes needed?
to make related functions consistent
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added doctest
### Was this patch authored or co-authored using generative AI tooling?
NO
Closes #42833 from zhengruifeng/py_func_floor.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../scala/org/apache/spark/sql/functions.scala | 18 +++++
python/pyspark/sql/connect/functions.py | 24 ++++--
python/pyspark/sql/functions.py | 89 ++++++++++++++++++----
.../scala/org/apache/spark/sql/functions.scala | 22 ++++++
4 files changed, 131 insertions(+), 22 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 54bf0106956..bf536c349cb 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2845,6 +2845,15 @@ object functions {
*/
def round(e: Column, scale: Int): Column = Column.fn("round", e, lit(scale))
+ /**
+ * Round the value of `e` to `scale` decimal places with HALF_UP round mode
if `scale` is
+ * greater than or equal to 0 or at integral part when `scale` is less than
0.
+ *
+ * @group math_funcs
+ * @since 4.0.0
+ */
+ def round(e: Column, scale: Column): Column = Column.fn("round", e, scale)
+
/**
* Returns the value of the column `e` rounded to 0 decimal places with
HALF_EVEN round mode.
*
@@ -2862,6 +2871,15 @@ object functions {
*/
def bround(e: Column, scale: Int): Column = Column.fn("bround", e,
lit(scale))
+ /**
+ * Round the value of `e` to `scale` decimal places with HALF_EVEN round
mode if `scale` is
+ * greater than or equal to 0 or at integral part when `scale` is less than
0.
+ *
+ * @group math_funcs
+ * @since 4.0.0
+ */
+ def bround(e: Column, scale: Column): Column = Column.fn("bround", e, scale)
+
/**
* @param e
* angle in radians
diff --git a/python/pyspark/sql/connect/functions.py
b/python/pyspark/sql/connect/functions.py
index cc03a3a3578..892ad6e6295 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -538,8 +538,12 @@ def bin(col: "ColumnOrName") -> Column:
bin.__doc__ = pysparkfuncs.bin.__doc__
-def bround(col: "ColumnOrName", scale: int = 0) -> Column:
- return _invoke_function("bround", _to_col(col), lit(scale))
+def bround(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) ->
Column:
+ if scale is None:
+ return _invoke_function_over_columns("bround", col)
+ else:
+ scale = lit(scale) if isinstance(scale, int) else scale
+ return _invoke_function_over_columns("bround", col, scale)
bround.__doc__ = pysparkfuncs.bround.__doc__
@@ -644,8 +648,12 @@ def factorial(col: "ColumnOrName") -> Column:
factorial.__doc__ = pysparkfuncs.factorial.__doc__
-def floor(col: "ColumnOrName") -> Column:
- return _invoke_function_over_columns("floor", col)
+def floor(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) ->
Column:
+ if scale is None:
+ return _invoke_function_over_columns("floor", col)
+ else:
+ scale = lit(scale) if isinstance(scale, int) else scale
+ return _invoke_function_over_columns("floor", col, scale)
floor.__doc__ = pysparkfuncs.floor.__doc__
@@ -773,8 +781,12 @@ def rint(col: "ColumnOrName") -> Column:
rint.__doc__ = pysparkfuncs.rint.__doc__
-def round(col: "ColumnOrName", scale: int = 0) -> Column:
- return _invoke_function("round", _to_col(col), lit(scale))
+def round(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) ->
Column:
+ if scale is None:
+ return _invoke_function_over_columns("round", col)
+ else:
+ scale = lit(scale) if isinstance(scale, int) else scale
+ return _invoke_function_over_columns("round", col, scale)
round.__doc__ = pysparkfuncs.round.__doc__
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index de91cced206..c1e24ba25ac 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2346,7 +2346,7 @@ def expm1(col: "ColumnOrName") -> Column:
@try_remote_functions
-def floor(col: "ColumnOrName") -> Column:
+def floor(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) ->
Column:
"""
Computes the floor of the given value.
@@ -2359,6 +2359,11 @@ def floor(col: "ColumnOrName") -> Column:
----------
col : :class:`~pyspark.sql.Column` or str
column to find floor for.
+ scale : :class:`~pyspark.sql.Column` or int
+ an optional parameter to control the rounding behavior.
+
+ .. versionadded:: 4.0.0
+
Returns
-------
@@ -2367,15 +2372,27 @@ def floor(col: "ColumnOrName") -> Column:
Examples
--------
- >>> df = spark.range(1)
- >>> df.select(floor(lit(2.5))).show()
+ >>> import pyspark.sql.functions as sf
+ >>> spark.range(1).select(sf.floor(sf.lit(2.5))).show()
+----------+
|FLOOR(2.5)|
+----------+
| 2|
+----------+
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.range(1).select(sf.floor(sf.lit(2.1267), sf.lit(2))).show()
+ +----------------+
+ |floor(2.1267, 2)|
+ +----------------+
+ | 2.12|
+ +----------------+
"""
- return _invoke_function_over_columns("floor", col)
+ if scale is None:
+ return _invoke_function_over_columns("floor", col)
+ else:
+ scale = lit(scale) if isinstance(scale, int) else scale
+ return _invoke_function_over_columns("floor", col, scale)
@try_remote_functions
@@ -5631,7 +5648,7 @@ def randn(seed: Optional[int] = None) -> Column:
@try_remote_functions
-def round(col: "ColumnOrName", scale: int = 0) -> Column:
+def round(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) ->
Column:
"""
Round the given value to `scale` decimal places using HALF_UP rounding
mode if `scale` >= 0
or at integral part when `scale` < 0.
@@ -5645,8 +5662,11 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column:
----------
col : :class:`~pyspark.sql.Column` or str
input column to round.
- scale : int optional default 0
- scale value.
+ scale : :class:`~pyspark.sql.Column` or int
+ an optional parameter to control the rounding behavior.
+
+ .. versionchanged:: 4.0.0
+ Support Column type.
Returns
-------
@@ -5655,14 +5675,31 @@ def round(col: "ColumnOrName", scale: int = 0) ->
Column:
Examples
--------
- >>> spark.createDataFrame([(2.5,)], ['a']).select(round('a',
0).alias('r')).collect()
- [Row(r=3.0)]
+ >>> import pyspark.sql.functions as sf
+ >>> spark.range(1).select(sf.round(sf.lit(2.5))).show()
+ +-------------+
+ |round(2.5, 0)|
+ +-------------+
+ | 3.0|
+ +-------------+
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.range(1).select(sf.round(sf.lit(2.1267), sf.lit(2))).show()
+ +----------------+
+ |round(2.1267, 2)|
+ +----------------+
+ | 2.13|
+ +----------------+
"""
- return _invoke_function("round", _to_java_column(col), scale)
+ if scale is None:
+ return _invoke_function_over_columns("round", col)
+ else:
+ scale = lit(scale) if isinstance(scale, int) else scale
+ return _invoke_function_over_columns("round", col, scale)
@try_remote_functions
-def bround(col: "ColumnOrName", scale: int = 0) -> Column:
+def bround(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) ->
Column:
"""
Round the given value to `scale` decimal places using HALF_EVEN rounding
mode if `scale` >= 0
or at integral part when `scale` < 0.
@@ -5676,8 +5713,11 @@ def bround(col: "ColumnOrName", scale: int = 0) ->
Column:
----------
col : :class:`~pyspark.sql.Column` or str
input column to round.
- scale : int optional default 0
- scale value.
+ scale : :class:`~pyspark.sql.Column` or int
+ an optional parameter to control the rounding behavior.
+
+ .. versionchanged:: 4.0.0
+ Support Column type.
Returns
-------
@@ -5686,10 +5726,27 @@ def bround(col: "ColumnOrName", scale: int = 0) ->
Column:
Examples
--------
- >>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a',
0).alias('r')).collect()
- [Row(r=2.0)]
+ >>> import pyspark.sql.functions as sf
+ >>> spark.range(1).select(sf.bround(sf.lit(2.5))).show()
+ +--------------+
+ |bround(2.5, 0)|
+ +--------------+
+ | 2.0|
+ +--------------+
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.range(1).select(sf.bround(sf.lit(2.1267), sf.lit(2))).show()
+ +-----------------+
+ |bround(2.1267, 2)|
+ +-----------------+
+ | 2.13|
+ +-----------------+
"""
- return _invoke_function("bround", _to_java_column(col), scale)
+ if scale is None:
+ return _invoke_function_over_columns("bround", col)
+ else:
+ scale = lit(scale) if isinstance(scale, int) else scale
+ return _invoke_function_over_columns("bround", col, scale)
@try_remote_functions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2f496785a6e..c1f674d2c0c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2881,6 +2881,17 @@ object functions {
*/
def round(e: Column, scale: Int): Column = withExpr { Round(e.expr,
Literal(scale)) }
+ /**
+ * Round the value of `e` to `scale` decimal places with HALF_UP round mode
+ * if `scale` is greater than or equal to 0 or at integral part when `scale`
is less than 0.
+ *
+ * @group math_funcs
+ * @since 4.0.0
+ */
+ def round(e: Column, scale: Column): Column = withExpr {
+ Round(e.expr, scale.expr)
+ }
+
/**
* Returns the value of the column `e` rounded to 0 decimal places with
HALF_EVEN round mode.
*
@@ -2898,6 +2909,17 @@ object functions {
*/
def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr,
Literal(scale)) }
+ /**
+ * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode
+ * if `scale` is greater than or equal to 0 or at integral part when `scale`
is less than 0.
+ *
+ * @group math_funcs
+ * @since 4.0.0
+ */
+ def bround(e: Column, scale: Column): Column = withExpr {
+ BRound(e.expr, scale.expr)
+ }
+
/**
* @param e angle in radians
* @return secant of the angle
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]