This is an automated email from the ASF dual-hosted git repository.
jhyde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 20014b6c5b [CALCITE-5283] Add ARG_MIN, ARG_MAX (aka MIN_BY, MAX_BY)
aggregate functions
20014b6c5b is described below
commit 20014b6c5b9b57d29248206f19d63ef50f7a5c0f
Author: zoudan <[email protected]>
AuthorDate: Wed Nov 23 15:42:49 2022 +0800
[CALCITE-5283] Add ARG_MIN, ARG_MAX (aka MIN_BY, MAX_BY) aggregate functions
Close apache/calcite#2981
---
.../calcite/adapter/enumerable/RexImpTable.java | 69 ++++++++++++++++++
.../org/apache/calcite/runtime/SqlFunctions.java | 81 ++++++++++++++++++---
.../main/java/org/apache/calcite/sql/SqlKind.java | 6 ++
.../calcite/sql/fun/SqlBasicAggFunction.java | 9 +++
.../calcite/sql/fun/SqlLibraryOperators.java | 12 ++++
.../calcite/sql/fun/SqlStdOperatorTable.java | 18 +++++
.../org/apache/calcite/sql/type/OperandTypes.java | 23 ++++++
.../org/apache/calcite/util/BuiltInMethod.java | 6 ++
.../apache/calcite/test/SqlToRelConverterTest.java | 26 +++++++
.../org/apache/calcite/test/SqlValidatorTest.java | 10 +++
.../apache/calcite/test/SqlToRelConverterTest.xml | 52 ++++++++++++++
core/src/test/resources/sql/agg.iq | 82 ++++++++++++++++++++++
core/src/test/resources/sql/winagg.iq | 24 +++++++
site/_docs/reference.md | 4 ++
.../org/apache/calcite/test/SqlOperatorTest.java | 18 +++++
15 files changed, 431 insertions(+), 9 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
index 35d563e23e..e61c7731d6 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
@@ -148,7 +148,9 @@ import static
org.apache.calcite.sql.fun.SqlLibraryOperators.LEFT;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_AND;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_OR;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LPAD;
+import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAX_BY;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MD5;
+import static org.apache.calcite.sql.fun.SqlLibraryOperators.MIN_BY;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MONTHNAME;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.POW;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.REGEXP_REPLACE;
@@ -181,6 +183,8 @@ import static
org.apache.calcite.sql.fun.SqlStdOperatorTable.ABS;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ACOS;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.AND;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ANY_VALUE;
+import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ARG_MAX;
+import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ARG_MIN;
import static
org.apache.calcite.sql.fun.SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ASCII;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ASIN;
@@ -734,6 +738,10 @@ public class RexImpTable {
constructorSupplier(MinMaxImplementor.class);
aggMap.put(MIN, minMax);
aggMap.put(MAX, minMax);
+ aggMap.put(ARG_MIN, constructorSupplier(ArgMinMaxImplementor.class));
+ aggMap.put(ARG_MAX, constructorSupplier(ArgMinMaxImplementor.class));
+ aggMap.put(MIN_BY, constructorSupplier(ArgMinMaxImplementor.class));
+ aggMap.put(MAX_BY, constructorSupplier(ArgMinMaxImplementor.class));
aggMap.put(ANY_VALUE, minMax);
aggMap.put(SOME, minMax);
aggMap.put(EVERY, minMax);
@@ -1204,6 +1212,67 @@ public class RexImpTable {
}
}
+ /** Implementor for the {@code ARG_MIN} and {@code ARG_MAX} aggregate
+ * functions. */
+ static class ArgMinMaxImplementor extends StrictAggImplementor {
+ @Override protected void implementNotNullReset(AggContext info,
+ AggResetContext reset) {
+ // acc[0] = null;
+ reset.currentBlock().add(
+ Expressions.statement(
+ Expressions.assign(reset.accumulator().get(0),
+ Expressions.constant(null))));
+
+ final Type compType = info.parameterTypes().get(1);
+ final Primitive p = Primitive.of(compType);
+ final boolean isMin = info.aggregation().kind == SqlKind.ARG_MIN;
+ final Object inf = p == null ? null : (isMin ? p.max : p.min);
+ //acc[1] = isMin ? {max value} : {min value};
+ reset.currentBlock().add(
+ Expressions.statement(
+ Expressions.assign(reset.accumulator().get(1),
+ Expressions.constant(inf, compType))));
+ }
+
+ @Override public void implementNotNullAdd(AggContext info,
+ AggAddContext add) {
+ Expression accComp = add.accumulator().get(1);
+ Expression argValue = add.arguments().get(0);
+ Expression argComp = add.arguments().get(1);
+ final Type compType = info.parameterTypes().get(1);
+ final Primitive p = Primitive.of(compType);
+ final boolean isMin = info.aggregation().kind == SqlKind.ARG_MIN;
+
+ final Method method =
+ (isMin
+ ? (p == null ? BuiltInMethod.LT_NULLABLE : BuiltInMethod.LT)
+ : (p == null ? BuiltInMethod.GT_NULLABLE : BuiltInMethod.GT))
+ .method;
+ Expression compareExpression =
+ Expressions.call(method.getDeclaringClass(),
+ method.getName(),
+ argComp,
+ accComp);
+
+ final BlockBuilder thenBlock =
+ new BlockBuilder(true, add.currentBlock());
+ thenBlock.add(
+ Expressions.statement(
+ Expressions.assign(add.accumulator().get(0), argValue)));
+ thenBlock.add(
+ Expressions.statement(
+ Expressions.assign(add.accumulator().get(1), argComp)));
+
+ add.currentBlock()
+ .add(Expressions.ifThen(compareExpression, thenBlock.toBlock()));
+ }
+
+ @Override public List<Type> getNotNullState(AggContext info) {
+ return ImmutableList.of(Object.class, // the result value
+ info.parameterTypes().get(1)); // the compare value
+ }
+ }
+
/** Implementor for the {@code SINGLE_VALUE} aggregate function. */
static class SingleValueImplementor implements AggImplementor {
@Override public List<Type> getStateType(AggContext info) {
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index bb617f8cf1..1f5c0f21b2 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -894,7 +894,7 @@ public class SqlFunctions {
/** SQL <code><</code> operator applied to boolean values. */
public static boolean lt(boolean b0, boolean b1) {
- return compare(b0, b1) < 0;
+ return Boolean.compare(b0, b1) < 0;
}
/** SQL <code><</code> operator applied to String values. */
@@ -917,6 +917,40 @@ public class SqlFunctions {
return b0.compareTo(b1) < 0;
}
+ /** Returns whether {@code b0} is less than {@code b1}
+ * (or {@code b1} is null). Helper for {@code ARG_MIN}. */
+ public static <T extends Comparable<T>> boolean ltNullable(T b0, T b1) {
+ return b1 == null || b0 != null && b0.compareTo(b1) < 0;
+ }
+
+ public static boolean lt(byte b0, byte b1) {
+ return b0 < b1;
+ }
+
+ public static boolean lt(char b0, char b1) {
+ return b0 < b1;
+ }
+
+ public static boolean lt(short b0, short b1) {
+ return b0 < b1;
+ }
+
+ public static boolean lt(int b0, int b1) {
+ return b0 < b1;
+ }
+
+ public static boolean lt(long b0, long b1) {
+ return b0 < b1;
+ }
+
+ public static boolean lt(float b0, float b1) {
+ return b0 < b1;
+ }
+
+ public static boolean lt(double b0, double b1) {
+ return b0 < b1;
+ }
+
/** SQL <code><</code> operator applied to Object values. */
public static boolean ltAny(Object b0, Object b1) {
if (b0.getClass().equals(b1.getClass())
@@ -934,7 +968,7 @@ public class SqlFunctions {
/** SQL <code>≤</code> operator applied to boolean values. */
public static boolean le(boolean b0, boolean b1) {
- return compare(b0, b1) <= 0;
+ return Boolean.compare(b0, b1) <= 0;
}
/** SQL <code>≤</code> operator applied to String values. */
@@ -975,7 +1009,7 @@ public class SqlFunctions {
/** SQL <code>></code> operator applied to boolean values. */
public static boolean gt(boolean b0, boolean b1) {
- return compare(b0, b1) > 0;
+ return Boolean.compare(b0, b1) > 0;
}
/** SQL <code>></code> operator applied to String values. */
@@ -998,6 +1032,40 @@ public class SqlFunctions {
return b0.compareTo(b1) > 0;
}
+ /** Returns whether {@code b0} is greater than {@code b1}
+ * (or {@code b1} is null). Helper for {@code ARG_MAX}. */
+ public static <T extends Comparable<T>> boolean gtNullable(T b0, T b1) {
+ return b1 == null || b0 != null && b0.compareTo(b1) > 0;
+ }
+
+ public static boolean gt(byte b0, byte b1) {
+ return b0 > b1;
+ }
+
+ public static boolean gt(char b0, char b1) {
+ return b0 > b1;
+ }
+
+ public static boolean gt(short b0, short b1) {
+ return b0 > b1;
+ }
+
+ public static boolean gt(int b0, int b1) {
+ return b0 > b1;
+ }
+
+ public static boolean gt(long b0, long b1) {
+ return b0 > b1;
+ }
+
+ public static boolean gt(float b0, float b1) {
+ return b0 > b1;
+ }
+
+ public static boolean gt(double b0, double b1) {
+ return b0 > b1;
+ }
+
/** SQL <code>></code> operator applied to Object values (at least one
* operand has ANY type; neither may be null). */
public static boolean gtAny(Object b0, Object b1) {
@@ -1016,7 +1084,7 @@ public class SqlFunctions {
/** SQL <code>≥</code> operator applied to boolean values. */
public static boolean ge(boolean b0, boolean b1) {
- return compare(b0, b1) >= 0;
+ return Boolean.compare(b0, b1) >= 0;
}
/** SQL <code>≥</code> operator applied to String values. */
@@ -1971,11 +2039,6 @@ public class SqlFunctions {
return b0 == null || b1 != null && b0.compareTo(b1) < 0 ? b1 : b0;
}
- /** Boolean comparison. */
- public static int compare(boolean x, boolean y) {
- return x == y ? 0 : x ? 1 : -1;
- }
-
/** CAST(FLOAT AS VARCHAR). */
public static String toString(float x) {
if (x == 0) {
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java
b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
index 6e310c7793..269b80e2d7 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
@@ -888,6 +888,12 @@ public enum SqlKind {
/** The {@code MODE} aggregate function. */
MODE,
+ /** The {@code ARG_MAX} aggregate function. */
+ ARG_MAX,
+
+ /** The {@code ARG_MIN} aggregate function. */
+ ARG_MIN,
+
/** The {@code PERCENTILE_CONT} aggregate function. */
PERCENTILE_CONT,
diff --git
a/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java
b/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java
index 54b07254fa..95623dfc56 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java
@@ -128,6 +128,15 @@ public final class SqlBasicAggFunction extends
SqlAggFunction {
return requireNonNull(super.getOperandTypeChecker(), "operandTypeChecker");
}
+ /** Sets {@link #getName()}. */
+ public SqlAggFunction withName(String name) {
+ return new SqlBasicAggFunction(name, getSqlIdentifier(), kind,
+ getReturnTypeInference(), getOperandTypeInference(),
+ getOperandTypeChecker(), getFunctionType(), requiresOrder(),
+ requiresOver(), requiresGroupOrder(), distinctOptionality, syntax,
+ allowsNullTreatment, allowsSeparator, percentile);
+ }
+
/** Sets {@link #getDistinctOptionality()}. */
SqlBasicAggFunction withDistinct(Optionality distinctOptionality) {
return new SqlBasicAggFunction(getName(), getSqlIdentifier(), kind,
diff --git
a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 6d36db0b3e..3bbe6fdd1e 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -470,6 +470,18 @@ public abstract class SqlLibraryOperators {
.withAllowsSeparator(true)
.withSyntax(SqlSyntax.ORDERED_FUNCTION);
+ /** The "MAX_BY(value, comp)" aggregate function, Spark's
+ * equivalent to {@link SqlStdOperatorTable#ARG_MAX}. */
+ @LibraryOperator(libraries = {SPARK})
+ public static final SqlAggFunction MAX_BY =
+ SqlStdOperatorTable.ARG_MAX.withName("MAX_BY");
+
+ /** The "MIN_BY(condition)" aggregate function, Spark's
+ * equivalent to {@link SqlStdOperatorTable#ARG_MIN}. */
+ @LibraryOperator(libraries = {SPARK})
+ public static final SqlAggFunction MIN_BY =
+ SqlStdOperatorTable.ARG_MIN.withName("MIN_BY");
+
/** The "DATE(string)" function, equivalent to "CAST(string AS DATE). */
@LibraryOperator(libraries = {BIG_QUERY})
public static final SqlFunction DATE =
diff --git
a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
index e5604775f7..92e38fd03a 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
@@ -1005,6 +1005,24 @@ public class SqlStdOperatorTable extends
ReflectiveSqlOperatorTable {
public static final SqlAggFunction APPROX_COUNT_DISTINCT =
new SqlCountAggFunction("APPROX_COUNT_DISTINCT");
+ /**
+ * <code>ARG_MAX</code> aggregate function.
+ */
+ public static final SqlBasicAggFunction ARG_MAX =
+ SqlBasicAggFunction.create("ARG_MAX", SqlKind.ARG_MAX,
+ ReturnTypes.ARG0_NULLABLE_IF_EMPTY, OperandTypes.ANY_COMPARABLE)
+ .withGroupOrder(Optionality.FORBIDDEN)
+ .withFunctionType(SqlFunctionCategory.SYSTEM);
+
+ /**
+ * <code>ARG_MIN</code> aggregate function.
+ */
+ public static final SqlBasicAggFunction ARG_MIN =
+ SqlBasicAggFunction.create("ARG_MIN", SqlKind.ARG_MIN,
+ ReturnTypes.ARG0_NULLABLE_IF_EMPTY, OperandTypes.ANY_COMPARABLE)
+ .withGroupOrder(Optionality.FORBIDDEN)
+ .withFunctionType(SqlFunctionCategory.SYSTEM);
+
/**
* <code>MIN</code> aggregate function.
*/
diff --git a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
index 82f813c10d..8f42fbd9eb 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
@@ -730,6 +730,29 @@ public abstract class OperandTypes {
public static final SqlSingleOperandTypeChecker ANY_STRING_STRING =
family(SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.STRING);
+ /**
+ * Operand type-checking strategy used by {@code ARG_MIN(value, comp)} and
+ * similar functions, where the first operand can have any type and the
second
+ * must be comparable.
+ */
+ public static final SqlOperandTypeChecker ANY_COMPARABLE =
+ new SqlOperandTypeChecker() {
+ @Override public boolean checkOperandTypes(SqlCallBinding callBinding,
+ boolean throwOnFailure) {
+ getOperandCountRange().isValidCount(callBinding.getOperandCount());
+ RelDataType type = callBinding.getOperandType(1);
+ return type.getComparability() == RelDataTypeComparability.ALL;
+ }
+
+ @Override public SqlOperandCountRange getOperandCountRange() {
+ return SqlOperandCountRanges.of(2);
+ }
+
+ @Override public String getAllowedSignatures(SqlOperator op, String
opName) {
+ return opName + "(<ANY>, <COMPARABLE_TYPE>)";
+ }
+ };
+
public static final SqlSingleOperandTypeChecker CURSOR =
family(SqlTypeFamily.CURSOR);
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index 22a36ebb9a..52ff2bf1ec 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -455,6 +455,12 @@ public enum BuiltInMethod {
NOT(SqlFunctions.class, "not", Boolean.class),
LESSER(SqlFunctions.class, "lesser", Comparable.class, Comparable.class),
GREATER(SqlFunctions.class, "greater", Comparable.class, Comparable.class),
+ LT_NULLABLE(SqlFunctions.class, "ltNullable", Comparable.class,
+ Comparable.class),
+ GT_NULLABLE(SqlFunctions.class, "gtNullable", Comparable.class,
+ Comparable.class),
+ LT(SqlFunctions.class, "lt", boolean.class, boolean.class),
+ GT(SqlFunctions.class, "gt", boolean.class, boolean.class),
BIT_AND(SqlFunctions.class, "bitAnd", long.class, long.class),
BIT_OR(SqlFunctions.class, "bitOr", long.class, long.class),
BIT_XOR(SqlFunctions.class, "bitXor", long.class, long.class),
diff --git
a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
index 63885e7a10..f7460a6611 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
@@ -4277,6 +4277,32 @@ class SqlToRelConverterTest extends SqlToRelTestBase {
sql(sql).ok();
}
+ @Test void testArgMinFunction() {
+ final String sql = "select arg_min(ename, deptno)\n"
+ + "from emp";
+ sql(sql).withTrim(true).ok();
+ }
+
+ @Test void testArgMinFunctionWithWinAgg() {
+ final String sql = "select job,\n"
+ + " arg_min(ename, deptno) over (partition by job order by sal)\n"
+ + "from emp";
+ sql(sql).withTrim(true).ok();
+ }
+
+ @Test void testArgMaxFunction() {
+ final String sql = "select arg_max(ename, deptno)\n"
+ + "from emp";
+ sql(sql).withTrim(true).ok();
+ }
+
+ @Test void testArgMaxFunctionWithWinAgg() {
+ final String sql = "select job,\n"
+ + " arg_max(ename, deptno) over (partition by job order by sal)\n"
+ + "from emp";
+ sql(sql).withTrim(true).ok();
+ }
+
@Test void testModeFunction() {
final String sql = "select mode(deptno)\n"
+ "from emp";
diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
index 26d5b3cfe1..17cc9306af 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
@@ -7101,6 +7101,16 @@ public class SqlValidatorTest extends
SqlValidatorTestCase {
sql("SELECT MAX(5) FROM emp").ok();
}
+ @Test void testArgMinMaxFunctions() {
+ sql("SELECT ARG_MIN(1, true) from emp").ok();
+ sql("SELECT ARG_MAX(2, false) from emp").ok();
+
+ sql("SELECT ARG_MIN(sal, deptno) FROM emp").ok();
+ sql("SELECT ARG_MAX(deptno, sal) FROM emp").ok();
+ sql("SELECT ARG_MIN('a', 5.5) FROM emp").ok();
+ sql("SELECT ARG_MAX('b', 5) FROM emp").ok();
+ }
+
@Test void testModeFunction() {
sql("select MODE(sal) from emp").ok();
sql("select MODE(sal) over (order by empno) from emp").ok();
diff --git
a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
index eb23e84764..b69a9a400c 100644
--- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
@@ -363,6 +363,58 @@ LogicalProject(ANYEMPNO=[$1])
LogicalAggregate(group=[{}], ANYEMPNO=[ANY_VALUE($0)])
LogicalProject(EMPNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testArgMaxFunction">
+ <Resource name="sql">
+ <![CDATA[select arg_max(ename, deptno)
+from emp]]>
+ </Resource>
+ <Resource name="plan">
+ <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[ARG_MAX($0, $1)])
+ LogicalProject(ENAME=[$1], DEPTNO=[$7])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testArgMaxFunctionWithWinAgg">
+ <Resource name="sql">
+ <![CDATA[select job,
+ arg_max(ename, deptno) over (partition by job order by sal)
+from emp]]>
+ </Resource>
+ <Resource name="plan">
+ <![CDATA[
+LogicalProject(JOB=[$2], EXPR$1=[ARG_MAX($1, $7) OVER (PARTITION BY $2 ORDER
BY $5)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testArgMinFunction">
+ <Resource name="sql">
+ <![CDATA[select arg_min(ename, deptno)
+from emp]]>
+ </Resource>
+ <Resource name="plan">
+ <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[ARG_MIN($0, $1)])
+ LogicalProject(ENAME=[$1], DEPTNO=[$7])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testArgMinFunctionWithWinAgg">
+ <Resource name="sql">
+ <![CDATA[select job,
+ arg_min(ename, deptno) over (partition by job order by sal)
+from emp]]>
+ </Resource>
+ <Resource name="plan">
+ <![CDATA[
+LogicalProject(JOB=[$2], EXPR$1=[ARG_MIN($1, $7) OVER (PARTITION BY $2 ORDER
BY $5)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
diff --git a/core/src/test/resources/sql/agg.iq
b/core/src/test/resources/sql/agg.iq
index c90c3b11c0..150025bb89 100644
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -3519,5 +3519,87 @@ where
+--------+------+--------+------+
(2 rows)
+!ok
+
+# [CALCITE-5283] Add ARG_MIN, ARG_MAX aggregate function
+
+# ARG_MIN, ARG_MAX without GROUP BY
+select arg_min(ename, deptno) as mi, arg_max(ename, deptno) as ma
+from emp;
++-------+-------+
+| MI | MA |
++-------+-------+
+| CLARK | ALLEN |
++-------+-------+
+(1 row)
+
+!ok
+
+# ARG_MIN, ARG_MAX with DISTINCT
+select arg_min(distinct ename, deptno) as mi, arg_max(distinct ename, deptno)
as ma
+from emp;
++-------+-------+
+| MI | MA |
++-------+-------+
+| CLARK | ALLEN |
++-------+-------+
+(1 row)
+
+!ok
+
+# ARG_MIN, ARG_MAX function with WHERE.
+select arg_min(ename, deptno) as mi, arg_max(ename, deptno) as ma
+from emp
+where deptno <= 20;
++-------+-------+
+| MI | MA |
++-------+-------+
+| CLARK | SMITH |
++-------+-------+
+(1 row)
+
+!ok
+
+# ARG_MIN, ARG_MAX function with WHERE that removes all rows.
+# Result is NULL even though ARG_MIN, ARG_MAX is applied to a not-NULL column.
+select arg_min(ename, deptno) as mi, arg_max(ename, deptno) as ma
+from emp
+where deptno > 60;
++----+----+
+| MI | MA |
++----+----+
+| | |
++----+----+
+(1 row)
+
+!ok
+
+# ARG_MIN, ARG_MAX function with GROUP BY. note that key is NULL but result is
not NULL.
+select deptno, arg_min(ename, ename) as mi, arg_max(ename, ename) as ma
+from emp
+group by deptno;
++--------+-------+--------+
+| DEPTNO | MI | MA |
++--------+-------+--------+
+| 10 | CLARK | MILLER |
+| 20 | ADAMS | SMITH |
+| 30 | ALLEN | WARD |
++--------+-------+--------+
+(3 rows)
+
+!ok
+
+# ARG_MIN, ARG_MAX applied to an integer.
+select arg_min(deptno, empno) as mi,
+ arg_max(deptno, empno) as ma,
+ arg_max(deptno, empno) filter (where job = 'MANAGER') as mamgr
+from emp;
++----+----+-------+
+| MI | MA | MAMGR |
++----+----+-------+
+| 20 | 10 | 10 |
++----+----+-------+
+(1 row)
+
!ok
# End agg.iq
diff --git a/core/src/test/resources/sql/winagg.iq
b/core/src/test/resources/sql/winagg.iq
index ce4264c053..3187bf158e 100644
--- a/core/src/test/resources/sql/winagg.iq
+++ b/core/src/test/resources/sql/winagg.iq
@@ -740,5 +740,29 @@ from emp;
+--------+-------+---+
(9 rows)
+!ok
+
+# [CALCITE-5283] Add ARG_MIN, ARG_MAX aggregate function
+
+# ARG_MIN, ARG_MAX function without ORDER BY.
+select gender,
+ arg_min(ename, deptno) over (partition by gender order by ename) as mi,
+ arg_max(ename, deptno) over (partition by gender order by ename) as ma
+from emp;
++--------+-------+-------+
+| GENDER | MI | MA |
++--------+-------+-------+
+| F | Alice | Alice |
+| F | Alice | Eve |
+| F | Alice | Grace |
+| F | Jane | Grace |
+| F | Jane | Grace |
+| F | Jane | Grace |
+| M | Adam | Adam |
+| M | Bob | Adam |
+| M | Bob | Adam |
++--------+-------+-------+
+(9 rows)
+
!ok
# End winagg.iq
diff --git a/site/_docs/reference.md b/site/_docs/reference.md
index 4de83ccc24..1fb0dfcfa6 100644
--- a/site/_docs/reference.md
+++ b/site/_docs/reference.md
@@ -1862,6 +1862,8 @@ and `LISTAGG`).
| Operator syntax | Description
|:---------------------------------- |:-----------
| ANY_VALUE( [ ALL | DISTINCT ] value) | Returns one of the values of
*value* across all input values; this is NOT specified in the SQL standard
+| ARG_MAX(value, comp) | Returns *value* for the
maximum value of *comp* in the group
+| ARG_MIN(value, comp) | Returns *value* for the
minimum value of *comp* in the group
| APPROX_COUNT_DISTINCT(value [, value ]*) | Returns the approximate
number of distinct values of *value*; the database is allowed to use an
approximation but is not required to
| AVG( [ ALL | DISTINCT ] numeric) | Returns the average
(arithmetic mean) of *numeric* across all input values
| BIT_AND( [ ALL | DISTINCT ] value) | Returns the bitwise AND of
all non-null input values, or null if none; integer and binary types are
supported
@@ -2740,6 +2742,8 @@ Dialect-specific aggregate functions.
| m | GROUP_CONCAT( [ ALL | DISTINCT ] value [, value ]* [ ORDER BY
orderItem [, orderItem ]* ] [ SEPARATOR separator ] ) | MySQL-specific variant
of `LISTAGG`
| b | LOGICAL_AND(condition) | Synonym for `EVERY`
| b | LOGICAL_OR(condition) | Synonym for `SOME`
+| s | MAX_BY(value, comp) | Synonym for `ARG_MAX`
+| s | MIN_BY(value, comp) | Synonym for `ARG_MIN`
| b p | STRING_AGG( [ ALL | DISTINCT ] value [, separator] [ ORDER BY
orderItem [, orderItem ]* ] ) | Synonym for `LISTAGG`
Usage Examples:
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index b56c566567..995406dce3 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -9438,6 +9438,24 @@ public class SqlOperatorTest {
isSingle("02"));
}
+ @Test void testArgMin() {
+ final SqlOperatorFixture f0 = fixture().withTester(t -> TESTER);
+ final String[] xValues = {"2", "3", "4", "4", "5", "7"};
+
+ final Consumer<SqlOperatorFixture> consumer = f -> {
+ f.checkAgg("arg_min(mod(x, 3), x)", xValues, isSingle("2"));
+ f.checkAgg("arg_max(mod(x, 3), x)", xValues, isSingle("1"));
+ };
+
+ final Consumer<SqlOperatorFixture> consumer2 = f -> {
+ f.checkAgg("min_by(mod(x, 3), x)", xValues, isSingle("2"));
+ f.checkAgg("max_by(mod(x, 3), x)", xValues, isSingle("1"));
+ };
+
+ consumer.accept(f0);
+ consumer2.accept(f0.withLibrary(SqlLibrary.SPARK));
+ }
+
/**
* Tests that CAST fails when given a value just outside the valid range for
* that type. For example,