This is an automated email from the ASF dual-hosted git repository.
kgyrtkirk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new b8dd7478d07 Custom Calcite Rule to remove redundant references (#16402)
b8dd7478d07 is described below
commit b8dd7478d079fdce36d3ed94a86749c6b8162833
Author: Sree Charan Manamala <[email protected]>
AuthorDate: Tue May 14 10:08:05 2024 +0530
Custom Calcite Rule to remove redundant references (#16402)
Custom calcite rule mimicking AggregateProjectMergeRule to extend support
to expressions.
The current calcite rule return null in such cases.
In addition, this removes the redundant references.
---
.../aggregation/GroupingAggregatorFactory.java | 23 +--
.../aggregation/GroupingAggregatorFactoryTest.java | 8 +
.../aggregation/builtin/GroupingSqlAggregator.java | 14 +-
.../sql/calcite/planner/CalciteRulesManager.java | 2 +
.../DruidAggregateRemoveRedundancyRule.java | 164 +++++++++++++++++++++
.../apache/druid/sql/calcite/CalciteQueryTest.java | 74 ++++++++--
6 files changed, 264 insertions(+), 21 deletions(-)
diff --git
a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java
b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java
index 8f8f7be4a14..e87c23951db 100644
---
a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java
+++
b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java
@@ -102,6 +102,20 @@ public class GroupingAggregatorFactory extends
AggregatorFactory
)
{
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator
name");
+ Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings),
"Must have a non-empty grouping dimensions");
+ // (Long.SIZE - 1) is just a sanity check. In practice, it will be just
few dimensions. This limit
+ // also makes sure that values are always positive.
+ Preconditions.checkArgument(
+ groupings.size() < Long.SIZE,
+ "Number of dimensions %s is more than supported %s",
+ groupings.size(),
+ Long.SIZE - 1
+ );
+ Preconditions.checkArgument(
+ groupings.stream().distinct().count() == groupings.size(),
+ "Encountered same dimension more than once in groupings"
+ );
+
this.name = name;
this.groupings = groupings;
this.keyDimensions = keyDimensions;
@@ -254,15 +268,6 @@ public class GroupingAggregatorFactory extends
AggregatorFactory
*/
private long groupingId(List<String> groupings, @Nullable Set<String>
keyDimensions)
{
- Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings),
"Must have a non-empty grouping dimensions");
- // (Long.SIZE - 1) is just a sanity check. In practice, it will be just
few dimensions. This limit
- // also makes sure that values are always positive.
- Preconditions.checkArgument(
- groupings.size() < Long.SIZE,
- "Number of dimensions %s is more than supported %s",
- groupings.size(),
- Long.SIZE - 1
- );
long temp = 0L;
for (String groupingDimension : groupings) {
temp = temp << 1;
diff --git
a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java
b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java
index 2be56bd0f0e..c9772ab9534 100644
---
a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java
+++
b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java
@@ -131,6 +131,14 @@ public class GroupingAggregatorFactoryTest
));
makeFactory(new String[Long.SIZE], null);
}
+
+ @Test
+ public void testWithDuplicateGroupings()
+ {
+ exception.expect(IllegalArgumentException.class);
+ exception.expectMessage("Encountered same dimension more than once in
groupings");
+ makeFactory(new String[]{"a", "a"}, null);
+ }
}
@RunWith(Parameterized.class)
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java
index 7c123dab927..1209ee30eaf 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java
@@ -90,7 +90,19 @@ public class GroupingSqlAggregator implements SqlAggregator
}
}
}
- AggregatorFactory factory = new GroupingAggregatorFactory(name, arguments);
+ AggregatorFactory factory;
+ try {
+ factory = new GroupingAggregatorFactory(name, arguments);
+ }
+ catch (Exception e) {
+ plannerContext.setPlanningError(
+ "Initialisation of Grouping Aggregator Factory in case of [%s] threw
[%s]",
+ aggregateCall,
+ e.getMessage()
+ );
+ return null;
+ }
+
return Aggregation.create(factory);
}
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
index 4326f63340d..829c44b18c6 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
@@ -66,6 +66,7 @@ import
org.apache.druid.sql.calcite.rule.ProjectAggregatePruneUnusedCallRule;
import org.apache.druid.sql.calcite.rule.ReverseLookupRule;
import org.apache.druid.sql.calcite.rule.RewriteFirstValueLastValueRule;
import org.apache.druid.sql.calcite.rule.SortCollapseRule;
+import
org.apache.druid.sql.calcite.rule.logical.DruidAggregateRemoveRedundancyRule;
import org.apache.druid.sql.calcite.rule.logical.DruidLogicalRules;
import org.apache.druid.sql.calcite.run.EngineFeature;
@@ -496,6 +497,7 @@ public class CalciteRulesManager
rules.add(FilterJoinExcludePushToChildRule.FILTER_ON_JOIN_EXCLUDE_PUSH_TO_CHILD);
rules.add(SortCollapseRule.instance());
rules.add(ProjectAggregatePruneUnusedCallRule.instance());
+ rules.add(DruidAggregateRemoveRedundancyRule.instance());
return rules.build();
}
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/rule/logical/DruidAggregateRemoveRedundancyRule.java
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/logical/DruidAggregateRemoveRedundancyRule.java
new file mode 100644
index 00000000000..1ef91dcb6ba
--- /dev/null
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/logical/DruidAggregateRemoveRedundancyRule.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.rule.logical;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.Aggregate.Group;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.rules.TransformationRule;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Util;
+import org.apache.calcite.util.mapping.Mappings;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+
+/**
+ * Planner rule that recognizes a {@link Aggregate}
+ * on top of a {@link Project} and if possible
+ * aggregate through the project or removes the project.
+ * <p>
+ * This is updated version of {@link
org.apache.calcite.rel.rules.AggregateProjectMergeRule}
+ * to be able to handle expressions.
+ */
[email protected]
+public class DruidAggregateRemoveRedundancyRule
+ extends RelOptRule
+ implements TransformationRule
+{
+
+ /**
+ * Creates a DruidAggregateRemoveRedundancyRule.
+ */
+ private static final DruidAggregateRemoveRedundancyRule INSTANCE = new
DruidAggregateRemoveRedundancyRule();
+
+ private DruidAggregateRemoveRedundancyRule()
+ {
+ super(operand(Aggregate.class, operand(Project.class, any())));
+ }
+
+ public static DruidAggregateRemoveRedundancyRule instance()
+ {
+ return INSTANCE;
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call)
+ {
+ final Aggregate aggregate = call.rel(0);
+ final Project project = call.rel(1);
+ RelNode x = apply(call, aggregate, project);
+ if (x != null) {
+ call.transformTo(x);
+ call.getPlanner().prune(aggregate);
+ }
+ }
+
+ public static @Nullable RelNode apply(RelOptRuleCall call, Aggregate
aggregate, Project project)
+ {
+ final Set<Integer> interestingFields = RelOptUtil.getAllFields(aggregate);
+ if (interestingFields.isEmpty()) {
+ return null;
+ }
+ final Map<Integer, Integer> map = new HashMap<>();
+ final Map<RexNode, Integer> assignedRefForExpr = new HashMap<>();
+ List<RexNode> newRexNodes = new ArrayList<>();
+ for (int source : interestingFields) {
+ final RexNode rex = project.getProjects().get(source);
+ if (!assignedRefForExpr.containsKey(rex)) {
+ RexNode newNode = new RexInputRef(source, rex.getType());
+ assignedRefForExpr.put(rex, newRexNodes.size());
+ newRexNodes.add(newNode);
+ }
+ map.put(source, assignedRefForExpr.get(rex));
+ }
+
+ if (newRexNodes.size() == project.getProjects().size()) {
+ return null;
+ }
+
+ final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
+ ImmutableList<ImmutableBitSet> newGroupingSets = null;
+ if (aggregate.getGroupType() != Group.SIMPLE) {
+ newGroupingSets =
+ ImmutableBitSet.ORDERING.immutableSortedCopy(
+
Sets.newTreeSet(ImmutableBitSet.permute(aggregate.getGroupSets(), map)));
+ }
+
+ final ImmutableList.Builder<AggregateCall> aggCalls =
ImmutableList.builder();
+ final int sourceCount = aggregate.getInput().getRowType().getFieldCount();
+ final int targetCount = newRexNodes.size();
+ final Mappings.TargetMapping targetMapping = Mappings.target(map,
sourceCount, targetCount);
+ for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+ aggCalls.add(aggregateCall.transform(targetMapping));
+ }
+
+ final RelBuilder relBuilder = call.builder();
+ relBuilder.push(project);
+ relBuilder.project(newRexNodes);
+
+ final Aggregate newAggregate =
+ aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
+ newGroupSet, newGroupingSets, aggCalls.build()
+ );
+ relBuilder.push(newAggregate);
+
+ final List<Integer> newKeys =
+ Util.transform(
+ aggregate.getGroupSet().asList(),
+ key -> Objects.requireNonNull(
+ map.get(key),
+ () -> "no value found for key " + key + " in " + map
+ )
+ );
+
+ // Add a project if the group set is not in the same order or
+ // contains duplicates.
+ if (!newKeys.equals(newGroupSet.asList())) {
+ final List<Integer> posList = new ArrayList<>();
+ for (int newKey : newKeys) {
+ posList.add(newGroupSet.indexOf(newKey));
+ }
+ for (int i = newAggregate.getGroupCount();
+ i < newAggregate.getRowType().getFieldCount(); i++) {
+ posList.add(i);
+ }
+ relBuilder.project(relBuilder.fields(posList));
+ }
+
+ return relBuilder.build();
+ }
+}
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index b5a32f4301c..9b302534ef6 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -8788,8 +8788,8 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
)
.setDimensions(
dimensions(
- new
DefaultDimensionSpec("dim1", "d0", ColumnType.STRING),
- new
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+ new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG),
+ new
DefaultDimensionSpec("dim1", "d1", ColumnType.STRING)
)
)
.setAggregatorSpecs(
@@ -8832,9 +8832,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a1"),
and(
- notNull("d0"),
+ notNull("d1"),
equality("a1", 0L, ColumnType.LONG),
- expressionFilter("\"d1\"")
+ expressionFilter("\"d0\"")
)
)
)
@@ -12938,8 +12938,7 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
.setVirtualColumns(expressionVirtualColumn("v0", "1",
ColumnType.LONG))
.setDimensions(
dimensions(
- new DefaultDimensionSpec("v0", "d0",
ColumnType.LONG),
- new DefaultDimensionSpec("v0", "d1",
ColumnType.LONG)
+ new DefaultDimensionSpec("v0", "d0",
ColumnType.LONG)
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
@@ -15680,10 +15679,63 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
.build()
)
).expectedResults(
- ResultMatchMode.RELAX_NULLS,
- ImmutableList.of(
- new Object[]{null, null, null}
- )
- );
+ NullHandling.sqlCompatible() ? ImmutableList.of(
+ new Object[]{null, null, null}
+ ) : ImmutableList.of(
+ new Object[]{false, false, ""}
+ )
+ ).run();
+ }
+
+ @SqlTestFrameworkConfig.NumMergeBuffers(4)
+ @Test
+ public void testGroupingSetsWithAggrgateCase()
+ {
+ cannotVectorize();
+ msqIncompatible();
+ final Map<String, Object> queryContext = ImmutableMap.of(
+ PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT, false,
+ PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT, true
+ );
+ testBuilder()
+ .sql(
+ "SELECT\n"
+ + " TIME_FLOOR(\"__time\", 'PT1H') ,\n"
+ + " COUNT(DISTINCT \"page\") ,\n"
+ + " COUNT(DISTINCT CASE WHEN \"channel\" = '#it.wikipedia' THEN
\"user\" END), \n"
+ + " COUNT(DISTINCT \"user\") FILTER (WHERE \"channel\" =
'#it.wikipedia'), "
+ + " COUNT(DISTINCT \"user\") \n"
+ + "FROM \"wikipedia\"\n"
+ + "GROUP BY 1"
+ )
+ .queryContext(queryContext)
+ .expectedResults(
+ ImmutableList.of(
+ new Object[]{1442016000000L, 264L, 5L, 5L, 149L},
+ new Object[]{1442019600000L, 1090L, 14L, 14L, 506L},
+ new Object[]{1442023200000L, 1045L, 10L, 10L, 459L},
+ new Object[]{1442026800000L, 766L, 10L, 10L, 427L},
+ new Object[]{1442030400000L, 781L, 6L, 6L, 427L},
+ new Object[]{1442034000000L, 1223L, 10L, 10L, 448L},
+ new Object[]{1442037600000L, 2092L, 13L, 13L, 498L},
+ new Object[]{1442041200000L, 2181L, 21L, 21L, 574L},
+ new Object[]{1442044800000L, 1552L, 36L, 36L, 707L},
+ new Object[]{1442048400000L, 1624L, 44L, 44L, 770L},
+ new Object[]{1442052000000L, 1710L, 37L, 37L, 785L},
+ new Object[]{1442055600000L, 1532L, 40L, 40L, 799L},
+ new Object[]{1442059200000L, 1633L, 45L, 45L, 855L},
+ new Object[]{1442062800000L, 1958L, 44L, 44L, 905L},
+ new Object[]{1442066400000L, 1779L, 48L, 48L, 886L},
+ new Object[]{1442070000000L, 1868L, 37L, 37L, 949L},
+ new Object[]{1442073600000L, 1846L, 50L, 50L, 969L},
+ new Object[]{1442077200000L, 2168L, 38L, 38L, 941L},
+ new Object[]{1442080800000L, 2043L, 40L, 40L, 925L},
+ new Object[]{1442084400000L, 1924L, 32L, 32L, 930L},
+ new Object[]{1442088000000L, 1736L, 31L, 31L, 882L},
+ new Object[]{1442091600000L, 1672L, 40L, 40L, 861L},
+ new Object[]{1442095200000L, 1504L, 28L, 28L, 716L},
+ new Object[]{1442098800000L, 1407L, 20L, 20L, 631L}
+ )
+ ).run();
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]