This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c03ebb467ec2 [SPARK-48272][SQL][PYTHON][CONNECT] Add function 
`timestamp_diff`
c03ebb467ec2 is described below

commit c03ebb467ec268d894f3d97bea388129a840f5cf
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed May 15 15:55:09 2024 +0800

    [SPARK-48272][SQL][PYTHON][CONNECT] Add function `timestamp_diff`
    
    ### What changes were proposed in this pull request?
    Add function `timestamp_diff`, by reusing existing proto
    
https://github.com/apache/spark/blob/c4df12cc884cddefcfcf8324b4d7b9349fb4f6a0/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala#L1971-L1974
    
    ### Why are the changes needed?
    this method is missing in dataframe API due to it is not in 
`FunctionRegistry`
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new method
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46576 from zhengruifeng/df_ts_diff.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../scala/org/apache/spark/sql/functions.scala     |  10 ++++
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   4 ++
 .../function_timestamp_diff.explain                |   2 +
 .../queries/function_timestamp_diff.json           |  33 ++++++++++++
 .../queries/function_timestamp_diff.proto.bin      | Bin 0 -> 145 bytes
 .../sql/connect/planner/SparkConnectPlanner.scala  |  10 ++--
 .../source/reference/pyspark.sql/functions.rst     |   1 +
 python/pyspark/sql/connect/functions/builtin.py    |   7 +++
 python/pyspark/sql/functions/builtin.py            |  60 +++++++++++++++++++++
 .../scala/org/apache/spark/sql/functions.scala     |  11 ++++
 .../apache/spark/sql/DataFrameFunctionsSuite.scala |   3 +-
 11 files changed, 135 insertions(+), 6 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index bf41ada97916..c537f535c6b2 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -5953,6 +5953,16 @@ object functions {
    */
   def timestamp_micros(e: Column): Column = Column.fn("timestamp_micros", e)
 
+  /**
+   * Gets the difference between the timestamps in the specified units by 
truncating the fraction
+   * part.
+   *
+   * @group datetime_funcs
+   * @since 4.0.0
+   */
+  def timestamp_diff(unit: String, start: Column, end: Column): Column =
+    Column.fn("timestampdiff", lit(unit), start, end)
+
   /**
    * Parses the `timestamp` expression with the `format` expression to a 
timestamp without time
    * zone. Returns null with invalid input.
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 144b45bdfd31..e6955805d38d 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -2305,6 +2305,10 @@ class PlanGenerationTestSuite
     fn.timestamp_micros(fn.col("x"))
   }
 
+  temporalFunctionTest("timestamp_diff") {
+    fn.timestamp_diff("year", fn.col("t"), fn.col("t"))
+  }
+
   // Array of Long
   // Array of Long
   // Array of Array of Long
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_timestamp_diff.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_timestamp_diff.explain
new file mode 100644
index 000000000000..7a0a3ff8c53d
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_timestamp_diff.explain
@@ -0,0 +1,2 @@
+Project [timestampdiff(year, t#0, t#0, Some(America/Los_Angeles)) AS 
timestampdiff(year, t, t)#0L]
++- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_timestamp_diff.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_timestamp_diff.json
new file mode 100644
index 000000000000..635cbb45460e
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/function_timestamp_diff.json
@@ -0,0 +1,33 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "project": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": 
"struct\u003cd:date,t:timestamp,s:string,x:bigint,wt:struct\u003cstart:timestamp,end:timestamp\u003e\u003e"
+      }
+    },
+    "expressions": [{
+      "unresolvedFunction": {
+        "functionName": "timestampdiff",
+        "arguments": [{
+          "literal": {
+            "string": "year"
+          }
+        }, {
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "t"
+          }
+        }, {
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "t"
+          }
+        }]
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_timestamp_diff.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_timestamp_diff.proto.bin
new file mode 100644
index 000000000000..3a81fd8b318c
Binary files /dev/null and 
b/connector/connect/common/src/test/resources/query-tests/queries/function_timestamp_diff.proto.bin
 differ
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 492dac12f614..149b2e482c1d 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1823,6 +1823,11 @@ class SparkConnectPlanner(
           new BloomFilterAggregate(children(0), children(1), children(2))
             .toAggregateExpression())
 
+      case "timestampdiff" if fun.getArgumentsCount == 3 =>
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+        val unit = extractString(children(0), "unit")
+        Some(TimestampDiff(unit, children(1), children(2)))
+
       case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
         val children = fun.getArgumentsList.asScala.map(transformExpression)
         val timeCol = children.head
@@ -1968,11 +1973,6 @@ class SparkConnectPlanner(
         val children = fun.getArgumentsList.asScala.map(transformExpression)
         Some(NullIndex(children(0)))
 
-      case "timestampdiff" if fun.getArgumentsCount == 3 =>
-        val children = fun.getArgumentsList.asScala.map(transformExpression)
-        val unit = extractString(children(0), "unit")
-        Some(TimestampDiff(unit, children(1), children(2)))
-
       // ML-specific functions
       case "vector_to_array" if fun.getArgumentsCount == 2 =>
         val expr = transformExpression(fun.getArguments(0))
diff --git a/python/docs/source/reference/pyspark.sql/functions.rst 
b/python/docs/source/reference/pyspark.sql/functions.rst
index fb3273bf95e7..16cf7e1337bb 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -281,6 +281,7 @@ Date and Timestamp Functions
     quarter
     second
     session_window
+    timestamp_diff
     timestamp_micros
     timestamp_millis
     timestamp_seconds
diff --git a/python/pyspark/sql/connect/functions/builtin.py 
b/python/pyspark/sql/connect/functions/builtin.py
index ea4fa2ba967b..a063c1b30165 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -3397,6 +3397,13 @@ def timestamp_micros(col: "ColumnOrName") -> Column:
 timestamp_micros.__doc__ = pysparkfuncs.timestamp_micros.__doc__
 
 
+def timestamp_diff(unit: str, start: "ColumnOrName", end: "ColumnOrName") -> 
Column:
+    return _invoke_function_over_columns("timestampdiff", lit(unit), start, 
end)
+
+
+timestamp_diff.__doc__ = pysparkfuncs.timestamp_diff.__doc__
+
+
 def window(
     timeColumn: "ColumnOrName",
     windowDuration: str,
diff --git a/python/pyspark/sql/functions/builtin.py 
b/python/pyspark/sql/functions/builtin.py
index bca86d907567..060c835df2be 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -9401,6 +9401,66 @@ def timestamp_micros(col: "ColumnOrName") -> Column:
     return _invoke_function_over_columns("timestamp_micros", col)
 
 
+@_try_remote_functions
+def timestamp_diff(unit: str, start: "ColumnOrName", end: "ColumnOrName") -> 
Column:
+    """
+    Gets the difference between the timestamps in the specified units by 
truncating
+    the fraction part.
+
+    .. versionadded:: 4.0.0
+
+    Parameters
+    ----------
+    unit : str
+        This indicates the units of the difference between the given 
timestamps.
+        Supported options are (case insensitive): "YEAR", "QUARTER", "MONTH", 
"WEEK",
+        "DAY", "HOUR", "MINUTE", "SECOND", "MILLISECOND" and "MICROSECOND".
+    start : :class:`~pyspark.sql.Column` or str
+        A timestamp which the expression subtracts from `endTimestamp`.
+    end : :class:`~pyspark.sql.Column` or str
+        A timestamp from which the expression subtracts `startTimestamp`.
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        the difference between the timestamps.
+
+    Examples
+    --------
+    >>> import datetime
+    >>> from pyspark.sql import functions as sf
+    >>> df = spark.createDataFrame(
+    ...     [(datetime.datetime(2016, 3, 11, 9, 0, 7), datetime.datetime(2024, 
4, 2, 9, 0, 7))],
+    ... ).toDF("start", "end")
+    >>> df.select(sf.timestamp_diff("year", "start", "end")).show()
+    +-------------------------------+
+    |timestampdiff(year, start, end)|
+    +-------------------------------+
+    |                              8|
+    +-------------------------------+
+    >>> df.select(sf.timestamp_diff("WEEK", "start", "end")).show()
+    +-------------------------------+
+    |timestampdiff(WEEK, start, end)|
+    +-------------------------------+
+    |                            420|
+    +-------------------------------+
+    >>> df.select(sf.timestamp_diff("day", "end", "start")).show()
+    +------------------------------+
+    |timestampdiff(day, end, start)|
+    +------------------------------+
+    |                         -2944|
+    +------------------------------+
+    """
+    from pyspark.sql.classic.column import _to_java_column
+
+    return _invoke_function(
+        "timestamp_diff",
+        unit,
+        _to_java_column(start),
+        _to_java_column(end),
+    )
+
+
 @_try_remote_functions
 def window(
     timeColumn: "ColumnOrName",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index d1b449bf27aa..6a4d74cf158e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -5741,6 +5741,17 @@ object functions {
    */
   def timestamp_micros(e: Column): Column = Column.fn("timestamp_micros", e)
 
+  /**
+   * Gets the difference between the timestamps in the specified units by 
truncating
+   * the fraction part.
+   *
+   * @group datetime_funcs
+   * @since 4.0.0
+   */
+  def timestamp_diff(unit: String, start: Column, end: Column): Column = 
withExpr {
+    TimestampDiff(unit, start.expr, end.expr)
+  }
+
   /**
    * Parses the `timestamp` expression with the `format` expression
    * to a timestamp without time zone. Returns null with invalid input.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index e42f397cbfc2..f4b16190dcd2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -82,7 +82,8 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
       "bucket", "days", "hours", "months", "years", // Datasource v2 partition 
transformations
       "product", // Discussed in https://github.com/apache/spark/pull/30745
       "unwrap_udt",
-      "collect_top_k"
+      "collect_top_k",
+      "timestamp_diff"
     )
 
     // We only consider functions matching this pattern, this excludes 
symbolic and other


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to