fjy closed pull request #6228: [Backport] Support projection after sorting in
SQL
URL: https://github.com/apache/incubator-druid/pull/6228
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
index 2532c8d7f82..09436b96e9d 100644
--- a/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
+++ b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
@@ -36,6 +36,7 @@
import io.druid.sql.calcite.table.RowSignature;
import javax.annotation.Nullable;
+import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@@ -112,7 +113,7 @@ public static Aggregation create(final AggregatorFactory
aggregatorFactory)
public static Aggregation create(final PostAggregator postAggregator)
{
- return new Aggregation(ImmutableList.of(), ImmutableList.of(),
postAggregator);
+ return new Aggregation(Collections.emptyList(), Collections.emptyList(),
postAggregator);
}
public static Aggregation create(
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
index bca4481992f..5503f50adf9 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
@@ -89,6 +89,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import java.util.OptionalInt;
import java.util.TreeSet;
import java.util.stream.Collectors;
@@ -105,9 +106,11 @@
private final DimFilter filter;
private final SelectProjection selectProjection;
private final Grouping grouping;
+ private final SortProject sortProject;
+ private final DefaultLimitSpec limitSpec;
private final RowSignature outputRowSignature;
private final RelDataType outputRowType;
- private final DefaultLimitSpec limitSpec;
+
private final Query query;
public DruidQuery(
@@ -128,15 +131,22 @@ public DruidQuery(
this.selectProjection = computeSelectProjection(partialQuery,
plannerContext, sourceRowSignature);
this.grouping = computeGrouping(partialQuery, plannerContext,
sourceRowSignature, rexBuilder);
+ final RowSignature sortingInputRowSignature;
+
if (this.selectProjection != null) {
- this.outputRowSignature = this.selectProjection.getOutputRowSignature();
+ sortingInputRowSignature = this.selectProjection.getOutputRowSignature();
} else if (this.grouping != null) {
- this.outputRowSignature = this.grouping.getOutputRowSignature();
+ sortingInputRowSignature = this.grouping.getOutputRowSignature();
} else {
- this.outputRowSignature = sourceRowSignature;
+ sortingInputRowSignature = sourceRowSignature;
}
- this.limitSpec = computeLimitSpec(partialQuery, this.outputRowSignature);
+ this.sortProject = computeSortProject(partialQuery, plannerContext,
sortingInputRowSignature, grouping);
+
+ // outputRowSignature is used only for scan and select query, and thus
sort and grouping must be null
+ this.outputRowSignature = sortProject == null ? sortingInputRowSignature :
sortProject.getOutputRowSignature();
+
+ this.limitSpec = computeLimitSpec(partialQuery, sortingInputRowSignature);
this.query = computeQuery();
}
@@ -235,7 +245,7 @@ private static Grouping computeGrouping(
)
{
final Aggregate aggregate = partialQuery.getAggregate();
- final Project postProject = partialQuery.getPostProject();
+ final Project aggregateProject = partialQuery.getAggregateProject();
if (aggregate == null) {
return null;
@@ -265,49 +275,27 @@ private static Grouping computeGrouping(
plannerContext
);
- if (postProject == null) {
+ if (aggregateProject == null) {
return Grouping.create(dimensions, aggregations, havingFilter,
aggregateRowSignature);
} else {
- final List<String> rowOrder = new ArrayList<>();
-
- int outputNameCounter = 0;
- for (final RexNode postAggregatorRexNode : postProject.getChildExps()) {
- // Attempt to convert to PostAggregator.
- final DruidExpression postAggregatorExpression =
Expressions.toDruidExpression(
- plannerContext,
- aggregateRowSignature,
- postAggregatorRexNode
- );
-
- if (postAggregatorExpression == null) {
- throw new CannotBuildQueryException(postProject,
postAggregatorRexNode);
- }
-
- if (postAggregatorDirectColumnIsOk(aggregateRowSignature,
postAggregatorExpression, postAggregatorRexNode)) {
- // Direct column access, without any type cast as far as Druid's
runtime is concerned.
- // (There might be a SQL-level type cast that we don't care about)
- rowOrder.add(postAggregatorExpression.getDirectColumn());
- } else {
- final String postAggregatorName = "p" + outputNameCounter++;
- final PostAggregator postAggregator = new ExpressionPostAggregator(
- postAggregatorName,
- postAggregatorExpression.getExpression(),
- null,
- plannerContext.getExprMacroTable()
- );
- aggregations.add(Aggregation.create(postAggregator));
- rowOrder.add(postAggregator.getName());
- }
- }
+ final ProjectRowOrderAndPostAggregations
projectRowOrderAndPostAggregations = computePostAggregations(
+ plannerContext,
+ aggregateRowSignature,
+ aggregateProject,
+ 0
+ );
+ projectRowOrderAndPostAggregations.postAggregations.forEach(
+ postAggregator ->
aggregations.add(Aggregation.create(postAggregator))
+ );
// Remove literal dimensions that did not appear in the projection. This
is useful for queries
// like "SELECT COUNT(*) FROM tbl GROUP BY 'dummy'" which some tools can
generate, and for which we don't
// actually want to include a dimension 'dummy'.
- final ImmutableBitSet postProjectBits =
RelOptUtil.InputFinder.bits(postProject.getChildExps(), null);
+ final ImmutableBitSet aggregateProjectBits =
RelOptUtil.InputFinder.bits(aggregateProject.getChildExps(), null);
for (int i = dimensions.size() - 1; i >= 0; i--) {
final DimensionExpression dimension = dimensions.get(i);
if (Parser.parse(dimension.getDruidExpression().getExpression(),
plannerContext.getExprMacroTable())
- .isLiteral() && !postProjectBits.get(i)) {
+ .isLiteral() && !aggregateProjectBits.get(i)) {
dimensions.remove(i);
}
}
@@ -316,11 +304,98 @@ private static Grouping computeGrouping(
dimensions,
aggregations,
havingFilter,
- RowSignature.from(rowOrder, postProject.getRowType())
+ RowSignature.from(projectRowOrderAndPostAggregations.rowOrder,
aggregateProject.getRowType())
);
}
}
+ @Nullable
+ private SortProject computeSortProject(
+ PartialDruidQuery partialQuery,
+ PlannerContext plannerContext,
+ RowSignature sortingInputRowSignature,
+ Grouping grouping
+ )
+ {
+ final Project sortProject = partialQuery.getSortProject();
+ if (sortProject == null) {
+ return null;
+ } else {
+ final List<PostAggregator> postAggregators =
grouping.getPostAggregators();
+ final OptionalInt maybeMaxCounter = postAggregators
+ .stream()
+ .mapToInt(postAggregator ->
Integer.parseInt(postAggregator.getName().substring(1)))
+ .max();
+
+ final ProjectRowOrderAndPostAggregations
projectRowOrderAndPostAggregations = computePostAggregations(
+ plannerContext,
+ sortingInputRowSignature,
+ sortProject,
+ maybeMaxCounter.orElse(-1) + 1 // 0 if max doesn't exist
+ );
+
+ return new SortProject(
+ sortingInputRowSignature,
+ projectRowOrderAndPostAggregations.postAggregations,
+ RowSignature.from(projectRowOrderAndPostAggregations.rowOrder,
sortProject.getRowType())
+ );
+ }
+ }
+
+ private static class ProjectRowOrderAndPostAggregations
+ {
+ private final List<String> rowOrder;
+ private final List<PostAggregator> postAggregations;
+
+ ProjectRowOrderAndPostAggregations(List<String> rowOrder,
List<PostAggregator> postAggregations)
+ {
+ this.rowOrder = rowOrder;
+ this.postAggregations = postAggregations;
+ }
+ }
+
+ private static ProjectRowOrderAndPostAggregations computePostAggregations(
+ PlannerContext plannerContext,
+ RowSignature inputRowSignature,
+ Project project,
+ int outputNameCounter
+ )
+ {
+ final List<String> rowOrder = new ArrayList<>();
+ final List<PostAggregator> aggregations = new ArrayList<>();
+
+ for (final RexNode postAggregatorRexNode : project.getChildExps()) {
+ // Attempt to convert to PostAggregator.
+ final DruidExpression postAggregatorExpression =
Expressions.toDruidExpression(
+ plannerContext,
+ inputRowSignature,
+ postAggregatorRexNode
+ );
+
+ if (postAggregatorExpression == null) {
+ throw new CannotBuildQueryException(project, postAggregatorRexNode);
+ }
+
+ if (postAggregatorDirectColumnIsOk(inputRowSignature,
postAggregatorExpression, postAggregatorRexNode)) {
+ // Direct column access, without any type cast as far as Druid's
runtime is concerned.
+ // (There might be a SQL-level type cast that we don't care about)
+ rowOrder.add(postAggregatorExpression.getDirectColumn());
+ } else {
+ final String postAggregatorName = "p" + outputNameCounter++;
+ final PostAggregator postAggregator = new ExpressionPostAggregator(
+ postAggregatorName,
+ postAggregatorExpression.getExpression(),
+ null,
+ plannerContext.getExprMacroTable()
+ );
+ aggregations.add(postAggregator);
+ rowOrder.add(postAggregator.getName());
+ }
+ }
+
+ return new ProjectRowOrderAndPostAggregations(rowOrder, aggregations);
+ }
+
/**
* Returns dimensions corresponding to {@code aggregate.getGroupSet()}, in
the same order.
*
@@ -540,18 +615,20 @@ public VirtualColumns getVirtualColumns(final
ExprMacroTable macroTable, final b
{
final List<VirtualColumn> retVal = new ArrayList<>();
- if (grouping != null) {
- if (includeDimensions) {
- for (DimensionExpression dimensionExpression :
grouping.getDimensions()) {
- retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
+ if (selectProjection != null) {
+ retVal.addAll(selectProjection.getVirtualColumns());
+ } else {
+ if (grouping != null) {
+ if (includeDimensions) {
+ for (DimensionExpression dimensionExpression :
grouping.getDimensions()) {
+ retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
+ }
}
- }
- for (Aggregation aggregation : grouping.getAggregations()) {
- retVal.addAll(aggregation.getVirtualColumns());
+ for (Aggregation aggregation : grouping.getAggregations()) {
+ retVal.addAll(aggregation.getVirtualColumns());
+ }
}
- } else if (selectProjection != null) {
- retVal.addAll(selectProjection.getVirtualColumns());
}
return VirtualColumns.create(retVal);
@@ -567,6 +644,11 @@ public DefaultLimitSpec getLimitSpec()
return limitSpec;
}
+ public SortProject getSortProject()
+ {
+ return sortProject;
+ }
+
public RelDataType getOutputRowType()
{
return outputRowType;
@@ -667,7 +749,6 @@ public TimeseriesQuery toTimeseriesQuery()
if (limitSpec != null) {
// If there is a limit spec, timeseries cannot LIMIT; and must be
ORDER BY time (or nothing).
-
if (limitSpec.isLimited()) {
return null;
}
@@ -797,6 +878,11 @@ public GroupByQuery toGroupByQuery()
final Filtration filtration =
Filtration.create(filter).optimize(sourceRowSignature);
+ final List<PostAggregator> postAggregators = new
ArrayList<>(grouping.getPostAggregators());
+ if (sortProject != null) {
+ postAggregators.addAll(sortProject.getPostAggregators());
+ }
+
return new GroupByQuery(
dataSource,
filtration.getQuerySegmentSpec(),
@@ -805,7 +891,7 @@ public GroupByQuery toGroupByQuery()
Granularities.ALL,
grouping.getDimensionSpecs(),
grouping.getAggregatorFactories(),
- grouping.getPostAggregators(),
+ postAggregators,
grouping.getHavingFilter() != null ? new
DimFilterHavingSpec(grouping.getHavingFilter(), true) : null,
limitSpec,
ImmutableSortedMap.copyOf(plannerContext.getQueryContext())
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
index 5d0ea438106..d49969174fc 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
@@ -220,14 +220,18 @@ public RelOptCost computeSelfCost(final RelOptPlanner
planner, final RelMetadata
cost += COST_PER_COLUMN *
partialQuery.getAggregate().getAggCallList().size();
}
- if (partialQuery.getPostProject() != null) {
- cost += COST_PER_COLUMN *
partialQuery.getPostProject().getChildExps().size();
+ if (partialQuery.getAggregateProject() != null) {
+ cost += COST_PER_COLUMN *
partialQuery.getAggregateProject().getChildExps().size();
}
if (partialQuery.getSort() != null && partialQuery.getSort().fetch !=
null) {
cost *= COST_LIMIT_MULTIPLIER;
}
+ if (partialQuery.getSortProject() != null) {
+ cost += COST_PER_COLUMN *
partialQuery.getSortProject().getChildExps().size();
+ }
+
if (partialQuery.getHavingFilter() != null) {
cost *= COST_HAVING_MULTIPLIER;
}
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
index 5d6abfc1f01..0da912b9564 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
@@ -358,8 +358,12 @@ public RelOptCost computeSelfCost(final RelOptPlanner
planner, final RelMetadata
newPartialQuery =
newPartialQuery.withHavingFilter(leftPartialQuery.getHavingFilter());
}
- if (leftPartialQuery.getPostProject() != null) {
- newPartialQuery =
newPartialQuery.withPostProject(leftPartialQuery.getPostProject());
+ if (leftPartialQuery.getAggregateProject() != null) {
+ newPartialQuery =
newPartialQuery.withAggregateProject(leftPartialQuery.getAggregateProject());
+ }
+
+ if (leftPartialQuery.getSortProject() != null) {
+ newPartialQuery =
newPartialQuery.withSortProject(leftPartialQuery.getSortProject());
}
if (leftPartialQuery.getSort() != null) {
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
b/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
index 263c2e09a1c..ec94838908b 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
@@ -46,8 +46,9 @@
private final Sort selectSort;
private final Aggregate aggregate;
private final Filter havingFilter;
- private final Project postProject;
+ private final Project aggregateProject;
private final Sort sort;
+ private final Project sortProject;
public enum Stage
{
@@ -57,8 +58,9 @@
SELECT_SORT,
AGGREGATE,
HAVING_FILTER,
- POST_PROJECT,
- SORT
+ AGGREGATE_PROJECT,
+ SORT,
+ SORT_PROJECT
}
public PartialDruidQuery(
@@ -67,9 +69,10 @@ public PartialDruidQuery(
final Project selectProject,
final Sort selectSort,
final Aggregate aggregate,
- final Project postProject,
+ final Project aggregateProject,
final Filter havingFilter,
- final Sort sort
+ final Sort sort,
+ final Project sortProject
)
{
this.scan = Preconditions.checkNotNull(scan, "scan");
@@ -77,14 +80,15 @@ public PartialDruidQuery(
this.selectProject = selectProject;
this.selectSort = selectSort;
this.aggregate = aggregate;
- this.postProject = postProject;
+ this.aggregateProject = aggregateProject;
this.havingFilter = havingFilter;
this.sort = sort;
+ this.sortProject = sortProject;
}
public static PartialDruidQuery create(final RelNode scanRel)
{
- return new PartialDruidQuery(scanRel, null, null, null, null, null, null,
null);
+ return new PartialDruidQuery(scanRel, null, null, null, null, null, null,
null, null);
}
public RelNode getScan()
@@ -117,9 +121,9 @@ public Filter getHavingFilter()
return havingFilter;
}
- public Project getPostProject()
+ public Project getAggregateProject()
{
- return postProject;
+ return aggregateProject;
}
public Sort getSort()
@@ -127,6 +131,11 @@ public Sort getSort()
return sort;
}
+ public Project getSortProject()
+ {
+ return sortProject;
+ }
+
public PartialDruidQuery withWhereFilter(final Filter newWhereFilter)
{
validateStage(Stage.WHERE_FILTER);
@@ -136,9 +145,10 @@ public PartialDruidQuery withWhereFilter(final Filter
newWhereFilter)
selectProject,
selectSort,
aggregate,
- postProject,
+ aggregateProject,
havingFilter,
- sort
+ sort,
+ sortProject
);
}
@@ -151,9 +161,10 @@ public PartialDruidQuery withSelectProject(final Project
newSelectProject)
newSelectProject,
selectSort,
aggregate,
- postProject,
+ aggregateProject,
havingFilter,
- sort
+ sort,
+ sortProject
);
}
@@ -166,9 +177,10 @@ public PartialDruidQuery withSelectSort(final Sort
newSelectSort)
selectProject,
newSelectSort,
aggregate,
- postProject,
+ aggregateProject,
havingFilter,
- sort
+ sort,
+ sortProject
);
}
@@ -181,9 +193,10 @@ public PartialDruidQuery withAggregate(final Aggregate
newAggregate)
selectProject,
selectSort,
newAggregate,
- postProject,
+ aggregateProject,
havingFilter,
- sort
+ sort,
+ sortProject
);
}
@@ -196,24 +209,26 @@ public PartialDruidQuery withHavingFilter(final Filter
newHavingFilter)
selectProject,
selectSort,
aggregate,
- postProject,
+ aggregateProject,
newHavingFilter,
- sort
+ sort,
+ sortProject
);
}
- public PartialDruidQuery withPostProject(final Project newPostProject)
+ public PartialDruidQuery withAggregateProject(final Project
newAggregateProject)
{
- validateStage(Stage.POST_PROJECT);
+ validateStage(Stage.AGGREGATE_PROJECT);
return new PartialDruidQuery(
scan,
whereFilter,
selectProject,
selectSort,
aggregate,
- newPostProject,
+ newAggregateProject,
havingFilter,
- sort
+ sort,
+ sortProject
);
}
@@ -226,9 +241,26 @@ public PartialDruidQuery withSort(final Sort newSort)
selectProject,
selectSort,
aggregate,
- postProject,
+ aggregateProject,
+ havingFilter,
+ newSort,
+ sortProject
+ );
+ }
+
+ public PartialDruidQuery withSortProject(final Project newSortProject)
+ {
+ validateStage(Stage.SORT_PROJECT);
+ return new PartialDruidQuery(
+ scan,
+ whereFilter,
+ selectProject,
+ selectSort,
+ aggregate,
+ aggregateProject,
havingFilter,
- newSort
+ sort,
+ newSortProject
);
}
@@ -265,6 +297,9 @@ public boolean canAccept(final Stage stage)
} else if (stage.compareTo(Stage.AGGREGATE) >= 0 && selectSort != null) {
// Cannot do any aggregations after a select + sort.
return false;
+ } else if (stage.compareTo(Stage.SORT) > 0 && sort == null) {
+ // Cannot add sort project without a sort
+ return false;
} else {
// Looks good.
return true;
@@ -277,12 +312,15 @@ public boolean canAccept(final Stage stage)
*
* @return stage
*/
+ @SuppressWarnings("VariableNotUsedInsideIf")
public Stage stage()
{
- if (sort != null) {
+ if (sortProject != null) {
+ return Stage.SORT_PROJECT;
+ } else if (sort != null) {
return Stage.SORT;
- } else if (postProject != null) {
- return Stage.POST_PROJECT;
+ } else if (aggregateProject != null) {
+ return Stage.AGGREGATE_PROJECT;
} else if (havingFilter != null) {
return Stage.HAVING_FILTER;
} else if (aggregate != null) {
@@ -308,10 +346,12 @@ public RelNode leafRel()
final Stage currentStage = stage();
switch (currentStage) {
+ case SORT_PROJECT:
+ return sortProject;
case SORT:
return sort;
- case POST_PROJECT:
- return postProject;
+ case AGGREGATE_PROJECT:
+ return aggregateProject;
case HAVING_FILTER:
return havingFilter;
case AGGREGATE:
@@ -352,14 +392,25 @@ public boolean equals(final Object o)
Objects.equals(selectSort, that.selectSort) &&
Objects.equals(aggregate, that.aggregate) &&
Objects.equals(havingFilter, that.havingFilter) &&
- Objects.equals(postProject, that.postProject) &&
- Objects.equals(sort, that.sort);
+ Objects.equals(aggregateProject, that.aggregateProject) &&
+ Objects.equals(sort, that.sort) &&
+ Objects.equals(sortProject, that.sortProject);
}
@Override
public int hashCode()
{
- return Objects.hash(scan, whereFilter, selectProject, selectSort,
aggregate, havingFilter, postProject, sort);
+ return Objects.hash(
+ scan,
+ whereFilter,
+ selectProject,
+ selectSort,
+ aggregate,
+ havingFilter,
+ aggregateProject,
+ sort,
+ sortProject
+ );
}
@Override
@@ -372,8 +423,9 @@ public String toString()
", selectSort=" + selectSort +
", aggregate=" + aggregate +
", havingFilter=" + havingFilter +
- ", postProject=" + postProject +
+ ", aggregateProject=" + aggregateProject +
", sort=" + sort +
+ ", sortProject=" + sortProject +
'}';
}
}
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java
b/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java
new file mode 100644
index 00000000000..c00aff97ee5
--- /dev/null
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to Metamarkets Group Inc. (Metamarkets) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. Metamarkets licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package io.druid.sql.calcite.rel;
+
+import com.google.common.base.Preconditions;
+import io.druid.java.util.common.ISE;
+import io.druid.query.aggregation.PostAggregator;
+import io.druid.sql.calcite.table.RowSignature;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+
+public class SortProject
+{
+ private final RowSignature inputRowSignature;
+ private final List<PostAggregator> postAggregators;
+ private final RowSignature outputRowSignature;
+
+ SortProject(
+ RowSignature inputRowSignature,
+ List<PostAggregator> postAggregators,
+ RowSignature outputRowSignature
+ )
+ {
+ this.inputRowSignature = Preconditions.checkNotNull(inputRowSignature,
"inputRowSignature");
+ this.postAggregators = Preconditions.checkNotNull(postAggregators,
"postAggregators");
+ this.outputRowSignature = Preconditions.checkNotNull(outputRowSignature,
"outputRowSignature");
+
+ // Verify no collisions.
+ final Set<String> seen = new HashSet<>();
+ inputRowSignature.getRowOrder().forEach(field -> {
+ if (!seen.add(field)) {
+ throw new ISE("Duplicate field name: %s", field);
+ }
+ });
+
+ for (PostAggregator postAggregator : postAggregators) {
+ if (postAggregator == null) {
+ throw new ISE("aggregation[%s] is not a postAggregator",
postAggregator);
+ }
+ if (!seen.add(postAggregator.getName())) {
+ throw new ISE("Duplicate field name: %s", postAggregator.getName());
+ }
+ }
+
+ // Verify that items in the output signature exist.
+ outputRowSignature.getRowOrder().forEach(field -> {
+ if (!seen.contains(field)) {
+ throw new ISE("Missing field in rowOrder: %s", field);
+ }
+ });
+ }
+
+ public List<PostAggregator> getPostAggregators()
+ {
+ return postAggregators;
+ }
+
+ public RowSignature getOutputRowSignature()
+ {
+ return outputRowSignature;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ SortProject sortProject = (SortProject) o;
+ return Objects.equals(inputRowSignature, sortProject.inputRowSignature) &&
+ Objects.equals(postAggregators, sortProject.postAggregators) &&
+ Objects.equals(outputRowSignature, sortProject.outputRowSignature);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(inputRowSignature, postAggregators,
outputRowSignature);
+ }
+
+ @Override
+ public String toString()
+ {
+ return "SortProject{" +
+ "inputRowSignature=" + inputRowSignature +
+ ", postAggregators=" + postAggregators +
+ ", outputRowSignature=" + outputRowSignature +
+ '}';
+ }
+}
diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
b/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
index 5ec46abfd01..d85b0c2cf23 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
@@ -68,8 +68,8 @@ private DruidRules()
),
new DruidQueryRule<>(
Project.class,
- PartialDruidQuery.Stage.POST_PROJECT,
- PartialDruidQuery::withPostProject
+ PartialDruidQuery.Stage.AGGREGATE_PROJECT,
+ PartialDruidQuery::withAggregateProject
),
new DruidQueryRule<>(
Filter.class,
@@ -81,10 +81,16 @@ private DruidRules()
PartialDruidQuery.Stage.SORT,
PartialDruidQuery::withSort
),
+ new DruidQueryRule<>(
+ Project.class,
+ PartialDruidQuery.Stage.SORT_PROJECT,
+ PartialDruidQuery::withSortProject
+ ),
DruidOuterQueryRule.AGGREGATE,
DruidOuterQueryRule.FILTER_AGGREGATE,
DruidOuterQueryRule.FILTER_PROJECT_AGGREGATE,
- DruidOuterQueryRule.PROJECT_AGGREGATE
+ DruidOuterQueryRule.PROJECT_AGGREGATE,
+ DruidOuterQueryRule.AGGREGATE_SORT_PROJECT
);
}
@@ -227,6 +233,32 @@ public void onMatch(final RelOptRuleCall call)
}
};
+ public static RelOptRule AGGREGATE_SORT_PROJECT = new DruidOuterQueryRule(
+ operand(Project.class, operand(Sort.class, operand(Aggregate.class,
operand(DruidRel.class, any())))),
+ "AGGREGATE_SORT_PROJECT"
+ )
+ {
+ @Override
+ public void onMatch(RelOptRuleCall call)
+ {
+ final Project sortProject = call.rel(0);
+ final Sort sort = call.rel(1);
+ final Aggregate aggregate = call.rel(2);
+ final DruidRel druidRel = call.rel(3);
+
+ final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
+ druidRel,
+ PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
+ .withAggregate(aggregate)
+ .withSort(sort)
+ .withSortProject(sortProject)
+ );
+ if (outerQueryRel.isValidDruidQuery()) {
+ call.transformTo(outerQueryRel);
+ }
+ }
+ };
+
public DruidOuterQueryRule(final RelOptRuleOperand op, final String
description)
{
super(op, StringUtils.format("%s:%s",
DruidOuterQueryRel.class.getSimpleName(), description));
diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
index 5376ff124f1..9ef0430932b 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
@@ -24,6 +24,7 @@
import io.druid.sql.calcite.planner.PlannerConfig;
import io.druid.sql.calcite.rel.DruidRel;
import io.druid.sql.calcite.rel.DruidSemiJoin;
+import io.druid.sql.calcite.rel.PartialDruidQuery;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
@@ -115,15 +116,18 @@ public void onMatch(RelOptRuleCall call)
return;
}
- final Project rightPostProject =
right.getPartialDruidQuery().getPostProject();
+ final PartialDruidQuery rightQuery = right.getPartialDruidQuery();
+ final Project rightProject = rightQuery.getSortProject() != null ?
+ rightQuery.getSortProject() :
+ rightQuery.getAggregateProject();
int i = 0;
for (int joinRef : joinInfo.rightSet()) {
final int aggregateRef;
- if (rightPostProject == null) {
+ if (rightProject == null) {
aggregateRef = joinRef;
} else {
- final RexNode projectExp =
rightPostProject.getChildExps().get(joinRef);
+ final RexNode projectExp = rightProject.getChildExps().get(joinRef);
if (projectExp.isA(SqlKind.INPUT_REF)) {
aggregateRef = ((RexInputRef) projectExp).getIndex();
} else {
diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
index 19722de6ee0..795ecabddc1 100644
--- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
@@ -72,6 +72,7 @@
import io.druid.query.groupby.having.DimFilterHavingSpec;
import io.druid.query.groupby.orderby.DefaultLimitSpec;
import io.druid.query.groupby.orderby.OrderByColumnSpec;
+import io.druid.query.groupby.orderby.OrderByColumnSpec.Direction;
import io.druid.query.lookup.RegisteredLookupExtractionFn;
import io.druid.query.ordering.StringComparator;
import io.druid.query.ordering.StringComparators;
@@ -122,6 +123,7 @@
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -6377,6 +6379,193 @@ public void testUnicodeFilterAndGroupBy() throws
Exception
);
}
+ @Test
+ public void testProjectAfterSort() throws Exception
+ {
+ testQuery(
+ "select dim1 from (select dim1, dim2, count(*) cnt from druid.foo
group by dim1, dim2 order by cnt)",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(QSS(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(
+ DIMS(
+ new DefaultDimensionSpec("dim1", "d0"),
+ new DefaultDimensionSpec("dim2", "d1")
+ )
+ )
+ .setAggregatorSpecs(AGGS(new
CountAggregatorFactory("a0")))
+ .setLimitSpec(
+ new DefaultLimitSpec(
+ Collections.singletonList(
+ new OrderByColumnSpec("a0",
Direction.ASCENDING, StringComparators.NUMERIC)
+ ),
+ Integer.MAX_VALUE
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{""},
+ new Object[]{"1"},
+ new Object[]{"10.1"},
+ new Object[]{"2"},
+ new Object[]{"abc"},
+ new Object[]{"def"}
+ )
+ );
+ }
+
+ @Test
+ public void testProjectAfterSort2() throws Exception
+ {
+ testQuery(
+ "select s / cnt, dim1, dim2, s from (select dim1, dim2, count(*) cnt,
sum(m2) s from druid.foo group by dim1, dim2 order by cnt)",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(QSS(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(
+ DIMS(
+ new DefaultDimensionSpec("dim1", "d0"),
+ new DefaultDimensionSpec("dim2", "d1")
+ )
+ )
+ .setAggregatorSpecs(
+ AGGS(new CountAggregatorFactory("a0"), new
DoubleSumAggregatorFactory("a1", "m2"))
+ )
+
.setPostAggregatorSpecs(Collections.singletonList(EXPRESSION_POST_AGG("p0",
"(\"a1\" / \"a0\")")))
+ .setLimitSpec(
+ new DefaultLimitSpec(
+ Collections.singletonList(
+ new OrderByColumnSpec("a0",
Direction.ASCENDING, StringComparators.NUMERIC)
+ ),
+ Integer.MAX_VALUE
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{1.0, "", "a", 1.0},
+ new Object[]{4.0, "1", "a", 4.0},
+ new Object[]{2.0, "10.1", "", 2.0},
+ new Object[]{3.0, "2", "", 3.0},
+ new Object[]{6.0, "abc", "", 6.0},
+ new Object[]{5.0, "def", "abc", 5.0}
+ )
+ );
+ }
+
+ @Test
+ public void testProjectAfterSort3() throws Exception
+ {
+ testQuery(
+ "select dim1 from (select dim1, dim1, count(*) cnt from druid.foo
group by dim1, dim1 order by cnt)",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(QSS(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(
+ DIMS(
+ new DefaultDimensionSpec("dim1", "d0")
+ )
+ )
+ .setAggregatorSpecs(AGGS(new
CountAggregatorFactory("a0")))
+ .setLimitSpec(
+ new DefaultLimitSpec(
+ Collections.singletonList(
+ new OrderByColumnSpec("a0",
Direction.ASCENDING, StringComparators.NUMERIC)
+ ),
+ Integer.MAX_VALUE
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{""},
+ new Object[]{"1"},
+ new Object[]{"10.1"},
+ new Object[]{"2"},
+ new Object[]{"abc"},
+ new Object[]{"def"}
+ )
+ );
+ }
+
+ @Test
+ public void testSortProjectAfterNestedGroupBy() throws Exception
+ {
+ testQuery(
+ "SELECT "
+ + " cnt "
+ + "FROM ("
+ + " SELECT "
+ + " __time, "
+ + " dim1, "
+ + " COUNT(m2) AS cnt "
+ + " FROM ("
+ + " SELECT "
+ + " __time, "
+ + " m2, "
+ + " dim1 "
+ + " FROM druid.foo "
+ + " GROUP BY __time, m2, dim1 "
+ + " ) "
+ + " GROUP BY __time, dim1 "
+ + " ORDER BY cnt"
+ + ")",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(
+ GroupByQuery.builder()
+
.setDataSource(CalciteTests.DATASOURCE1)
+
.setInterval(QSS(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(DIMS(
+ new DefaultDimensionSpec("__time",
"d0", ValueType.LONG),
+ new DefaultDimensionSpec("dim1",
"d1"),
+ new DefaultDimensionSpec("m2",
"d2", ValueType.DOUBLE)
+ ))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ )
+ .setInterval(QSS(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(DIMS(
+ new DefaultDimensionSpec("d0", "_d0",
ValueType.LONG),
+ new DefaultDimensionSpec("d1", "_d1",
ValueType.STRING)
+ ))
+ .setAggregatorSpecs(AGGS(
+ new CountAggregatorFactory("a0")
+ ))
+ .setLimitSpec(
+ new DefaultLimitSpec(
+ Collections.singletonList(
+ new OrderByColumnSpec("a0",
Direction.ASCENDING, StringComparators.NUMERIC)
+ ),
+ Integer.MAX_VALUE
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{1L},
+ new Object[]{1L},
+ new Object[]{1L},
+ new Object[]{1L},
+ new Object[]{1L},
+ new Object[]{1L}
+ )
+ );
+ }
+
private void testQuery(
final String sql,
final List<Query> expectedQueries,
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]