This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 56393b90723 [SPARK-45091][PYTHON][CONNECT][SQL] Function 
`floor/round/bround` accept Column type `scale`
56393b90723 is described below

commit 56393b90723257a757b7b87fb623847ef03d4bf3
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Sep 7 10:44:25 2023 +0800

    [SPARK-45091][PYTHON][CONNECT][SQL] Function `floor/round/bround` accept 
Column type `scale`
    
    ### What changes were proposed in this pull request?
    1, `floor`: add missing parameter `scale` in Python, which already existed 
in Scala for a long time;
    2, `round/bround`: parameter `scale` support Column type, to be consistent 
with `floor/ceil/ceiling`
    
    ### Why are the changes needed?
    to make related functions consistent
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added doctest
    
    ### Was this patch authored or co-authored using generative AI tooling?
    NO
    
    Closes #42833 from zhengruifeng/py_func_floor.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../scala/org/apache/spark/sql/functions.scala     | 18 +++++
 python/pyspark/sql/connect/functions.py            | 24 ++++--
 python/pyspark/sql/functions.py                    | 89 ++++++++++++++++++----
 .../scala/org/apache/spark/sql/functions.scala     | 22 ++++++
 4 files changed, 131 insertions(+), 22 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 54bf0106956..bf536c349cb 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2845,6 +2845,15 @@ object functions {
    */
   def round(e: Column, scale: Int): Column = Column.fn("round", e, lit(scale))
 
+  /**
+   * Round the value of `e` to `scale` decimal places with HALF_UP round mode 
if `scale` is
+   * greater than or equal to 0 or at integral part when `scale` is less than 
0.
+   *
+   * @group math_funcs
+   * @since 4.0.0
+   */
+  def round(e: Column, scale: Column): Column = Column.fn("round", e, scale)
+
   /**
    * Returns the value of the column `e` rounded to 0 decimal places with 
HALF_EVEN round mode.
    *
@@ -2862,6 +2871,15 @@ object functions {
    */
   def bround(e: Column, scale: Int): Column = Column.fn("bround", e, 
lit(scale))
 
+  /**
+   * Round the value of `e` to `scale` decimal places with HALF_EVEN round 
mode if `scale` is
+   * greater than or equal to 0 or at integral part when `scale` is less than 
0.
+   *
+   * @group math_funcs
+   * @since 4.0.0
+   */
+  def bround(e: Column, scale: Column): Column = Column.fn("bround", e, scale)
+
   /**
    * @param e
    *   angle in radians
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index cc03a3a3578..892ad6e6295 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -538,8 +538,12 @@ def bin(col: "ColumnOrName") -> Column:
 bin.__doc__ = pysparkfuncs.bin.__doc__
 
 
-def bround(col: "ColumnOrName", scale: int = 0) -> Column:
-    return _invoke_function("bround", _to_col(col), lit(scale))
+def bround(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> 
Column:
+    if scale is None:
+        return _invoke_function_over_columns("bround", col)
+    else:
+        scale = lit(scale) if isinstance(scale, int) else scale
+        return _invoke_function_over_columns("bround", col, scale)
 
 
 bround.__doc__ = pysparkfuncs.bround.__doc__
@@ -644,8 +648,12 @@ def factorial(col: "ColumnOrName") -> Column:
 factorial.__doc__ = pysparkfuncs.factorial.__doc__
 
 
-def floor(col: "ColumnOrName") -> Column:
-    return _invoke_function_over_columns("floor", col)
+def floor(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> 
Column:
+    if scale is None:
+        return _invoke_function_over_columns("floor", col)
+    else:
+        scale = lit(scale) if isinstance(scale, int) else scale
+        return _invoke_function_over_columns("floor", col, scale)
 
 
 floor.__doc__ = pysparkfuncs.floor.__doc__
@@ -773,8 +781,12 @@ def rint(col: "ColumnOrName") -> Column:
 rint.__doc__ = pysparkfuncs.rint.__doc__
 
 
-def round(col: "ColumnOrName", scale: int = 0) -> Column:
-    return _invoke_function("round", _to_col(col), lit(scale))
+def round(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> 
Column:
+    if scale is None:
+        return _invoke_function_over_columns("round", col)
+    else:
+        scale = lit(scale) if isinstance(scale, int) else scale
+        return _invoke_function_over_columns("round", col, scale)
 
 
 round.__doc__ = pysparkfuncs.round.__doc__
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index de91cced206..c1e24ba25ac 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2346,7 +2346,7 @@ def expm1(col: "ColumnOrName") -> Column:
 
 
 @try_remote_functions
-def floor(col: "ColumnOrName") -> Column:
+def floor(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> 
Column:
     """
     Computes the floor of the given value.
 
@@ -2359,6 +2359,11 @@ def floor(col: "ColumnOrName") -> Column:
     ----------
     col : :class:`~pyspark.sql.Column` or str
         column to find floor for.
+    scale : :class:`~pyspark.sql.Column` or int
+        an optional parameter to control the rounding behavior.
+
+            .. versionadded:: 4.0.0
+
 
     Returns
     -------
@@ -2367,15 +2372,27 @@ def floor(col: "ColumnOrName") -> Column:
 
     Examples
     --------
-    >>> df = spark.range(1)
-    >>> df.select(floor(lit(2.5))).show()
+    >>> import pyspark.sql.functions as sf
+    >>> spark.range(1).select(sf.floor(sf.lit(2.5))).show()
     +----------+
     |FLOOR(2.5)|
     +----------+
     |         2|
     +----------+
+
+    >>> import pyspark.sql.functions as sf
+    >>> spark.range(1).select(sf.floor(sf.lit(2.1267), sf.lit(2))).show()
+    +----------------+
+    |floor(2.1267, 2)|
+    +----------------+
+    |            2.12|
+    +----------------+
     """
-    return _invoke_function_over_columns("floor", col)
+    if scale is None:
+        return _invoke_function_over_columns("floor", col)
+    else:
+        scale = lit(scale) if isinstance(scale, int) else scale
+        return _invoke_function_over_columns("floor", col, scale)
 
 
 @try_remote_functions
@@ -5631,7 +5648,7 @@ def randn(seed: Optional[int] = None) -> Column:
 
 
 @try_remote_functions
-def round(col: "ColumnOrName", scale: int = 0) -> Column:
+def round(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> 
Column:
     """
     Round the given value to `scale` decimal places using HALF_UP rounding 
mode if `scale` >= 0
     or at integral part when `scale` < 0.
@@ -5645,8 +5662,11 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column:
     ----------
     col : :class:`~pyspark.sql.Column` or str
         input column to round.
-    scale : int optional default 0
-        scale value.
+    scale : :class:`~pyspark.sql.Column` or int
+        an optional parameter to control the rounding behavior.
+
+            .. versionchanged:: 4.0.0
+                Support Column type.
 
     Returns
     -------
@@ -5655,14 +5675,31 @@ def round(col: "ColumnOrName", scale: int = 0) -> 
Column:
 
     Examples
     --------
-    >>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 
0).alias('r')).collect()
-    [Row(r=3.0)]
+    >>> import pyspark.sql.functions as sf
+    >>> spark.range(1).select(sf.round(sf.lit(2.5))).show()
+    +-------------+
+    |round(2.5, 0)|
+    +-------------+
+    |          3.0|
+    +-------------+
+
+    >>> import pyspark.sql.functions as sf
+    >>> spark.range(1).select(sf.round(sf.lit(2.1267), sf.lit(2))).show()
+    +----------------+
+    |round(2.1267, 2)|
+    +----------------+
+    |            2.13|
+    +----------------+
     """
-    return _invoke_function("round", _to_java_column(col), scale)
+    if scale is None:
+        return _invoke_function_over_columns("round", col)
+    else:
+        scale = lit(scale) if isinstance(scale, int) else scale
+        return _invoke_function_over_columns("round", col, scale)
 
 
 @try_remote_functions
-def bround(col: "ColumnOrName", scale: int = 0) -> Column:
+def bround(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> 
Column:
     """
     Round the given value to `scale` decimal places using HALF_EVEN rounding 
mode if `scale` >= 0
     or at integral part when `scale` < 0.
@@ -5676,8 +5713,11 @@ def bround(col: "ColumnOrName", scale: int = 0) -> 
Column:
     ----------
     col : :class:`~pyspark.sql.Column` or str
         input column to round.
-    scale : int optional default 0
-        scale value.
+    scale : :class:`~pyspark.sql.Column` or int
+        an optional parameter to control the rounding behavior.
+
+            .. versionchanged:: 4.0.0
+                Support Column type.
 
     Returns
     -------
@@ -5686,10 +5726,27 @@ def bround(col: "ColumnOrName", scale: int = 0) -> 
Column:
 
     Examples
     --------
-    >>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 
0).alias('r')).collect()
-    [Row(r=2.0)]
+    >>> import pyspark.sql.functions as sf
+    >>> spark.range(1).select(sf.bround(sf.lit(2.5))).show()
+    +--------------+
+    |bround(2.5, 0)|
+    +--------------+
+    |           2.0|
+    +--------------+
+
+    >>> import pyspark.sql.functions as sf
+    >>> spark.range(1).select(sf.bround(sf.lit(2.1267), sf.lit(2))).show()
+    +-----------------+
+    |bround(2.1267, 2)|
+    +-----------------+
+    |             2.13|
+    +-----------------+
     """
-    return _invoke_function("bround", _to_java_column(col), scale)
+    if scale is None:
+        return _invoke_function_over_columns("bround", col)
+    else:
+        scale = lit(scale) if isinstance(scale, int) else scale
+        return _invoke_function_over_columns("bround", col, scale)
 
 
 @try_remote_functions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2f496785a6e..c1f674d2c0c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2881,6 +2881,17 @@ object functions {
    */
   def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, 
Literal(scale)) }
 
+  /**
+   * Round the value of `e` to `scale` decimal places with HALF_UP round mode
+   * if `scale` is greater than or equal to 0 or at integral part when `scale` 
is less than 0.
+   *
+   * @group math_funcs
+   * @since 4.0.0
+   */
+  def round(e: Column, scale: Column): Column = withExpr {
+    Round(e.expr, scale.expr)
+  }
+
   /**
    * Returns the value of the column `e` rounded to 0 decimal places with 
HALF_EVEN round mode.
    *
@@ -2898,6 +2909,17 @@ object functions {
    */
   def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, 
Literal(scale)) }
 
+  /**
+   * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode
+   * if `scale` is greater than or equal to 0 or at integral part when `scale` 
is less than 0.
+   *
+   * @group math_funcs
+   * @since 4.0.0
+   */
+  def bround(e: Column, scale: Column): Column = withExpr {
+    BRound(e.expr, scale.expr)
+  }
+
   /**
    * @param e angle in radians
    * @return secant of the angle


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to