This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 13b0251 [SPARK-36963][SQL] Add max_by/min_by to sql.functions
13b0251 is described below
commit 13b02512db6e1d735994d79b92c9783e78f5f745
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Oct 11 09:47:04 2021 +0900
[SPARK-36963][SQL] Add max_by/min_by to sql.functions
### What changes were proposed in this pull request?
Add max_by/min_by to sql.functions
### Why are the changes needed?
for convenience
### Does this PR introduce _any_ user-facing change?
yes, new methods are added
### How was this patch tested?
existing testsuits and added testsuits
Closes #34229 from zhengruifeng/functions_add_max_min_by.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../src/main/scala/org/apache/spark/sql/functions.scala | 16 ++++++++++++++++
.../org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++
2 files changed, 26 insertions(+)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 7bca29f..b32c1f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -674,6 +674,14 @@ object functions {
def max(columnName: String): Column = max(Column(columnName))
/**
+ * Aggregate function: returns the value associated with the maximum value
of ord.
+ *
+ * @group agg_funcs
+ * @since 3.3.0
+ */
+ def max_by(e: Column, ord: Column): Column = withAggregateFunction {
MaxBy(e.expr, ord.expr) }
+
+ /**
* Aggregate function: returns the average of the values in a group.
* Alias for avg.
*
@@ -708,6 +716,14 @@ object functions {
def min(columnName: String): Column = min(Column(columnName))
/**
+ * Aggregate function: returns the value associated with the minimum value
of ord.
+ *
+ * @group agg_funcs
+ * @since 3.3.0
+ */
+ def min_by(e: Column, ord: Column): Column = withAggregateFunction {
MinBy(e.expr, ord.expr) }
+
+ /**
* Aggregate function: returns the approximate `percentile` of the numeric
column `col` which
* is the smallest value in the ordered `col` values (sorted from least to
greatest) such that
* no more than `percentage` of `col` values is less than the value or equal
to that value.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 1f8638c..c3076c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -876,6 +876,11 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) ::
Nil)
checkAnswer(
+ courseSales.groupBy("course").agg(max_by(col("year"), col("earnings"))),
+ Row("dotNET", 2013) :: Row("Java", 2013) :: Nil
+ )
+
+ checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c',
20)) AS tab(x, y)"),
Row("b") :: Nil
)
@@ -932,6 +937,11 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) ::
Nil)
checkAnswer(
+ courseSales.groupBy("course").agg(min_by(col("year"), col("earnings"))),
+ Row("dotNET", 2012) :: Row("Java", 2012) :: Nil
+ )
+
+ checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c',
20)) AS tab(x, y)"),
Row("a") :: Nil
)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]