fjy closed pull request #6228: [Backport] Support projection after sorting in 
SQL
URL: https://github.com/apache/incubator-druid/pull/6228
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java 
b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
index 2532c8d7f82..09436b96e9d 100644
--- a/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
+++ b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
@@ -36,6 +36,7 @@
 import io.druid.sql.calcite.table.RowSignature;
 
 import javax.annotation.Nullable;
+import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 import java.util.Set;
@@ -112,7 +113,7 @@ public static Aggregation create(final AggregatorFactory 
aggregatorFactory)
 
   public static Aggregation create(final PostAggregator postAggregator)
   {
-    return new Aggregation(ImmutableList.of(), ImmutableList.of(), 
postAggregator);
+    return new Aggregation(Collections.emptyList(), Collections.emptyList(), 
postAggregator);
   }
 
   public static Aggregation create(
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java 
b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
index bca4481992f..5503f50adf9 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
@@ -89,6 +89,7 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.OptionalInt;
 import java.util.TreeSet;
 import java.util.stream.Collectors;
 
@@ -105,9 +106,11 @@
   private final DimFilter filter;
   private final SelectProjection selectProjection;
   private final Grouping grouping;
+  private final SortProject sortProject;
+  private final DefaultLimitSpec limitSpec;
   private final RowSignature outputRowSignature;
   private final RelDataType outputRowType;
-  private final DefaultLimitSpec limitSpec;
+
   private final Query query;
 
   public DruidQuery(
@@ -128,15 +131,22 @@ public DruidQuery(
     this.selectProjection = computeSelectProjection(partialQuery, 
plannerContext, sourceRowSignature);
     this.grouping = computeGrouping(partialQuery, plannerContext, 
sourceRowSignature, rexBuilder);
 
+    final RowSignature sortingInputRowSignature;
+
     if (this.selectProjection != null) {
-      this.outputRowSignature = this.selectProjection.getOutputRowSignature();
+      sortingInputRowSignature = this.selectProjection.getOutputRowSignature();
     } else if (this.grouping != null) {
-      this.outputRowSignature = this.grouping.getOutputRowSignature();
+      sortingInputRowSignature = this.grouping.getOutputRowSignature();
     } else {
-      this.outputRowSignature = sourceRowSignature;
+      sortingInputRowSignature = sourceRowSignature;
     }
 
-    this.limitSpec = computeLimitSpec(partialQuery, this.outputRowSignature);
+    this.sortProject = computeSortProject(partialQuery, plannerContext, 
sortingInputRowSignature, grouping);
+
+    // outputRowSignature is used only for scan and select query, and thus 
sort and grouping must be null
+    this.outputRowSignature = sortProject == null ? sortingInputRowSignature : 
sortProject.getOutputRowSignature();
+
+    this.limitSpec = computeLimitSpec(partialQuery, sortingInputRowSignature);
     this.query = computeQuery();
   }
 
@@ -235,7 +245,7 @@ private static Grouping computeGrouping(
   )
   {
     final Aggregate aggregate = partialQuery.getAggregate();
-    final Project postProject = partialQuery.getPostProject();
+    final Project aggregateProject = partialQuery.getAggregateProject();
 
     if (aggregate == null) {
       return null;
@@ -265,49 +275,27 @@ private static Grouping computeGrouping(
         plannerContext
     );
 
-    if (postProject == null) {
+    if (aggregateProject == null) {
       return Grouping.create(dimensions, aggregations, havingFilter, 
aggregateRowSignature);
     } else {
-      final List<String> rowOrder = new ArrayList<>();
-
-      int outputNameCounter = 0;
-      for (final RexNode postAggregatorRexNode : postProject.getChildExps()) {
-        // Attempt to convert to PostAggregator.
-        final DruidExpression postAggregatorExpression = 
Expressions.toDruidExpression(
-            plannerContext,
-            aggregateRowSignature,
-            postAggregatorRexNode
-        );
-
-        if (postAggregatorExpression == null) {
-          throw new CannotBuildQueryException(postProject, 
postAggregatorRexNode);
-        }
-
-        if (postAggregatorDirectColumnIsOk(aggregateRowSignature, 
postAggregatorExpression, postAggregatorRexNode)) {
-          // Direct column access, without any type cast as far as Druid's 
runtime is concerned.
-          // (There might be a SQL-level type cast that we don't care about)
-          rowOrder.add(postAggregatorExpression.getDirectColumn());
-        } else {
-          final String postAggregatorName = "p" + outputNameCounter++;
-          final PostAggregator postAggregator = new ExpressionPostAggregator(
-              postAggregatorName,
-              postAggregatorExpression.getExpression(),
-              null,
-              plannerContext.getExprMacroTable()
-          );
-          aggregations.add(Aggregation.create(postAggregator));
-          rowOrder.add(postAggregator.getName());
-        }
-      }
+      final ProjectRowOrderAndPostAggregations 
projectRowOrderAndPostAggregations = computePostAggregations(
+          plannerContext,
+          aggregateRowSignature,
+          aggregateProject,
+          0
+      );
+      projectRowOrderAndPostAggregations.postAggregations.forEach(
+          postAggregator -> 
aggregations.add(Aggregation.create(postAggregator))
+      );
 
       // Remove literal dimensions that did not appear in the projection. This 
is useful for queries
       // like "SELECT COUNT(*) FROM tbl GROUP BY 'dummy'" which some tools can 
generate, and for which we don't
       // actually want to include a dimension 'dummy'.
-      final ImmutableBitSet postProjectBits = 
RelOptUtil.InputFinder.bits(postProject.getChildExps(), null);
+      final ImmutableBitSet aggregateProjectBits = 
RelOptUtil.InputFinder.bits(aggregateProject.getChildExps(), null);
       for (int i = dimensions.size() - 1; i >= 0; i--) {
         final DimensionExpression dimension = dimensions.get(i);
         if (Parser.parse(dimension.getDruidExpression().getExpression(), 
plannerContext.getExprMacroTable())
-                  .isLiteral() && !postProjectBits.get(i)) {
+                  .isLiteral() && !aggregateProjectBits.get(i)) {
           dimensions.remove(i);
         }
       }
@@ -316,11 +304,98 @@ private static Grouping computeGrouping(
           dimensions,
           aggregations,
           havingFilter,
-          RowSignature.from(rowOrder, postProject.getRowType())
+          RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, 
aggregateProject.getRowType())
       );
     }
   }
 
+  @Nullable
+  private SortProject computeSortProject(
+      PartialDruidQuery partialQuery,
+      PlannerContext plannerContext,
+      RowSignature sortingInputRowSignature,
+      Grouping grouping
+  )
+  {
+    final Project sortProject = partialQuery.getSortProject();
+    if (sortProject == null) {
+      return null;
+    } else {
+      final List<PostAggregator> postAggregators = 
grouping.getPostAggregators();
+      final OptionalInt maybeMaxCounter = postAggregators
+          .stream()
+          .mapToInt(postAggregator -> 
Integer.parseInt(postAggregator.getName().substring(1)))
+          .max();
+
+      final ProjectRowOrderAndPostAggregations 
projectRowOrderAndPostAggregations = computePostAggregations(
+          plannerContext,
+          sortingInputRowSignature,
+          sortProject,
+          maybeMaxCounter.orElse(-1) + 1 // 0 if max doesn't exist
+      );
+
+      return new SortProject(
+          sortingInputRowSignature,
+          projectRowOrderAndPostAggregations.postAggregations,
+          RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, 
sortProject.getRowType())
+      );
+    }
+  }
+
+  private static class ProjectRowOrderAndPostAggregations
+  {
+    private final List<String> rowOrder;
+    private final List<PostAggregator> postAggregations;
+
+    ProjectRowOrderAndPostAggregations(List<String> rowOrder, 
List<PostAggregator> postAggregations)
+    {
+      this.rowOrder = rowOrder;
+      this.postAggregations = postAggregations;
+    }
+  }
+
+  private static ProjectRowOrderAndPostAggregations computePostAggregations(
+      PlannerContext plannerContext,
+      RowSignature inputRowSignature,
+      Project project,
+      int outputNameCounter
+  )
+  {
+    final List<String> rowOrder = new ArrayList<>();
+    final List<PostAggregator> aggregations = new ArrayList<>();
+
+    for (final RexNode postAggregatorRexNode : project.getChildExps()) {
+      // Attempt to convert to PostAggregator.
+      final DruidExpression postAggregatorExpression = 
Expressions.toDruidExpression(
+          plannerContext,
+          inputRowSignature,
+          postAggregatorRexNode
+      );
+
+      if (postAggregatorExpression == null) {
+        throw new CannotBuildQueryException(project, postAggregatorRexNode);
+      }
+
+      if (postAggregatorDirectColumnIsOk(inputRowSignature, 
postAggregatorExpression, postAggregatorRexNode)) {
+        // Direct column access, without any type cast as far as Druid's 
runtime is concerned.
+        // (There might be a SQL-level type cast that we don't care about)
+        rowOrder.add(postAggregatorExpression.getDirectColumn());
+      } else {
+        final String postAggregatorName = "p" + outputNameCounter++;
+        final PostAggregator postAggregator = new ExpressionPostAggregator(
+            postAggregatorName,
+            postAggregatorExpression.getExpression(),
+            null,
+            plannerContext.getExprMacroTable()
+        );
+        aggregations.add(postAggregator);
+        rowOrder.add(postAggregator.getName());
+      }
+    }
+
+    return new ProjectRowOrderAndPostAggregations(rowOrder, aggregations);
+  }
+
   /**
    * Returns dimensions corresponding to {@code aggregate.getGroupSet()}, in 
the same order.
    *
@@ -540,18 +615,20 @@ public VirtualColumns getVirtualColumns(final 
ExprMacroTable macroTable, final b
   {
     final List<VirtualColumn> retVal = new ArrayList<>();
 
-    if (grouping != null) {
-      if (includeDimensions) {
-        for (DimensionExpression dimensionExpression : 
grouping.getDimensions()) {
-          retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
+    if (selectProjection != null) {
+      retVal.addAll(selectProjection.getVirtualColumns());
+    } else {
+      if (grouping != null) {
+        if (includeDimensions) {
+          for (DimensionExpression dimensionExpression : 
grouping.getDimensions()) {
+            retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
+          }
         }
-      }
 
-      for (Aggregation aggregation : grouping.getAggregations()) {
-        retVal.addAll(aggregation.getVirtualColumns());
+        for (Aggregation aggregation : grouping.getAggregations()) {
+          retVal.addAll(aggregation.getVirtualColumns());
+        }
       }
-    } else if (selectProjection != null) {
-      retVal.addAll(selectProjection.getVirtualColumns());
     }
 
     return VirtualColumns.create(retVal);
@@ -567,6 +644,11 @@ public DefaultLimitSpec getLimitSpec()
     return limitSpec;
   }
 
+  public SortProject getSortProject()
+  {
+    return sortProject;
+  }
+
   public RelDataType getOutputRowType()
   {
     return outputRowType;
@@ -667,7 +749,6 @@ public TimeseriesQuery toTimeseriesQuery()
 
       if (limitSpec != null) {
         // If there is a limit spec, timeseries cannot LIMIT; and must be 
ORDER BY time (or nothing).
-
         if (limitSpec.isLimited()) {
           return null;
         }
@@ -797,6 +878,11 @@ public GroupByQuery toGroupByQuery()
 
     final Filtration filtration = 
Filtration.create(filter).optimize(sourceRowSignature);
 
+    final List<PostAggregator> postAggregators = new 
ArrayList<>(grouping.getPostAggregators());
+    if (sortProject != null) {
+      postAggregators.addAll(sortProject.getPostAggregators());
+    }
+
     return new GroupByQuery(
         dataSource,
         filtration.getQuerySegmentSpec(),
@@ -805,7 +891,7 @@ public GroupByQuery toGroupByQuery()
         Granularities.ALL,
         grouping.getDimensionSpecs(),
         grouping.getAggregatorFactories(),
-        grouping.getPostAggregators(),
+        postAggregators,
         grouping.getHavingFilter() != null ? new 
DimFilterHavingSpec(grouping.getHavingFilter(), true) : null,
         limitSpec,
         ImmutableSortedMap.copyOf(plannerContext.getQueryContext())
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java 
b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
index 5d0ea438106..d49969174fc 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
@@ -220,14 +220,18 @@ public RelOptCost computeSelfCost(final RelOptPlanner 
planner, final RelMetadata
       cost += COST_PER_COLUMN * 
partialQuery.getAggregate().getAggCallList().size();
     }
 
-    if (partialQuery.getPostProject() != null) {
-      cost += COST_PER_COLUMN * 
partialQuery.getPostProject().getChildExps().size();
+    if (partialQuery.getAggregateProject() != null) {
+      cost += COST_PER_COLUMN * 
partialQuery.getAggregateProject().getChildExps().size();
     }
 
     if (partialQuery.getSort() != null && partialQuery.getSort().fetch != 
null) {
       cost *= COST_LIMIT_MULTIPLIER;
     }
 
+    if (partialQuery.getSortProject() != null) {
+      cost += COST_PER_COLUMN * 
partialQuery.getSortProject().getChildExps().size();
+    }
+
     if (partialQuery.getHavingFilter() != null) {
       cost *= COST_HAVING_MULTIPLIER;
     }
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java 
b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
index 5d6abfc1f01..0da912b9564 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
@@ -358,8 +358,12 @@ public RelOptCost computeSelfCost(final RelOptPlanner 
planner, final RelMetadata
         newPartialQuery = 
newPartialQuery.withHavingFilter(leftPartialQuery.getHavingFilter());
       }
 
-      if (leftPartialQuery.getPostProject() != null) {
-        newPartialQuery = 
newPartialQuery.withPostProject(leftPartialQuery.getPostProject());
+      if (leftPartialQuery.getAggregateProject() != null) {
+        newPartialQuery = 
newPartialQuery.withAggregateProject(leftPartialQuery.getAggregateProject());
+      }
+
+      if (leftPartialQuery.getSortProject() != null) {
+        newPartialQuery = 
newPartialQuery.withSortProject(leftPartialQuery.getSortProject());
       }
 
       if (leftPartialQuery.getSort() != null) {
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java 
b/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
index 263c2e09a1c..ec94838908b 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
@@ -46,8 +46,9 @@
   private final Sort selectSort;
   private final Aggregate aggregate;
   private final Filter havingFilter;
-  private final Project postProject;
+  private final Project aggregateProject;
   private final Sort sort;
+  private final Project sortProject;
 
   public enum Stage
   {
@@ -57,8 +58,9 @@
     SELECT_SORT,
     AGGREGATE,
     HAVING_FILTER,
-    POST_PROJECT,
-    SORT
+    AGGREGATE_PROJECT,
+    SORT,
+    SORT_PROJECT
   }
 
   public PartialDruidQuery(
@@ -67,9 +69,10 @@ public PartialDruidQuery(
       final Project selectProject,
       final Sort selectSort,
       final Aggregate aggregate,
-      final Project postProject,
+      final Project aggregateProject,
       final Filter havingFilter,
-      final Sort sort
+      final Sort sort,
+      final Project sortProject
   )
   {
     this.scan = Preconditions.checkNotNull(scan, "scan");
@@ -77,14 +80,15 @@ public PartialDruidQuery(
     this.selectProject = selectProject;
     this.selectSort = selectSort;
     this.aggregate = aggregate;
-    this.postProject = postProject;
+    this.aggregateProject = aggregateProject;
     this.havingFilter = havingFilter;
     this.sort = sort;
+    this.sortProject = sortProject;
   }
 
   public static PartialDruidQuery create(final RelNode scanRel)
   {
-    return new PartialDruidQuery(scanRel, null, null, null, null, null, null, 
null);
+    return new PartialDruidQuery(scanRel, null, null, null, null, null, null, 
null, null);
   }
 
   public RelNode getScan()
@@ -117,9 +121,9 @@ public Filter getHavingFilter()
     return havingFilter;
   }
 
-  public Project getPostProject()
+  public Project getAggregateProject()
   {
-    return postProject;
+    return aggregateProject;
   }
 
   public Sort getSort()
@@ -127,6 +131,11 @@ public Sort getSort()
     return sort;
   }
 
+  public Project getSortProject()
+  {
+    return sortProject;
+  }
+
   public PartialDruidQuery withWhereFilter(final Filter newWhereFilter)
   {
     validateStage(Stage.WHERE_FILTER);
@@ -136,9 +145,10 @@ public PartialDruidQuery withWhereFilter(final Filter 
newWhereFilter)
         selectProject,
         selectSort,
         aggregate,
-        postProject,
+        aggregateProject,
         havingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
@@ -151,9 +161,10 @@ public PartialDruidQuery withSelectProject(final Project 
newSelectProject)
         newSelectProject,
         selectSort,
         aggregate,
-        postProject,
+        aggregateProject,
         havingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
@@ -166,9 +177,10 @@ public PartialDruidQuery withSelectSort(final Sort 
newSelectSort)
         selectProject,
         newSelectSort,
         aggregate,
-        postProject,
+        aggregateProject,
         havingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
@@ -181,9 +193,10 @@ public PartialDruidQuery withAggregate(final Aggregate 
newAggregate)
         selectProject,
         selectSort,
         newAggregate,
-        postProject,
+        aggregateProject,
         havingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
@@ -196,24 +209,26 @@ public PartialDruidQuery withHavingFilter(final Filter 
newHavingFilter)
         selectProject,
         selectSort,
         aggregate,
-        postProject,
+        aggregateProject,
         newHavingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
-  public PartialDruidQuery withPostProject(final Project newPostProject)
+  public PartialDruidQuery withAggregateProject(final Project 
newAggregateProject)
   {
-    validateStage(Stage.POST_PROJECT);
+    validateStage(Stage.AGGREGATE_PROJECT);
     return new PartialDruidQuery(
         scan,
         whereFilter,
         selectProject,
         selectSort,
         aggregate,
-        newPostProject,
+        newAggregateProject,
         havingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
@@ -226,9 +241,26 @@ public PartialDruidQuery withSort(final Sort newSort)
         selectProject,
         selectSort,
         aggregate,
-        postProject,
+        aggregateProject,
+        havingFilter,
+        newSort,
+        sortProject
+    );
+  }
+
+  public PartialDruidQuery withSortProject(final Project newSortProject)
+  {
+    validateStage(Stage.SORT_PROJECT);
+    return new PartialDruidQuery(
+        scan,
+        whereFilter,
+        selectProject,
+        selectSort,
+        aggregate,
+        aggregateProject,
         havingFilter,
-        newSort
+        sort,
+        newSortProject
     );
   }
 
@@ -265,6 +297,9 @@ public boolean canAccept(final Stage stage)
     } else if (stage.compareTo(Stage.AGGREGATE) >= 0 && selectSort != null) {
       // Cannot do any aggregations after a select + sort.
       return false;
+    } else if (stage.compareTo(Stage.SORT) > 0 && sort == null) {
+      // Cannot add sort project without a sort
+      return false;
     } else {
       // Looks good.
       return true;
@@ -277,12 +312,15 @@ public boolean canAccept(final Stage stage)
    *
    * @return stage
    */
+  @SuppressWarnings("VariableNotUsedInsideIf")
   public Stage stage()
   {
-    if (sort != null) {
+    if (sortProject != null) {
+      return Stage.SORT_PROJECT;
+    } else if (sort != null) {
       return Stage.SORT;
-    } else if (postProject != null) {
-      return Stage.POST_PROJECT;
+    } else if (aggregateProject != null) {
+      return Stage.AGGREGATE_PROJECT;
     } else if (havingFilter != null) {
       return Stage.HAVING_FILTER;
     } else if (aggregate != null) {
@@ -308,10 +346,12 @@ public RelNode leafRel()
     final Stage currentStage = stage();
 
     switch (currentStage) {
+      case SORT_PROJECT:
+        return sortProject;
       case SORT:
         return sort;
-      case POST_PROJECT:
-        return postProject;
+      case AGGREGATE_PROJECT:
+        return aggregateProject;
       case HAVING_FILTER:
         return havingFilter;
       case AGGREGATE:
@@ -352,14 +392,25 @@ public boolean equals(final Object o)
            Objects.equals(selectSort, that.selectSort) &&
            Objects.equals(aggregate, that.aggregate) &&
            Objects.equals(havingFilter, that.havingFilter) &&
-           Objects.equals(postProject, that.postProject) &&
-           Objects.equals(sort, that.sort);
+           Objects.equals(aggregateProject, that.aggregateProject) &&
+           Objects.equals(sort, that.sort) &&
+           Objects.equals(sortProject, that.sortProject);
   }
 
   @Override
   public int hashCode()
   {
-    return Objects.hash(scan, whereFilter, selectProject, selectSort, 
aggregate, havingFilter, postProject, sort);
+    return Objects.hash(
+        scan,
+        whereFilter,
+        selectProject,
+        selectSort,
+        aggregate,
+        havingFilter,
+        aggregateProject,
+        sort,
+        sortProject
+    );
   }
 
   @Override
@@ -372,8 +423,9 @@ public String toString()
            ", selectSort=" + selectSort +
            ", aggregate=" + aggregate +
            ", havingFilter=" + havingFilter +
-           ", postProject=" + postProject +
+           ", aggregateProject=" + aggregateProject +
            ", sort=" + sort +
+           ", sortProject=" + sortProject +
            '}';
   }
 }
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java 
b/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java
new file mode 100644
index 00000000000..c00aff97ee5
--- /dev/null
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to Metamarkets Group Inc. (Metamarkets) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. Metamarkets licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package io.druid.sql.calcite.rel;
+
+import com.google.common.base.Preconditions;
+import io.druid.java.util.common.ISE;
+import io.druid.query.aggregation.PostAggregator;
+import io.druid.sql.calcite.table.RowSignature;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+
+public class SortProject
+{
+  private final RowSignature inputRowSignature;
+  private final List<PostAggregator> postAggregators;
+  private final RowSignature outputRowSignature;
+
+  SortProject(
+      RowSignature inputRowSignature,
+      List<PostAggregator> postAggregators,
+      RowSignature outputRowSignature
+  )
+  {
+    this.inputRowSignature = Preconditions.checkNotNull(inputRowSignature, 
"inputRowSignature");
+    this.postAggregators = Preconditions.checkNotNull(postAggregators, 
"postAggregators");
+    this.outputRowSignature = Preconditions.checkNotNull(outputRowSignature, 
"outputRowSignature");
+
+    // Verify no collisions.
+    final Set<String> seen = new HashSet<>();
+    inputRowSignature.getRowOrder().forEach(field -> {
+      if (!seen.add(field)) {
+        throw new ISE("Duplicate field name: %s", field);
+      }
+    });
+
+    for (PostAggregator postAggregator : postAggregators) {
+      if (postAggregator == null) {
+        throw new ISE("aggregation[%s] is not a postAggregator", 
postAggregator);
+      }
+      if (!seen.add(postAggregator.getName())) {
+        throw new ISE("Duplicate field name: %s", postAggregator.getName());
+      }
+    }
+
+    // Verify that items in the output signature exist.
+    outputRowSignature.getRowOrder().forEach(field -> {
+      if (!seen.contains(field)) {
+        throw new ISE("Missing field in rowOrder: %s", field);
+      }
+    });
+  }
+
+  public List<PostAggregator> getPostAggregators()
+  {
+    return postAggregators;
+  }
+
+  public RowSignature getOutputRowSignature()
+  {
+    return outputRowSignature;
+  }
+
+  @Override
+  public boolean equals(Object o)
+  {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    SortProject sortProject = (SortProject) o;
+    return Objects.equals(inputRowSignature, sortProject.inputRowSignature) &&
+           Objects.equals(postAggregators, sortProject.postAggregators) &&
+           Objects.equals(outputRowSignature, sortProject.outputRowSignature);
+  }
+
+  @Override
+  public int hashCode()
+  {
+    return Objects.hash(inputRowSignature, postAggregators, 
outputRowSignature);
+  }
+
+  @Override
+  public String toString()
+  {
+    return "SortProject{" +
+           "inputRowSignature=" + inputRowSignature +
+           ", postAggregators=" + postAggregators +
+           ", outputRowSignature=" + outputRowSignature +
+           '}';
+  }
+}
diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java 
b/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
index 5ec46abfd01..d85b0c2cf23 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
@@ -68,8 +68,8 @@ private DruidRules()
         ),
         new DruidQueryRule<>(
             Project.class,
-            PartialDruidQuery.Stage.POST_PROJECT,
-            PartialDruidQuery::withPostProject
+            PartialDruidQuery.Stage.AGGREGATE_PROJECT,
+            PartialDruidQuery::withAggregateProject
         ),
         new DruidQueryRule<>(
             Filter.class,
@@ -81,10 +81,16 @@ private DruidRules()
             PartialDruidQuery.Stage.SORT,
             PartialDruidQuery::withSort
         ),
+        new DruidQueryRule<>(
+            Project.class,
+            PartialDruidQuery.Stage.SORT_PROJECT,
+            PartialDruidQuery::withSortProject
+        ),
         DruidOuterQueryRule.AGGREGATE,
         DruidOuterQueryRule.FILTER_AGGREGATE,
         DruidOuterQueryRule.FILTER_PROJECT_AGGREGATE,
-        DruidOuterQueryRule.PROJECT_AGGREGATE
+        DruidOuterQueryRule.PROJECT_AGGREGATE,
+        DruidOuterQueryRule.AGGREGATE_SORT_PROJECT
     );
   }
 
@@ -227,6 +233,32 @@ public void onMatch(final RelOptRuleCall call)
       }
     };
 
+    public static RelOptRule AGGREGATE_SORT_PROJECT = new DruidOuterQueryRule(
+        operand(Project.class, operand(Sort.class, operand(Aggregate.class, 
operand(DruidRel.class, any())))),
+        "AGGREGATE_SORT_PROJECT"
+    )
+    {
+      @Override
+      public void onMatch(RelOptRuleCall call)
+      {
+        final Project sortProject = call.rel(0);
+        final Sort sort = call.rel(1);
+        final Aggregate aggregate = call.rel(2);
+        final DruidRel druidRel = call.rel(3);
+
+        final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
+            druidRel,
+            PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
+                             .withAggregate(aggregate)
+                             .withSort(sort)
+                             .withSortProject(sortProject)
+        );
+        if (outerQueryRel.isValidDruidQuery()) {
+          call.transformTo(outerQueryRel);
+        }
+      }
+    };
+
     public DruidOuterQueryRule(final RelOptRuleOperand op, final String 
description)
     {
       super(op, StringUtils.format("%s:%s", 
DruidOuterQueryRel.class.getSimpleName(), description));
diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java 
b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
index 5376ff124f1..9ef0430932b 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
@@ -24,6 +24,7 @@
 import io.druid.sql.calcite.planner.PlannerConfig;
 import io.druid.sql.calcite.rel.DruidRel;
 import io.druid.sql.calcite.rel.DruidSemiJoin;
+import io.druid.sql.calcite.rel.PartialDruidQuery;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.plan.RelOptUtil;
@@ -115,15 +116,18 @@ public void onMatch(RelOptRuleCall call)
       return;
     }
 
-    final Project rightPostProject = 
right.getPartialDruidQuery().getPostProject();
+    final PartialDruidQuery rightQuery = right.getPartialDruidQuery();
+    final Project rightProject = rightQuery.getSortProject() != null ?
+                                 rightQuery.getSortProject() :
+                                 rightQuery.getAggregateProject();
     int i = 0;
     for (int joinRef : joinInfo.rightSet()) {
       final int aggregateRef;
 
-      if (rightPostProject == null) {
+      if (rightProject == null) {
         aggregateRef = joinRef;
       } else {
-        final RexNode projectExp = 
rightPostProject.getChildExps().get(joinRef);
+        final RexNode projectExp = rightProject.getChildExps().get(joinRef);
         if (projectExp.isA(SqlKind.INPUT_REF)) {
           aggregateRef = ((RexInputRef) projectExp).getIndex();
         } else {
diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java 
b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
index 19722de6ee0..795ecabddc1 100644
--- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
@@ -72,6 +72,7 @@
 import io.druid.query.groupby.having.DimFilterHavingSpec;
 import io.druid.query.groupby.orderby.DefaultLimitSpec;
 import io.druid.query.groupby.orderby.OrderByColumnSpec;
+import io.druid.query.groupby.orderby.OrderByColumnSpec.Direction;
 import io.druid.query.lookup.RegisteredLookupExtractionFn;
 import io.druid.query.ordering.StringComparator;
 import io.druid.query.ordering.StringComparators;
@@ -122,6 +123,7 @@
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
@@ -6377,6 +6379,193 @@ public void testUnicodeFilterAndGroupBy() throws 
Exception
     );
   }
 
+  @Test
+  public void testProjectAfterSort() throws Exception
+  {
+    testQuery(
+        "select dim1 from (select dim1, dim2, count(*) cnt from druid.foo 
group by dim1, dim2 order by cnt)",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(CalciteTests.DATASOURCE1)
+                        .setInterval(QSS(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimensions(
+                            DIMS(
+                                new DefaultDimensionSpec("dim1", "d0"),
+                                new DefaultDimensionSpec("dim2", "d1")
+                            )
+                        )
+                        .setAggregatorSpecs(AGGS(new 
CountAggregatorFactory("a0")))
+                        .setLimitSpec(
+                            new DefaultLimitSpec(
+                                Collections.singletonList(
+                                    new OrderByColumnSpec("a0", 
Direction.ASCENDING, StringComparators.NUMERIC)
+                                ),
+                                Integer.MAX_VALUE
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{""},
+            new Object[]{"1"},
+            new Object[]{"10.1"},
+            new Object[]{"2"},
+            new Object[]{"abc"},
+            new Object[]{"def"}
+        )
+    );
+  }
+
+  @Test
+  public void testProjectAfterSort2() throws Exception
+  {
+    testQuery(
+        "select s / cnt, dim1, dim2, s from (select dim1, dim2, count(*) cnt, 
sum(m2) s from druid.foo group by dim1, dim2 order by cnt)",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(CalciteTests.DATASOURCE1)
+                        .setInterval(QSS(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimensions(
+                            DIMS(
+                                new DefaultDimensionSpec("dim1", "d0"),
+                                new DefaultDimensionSpec("dim2", "d1")
+                            )
+                        )
+                        .setAggregatorSpecs(
+                            AGGS(new CountAggregatorFactory("a0"), new 
DoubleSumAggregatorFactory("a1", "m2"))
+                        )
+                        
.setPostAggregatorSpecs(Collections.singletonList(EXPRESSION_POST_AGG("p0", 
"(\"a1\" / \"a0\")")))
+                        .setLimitSpec(
+                            new DefaultLimitSpec(
+                                Collections.singletonList(
+                                    new OrderByColumnSpec("a0", 
Direction.ASCENDING, StringComparators.NUMERIC)
+                                ),
+                                Integer.MAX_VALUE
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{1.0, "", "a", 1.0},
+            new Object[]{4.0, "1", "a", 4.0},
+            new Object[]{2.0, "10.1", "", 2.0},
+            new Object[]{3.0, "2", "", 3.0},
+            new Object[]{6.0, "abc", "", 6.0},
+            new Object[]{5.0, "def", "abc", 5.0}
+        )
+    );
+  }
+
+  @Test
+  public void testProjectAfterSort3() throws Exception
+  {
+    testQuery(
+        "select dim1 from (select dim1, dim1, count(*) cnt from druid.foo 
group by dim1, dim1 order by cnt)",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(CalciteTests.DATASOURCE1)
+                        .setInterval(QSS(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimensions(
+                            DIMS(
+                                new DefaultDimensionSpec("dim1", "d0")
+                            )
+                        )
+                        .setAggregatorSpecs(AGGS(new 
CountAggregatorFactory("a0")))
+                        .setLimitSpec(
+                            new DefaultLimitSpec(
+                                Collections.singletonList(
+                                    new OrderByColumnSpec("a0", 
Direction.ASCENDING, StringComparators.NUMERIC)
+                                ),
+                                Integer.MAX_VALUE
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{""},
+            new Object[]{"1"},
+            new Object[]{"10.1"},
+            new Object[]{"2"},
+            new Object[]{"abc"},
+            new Object[]{"def"}
+        )
+    );
+  }
+
+  @Test
+  public void testSortProjectAfterNestedGroupBy() throws Exception
+  {
+    testQuery(
+        "SELECT "
+        + "  cnt "
+        + "FROM ("
+        + "  SELECT "
+        + "    __time, "
+        + "    dim1, "
+        + "    COUNT(m2) AS cnt "
+        + "  FROM ("
+        + "    SELECT "
+        + "        __time, "
+        + "        m2, "
+        + "        dim1 "
+        + "    FROM druid.foo "
+        + "    GROUP BY __time, m2, dim1 "
+        + "  ) "
+        + "  GROUP BY __time, dim1 "
+        + "  ORDER BY cnt"
+        + ")",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            GroupByQuery.builder()
+                                        
.setDataSource(CalciteTests.DATASOURCE1)
+                                        
.setInterval(QSS(Filtration.eternity()))
+                                        .setGranularity(Granularities.ALL)
+                                        .setDimensions(DIMS(
+                                            new DefaultDimensionSpec("__time", 
"d0", ValueType.LONG),
+                                            new DefaultDimensionSpec("dim1", 
"d1"),
+                                            new DefaultDimensionSpec("m2", 
"d2", ValueType.DOUBLE)
+                                        ))
+                                        .setContext(QUERY_CONTEXT_DEFAULT)
+                                        .build()
+                        )
+                        .setInterval(QSS(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimensions(DIMS(
+                            new DefaultDimensionSpec("d0", "_d0", 
ValueType.LONG),
+                            new DefaultDimensionSpec("d1", "_d1", 
ValueType.STRING)
+                        ))
+                        .setAggregatorSpecs(AGGS(
+                            new CountAggregatorFactory("a0")
+                        ))
+                        .setLimitSpec(
+                            new DefaultLimitSpec(
+                                Collections.singletonList(
+                                    new OrderByColumnSpec("a0", 
Direction.ASCENDING, StringComparators.NUMERIC)
+                                ),
+                                Integer.MAX_VALUE
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{1L},
+            new Object[]{1L},
+            new Object[]{1L},
+            new Object[]{1L},
+            new Object[]{1L},
+            new Object[]{1L}
+        )
+    );
+  }
+
   private void testQuery(
       final String sql,
       final List<Query> expectedQueries,


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to