This is an automated email from the ASF dual-hosted git repository.

siddteotia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 69d2fae2b4 [multi-stage] Support query plan for GROUP BY only in the 
intermediary stage (#10248)
69d2fae2b4 is described below

commit 69d2fae2b4c8388a9ec7a559deb07ea2e62d8f26
Author: Vivek Iyer Vaidyanathan <[email protected]>
AuthorDate: Mon Feb 27 02:47:43 2023 -0800

    [multi-stage] Support query plan for GROUP BY only in the intermediary 
stage (#10248)
    
    * Group By Optimization: Skip leaf stage aggregation
    
    * Fix checkstyle violations and add runtime tests
---
 .../BrokerRequestHandlerDelegate.java              |  18 +--
 .../calcite/rel/hint/PinotHintStrategyTable.java   |  17 +++
 .../PinotAggregateExchangeNodeInsertRule.java      | 160 ++++++++++++++++++---
 .../pinot/query/QueryEnvironmentTestBase.java      |  15 ++
 .../src/test/resources/queries/AggregatePlans.json |  31 ++++
 .../src/test/resources/queries/GroupByPlans.json   | 157 ++++++++++++++++++++
 .../src/test/resources/queries/OrderByPlans.json   |  30 ++++
 .../src/test/resources/queries/Aggregates.json     |  66 +++++++++
 8 files changed, 463 insertions(+), 31 deletions(-)

diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BrokerRequestHandlerDelegate.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BrokerRequestHandlerDelegate.java
index 3e6a0598be..360291ec69 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BrokerRequestHandlerDelegate.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BrokerRequestHandlerDelegate.java
@@ -47,15 +47,15 @@ public class BrokerRequestHandlerDelegate implements 
BrokerRequestHandler {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(BrokerRequestHandlerDelegate.class);
 
   private final BrokerRequestHandler _singleStageBrokerRequestHandler;
-  private final BrokerRequestHandler _multiStageWorkerRequestHandler;
+  private final BrokerRequestHandler _multiStageBrokerRequestHandler;
   private final BrokerMetrics _brokerMetrics;
   private final String _brokerId;
 
   public BrokerRequestHandlerDelegate(String brokerId, BrokerRequestHandler 
singleStageBrokerRequestHandler,
-      @Nullable BrokerRequestHandler multiStageWorkerRequestHandler, 
BrokerMetrics brokerMetrics) {
+      @Nullable BrokerRequestHandler multiStageBrokerRequestHandler, 
BrokerMetrics brokerMetrics) {
     _brokerId = brokerId;
     _singleStageBrokerRequestHandler = singleStageBrokerRequestHandler;
-    _multiStageWorkerRequestHandler = multiStageWorkerRequestHandler;
+    _multiStageBrokerRequestHandler = multiStageBrokerRequestHandler;
     _brokerMetrics = brokerMetrics;
   }
 
@@ -64,8 +64,8 @@ public class BrokerRequestHandlerDelegate implements 
BrokerRequestHandler {
     if (_singleStageBrokerRequestHandler != null) {
       _singleStageBrokerRequestHandler.start();
     }
-    if (_multiStageWorkerRequestHandler != null) {
-      _multiStageWorkerRequestHandler.start();
+    if (_multiStageBrokerRequestHandler != null) {
+      _multiStageBrokerRequestHandler.start();
     }
   }
 
@@ -74,8 +74,8 @@ public class BrokerRequestHandlerDelegate implements 
BrokerRequestHandler {
     if (_singleStageBrokerRequestHandler != null) {
       _singleStageBrokerRequestHandler.shutDown();
     }
-    if (_multiStageWorkerRequestHandler != null) {
-      _multiStageWorkerRequestHandler.shutDown();
+    if (_multiStageBrokerRequestHandler != null) {
+      _multiStageBrokerRequestHandler.shutDown();
     }
   }
 
@@ -99,9 +99,9 @@ public class BrokerRequestHandlerDelegate implements 
BrokerRequestHandler {
           CommonConstants.Broker.Request.QUERY_OPTIONS));
     }
 
-    if (_multiStageWorkerRequestHandler != null && 
Boolean.parseBoolean(sqlNodeAndOptions.getOptions().get(
+    if (_multiStageBrokerRequestHandler != null && 
Boolean.parseBoolean(sqlNodeAndOptions.getOptions().get(
           
CommonConstants.Broker.Request.QueryOptionKey.USE_MULTISTAGE_ENGINE))) {
-        return _multiStageWorkerRequestHandler.handleRequest(request, 
requesterIdentity, requestContext);
+        return _multiStageBrokerRequestHandler.handleRequest(request, 
requesterIdentity, requestContext);
     } else {
       return _singleStageBrokerRequestHandler.handleRequest(request, 
sqlNodeAndOptions, requesterIdentity,
           requestContext);
diff --git 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/hint/PinotHintStrategyTable.java
 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/hint/PinotHintStrategyTable.java
index ca4bc189a6..2d4980aa7c 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/hint/PinotHintStrategyTable.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/hint/PinotHintStrategyTable.java
@@ -18,6 +18,9 @@
  */
 package org.apache.calcite.rel.hint;
 
+import com.google.common.collect.ImmutableList;
+
+
 /**
  * Default hint strategy set for Pinot query.
  */
@@ -25,6 +28,10 @@ public class PinotHintStrategyTable {
   public static final String INTERNAL_AGG_INTERMEDIATE_STAGE = 
"aggIntermediateStage";
   public static final String INTERNAL_AGG_FINAL_STAGE = "aggFinalStage";
 
+  public static final String SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION = 
"skipLeafStageGroupByAggregation";
+
+
+
   private PinotHintStrategyTable() {
     // do not instantiate.
   }
@@ -32,5 +39,15 @@ public class PinotHintStrategyTable {
   public static final HintStrategyTable PINOT_HINT_STRATEGY_TABLE = 
HintStrategyTable.builder()
       .hintStrategy(INTERNAL_AGG_INTERMEDIATE_STAGE, HintPredicates.AGGREGATE)
       .hintStrategy(INTERNAL_AGG_FINAL_STAGE, HintPredicates.AGGREGATE)
+      .hintStrategy(SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION, 
HintPredicates.AGGREGATE)
       .build();
+
+  public static boolean containsHint(ImmutableList<RelHint> hintList, String 
hintName) {
+    for (RelHint relHint : hintList) {
+      if (relHint.hintName.equals(hintName)) {
+        return true;
+      }
+    }
+    return false;
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
index 71bb7100e8..99736870bf 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
@@ -29,13 +29,18 @@ import java.util.Map;
 import java.util.Set;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.hep.HepRelVertex;
 import org.apache.calcite.rel.RelDistributions;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.hint.PinotHintStrategyTable;
 import org.apache.calcite.rel.hint.RelHint;
 import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.rel.logical.LogicalExchange;
+import org.apache.calcite.rel.logical.LogicalProject;
+import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlAggFunction;
@@ -100,11 +105,21 @@ public class PinotAggregateExchangeNodeInsertRule extends 
RelOptRule {
   @Override
   public void onMatch(RelOptRuleCall call) {
     Aggregate oldAggRel = call.rel(0);
-    ImmutableList<RelHint> orgHints = oldAggRel.getHints();
+    ImmutableList<RelHint> oldHints = oldAggRel.getHints();
+
+    // If "skipLeafStageGroupByAggregation" SQLHint is passed, the leaf stage 
aggregation is skipped. This only
+    // applies for Group By Aggregations.
+    if (!oldAggRel.getGroupSet().isEmpty() && 
PinotHintStrategyTable.containsHint(oldHints,
+        PinotHintStrategyTable.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION)) {
+      // This is not the default path. Use this group by optimization to skip 
leaf stage aggregation when aggregating
+      // at leaf level could be wasted effort. eg: when cardinality of group 
by column is very high.
+      createPlanWithoutLeafAggregation(call);
+      return;
+    }
 
     // 1. attach leaf agg RelHint to original agg.
     ImmutableList<RelHint> newLeafAggHints =
-        new 
ImmutableList.Builder<RelHint>().addAll(orgHints).add(AggregateNode.INTERMEDIATE_STAGE_HINT).build();
+        new 
ImmutableList.Builder<RelHint>().addAll(oldHints).add(AggregateNode.INTERMEDIATE_STAGE_HINT).build();
     Aggregate newLeafAgg =
         new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), 
newLeafAggHints, oldAggRel.getInput(),
             oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), 
oldAggRel.getAggCallList());
@@ -119,16 +134,16 @@ public class PinotAggregateExchangeNodeInsertRule extends 
RelOptRule {
     }
 
     // 3. attach intermediate agg stage.
-    RelNode newAggNode = makeNewIntermediateAgg(call, oldAggRel, exchange);
+    RelNode newAggNode = makeNewIntermediateAgg(call, oldAggRel, exchange, 
true, null, null);
     call.transformTo(newAggNode);
   }
 
-  private RelNode makeNewIntermediateAgg(RelOptRuleCall ruleCall, Aggregate 
oldAggRel, LogicalExchange exchange) {
+  private RelNode makeNewIntermediateAgg(RelOptRuleCall ruleCall, Aggregate 
oldAggRel, LogicalExchange exchange,
+      boolean isLeafStageAggregationPresent, List<Integer> argList, 
List<Integer> groupByList) {
 
     // add the exchange as the input node to the relation builder.
     RelBuilder relBuilder = ruleCall.builder();
     relBuilder.push(exchange);
-    List<RexNode> inputExprs = new ArrayList<>(relBuilder.fields());
 
     // make input ref to the exchange after the leaf aggregate.
     RexBuilder rexBuilder = exchange.getCluster().getRexBuilder();
@@ -144,14 +159,15 @@ public class PinotAggregateExchangeNodeInsertRule extends 
RelOptRule {
 
     for (int oldCallIndex = 0; oldCallIndex < oldCalls.size(); oldCallIndex++) 
{
       AggregateCall oldCall = oldCalls.get(oldCallIndex);
-      convertAggCall(rexBuilder, oldAggRel, oldCallIndex, oldCall, newCalls, 
aggCallMapping, inputExprs);
+      convertAggCall(rexBuilder, oldAggRel, oldCallIndex, oldCall, newCalls, 
aggCallMapping,
+          isLeafStageAggregationPresent, argList);
     }
 
     // create new aggregate relation.
     ImmutableList<RelHint> orgHints = oldAggRel.getHints();
     ImmutableList<RelHint> newIntermediateAggHints =
         new 
ImmutableList.Builder<RelHint>().addAll(orgHints).add(AggregateNode.FINAL_STAGE_HINT).build();
-    ImmutableBitSet groupSet = ImmutableBitSet.range(nGroups);
+    ImmutableBitSet groupSet = groupByList == null ? 
ImmutableBitSet.range(nGroups) : ImmutableBitSet.of(groupByList);
     relBuilder.aggregate(
         relBuilder.groupKey(groupSet, ImmutableList.of(groupSet)),
         newCalls);
@@ -169,7 +185,7 @@ public class PinotAggregateExchangeNodeInsertRule extends 
RelOptRule {
    */
   private static void convertAggCall(RexBuilder rexBuilder, Aggregate 
oldAggRel, int oldCallIndex,
       AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, 
RexNode> aggCallMapping,
-      List<RexNode> inputExprs) {
+      boolean isLeafStageAggregationPresent, List<Integer> argList) {
     final int nGroups = oldAggRel.getGroupCount();
     final SqlAggFunction oldAggregation = oldCall.getAggregation();
     final SqlKind aggKind = oldAggregation.getKind();
@@ -177,23 +193,31 @@ public class PinotAggregateExchangeNodeInsertRule extends 
RelOptRule {
     Preconditions.checkState(SUPPORTED_AGG_KIND.contains(aggKind), 
"Unsupported SQL aggregation "
         + "kind: {}. Only splittable aggregation functions are supported!", 
aggKind);
 
-    // Special treatment on COUNT
     AggregateCall newCall;
-    if (oldAggregation instanceof SqlCountAggFunction) {
-      newCall = AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), 
oldCall.isDistinct(), oldCall.isApproximate(),
-          oldCall.ignoreNulls(), convertArgList(nGroups + oldCallIndex, 
Collections.singletonList(oldCallIndex)),
-          oldCall.filterArg, oldCall.distinctKeys, oldCall.collation, 
oldCall.type, oldCall.getName());
+    if (isLeafStageAggregationPresent) {
+
+      // Special treatment for Count. If count is performed at the Leaf Stage, 
a Sum needs to be performed at the
+      // intermediate stage.
+      if (oldAggregation instanceof SqlCountAggFunction) {
+        newCall =
+            AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), 
oldCall.isDistinct(), oldCall.isApproximate(),
+                oldCall.ignoreNulls(), convertArgList(nGroups + oldCallIndex, 
Collections.singletonList(oldCallIndex)),
+                oldCall.filterArg, oldCall.distinctKeys, oldCall.collation, 
oldCall.type, oldCall.getName());
+      } else {
+        newCall = AggregateCall.create(oldCall.getAggregation(), 
oldCall.isDistinct(), oldCall.isApproximate(),
+            oldCall.ignoreNulls(), convertArgList(nGroups + oldCallIndex, 
oldCall.getArgList()), oldCall.filterArg,
+            oldCall.distinctKeys, oldCall.collation, oldCall.type, 
oldCall.getName());
+      }
     } else {
-      newCall = AggregateCall.create(
-          oldCall.getAggregation(), oldCall.isDistinct(), 
oldCall.isApproximate(), oldCall.ignoreNulls(),
-          convertArgList(nGroups + oldCallIndex, oldCall.getArgList()), 
oldCall.filterArg, oldCall.distinctKeys,
-          oldCall.collation, oldCall.type, oldCall.getName());
+      List<Integer> newArgList = oldCall.getArgList().size() == 0 ? 
Collections.emptyList()
+          : Collections.singletonList(argList.get(oldCallIndex));
+
+      newCall = AggregateCall.create(oldCall.getAggregation(), 
oldCall.isDistinct(), oldCall.isApproximate(),
+          oldCall.ignoreNulls(), newArgList, oldCall.filterArg, 
oldCall.distinctKeys, oldCall.collation, oldCall.type,
+          oldCall.getName());
     }
-    rexBuilder.addAggCall(newCall,
-        nGroups,
-        newCalls,
-        aggCallMapping,
-        oldAggRel.getInput()::fieldIsNullable);
+
+    rexBuilder.addAggCall(newCall, nGroups, newCalls, aggCallMapping, 
oldAggRel.getInput()::fieldIsNullable);
   }
 
   private static List<Integer> convertArgList(int oldCallIndexWithShift, 
List<Integer> argList) {
@@ -201,4 +225,96 @@ public class PinotAggregateExchangeNodeInsertRule extends 
RelOptRule {
         "Unable to convert call as the argList contains more than 1 argument");
     return argList.size() == 1 ? 
Collections.singletonList(oldCallIndexWithShift) : Collections.emptyList();
   }
+
+  private void createPlanWithoutLeafAggregation(RelOptRuleCall call) {
+    Aggregate oldAggRel = call.rel(0);
+    RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
+    LogicalProject project;
+
+    List<Integer> newAggArgColumns = new ArrayList<>();
+    List<Integer> newAggGroupByColumns = new ArrayList<>();
+
+    // 1. Create the LogicalProject node if it does not exist. This is to send 
only the relevant columns over
+    //    the wire for intermediate aggregation.
+    if (childRel instanceof Project) {
+      // Avoid creating a new LogicalProject if the child node of aggregation 
is already a project node.
+      project = (LogicalProject) childRel;
+      newAggArgColumns = fetchNewAggArgCols(oldAggRel.getAggCallList());
+      newAggGroupByColumns = oldAggRel.getGroupSet().asList();
+    } else {
+      // Create a leaf stage project. This is done so that only the required 
columns are sent over the wire for
+      // intermediate aggregation. If there are multiple aggregations on the 
same column, the column is projected
+      // only once.
+      project = createLogicalProjectForAggregate(oldAggRel, newAggArgColumns, 
newAggGroupByColumns);
+    }
+
+    // 2. Create an exchange on top of the LogicalProject.
+    LogicalExchange exchange = LogicalExchange.create(project, 
RelDistributions.hash(newAggGroupByColumns));
+
+    // 3. Create an intermediate stage aggregation.
+    RelNode newAggNode =
+        makeNewIntermediateAgg(call, oldAggRel, exchange, false, 
newAggArgColumns, newAggGroupByColumns);
+
+    call.transformTo(newAggNode);
+  }
+
+  private LogicalProject createLogicalProjectForAggregate(Aggregate oldAggRel, 
List<Integer> newAggArgColumns,
+      List<Integer> newAggGroupByCols) {
+    RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
+    RexBuilder childRexBuilder = childRel.getCluster().getRexBuilder();
+    List<RelDataTypeField> fieldList = childRel.getRowType().getFieldList();
+
+    List<RexNode> projectColRexNodes = new ArrayList<>();
+    List<String> projectColNames = new ArrayList<>();
+    // Maintains a mapping from the column to the corresponding index in 
projectColRexNodes.
+    Map<Integer, Integer> projectSet = new HashMap<>();
+
+    int projectIndex = 0;
+    for (int group : oldAggRel.getGroupSet().asSet()) {
+      projectColNames.add(fieldList.get(group).getName());
+      projectColRexNodes.add(childRexBuilder.makeInputRef(childRel, group));
+      projectSet.put(group, projectColRexNodes.size() - 1);
+      newAggGroupByCols.add(projectIndex++);
+    }
+
+    List<AggregateCall> oldAggCallList = oldAggRel.getAggCallList();
+    for (int i = 0; i < oldAggCallList.size(); i++) {
+      List<Integer> argList = oldAggCallList.get(i).getArgList();
+      if (argList.size() == 0) {
+        newAggArgColumns.add(-1);
+        continue;
+      }
+      for (int j = 0; j < argList.size(); j++) {
+        Integer col = argList.get(j);
+        if (!projectSet.containsKey(col)) {
+          projectColRexNodes.add(childRexBuilder.makeInputRef(childRel, col));
+          projectColNames.add(fieldList.get(col).getName());
+          projectSet.put(col, projectColRexNodes.size() - 1);
+          newAggArgColumns.add(projectColRexNodes.size() - 1);
+        } else {
+          newAggArgColumns.add(projectSet.get(col));
+        }
+      }
+    }
+
+    return LogicalProject.create(childRel, Collections.emptyList(), 
projectColRexNodes, projectColNames);
+  }
+
+  private List<Integer> fetchNewAggArgCols(List<AggregateCall> oldAggCallList) 
{
+    List<Integer> newAggArgColumns = new ArrayList<>();
+
+    for (int i = 0; i < oldAggCallList.size(); i++) {
+      if (oldAggCallList.get(i).getArgList().size() == 0) {
+        // This can be true for COUNT. Add a placeholder value which will be 
ignored.
+        newAggArgColumns.add(-1);
+        continue;
+      }
+      for (int j = 0; j < oldAggCallList.get(i).getArgList().size(); j++) {
+        Integer col = oldAggCallList.get(i).getArgList().get(j);
+        newAggArgColumns.add(col);
+      }
+    }
+
+    return newAggArgColumns;
+  }
 }
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
index 3a69d3dceb..c5debf4e85 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
@@ -109,6 +109,21 @@ public class QueryEnvironmentTestBase {
         new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2, 
a.col1), MIN(a.col3) OVER (ORDER BY a.col2, "
             + "a.col1) FROM a"},
         new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2), 
MIN(a.col3) OVER (ORDER BY a.col2) FROM a"},
+        new Object[]{"SELECT /*+ skipLeafStageGroupByAggregation */ a.col1, 
SUM(a.col3) FROM a WHERE a.col3 >= 0"
+            + " AND a.col2 = 'a' GROUP BY a.col1"},
+        new Object[]{"SELECT /*+ skipLeafStageGroupByAggregation */ a.col1, 
COUNT(*) FROM a WHERE a.col3 >= 0 "
+            + "AND a.col2 = 'a' GROUP BY a.col1"},
+        new Object[]{"SELECT /*+ skipLeafStageGroupByAggregation */ a.col2, 
a.col1, SUM(a.col3) FROM a WHERE a"
+            + ".col3 >= 0 AND a.col1 = 'a'  GROUP BY a.col1, a.col2"},
+        new Object[]{"SELECT /*+ skipLeafStageGroupByAggregation */ a.col1, 
AVG(b.col3) FROM a JOIN b ON a.col1 "
+            + "= b.col2  WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 
GROUP BY a.col1"},
+        new Object[]{"SELECT /*+ skipLeafStageGroupByAggregation */ a.col1 as 
v1, a.col1 as v2, AVG(a.col3) FROM"
+            + " a GROUP BY v1, v2"},
+        new Object[]{"SELECT /*+ skipLeafStageGroupByAggregation */ a.col2, 
COUNT(*), SUM(a.col3), SUM(a.col1) "
+            + "FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col2 
HAVING COUNT(*) > 10 AND MAX(a.col3) >= 0 "
+            + "AND MIN(a.col3) < 20 AND SUM(a.col3) <= 10 AND AVG(a.col3) = 
5"},
+        new Object[]{"SELECT /*+ skipLeafStageGroupByAggregation */ a.col2, 
a.col3 FROM a JOIN b ON a.col1 = b"
+            + ".col1  WHERE a.col3 >= 0 GROUP BY a.col2, a.col3"},
     };
   }
 
diff --git a/pinot-query-planner/src/test/resources/queries/AggregatePlans.json 
b/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
index 42f5609396..83ec13080d 100644
--- a/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
+++ b/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
@@ -59,6 +59,37 @@
           "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
+      },
+      {
+        "description": "Select aggregates with filters and select alias. The 
group by aggregate hint should be a no-op.",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
AVG(a.col3) as avg, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 
'pink floyd'",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(avg=[/(CAST($0):DOUBLE, $1)], count=[$1])",
+          "\n  LogicalProject($f0=[CASE(=($1, 0), null:INTEGER, $0)], 
$f1=[$1])",
+          "\n    LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[$SUM0($1)])",
+          "\n      LogicalExchange(distribution=[hash])",
+          "\n        LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], 
agg#1=[COUNT()])",
+          "\n          LogicalProject(col2=[$0], col3=[$1])",
+          "\n            LogicalFilter(condition=[AND(>=($1, 0), =($0, 'pink 
floyd'))])",
+          "\n              LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "Select aggregates with filters and select alias. The 
group by aggregate hint should be a no-op.",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
SUM(a.col3) as sum, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 
'pink floyd'",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(sum=[CASE(=($1, 0), null:INTEGER, $0)], 
count=[$1])",
+          "\n  LogicalAggregate(group=[{}], sum=[$SUM0($0)], 
agg#1=[$SUM0($1)])",
+          "\n    LogicalExchange(distribution=[hash])",
+          "\n      LogicalAggregate(group=[{}], sum=[$SUM0($1)], 
agg#1=[COUNT()])",
+          "\n        LogicalProject(col2=[$0], col3=[$1])",
+          "\n          LogicalFilter(condition=[AND(>=($1, 0), =($0, 'pink 
floyd'))])",
+          "\n            LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
       }
     ]
   }
diff --git a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json 
b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
index 7c340693e5..f977b7f7a4 100644
--- a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
@@ -13,6 +13,20 @@
           "\n"
         ]
       },
+      {
+        "description": "SQL hint based group by optimization with select and 
multiple aggregations one 1 column",
+        "sql": "EXPLAIN PLAN FOR SELECT a.col1, SUM(a.col3), AVG(a.col3), 
MAX(a.col3), MIN(a.col3) FROM a GROUP BY a.col1",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($2):DOUBLE 
NOT NULL, $3)], EXPR$3=[$4], EXPR$4=[$5])",
+          "\n  LogicalProject(col1=[$0], EXPR$1=[$1], $f2=[$1], $f3=[$2], 
EXPR$3=[$3], EXPR$4=[$4])",
+          "\n    LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
agg#1=[$SUM0($2)], EXPR$3=[MAX($3)], EXPR$4=[MIN($4)])",
+          "\n      LogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalAggregate(group=[{2}], EXPR$1=[$SUM0($1)], 
agg#1=[COUNT()], EXPR$3=[MAX($1)], EXPR$4=[MIN($1)])",
+          "\n          LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
       {
         "description": "Group by with filter",
         "sql": "EXPLAIN PLAN FOR SELECT a.col1, SUM(a.col3) FROM a WHERE 
a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1",
@@ -91,6 +105,149 @@
           "\n                  LogicalTableScan(table=[[a]])",
           "\n"
         ]
+      },
+      {
+        "description": "SQL hint based group by optimization with select and 
aggregate column",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1, SUM(a.col3) FROM a GROUP BY a.col1",
+        "output": [
+          "Execution Plan",
+          "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+          "\n  LogicalExchange(distribution=[hash[0]])",
+          "\n    LogicalProject(col1=[$2], col3=[$1])",
+          "\n      LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "SQL hint based group by optimization with select and 
AVG aggregation",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1, AVG(a.col3) FROM a GROUP BY a.col1",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, 
$2)])",
+          "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
agg#1=[COUNT()])",
+          "\n    LogicalExchange(distribution=[hash[0]])",
+          "\n      LogicalProject(col1=[$2], col3=[$1])",
+          "\n        LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "SQL hint based group by optimization with select and 
multiple aggregations one 1 column",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1, SUM(a.col3), AVG(a.col3), MAX(a.col3), MIN(a.col3) FROM a GROUP BY 
a.col1",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($2):DOUBLE 
NOT NULL, $3)], EXPR$3=[$4], EXPR$4=[$5])",
+          "\n  LogicalProject(col1=[$0], EXPR$1=[$1], $f2=[$1], $f3=[$2], 
EXPR$3=[$3], EXPR$4=[$4])",
+          "\n    LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
agg#1=[COUNT()], EXPR$3=[MAX($1)], EXPR$4=[MIN($1)])",
+          "\n      LogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalProject(col1=[$2], col3=[$1])",
+          "\n          LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "SQL hint based group by optimization with filter",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1",
+        "output": [
+          "Execution Plan",
+          "\nLogicalAggregate(group=[{2}], EXPR$1=[$SUM0($1)])",
+          "\n  LogicalExchange(distribution=[hash[2]])",
+          "\n    LogicalProject(col2=[$0], col3=[$1], col1=[$2])",
+          "\n      LogicalFilter(condition=[AND(>=($1, 0), =($0, 'a'))])",
+          "\n        LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "SQL hint based group by optimization with filter",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1, SUM(a.col3), MAX(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' 
GROUP BY a.col1",
+        "output": [
+          "Execution Plan",
+          "\nLogicalAggregate(group=[{2}], EXPR$1=[$SUM0($1)], 
EXPR$2=[MAX($1)])",
+          "\n  LogicalExchange(distribution=[hash[2]])",
+          "\n    LogicalProject(col2=[$0], col3=[$1], col1=[$2])",
+          "\n      LogicalFilter(condition=[AND(>=($1, 0), =($0, 'a'))])",
+          "\n        LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "SQL hint based group by optimization count(*) with 
filter",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1, COUNT(*) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1",
+        "notes": "TODO: Needs follow up. Project should only keep a.col1 since 
the other columns are pushed to the filter, but it currently keeps them all",
+        "output": [
+          "Execution Plan",
+          "\nLogicalAggregate(group=[{2}], EXPR$1=[COUNT()])",
+          "\n  LogicalExchange(distribution=[hash[2]])",
+          "\n    LogicalProject(col2=[$0], col3=[$1], col1=[$2])",
+          "\n      LogicalFilter(condition=[AND(>=($1, 0), =($0, 'a'))])",
+          "\n        LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "SQL hint based group by optimization on 2 columns with 
filter",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col2, a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 AND a.col1 = 'a'  GROUP BY 
a.col1, a.col2",
+        "output": [
+          "Execution Plan",
+          "\nLogicalAggregate(group=[{0, 2}], EXPR$2=[$SUM0($1)])",
+          "\n  LogicalExchange(distribution=[hash[0, 2]])",
+          "\n    LogicalProject(col2=[$0], col3=[$1], col1=[$2])",
+          "\n      LogicalFilter(condition=[AND(>=($1, 0), =($2, 'a'))])",
+          "\n        LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "SQL hint based group by optimization with having 
clause",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1, COUNT(*), SUM(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP 
BY a.col1 HAVING COUNT(*) > 10 AND MAX(a.col3) >= 0 AND MIN(a.col3) < 20 AND 
SUM(a.col3) <= 10 AND AVG(a.col3) = 5",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])",
+          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =($5, 5))])",
+          "\n    LogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2], $f3=[$3], 
$f4=[$4], $f5=[/(CAST($5):DOUBLE NOT NULL, $1)])",
+          "\n      LogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2], 
$f3=[$3], $f4=[$4], $f5=[$2])",
+          "\n        LogicalAggregate(group=[{2}], EXPR$1=[COUNT()], 
EXPR$2=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
+          "\n          LogicalExchange(distribution=[hash[2]])",
+          "\n            LogicalProject(col2=[$0], col3=[$1], col1=[$2])",
+          "\n              LogicalFilter(condition=[AND(>=($1, 0), =($0, 
'a'))])",
+          "\n                LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "SQL hint based group by optimization with having 
clause but no count",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1 
HAVING MAX(a.col3) >= 0 AND MIN(a.col3) < 20 AND SUM(a.col3) <= 10 AND 
AVG(a.col3) = 5",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(col1=[$0], EXPR$1=[$1])",
+          "\n  LogicalFilter(condition=[AND(>=($2, 0), <($3, 20), <=($1, 10), 
=($4, 5))])",
+          "\n    LogicalProject(col1=[$0], EXPR$1=[$1], $f2=[$2], $f3=[$3], 
$f4=[/(CAST($4):DOUBLE NOT NULL, $5)])",
+          "\n      LogicalProject(col1=[$0], EXPR$1=[$1], $f2=[$2], $f3=[$3], 
$f4=[$1], $f5=[$4])",
+          "\n        LogicalAggregate(group=[{2}], EXPR$1=[$SUM0($1)], 
agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()])",
+          "\n          LogicalExchange(distribution=[hash[2]])",
+          "\n            LogicalProject(col2=[$0], col3=[$1], col1=[$2])",
+          "\n              LogicalFilter(condition=[AND(>=($1, 0), =($0, 
'a'))])",
+          "\n                LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "SQL hint based group by optimization with having 
clause and select alias",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1 as value1, COUNT(*) AS count, SUM(a.col3) as SUM FROM a WHERE a.col3 >= 
0 AND a.col2 = 'a' GROUP BY a.col1 HAVING COUNT(*) > 10 AND MAX(a.col3) >= 0 
AND MIN(a.col3) < 20 AND SUM(a.col3) <= 10 AND AVG(a.col3) = 5",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])",
+          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =($5, 5))])",
+          "\n    LogicalProject(col1=[$0], count=[$1], SUM=[$2], $f3=[$3], 
$f4=[$4], $f5=[/(CAST($5):DOUBLE NOT NULL, $1)])",
+          "\n      LogicalProject(col1=[$0], count=[$1], SUM=[$2], $f3=[$3], 
$f4=[$4], $f5=[$2])",
+          "\n        LogicalAggregate(group=[{2}], count=[COUNT()], 
SUM=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
+          "\n          LogicalExchange(distribution=[hash[2]])",
+          "\n            LogicalProject(col2=[$0], col3=[$1], col1=[$2])",
+          "\n              LogicalFilter(condition=[AND(>=($1, 0), =($0, 
'a'))])",
+          "\n                LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
       }
     ]
   }
diff --git a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json 
b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
index 94f8fe9ba3..b1aa977161 100644
--- a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
@@ -66,6 +66,21 @@
           "\n"
         ]
       },
+      {
+        "description": "Order by and group by with hint",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1, SUM(a.col3) FROM a GROUP BY a.col1 ORDER BY a.col1",
+        "output": [
+          "Execution Plan",
+          "\nLogicalSort(sort0=[$0], dir0=[ASC], offset=[0])",
+          "\n  LogicalSortExchange(distribution=[hash], collation=[[0]])",
+          "\n    LogicalSort(sort0=[$0], dir0=[ASC])",
+          "\n      LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+          "\n        LogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalProject(col1=[$2], col3=[$1])",
+          "\n            LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
+      },
       {
         "description": "Order by and group by with alias",
         "sql": "EXPLAIN PLAN FOR SELECT a.col1 AS value1, SUM(a.col3) AS sum 
FROM a GROUP BY a.col1 ORDER BY a.col1",
@@ -80,6 +95,21 @@
           "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
+      },
+      {
+        "description": "Order by and group by with alias with SqlHint",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1 AS value1, SUM(a.col3) AS sum FROM a GROUP BY a.col1 ORDER BY a.col1",
+        "output": [
+          "Execution Plan",
+          "\nLogicalSort(sort0=[$0], dir0=[ASC], offset=[0])",
+          "\n  LogicalSortExchange(distribution=[hash], collation=[[0]])",
+          "\n    LogicalSort(sort0=[$0], dir0=[ASC])",
+          "\n      LogicalAggregate(group=[{0}], sum=[$SUM0($1)])",
+          "\n        LogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalProject(col1=[$2], col3=[$1])",
+          "\n            LogicalTableScan(table=[[a]])",
+          "\n"
+        ]
       }
     ]
   }
diff --git a/pinot-query-runtime/src/test/resources/queries/Aggregates.json 
b/pinot-query-runtime/src/test/resources/queries/Aggregates.json
index 51a45cebdb..ade91b5ee1 100644
--- a/pinot-query-runtime/src/test/resources/queries/Aggregates.json
+++ b/pinot-query-runtime/src/test/resources/queries/Aggregates.json
@@ -277,5 +277,71 @@
       { "sql": "SELECT upper(string_col), count(int_col) FROM {tbl} GROUP BY 
upper(string_col) HAVING sum(int_col) > 0 ORDER BY upper(string_col)" },
       { "sql": "SELECT upper(string_col), count(int_col) FROM {tbl} GROUP BY 
upper(string_col) HAVING sum(int_col) >= 0 AND count(int_col) >= 0 ORDER BY 
count(int_col)" }
     ]
+  },
+  "aggregate_with_hints": {
+    "tables": {
+      "tbl": {
+        "schema": [
+          {"name": "int_col", "type": "INT"},
+          {"name": "double_col", "type": "DOUBLE"},
+          {"name": "string_col", "type": "STRING"},
+          {"name": "bool_col", "type": "BOOLEAN"}
+        ],
+        "inputs": [
+          [2, 300, "a", true],
+          [2, 400, "a", true],
+          [3, 100, "b", false],
+          [100, 1, "b", false],
+          [101, 1.01, "c", false],
+          [150, 1.5, "c", false],
+          [175, 1.75, "c", true]
+        ]
+      }
+    },
+    "queries": [
+      {
+        "psql": "4.2.7",
+        "description": "aggregation without groupby. hint is a no-op.",
+        "sql": "SELECT /*+ skipLeafStageGroupByAggregation */ avg(double_col) 
FROM {tbl}"
+      },
+      {
+        "psql": "4.2.7",
+        "description": "count, sum group by order by",
+        "sql": "select /*+ skipLeafStageGroupByAggregation */ string_col, 
count(int_col), sum(double_col) from {tbl} group by string_col order by 
string_col;"
+      },
+      {
+        "psql": "4.2.7",
+        "description": "count, sum group by order by. Multiple aggregations on 
single column.",
+        "sql": "select /*+ skipLeafStageGroupByAggregation */ string_col, 
max(int_col), min(double_col), count(int_col), sum(double_col) from {tbl} group 
by string_col order by string_col;"
+      },
+      {
+        "psql": "9.21.0",
+        "description": "aggregate boolean column",
+        "sql": "SELECT /*+ skipLeafStageGroupByAggregation */ 
bool_and(bool_col), bool_or(bool_col) FROM {tbl} GROUP BY string_col"
+      },
+      {
+        "psql": "9.21.0",
+        "description": "aggregate boolean column no group by",
+        "sql": "SELECT /*+ skipLeafStageGroupByAggregation */ 
bool_and(bool_col), bool_or(bool_col) FROM {tbl}"
+      },
+      {
+        "ignored": true,
+        "comment": "sum empty returns [0] instead of [null] at the moment",
+        "description": "sum empty input after filter with subquery",
+        "sql": "SELECT /*+ skipLeafStageGroupByAggregation */ sum(int_col) 
FROM {tbl} WHERE string_col IN ( SELECT string_col FROM {tbl} WHERE int_col 
BETWEEN 1 AND 0 GROUP BY string_col )"
+      },
+      {
+        "description": "count empty input after filter with sub-query",
+        "sql": "SELECT count(*) FROM {tbl} WHERE string_col IN ( SELECT /*+ 
skipLeafStageGroupByAggregation */ string_col FROM {tbl} WHERE int_col BETWEEN 
1 AND 0 GROUP BY string_col )"
+      },
+      {
+        "description": "count empty input after filter with sub-query",
+        "sql": "SELECT count(int_col) FROM {tbl} WHERE string_col IN ( SELECT 
/*+ skipLeafStageGroupByAggregation */ string_col FROM {tbl} WHERE int_col 
BETWEEN 1 AND 0 GROUP BY string_col )"
+      },
+      {
+        "description": "group by optimization with filter",
+        "sql": "SELECT /*+ skipLeafStageGroupByAggregation */ double_col, 
sum(int_col) FROM {tbl} WHERE int_col > 3 AND double_col > 1.0 GROUP BY 
double_col"
+      }
+    ]
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to