MaxGekk commented on code in PR #42755:
URL: https://github.com/apache/spark/pull/42755#discussion_r1323175490


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala:
##########
@@ -35,24 +42,57 @@ import org.apache.spark.util.collection.OpenHashMap
        0-10
       > SELECT _FUNC_(col) FROM VALUES (0), (10), (10), (null), (null), (null) 
AS tab(col);
        10
+      > SELECT _FUNC_(col, false) FROM VALUES (-10), (0), (10) AS tab(col);
+       0
+      > SELECT _FUNC_(col, true) FROM VALUES (-10), (0), (10) AS tab(col);
+       -10
   """,
   group = "agg_funcs",
   since = "3.4.0")
 // scalastyle:on line.size.limit
 case class Mode(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0) extends TypedAggregateWithHashMapAsBuffer
-  with ImplicitCastInputTypes with UnaryLike[Expression] {
+    inputAggBufferOffset: Int = 0,
+    deterministicResult: Expression = Literal.FalseLiteral)
+  extends TypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes
+    with BinaryLike[Expression] {
 
   def this(child: Expression) = this(child, 0, 0)
 
+  def this(child: Expression, deterministicResult: Expression) = {
+    this(child, 0, 0, deterministicResult)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = deterministicResult
+
   // Returns null for empty inputs
   override def nullable: Boolean = true
 
   override def dataType: DataType = child.dataType
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, 
BooleanType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      return defaultCheck
+    }
+    if (!deterministicResult.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> "deterministicResult",

Review Comment:
   It is an id, so, should be quoted by `toSQLId`



##########
python/pyspark/sql/functions.py:
##########
@@ -1195,12 +1195,12 @@ def mode(col: "ColumnOrName") -> Column:
     ...     ("dotNET", 2013, 48000), ("Java", 2013, 30000)],
     ...     schema=("course", "year", "earnings"))
     >>> df.groupby("course").agg(mode("year")).show()
-    +------+----------+
-    |course|mode(year)|
-    +------+----------+
-    |  Java|      2012|
-    |dotNET|      2012|
-    +------+----------+
+    +------+-----------------+
+    |course|mode(year, false)|

Review Comment:
   Shell we support new parameter in Python API too?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to