This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 20014b6c5b [CALCITE-5283] Add ARG_MIN, ARG_MAX (aka MIN_BY, MAX_BY) 
aggregate functions
20014b6c5b is described below

commit 20014b6c5b9b57d29248206f19d63ef50f7a5c0f
Author: zoudan <zou...@bytedance.com>
AuthorDate: Wed Nov 23 15:42:49 2022 +0800

    [CALCITE-5283] Add ARG_MIN, ARG_MAX (aka MIN_BY, MAX_BY) aggregate functions
    
    Close apache/calcite#2981
---
 .../calcite/adapter/enumerable/RexImpTable.java    | 69 ++++++++++++++++++
 .../org/apache/calcite/runtime/SqlFunctions.java   | 81 ++++++++++++++++++---
 .../main/java/org/apache/calcite/sql/SqlKind.java  |  6 ++
 .../calcite/sql/fun/SqlBasicAggFunction.java       |  9 +++
 .../calcite/sql/fun/SqlLibraryOperators.java       | 12 ++++
 .../calcite/sql/fun/SqlStdOperatorTable.java       | 18 +++++
 .../org/apache/calcite/sql/type/OperandTypes.java  | 23 ++++++
 .../org/apache/calcite/util/BuiltInMethod.java     |  6 ++
 .../apache/calcite/test/SqlToRelConverterTest.java | 26 +++++++
 .../org/apache/calcite/test/SqlValidatorTest.java  | 10 +++
 .../apache/calcite/test/SqlToRelConverterTest.xml  | 52 ++++++++++++++
 core/src/test/resources/sql/agg.iq                 | 82 ++++++++++++++++++++++
 core/src/test/resources/sql/winagg.iq              | 24 +++++++
 site/_docs/reference.md                            |  4 ++
 .../org/apache/calcite/test/SqlOperatorTest.java   | 18 +++++
 15 files changed, 431 insertions(+), 9 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
index 35d563e23e..e61c7731d6 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
@@ -148,7 +148,9 @@ import static 
org.apache.calcite.sql.fun.SqlLibraryOperators.LEFT;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_AND;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_OR;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.LPAD;
+import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAX_BY;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.MD5;
+import static org.apache.calcite.sql.fun.SqlLibraryOperators.MIN_BY;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.MONTHNAME;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.POW;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.REGEXP_REPLACE;
@@ -181,6 +183,8 @@ import static 
org.apache.calcite.sql.fun.SqlStdOperatorTable.ABS;
 import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ACOS;
 import static org.apache.calcite.sql.fun.SqlStdOperatorTable.AND;
 import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ANY_VALUE;
+import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ARG_MAX;
+import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ARG_MIN;
 import static 
org.apache.calcite.sql.fun.SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR;
 import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ASCII;
 import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ASIN;
@@ -734,6 +738,10 @@ public class RexImpTable {
           constructorSupplier(MinMaxImplementor.class);
       aggMap.put(MIN, minMax);
       aggMap.put(MAX, minMax);
+      aggMap.put(ARG_MIN, constructorSupplier(ArgMinMaxImplementor.class));
+      aggMap.put(ARG_MAX, constructorSupplier(ArgMinMaxImplementor.class));
+      aggMap.put(MIN_BY, constructorSupplier(ArgMinMaxImplementor.class));
+      aggMap.put(MAX_BY, constructorSupplier(ArgMinMaxImplementor.class));
       aggMap.put(ANY_VALUE, minMax);
       aggMap.put(SOME, minMax);
       aggMap.put(EVERY, minMax);
@@ -1204,6 +1212,67 @@ public class RexImpTable {
     }
   }
 
+  /** Implementor for the {@code ARG_MIN} and {@code ARG_MAX} aggregate
+   * functions. */
+  static class ArgMinMaxImplementor extends StrictAggImplementor {
+    @Override protected void implementNotNullReset(AggContext info,
+        AggResetContext reset) {
+      // acc[0] = null;
+      reset.currentBlock().add(
+          Expressions.statement(
+              Expressions.assign(reset.accumulator().get(0),
+                  Expressions.constant(null))));
+
+      final Type compType = info.parameterTypes().get(1);
+      final Primitive p = Primitive.of(compType);
+      final boolean isMin = info.aggregation().kind == SqlKind.ARG_MIN;
+      final Object inf = p == null ? null : (isMin ? p.max : p.min);
+      //acc[1] = isMin ? {max value} : {min value};
+      reset.currentBlock().add(
+          Expressions.statement(
+              Expressions.assign(reset.accumulator().get(1),
+                  Expressions.constant(inf, compType))));
+    }
+
+    @Override public void implementNotNullAdd(AggContext info,
+        AggAddContext add) {
+      Expression accComp = add.accumulator().get(1);
+      Expression argValue = add.arguments().get(0);
+      Expression argComp = add.arguments().get(1);
+      final Type compType = info.parameterTypes().get(1);
+      final Primitive p = Primitive.of(compType);
+      final boolean isMin = info.aggregation().kind == SqlKind.ARG_MIN;
+
+      final Method method =
+          (isMin
+              ? (p == null ? BuiltInMethod.LT_NULLABLE : BuiltInMethod.LT)
+              : (p == null ? BuiltInMethod.GT_NULLABLE : BuiltInMethod.GT))
+              .method;
+      Expression compareExpression =
+          Expressions.call(method.getDeclaringClass(),
+              method.getName(),
+              argComp,
+              accComp);
+
+      final BlockBuilder thenBlock =
+          new BlockBuilder(true, add.currentBlock());
+      thenBlock.add(
+          Expressions.statement(
+              Expressions.assign(add.accumulator().get(0), argValue)));
+      thenBlock.add(
+          Expressions.statement(
+              Expressions.assign(add.accumulator().get(1), argComp)));
+
+      add.currentBlock()
+          .add(Expressions.ifThen(compareExpression, thenBlock.toBlock()));
+    }
+
+    @Override public List<Type> getNotNullState(AggContext info) {
+      return ImmutableList.of(Object.class, // the result value
+          info.parameterTypes().get(1)); // the compare value
+    }
+  }
+
   /** Implementor for the {@code SINGLE_VALUE} aggregate function. */
   static class SingleValueImplementor implements AggImplementor {
     @Override public List<Type> getStateType(AggContext info) {
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java 
b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index bb617f8cf1..1f5c0f21b2 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -894,7 +894,7 @@ public class SqlFunctions {
 
   /** SQL <code>&lt;</code> operator applied to boolean values. */
   public static boolean lt(boolean b0, boolean b1) {
-    return compare(b0, b1) < 0;
+    return Boolean.compare(b0, b1) < 0;
   }
 
   /** SQL <code>&lt;</code> operator applied to String values. */
@@ -917,6 +917,40 @@ public class SqlFunctions {
     return b0.compareTo(b1) < 0;
   }
 
+  /** Returns whether {@code b0} is less than {@code b1}
+   * (or {@code b1} is null). Helper for {@code ARG_MIN}. */
+  public static <T extends Comparable<T>> boolean ltNullable(T b0, T b1) {
+    return b1 == null || b0 != null && b0.compareTo(b1) < 0;
+  }
+
+  public static boolean lt(byte b0, byte b1) {
+    return b0 < b1;
+  }
+
+  public static boolean lt(char b0, char b1) {
+    return b0 < b1;
+  }
+
+  public static boolean lt(short b0, short b1) {
+    return b0 < b1;
+  }
+
+  public static boolean lt(int b0, int b1) {
+    return b0 < b1;
+  }
+
+  public static boolean lt(long b0, long b1) {
+    return b0 < b1;
+  }
+
+  public static boolean lt(float b0, float b1) {
+    return b0 < b1;
+  }
+
+  public static boolean lt(double b0, double b1) {
+    return b0 < b1;
+  }
+
   /** SQL <code>&lt;</code> operator applied to Object values. */
   public static boolean ltAny(Object b0, Object b1) {
     if (b0.getClass().equals(b1.getClass())
@@ -934,7 +968,7 @@ public class SqlFunctions {
 
   /** SQL <code>&le;</code> operator applied to boolean values. */
   public static boolean le(boolean b0, boolean b1) {
-    return compare(b0, b1) <= 0;
+    return Boolean.compare(b0, b1) <= 0;
   }
 
   /** SQL <code>&le;</code> operator applied to String values. */
@@ -975,7 +1009,7 @@ public class SqlFunctions {
 
   /** SQL <code>&gt;</code> operator applied to boolean values. */
   public static boolean gt(boolean b0, boolean b1) {
-    return compare(b0, b1) > 0;
+    return Boolean.compare(b0, b1) > 0;
   }
 
   /** SQL <code>&gt;</code> operator applied to String values. */
@@ -998,6 +1032,40 @@ public class SqlFunctions {
     return b0.compareTo(b1) > 0;
   }
 
+  /** Returns whether {@code b0} is greater than {@code b1}
+   * (or {@code b1} is null). Helper for {@code ARG_MAX}. */
+  public static <T extends Comparable<T>> boolean gtNullable(T b0, T b1) {
+    return b1 == null || b0 != null && b0.compareTo(b1) > 0;
+  }
+
+  public static boolean gt(byte b0, byte b1) {
+    return b0 > b1;
+  }
+
+  public static boolean gt(char b0, char b1) {
+    return b0 > b1;
+  }
+
+  public static boolean gt(short b0, short b1) {
+    return b0 > b1;
+  }
+
+  public static boolean gt(int b0, int b1) {
+    return b0 > b1;
+  }
+
+  public static boolean gt(long b0, long b1) {
+    return b0 > b1;
+  }
+
+  public static boolean gt(float b0, float b1) {
+    return b0 > b1;
+  }
+
+  public static boolean gt(double b0, double b1) {
+    return b0 > b1;
+  }
+
   /** SQL <code>&gt;</code> operator applied to Object values (at least one
    * operand has ANY type; neither may be null). */
   public static boolean gtAny(Object b0, Object b1) {
@@ -1016,7 +1084,7 @@ public class SqlFunctions {
 
   /** SQL <code>&ge;</code> operator applied to boolean values. */
   public static boolean ge(boolean b0, boolean b1) {
-    return compare(b0, b1) >= 0;
+    return Boolean.compare(b0, b1) >= 0;
   }
 
   /** SQL <code>&ge;</code> operator applied to String values. */
@@ -1971,11 +2039,6 @@ public class SqlFunctions {
     return b0 == null || b1 != null && b0.compareTo(b1) < 0 ? b1 : b0;
   }
 
-  /** Boolean comparison. */
-  public static int compare(boolean x, boolean y) {
-    return x == y ? 0 : x ? 1 : -1;
-  }
-
   /** CAST(FLOAT AS VARCHAR). */
   public static String toString(float x) {
     if (x == 0) {
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java 
b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
index 6e310c7793..269b80e2d7 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
@@ -888,6 +888,12 @@ public enum SqlKind {
   /** The {@code MODE} aggregate function. */
   MODE,
 
+  /** The {@code ARG_MAX} aggregate function. */
+  ARG_MAX,
+
+  /** The {@code ARG_MIN} aggregate function. */
+  ARG_MIN,
+
   /** The {@code PERCENTILE_CONT} aggregate function. */
   PERCENTILE_CONT,
 
diff --git 
a/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java 
b/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java
index 54b07254fa..95623dfc56 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlBasicAggFunction.java
@@ -128,6 +128,15 @@ public final class SqlBasicAggFunction extends 
SqlAggFunction {
     return requireNonNull(super.getOperandTypeChecker(), "operandTypeChecker");
   }
 
+  /** Sets {@link #getName()}. */
+  public SqlAggFunction withName(String name) {
+    return new SqlBasicAggFunction(name, getSqlIdentifier(), kind,
+        getReturnTypeInference(), getOperandTypeInference(),
+        getOperandTypeChecker(), getFunctionType(), requiresOrder(),
+        requiresOver(), requiresGroupOrder(), distinctOptionality, syntax,
+        allowsNullTreatment, allowsSeparator, percentile);
+  }
+
   /** Sets {@link #getDistinctOptionality()}. */
   SqlBasicAggFunction withDistinct(Optionality distinctOptionality) {
     return new SqlBasicAggFunction(getName(), getSqlIdentifier(), kind,
diff --git 
a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java 
b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 6d36db0b3e..3bbe6fdd1e 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -470,6 +470,18 @@ public abstract class SqlLibraryOperators {
           .withAllowsSeparator(true)
           .withSyntax(SqlSyntax.ORDERED_FUNCTION);
 
+  /** The "MAX_BY(value, comp)" aggregate function, Spark's
+   * equivalent to {@link SqlStdOperatorTable#ARG_MAX}. */
+  @LibraryOperator(libraries = {SPARK})
+  public static final SqlAggFunction MAX_BY =
+      SqlStdOperatorTable.ARG_MAX.withName("MAX_BY");
+
+  /** The "MIN_BY(condition)" aggregate function, Spark's
+   * equivalent to {@link SqlStdOperatorTable#ARG_MIN}. */
+  @LibraryOperator(libraries = {SPARK})
+  public static final SqlAggFunction MIN_BY =
+      SqlStdOperatorTable.ARG_MIN.withName("MIN_BY");
+
   /** The "DATE(string)" function, equivalent to "CAST(string AS DATE). */
   @LibraryOperator(libraries = {BIG_QUERY})
   public static final SqlFunction DATE =
diff --git 
a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java 
b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
index e5604775f7..92e38fd03a 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
@@ -1005,6 +1005,24 @@ public class SqlStdOperatorTable extends 
ReflectiveSqlOperatorTable {
   public static final SqlAggFunction APPROX_COUNT_DISTINCT =
       new SqlCountAggFunction("APPROX_COUNT_DISTINCT");
 
+  /**
+   * <code>ARG_MAX</code> aggregate function.
+   */
+  public static final SqlBasicAggFunction ARG_MAX =
+      SqlBasicAggFunction.create("ARG_MAX", SqlKind.ARG_MAX,
+          ReturnTypes.ARG0_NULLABLE_IF_EMPTY, OperandTypes.ANY_COMPARABLE)
+          .withGroupOrder(Optionality.FORBIDDEN)
+          .withFunctionType(SqlFunctionCategory.SYSTEM);
+
+  /**
+   * <code>ARG_MIN</code> aggregate function.
+   */
+  public static final SqlBasicAggFunction ARG_MIN =
+      SqlBasicAggFunction.create("ARG_MIN", SqlKind.ARG_MIN,
+              ReturnTypes.ARG0_NULLABLE_IF_EMPTY, OperandTypes.ANY_COMPARABLE)
+          .withGroupOrder(Optionality.FORBIDDEN)
+          .withFunctionType(SqlFunctionCategory.SYSTEM);
+
   /**
    * <code>MIN</code> aggregate function.
    */
diff --git a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java 
b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
index 82f813c10d..8f42fbd9eb 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
@@ -730,6 +730,29 @@ public abstract class OperandTypes {
   public static final SqlSingleOperandTypeChecker ANY_STRING_STRING =
       family(SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.STRING);
 
+  /**
+   * Operand type-checking strategy used by {@code ARG_MIN(value, comp)} and
+   * similar functions, where the first operand can have any type and the 
second
+   * must be comparable.
+   */
+  public static final SqlOperandTypeChecker ANY_COMPARABLE =
+      new SqlOperandTypeChecker() {
+        @Override public boolean checkOperandTypes(SqlCallBinding callBinding,
+            boolean throwOnFailure) {
+          getOperandCountRange().isValidCount(callBinding.getOperandCount());
+          RelDataType type = callBinding.getOperandType(1);
+          return type.getComparability() == RelDataTypeComparability.ALL;
+        }
+
+        @Override public SqlOperandCountRange getOperandCountRange() {
+          return SqlOperandCountRanges.of(2);
+        }
+
+        @Override public String getAllowedSignatures(SqlOperator op, String 
opName) {
+          return opName + "(<ANY>, <COMPARABLE_TYPE>)";
+        }
+  };
+
   public static final SqlSingleOperandTypeChecker CURSOR =
       family(SqlTypeFamily.CURSOR);
 
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java 
b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index 22a36ebb9a..52ff2bf1ec 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -455,6 +455,12 @@ public enum BuiltInMethod {
   NOT(SqlFunctions.class, "not", Boolean.class),
   LESSER(SqlFunctions.class, "lesser", Comparable.class, Comparable.class),
   GREATER(SqlFunctions.class, "greater", Comparable.class, Comparable.class),
+  LT_NULLABLE(SqlFunctions.class, "ltNullable", Comparable.class,
+      Comparable.class),
+  GT_NULLABLE(SqlFunctions.class, "gtNullable", Comparable.class,
+      Comparable.class),
+  LT(SqlFunctions.class, "lt", boolean.class, boolean.class),
+  GT(SqlFunctions.class, "gt", boolean.class, boolean.class),
   BIT_AND(SqlFunctions.class, "bitAnd", long.class, long.class),
   BIT_OR(SqlFunctions.class, "bitOr", long.class, long.class),
   BIT_XOR(SqlFunctions.class, "bitXor", long.class, long.class),
diff --git 
a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java 
b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
index 63885e7a10..f7460a6611 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
@@ -4277,6 +4277,32 @@ class SqlToRelConverterTest extends SqlToRelTestBase {
     sql(sql).ok();
   }
 
+  @Test void testArgMinFunction() {
+    final String sql = "select arg_min(ename, deptno)\n"
+        + "from emp";
+    sql(sql).withTrim(true).ok();
+  }
+
+  @Test void testArgMinFunctionWithWinAgg() {
+    final String sql = "select job,\n"
+        + "  arg_min(ename, deptno) over (partition by job order by sal)\n"
+        + "from emp";
+    sql(sql).withTrim(true).ok();
+  }
+
+  @Test void testArgMaxFunction() {
+    final String sql = "select arg_max(ename, deptno)\n"
+        + "from emp";
+    sql(sql).withTrim(true).ok();
+  }
+
+  @Test void testArgMaxFunctionWithWinAgg() {
+    final String sql = "select job,\n"
+        + "  arg_max(ename, deptno) over (partition by job order by sal)\n"
+        + "from emp";
+    sql(sql).withTrim(true).ok();
+  }
+
   @Test void testModeFunction() {
     final String sql = "select mode(deptno)\n"
         + "from emp";
diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java 
b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
index 26d5b3cfe1..17cc9306af 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
@@ -7101,6 +7101,16 @@ public class SqlValidatorTest extends 
SqlValidatorTestCase {
     sql("SELECT MAX(5) FROM emp").ok();
   }
 
+  @Test void testArgMinMaxFunctions() {
+    sql("SELECT ARG_MIN(1, true) from emp").ok();
+    sql("SELECT ARG_MAX(2, false) from emp").ok();
+
+    sql("SELECT ARG_MIN(sal, deptno) FROM emp").ok();
+    sql("SELECT ARG_MAX(deptno, sal) FROM emp").ok();
+    sql("SELECT ARG_MIN('a', 5.5) FROM emp").ok();
+    sql("SELECT ARG_MAX('b', 5) FROM emp").ok();
+  }
+
   @Test void testModeFunction() {
     sql("select MODE(sal) from emp").ok();
     sql("select MODE(sal) over (order by empno) from emp").ok();
diff --git 
a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml 
b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
index eb23e84764..b69a9a400c 100644
--- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
@@ -363,6 +363,58 @@ LogicalProject(ANYEMPNO=[$1])
 LogicalAggregate(group=[{}], ANYEMPNO=[ANY_VALUE($0)])
   LogicalProject(EMPNO=[$0])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testArgMaxFunction">
+    <Resource name="sql">
+      <![CDATA[select arg_max(ename, deptno)
+from emp]]>
+    </Resource>
+    <Resource name="plan">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[ARG_MAX($0, $1)])
+  LogicalProject(ENAME=[$1], DEPTNO=[$7])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testArgMaxFunctionWithWinAgg">
+    <Resource name="sql">
+      <![CDATA[select job,
+  arg_max(ename, deptno) over (partition by job order by sal)
+from emp]]>
+    </Resource>
+    <Resource name="plan">
+      <![CDATA[
+LogicalProject(JOB=[$2], EXPR$1=[ARG_MAX($1, $7) OVER (PARTITION BY $2 ORDER 
BY $5)])
+  LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testArgMinFunction">
+    <Resource name="sql">
+      <![CDATA[select arg_min(ename, deptno)
+from emp]]>
+    </Resource>
+    <Resource name="plan">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[ARG_MIN($0, $1)])
+  LogicalProject(ENAME=[$1], DEPTNO=[$7])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testArgMinFunctionWithWinAgg">
+    <Resource name="sql">
+      <![CDATA[select job,
+  arg_min(ename, deptno) over (partition by job order by sal)
+from emp]]>
+    </Resource>
+    <Resource name="plan">
+      <![CDATA[
+LogicalProject(JOB=[$2], EXPR$1=[ARG_MIN($1, $7) OVER (PARTITION BY $2 ORDER 
BY $5)])
+  LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
   </TestCase>
diff --git a/core/src/test/resources/sql/agg.iq 
b/core/src/test/resources/sql/agg.iq
index c90c3b11c0..150025bb89 100644
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -3519,5 +3519,87 @@ where
 +--------+------+--------+------+
 (2 rows)
 
+!ok
+
+# [CALCITE-5283] Add ARG_MIN, ARG_MAX aggregate function
+
+# ARG_MIN, ARG_MAX without GROUP BY
+select arg_min(ename, deptno) as mi, arg_max(ename, deptno) as ma
+from emp;
++-------+-------+
+| MI    | MA    |
++-------+-------+
+| CLARK | ALLEN |
++-------+-------+
+(1 row)
+
+!ok
+
+# ARG_MIN, ARG_MAX with DISTINCT
+select arg_min(distinct ename, deptno) as mi, arg_max(distinct ename, deptno) 
as ma
+from emp;
++-------+-------+
+| MI    | MA    |
++-------+-------+
+| CLARK | ALLEN |
++-------+-------+
+(1 row)
+
+!ok
+
+# ARG_MIN, ARG_MAX function with WHERE.
+select arg_min(ename, deptno) as mi, arg_max(ename, deptno) as ma
+from emp
+where deptno <= 20;
++-------+-------+
+| MI    | MA    |
++-------+-------+
+| CLARK | SMITH |
++-------+-------+
+(1 row)
+
+!ok
+
+# ARG_MIN, ARG_MAX function with WHERE that removes all rows.
+# Result is NULL even though ARG_MIN, ARG_MAX is applied to a not-NULL column.
+select arg_min(ename, deptno) as mi, arg_max(ename, deptno) as ma
+from emp
+where deptno > 60;
++----+----+
+| MI | MA |
++----+----+
+|    |    |
++----+----+
+(1 row)
+
+!ok
+
+# ARG_MIN, ARG_MAX function with GROUP BY. note that key is NULL but result is 
not NULL.
+select deptno, arg_min(ename, ename) as mi, arg_max(ename, ename) as ma
+from emp
+group by deptno;
++--------+-------+--------+
+| DEPTNO | MI    | MA     |
++--------+-------+--------+
+|     10 | CLARK | MILLER |
+|     20 | ADAMS | SMITH  |
+|     30 | ALLEN | WARD   |
++--------+-------+--------+
+(3 rows)
+
+!ok
+
+# ARG_MIN, ARG_MAX applied to an integer.
+select arg_min(deptno, empno) as mi,
+  arg_max(deptno, empno) as ma,
+  arg_max(deptno, empno) filter (where job = 'MANAGER') as mamgr
+from emp;
++----+----+-------+
+| MI | MA | MAMGR |
++----+----+-------+
+| 20 | 10 |    10 |
++----+----+-------+
+(1 row)
+
 !ok
 # End agg.iq
diff --git a/core/src/test/resources/sql/winagg.iq 
b/core/src/test/resources/sql/winagg.iq
index ce4264c053..3187bf158e 100644
--- a/core/src/test/resources/sql/winagg.iq
+++ b/core/src/test/resources/sql/winagg.iq
@@ -740,5 +740,29 @@ from emp;
 +--------+-------+---+
 (9 rows)
 
+!ok
+
+# [CALCITE-5283] Add ARG_MIN, ARG_MAX aggregate function
+
+# ARG_MIN, ARG_MAX function without ORDER BY.
+select gender,
+  arg_min(ename, deptno) over (partition by gender order by ename) as mi,
+  arg_max(ename, deptno) over (partition by gender order by ename) as ma
+from emp;
++--------+-------+-------+
+| GENDER | MI    | MA    |
++--------+-------+-------+
+| F      | Alice | Alice |
+| F      | Alice | Eve   |
+| F      | Alice | Grace |
+| F      | Jane  | Grace |
+| F      | Jane  | Grace |
+| F      | Jane  | Grace |
+| M      | Adam  | Adam  |
+| M      | Bob   | Adam  |
+| M      | Bob   | Adam  |
++--------+-------+-------+
+(9 rows)
+
 !ok
 # End winagg.iq
diff --git a/site/_docs/reference.md b/site/_docs/reference.md
index 4de83ccc24..1fb0dfcfa6 100644
--- a/site/_docs/reference.md
+++ b/site/_docs/reference.md
@@ -1862,6 +1862,8 @@ and `LISTAGG`).
 | Operator syntax                    | Description
 |:---------------------------------- |:-----------
 | ANY_VALUE( [ ALL &#124; DISTINCT ] value)     | Returns one of the values of 
*value* across all input values; this is NOT specified in the SQL standard
+| ARG_MAX(value, comp)                          | Returns *value* for the 
maximum value of *comp* in the group
+| ARG_MIN(value, comp)                          | Returns *value* for the 
minimum value of *comp* in the group
 | APPROX_COUNT_DISTINCT(value [, value ]*)      | Returns the approximate 
number of distinct values of *value*; the database is allowed to use an 
approximation but is not required to
 | AVG( [ ALL &#124; DISTINCT ] numeric)         | Returns the average 
(arithmetic mean) of *numeric* across all input values
 | BIT_AND( [ ALL &#124; DISTINCT ] value)       | Returns the bitwise AND of 
all non-null input values, or null if none; integer and binary types are 
supported
@@ -2740,6 +2742,8 @@ Dialect-specific aggregate functions.
 | m | GROUP_CONCAT( [ ALL &#124; DISTINCT ] value [, value ]* [ ORDER BY 
orderItem [, orderItem ]* ] [ SEPARATOR separator ] ) | MySQL-specific variant 
of `LISTAGG`
 | b | LOGICAL_AND(condition)                         | Synonym for `EVERY`
 | b | LOGICAL_OR(condition)                          | Synonym for `SOME`
+| s | MAX_BY(value, comp)                            | Synonym for `ARG_MAX`
+| s | MIN_BY(value, comp)                            | Synonym for `ARG_MIN`
 | b p | STRING_AGG( [ ALL &#124; DISTINCT ] value [, separator] [ ORDER BY 
orderItem [, orderItem ]* ] ) | Synonym for `LISTAGG`
 
 Usage Examples:
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java 
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index b56c566567..995406dce3 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -9438,6 +9438,24 @@ public class SqlOperatorTest {
         isSingle("02"));
   }
 
+  @Test void testArgMin() {
+    final SqlOperatorFixture f0 = fixture().withTester(t -> TESTER);
+    final String[] xValues = {"2", "3", "4", "4", "5", "7"};
+
+    final Consumer<SqlOperatorFixture> consumer = f -> {
+      f.checkAgg("arg_min(mod(x, 3), x)", xValues, isSingle("2"));
+      f.checkAgg("arg_max(mod(x, 3), x)", xValues, isSingle("1"));
+    };
+
+    final Consumer<SqlOperatorFixture> consumer2 = f -> {
+      f.checkAgg("min_by(mod(x, 3), x)", xValues, isSingle("2"));
+      f.checkAgg("max_by(mod(x, 3), x)", xValues, isSingle("1"));
+    };
+
+    consumer.accept(f0);
+    consumer2.accept(f0.withLibrary(SqlLibrary.SPARK));
+  }
+
   /**
    * Tests that CAST fails when given a value just outside the valid range for
    * that type. For example,

Reply via email to