This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new f98b250bf5 [Multi-stage] Support is_enable_group_trim agg option 
(#14664)
f98b250bf5 is described below

commit f98b250bf5b15cac6a6b48b878263fb999517b73
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Mon Dec 30 23:02:30 2024 -0800

    [Multi-stage] Support is_enable_group_trim agg option (#14664)
---
 pinot-common/src/main/proto/plan.proto             |   2 +
 .../pinot/calcite/rel/hint/PinotHintOptions.java   |   3 +-
 .../calcite/rel/logical/PinotLogicalAggregate.java |  55 +++---
 .../PinotAggregateExchangeNodeInsertRule.java      | 216 ++++++++++++++++-----
 .../calcite/rel/rules/PinotQueryRuleSets.java      |   4 +-
 .../query/parser/CalciteRexExpressionParser.java   |   4 +-
 .../query/planner/explain/PlanNodeMerger.java      |   6 +
 .../planner/logical/EquivalentStagesFinder.java    |   4 +-
 .../planner/logical/RelToPlanNodeConverter.java    |   2 +-
 .../query/planner/plannode/AggregateNode.java      |  30 ++-
 .../query/planner/serde/PlanNodeDeserializer.java  |   3 +-
 .../query/planner/serde/PlanNodeSerializer.java    |   2 +
 .../src/test/resources/queries/GroupByPlans.json   |  49 +++++
 .../plan/server/ServerPlanRequestVisitor.java      |  39 ++--
 .../runtime/operator/AggregateOperatorTest.java    |   2 +-
 .../runtime/operator/MultiStageAccountingTest.java |   2 +-
 .../src/test/resources/queries/QueryHints.json     |   8 +
 17 files changed, 325 insertions(+), 106 deletions(-)

diff --git a/pinot-common/src/main/proto/plan.proto 
b/pinot-common/src/main/proto/plan.proto
index 49d3573076..e3b2bbf654 100644
--- a/pinot-common/src/main/proto/plan.proto
+++ b/pinot-common/src/main/proto/plan.proto
@@ -69,6 +69,8 @@ message AggregateNode {
   repeated int32 groupKeys = 3;
   AggType aggType = 4;
   bool leafReturnFinalResult = 5;
+  repeated Collation collations = 6;
+  int32 limit = 7;
 }
 
 message FilterNode {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
index 558b2f8985..3c676edd18 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
@@ -42,7 +42,8 @@ public class PinotHintOptions {
   public static class AggregateOptions {
     public static final String IS_PARTITIONED_BY_GROUP_BY_KEYS = 
"is_partitioned_by_group_by_keys";
     public static final String IS_LEAF_RETURN_FINAL_RESULT = 
"is_leaf_return_final_result";
-    public static final String SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION = 
"is_skip_leaf_stage_group_by";
+    public static final String IS_SKIP_LEAF_STAGE_GROUP_BY = 
"is_skip_leaf_stage_group_by";
+    public static final String IS_ENABLE_GROUP_TRIM = "is_enable_group_trim";
 
     public static final String NUM_GROUPS_LIMIT = "num_groups_limit";
     public static final String MAX_INITIAL_RESULT_HOLDER_CAPACITY = 
"max_initial_result_holder_capacity";
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java
index 241c44703e..f9edb412c8 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java
@@ -22,6 +22,7 @@ import java.util.List;
 import javax.annotation.Nullable;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
+import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.RelWriter;
 import org.apache.calcite.rel.core.Aggregate;
@@ -35,39 +36,36 @@ public class PinotLogicalAggregate extends Aggregate {
   private final AggType _aggType;
   private final boolean _leafReturnFinalResult;
 
+  // The following fields are set when group trim is enabled, and are 
extracted from the Sort on top of this Aggregate.
+  private final List<RelFieldCollation> _collations;
+  private final int _limit;
+
   public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, 
List<RelHint> hints, RelNode input,
       ImmutableBitSet groupSet, @Nullable List<ImmutableBitSet> groupSets, 
List<AggregateCall> aggCalls,
-      AggType aggType, boolean leafReturnFinalResult) {
+      AggType aggType, boolean leafReturnFinalResult, @Nullable 
List<RelFieldCollation> collations, int limit) {
     super(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls);
     _aggType = aggType;
     _leafReturnFinalResult = leafReturnFinalResult;
+    _collations = collations;
+    _limit = limit;
   }
 
-  public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, 
List<RelHint> hints, RelNode input,
-      ImmutableBitSet groupSet, @Nullable List<ImmutableBitSet> groupSets, 
List<AggregateCall> aggCalls,
-      AggType aggType) {
-    this(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls, 
aggType, false);
-  }
-
-  public PinotLogicalAggregate(Aggregate aggRel, List<AggregateCall> aggCalls, 
AggType aggType,
-      boolean leafReturnFinalResult) {
-    this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), 
aggRel.getInput(), aggRel.getGroupSet(),
-        aggRel.getGroupSets(), aggCalls, aggType, leafReturnFinalResult);
+  public PinotLogicalAggregate(Aggregate aggRel, RelNode input, 
ImmutableBitSet groupSet,
+      @Nullable List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls, 
AggType aggType,
+      boolean leafReturnFinalResult, @Nullable List<RelFieldCollation> 
collations, int limit) {
+    this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, 
groupSet, groupSets, aggCalls, aggType,
+        leafReturnFinalResult, collations, limit);
   }
 
-  public PinotLogicalAggregate(Aggregate aggRel, List<AggregateCall> aggCalls, 
AggType aggType) {
-    this(aggRel, aggCalls, aggType, false);
-  }
-
-  public PinotLogicalAggregate(Aggregate aggRel, RelNode input, 
List<AggregateCall> aggCalls, AggType aggType) {
-    this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, 
aggRel.getGroupSet(),
-        aggRel.getGroupSets(), aggCalls, aggType);
+  public PinotLogicalAggregate(Aggregate aggRel, RelNode input, 
List<AggregateCall> aggCalls, AggType aggType,
+      boolean leafReturnFinalResult, @Nullable List<RelFieldCollation> 
collations, int limit) {
+    this(aggRel, input, aggRel.getGroupSet(), aggRel.getGroupSets(), aggCalls, 
aggType,
+        leafReturnFinalResult, collations, limit);
   }
 
   public PinotLogicalAggregate(Aggregate aggRel, RelNode input, 
ImmutableBitSet groupSet, List<AggregateCall> aggCalls,
-      AggType aggType, boolean leafReturnFinalResult) {
-    this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, 
groupSet, null, aggCalls, aggType,
-        leafReturnFinalResult);
+      AggType aggType, boolean leafReturnFinalResult, @Nullable 
List<RelFieldCollation> collations, int limit) {
+    this(aggRel, input, groupSet, null, aggCalls, aggType, 
leafReturnFinalResult, collations, limit);
   }
 
   public AggType getAggType() {
@@ -78,11 +76,20 @@ public class PinotLogicalAggregate extends Aggregate {
     return _leafReturnFinalResult;
   }
 
+  @Nullable
+  public List<RelFieldCollation> getCollations() {
+    return _collations;
+  }
+
+  public int getLimit() {
+    return _limit;
+  }
+
   @Override
   public PinotLogicalAggregate copy(RelTraitSet traitSet, RelNode input, 
ImmutableBitSet groupSet,
       @Nullable List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) 
{
     return new PinotLogicalAggregate(getCluster(), traitSet, hints, input, 
groupSet, groupSets, aggCalls, _aggType,
-        _leafReturnFinalResult);
+        _leafReturnFinalResult, _collations, _limit);
   }
 
   @Override
@@ -90,12 +97,14 @@ public class PinotLogicalAggregate extends Aggregate {
     RelWriter relWriter = super.explainTerms(pw);
     relWriter.item("aggType", _aggType);
     relWriter.itemIf("leafReturnFinalResult", true, _leafReturnFinalResult);
+    relWriter.itemIf("collations", _collations, _collations != null);
+    relWriter.itemIf("limit", _limit, _limit > 0);
     return relWriter;
   }
 
   @Override
   public RelNode withHints(List<RelHint> hintList) {
     return new PinotLogicalAggregate(getCluster(), traitSet, hintList, input, 
groupSet, groupSets, aggCalls, _aggType,
-        _leafReturnFinalResult);
+        _leafReturnFinalResult, _collations, _limit);
   }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
index df11fdb49a..84b2a274aa 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
@@ -28,10 +28,12 @@ import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelCollation;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.calcite.rel.RelDistributions;
+import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.Sort;
 import org.apache.calcite.rel.core.Union;
 import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.rel.rules.AggregateExtractProjectRule;
@@ -82,49 +84,161 @@ import 
org.apache.pinot.segment.spi.AggregationFunctionType;
  * - COUNT(*) with a GROUP_BY_KEY transforms into: COUNT(*)__LEAF --> 
COUNT(*)__FINAL, where
  *   - COUNT(*)__LEAF produces TUPLE[ SUM(1), GROUP_BY_KEY ]
  *   - COUNT(*)__FINAL produces TUPLE[ SUM(COUNT(*)__LEAF), GROUP_BY_KEY ]
+ *
+ * There are 3 sub-rules:
+ * 1. {@link SortProjectAggregate}:
+ *   Matches the case when there's a Sort on top of Project on top of 
Aggregate, and enable group trim hint is present.
+ *   E.g.
+ *     SELECT /*+ aggOptions(is_enable_group_trim='true') * /
+ *     COUNT(*) AS cnt, col1 FROM myTable GROUP BY col1 ORDER BY cnt DESC 
LIMIT 10
+ *   It will extract the collations and limit from the Sort node, and set them 
into the Aggregate node. It works only
+ *   when the sort key is a direct reference to the input, i.e. no transform 
on the input columns.
+ * 2. {@link SortAggregate}:
+ *   Matches the case when there's a Sort on top of Aggregate, and enable 
group trim hint is present.
+ *   E.g.
+ *     SELECT /*+ aggOptions(is_enable_group_trim='true') * /
+ *     col1, COUNT(*) AS cnt FROM myTable GROUP BY col1 ORDER BY cnt DESC 
LIMIT 10
+ *   It will extract the collations and limit from the Sort node, and set them 
into the Aggregate node.
+ * 3. {@link WithoutSort}:
+ *   Matches Aggregate node if there is no match of {@link 
SortProjectAggregate} or {@link SortAggregate}.
+ *
+ * TODO:
+ *   1. Always enable group trim when the result is guaranteed to be accurate
+ *   2. Add intermediate stage group trim
+ *   3. Allow tuning group trim parameters with query hint
  */
-public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
-  public static final PinotAggregateExchangeNodeInsertRule INSTANCE =
-      new 
PinotAggregateExchangeNodeInsertRule(PinotRuleUtils.PINOT_REL_FACTORY);
-
-  public PinotAggregateExchangeNodeInsertRule(RelBuilderFactory factory) {
-    // NOTE: Explicitly match for LogicalAggregate because after applying the 
rule, LogicalAggregate is replaced with
-    //       PinotLogicalAggregate, and the rule won't be applied again.
-    super(operand(LogicalAggregate.class, any()), factory, null);
+public class PinotAggregateExchangeNodeInsertRule {
+
+  public static class SortProjectAggregate extends RelOptRule {
+    public static final SortProjectAggregate INSTANCE = new 
SortProjectAggregate(PinotRuleUtils.PINOT_REL_FACTORY);
+
+    private SortProjectAggregate(RelBuilderFactory factory) {
+      // NOTE: Explicitly match for LogicalAggregate because after applying 
the rule, LogicalAggregate is replaced with
+      //       PinotLogicalAggregate, and the rule won't be applied again.
+      super(operand(Sort.class, operand(Project.class, 
operand(LogicalAggregate.class, any()))), factory, null);
+    }
+
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+      LogicalAggregate aggRel = call.rel(2);
+      if (aggRel.getGroupSet().isEmpty()) {
+        return;
+      }
+      Map<String, String> hintOptions =
+          PinotHintStrategyTable.getHintOptions(aggRel.getHints(), 
PinotHintOptions.AGGREGATE_HINT_OPTIONS);
+      if (hintOptions == null || !Boolean.parseBoolean(
+          
hintOptions.get(PinotHintOptions.AggregateOptions.IS_ENABLE_GROUP_TRIM))) {
+        return;
+      }
+
+      Sort sortRel = call.rel(0);
+      Project projectRel = call.rel(1);
+      List<RexNode> projects = projectRel.getProjects();
+      List<RelFieldCollation> collations = 
sortRel.getCollation().getFieldCollations();
+      List<RelFieldCollation> newCollations = new 
ArrayList<>(collations.size());
+      for (RelFieldCollation fieldCollation : collations) {
+        RexNode project = projects.get(fieldCollation.getFieldIndex());
+        if (project instanceof RexInputRef) {
+          newCollations.add(fieldCollation.withFieldIndex(((RexInputRef) 
project).getIndex()));
+        } else {
+          // Cannot enable group trim when the sort key is not a direct 
reference to the input.
+          return;
+        }
+      }
+      int limit = 0;
+      if (sortRel.fetch != null) {
+        limit = RexLiteral.intValue(sortRel.fetch);
+      }
+      if (limit <= 0) {
+        // Cannot enable group trim when there is no limit.
+        return;
+      }
+
+      PinotLogicalAggregate newAggRel = createPlan(call, aggRel, true, 
hintOptions, newCollations, limit);
+      RelNode newProjectRel = projectRel.copy(projectRel.getTraitSet(), 
List.of(newAggRel));
+      call.transformTo(sortRel.copy(sortRel.getTraitSet(), 
List.of(newProjectRel)));
+    }
   }
 
-  /**
-   * Split the AGG into 3 plan fragments, all with the same AGG type (in some 
cases the final agg name may be different)
-   * Pinot internal plan fragment optimization can use the info of the input 
data type to infer whether it should
-   * generate the "final-stage AGG operator" or "intermediate-stage AGG 
operator" or "leaf-stage AGG operator"
-   *
-   * @param call the {@link RelOptRuleCall} on match.
-   * @see org.apache.pinot.core.query.aggregation.function.AggregationFunction
-   */
-  @Override
-  public void onMatch(RelOptRuleCall call) {
-    Aggregate aggRel = call.rel(0);
-    boolean hasGroupBy = !aggRel.getGroupSet().isEmpty();
-    RelCollation collation = extractWithInGroupCollation(aggRel);
-    Map<String, String> hintOptions =
-        PinotHintStrategyTable.getHintOptions(aggRel.getHints(), 
PinotHintOptions.AGGREGATE_HINT_OPTIONS);
-    // Collation is not supported in leaf stage aggregation.
-    if (collation != null || (hasGroupBy && hintOptions != null && 
Boolean.parseBoolean(
-        
hintOptions.get(PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION))))
 {
-      call.transformTo(createPlanWithExchangeDirectAggregation(call, 
collation));
-    } else if (hasGroupBy && hintOptions != null && Boolean.parseBoolean(
+  public static class SortAggregate extends RelOptRule {
+    public static final SortAggregate INSTANCE = new 
SortAggregate(PinotRuleUtils.PINOT_REL_FACTORY);
+
+    private SortAggregate(RelBuilderFactory factory) {
+      // NOTE: Explicitly match for LogicalAggregate because after applying 
the rule, LogicalAggregate is replaced with
+      //       PinotLogicalAggregate, and the rule won't be applied again.
+      super(operand(Sort.class, operand(LogicalAggregate.class, any())), 
factory, null);
+    }
+
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+      LogicalAggregate aggRel = call.rel(1);
+      if (aggRel.getGroupSet().isEmpty()) {
+        return;
+      }
+      Map<String, String> hintOptions =
+          PinotHintStrategyTable.getHintOptions(aggRel.getHints(), 
PinotHintOptions.AGGREGATE_HINT_OPTIONS);
+      if (hintOptions == null || !Boolean.parseBoolean(
+          
hintOptions.get(PinotHintOptions.AggregateOptions.IS_ENABLE_GROUP_TRIM))) {
+        return;
+      }
+
+      Sort sortRel = call.rel(0);
+      List<RelFieldCollation> collations = 
sortRel.getCollation().getFieldCollations();
+      int limit = 0;
+      if (sortRel.fetch != null) {
+        limit = RexLiteral.intValue(sortRel.fetch);
+      }
+      if (limit <= 0) {
+        // Cannot enable group trim when there is no limit.
+        return;
+      }
+
+      PinotLogicalAggregate newAggRel = createPlan(call, aggRel, true, 
hintOptions, collations, limit);
+      call.transformTo(sortRel.copy(sortRel.getTraitSet(), 
List.of(newAggRel)));
+    }
+  }
+
+  public static class WithoutSort extends RelOptRule {
+    public static final WithoutSort INSTANCE = new 
WithoutSort(PinotRuleUtils.PINOT_REL_FACTORY);
+
+    private WithoutSort(RelBuilderFactory factory) {
+      // NOTE: Explicitly match for LogicalAggregate because after applying 
the rule, LogicalAggregate is replaced with
+      //       PinotLogicalAggregate, and the rule won't be applied again.
+      super(operand(LogicalAggregate.class, any()), factory, null);
+    }
+
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+      Aggregate aggRel = call.rel(0);
+      Map<String, String> hintOptions =
+          PinotHintStrategyTable.getHintOptions(aggRel.getHints(), 
PinotHintOptions.AGGREGATE_HINT_OPTIONS);
+      call.transformTo(
+          createPlan(call, aggRel, !aggRel.getGroupSet().isEmpty(), 
hintOptions != null ? hintOptions : Map.of(), null,
+              0));
+    }
+  }
+
+  private static PinotLogicalAggregate createPlan(RelOptRuleCall call, 
Aggregate aggRel, boolean hasGroupBy,
+      Map<String, String> hintOptions, @Nullable List<RelFieldCollation> 
collations, int limit) {
+    // WITHIN GROUP collation is not supported in leaf stage aggregation.
+    RelCollation withinGroupCollation = extractWithinGroupCollation(aggRel);
+    if (withinGroupCollation != null || (hasGroupBy && Boolean.parseBoolean(
+        
hintOptions.get(PinotHintOptions.AggregateOptions.IS_SKIP_LEAF_STAGE_GROUP_BY))))
 {
+      return createPlanWithExchangeDirectAggregation(call, aggRel, 
withinGroupCollation, collations, limit);
+    } else if (hasGroupBy && Boolean.parseBoolean(
         
hintOptions.get(PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS)))
 {
-      call.transformTo(new PinotLogicalAggregate(aggRel, buildAggCalls(aggRel, 
AggType.DIRECT, false), AggType.DIRECT));
+      return new PinotLogicalAggregate(aggRel, aggRel.getInput(), 
buildAggCalls(aggRel, AggType.DIRECT, false),
+          AggType.DIRECT, false, collations, limit);
     } else {
-      boolean leafReturnFinalResult = hintOptions != null && 
Boolean.parseBoolean(
-          
hintOptions.get(PinotHintOptions.AggregateOptions.IS_LEAF_RETURN_FINAL_RESULT));
-      call.transformTo(createPlanWithLeafExchangeFinalAggregate(call, 
leafReturnFinalResult));
+      boolean leafReturnFinalResult =
+          
Boolean.parseBoolean(hintOptions.get(PinotHintOptions.AggregateOptions.IS_LEAF_RETURN_FINAL_RESULT));
+      return createPlanWithLeafExchangeFinalAggregate(aggRel, 
leafReturnFinalResult, collations, limit);
     }
   }
 
   // TODO: Currently it only handles one WITHIN GROUP collation across all 
AggregateCalls.
   @Nullable
-  private static RelCollation extractWithInGroupCollation(Aggregate aggRel) {
+  private static RelCollation extractWithinGroupCollation(Aggregate aggRel) {
     for (AggregateCall aggCall : aggRel.getAggCallList()) {
       RelCollation collation = aggCall.getCollation();
       if (!collation.getFieldCollations().isEmpty()) {
@@ -138,55 +252,54 @@ public class PinotAggregateExchangeNodeInsertRule extends 
RelOptRule {
    * Use this group by optimization to skip leaf stage aggregation when 
aggregating at leaf level is not desired. Many
    * situation could be wasted effort to do group-by on leaf, eg: when 
cardinality of group by column is very high.
    */
-  private static PinotLogicalAggregate 
createPlanWithExchangeDirectAggregation(RelOptRuleCall call,
-      @Nullable RelCollation collation) {
-    Aggregate aggRel = call.rel(0);
+  private static PinotLogicalAggregate 
createPlanWithExchangeDirectAggregation(RelOptRuleCall call, Aggregate aggRel,
+      @Nullable RelCollation withinGroupCollation, @Nullable 
List<RelFieldCollation> collations, int limit) {
     RelNode input = aggRel.getInput();
     // Create Project when there's none below the aggregate.
     if (!(PinotRuleUtils.unboxRel(input) instanceof Project)) {
-      aggRel = (Aggregate) generateProjectUnderAggregate(call);
+      aggRel = (Aggregate) generateProjectUnderAggregate(call, aggRel);
       input = aggRel.getInput();
     }
 
     ImmutableBitSet groupSet = aggRel.getGroupSet();
     RelDistribution distribution = RelDistributions.hash(groupSet.asList());
     RelNode exchange;
-    if (collation != null) {
+    if (withinGroupCollation != null) {
       // Insert a LogicalSort node between exchange and aggregate whe 
collation exists.
-      exchange = PinotLogicalSortExchange.create(input, distribution, 
collation, false, true);
+      exchange = PinotLogicalSortExchange.create(input, distribution, 
withinGroupCollation, false, true);
     } else {
       exchange = PinotLogicalExchange.create(input, distribution);
     }
 
-    return new PinotLogicalAggregate(aggRel, exchange, buildAggCalls(aggRel, 
AggType.DIRECT, false), AggType.DIRECT);
+    return new PinotLogicalAggregate(aggRel, exchange, buildAggCalls(aggRel, 
AggType.DIRECT, false), AggType.DIRECT,
+        false, collations, limit);
   }
 
   /**
    * Aggregate node will be split into LEAF + EXCHANGE + FINAL.
    * TODO: Add optional INTERMEDIATE stage to reduce hotspot.
    */
-  private static PinotLogicalAggregate 
createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall call,
-      boolean leafReturnFinalResult) {
-    Aggregate aggRel = call.rel(0);
+  private static PinotLogicalAggregate 
createPlanWithLeafExchangeFinalAggregate(Aggregate aggRel,
+      boolean leafReturnFinalResult, @Nullable List<RelFieldCollation> 
collations, int limit) {
     // Create a LEAF aggregate.
     PinotLogicalAggregate leafAggRel =
-        new PinotLogicalAggregate(aggRel, buildAggCalls(aggRel, AggType.LEAF, 
leafReturnFinalResult), AggType.LEAF,
-            leafReturnFinalResult);
+        new PinotLogicalAggregate(aggRel, aggRel.getInput(), 
buildAggCalls(aggRel, AggType.LEAF, leafReturnFinalResult),
+            AggType.LEAF, leafReturnFinalResult, collations, limit);
     // Create an EXCHANGE node over the LEAF aggregate.
     PinotLogicalExchange exchange = PinotLogicalExchange.create(leafAggRel,
         RelDistributions.hash(ImmutableIntList.range(0, 
aggRel.getGroupCount())));
     // Create a FINAL aggregate over the EXCHANGE.
-    return convertAggFromIntermediateInput(call, exchange, AggType.FINAL, 
leafReturnFinalResult);
+    return convertAggFromIntermediateInput(aggRel, exchange, AggType.FINAL, 
leafReturnFinalResult, collations, limit);
   }
 
   /**
    * The following is copied from {@link 
AggregateExtractProjectRule#onMatch(RelOptRuleCall)} with modification to take
    * aggregate input as input.
    */
-  private static RelNode generateProjectUnderAggregate(RelOptRuleCall call) {
-    final Aggregate aggregate = call.rel(0);
+  private static RelNode generateProjectUnderAggregate(RelOptRuleCall call, 
Aggregate aggregate) {
     // --------------- MODIFIED ---------------
     final RelNode input = aggregate.getInput();
+    // final Aggregate aggregate = call.rel(0);
     // final RelNode input = call.rel(1);
     // ------------- END MODIFIED -------------
 
@@ -230,9 +343,8 @@ public class PinotAggregateExchangeNodeInsertRule extends 
RelOptRule {
     return relBuilder.build();
   }
 
-  private static PinotLogicalAggregate 
convertAggFromIntermediateInput(RelOptRuleCall call,
-      PinotLogicalExchange exchange, AggType aggType, boolean 
leafReturnFinalResult) {
-    Aggregate aggRel = call.rel(0);
+  private static PinotLogicalAggregate 
convertAggFromIntermediateInput(Aggregate aggRel, PinotLogicalExchange exchange,
+      AggType aggType, boolean leafReturnFinalResult, @Nullable 
List<RelFieldCollation> collations, int limit) {
     RelNode input = aggRel.getInput();
     List<RexNode> projects = findImmediateProjects(input);
 
@@ -269,7 +381,7 @@ public class PinotAggregateExchangeNodeInsertRule extends 
RelOptRule {
     }
 
     return new PinotLogicalAggregate(aggRel, exchange, 
ImmutableBitSet.range(groupCount), aggCalls, aggType,
-        leafReturnFinalResult);
+        leafReturnFinalResult, collations, limit);
   }
 
   private static List<AggregateCall> buildAggCalls(Aggregate aggRel, AggType 
aggType, boolean leafReturnFinalResult) {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
index e831e7460a..80e524e11f 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
@@ -136,7 +136,9 @@ public class PinotQueryRuleSets {
 
       PinotSingleValueAggregateRemoveRule.INSTANCE,
       PinotJoinExchangeNodeInsertRule.INSTANCE,
-      PinotAggregateExchangeNodeInsertRule.INSTANCE,
+      PinotAggregateExchangeNodeInsertRule.SortProjectAggregate.INSTANCE,
+      PinotAggregateExchangeNodeInsertRule.SortAggregate.INSTANCE,
+      PinotAggregateExchangeNodeInsertRule.WithoutSort.INSTANCE,
       PinotWindowExchangeNodeInsertRule.INSTANCE,
       PinotSetOpExchangeNodeInsertRule.INSTANCE,
 
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
index a20b2479d4..fdd19a9aef 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
@@ -29,7 +29,6 @@ import org.apache.pinot.common.request.PinotQuery;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.common.utils.request.RequestUtils;
 import org.apache.pinot.query.planner.logical.RexExpression;
-import org.apache.pinot.query.planner.plannode.SortNode;
 import org.apache.pinot.spi.utils.BooleanUtils;
 import org.apache.pinot.spi.utils.ByteArray;
 import org.apache.pinot.sql.parsers.ParserUtils;
@@ -96,8 +95,7 @@ public class CalciteRexExpressionParser {
     return expressions;
   }
 
-  public static List<Expression> convertOrderByList(SortNode node, PinotQuery 
pinotQuery) {
-    List<RelFieldCollation> collations = node.getCollations();
+  public static List<Expression> convertOrderByList(List<RelFieldCollation> 
collations, PinotQuery pinotQuery) {
     List<Expression> orderByExpressions = new ArrayList<>(collations.size());
     for (RelFieldCollation collation : collations) {
       orderByExpressions.add(convertOrderBy(collation, pinotQuery));
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java
index 611d441725..6ae02da45f 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java
@@ -147,6 +147,12 @@ class PlanNodeMerger {
       if (node.isLeafReturnFinalResult() != 
otherNode.isLeafReturnFinalResult()) {
         return null;
       }
+      if (!node.getCollations().equals(otherNode.getCollations())) {
+        return null;
+      }
+      if (node.getLimit() != otherNode.getLimit()) {
+        return null;
+      }
       List<PlanNode> children = mergeChildren(node, context);
       if (children == null) {
         return null;
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java
index 55813264ff..61cf5d5be6 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java
@@ -195,7 +195,9 @@ public class EquivalentStagesFinder {
             && Objects.equals(node1.getFilterArgs(), that.getFilterArgs())
             && Objects.equals(node1.getGroupKeys(), that.getGroupKeys())
             && node1.getAggType() == that.getAggType()
-            && node1.isLeafReturnFinalResult() == 
that.isLeafReturnFinalResult();
+            && node1.isLeafReturnFinalResult() == 
that.isLeafReturnFinalResult()
+            && Objects.equals(node1.getCollations(), that.getCollations())
+            && node1.getLimit() == that.getLimit();
       }
 
       @Override
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
index 3817011612..3f5ab2261e 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
@@ -264,7 +264,7 @@ public final class RelToPlanNodeConverter {
     }
     return new AggregateNode(DEFAULT_STAGE_ID, 
toDataSchema(node.getRowType()), NodeHint.fromRelHints(node.getHints()),
         convertInputs(node.getInputs()), functionCalls, filterArgs, 
node.getGroupSet().asList(), node.getAggType(),
-        node.isLeafReturnFinalResult());
+        node.isLeafReturnFinalResult(), node.getCollations(), node.getLimit());
   }
 
   private ProjectNode convertLogicalProject(LogicalProject node) {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java
index be4a6d9fb8..5e6fda1e1b 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java
@@ -20,6 +20,8 @@ package org.apache.pinot.query.planner.plannode;
 
 import java.util.List;
 import java.util.Objects;
+import javax.annotation.Nullable;
+import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.logical.RexExpression;
 
@@ -31,15 +33,22 @@ public class AggregateNode extends BasePlanNode {
   private final AggType _aggType;
   private final boolean _leafReturnFinalResult;
 
+  // The following fields are set when group trim is enabled, and are 
extracted from the Sort on top of this Aggregate.
+  // The group trim behavior at leaf stage is shared with single-stage engine.
+  private final List<RelFieldCollation> _collations;
+  private final int _limit;
+
   public AggregateNode(int stageId, DataSchema dataSchema, NodeHint nodeHint, 
List<PlanNode> inputs,
       List<RexExpression.FunctionCall> aggCalls, List<Integer> filterArgs, 
List<Integer> groupKeys, AggType aggType,
-      boolean leafReturnFinalResult) {
+      boolean leafReturnFinalResult, @Nullable List<RelFieldCollation> 
collations, int limit) {
     super(stageId, dataSchema, nodeHint, inputs);
     _aggCalls = aggCalls;
     _filterArgs = filterArgs;
     _groupKeys = groupKeys;
     _aggType = aggType;
     _leafReturnFinalResult = leafReturnFinalResult;
+    _collations = collations != null ? collations : List.of();
+    _limit = limit;
   }
 
   public List<RexExpression.FunctionCall> getAggCalls() {
@@ -62,6 +71,14 @@ public class AggregateNode extends BasePlanNode {
     return _leafReturnFinalResult;
   }
 
+  public List<RelFieldCollation> getCollations() {
+    return _collations;
+  }
+
+  public int getLimit() {
+    return _limit;
+  }
+
   @Override
   public String explain() {
     return "AGGREGATE_" + _aggType;
@@ -75,7 +92,7 @@ public class AggregateNode extends BasePlanNode {
   @Override
   public PlanNode withInputs(List<PlanNode> inputs) {
     return new AggregateNode(_stageId, _dataSchema, _nodeHint, inputs, 
_aggCalls, _filterArgs, _groupKeys, _aggType,
-        _leafReturnFinalResult);
+        _leafReturnFinalResult, _collations, _limit);
   }
 
   @Override
@@ -90,14 +107,15 @@ public class AggregateNode extends BasePlanNode {
       return false;
     }
     AggregateNode that = (AggregateNode) o;
-    return Objects.equals(_aggCalls, that._aggCalls) && 
Objects.equals(_filterArgs, that._filterArgs) && Objects.equals(
-        _groupKeys, that._groupKeys) && _aggType == that._aggType
-        && _leafReturnFinalResult == that._leafReturnFinalResult;
+    return _leafReturnFinalResult == that._leafReturnFinalResult && _limit == 
that._limit && Objects.equals(_aggCalls,
+        that._aggCalls) && Objects.equals(_filterArgs, that._filterArgs) && 
Objects.equals(_groupKeys, that._groupKeys)
+        && _aggType == that._aggType && Objects.equals(_collations, 
that._collations);
   }
 
   @Override
   public int hashCode() {
-    return Objects.hash(super.hashCode(), _aggCalls, _filterArgs, _groupKeys, 
_aggType, _leafReturnFinalResult);
+    return Objects.hash(super.hashCode(), _aggCalls, _filterArgs, _groupKeys, 
_aggType, _leafReturnFinalResult,
+        _collations, _limit);
   }
 
   /**
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java
index abd474ebce..0f68514189 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java
@@ -87,7 +87,8 @@ public class PlanNodeDeserializer {
     return new AggregateNode(protoNode.getStageId(), 
extractDataSchema(protoNode), extractNodeHint(protoNode),
         extractInputs(protoNode), 
convertFunctionCalls(protoAggregateNode.getAggCallsList()),
         protoAggregateNode.getFilterArgsList(), 
protoAggregateNode.getGroupKeysList(),
-        convertAggType(protoAggregateNode.getAggType()), 
protoAggregateNode.getLeafReturnFinalResult());
+        convertAggType(protoAggregateNode.getAggType()), 
protoAggregateNode.getLeafReturnFinalResult(),
+        convertCollations(protoAggregateNode.getCollationsList()), 
protoAggregateNode.getLimit());
   }
 
   private static FilterNode deserializeFilterNode(Plan.PlanNode protoNode) {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java
index 65ccb13b2c..e7862173e7 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java
@@ -98,6 +98,8 @@ public class PlanNodeSerializer {
           .addAllGroupKeys(node.getGroupKeys())
           .setAggType(convertAggType(node.getAggType()))
           .setLeafReturnFinalResult(node.isLeafReturnFinalResult())
+          .addAllCollations(convertCollations(node.getCollations()))
+          .setLimit(node.getLimit())
           .build();
       builder.setAggregateNode(aggregateNode);
       return null;
diff --git a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json 
b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
index 63a69f5e8e..8e513066d9 100644
--- a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
@@ -249,6 +249,55 @@
           "\n              LogicalTableScan(table=[[default, a]])",
           "\n"
         ]
+      },
+      {
+        "description": "SQL hint based group by optimization with partitioned 
aggregated values and group trim enabled",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptions(is_leaf_return_final_result='true', is_enable_group_trim='true') */ 
col1, COUNT(DISTINCT col2) AS cnt FROM a WHERE col3 >= 0 GROUP BY col1 ORDER BY 
cnt DESC LIMIT 10",
+        "output": [
+          "Execution Plan",
+          "\nLogicalSort(sort0=[$1], dir0=[DESC], offset=[0], fetch=[10])",
+          "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[1 
DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
+          "\n    LogicalSort(sort0=[$1], dir0=[DESC], fetch=[10])",
+          "\n      PinotLogicalAggregate(group=[{0}], 
agg#0=[DISTINCTCOUNT($1)], aggType=[FINAL], leafReturnFinalResult=[true], 
collations=[[1 DESC]], limit=[10])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          PinotLogicalAggregate(group=[{0}], 
agg#0=[DISTINCTCOUNT($1)], aggType=[LEAF], leafReturnFinalResult=[true], 
collations=[[1 DESC]], limit=[10])",
+          "\n            LogicalFilter(condition=[>=($2, 0)])",
+          "\n              LogicalTableScan(table=[[default, a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "SQL hint based group by optimization with group trim 
enabled without returning group key",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptions(is_enable_group_trim='true') */ COUNT(DISTINCT col2) AS cnt FROM a 
WHERE a.col3 >= 0 GROUP BY col1 ORDER BY cnt DESC LIMIT 10",
+        "output": [
+          "Execution Plan",
+          "\nLogicalSort(sort0=[$0], dir0=[DESC], offset=[0], fetch=[10])",
+          "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[0 
DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
+          "\n    LogicalSort(sort0=[$0], dir0=[DESC], fetch=[10])",
+          "\n      LogicalProject(cnt=[$1])",
+          "\n        PinotLogicalAggregate(group=[{0}], 
agg#0=[DISTINCTCOUNT($1)], aggType=[FINAL], collations=[[1 DESC]], limit=[10])",
+          "\n          PinotLogicalExchange(distribution=[hash[0]])",
+          "\n            PinotLogicalAggregate(group=[{0}], 
agg#0=[DISTINCTCOUNT($1)], aggType=[LEAF], collations=[[1 DESC]], limit=[10])",
+          "\n              LogicalFilter(condition=[>=($2, 0)])",
+          "\n                LogicalTableScan(table=[[default, a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "SQL hint based distinct optimization with group trim 
enabled",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptions(is_enable_group_trim='true') */ DISTINCT col1, col2 FROM a WHERE 
col3 >= 0 LIMIT 10",
+        "output": [
+          "Execution Plan",
+          "\nLogicalSort(offset=[0], fetch=[10])",
+          "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[]], 
isSortOnSender=[false], isSortOnReceiver=[false])",
+          "\n    LogicalSort(fetch=[10])",
+          "\n      PinotLogicalAggregate(group=[{0, 1}], aggType=[FINAL], 
collations=[[]], limit=[10])",
+          "\n        PinotLogicalExchange(distribution=[hash[0, 1]])",
+          "\n          PinotLogicalAggregate(group=[{0, 1}], aggType=[LEAF], 
collations=[[]], limit=[10])",
+          "\n            LogicalFilter(condition=[>=($2, 0)])",
+          "\n              LogicalTableScan(table=[[default, a]])",
+          "\n"
+        ]
       }
     ]
   }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java
index bd58b7f64f..1ac11809aa 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java
@@ -22,6 +22,7 @@ import com.google.common.base.Preconditions;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.pinot.calcite.rel.logical.PinotRelExchangeType;
 import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.common.request.DataSource;
@@ -71,22 +72,29 @@ public class ServerPlanRequestVisitor implements 
PlanNodeVisitor<Void, ServerPla
   public Void visitAggregate(AggregateNode node, ServerPlanRequestContext 
context) {
     if (visit(node.getInputs().get(0), context)) {
       PinotQuery pinotQuery = context.getPinotQuery();
-      if (pinotQuery.getGroupByList() == null) {
-        List<Expression> groupByList = 
CalciteRexExpressionParser.convertInputRefs(node.getGroupKeys(), pinotQuery);
+      List<Expression> groupByList = 
CalciteRexExpressionParser.convertInputRefs(node.getGroupKeys(), pinotQuery);
+      if (!groupByList.isEmpty()) {
         pinotQuery.setGroupByList(groupByList);
-        pinotQuery.setSelectList(
-            CalciteRexExpressionParser.convertAggregateList(groupByList, 
node.getAggCalls(), node.getFilterArgs(),
-                pinotQuery));
-        if (node.getAggType() == AggregateNode.AggType.DIRECT) {
-          
pinotQuery.putToQueryOptions(CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT,
-              "true");
-        } else if (node.isLeafReturnFinalResult()) {
-          pinotQuery.putToQueryOptions(
-              
CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT_KEY_UNPARTITIONED,
 "true");
+      }
+      pinotQuery.setSelectList(
+          CalciteRexExpressionParser.convertAggregateList(groupByList, 
node.getAggCalls(), node.getFilterArgs(),
+              pinotQuery));
+      if (node.getAggType() == AggregateNode.AggType.DIRECT) {
+        
pinotQuery.putToQueryOptions(CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT,
 "true");
+      } else if (node.isLeafReturnFinalResult()) {
+        pinotQuery.putToQueryOptions(
+            
CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT_KEY_UNPARTITIONED,
 "true");
+      }
+      int limit = node.getLimit();
+      if (limit > 0) {
+        List<RelFieldCollation> collations = node.getCollations();
+        if (!collations.isEmpty()) {
+          
pinotQuery.setOrderByList(CalciteRexExpressionParser.convertOrderByList(collations,
 pinotQuery));
         }
-        // there cannot be any more modification of PinotQuery post agg, thus 
this is the last one possible.
-        context.setLeafStageBoundaryNode(node);
+        pinotQuery.setLimit(limit);
       }
+      // There cannot be any more modification of PinotQuery post agg, thus 
this is the last one possible.
+      context.setLeafStageBoundaryNode(node);
     }
     return null;
   }
@@ -193,8 +201,9 @@ public class ServerPlanRequestVisitor implements 
PlanNodeVisitor<Void, ServerPla
     if (visit(node.getInputs().get(0), context)) {
       PinotQuery pinotQuery = context.getPinotQuery();
       if (pinotQuery.getOrderByList() == null) {
-        if (!node.getCollations().isEmpty()) {
-          
pinotQuery.setOrderByList(CalciteRexExpressionParser.convertOrderByList(node, 
pinotQuery));
+        List<RelFieldCollation> collations = node.getCollations();
+        if (!collations.isEmpty()) {
+          
pinotQuery.setOrderByList(CalciteRexExpressionParser.convertOrderByList(collations,
 pinotQuery));
         }
         if (node.getFetch() >= 0) {
           pinotQuery.setLimit(node.getFetch());
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
index f7f56e0ccb..b2e73f226a 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
@@ -273,7 +273,7 @@ public class AggregateOperatorTest {
       List<Integer> filterArgs, List<Integer> groupKeys, PlanNode.NodeHint 
nodeHint) {
     return new AggregateOperator(OperatorTestUtil.getTracingContext(), _input,
         new AggregateNode(-1, resultSchema, nodeHint, List.of(), aggCalls, 
filterArgs, groupKeys, AggType.DIRECT,
-            false));
+            false, null, 0));
   }
 
   private AggregateOperator getOperator(DataSchema resultSchema, 
List<RexExpression.FunctionCall> aggCalls,
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java
index fc7ebba0b4..05ccf57621 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java
@@ -152,7 +152,7 @@ public class MultiStageAccountingTest implements ITest {
         new DataSchema(new String[]{"group", "sum"}, new 
DataSchema.ColumnDataType[]{INT, DOUBLE});
     return new AggregateOperator(OperatorTestUtil.getTracingContext(), input,
         new AggregateNode(-1, resultSchema, PlanNode.NodeHint.EMPTY, 
List.of(), aggCalls, filterArgs, groupKeys,
-            AggregateNode.AggType.DIRECT, false));
+            AggregateNode.AggType.DIRECT, false, null, 0));
   }
 
   private static MultiStageOperator getHashJoinOperator() {
diff --git a/pinot-query-runtime/src/test/resources/queries/QueryHints.json 
b/pinot-query-runtime/src/test/resources/queries/QueryHints.json
index e7c2ca3757..e8d30ed409 100644
--- a/pinot-query-runtime/src/test/resources/queries/QueryHints.json
+++ b/pinot-query-runtime/src/test/resources/queries/QueryHints.json
@@ -321,6 +321,14 @@
         "description": "aggregate with skip intermediate stage hint (via hint 
option is_partitioned_by_group_by_keys)",
         "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') 
*/ {tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num), COUNT(DISTINCT 
{tbl1}.val) FROM {tbl1} WHERE {tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY 
{tbl1}.num"
       },
+      {
+        "description": "aggregate with skip intermediate stage and enable 
group trim hint",
+        "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true', 
is_enable_group_trim='true') */ num, COUNT(*), SUM(val), SUM(num), 
COUNT(DISTINCT val) FROM {tbl1} WHERE val >= 0 AND name != 'a' GROUP BY num 
ORDER BY COUNT(*) DESC, num LIMIT 1"
+      },
+      {
+        "description": "distinct with enable group trim hint",
+        "sql": "SELECT /*+ aggOptions(is_enable_group_trim='true') */ DISTINCT 
num, val FROM {tbl1} WHERE val >= 0 AND name != 'a' ORDER BY val DESC, num 
LIMIT 1"
+      },
       {
         "description": "join with pre-partitioned left and right tables",
         "sql": "SELECT {tbl1}.num, {tbl1}.val, {tbl2}.data FROM {tbl1} /*+ 
tableOptions(partition_function='hashcode', partition_key='num', 
partition_size='4') */ JOIN {tbl2} /*+ 
tableOptions(partition_function='hashcode', partition_key='num', 
partition_size='4') */ ON {tbl1}.num = {tbl2}.num WHERE {tbl2}.data > 0"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to