This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f5365d0dc59 [SPARK-45034][SQL] Support deterministic mode function
f5365d0dc59 is described below

commit f5365d0dc590d4965a269da223dbd72fbb764595
Author: Peter Toth <peter.t...@gmail.com>
AuthorDate: Sun Sep 17 21:37:57 2023 +0300

    [SPARK-45034][SQL] Support deterministic mode function
    
    ### What changes were proposed in this pull request?
    This PR adds a new optional argument to the `mode` aggregate function to 
provide deterministic results. When multiple values have the same greatest 
frequency then the new boolean argument can be used to get the lowest or 
highest value instead of an arbitraty one.
    
    ### Why are the changes needed?
    To make the function more user friendly.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, it adds a new argument to the `mode` function.
    
    ### How was this patch tested?
    Added new UTs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #42755 from peter-toth/SPARK-45034-deterministic-mode-function.
    
    Authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../scala/org/apache/spark/sql/functions.scala     |  14 ++-
 .../explain-results/function_mode.explain          |   2 +-
 .../query-tests/queries/function_mode.json         |   4 +
 .../query-tests/queries/function_mode.proto.bin    | Bin 173 -> 179 bytes
 python/pyspark/sql/connect/functions.py            |   4 +-
 python/pyspark/sql/functions.py                    |  35 ++++--
 .../sql/catalyst/expressions/aggregate/Mode.scala  |  76 ++++++++++--
 .../scala/org/apache/spark/sql/functions.scala     |  16 ++-
 .../sql-functions/sql-expression-schema.md         |   2 +-
 .../sql-tests/analyzer-results/group-by.sql.out    | 120 ++++++++++++++++++-
 .../test/resources/sql-tests/inputs/group-by.sql   |  11 ++
 .../resources/sql-tests/results/group-by.sql.out   | 132 ++++++++++++++++++++-
 .../apache/spark/sql/DatasetAggregatorSuite.scala  |  10 ++
 13 files changed, 397 insertions(+), 29 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index b2102d4ba55..83f0ee64501 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -827,7 +827,19 @@ object functions {
    * @group agg_funcs
    * @since 3.4.0
    */
-  def mode(e: Column): Column = Column.fn("mode", e)
+  def mode(e: Column): Column = mode(e, deterministic = false)
+
+  /**
+   * Aggregate function: returns the most frequent value in a group.
+   *
+   * When multiple values have the same greatest frequency then either any of 
values is returned
+   * if deterministic is false or is not defined, or the lowest value is 
returned if deterministic
+   * is true.
+   *
+   * @group agg_funcs
+   * @since 4.0.0
+   */
+  def mode(e: Column, deterministic: Boolean): Column = Column.fn("mode", e, 
lit(deterministic))
 
   /**
    * Aggregate function: returns the maximum value of the expression in a 
group.
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_mode.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_mode.explain
index dfa2113a2c3..28bbb44b0fd 100644
--- 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_mode.explain
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_mode.explain
@@ -1,2 +1,2 @@
-Aggregate [mode(a#0, 0, 0) AS mode(a)#0]
+Aggregate [mode(a#0, 0, 0, false) AS mode(a, false)#0]
 +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_mode.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_mode.json
index 8e8183e9e08..5c26edee803 100644
--- 
a/connector/connect/common/src/test/resources/query-tests/queries/function_mode.json
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/function_mode.json
@@ -18,6 +18,10 @@
           "unresolvedAttribute": {
             "unparsedIdentifier": "a"
           }
+        }, {
+          "literal": {
+            "boolean": false
+          }
         }]
       }
     }]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_mode.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_mode.proto.bin
index dca0953a387..cc115e43172 100644
Binary files 
a/connector/connect/common/src/test/resources/query-tests/queries/function_mode.proto.bin
 and 
b/connector/connect/common/src/test/resources/query-tests/queries/function_mode.proto.bin
 differ
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 892ad6e6295..f89b1aae500 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -1136,8 +1136,8 @@ def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> 
Column:
 min_by.__doc__ = pysparkfuncs.min_by.__doc__
 
 
-def mode(col: "ColumnOrName") -> Column:
-    return _invoke_function_over_columns("mode", col)
+def mode(col: "ColumnOrName", deterministic: bool = False) -> Column:
+    return _invoke_function("mode", _to_col(col), lit(deterministic))
 
 
 mode.__doc__ = pysparkfuncs.mode.__doc__
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 31936241619..1e12b9bf469 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -737,16 +737,21 @@ def abs(col: "ColumnOrName") -> Column:
 
 
 @_try_remote_functions
-def mode(col: "ColumnOrName") -> Column:
+def mode(col: "ColumnOrName", deterministic: bool = False) -> Column:
     """
     Returns the most frequent value in a group.
 
     .. versionadded:: 3.4.0
 
+    .. versionchanged:: 4.0.0
+            Supports deterministic argument.
+
     Parameters
     ----------
     col : :class:`~pyspark.sql.Column` or str
         target column to compute on.
+    deterministic : bool, optional
+        if there are multiple equally-frequent results then return the lowest 
(defaults to false).
 
     Returns
     -------
@@ -765,14 +770,26 @@ def mode(col: "ColumnOrName") -> Column:
     ...     ("dotNET", 2013, 48000), ("Java", 2013, 30000)],
     ...     schema=("course", "year", "earnings"))
     >>> df.groupby("course").agg(mode("year")).show()
-    +------+----------+
-    |course|mode(year)|
-    +------+----------+
-    |  Java|      2012|
-    |dotNET|      2012|
-    +------+----------+
-    """
-    return _invoke_function_over_columns("mode", col)
+    +------+-----------------+
+    |course|mode(year, false)|
+    +------+-----------------+
+    |  Java|             2012|
+    |dotNET|             2012|
+    +------+-----------------+
+
+    When multiple values have the same greatest frequency then either any of 
values is returned if
+    deterministic is false or is not defined, or the lowest value is returned 
if deterministic is
+    true.
+
+    >>> df2 = spark.createDataFrame([(-10,), (0,), (10,)], ["col"])
+    >>> df2.select(mode("col", False), mode("col", True)).show()
+    +----------------+---------------+
+    |mode(col, false)|mode(col, true)|
+    +----------------+---------------+
+    |               0|            -10|
+    +----------------+---------------+
+    """
+    return _invoke_function("mode", _to_java_column(col), deterministic)
 
 
 @_try_remote_functions
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
index cad7d1f07dc..4ac44d9d2c9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
@@ -18,15 +18,22 @@
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Expression, 
ExpressionDescription, ImplicitCastInputTypes}
-import org.apache.spark.sql.catalyst.trees.UnaryLike
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, 
TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.{Expression, 
ExpressionDescription, ImplicitCastInputTypes, Literal}
+import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
+import org.apache.spark.sql.catalyst.types.PhysicalDataType
 import org.apache.spark.sql.catalyst.util.GenericArrayData
-import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, 
DataType}
+import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr
+import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
+import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, 
BooleanType, DataType}
 import org.apache.spark.util.collection.OpenHashMap
 
 // scalastyle:off line.size.limit
 @ExpressionDescription(
-  usage = "_FUNC_(col) - Returns the most frequent value for the values within 
`col`. NULL values are ignored. If all the values are NULL, or there are 0 
rows, returns NULL.",
+  usage = """
+    _FUNC_(col[, deterministic]) - Returns the most frequent value for the 
values within `col`. NULL values are ignored. If all the values are NULL, or 
there are 0 rows, returns NULL.
+      When multiple values have the same greatest frequency then either any of 
values is returned if `deterministic` is false or is not defined, or the lowest 
value is returned if `deterministic` is true.""",
   examples = """
     Examples:
       > SELECT _FUNC_(col) FROM VALUES (0), (10), (10) AS tab(col);
@@ -35,6 +42,10 @@ import org.apache.spark.util.collection.OpenHashMap
        0-10
       > SELECT _FUNC_(col) FROM VALUES (0), (10), (10), (null), (null), (null) 
AS tab(col);
        10
+      > SELECT _FUNC_(col, false) FROM VALUES (-10), (0), (10) AS tab(col);
+       0
+      > SELECT _FUNC_(col, true) FROM VALUES (-10), (0), (10) AS tab(col);
+       -10
   """,
   group = "agg_funcs",
   since = "3.4.0")
@@ -42,17 +53,53 @@ import org.apache.spark.util.collection.OpenHashMap
 case class Mode(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0) extends TypedAggregateWithHashMapAsBuffer
-  with ImplicitCastInputTypes with UnaryLike[Expression] {
+    inputAggBufferOffset: Int = 0,
+    deterministicExpr: Expression = Literal.FalseLiteral)
+  extends TypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes
+    with BinaryLike[Expression] {
 
   def this(child: Expression) = this(child, 0, 0)
 
+  def this(child: Expression, deterministicExpr: Expression) = {
+    this(child, 0, 0, deterministicExpr)
+  }
+
+  @transient
+  protected lazy val deterministicResult = 
deterministicExpr.eval().asInstanceOf[Boolean]
+
+  override def left: Expression = child
+
+  override def right: Expression = deterministicExpr
+
   // Returns null for empty inputs
   override def nullable: Boolean = true
 
   override def dataType: DataType = child.dataType
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, 
BooleanType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      return defaultCheck
+    }
+    if (!deterministicExpr.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("deterministic"),
+          "inputType" -> toSQLType(deterministicExpr.dataType),
+          "inputExpr" -> toSQLExpr(deterministicExpr)
+        )
+      )
+    } else if (deterministicExpr.eval() == null) {
+      DataTypeMismatch(
+        errorSubClass = "UNEXPECTED_NULL",
+        messageParameters = Map("exprName" -> toSQLId("deterministic")))
+    } else {
+      TypeCheckSuccess
+    }
+  }
 
   override def prettyName: String = "mode"
 
@@ -81,7 +128,16 @@ case class Mode(
       return null
     }
 
-    buffer.maxBy(_._2)._1
+    (if (deterministicResult) {
+      // When deterministic result is rquired but multiple keys have the same 
greatest frequency
+      // then let's select the lowest.
+      val defaultKeyOrdering =
+        
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]]
+      val ordering = Ordering.Tuple2(Ordering.Long, defaultKeyOrdering.reverse)
+      buffer.maxBy { case (key, count) => (count, key) }(ordering)
+    } else {
+      buffer.maxBy(_._2)
+    })._1
   }
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
Mode =
@@ -90,8 +146,8 @@ case class Mode(
   override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Mode 
=
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  override protected def withNewChildInternal(newChild: Expression): 
Expression =
-    copy(child = newChild)
+  override def withNewChildrenInternal(newLeft: Expression, newRight: 
Expression): Expression =
+    copy(child = newLeft, deterministicExpr = newRight)
 }
 
 /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 5935695818e..dcde01ec408 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -870,7 +870,21 @@ object functions {
    * @group agg_funcs
    * @since 3.4.0
    */
-  def mode(e: Column): Column = withAggregateFunction { Mode(e.expr) }
+  def mode(e: Column): Column = mode(e, deterministic = false)
+
+  /**
+   * Aggregate function: returns the most frequent value in a group.
+   *
+   * When multiple values have the same greatest frequency then either any of 
values is returned
+   * if deterministic is false or is not defined, or the lowest value is 
returned if deterministic
+   * is true.
+   *
+   * @group agg_funcs
+   * @since 4.0.0
+   */
+  def mode(e: Column, deterministic: Boolean): Column = withAggregateFunction {
+    Mode(e.expr, deterministicExpr = lit(deterministic).expr)
+  }
 
   /**
    * Aggregate function: returns the maximum value of the expression in a 
group.
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md 
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index f518a67e1fa..9e06d5ac58a 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -404,7 +404,7 @@
 | org.apache.spark.sql.catalyst.expressions.aggregate.Median | median | SELECT 
median(col) FROM VALUES (0), (10) AS tab(col) | struct<median(col):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.Min | min | SELECT 
min(col) FROM VALUES (10), (-1), (20) AS tab(col) | struct<min(col):int> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.MinBy | min_by | SELECT 
min_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y) | 
struct<min_by(x, y):string> |
-| org.apache.spark.sql.catalyst.expressions.aggregate.Mode | mode | SELECT 
mode(col) FROM VALUES (0), (10), (10) AS tab(col) | struct<mode(col):int> |
+| org.apache.spark.sql.catalyst.expressions.aggregate.Mode | mode | SELECT 
mode(col) FROM VALUES (0), (10), (10) AS tab(col) | struct<mode(col, 
false):int> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.Percentile | percentile 
| SELECT percentile(col, 0.3) FROM VALUES (0), (10) AS tab(col) | 
struct<percentile(col, 0.3, 1):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgX | regr_avgx | 
SELECT regr_avgx(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) 
| struct<regr_avgx(y, x):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgY | regr_avgy | 
SELECT regr_avgy(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) 
| struct<regr_avgy(y, x):double> |
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
index 202ceee1804..56b2553045f 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
@@ -1155,7 +1155,7 @@ Aggregate [a#x], [a#x, collect_list(b#x, 0, 0) AS 
collect_list(b)#x, collect_lis
 -- !query
 SELECT mode(a), mode(b) FROM testData
 -- !query analysis
-Aggregate [mode(a#x, 0, 0) AS mode(a)#x, mode(b#x, 0, 0) AS mode(b)#x]
+Aggregate [mode(a#x, 0, 0, false) AS mode(a, false)#x, mode(b#x, 0, 0, false) 
AS mode(b, false)#x]
 +- SubqueryAlias testdata
    +- View (`testData`, [a#x,b#x])
       +- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x]
@@ -1168,7 +1168,7 @@ Aggregate [mode(a#x, 0, 0) AS mode(a)#x, mode(b#x, 0, 0) 
AS mode(b)#x]
 SELECT a, mode(b) FROM testData GROUP BY a ORDER BY a
 -- !query analysis
 Sort [a#x ASC NULLS FIRST], true
-+- Aggregate [a#x], [a#x, mode(b#x, 0, 0) AS mode(b)#x]
++- Aggregate [a#x], [a#x, mode(b#x, 0, 0, false) AS mode(b, false)#x]
    +- SubqueryAlias testdata
       +- View (`testData`, [a#x,b#x])
          +- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x]
@@ -1196,3 +1196,119 @@ Aggregate [c#x], [(c#x * 2) AS d#x]
          +- Project [if ((a#x < 0)) 0 else a#x AS b#x]
             +- SubqueryAlias t1
                +- LocalRelation [a#x]
+
+
+-- !query
+SELECT mode(col) FROM VALUES (-10), (0), (10) AS tab(col)
+-- !query analysis
+Aggregate [mode(col#x, 0, 0, false) AS mode(col, false)#x]
++- SubqueryAlias tab
+   +- LocalRelation [col#x]
+
+
+-- !query
+SELECT mode(col, false) FROM VALUES (-10), (0), (10) AS tab(col)
+-- !query analysis
+Aggregate [mode(col#x, 0, 0, false) AS mode(col, false)#x]
++- SubqueryAlias tab
+   +- LocalRelation [col#x]
+
+
+-- !query
+SELECT mode(col, true) FROM VALUES (-10), (0), (10) AS tab(col)
+-- !query analysis
+Aggregate [mode(col#x, 0, 0, true) AS mode(col, true)#x]
++- SubqueryAlias tab
+   +- LocalRelation [col#x]
+
+
+-- !query
+SELECT mode(col, 'true') FROM VALUES (-10), (0), (10) AS tab(col)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputSql" : "\"true\"",
+    "inputType" : "\"STRING\"",
+    "paramIndex" : "2",
+    "requiredType" : "\"BOOLEAN\"",
+    "sqlExpr" : "\"mode(col, true)\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 24,
+    "fragment" : "mode(col, 'true')"
+  } ]
+}
+
+
+-- !query
+SELECT mode(col, null) FROM VALUES (-10), (0), (10) AS tab(col)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_NULL",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "exprName" : "`deterministic`",
+    "sqlExpr" : "\"mode(col, NULL)\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 22,
+    "fragment" : "mode(col, null)"
+  } ]
+}
+
+
+-- !query
+SELECT mode(col, b) FROM VALUES (-10, false), (0, false), (10, false) AS 
tab(col, b)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputExpr" : "\"b\"",
+    "inputName" : "`deterministic`",
+    "inputType" : "\"BOOLEAN\"",
+    "sqlExpr" : "\"mode(col, b)\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 19,
+    "fragment" : "mode(col, b)"
+  } ]
+}
+
+
+-- !query
+SELECT mode(col) FROM VALUES (map(1, 'a')) AS tab(col)
+-- !query analysis
+Aggregate [mode(col#x, 0, 0, false) AS mode(col, false)#x]
++- SubqueryAlias tab
+   +- LocalRelation [col#x]
+
+
+-- !query
+SELECT mode(col, false) FROM VALUES (map(1, 'a')) AS tab(col)
+-- !query analysis
+Aggregate [mode(col#x, 0, 0, false) AS mode(col, false)#x]
++- SubqueryAlias tab
+   +- LocalRelation [col#x]
+
+
+-- !query
+SELECT mode(col, true) FROM VALUES (map(1, 'a')) AS tab(col)
+-- !query analysis
+Aggregate [mode(col#x, 0, 0, true) AS mode(col, true)#x]
++- SubqueryAlias tab
+   +- LocalRelation [col#x]
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql 
b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index c35cdb0de27..4b76510b65f 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -264,3 +264,14 @@ FROM (
          GROUP BY b
      ) t3
 GROUP BY c;
+
+-- SPARK-45034: Support deterministic mode function
+SELECT mode(col) FROM VALUES (-10), (0), (10) AS tab(col);
+SELECT mode(col, false) FROM VALUES (-10), (0), (10) AS tab(col);
+SELECT mode(col, true) FROM VALUES (-10), (0), (10) AS tab(col);
+SELECT mode(col, 'true') FROM VALUES (-10), (0), (10) AS tab(col);
+SELECT mode(col, null) FROM VALUES (-10), (0), (10) AS tab(col);
+SELECT mode(col, b) FROM VALUES (-10, false), (0, false), (10, false) AS 
tab(col, b);
+SELECT mode(col) FROM VALUES (map(1, 'a')) AS tab(col);
+SELECT mode(col, false) FROM VALUES (map(1, 'a')) AS tab(col);
+SELECT mode(col, true) FROM VALUES (map(1, 'a')) AS tab(col);
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out 
b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index db79646fe43..ac92c369de2 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1089,7 +1089,7 @@ 
struct<a:int,collect_list(b):array<int>,collect_list(b):array<int>>
 -- !query
 SELECT mode(a), mode(b) FROM testData
 -- !query schema
-struct<mode(a):int,mode(b):int>
+struct<mode(a, false):int,mode(b, false):int>
 -- !query output
 3      1
 
@@ -1097,7 +1097,7 @@ struct<mode(a):int,mode(b):int>
 -- !query
 SELECT a, mode(b) FROM testData GROUP BY a ORDER BY a
 -- !query schema
-struct<a:int,mode(b):int>
+struct<a:int,mode(b, false):int>
 -- !query output
 NULL   1
 1      1
@@ -1121,3 +1121,131 @@ struct<d:int>
 -- !query output
 0
 2
+
+
+-- !query
+SELECT mode(col) FROM VALUES (-10), (0), (10) AS tab(col)
+-- !query schema
+struct<mode(col, false):int>
+-- !query output
+0
+
+
+-- !query
+SELECT mode(col, false) FROM VALUES (-10), (0), (10) AS tab(col)
+-- !query schema
+struct<mode(col, false):int>
+-- !query output
+0
+
+
+-- !query
+SELECT mode(col, true) FROM VALUES (-10), (0), (10) AS tab(col)
+-- !query schema
+struct<mode(col, true):int>
+-- !query output
+-10
+
+
+-- !query
+SELECT mode(col, 'true') FROM VALUES (-10), (0), (10) AS tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputSql" : "\"true\"",
+    "inputType" : "\"STRING\"",
+    "paramIndex" : "2",
+    "requiredType" : "\"BOOLEAN\"",
+    "sqlExpr" : "\"mode(col, true)\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 24,
+    "fragment" : "mode(col, 'true')"
+  } ]
+}
+
+
+-- !query
+SELECT mode(col, null) FROM VALUES (-10), (0), (10) AS tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_NULL",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "exprName" : "`deterministic`",
+    "sqlExpr" : "\"mode(col, NULL)\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 22,
+    "fragment" : "mode(col, null)"
+  } ]
+}
+
+
+-- !query
+SELECT mode(col, b) FROM VALUES (-10, false), (0, false), (10, false) AS 
tab(col, b)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputExpr" : "\"b\"",
+    "inputName" : "`deterministic`",
+    "inputType" : "\"BOOLEAN\"",
+    "sqlExpr" : "\"mode(col, b)\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 19,
+    "fragment" : "mode(col, b)"
+  } ]
+}
+
+
+-- !query
+SELECT mode(col) FROM VALUES (map(1, 'a')) AS tab(col)
+-- !query schema
+struct<mode(col, false):map<int,string>>
+-- !query output
+{1:"a"}
+
+
+-- !query
+SELECT mode(col, false) FROM VALUES (map(1, 'a')) AS tab(col)
+-- !query schema
+struct<mode(col, false):map<int,string>>
+-- !query output
+{1:"a"}
+
+
+-- !query
+SELECT mode(col, true) FROM VALUES (map(1, 'a')) AS tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkIllegalArgumentException
+{
+  "errorClass" : "_LEGACY_ERROR_TEMP_2005",
+  "messageParameters" : {
+    "dataType" : "PhysicalMapType"
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index e9daa825dd4..2de2d90e7dd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -432,4 +432,14 @@ class DatasetAggregatorSuite extends QueryTest with 
SharedSparkSession {
     val agg = df.select(mode(col("a"))).as[String]
     checkDataset(agg, "3")
   }
+
+  test("SPARK-45034: Support deterministic mode function") {
+    val df = Seq(-10, 0, 10).toDF("col")
+
+    val agg = df.select(mode(col("col"), false))
+    checkAnswer(agg, Row(0))
+
+    val agg2 = df.select(mode(col("col"), true))
+    checkAnswer(agg2, Row(-10))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to