This is an automated email from the ASF dual-hosted git repository. jhyde pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push: new 20014b6c5b [CALCITE-5283] Add ARG_MIN, ARG_MAX (aka MIN_BY, MAX_BY) aggregate functions 20014b6c5b is described below commit 20014b6c5b9b57d29248206f19d63ef50f7a5c0f Author: zoudan <zou...@bytedance.com> AuthorDate: Wed Nov 23 15:42:49 2022 +0800 [CALCITE-5283] Add ARG_MIN, ARG_MAX (aka MIN_BY, MAX_BY) aggregate functions Close apache/calcite#2981 --- .../calcite/adapter/enumerable/RexImpTable.java | 69 ++++++++++++++++++ .../org/apache/calcite/runtime/SqlFunctions.java | 81 ++++++++++++++++++--- .../main/java/org/apache/calcite/sql/SqlKind.java | 6 ++ .../calcite/sql/fun/SqlBasicAggFunction.java | 9 +++ .../calcite/sql/fun/SqlLibraryOperators.java | 12 ++++ .../calcite/sql/fun/SqlStdOperatorTable.java | 18 +++++ .../org/apache/calcite/sql/type/OperandTypes.java | 23 ++++++ .../org/apache/calcite/util/BuiltInMethod.java | 6 ++ .../apache/calcite/test/SqlToRelConverterTest.java | 26 +++++++ .../org/apache/calcite/test/SqlValidatorTest.java | 10 +++ .../apache/calcite/test/SqlToRelConverterTest.xml | 52 ++++++++++++++ core/src/test/resources/sql/agg.iq | 82 ++++++++++++++++++++++ core/src/test/resources/sql/winagg.iq | 24 +++++++ site/_docs/reference.md | 4 ++ .../org/apache/calcite/test/SqlOperatorTest.java | 18 +++++ 15 files changed, 431 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java index 35d563e23e..e61c7731d6 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java @@ -148,7 +148,9 @@ import static org.apache.calcite.sql.fun.SqlLibraryOperators.LEFT; import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_AND; import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_OR; import static org.apache.calcite.sql.fun.SqlLibraryOperators.LPAD; +import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAX_BY; import static org.apache.calcite.sql.fun.SqlLibraryOperators.MD5; +import static org.apache.calcite.sql.fun.SqlLibraryOperators.MIN_BY; import static org.apache.calcite.sql.fun.SqlLibraryOperators.MONTHNAME; import static org.apache.calcite.sql.fun.SqlLibraryOperators.POW; import static org.apache.calcite.sql.fun.SqlLibraryOperators.REGEXP_REPLACE; @@ -181,6 +183,8 @@ import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ABS; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ACOS; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.AND; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ANY_VALUE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ARG_MAX; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ARG_MIN; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ASCII; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ASIN; @@ -734,6 +738,10 @@ public class RexImpTable { constructorSupplier(MinMaxImplementor.class); aggMap.put(MIN, minMax); aggMap.put(MAX, minMax); + aggMap.put(ARG_MIN, constructorSupplier(ArgMinMaxImplementor.class)); + aggMap.put(ARG_MAX, constructorSupplier(ArgMinMaxImplementor.class)); + aggMap.put(MIN_BY, constructorSupplier(ArgMinMaxImplementor.class)); + aggMap.put(MAX_BY, constructorSupplier(ArgMinMaxImplementor.class)); aggMap.put(ANY_VALUE, minMax); aggMap.put(SOME, minMax); aggMap.put(EVERY, minMax); @@ -1204,6 +1212,67 @@ public class RexImpTable { } } + /** Implementor for the {@code ARG_MIN} and {@code ARG_MAX} aggregate + * functions. */ + static class ArgMinMaxImplementor extends StrictAggImplementor { + @Override protected void implementNotNullReset(AggContext info, + AggResetContext reset) { + // acc[0] = null; + reset.currentBlock().add( + Expressions.statement( + Expressions.assign(reset.accumulator().get(0), + Expressions.constant(null)))); + + final Type compType = info.parameterTypes().get(1); + final Primitive p = Primitive.of(compType); + final boolean isMin = info.aggregation().kind == SqlKind.ARG_MIN; + final Object inf = p == null ? null : (isMin ? p.max : p.min); + //acc[1] = isMin ? {max value} : {min value}; + reset.currentBlock().add( + Expressions.statement( + Expressions.assign(reset.accumulator().get(1), + Expressions.constant(inf, compType)))); + } + + @Override public void implementNotNullAdd(AggContext info, + AggAddContext add) { + Expression accComp = add.accumulator().get(1); + Expression argValue = add.arguments().get(0); + Expression argComp = add.arguments().get(1); + final Type compType = info.parameterTypes().get(1); + final Primitive p = Primitive.of(compType); + final boolean isMin = info.aggregation().kind == SqlKind.ARG_MIN; + + final Method method = + (isMin + ? (p == null ? BuiltInMethod.LT_NULLABLE : BuiltInMethod.LT) + : (p == null ? BuiltInMethod.GT_NULLABLE : BuiltInMethod.GT)) + .method; + Expression compareExpression = + Expressions.call(method.getDeclaringClass(), + method.getName(), + argComp, + accComp); + + final BlockBuilder thenBlock = + new BlockBuilder(true, add.currentBlock()); + thenBlock.add( + Expressions.statement( + Expressions.assign(add.accumulator().get(0), argValue))); + thenBlock.add( + Expressions.statement( + Expressions.assign(add.accumulator().get(1), argComp))); + + add.currentBlock() + .add(Expressions.ifThen(compareExpression, thenBlock.toBlock())); + } + + @Override public List<Type> getNotNullState(AggContext info) { + return ImmutableList.of(Object.class, // the result value + info.parameterTypes().get(1)); // the compare value + } + } + /** Implementor for the {@code SINGLE_VALUE} aggregate function. */ static class SingleValueImplementor implements AggImplementor { @Override public List<Type> getStateType(AggContext info) { diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java index bb617f8cf1..1f5c0f21b2 100644 --- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java +++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java @@ -894,7 +894,7 @@ public class SqlFunctions { /** SQL <code><</code> operator applied to boolean values. */ public static boolean lt(boolean b0, boolean b1) { - return compare(b0, b1) < 0; + return Boolean.compare(b0, b1) < 0; } /** SQL <code><</code> operator applied to String values. */ @@ -917,6 +917,40 @@ public class SqlFunctions { return b0.compareTo(b1) < 0; } + /** Returns whether {@code b0} is less than {@code b1} + * (or {@code b1} is null). Helper for {@code ARG_MIN}. */ + public static <T extends Comparable<T>> boolean ltNullable(T b0, T b1) { + return b1 == null || b0 != null && b0.compareTo(b1) < 0; + } + + public static boolean lt(byte b0, byte b1) { + return b0 < b1; + } + + public static boolean lt(char b0, char b1) { + return b0 < b1; + } + + public static boolean lt(short b0, short b1) { + return b0 < b1; + } + + public static boolean lt(int b0, int b1) { + return b0 < b1; + } + + public static boolean lt(long b0, long b1) { + return b0 < b1; + } + + public static boolean lt(float b0, float b1) { + return b0 < b1; + } + + public static boolean lt(double b0, double b1) { + return b0 < b1; + } + /** SQL <code><</code> operator applied to Object values. */ public static boolean ltAny(Object b0, Object b1) { if (b0.getClass().equals(b1.getClass()) @@ -934,7 +968,7 @@ public class SqlFunctions { /** SQL <code>≤</code> operator applied to boolean values. */ public static boolean le(boolean b0, boolean b1) { - return compare(b0, b1) <= 0; + return Boolean.compare(b0, b1) <= 0; } /** SQL <code>≤</code> operator applied to String values. */ @@ -975,7 +1009,7 @@ public class SqlFunctions { /** SQL <code>></code> operator applied to boolean values. */ public static boolean gt(boolean b0, boolean b1) { - return compare(b0, b1) > 0; + return Boolean.compare(b0, b1) > 0; } /** SQL <code>></code> operator applied to String values. */ @@ -998,6 +1032,40 @@ public class SqlFunctions { return b0.compareTo(b1) > 0; } + /** Returns whether {@code b0} is greater than {@code b1} + * (or {@code b1} is null). Helper for {@code ARG_MAX}. */ + public static <T extends Comparable<T>> boolean gtNullable(T b0, T b1) { + return b1 == null || b0 != null && b0.compareTo(b1) > 0; + } + + public static boolean gt(byte b0, byte b1) { + return b0 > b1; + } + + public static boolean gt(char b0, char b1) { + return b0 > b1; + } + + public static boolean gt(short b0, short b1) { + return b0 > b1; + } + + public static boolean gt(int b0, int b1) { + return b0 > b1; + } + + public static boolean gt(long b0, long b1) { + return b0 > b1; + } + + public static boolean gt(float b0, float b1) { + return b0 > b1; + } + + public static boolean gt(double b0, double b1) { + return b0 > b1; + } + /** SQL <code>></code> operator applied to Object values (at least one * operand has ANY type; neither may be null). */ public static boolean gtAny(Object b0, Object b1) { @@ -1016,7 +1084,7 @@ public class SqlFunctions { /** SQL <code>≥</code> operator applied to boolean values. */ public static boolean ge(boolean b0, boolean b1) { - return compare(b0, b1) >= 0; + return Boolean.compare(b0, b1) >= 0; } /** SQL <code>≥</code> operator applied to String values. */ @@ -1971,11 +2039,6 @@ public class SqlFunctions { return b0 == null || b1 != null && b0.compareTo(b1) < 0 ? b1 : b0; } - /** Boolean comparison. */ - public static int compare(boolean x, boolean y) { - return x == y ? 0 : x ? 1 : -1; - } - /** CAST(FLOAT AS VARCHAR). */ public static String toString(float x) { if (x == 0) { diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java b/core/src/main/java/org/apache/calcite/sql/SqlKind.java index 6e310c7793..269b80e2d7 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java @@ -888,6 +888,12 @@ public enum SqlKind { /** The {@code MODE} aggregate function. */ MODE, + /** The {@code ARG_MAX} aggregate function. */ + ARG_MAX, + + /** The {@code ARG_MIN} aggregate function. */ + ARG_MIN, + /** The {@code PERCENTILE_CONT} aggregate function. */ PERCENTILE_CONT, diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java index 54b07254fa..95623dfc56 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java @@ -128,6 +128,15 @@ public final class SqlBasicAggFunction extends SqlAggFunction { return requireNonNull(super.getOperandTypeChecker(), "operandTypeChecker"); } + /** Sets {@link #getName()}. */ + public SqlAggFunction withName(String name) { + return new SqlBasicAggFunction(name, getSqlIdentifier(), kind, + getReturnTypeInference(), getOperandTypeInference(), + getOperandTypeChecker(), getFunctionType(), requiresOrder(), + requiresOver(), requiresGroupOrder(), distinctOptionality, syntax, + allowsNullTreatment, allowsSeparator, percentile); + } + /** Sets {@link #getDistinctOptionality()}. */ SqlBasicAggFunction withDistinct(Optionality distinctOptionality) { return new SqlBasicAggFunction(getName(), getSqlIdentifier(), kind, diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java index 6d36db0b3e..3bbe6fdd1e 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java @@ -470,6 +470,18 @@ public abstract class SqlLibraryOperators { .withAllowsSeparator(true) .withSyntax(SqlSyntax.ORDERED_FUNCTION); + /** The "MAX_BY(value, comp)" aggregate function, Spark's + * equivalent to {@link SqlStdOperatorTable#ARG_MAX}. */ + @LibraryOperator(libraries = {SPARK}) + public static final SqlAggFunction MAX_BY = + SqlStdOperatorTable.ARG_MAX.withName("MAX_BY"); + + /** The "MIN_BY(condition)" aggregate function, Spark's + * equivalent to {@link SqlStdOperatorTable#ARG_MIN}. */ + @LibraryOperator(libraries = {SPARK}) + public static final SqlAggFunction MIN_BY = + SqlStdOperatorTable.ARG_MIN.withName("MIN_BY"); + /** The "DATE(string)" function, equivalent to "CAST(string AS DATE). */ @LibraryOperator(libraries = {BIG_QUERY}) public static final SqlFunction DATE = diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java index e5604775f7..92e38fd03a 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java @@ -1005,6 +1005,24 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable { public static final SqlAggFunction APPROX_COUNT_DISTINCT = new SqlCountAggFunction("APPROX_COUNT_DISTINCT"); + /** + * <code>ARG_MAX</code> aggregate function. + */ + public static final SqlBasicAggFunction ARG_MAX = + SqlBasicAggFunction.create("ARG_MAX", SqlKind.ARG_MAX, + ReturnTypes.ARG0_NULLABLE_IF_EMPTY, OperandTypes.ANY_COMPARABLE) + .withGroupOrder(Optionality.FORBIDDEN) + .withFunctionType(SqlFunctionCategory.SYSTEM); + + /** + * <code>ARG_MIN</code> aggregate function. + */ + public static final SqlBasicAggFunction ARG_MIN = + SqlBasicAggFunction.create("ARG_MIN", SqlKind.ARG_MIN, + ReturnTypes.ARG0_NULLABLE_IF_EMPTY, OperandTypes.ANY_COMPARABLE) + .withGroupOrder(Optionality.FORBIDDEN) + .withFunctionType(SqlFunctionCategory.SYSTEM); + /** * <code>MIN</code> aggregate function. */ diff --git a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java index 82f813c10d..8f42fbd9eb 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java @@ -730,6 +730,29 @@ public abstract class OperandTypes { public static final SqlSingleOperandTypeChecker ANY_STRING_STRING = family(SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.STRING); + /** + * Operand type-checking strategy used by {@code ARG_MIN(value, comp)} and + * similar functions, where the first operand can have any type and the second + * must be comparable. + */ + public static final SqlOperandTypeChecker ANY_COMPARABLE = + new SqlOperandTypeChecker() { + @Override public boolean checkOperandTypes(SqlCallBinding callBinding, + boolean throwOnFailure) { + getOperandCountRange().isValidCount(callBinding.getOperandCount()); + RelDataType type = callBinding.getOperandType(1); + return type.getComparability() == RelDataTypeComparability.ALL; + } + + @Override public SqlOperandCountRange getOperandCountRange() { + return SqlOperandCountRanges.of(2); + } + + @Override public String getAllowedSignatures(SqlOperator op, String opName) { + return opName + "(<ANY>, <COMPARABLE_TYPE>)"; + } + }; + public static final SqlSingleOperandTypeChecker CURSOR = family(SqlTypeFamily.CURSOR); diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 22a36ebb9a..52ff2bf1ec 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -455,6 +455,12 @@ public enum BuiltInMethod { NOT(SqlFunctions.class, "not", Boolean.class), LESSER(SqlFunctions.class, "lesser", Comparable.class, Comparable.class), GREATER(SqlFunctions.class, "greater", Comparable.class, Comparable.class), + LT_NULLABLE(SqlFunctions.class, "ltNullable", Comparable.class, + Comparable.class), + GT_NULLABLE(SqlFunctions.class, "gtNullable", Comparable.class, + Comparable.class), + LT(SqlFunctions.class, "lt", boolean.class, boolean.class), + GT(SqlFunctions.class, "gt", boolean.class, boolean.class), BIT_AND(SqlFunctions.class, "bitAnd", long.class, long.class), BIT_OR(SqlFunctions.class, "bitOr", long.class, long.class), BIT_XOR(SqlFunctions.class, "bitXor", long.class, long.class), diff --git a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java index 63885e7a10..f7460a6611 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java @@ -4277,6 +4277,32 @@ class SqlToRelConverterTest extends SqlToRelTestBase { sql(sql).ok(); } + @Test void testArgMinFunction() { + final String sql = "select arg_min(ename, deptno)\n" + + "from emp"; + sql(sql).withTrim(true).ok(); + } + + @Test void testArgMinFunctionWithWinAgg() { + final String sql = "select job,\n" + + " arg_min(ename, deptno) over (partition by job order by sal)\n" + + "from emp"; + sql(sql).withTrim(true).ok(); + } + + @Test void testArgMaxFunction() { + final String sql = "select arg_max(ename, deptno)\n" + + "from emp"; + sql(sql).withTrim(true).ok(); + } + + @Test void testArgMaxFunctionWithWinAgg() { + final String sql = "select job,\n" + + " arg_max(ename, deptno) over (partition by job order by sal)\n" + + "from emp"; + sql(sql).withTrim(true).ok(); + } + @Test void testModeFunction() { final String sql = "select mode(deptno)\n" + "from emp"; diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java index 26d5b3cfe1..17cc9306af 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java @@ -7101,6 +7101,16 @@ public class SqlValidatorTest extends SqlValidatorTestCase { sql("SELECT MAX(5) FROM emp").ok(); } + @Test void testArgMinMaxFunctions() { + sql("SELECT ARG_MIN(1, true) from emp").ok(); + sql("SELECT ARG_MAX(2, false) from emp").ok(); + + sql("SELECT ARG_MIN(sal, deptno) FROM emp").ok(); + sql("SELECT ARG_MAX(deptno, sal) FROM emp").ok(); + sql("SELECT ARG_MIN('a', 5.5) FROM emp").ok(); + sql("SELECT ARG_MAX('b', 5) FROM emp").ok(); + } + @Test void testModeFunction() { sql("select MODE(sal) from emp").ok(); sql("select MODE(sal) over (order by empno) from emp").ok(); diff --git a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml index eb23e84764..b69a9a400c 100644 --- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml @@ -363,6 +363,58 @@ LogicalProject(ANYEMPNO=[$1]) LogicalAggregate(group=[{}], ANYEMPNO=[ANY_VALUE($0)]) LogicalProject(EMPNO=[$0]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> + <TestCase name="testArgMaxFunction"> + <Resource name="sql"> + <![CDATA[select arg_max(ename, deptno) +from emp]]> + </Resource> + <Resource name="plan"> + <![CDATA[ +LogicalAggregate(group=[{}], EXPR$0=[ARG_MAX($0, $1)]) + LogicalProject(ENAME=[$1], DEPTNO=[$7]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> + <TestCase name="testArgMaxFunctionWithWinAgg"> + <Resource name="sql"> + <![CDATA[select job, + arg_max(ename, deptno) over (partition by job order by sal) +from emp]]> + </Resource> + <Resource name="plan"> + <![CDATA[ +LogicalProject(JOB=[$2], EXPR$1=[ARG_MAX($1, $7) OVER (PARTITION BY $2 ORDER BY $5)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> + <TestCase name="testArgMinFunction"> + <Resource name="sql"> + <![CDATA[select arg_min(ename, deptno) +from emp]]> + </Resource> + <Resource name="plan"> + <![CDATA[ +LogicalAggregate(group=[{}], EXPR$0=[ARG_MIN($0, $1)]) + LogicalProject(ENAME=[$1], DEPTNO=[$7]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> + <TestCase name="testArgMinFunctionWithWinAgg"> + <Resource name="sql"> + <![CDATA[select job, + arg_min(ename, deptno) over (partition by job order by sal) +from emp]]> + </Resource> + <Resource name="plan"> + <![CDATA[ +LogicalProject(JOB=[$2], EXPR$1=[ARG_MIN($1, $7) OVER (PARTITION BY $2 ORDER BY $5)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> </Resource> </TestCase> diff --git a/core/src/test/resources/sql/agg.iq b/core/src/test/resources/sql/agg.iq index c90c3b11c0..150025bb89 100644 --- a/core/src/test/resources/sql/agg.iq +++ b/core/src/test/resources/sql/agg.iq @@ -3519,5 +3519,87 @@ where +--------+------+--------+------+ (2 rows) +!ok + +# [CALCITE-5283] Add ARG_MIN, ARG_MAX aggregate function + +# ARG_MIN, ARG_MAX without GROUP BY +select arg_min(ename, deptno) as mi, arg_max(ename, deptno) as ma +from emp; ++-------+-------+ +| MI | MA | ++-------+-------+ +| CLARK | ALLEN | ++-------+-------+ +(1 row) + +!ok + +# ARG_MIN, ARG_MAX with DISTINCT +select arg_min(distinct ename, deptno) as mi, arg_max(distinct ename, deptno) as ma +from emp; ++-------+-------+ +| MI | MA | ++-------+-------+ +| CLARK | ALLEN | ++-------+-------+ +(1 row) + +!ok + +# ARG_MIN, ARG_MAX function with WHERE. +select arg_min(ename, deptno) as mi, arg_max(ename, deptno) as ma +from emp +where deptno <= 20; ++-------+-------+ +| MI | MA | ++-------+-------+ +| CLARK | SMITH | ++-------+-------+ +(1 row) + +!ok + +# ARG_MIN, ARG_MAX function with WHERE that removes all rows. +# Result is NULL even though ARG_MIN, ARG_MAX is applied to a not-NULL column. +select arg_min(ename, deptno) as mi, arg_max(ename, deptno) as ma +from emp +where deptno > 60; ++----+----+ +| MI | MA | ++----+----+ +| | | ++----+----+ +(1 row) + +!ok + +# ARG_MIN, ARG_MAX function with GROUP BY. note that key is NULL but result is not NULL. +select deptno, arg_min(ename, ename) as mi, arg_max(ename, ename) as ma +from emp +group by deptno; ++--------+-------+--------+ +| DEPTNO | MI | MA | ++--------+-------+--------+ +| 10 | CLARK | MILLER | +| 20 | ADAMS | SMITH | +| 30 | ALLEN | WARD | ++--------+-------+--------+ +(3 rows) + +!ok + +# ARG_MIN, ARG_MAX applied to an integer. +select arg_min(deptno, empno) as mi, + arg_max(deptno, empno) as ma, + arg_max(deptno, empno) filter (where job = 'MANAGER') as mamgr +from emp; ++----+----+-------+ +| MI | MA | MAMGR | ++----+----+-------+ +| 20 | 10 | 10 | ++----+----+-------+ +(1 row) + !ok # End agg.iq diff --git a/core/src/test/resources/sql/winagg.iq b/core/src/test/resources/sql/winagg.iq index ce4264c053..3187bf158e 100644 --- a/core/src/test/resources/sql/winagg.iq +++ b/core/src/test/resources/sql/winagg.iq @@ -740,5 +740,29 @@ from emp; +--------+-------+---+ (9 rows) +!ok + +# [CALCITE-5283] Add ARG_MIN, ARG_MAX aggregate function + +# ARG_MIN, ARG_MAX function without ORDER BY. +select gender, + arg_min(ename, deptno) over (partition by gender order by ename) as mi, + arg_max(ename, deptno) over (partition by gender order by ename) as ma +from emp; ++--------+-------+-------+ +| GENDER | MI | MA | ++--------+-------+-------+ +| F | Alice | Alice | +| F | Alice | Eve | +| F | Alice | Grace | +| F | Jane | Grace | +| F | Jane | Grace | +| F | Jane | Grace | +| M | Adam | Adam | +| M | Bob | Adam | +| M | Bob | Adam | ++--------+-------+-------+ +(9 rows) + !ok # End winagg.iq diff --git a/site/_docs/reference.md b/site/_docs/reference.md index 4de83ccc24..1fb0dfcfa6 100644 --- a/site/_docs/reference.md +++ b/site/_docs/reference.md @@ -1862,6 +1862,8 @@ and `LISTAGG`). | Operator syntax | Description |:---------------------------------- |:----------- | ANY_VALUE( [ ALL | DISTINCT ] value) | Returns one of the values of *value* across all input values; this is NOT specified in the SQL standard +| ARG_MAX(value, comp) | Returns *value* for the maximum value of *comp* in the group +| ARG_MIN(value, comp) | Returns *value* for the minimum value of *comp* in the group | APPROX_COUNT_DISTINCT(value [, value ]*) | Returns the approximate number of distinct values of *value*; the database is allowed to use an approximation but is not required to | AVG( [ ALL | DISTINCT ] numeric) | Returns the average (arithmetic mean) of *numeric* across all input values | BIT_AND( [ ALL | DISTINCT ] value) | Returns the bitwise AND of all non-null input values, or null if none; integer and binary types are supported @@ -2740,6 +2742,8 @@ Dialect-specific aggregate functions. | m | GROUP_CONCAT( [ ALL | DISTINCT ] value [, value ]* [ ORDER BY orderItem [, orderItem ]* ] [ SEPARATOR separator ] ) | MySQL-specific variant of `LISTAGG` | b | LOGICAL_AND(condition) | Synonym for `EVERY` | b | LOGICAL_OR(condition) | Synonym for `SOME` +| s | MAX_BY(value, comp) | Synonym for `ARG_MAX` +| s | MIN_BY(value, comp) | Synonym for `ARG_MIN` | b p | STRING_AGG( [ ALL | DISTINCT ] value [, separator] [ ORDER BY orderItem [, orderItem ]* ] ) | Synonym for `LISTAGG` Usage Examples: diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index b56c566567..995406dce3 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -9438,6 +9438,24 @@ public class SqlOperatorTest { isSingle("02")); } + @Test void testArgMin() { + final SqlOperatorFixture f0 = fixture().withTester(t -> TESTER); + final String[] xValues = {"2", "3", "4", "4", "5", "7"}; + + final Consumer<SqlOperatorFixture> consumer = f -> { + f.checkAgg("arg_min(mod(x, 3), x)", xValues, isSingle("2")); + f.checkAgg("arg_max(mod(x, 3), x)", xValues, isSingle("1")); + }; + + final Consumer<SqlOperatorFixture> consumer2 = f -> { + f.checkAgg("min_by(mod(x, 3), x)", xValues, isSingle("2")); + f.checkAgg("max_by(mod(x, 3), x)", xValues, isSingle("1")); + }; + + consumer.accept(f0); + consumer2.accept(f0.withLibrary(SqlLibrary.SPARK)); + } + /** * Tests that CAST fails when given a value just outside the valid range for * that type. For example,