This is an automated email from the ASF dual-hosted git repository.

zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 899dec2  [SPARK-37738][PYTHON] Support column type inputs for second 
arg of date manipulation functions
899dec2 is described below

commit 899dec2ea36bcd934f1512b26e7fc8903e9f5ba8
Author: Daniel Davies <ddav...@palantir.com>
AuthorDate: Thu Dec 30 16:07:39 2021 +0100

    [SPARK-37738][PYTHON] Support column type inputs for second arg of date 
manipulation functions
    
    ### What changes were proposed in this pull request?
    See https://issues.apache.org/jira/browse/SPARK-37738
    
    There seems to be a skew in the Scala Spark API & PySpark API; namely, 
date_add/ date_sub/ add_months take an 'int' type for the 'days' parameter in 
PySpark, but can accept a column or an integer in Scala.
    
    This PR makes both types available to the 'days' parameter in PySpark.
    
    ### Why are the changes needed?
    Users should see a consistent API across Python & Scala side processing.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes- additive only.
    
    >>> df = spark.createDataFrame([('2015-04-08', 2,)], ['dt', 'add'])
    >>> df.select(date_add(df.dt, df.add).alias('next_date')).collect()
    [Row(next_date=datetime.date(2015, 4, 10))]
    
    ### How was this patch tested?
    3 new unit tests
    
    Closes #35032 from Daniel-Davies/master.
    
    Lead-authored-by: Daniel Davies <ddav...@palantir.com>
    Co-authored-by: Daniel-Davies 
<33356828+daniel-dav...@users.noreply.github.com>
    Signed-off-by: zero323 <mszymkiew...@gmail.com>
---
 python/pyspark/sql/functions.py            | 33 ++++++++++++-----
 python/pyspark/sql/tests/test_functions.py | 57 ++++++++++++++++++++++++++++++
 2 files changed, 81 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index a0f2bbf..4791d3c 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2200,7 +2200,7 @@ def make_date(year: "ColumnOrName", month: 
"ColumnOrName", day: "ColumnOrName")
     return Column(jc)
 
 
-def date_add(start: "ColumnOrName", days: int) -> Column:
+def date_add(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> 
Column:
     """
     Returns the date that is `days` days after `start`
 
@@ -2208,16 +2208,21 @@ def date_add(start: "ColumnOrName", days: int) -> 
Column:
 
     Examples
     --------
-    >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+    >>> df = spark.createDataFrame([('2015-04-08', 2,)], ['dt', 'add'])
     >>> df.select(date_add(df.dt, 1).alias('next_date')).collect()
     [Row(next_date=datetime.date(2015, 4, 9))]
+    >>> df.select(date_add(df.dt, 
df.add.cast('integer')).alias('next_date')).collect()
+    [Row(next_date=datetime.date(2015, 4, 10))]
     """
     sc = SparkContext._active_spark_context
     assert sc is not None and sc._jvm is not None
-    return Column(sc._jvm.functions.date_add(_to_java_column(start), days))
+
+    days = lit(days) if isinstance(days, int) else days
+
+    return Column(sc._jvm.functions.date_add(_to_java_column(start), 
_to_java_column(days)))
 
 
-def date_sub(start: "ColumnOrName", days: int) -> Column:
+def date_sub(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> 
Column:
     """
     Returns the date that is `days` days before `start`
 
@@ -2225,13 +2230,18 @@ def date_sub(start: "ColumnOrName", days: int) -> 
Column:
 
     Examples
     --------
-    >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+    >>> df = spark.createDataFrame([('2015-04-08', 2,)], ['dt', 'sub'])
     >>> df.select(date_sub(df.dt, 1).alias('prev_date')).collect()
     [Row(prev_date=datetime.date(2015, 4, 7))]
+    >>> df.select(date_sub(df.dt, 
df.sub.cast('integer')).alias('prev_date')).collect()
+    [Row(prev_date=datetime.date(2015, 4, 6))]
     """
     sc = SparkContext._active_spark_context
     assert sc is not None and sc._jvm is not None
-    return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))
+
+    days = lit(days) if isinstance(days, int) else days
+
+    return Column(sc._jvm.functions.date_sub(_to_java_column(start), 
_to_java_column(days)))
 
 
 def datediff(end: "ColumnOrName", start: "ColumnOrName") -> Column:
@@ -2251,7 +2261,7 @@ def datediff(end: "ColumnOrName", start: "ColumnOrName") 
-> Column:
     return Column(sc._jvm.functions.datediff(_to_java_column(end), 
_to_java_column(start)))
 
 
-def add_months(start: "ColumnOrName", months: int) -> Column:
+def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> 
Column:
     """
     Returns the date that is `months` months after `start`
 
@@ -2259,13 +2269,18 @@ def add_months(start: "ColumnOrName", months: int) -> 
Column:
 
     Examples
     --------
-    >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+    >>> df = spark.createDataFrame([('2015-04-08', 2)], ['dt', 'add'])
     >>> df.select(add_months(df.dt, 1).alias('next_month')).collect()
     [Row(next_month=datetime.date(2015, 5, 8))]
+    >>> df.select(add_months(df.dt, 
df.add.cast('integer')).alias('next_month')).collect()
+    [Row(next_month=datetime.date(2015, 6, 8))]
     """
     sc = SparkContext._active_spark_context
     assert sc is not None and sc._jvm is not None
-    return Column(sc._jvm.functions.add_months(_to_java_column(start), months))
+
+    months = lit(months) if isinstance(months, int) else months
+
+    return Column(sc._jvm.functions.add_months(_to_java_column(start), 
_to_java_column(months)))
 
 
 def months_between(date1: "ColumnOrName", date2: "ColumnOrName", roundOff: 
bool = True) -> Column:
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index b6a3278..eb3b433 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -43,6 +43,9 @@ from pyspark.sql.functions import (
     csc,
     cot,
     make_date,
+    date_add,
+    date_sub,
+    add_months,
 )
 from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
 
@@ -286,6 +289,60 @@ class FunctionsTests(ReusedSQLTestCase):
         row = df.select(dayofweek(df.date)).first()
         self.assertEqual(row[0], 2)
 
+    # Test added for SPARK-37738; change Python API to accept both col & int 
as input
+    def test_date_add_function(self):
+        dt = datetime.date(2021, 12, 27)
+
+        # Note; number var in Python gets converted to LongType column;
+        # this is not supported by the function, so cast to Integer explicitly
+        df = self.spark.createDataFrame([Row(date=dt, add=2)], "date date, add 
integer")
+
+        self.assertTrue(
+            all(
+                df.select(
+                    date_add(df.date, df.add) == datetime.date(2021, 12, 29),
+                    date_add(df.date, "add") == datetime.date(2021, 12, 29),
+                    date_add(df.date, 3) == datetime.date(2021, 12, 30),
+                ).first()
+            )
+        )
+
+    # Test added for SPARK-37738; change Python API to accept both col & int 
as input
+    def test_date_sub_function(self):
+        dt = datetime.date(2021, 12, 27)
+
+        # Note; number var in Python gets converted to LongType column;
+        # this is not supported by the function, so cast to Integer explicitly
+        df = self.spark.createDataFrame([Row(date=dt, sub=2)], "date date, sub 
integer")
+
+        self.assertTrue(
+            all(
+                df.select(
+                    date_sub(df.date, df.sub) == datetime.date(2021, 12, 25),
+                    date_sub(df.date, "sub") == datetime.date(2021, 12, 25),
+                    date_sub(df.date, 3) == datetime.date(2021, 12, 24),
+                ).first()
+            )
+        )
+
+    # Test added for SPARK-37738; change Python API to accept both col & int 
as input
+    def test_add_months_function(self):
+        dt = datetime.date(2021, 12, 27)
+
+        # Note; number in Python gets converted to LongType column;
+        # this is not supported by the function, so cast to Integer explicitly
+        df = self.spark.createDataFrame([Row(date=dt, add=2)], "date date, add 
integer")
+
+        self.assertTrue(
+            all(
+                df.select(
+                    add_months(df.date, df.add) == datetime.date(2022, 2, 27),
+                    add_months(df.date, "add") == datetime.date(2022, 2, 27),
+                    add_months(df.date, 3) == datetime.date(2022, 3, 27),
+                ).first()
+            )
+        )
+
     def test_make_date(self):
         # SPARK-36554: expose make_date expression
         df = self.spark.createDataFrame([(2020, 6, 26)], ["Y", "M", "D"])

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to