This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6976ae76a650 [SPARK-55322][SQL] `MaxBy` and `MinBy` Overload with K
Elements
6976ae76a650 is described below
commit 6976ae76a650212dd614857d9128f5f75e8e6397
Author: Alexis Schlomer <[email protected]>
AuthorDate: Fri Feb 20 13:27:35 2026 -0800
[SPARK-55322][SQL] `MaxBy` and `MinBy` Overload with K Elements
### What changes were proposed in this pull request?
Adds an optional third parameter `k` to `max_by` and `min_by` aggregate
functions, enabling them to return an array of the top/bottom k values instead
of a single value.
```
-- Returns array of 2 values with highest y
SELECT max_by(x, y, 2) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS t(x,
y);
-- ["b", "c"]
-- Returns array of 2 values with lowest y
SELECT min_by(x, y, 2) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS t(x,
y);
-- ["a", "c"]
```
Implementation uses a bounded heap during aggregation, avoiding full sorts.
NULL ordering keys are excluded (consistent with existing max_by/min_by
behavior).
### Why are the changes needed?
Replaces verbose CTE + window function ranking patterns with a single
aggregation call. Aligns Spark with Snowflake, DuckDB, and Trino which offer
similar functionality.
### Does this PR introduce _any_ user-facing change?
Yes. New 3-argument overloads for `max_by(expr, ord, k)` and `min_by(expr,
ord, k)` that return array<T> instead of T. Existing 2-argument versions are
unchanged. See example above.
### How was this patch tested?
Unit tests in `DataFrameAggregateSuite.scala` covering basic usage, edge
cases (k=1, k > rows, all NULLs), varied types (int, string, struct), GROUP BY,
DataFrame API, and error conditions.
New golden SQL file.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: claude-4.5-opus-high (co-author)
Closes #54134 from AlSchlo/master.
Lead-authored-by: Alexis Schlomer <[email protected]>
Co-authored-by: Alexis Schlomer <[email protected]>
Signed-off-by: Liang-Chi Hsieh <[email protected]>
---
python/pyspark/sql/connect/functions/builtin.py | 8 +-
python/pyspark/sql/functions/builtin.py | 62 ++-
.../sql/tests/connect/test_connect_function.py | 10 +
python/pyspark/sql/tests/test_functions.py | 42 ++
.../scala/org/apache/spark/sql/functions.scala | 74 ++++
.../sql/catalyst/analysis/FunctionRegistry.scala | 4 +-
.../catalyst/expressions/aggregate/MaxMinByK.scala | 303 +++++++++++++++
.../expressions/aggregate/MaxMinByKHeap.scala | 170 ++++++++
.../sql-functions/sql-expression-schema.md | 4 +-
.../analyzer-results/max-min-by-k.sql.out | 432 +++++++++++++++++++++
.../resources/sql-tests/inputs/max-min-by-k.sql | 105 +++++
.../sql-tests/results/max-min-by-k.sql.out | 347 +++++++++++++++++
.../apache/spark/sql/DataFrameAggregateSuite.scala | 301 ++++++++++++++
13 files changed, 1850 insertions(+), 12 deletions(-)
diff --git a/python/pyspark/sql/connect/functions/builtin.py
b/python/pyspark/sql/connect/functions/builtin.py
index b22fd224ae67..02109d2ef2cc 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -1236,7 +1236,9 @@ def max(col: "ColumnOrName") -> Column:
max.__doc__ = pysparkfuncs.max.__doc__
-def max_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column:
+def max_by(col: "ColumnOrName", ord: "ColumnOrName", k: Optional[int] = None)
-> Column:
+ if k is not None:
+ return _invoke_function_over_columns("max_by", col, ord, lit(k))
return _invoke_function_over_columns("max_by", col, ord)
@@ -1264,7 +1266,9 @@ def min(col: "ColumnOrName") -> Column:
min.__doc__ = pysparkfuncs.min.__doc__
-def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column:
+def min_by(col: "ColumnOrName", ord: "ColumnOrName", k: Optional[int] = None)
-> Column:
+ if k is not None:
+ return _invoke_function_over_columns("min_by", col, ord, lit(k))
return _invoke_function_over_columns("min_by", col, ord)
diff --git a/python/pyspark/sql/functions/builtin.py
b/python/pyspark/sql/functions/builtin.py
index 7bf6128b3d5e..58662be0e960 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -1329,17 +1329,23 @@ def min(col: "ColumnOrName") -> Column:
@_try_remote_functions
-def max_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column:
+def max_by(col: "ColumnOrName", ord: "ColumnOrName", k: Optional[int] = None)
-> Column:
"""
- Returns the value from the `col` parameter that is associated with the
maximum value
+ Returns the value(s) from the `col` parameter that are associated with the
maximum value(s)
from the `ord` parameter. This function is often used to find the `col`
parameter value
corresponding to the maximum `ord` parameter value within each group when
used with groupBy().
+ When `k` is specified, returns an array of up to `k` values associated
with the top `k`
+ maximum values from `ord`.
+
.. versionadded:: 3.3.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
+ .. versionchanged:: 4.2.0
+ Added optional `k` parameter to return top-k values.
+
Notes
-----
The function is non-deterministic so the output order can be different for
those
@@ -1353,12 +1359,16 @@ def max_by(col: "ColumnOrName", ord: "ColumnOrName") ->
Column:
ord : :class:`~pyspark.sql.Column` or column name
The column that needs to be maximized. This could be the column
instance
or the column name as string.
+ k : int, optional
+ If specified, returns an array of up to `k` values associated with the
top `k`
+ maximum ordering values, sorted in descending order by the ordering
column.
+ Must be a positive integer literal <= 100000.
Returns
-------
:class:`~pyspark.sql.Column`
A column object representing the value from `col` that is associated
with
- the maximum value from `ord`.
+ the maximum value from `ord`. If `k` is specified, returns an array of
values.
Examples
--------
@@ -1410,22 +1420,43 @@ def max_by(col: "ColumnOrName", ord: "ColumnOrName") ->
Column:
| Consult| Henry|
| Finance| George|
+----------+---------------------------+
+
+ Example 4: Using `max_by` with `k` to get top-k values
+
+ >>> import pyspark.sql.functions as sf
+ >>> df = spark.createDataFrame([
+ ... ("a", 10), ("b", 50), ("c", 20), ("d", 40)],
+ ... schema=("x", "y"))
+ >>> df.select(sf.max_by("x", "y", 2)).show()
+ +---------------+
+ |max_by(x, y, 2)|
+ +---------------+
+ | [b, d]|
+ +---------------+
"""
+ if k is not None:
+ return _invoke_function_over_columns("max_by", col, ord, lit(k))
return _invoke_function_over_columns("max_by", col, ord)
@_try_remote_functions
-def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column:
+def min_by(col: "ColumnOrName", ord: "ColumnOrName", k: Optional[int] = None)
-> Column:
"""
- Returns the value from the `col` parameter that is associated with the
minimum value
+ Returns the value(s) from the `col` parameter that are associated with the
minimum value(s)
from the `ord` parameter. This function is often used to find the `col`
parameter value
corresponding to the minimum `ord` parameter value within each group when
used with groupBy().
+ When `k` is specified, returns an array of up to `k` values associated
with the bottom `k`
+ minimum values from `ord`.
+
.. versionadded:: 3.3.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
+ .. versionchanged:: 4.2.0
+ Added optional `k` parameter to return bottom-k values.
+
Notes
-----
The function is non-deterministic so the output order can be different for
those
@@ -1439,12 +1470,16 @@ def min_by(col: "ColumnOrName", ord: "ColumnOrName") ->
Column:
ord : :class:`~pyspark.sql.Column` or column name
The column that needs to be minimized. This could be the column
instance
or the column name as string.
+ k : int, optional
+ If specified, returns an array of up to `k` values associated with the
bottom `k`
+ minimum ordering values, sorted in ascending order by the ordering
column.
+ Must be a positive integer literal <= 100000.
Returns
-------
:class:`~pyspark.sql.Column`
Column object that represents the value from `col` associated with
- the minimum value from `ord`.
+ the minimum value from `ord`. If `k` is specified, returns an array of
values.
Examples
--------
@@ -1496,7 +1531,22 @@ def min_by(col: "ColumnOrName", ord: "ColumnOrName") ->
Column:
| Consult| Eva|
| Finance| Frank|
+----------+---------------------------+
+
+ Example 4: Using `min_by` with `k` to get bottom-k values
+
+ >>> import pyspark.sql.functions as sf
+ >>> df = spark.createDataFrame([
+ ... ("a", 10), ("b", 50), ("c", 20), ("d", 40)],
+ ... schema=("x", "y"))
+ >>> df.select(sf.min_by("x", "y", 2)).show()
+ +---------------+
+ |min_by(x, y, 2)|
+ +---------------+
+ | [a, c]|
+ +---------------+
"""
+ if k is not None:
+ return _invoke_function_over_columns("min_by", col, ord, lit(k))
return _invoke_function_over_columns("min_by", col, ord)
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 9af96dbb56ea..de1584cb3bc3 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -622,6 +622,16 @@ class SparkConnectFunctionTests(ReusedMixedTestCase,
PandasOnSparkTestUtils):
sdf.groupBy("a").agg(sfunc(sdf.b,
"c")).orderBy("a").toPandas(),
)
+ # test max_by and min_by with k parameter
+ self.assert_eq(
+ cdf.select(CF.max_by(cdf.b, "c", 2)).toPandas(),
+ sdf.select(SF.max_by(sdf.b, "c", 2)).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.min_by(cdf.b, "c", 2)).toPandas(),
+ sdf.select(SF.min_by(sdf.b, "c", 2)).toPandas(),
+ )
+
# test grouping
self.assert_eq(
cdf.cube("a").agg(CF.grouping("a"),
CF.sum("c")).orderBy("a").toPandas(),
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index bfad39d6babc..b2a566afd500 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -3654,6 +3654,48 @@ class FunctionsTestsMixin:
)
self.assertEqual(results, [expected])
+ def test_max_by_min_by_with_k(self):
+ """Test max_by and min_by aggregate functions with k parameter"""
+ df = self.spark.createDataFrame(
+ [("a", 10), ("b", 50), ("c", 20), ("d", 40), ("e", 30)],
+ schema=("x", "y"),
+ )
+
+ # Test max_by with k
+ result = df.select(F.max_by("x", "y", 3)).collect()[0][0]
+ self.assertEqual(result, ["b", "d", "e"])
+
+ # Test min_by with k
+ result = df.select(F.min_by("x", "y", 3)).collect()[0][0]
+ self.assertEqual(result, ["a", "c", "e"])
+
+ # Test k = 1
+ result = df.select(F.max_by("x", "y", 1)).collect()[0][0]
+ self.assertEqual(result, ["b"])
+
+ result = df.select(F.min_by("x", "y", 1)).collect()[0][0]
+ self.assertEqual(result, ["a"])
+
+ # Test k larger than row count
+ result = df.select(F.max_by("x", "y", 10)).collect()[0][0]
+ self.assertEqual(sorted(result), ["a", "b", "c", "d", "e"])
+
+ # Test with groupBy
+ df2 = self.spark.createDataFrame(
+ [
+ ("Eng", "Alice", 120000),
+ ("Eng", "Bob", 95000),
+ ("Eng", "Carol", 110000),
+ ("Sales", "Dave", 80000),
+ ("Sales", "Eve", 75000),
+ ("Sales", "Frank", 85000),
+ ],
+ schema=("dept", "emp", "salary"),
+ )
+ result = df2.groupBy("dept").agg(F.max_by("emp", "salary",
2)).orderBy("dept").collect()
+ self.assertEqual(result[0][1], ["Alice", "Carol"]) # Eng
+ self.assertEqual(result[1][1], ["Frank", "Dave"]) # Sales
+
class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
pass
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
index 261584ef186c..69cd1ca3dc9f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
@@ -938,6 +938,42 @@ object functions {
*/
def max_by(e: Column, ord: Column): Column = Column.fn("max_by", e, ord)
+ /**
+ * Aggregate function: returns an array of values associated with the top
`k` values of `ord`.
+ *
+ * The result array contains values in descending order by their associated
ordering values.
+ * Returns null if there are no non-null ordering values.
+ *
+ * @note
+ * The function is non-deterministic because the order of collected
results depends on the
+ * order of the rows which may be non-deterministic after a shuffle when
there are ties in the
+ * ordering expression.
+ * @note
+ * The maximum value of `k` is 100000.
+ *
+ * @group agg_funcs
+ * @since 4.2.0
+ */
+ def max_by(e: Column, ord: Column, k: Int): Column = Column.fn("max_by", e,
ord, lit(k))
+
+ /**
+ * Aggregate function: returns an array of values associated with the top
`k` values of `ord`.
+ *
+ * The result array contains values in descending order by their associated
ordering values.
+ * Returns null if there are no non-null ordering values.
+ *
+ * @note
+ * The function is non-deterministic because the order of collected
results depends on the
+ * order of the rows which may be non-deterministic after a shuffle when
there are ties in the
+ * ordering expression.
+ * @note
+ * The maximum value of `k` is 100000.
+ *
+ * @group agg_funcs
+ * @since 4.2.0
+ */
+ def max_by(e: Column, ord: Column, k: Column): Column = Column.fn("max_by",
e, ord, k)
+
/**
* Aggregate function: returns the average of the values in a group. Alias
for avg.
*
@@ -990,6 +1026,44 @@ object functions {
*/
def min_by(e: Column, ord: Column): Column = Column.fn("min_by", e, ord)
+ /**
+ * Aggregate function: returns an array of values associated with the bottom
`k` values of
+ * `ord`.
+ *
+ * The result array contains values in ascending order by their associated
ordering values.
+ * Returns null if there are no non-null ordering values.
+ *
+ * @note
+ * The function is non-deterministic because the order of collected
results depends on the
+ * order of the rows which may be non-deterministic after a shuffle when
there are ties in the
+ * ordering expression.
+ * @note
+ * The maximum value of `k` is 100000.
+ *
+ * @group agg_funcs
+ * @since 4.2.0
+ */
+ def min_by(e: Column, ord: Column, k: Int): Column = Column.fn("min_by", e,
ord, lit(k))
+
+ /**
+ * Aggregate function: returns an array of values associated with the bottom
`k` values of
+ * `ord`.
+ *
+ * The result array contains values in ascending order by their associated
ordering values.
+ * Returns null if there are no non-null ordering values.
+ *
+ * @note
+ * The function is non-deterministic because the order of collected
results depends on the
+ * order of the rows which may be non-deterministic after a shuffle when
there are ties in the
+ * ordering expression.
+ * @note
+ * The maximum value of `k` is 100000.
+ *
+ * @group agg_funcs
+ * @since 4.2.0
+ */
+ def min_by(e: Column, ord: Column, k: Column): Column = Column.fn("min_by",
e, ord, k)
+
/**
* Aggregate function: returns the exact percentile(s) of numeric column
`expr` at the given
* percentage(s) with value range in [0.0, 1.0].
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 9468b14f6da1..86c705abb6f3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -485,10 +485,10 @@ object FunctionRegistry {
expression[Last]("last"),
expression[Last]("last_value", true),
expression[Max]("max"),
- expression[MaxBy]("max_by"),
+ expressionBuilder("max_by", MaxByBuilder),
expression[Average]("mean", true),
expression[Min]("min"),
- expression[MinBy]("min_by"),
+ expressionBuilder("min_by", MinByBuilder),
expression[Percentile]("percentile"),
expressionBuilder("percentile_cont", PercentileContBuilder),
expressionBuilder("percentile_disc", PercentileDiscBuilder),
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxMinByK.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxMinByK.scala
new file mode 100644
index 000000000000..a0a7acd93097
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxMinByK.scala
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder,
TypeCheckResult}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils}
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types._
+
+/**
+ * Returns top/bottom K values ordered by orderingExpr.
+ * Uses a heap (min-heap for max_by, max-heap for min_by) to efficiently
maintain K elements.
+ * This is the internal implementation used by max_by(x, y, k) and min_by(x,
y, k).
+ * Returns NULL if there are no non-NULL ordering values or the input is empty.
+ */
+case class MaxMinByK(
+ valueExpr: Expression,
+ orderingExpr: Expression,
+ kExpr: Expression,
+ reverse: Boolean = false,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends ImperativeAggregate
+ with TernaryLike[Expression]
+ with ImplicitCastInputTypes {
+
+ def this(valueExpr: Expression, orderingExpr: Expression, kExpr: Expression)
=
+ this(valueExpr, orderingExpr, kExpr, false, 0, 0)
+
+ def this(valueExpr: Expression, orderingExpr: Expression, kExpr: Expression,
reverse: Boolean) =
+ this(valueExpr, orderingExpr, kExpr, reverse, 0, 0)
+
+ final val MAX_K = 100000
+ // After ImplicitCastInputTypes casts kExpr to IntegerType and
+ // checkInputDataTypes() validates foldability and range, eval() is safe
here.
+ lazy val k: Int = kExpr.eval().asInstanceOf[Int]
+
+ override def first: Expression = valueExpr
+ override def second: Expression = orderingExpr
+ override def third: Expression = kExpr
+
+ override def prettyName: String = if (reverse) "min_by" else "max_by"
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = ArrayType(valueExpr.dataType, containsNull
= true)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(
+ AnyDataType,
+ AnyDataType,
+ IntegerType
+ )
+
+ private lazy val valuesAttr = AttributeReference(
+ "values",
+ ArrayType(valueExpr.dataType, containsNull = true),
+ nullable = false
+ )()
+ private lazy val orderingsAttr = AttributeReference(
+ "orderings",
+ ArrayType(orderingExpr.dataType, containsNull = true),
+ nullable = false
+ )()
+ private lazy val heapIndicesAttr = AttributeReference(
+ "heapIndices",
+ BinaryType,
+ nullable = false
+ )()
+
+ override lazy val aggBufferAttributes: Seq[AttributeReference] =
+ Seq(valuesAttr, orderingsAttr, heapIndicesAttr)
+
+ private val VALUES_OFFSET = 0
+ private val ORDERINGS_OFFSET = 1
+ private val HEAP_OFFSET = 2
+
+ override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
+
+ override def aggBufferSchema: StructType =
DataTypeUtils.fromAttributes(aggBufferAttributes)
+ override def defaultResult: Option[Literal] = Option(Literal.create(null,
dataType))
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val parentCheck = super.checkInputDataTypes()
+ if (!parentCheck.isSuccess) return parentCheck
+
+ if (!kExpr.foldable) {
+ return DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> "k",
+ "inputType" -> kExpr.dataType.catalogString,
+ "inputExpr" -> kExpr.sql
+ )
+ )
+ }
+
+ val orderingCheck = TypeUtils.checkForOrderingExpr(orderingExpr.dataType,
prettyName)
+ if (!orderingCheck.isSuccess) return orderingCheck
+
+ if (k < 1 || k > MAX_K) {
+ return DataTypeMismatch(
+ errorSubClass = "VALUE_OUT_OF_RANGE",
+ messageParameters = Map(
+ "exprName" -> toSQLId("k"),
+ "valueRange" -> s"[1, $MAX_K]",
+ "currentValue" -> k.toString
+ )
+ )
+ }
+
+ TypeCheckResult.TypeCheckSuccess
+ }
+
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(orderingExpr.dataType)
+
+ // max_by uses min-heap (smaller at top), min_by uses max-heap (larger at
top)
+ private def heapCompare(ordA: Any, ordB: Any): Int =
+ if (reverse) -ordering.compare(ordA, ordB) else ordering.compare(ordA,
ordB)
+
+ override def initialize(buffer: InternalRow): Unit = {
+ val offset = mutableAggBufferOffset
+ buffer.update(offset + VALUES_OFFSET, new GenericArrayData(new
Array[Any](k)))
+ buffer.update(offset + ORDERINGS_OFFSET, new GenericArrayData(new
Array[Any](k)))
+ // heapBytes is binary: [size (4 bytes), idx0 (4 bytes), ..., idx(k-1) (4
bytes)]
+ val heapBytes = new Array[Byte]((k + 1) * 4)
+ // size is already 0 (zero-initialized)
+ buffer.update(offset + HEAP_OFFSET, heapBytes)
+ }
+
+ override def update(mutableAggBuffer: InternalRow, inputRow: InternalRow):
Unit = {
+ val ord = orderingExpr.eval(inputRow)
+ if (ord == null) return
+
+ val value = valueExpr.eval(inputRow)
+ val offset = mutableAggBufferOffset
+
+ val valuesArr = MaxMinByKHeap.getMutableArray(
+ mutableAggBuffer, offset + VALUES_OFFSET, valueExpr.dataType)
+ val orderingsArr = MaxMinByKHeap.getMutableArray(
+ mutableAggBuffer, offset + ORDERINGS_OFFSET, orderingExpr.dataType)
+ val heap = MaxMinByKHeap.getMutableHeap(mutableAggBuffer, offset +
HEAP_OFFSET)
+
+ MaxMinByKHeap.insert(value, ord, k, valuesArr, orderingsArr, heap,
heapCompare)
+ }
+
+ override def merge(mutableAggBuffer: InternalRow, inputAggBuffer:
InternalRow): Unit = {
+ val offset = mutableAggBufferOffset
+ val inOff = inputAggBufferOffset
+
+ val valuesArr = MaxMinByKHeap.getMutableArray(
+ mutableAggBuffer, offset + VALUES_OFFSET, valueExpr.dataType)
+ val orderingsArr = MaxMinByKHeap.getMutableArray(
+ mutableAggBuffer, offset + ORDERINGS_OFFSET, orderingExpr.dataType)
+ val heap = MaxMinByKHeap.getMutableHeap(mutableAggBuffer, offset +
HEAP_OFFSET)
+
+ val inputValues = MaxMinByKHeap.readArray(
+ inputAggBuffer.getArray(inOff + VALUES_OFFSET), valueExpr.dataType)
+ val inputOrderings = MaxMinByKHeap.readArray(
+ inputAggBuffer.getArray(inOff + ORDERINGS_OFFSET), orderingExpr.dataType)
+ val inputHeap = inputAggBuffer.getBinary(inOff + HEAP_OFFSET)
+ val inputHeapSize = MaxMinByKHeap.getSize(inputHeap)
+
+ for (i <- 0 until inputHeapSize) {
+ val idx = MaxMinByKHeap.getIdx(inputHeap, i)
+ val inputOrd = inputOrderings(idx)
+ if (inputOrd != null) {
+ MaxMinByKHeap.insert(inputValues(idx), inputOrd, k, valuesArr,
orderingsArr, heap,
+ heapCompare)
+ }
+ }
+ }
+
+ override def eval(buffer: InternalRow): Any = {
+ val offset = mutableAggBufferOffset
+
+ val valuesArr = MaxMinByKHeap.getMutableArray(
+ buffer, offset + VALUES_OFFSET, valueExpr.dataType)
+ val orderingsArr = MaxMinByKHeap.getMutableArray(
+ buffer, offset + ORDERINGS_OFFSET, orderingExpr.dataType)
+ val heap = MaxMinByKHeap.getMutableHeap(buffer, offset + HEAP_OFFSET)
+ val heapSize = MaxMinByKHeap.getSize(heap)
+
+ if (heapSize == 0) return null
+
+ val elements = new Array[(Any, Any)](heapSize)
+ for (i <- 0 until heapSize) {
+ elements(i) = (InternalRow.copyValue(valuesArr(i)), orderingsArr(i))
+ }
+
+ // Sort result array (heap maintains K elements but not in sorted order).
+ val sorted = if (reverse) {
+ elements.sortWith { (a, b) => ordering.compare(a._2, b._2) < 0 }
+ } else {
+ elements.sortWith { (a, b) => ordering.compare(a._2, b._2) > 0 }
+ }
+ new GenericArrayData(sorted.map(_._1))
+ }
+
+ override def withNewMutableAggBufferOffset(newOffset: Int):
ImperativeAggregate =
+ copy(mutableAggBufferOffset = newOffset)
+
+ override def withNewInputAggBufferOffset(newOffset: Int):
ImperativeAggregate =
+ copy(inputAggBufferOffset = newOffset)
+
+ override protected def withNewChildrenInternal(
+ newFirst: Expression,
+ newSecond: Expression,
+ newThird: Expression): MaxMinByK =
+ copy(valueExpr = newFirst, orderingExpr = newSecond, kExpr = newThird)
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(x, y) - Returns the value of `x` associated with the maximum value
of `y`.
+ _FUNC_(x, y, k) - Returns an array of the `k` values of `x` associated
with the
+ maximum values of `y`, sorted in descending order by `y`.
+ Returns NULL if there are no non-NULL ordering values.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS
tab(x, y);
+ b
+ > SELECT _FUNC_(x, y, 2) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS
tab(x, y);
+ ["b","c"]
+ """,
+ note = """
+ The function is non-deterministic so the output order can be different for
+ those associated the same values of `y`.
+
+ The maximum value of `k` is 100000.
+ """,
+ group = "agg_funcs",
+ since = "4.2.0")
+// scalastyle:on line.size.limit
+object MaxByBuilder extends ExpressionBuilder {
+ override def build(funcName: String, expressions: Seq[Expression]):
Expression = {
+ expressions.length match {
+ case 2 => MaxBy(expressions(0), expressions(1))
+ case 3 => new MaxMinByK(expressions(0), expressions(1), expressions(2),
reverse = false)
+ case n =>
+ throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(2, 3), n)
+ }
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(x, y) - Returns the value of `x` associated with the minimum value
of `y`.
+ _FUNC_(x, y, k) - Returns an array of the `k` values of `x` associated
with the
+ minimum values of `y`, sorted in ascending order by `y`.
+ Returns NULL if there are no non-NULL ordering values.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS
tab(x, y);
+ a
+ > SELECT _FUNC_(x, y, 2) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS
tab(x, y);
+ ["a","c"]
+ """,
+ note = """
+ The function is non-deterministic so the output order can be different for
+ those associated the same values of `y`.
+
+ The maximum value of `k` is 100000.
+ """,
+ group = "agg_funcs",
+ since = "4.2.0")
+// scalastyle:on line.size.limit
+object MinByBuilder extends ExpressionBuilder {
+ override def build(funcName: String, expressions: Seq[Expression]):
Expression = {
+ expressions.length match {
+ case 2 => MinBy(expressions(0), expressions(1))
+ case 3 => new MaxMinByK(expressions(0), expressions(1), expressions(2),
reverse = true)
+ case n =>
+ throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(2, 3), n)
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxMinByKHeap.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxMinByKHeap.scala
new file mode 100644
index 000000000000..0e855d099906
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxMinByKHeap.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.types.DataType
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Helper for MaxMinByK aggregate providing heap operations.
+ * Heap operates on indices to avoid copying large values.
+ *
+ * Binary heap layout: [size (4 bytes), idx0 (4 bytes), idx1 (4 bytes), ...,
idx(k-1) (4 bytes)]
+ * Total size: (k + 1) * 4 bytes
+ *
+ * All integers are stored in native byte order via Platform (Unsafe) for
direct access.
+ */
+private[catalyst] object MaxMinByKHeap {
+
+ def getSize(heap: Array[Byte]): Int =
+ Platform.getInt(heap, Platform.BYTE_ARRAY_OFFSET)
+
+ def setSize(heap: Array[Byte], size: Int): Unit =
+ Platform.putInt(heap, Platform.BYTE_ARRAY_OFFSET, size)
+
+ def getIdx(heap: Array[Byte], pos: Int): Int =
+ Platform.getInt(heap, Platform.BYTE_ARRAY_OFFSET + (pos + 1).toLong * 4)
+
+ def setIdx(heap: Array[Byte], pos: Int, idx: Int): Unit =
+ Platform.putInt(heap, Platform.BYTE_ARRAY_OFFSET + (pos + 1).toLong * 4,
idx)
+
+ def swap(heap: Array[Byte], i: Int, j: Int): Unit = {
+ val tmp = getIdx(heap, i)
+ setIdx(heap, i, getIdx(heap, j))
+ setIdx(heap, j, tmp)
+ }
+
+ def siftUp(
+ heap: Array[Byte],
+ pos: Int,
+ orderings: Array[Any],
+ compare: (Any, Any) => Int): Unit = {
+ var current = pos
+ while (current > 0) {
+ val parent = (current - 1) / 2
+ val curOrd = orderings(getIdx(heap, current))
+ val parOrd = orderings(getIdx(heap, parent))
+
+ if (compare(curOrd, parOrd) < 0) {
+ swap(heap, current, parent)
+ current = parent
+ } else {
+ return
+ }
+ }
+ }
+
+ def siftDown(
+ heap: Array[Byte],
+ pos: Int,
+ size: Int,
+ orderings: Array[Any],
+ compare: (Any, Any) => Int): Unit = {
+ var current = pos
+ while (2 * current + 1 < size) {
+ val left = 2 * current + 1
+ val right = left + 1
+ val leftOrd = orderings(getIdx(heap, left))
+
+ val preferred = if (right < size) {
+ val rightOrd = orderings(getIdx(heap, right))
+ if (compare(rightOrd, leftOrd) < 0) right else left
+ } else {
+ left
+ }
+
+ val curOrd = orderings(getIdx(heap, current))
+ val prefOrd = orderings(getIdx(heap, preferred))
+ if (compare(curOrd, prefOrd) <= 0) {
+ return
+ }
+
+ swap(heap, current, preferred)
+ current = preferred
+ }
+ }
+
+ /**
+ * Insert element into heap. If heap is full, replaces root if new element
is better.
+ */
+ def insert(
+ value: Any,
+ ord: Any,
+ k: Int,
+ valuesArr: Array[Any],
+ orderingsArr: Array[Any],
+ heap: Array[Byte],
+ compare: (Any, Any) => Int): Unit = {
+ val size = getSize(heap)
+ if (size < k) {
+ valuesArr(size) = InternalRow.copyValue(value)
+ orderingsArr(size) = InternalRow.copyValue(ord)
+
+ setIdx(heap, size, size)
+ siftUp(heap, size, orderingsArr, compare)
+ setSize(heap, size + 1)
+ } else if (compare(ord, orderingsArr(getIdx(heap, 0))) > 0) {
+ val idx = getIdx(heap, 0)
+ valuesArr(idx) = InternalRow.copyValue(value)
+ orderingsArr(idx) = InternalRow.copyValue(ord)
+
+ siftDown(heap, 0, size, orderingsArr, compare)
+ }
+ }
+
+ /**
+ * Get mutable array from buffer for in-place updates.
+ * Converts UnsafeArrayData (after spill) to GenericArrayData.
+ */
+ def getMutableArray(buffer: InternalRow, offset: Int, elementType:
DataType): Array[Any] = {
+ buffer.getArray(offset) match {
+ case g: GenericArrayData =>
+ g.array.asInstanceOf[Array[Any]]
+ case other =>
+ val size = other.numElements()
+ val newArr = new Array[Any](size)
+
+ for (i <- 0 until size) {
+ if (!other.isNullAt(i)) {
+ newArr(i) = InternalRow.copyValue(other.get(i, elementType))
+ }
+ }
+
+ val newArrayData = new GenericArrayData(newArr)
+ buffer.update(offset, newArrayData)
+ newArr
+ }
+ }
+
+ /**
+ * Get mutable heap binary buffer from buffer for in-place updates.
+ */
+ def getMutableHeap(buffer: InternalRow, offset: Int): Array[Byte] = {
+ buffer.getBinary(offset)
+ }
+
+ /** Read-only view of array data, used during merge to read input buffer. */
+ def readArray(arr: ArrayData, elementType: DataType): IndexedSeq[Any] = {
+ val size = arr.numElements()
+ (0 until size).map { i =>
+ if (arr.isNullAt(i)) null else arr.get(i, elementType)
+ }
+ }
+}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 960cec325f54..80dc935bf506 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -492,10 +492,10 @@
| org.apache.spark.sql.catalyst.expressions.aggregate.ListAgg | listagg |
SELECT listagg(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col) |
struct<listagg(col, NULL):string> |
| org.apache.spark.sql.catalyst.expressions.aggregate.ListAgg | string_agg |
SELECT string_agg(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col) |
struct<string_agg(col, NULL):string> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Max | max | SELECT
max(col) FROM VALUES (10), (50), (20) AS tab(col) | struct<max(col):int> |
-| org.apache.spark.sql.catalyst.expressions.aggregate.MaxBy | max_by | SELECT
max_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y) |
struct<max_by(x, y):string> |
+| org.apache.spark.sql.catalyst.expressions.aggregate.MaxByBuilder | max_by |
SELECT max_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y) |
struct<max_by(x, y):string> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Median | median | SELECT
median(col) FROM VALUES (0), (10) AS tab(col) | struct<median(col):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Min | min | SELECT
min(col) FROM VALUES (10), (-1), (20) AS tab(col) | struct<min(col):int> |
-| org.apache.spark.sql.catalyst.expressions.aggregate.MinBy | min_by | SELECT
min_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y) |
struct<min_by(x, y):string> |
+| org.apache.spark.sql.catalyst.expressions.aggregate.MinByBuilder | min_by |
SELECT min_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y) |
struct<min_by(x, y):string> |
| org.apache.spark.sql.catalyst.expressions.aggregate.ModeBuilder | mode |
SELECT mode(col) FROM VALUES (0), (10), (10) AS tab(col) |
struct<mode(col):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Percentile | percentile
| SELECT percentile(col, 0.3) FROM VALUES (0), (10) AS tab(col) |
struct<percentile(col, 0.3, 1):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.PercentileContBuilder |
percentile_cont | SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY col) FROM
VALUES (0), (10) AS tab(col) | struct<percentile_cont(0.25) WITHIN GROUP (ORDER
BY col):double> |
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/max-min-by-k.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/max-min-by-k.sql.out
new file mode 100644
index 000000000000..a5095b351cc5
--- /dev/null
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/max-min-by-k.sql.out
@@ -0,0 +1,432 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW basic_data AS SELECT * FROM VALUES
+ ('Alice', 85),
+ ('Bob', 92),
+ ('Carol', 78),
+ ('Dave', 95),
+ ('Eve', 88)
+AS basic_data(name, score)
+-- !query analysis
+CreateViewCommand `basic_data`, SELECT * FROM VALUES
+ ('Alice', 85),
+ ('Bob', 92),
+ ('Carol', 78),
+ ('Dave', 95),
+ ('Eve', 88)
+AS basic_data(name, score), false, true, LocalTempView, UNSUPPORTED, true
+ +- Project [name#x, score#x]
+ +- SubqueryAlias basic_data
+ +- LocalRelation [name#x, score#x]
+
+
+-- !query
+SELECT max_by(name, score, 3) FROM basic_data
+-- !query analysis
+Aggregate [max_by(name#x, score#x, 3, false, 0, 0) AS max_by(name, score, 3)#x]
++- SubqueryAlias basic_data
+ +- View (`basic_data`, [name#x, score#x])
+ +- Project [cast(name#x as string) AS name#x, cast(score#x as int) AS
score#x]
+ +- Project [name#x, score#x]
+ +- SubqueryAlias basic_data
+ +- LocalRelation [name#x, score#x]
+
+
+-- !query
+SELECT min_by(name, score, 3) FROM basic_data
+-- !query analysis
+Aggregate [min_by(name#x, score#x, 3, true, 0, 0) AS min_by(name, score, 3)#x]
++- SubqueryAlias basic_data
+ +- View (`basic_data`, [name#x, score#x])
+ +- Project [cast(name#x as string) AS name#x, cast(score#x as int) AS
score#x]
+ +- Project [name#x, score#x]
+ +- SubqueryAlias basic_data
+ +- LocalRelation [name#x, score#x]
+
+
+-- !query
+SELECT max_by(name, score, 1) FROM basic_data
+-- !query analysis
+Aggregate [max_by(name#x, score#x, 1, false, 0, 0) AS max_by(name, score, 1)#x]
++- SubqueryAlias basic_data
+ +- View (`basic_data`, [name#x, score#x])
+ +- Project [cast(name#x as string) AS name#x, cast(score#x as int) AS
score#x]
+ +- Project [name#x, score#x]
+ +- SubqueryAlias basic_data
+ +- LocalRelation [name#x, score#x]
+
+
+-- !query
+SELECT min_by(name, score, 1) FROM basic_data
+-- !query analysis
+Aggregate [min_by(name#x, score#x, 1, true, 0, 0) AS min_by(name, score, 1)#x]
++- SubqueryAlias basic_data
+ +- View (`basic_data`, [name#x, score#x])
+ +- Project [cast(name#x as string) AS name#x, cast(score#x as int) AS
score#x]
+ +- Project [name#x, score#x]
+ +- SubqueryAlias basic_data
+ +- LocalRelation [name#x, score#x]
+
+
+-- !query
+SELECT max_by(name, score, 10) FROM basic_data
+-- !query analysis
+Aggregate [max_by(name#x, score#x, 10, false, 0, 0) AS max_by(name, score,
10)#x]
++- SubqueryAlias basic_data
+ +- View (`basic_data`, [name#x, score#x])
+ +- Project [cast(name#x as string) AS name#x, cast(score#x as int) AS
score#x]
+ +- Project [name#x, score#x]
+ +- SubqueryAlias basic_data
+ +- LocalRelation [name#x, score#x]
+
+
+-- !query
+SELECT min_by(name, score, 10) FROM basic_data
+-- !query analysis
+Aggregate [min_by(name#x, score#x, 10, true, 0, 0) AS min_by(name, score,
10)#x]
++- SubqueryAlias basic_data
+ +- View (`basic_data`, [name#x, score#x])
+ +- Project [cast(name#x as string) AS name#x, cast(score#x as int) AS
score#x]
+ +- Project [name#x, score#x]
+ +- SubqueryAlias basic_data
+ +- LocalRelation [name#x, score#x]
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW dept_data AS SELECT * FROM VALUES
+ ('Eng', 'Alice', 120000),
+ ('Eng', 'Bob', 95000),
+ ('Eng', 'Carol', 110000),
+ ('Sales', 'Dave', 80000),
+ ('Sales', 'Eve', 75000),
+ ('Sales', 'Frank', 85000)
+AS dept_data(dept, emp, salary)
+-- !query analysis
+CreateViewCommand `dept_data`, SELECT * FROM VALUES
+ ('Eng', 'Alice', 120000),
+ ('Eng', 'Bob', 95000),
+ ('Eng', 'Carol', 110000),
+ ('Sales', 'Dave', 80000),
+ ('Sales', 'Eve', 75000),
+ ('Sales', 'Frank', 85000)
+AS dept_data(dept, emp, salary), false, true, LocalTempView, UNSUPPORTED, true
+ +- Project [dept#x, emp#x, salary#x]
+ +- SubqueryAlias dept_data
+ +- LocalRelation [dept#x, emp#x, salary#x]
+
+
+-- !query
+SELECT dept, max_by(emp, salary, 2) FROM dept_data GROUP BY dept ORDER BY dept
+-- !query analysis
+Sort [dept#x ASC NULLS FIRST], true
++- Aggregate [dept#x], [dept#x, max_by(emp#x, salary#x, 2, false, 0, 0) AS
max_by(emp, salary, 2)#x]
+ +- SubqueryAlias dept_data
+ +- View (`dept_data`, [dept#x, emp#x, salary#x])
+ +- Project [cast(dept#x as string) AS dept#x, cast(emp#x as string)
AS emp#x, cast(salary#x as int) AS salary#x]
+ +- Project [dept#x, emp#x, salary#x]
+ +- SubqueryAlias dept_data
+ +- LocalRelation [dept#x, emp#x, salary#x]
+
+
+-- !query
+SELECT dept, min_by(emp, salary, 2) FROM dept_data GROUP BY dept ORDER BY dept
+-- !query analysis
+Sort [dept#x ASC NULLS FIRST], true
++- Aggregate [dept#x], [dept#x, min_by(emp#x, salary#x, 2, true, 0, 0) AS
min_by(emp, salary, 2)#x]
+ +- SubqueryAlias dept_data
+ +- View (`dept_data`, [dept#x, emp#x, salary#x])
+ +- Project [cast(dept#x as string) AS dept#x, cast(emp#x as string)
AS emp#x, cast(salary#x as int) AS salary#x]
+ +- Project [dept#x, emp#x, salary#x]
+ +- SubqueryAlias dept_data
+ +- LocalRelation [dept#x, emp#x, salary#x]
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW null_data AS SELECT * FROM VALUES
+ ('a', 10),
+ ('b', NULL),
+ ('c', 30),
+ ('d', 20)
+AS null_data(x, y)
+-- !query analysis
+CreateViewCommand `null_data`, SELECT * FROM VALUES
+ ('a', 10),
+ ('b', NULL),
+ ('c', 30),
+ ('d', 20)
+AS null_data(x, y), false, true, LocalTempView, UNSUPPORTED, true
+ +- Project [x#x, y#x]
+ +- SubqueryAlias null_data
+ +- LocalRelation [x#x, y#x]
+
+
+-- !query
+SELECT max_by(x, y, 2) FROM null_data
+-- !query analysis
+Aggregate [max_by(x#x, y#x, 2, false, 0, 0) AS max_by(x, y, 2)#x]
++- SubqueryAlias null_data
+ +- View (`null_data`, [x#x, y#x])
+ +- Project [cast(x#x as string) AS x#x, cast(y#x as int) AS y#x]
+ +- Project [x#x, y#x]
+ +- SubqueryAlias null_data
+ +- LocalRelation [x#x, y#x]
+
+
+-- !query
+SELECT min_by(x, y, 2) FROM null_data
+-- !query analysis
+Aggregate [min_by(x#x, y#x, 2, true, 0, 0) AS min_by(x, y, 2)#x]
++- SubqueryAlias null_data
+ +- View (`null_data`, [x#x, y#x])
+ +- Project [cast(x#x as string) AS x#x, cast(y#x as int) AS y#x]
+ +- Project [x#x, y#x]
+ +- SubqueryAlias null_data
+ +- LocalRelation [x#x, y#x]
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW null_value_data AS SELECT * FROM VALUES
+ (NULL, 10),
+ ('b', 20),
+ ('c', 30)
+AS null_value_data(x, y)
+-- !query analysis
+CreateViewCommand `null_value_data`, SELECT * FROM VALUES
+ (NULL, 10),
+ ('b', 20),
+ ('c', 30)
+AS null_value_data(x, y), false, true, LocalTempView, UNSUPPORTED, true
+ +- Project [x#x, y#x]
+ +- SubqueryAlias null_value_data
+ +- LocalRelation [x#x, y#x]
+
+
+-- !query
+SELECT max_by(x, y, 2) FROM null_value_data
+-- !query analysis
+Aggregate [max_by(x#x, y#x, 2, false, 0, 0) AS max_by(x, y, 2)#x]
++- SubqueryAlias null_value_data
+ +- View (`null_value_data`, [x#x, y#x])
+ +- Project [cast(x#x as string) AS x#x, cast(y#x as int) AS y#x]
+ +- Project [x#x, y#x]
+ +- SubqueryAlias null_value_data
+ +- LocalRelation [x#x, y#x]
+
+
+-- !query
+SELECT min_by(x, y, 2) FROM null_value_data
+-- !query analysis
+Aggregate [min_by(x#x, y#x, 2, true, 0, 0) AS min_by(x, y, 2)#x]
++- SubqueryAlias null_value_data
+ +- View (`null_value_data`, [x#x, y#x])
+ +- Project [cast(x#x as string) AS x#x, cast(y#x as int) AS y#x]
+ +- Project [x#x, y#x]
+ +- SubqueryAlias null_value_data
+ +- LocalRelation [x#x, y#x]
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW typed_data AS SELECT * FROM VALUES
+ ('a', 1.5),
+ ('b', 2.5),
+ ('c', 0.5)
+AS typed_data(name, score)
+-- !query analysis
+CreateViewCommand `typed_data`, SELECT * FROM VALUES
+ ('a', 1.5),
+ ('b', 2.5),
+ ('c', 0.5)
+AS typed_data(name, score), false, true, LocalTempView, UNSUPPORTED, true
+ +- Project [name#x, score#x]
+ +- SubqueryAlias typed_data
+ +- LocalRelation [name#x, score#x]
+
+
+-- !query
+SELECT max_by(name, score, 2) FROM typed_data
+-- !query analysis
+Aggregate [max_by(name#x, score#x, 2, false, 0, 0) AS max_by(name, score, 2)#x]
++- SubqueryAlias typed_data
+ +- View (`typed_data`, [name#x, score#x])
+ +- Project [cast(name#x as string) AS name#x, cast(score#x as
decimal(2,1)) AS score#x]
+ +- Project [name#x, score#x]
+ +- SubqueryAlias typed_data
+ +- LocalRelation [name#x, score#x]
+
+
+-- !query
+SELECT min_by(name, score, 2) FROM typed_data
+-- !query analysis
+Aggregate [min_by(name#x, score#x, 2, true, 0, 0) AS min_by(name, score, 2)#x]
++- SubqueryAlias typed_data
+ +- View (`typed_data`, [name#x, score#x])
+ +- Project [cast(name#x as string) AS name#x, cast(score#x as
decimal(2,1)) AS score#x]
+ +- Project [name#x, score#x]
+ +- SubqueryAlias typed_data
+ +- LocalRelation [name#x, score#x]
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW date_data AS SELECT * FROM VALUES
+ ('event1', DATE '2024-01-15'),
+ ('event2', DATE '2024-03-20'),
+ ('event3', DATE '2024-02-10')
+AS date_data(event, event_date)
+-- !query analysis
+CreateViewCommand `date_data`, SELECT * FROM VALUES
+ ('event1', DATE '2024-01-15'),
+ ('event2', DATE '2024-03-20'),
+ ('event3', DATE '2024-02-10')
+AS date_data(event, event_date), false, true, LocalTempView, UNSUPPORTED, true
+ +- Project [event#x, event_date#x]
+ +- SubqueryAlias date_data
+ +- LocalRelation [event#x, event_date#x]
+
+
+-- !query
+SELECT max_by(event, event_date, 2) FROM date_data
+-- !query analysis
+Aggregate [max_by(event#x, event_date#x, 2, false, 0, 0) AS max_by(event,
event_date, 2)#x]
++- SubqueryAlias date_data
+ +- View (`date_data`, [event#x, event_date#x])
+ +- Project [cast(event#x as string) AS event#x, cast(event_date#x as
date) AS event_date#x]
+ +- Project [event#x, event_date#x]
+ +- SubqueryAlias date_data
+ +- LocalRelation [event#x, event_date#x]
+
+
+-- !query
+SELECT min_by(event, event_date, 2) FROM date_data
+-- !query analysis
+Aggregate [min_by(event#x, event_date#x, 2, true, 0, 0) AS min_by(event,
event_date, 2)#x]
++- SubqueryAlias date_data
+ +- View (`date_data`, [event#x, event_date#x])
+ +- Project [cast(event#x as string) AS event#x, cast(event_date#x as
date) AS event_date#x]
+ +- Project [event#x, event_date#x]
+ +- SubqueryAlias date_data
+ +- LocalRelation [event#x, event_date#x]
+
+
+-- !query
+SELECT dept, emp, salary,
+ max_by(emp, salary, 2) OVER (PARTITION BY dept) as top2,
+ min_by(emp, salary, 2) OVER (PARTITION BY dept) as bottom2
+FROM dept_data
+ORDER BY dept, emp
+-- !query analysis
+Sort [dept#x ASC NULLS FIRST, emp#x ASC NULLS FIRST], true
++- Project [dept#x, emp#x, salary#x, top2#x, bottom2#x]
+ +- Project [dept#x, emp#x, salary#x, top2#x, bottom2#x, top2#x, bottom2#x]
+ +- Window [max_by(emp#x, salary#x, 2, false, 0, 0)
windowspecdefinition(dept#x, specifiedwindowframe(RowFrame,
unboundedpreceding$(), unboundedfollowing$())) AS top2#x, min_by(emp#x,
salary#x, 2, true, 0, 0) windowspecdefinition(dept#x,
specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$()))
AS bottom2#x], [dept#x]
+ +- Project [dept#x, emp#x, salary#x]
+ +- SubqueryAlias dept_data
+ +- View (`dept_data`, [dept#x, emp#x, salary#x])
+ +- Project [cast(dept#x as string) AS dept#x, cast(emp#x as
string) AS emp#x, cast(salary#x as int) AS salary#x]
+ +- Project [dept#x, emp#x, salary#x]
+ +- SubqueryAlias dept_data
+ +- LocalRelation [dept#x, emp#x, salary#x]
+
+
+-- !query
+SELECT max_by(name, score, 0) FROM basic_data
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "currentValue" : "0",
+ "exprName" : "`k`",
+ "sqlExpr" : "\"max_by(name, score, 0)\"",
+ "valueRange" : "[1, 100000]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 29,
+ "fragment" : "max_by(name, score, 0)"
+ } ]
+}
+
+
+-- !query
+SELECT max_by(name, score, -1) FROM basic_data
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "currentValue" : "-1",
+ "exprName" : "`k`",
+ "sqlExpr" : "\"max_by(name, score, -1)\"",
+ "valueRange" : "[1, 100000]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 30,
+ "fragment" : "max_by(name, score, -1)"
+ } ]
+}
+
+
+-- !query
+SELECT max_by(name, score, 100001) FROM basic_data
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "currentValue" : "100001",
+ "exprName" : "`k`",
+ "sqlExpr" : "\"max_by(name, score, 100001)\"",
+ "valueRange" : "[1, 100000]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 34,
+ "fragment" : "max_by(name, score, 100001)"
+ } ]
+}
+
+
+-- !query
+DROP VIEW basic_data
+-- !query analysis
+DropTempViewCommand basic_data
+
+
+-- !query
+DROP VIEW dept_data
+-- !query analysis
+DropTempViewCommand dept_data
+
+
+-- !query
+DROP VIEW null_data
+-- !query analysis
+DropTempViewCommand null_data
+
+
+-- !query
+DROP VIEW null_value_data
+-- !query analysis
+DropTempViewCommand null_value_data
+
+
+-- !query
+DROP VIEW typed_data
+-- !query analysis
+DropTempViewCommand typed_data
+
+
+-- !query
+DROP VIEW date_data
+-- !query analysis
+DropTempViewCommand date_data
diff --git a/sql/core/src/test/resources/sql-tests/inputs/max-min-by-k.sql
b/sql/core/src/test/resources/sql-tests/inputs/max-min-by-k.sql
new file mode 100644
index 000000000000..74ebc6fc56e8
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/max-min-by-k.sql
@@ -0,0 +1,105 @@
+-- Test max_by and min_by aggregate functions with k parameter
+
+-- Test data
+CREATE OR REPLACE TEMPORARY VIEW basic_data AS SELECT * FROM VALUES
+ ('Alice', 85),
+ ('Bob', 92),
+ ('Carol', 78),
+ ('Dave', 95),
+ ('Eve', 88)
+AS basic_data(name, score);
+
+-- Basic max_by with k
+SELECT max_by(name, score, 3) FROM basic_data;
+
+-- Basic min_by with k
+SELECT min_by(name, score, 3) FROM basic_data;
+
+-- k = 1 (should return array with single element)
+SELECT max_by(name, score, 1) FROM basic_data;
+SELECT min_by(name, score, 1) FROM basic_data;
+
+-- k larger than row count (should return all elements)
+SELECT max_by(name, score, 10) FROM basic_data;
+SELECT min_by(name, score, 10) FROM basic_data;
+
+-- GROUP BY test data
+CREATE OR REPLACE TEMPORARY VIEW dept_data AS SELECT * FROM VALUES
+ ('Eng', 'Alice', 120000),
+ ('Eng', 'Bob', 95000),
+ ('Eng', 'Carol', 110000),
+ ('Sales', 'Dave', 80000),
+ ('Sales', 'Eve', 75000),
+ ('Sales', 'Frank', 85000)
+AS dept_data(dept, emp, salary);
+
+-- max_by with GROUP BY
+SELECT dept, max_by(emp, salary, 2) FROM dept_data GROUP BY dept ORDER BY dept;
+
+-- min_by with GROUP BY
+SELECT dept, min_by(emp, salary, 2) FROM dept_data GROUP BY dept ORDER BY dept;
+
+-- NULL handling: NULL ordering values are skipped
+CREATE OR REPLACE TEMPORARY VIEW null_data AS SELECT * FROM VALUES
+ ('a', 10),
+ ('b', NULL),
+ ('c', 30),
+ ('d', 20)
+AS null_data(x, y);
+
+SELECT max_by(x, y, 2) FROM null_data;
+SELECT min_by(x, y, 2) FROM null_data;
+
+-- NULL values (not ordering) are preserved
+CREATE OR REPLACE TEMPORARY VIEW null_value_data AS SELECT * FROM VALUES
+ (NULL, 10),
+ ('b', 20),
+ ('c', 30)
+AS null_value_data(x, y);
+
+SELECT max_by(x, y, 2) FROM null_value_data;
+SELECT min_by(x, y, 2) FROM null_value_data;
+
+-- Different data types for ordering
+CREATE OR REPLACE TEMPORARY VIEW typed_data AS SELECT * FROM VALUES
+ ('a', 1.5),
+ ('b', 2.5),
+ ('c', 0.5)
+AS typed_data(name, score);
+
+SELECT max_by(name, score, 2) FROM typed_data;
+SELECT min_by(name, score, 2) FROM typed_data;
+
+-- Date ordering
+CREATE OR REPLACE TEMPORARY VIEW date_data AS SELECT * FROM VALUES
+ ('event1', DATE '2024-01-15'),
+ ('event2', DATE '2024-03-20'),
+ ('event3', DATE '2024-02-10')
+AS date_data(event, event_date);
+
+SELECT max_by(event, event_date, 2) FROM date_data;
+SELECT min_by(event, event_date, 2) FROM date_data;
+
+-- Window function test
+SELECT dept, emp, salary,
+ max_by(emp, salary, 2) OVER (PARTITION BY dept) as top2,
+ min_by(emp, salary, 2) OVER (PARTITION BY dept) as bottom2
+FROM dept_data
+ORDER BY dept, emp;
+
+-- Error case: k must be positive
+SELECT max_by(name, score, 0) FROM basic_data;
+
+-- Error case: k must be positive (negative)
+SELECT max_by(name, score, -1) FROM basic_data;
+
+-- Error case: k exceeds maximum limit (100000)
+SELECT max_by(name, score, 100001) FROM basic_data;
+
+-- Cleanup
+DROP VIEW basic_data;
+DROP VIEW dept_data;
+DROP VIEW null_data;
+DROP VIEW null_value_data;
+DROP VIEW typed_data;
+DROP VIEW date_data;
diff --git a/sql/core/src/test/resources/sql-tests/results/max-min-by-k.sql.out
b/sql/core/src/test/resources/sql-tests/results/max-min-by-k.sql.out
new file mode 100644
index 000000000000..88e47e7e7885
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/max-min-by-k.sql.out
@@ -0,0 +1,347 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW basic_data AS SELECT * FROM VALUES
+ ('Alice', 85),
+ ('Bob', 92),
+ ('Carol', 78),
+ ('Dave', 95),
+ ('Eve', 88)
+AS basic_data(name, score)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT max_by(name, score, 3) FROM basic_data
+-- !query schema
+struct<max_by(name, score, 3):array<string>>
+-- !query output
+["Dave","Bob","Eve"]
+
+
+-- !query
+SELECT min_by(name, score, 3) FROM basic_data
+-- !query schema
+struct<min_by(name, score, 3):array<string>>
+-- !query output
+["Carol","Alice","Eve"]
+
+
+-- !query
+SELECT max_by(name, score, 1) FROM basic_data
+-- !query schema
+struct<max_by(name, score, 1):array<string>>
+-- !query output
+["Dave"]
+
+
+-- !query
+SELECT min_by(name, score, 1) FROM basic_data
+-- !query schema
+struct<min_by(name, score, 1):array<string>>
+-- !query output
+["Carol"]
+
+
+-- !query
+SELECT max_by(name, score, 10) FROM basic_data
+-- !query schema
+struct<max_by(name, score, 10):array<string>>
+-- !query output
+["Dave","Bob","Eve","Alice","Carol"]
+
+
+-- !query
+SELECT min_by(name, score, 10) FROM basic_data
+-- !query schema
+struct<min_by(name, score, 10):array<string>>
+-- !query output
+["Carol","Alice","Eve","Bob","Dave"]
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW dept_data AS SELECT * FROM VALUES
+ ('Eng', 'Alice', 120000),
+ ('Eng', 'Bob', 95000),
+ ('Eng', 'Carol', 110000),
+ ('Sales', 'Dave', 80000),
+ ('Sales', 'Eve', 75000),
+ ('Sales', 'Frank', 85000)
+AS dept_data(dept, emp, salary)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT dept, max_by(emp, salary, 2) FROM dept_data GROUP BY dept ORDER BY dept
+-- !query schema
+struct<dept:string,max_by(emp, salary, 2):array<string>>
+-- !query output
+Eng ["Alice","Carol"]
+Sales ["Frank","Dave"]
+
+
+-- !query
+SELECT dept, min_by(emp, salary, 2) FROM dept_data GROUP BY dept ORDER BY dept
+-- !query schema
+struct<dept:string,min_by(emp, salary, 2):array<string>>
+-- !query output
+Eng ["Bob","Carol"]
+Sales ["Eve","Dave"]
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW null_data AS SELECT * FROM VALUES
+ ('a', 10),
+ ('b', NULL),
+ ('c', 30),
+ ('d', 20)
+AS null_data(x, y)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT max_by(x, y, 2) FROM null_data
+-- !query schema
+struct<max_by(x, y, 2):array<string>>
+-- !query output
+["c","d"]
+
+
+-- !query
+SELECT min_by(x, y, 2) FROM null_data
+-- !query schema
+struct<min_by(x, y, 2):array<string>>
+-- !query output
+["a","d"]
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW null_value_data AS SELECT * FROM VALUES
+ (NULL, 10),
+ ('b', 20),
+ ('c', 30)
+AS null_value_data(x, y)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT max_by(x, y, 2) FROM null_value_data
+-- !query schema
+struct<max_by(x, y, 2):array<string>>
+-- !query output
+["c","b"]
+
+
+-- !query
+SELECT min_by(x, y, 2) FROM null_value_data
+-- !query schema
+struct<min_by(x, y, 2):array<string>>
+-- !query output
+[null,"b"]
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW typed_data AS SELECT * FROM VALUES
+ ('a', 1.5),
+ ('b', 2.5),
+ ('c', 0.5)
+AS typed_data(name, score)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT max_by(name, score, 2) FROM typed_data
+-- !query schema
+struct<max_by(name, score, 2):array<string>>
+-- !query output
+["b","a"]
+
+
+-- !query
+SELECT min_by(name, score, 2) FROM typed_data
+-- !query schema
+struct<min_by(name, score, 2):array<string>>
+-- !query output
+["c","a"]
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW date_data AS SELECT * FROM VALUES
+ ('event1', DATE '2024-01-15'),
+ ('event2', DATE '2024-03-20'),
+ ('event3', DATE '2024-02-10')
+AS date_data(event, event_date)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT max_by(event, event_date, 2) FROM date_data
+-- !query schema
+struct<max_by(event, event_date, 2):array<string>>
+-- !query output
+["event2","event3"]
+
+
+-- !query
+SELECT min_by(event, event_date, 2) FROM date_data
+-- !query schema
+struct<min_by(event, event_date, 2):array<string>>
+-- !query output
+["event1","event3"]
+
+
+-- !query
+SELECT dept, emp, salary,
+ max_by(emp, salary, 2) OVER (PARTITION BY dept) as top2,
+ min_by(emp, salary, 2) OVER (PARTITION BY dept) as bottom2
+FROM dept_data
+ORDER BY dept, emp
+-- !query schema
+struct<dept:string,emp:string,salary:int,top2:array<string>,bottom2:array<string>>
+-- !query output
+Eng Alice 120000 ["Alice","Carol"] ["Bob","Carol"]
+Eng Bob 95000 ["Alice","Carol"] ["Bob","Carol"]
+Eng Carol 110000 ["Alice","Carol"] ["Bob","Carol"]
+Sales Dave 80000 ["Frank","Dave"] ["Eve","Dave"]
+Sales Eve 75000 ["Frank","Dave"] ["Eve","Dave"]
+Sales Frank 85000 ["Frank","Dave"] ["Eve","Dave"]
+
+
+-- !query
+SELECT max_by(name, score, 0) FROM basic_data
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "currentValue" : "0",
+ "exprName" : "`k`",
+ "sqlExpr" : "\"max_by(name, score, 0)\"",
+ "valueRange" : "[1, 100000]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 29,
+ "fragment" : "max_by(name, score, 0)"
+ } ]
+}
+
+
+-- !query
+SELECT max_by(name, score, -1) FROM basic_data
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "currentValue" : "-1",
+ "exprName" : "`k`",
+ "sqlExpr" : "\"max_by(name, score, -1)\"",
+ "valueRange" : "[1, 100000]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 30,
+ "fragment" : "max_by(name, score, -1)"
+ } ]
+}
+
+
+-- !query
+SELECT max_by(name, score, 100001) FROM basic_data
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "currentValue" : "100001",
+ "exprName" : "`k`",
+ "sqlExpr" : "\"max_by(name, score, 100001)\"",
+ "valueRange" : "[1, 100000]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 34,
+ "fragment" : "max_by(name, score, 100001)"
+ } ]
+}
+
+
+-- !query
+DROP VIEW basic_data
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW dept_data
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW null_data
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW null_value_data
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW typed_data
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW date_data
+-- !query schema
+struct<>
+-- !query output
+
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 6c14a869f201..34adee8f89b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1156,6 +1156,307 @@ class DataFrameAggregateSuite extends QueryTest
}
}
+ test("max_by and min_by with k") {
+ // Basic: string values, integer ordering
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 2), min_by(x, y, 2)
+ |FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)
+ """.stripMargin),
+ Row(Seq("b", "c"), Seq("a", "c")) :: Nil
+ )
+
+ // DataFrame API
+ checkAnswer(
+ spark.sql("SELECT * FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS
tab(x, y)")
+ .agg(max_by(col("x"), col("y"), 2), min_by(col("x"), col("y"), 2)),
+ Row(Seq("b", "c"), Seq("a", "c")) :: Nil
+ )
+
+ // k larger than available rows
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 5), min_by(x, y, 5)
+ |FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)
+ """.stripMargin),
+ Row(Seq("b", "c", "a"), Seq("a", "c", "b")) :: Nil
+ )
+
+ // k = 1
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 1), min_by(x, y, 1)
+ |FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)
+ """.stripMargin),
+ Row(Seq("b"), Seq("a")) :: Nil
+ )
+
+ // NULL orderings are skipped
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 2), min_by(x, y, 2)
+ |FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)
+ """.stripMargin),
+ Row(Seq("c", "a"), Seq("a", "c")) :: Nil
+ )
+
+ // All NULL orderings yields null
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 2), min_by(x, y, 2)
+ |FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)
+ """.stripMargin),
+ Row(null, null) :: Nil
+ )
+
+ // Empty input yields null
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 2), min_by(x, y, 2)
+ |FROM VALUES (('a', 10)) AS tab(x, y) WHERE false
+ """.stripMargin),
+ Row(null, null) :: Nil
+ )
+
+ // Integer values, integer ordering
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 2), min_by(x, y, 2)
+ |FROM VALUES ((1, 100)), ((2, 200)), ((3, 150)) AS tab(x, y)
+ """.stripMargin),
+ Row(Seq(2, 3), Seq(1, 3)) :: Nil
+ )
+
+ // 10 elements, k=3 - forces heap replacements
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 3), min_by(x, y, 3)
+ |FROM VALUES ((1, 50)), ((2, 30)), ((3, 80)), ((4, 10)), ((5, 90)),
+ | ((6, 20)), ((7, 70)), ((8, 40)), ((9, 60)), ((10, 100))
+ |AS tab(x, y)
+ """.stripMargin),
+ Row(Seq(10, 5, 3), Seq(4, 6, 2)) :: Nil
+ )
+
+ // descending input order (worst case for min-heap)
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 3)
+ |FROM VALUES ((1, 100)), ((2, 90)), ((3, 80)), ((4, 70)), ((5, 60)),
+ | ((6, 50)), ((7, 40)), ((8, 30)), ((9, 20)), ((10, 10))
+ |AS tab(x, y)
+ """.stripMargin),
+ Row(Seq(1, 2, 3)) :: Nil
+ )
+
+ // ascending input order (worst case for max-heap in min_by)
+ checkAnswer(
+ sql(
+ """
+ |SELECT min_by(x, y, 3)
+ |FROM VALUES ((1, 10)), ((2, 20)), ((3, 30)), ((4, 40)), ((5, 50)),
+ | ((6, 60)), ((7, 70)), ((8, 80)), ((9, 90)), ((10, 100))
+ |AS tab(x, y)
+ """.stripMargin),
+ Row(Seq(1, 2, 3)) :: Nil
+ )
+
+ // Large k with many elements
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 5), min_by(x, y, 5)
+ |FROM VALUES ((1, 15)), ((2, 25)), ((3, 35)), ((4, 45)), ((5, 55)),
+ | ((6, 65)), ((7, 75)), ((8, 85))
+ |AS tab(x, y)
+ """.stripMargin),
+ Row(Seq(8, 7, 6, 5, 4), Seq(1, 2, 3, 4, 5)) :: Nil
+ )
+
+ // Duplicate ordering values (non-deterministic order within ties, but set
should match)
+ val dupsResult = sql(
+ """
+ |SELECT max_by(x, y, 3)
+ |FROM VALUES ((1, 50)), ((2, 50)), ((3, 50)), ((4, 10)), ((5, 90))
+ |AS tab(x, y)
+ """.stripMargin).collect()(0).getSeq[Int](0).toSet
+ assert(dupsResult.contains(5)) // 90 is highest, must be included
+ assert(dupsResult.size == 3)
+
+ // Struct ordering
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 2), min_by(x, y, 2)
+ |FROM VALUES (('a', (10, 20))), (('b', (10, 50))), (('c', (10, 60)))
AS tab(x, y)
+ """.stripMargin),
+ Row(Seq("c", "b"), Seq("a", "b")) :: Nil
+ )
+
+ // Array ordering
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 2), min_by(x, y, 2)
+ |FROM VALUES (('a', array(10, 20))), (('b', array(10, 50))), (('c',
array(10, 60)))
+ |AS tab(x, y)
+ """.stripMargin),
+ Row(Seq("c", "b"), Seq("a", "b")) :: Nil
+ )
+
+ // Struct values
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 2), min_by(x, y, 2)
+ |FROM VALUES ((('a', 1), 10)), ((('b', 2), 50)), ((('c', 3), 20)) AS
tab(x, y)
+ """.stripMargin),
+ Row(Seq(Row("b", 2), Row("c", 3)), Seq(Row("a", 1), Row("c", 3))) :: Nil
+ )
+
+ // Array values
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 2), min_by(x, y, 2)
+ |FROM VALUES ((array(1, 2), 10)), ((array(3, 4), 50)), ((array(5,
6), 20)) AS tab(x, y)
+ """.stripMargin),
+ Row(Seq(Seq(3, 4), Seq(5, 6)), Seq(Seq(1, 2), Seq(5, 6))) :: Nil
+ )
+
+ // Map values
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 2), min_by(x, y, 2)
+ |FROM VALUES ((map('a', 1), 10)), ((map('b', 2), 50)), ((map('c',
3), 20)) AS tab(x, y)
+ """.stripMargin),
+ Row(Seq(Map("b" -> 2), Map("c" -> 3)), Seq(Map("a" -> 1), Map("c" ->
3))) :: Nil
+ )
+
+ // GROUP BY
+ checkAnswer(
+ sql(
+ """
+ |SELECT course, max_by(year, earnings, 2), min_by(year, earnings, 2)
+ |FROM VALUES
+ | (('Java', 2012, 20000)), (('Java', 2013, 30000)),
+ | (('dotNET', 2012, 15000)), (('dotNET', 2013, 48000))
+ |AS tab(course, year, earnings)
+ |GROUP BY course
+ |ORDER BY course
+ """.stripMargin),
+ Row("Java", Seq(2013, 2012), Seq(2012, 2013)) ::
+ Row("dotNET", Seq(2013, 2012), Seq(2012, 2013)) :: Nil
+ )
+
+ // Error: k must be a constant (not a column reference)
+ Seq("max_by", "min_by").foreach { fn =>
+ val error = intercept[AnalysisException] {
+ sql(s"SELECT $fn(x, y, z) FROM VALUES (('a', 10, 2)) AS tab(x, y,
z)").collect()
+ }
+ assert(error.getMessage.contains("NON_FOLDABLE_INPUT") ||
+ error.getMessage.contains("constant integer"))
+ }
+
+ // Float k is implicitly cast to integer (truncated) - 2.5 becomes 2
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 2.5), min_by(x, y, 2.5)
+ |FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)
+ """.stripMargin),
+ Row(Seq("b", "c"), Seq("a", "c")) :: Nil
+ )
+
+ // Error: string k cannot be cast to integer
+ Seq("max_by", "min_by").foreach { fn =>
+ val error = intercept[Exception] {
+ sql(s"SELECT $fn(x, y, 'two') FROM VALUES (('a', 10)) AS tab(x,
y)").collect()
+ }
+ assert(error.getMessage.contains("CAST_INVALID_INPUT") ||
+ error.getMessage.contains("cannot be cast"))
+ }
+
+ // Error: k must be positive
+ Seq("max_by", "min_by").foreach { fn =>
+ val error = intercept[Exception] {
+ sql(s"SELECT $fn(x, y, 0) FROM VALUES (('a', 10)) AS tab(x,
y)").collect()
+ }
+ assert(error.getMessage.contains("VALUE_OUT_OF_RANGE") ||
+ error.getMessage.contains("positive"))
+ }
+
+ // Error: non-orderable type (MAP)
+ withTempView("tempView") {
+ Seq((0, "a"), (1, "b"), (2, "c"))
+ .toDF("x", "y")
+ .select($"x", map($"x", $"y").as("y"))
+ .createOrReplaceTempView("tempView")
+ Seq("max_by", "min_by").foreach { fn =>
+ val mapError = intercept[AnalysisException] {
+ sql(s"SELECT $fn(x, y, 2) FROM tempView").collect()
+ }
+ assert(mapError.getMessage.contains("INVALID_ORDERING_TYPE") ||
+ mapError.getMessage.contains("not orderable"))
+ }
+ }
+
+ // Error: non-orderable type (ARRAY<MAP>)
+ withTempView("tempView") {
+ Seq((0, "a"), (1, "b"), (2, "c"))
+ .toDF("x", "y")
+ .select($"x", array(map($"x", $"y")).as("y"))
+ .createOrReplaceTempView("tempView")
+ Seq("max_by", "min_by").foreach { fn =>
+ val error = intercept[AnalysisException] {
+ sql(s"SELECT $fn(x, y, 2) FROM tempView").collect()
+ }
+ assert(error.getMessage.contains("INVALID_ORDERING_TYPE"))
+ }
+ }
+
+ // Error: non-orderable type (VARIANT)
+ withTempView("tempView") {
+ sql("SELECT 'a' as x, parse_json('{\"k\": 1}') as y")
+ .createOrReplaceTempView("tempView")
+ Seq("max_by", "min_by").foreach { fn =>
+ val error = intercept[AnalysisException] {
+ sql(s"SELECT $fn(x, y, 2) FROM tempView").collect()
+ }
+ assert(error.getMessage.contains("INVALID_ORDERING_TYPE"))
+ }
+ }
+
+ // Error: k exceeds maximum limit (100000)
+ Seq("max_by", "min_by").foreach { fn =>
+ val error = intercept[Exception] {
+ sql(s"SELECT $fn(x, y, 100001) FROM VALUES (('a', 10)) AS tab(x,
y)").collect()
+ }
+ assert(error.getMessage.contains("VALUE_OUT_OF_RANGE") ||
+ error.getMessage.contains("100000"))
+ }
+
+ // Large k
+ checkAnswer(
+ sql(
+ """
+ |SELECT max_by(x, y, 100000), min_by(x, y, 100000)
+ |FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)
+ """.stripMargin),
+ Row(Seq("b", "c", "a"), Seq("a", "c", "b")) :: Nil
+ )
+ }
+
test("percentile_like") {
// percentile
checkAnswer(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]