peter-toth commented on code in PR #42755:
URL: https://github.com/apache/spark/pull/42755#discussion_r1325516552


##########
python/pyspark/sql/functions.py:
##########
@@ -1195,12 +1195,12 @@ def mode(col: "ColumnOrName") -> Column:
     ...     ("dotNET", 2013, 48000), ("Java", 2013, 30000)],
     ...     schema=("course", "year", "earnings"))
     >>> df.groupby("course").agg(mode("year")).show()
-    +------+----------+
-    |course|mode(year)|
-    +------+----------+
-    |  Java|      2012|
-    |dotNET|      2012|
-    +------+----------+
+    +------+-----------------+
+    |course|mode(year, false)|

Review Comment:
   Thanks! This is actually a good question and I think we shall.
   I've added the new param to both Scala and Python APIs in both SQL and 
Connect in the latest.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala:
##########
@@ -35,24 +42,57 @@ import org.apache.spark.util.collection.OpenHashMap
        0-10
       > SELECT _FUNC_(col) FROM VALUES (0), (10), (10), (null), (null), (null) 
AS tab(col);
        10
+      > SELECT _FUNC_(col, false) FROM VALUES (-10), (0), (10) AS tab(col);
+       0
+      > SELECT _FUNC_(col, true) FROM VALUES (-10), (0), (10) AS tab(col);
+       -10
   """,
   group = "agg_funcs",
   since = "3.4.0")
 // scalastyle:on line.size.limit
 case class Mode(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0) extends TypedAggregateWithHashMapAsBuffer
-  with ImplicitCastInputTypes with UnaryLike[Expression] {
+    inputAggBufferOffset: Int = 0,
+    deterministicResult: Expression = Literal.FalseLiteral)
+  extends TypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes
+    with BinaryLike[Expression] {
 
   def this(child: Expression) = this(child, 0, 0)
 
+  def this(child: Expression, deterministicResult: Expression) = {
+    this(child, 0, 0, deterministicResult)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = deterministicResult
+
   // Returns null for empty inputs
   override def nullable: Boolean = true
 
   override def dataType: DataType = child.dataType
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, 
BooleanType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      return defaultCheck
+    }
+    if (!deterministicResult.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> "deterministicResult",

Review Comment:
   Ok, fixed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to