This is an automated email from the ASF dual-hosted git repository.

kgyrtkirk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new b8dd7478d07 Custom Calcite Rule to remove redundant references (#16402)
b8dd7478d07 is described below

commit b8dd7478d079fdce36d3ed94a86749c6b8162833
Author: Sree Charan Manamala <[email protected]>
AuthorDate: Tue May 14 10:08:05 2024 +0530

    Custom Calcite Rule to remove redundant references (#16402)
    
    Custom calcite rule mimicking AggregateProjectMergeRule to extend support 
to expressions.
    The current calcite rule return null in such cases.
    In addition, this removes the redundant references.
---
 .../aggregation/GroupingAggregatorFactory.java     |  23 +--
 .../aggregation/GroupingAggregatorFactoryTest.java |   8 +
 .../aggregation/builtin/GroupingSqlAggregator.java |  14 +-
 .../sql/calcite/planner/CalciteRulesManager.java   |   2 +
 .../DruidAggregateRemoveRedundancyRule.java        | 164 +++++++++++++++++++++
 .../apache/druid/sql/calcite/CalciteQueryTest.java |  74 ++++++++--
 6 files changed, 264 insertions(+), 21 deletions(-)

diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java
index 8f8f7be4a14..e87c23951db 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java
@@ -102,6 +102,20 @@ public class GroupingAggregatorFactory extends 
AggregatorFactory
   )
   {
     Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator 
name");
+    Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings), 
"Must have a non-empty grouping dimensions");
+    // (Long.SIZE - 1) is just a sanity check. In practice, it will be just 
few dimensions. This limit
+    // also makes sure that values are always positive.
+    Preconditions.checkArgument(
+        groupings.size() < Long.SIZE,
+        "Number of dimensions %s is more than supported %s",
+        groupings.size(),
+        Long.SIZE - 1
+    );
+    Preconditions.checkArgument(
+        groupings.stream().distinct().count() == groupings.size(),
+        "Encountered same dimension more than once in groupings"
+    );
+
     this.name = name;
     this.groupings = groupings;
     this.keyDimensions = keyDimensions;
@@ -254,15 +268,6 @@ public class GroupingAggregatorFactory extends 
AggregatorFactory
    */
   private long groupingId(List<String> groupings, @Nullable Set<String> 
keyDimensions)
   {
-    Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings), 
"Must have a non-empty grouping dimensions");
-    // (Long.SIZE - 1) is just a sanity check. In practice, it will be just 
few dimensions. This limit
-    // also makes sure that values are always positive.
-    Preconditions.checkArgument(
-        groupings.size() < Long.SIZE,
-        "Number of dimensions %s is more than supported %s",
-        groupings.size(),
-        Long.SIZE - 1
-    );
     long temp = 0L;
     for (String groupingDimension : groupings) {
       temp = temp << 1;
diff --git 
a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java
 
b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java
index 2be56bd0f0e..c9772ab9534 100644
--- 
a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java
+++ 
b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java
@@ -131,6 +131,14 @@ public class GroupingAggregatorFactoryTest
       ));
       makeFactory(new String[Long.SIZE], null);
     }
+
+    @Test
+    public void testWithDuplicateGroupings()
+    {
+      exception.expect(IllegalArgumentException.class);
+      exception.expectMessage("Encountered same dimension more than once in 
groupings");
+      makeFactory(new String[]{"a", "a"}, null);
+    }
   }
 
   @RunWith(Parameterized.class)
diff --git 
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java
 
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java
index 7c123dab927..1209ee30eaf 100644
--- 
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java
+++ 
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java
@@ -90,7 +90,19 @@ public class GroupingSqlAggregator implements SqlAggregator
         }
       }
     }
-    AggregatorFactory factory = new GroupingAggregatorFactory(name, arguments);
+    AggregatorFactory factory;
+    try {
+      factory = new GroupingAggregatorFactory(name, arguments);
+    }
+    catch (Exception e) {
+      plannerContext.setPlanningError(
+          "Initialisation of Grouping Aggregator Factory in case of [%s] threw 
[%s]",
+          aggregateCall,
+          e.getMessage()
+      );
+      return null;
+    }
+
     return Aggregation.create(factory);
   }
 
diff --git 
a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
 
b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
index 4326f63340d..829c44b18c6 100644
--- 
a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
+++ 
b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
@@ -66,6 +66,7 @@ import 
org.apache.druid.sql.calcite.rule.ProjectAggregatePruneUnusedCallRule;
 import org.apache.druid.sql.calcite.rule.ReverseLookupRule;
 import org.apache.druid.sql.calcite.rule.RewriteFirstValueLastValueRule;
 import org.apache.druid.sql.calcite.rule.SortCollapseRule;
+import 
org.apache.druid.sql.calcite.rule.logical.DruidAggregateRemoveRedundancyRule;
 import org.apache.druid.sql.calcite.rule.logical.DruidLogicalRules;
 import org.apache.druid.sql.calcite.run.EngineFeature;
 
@@ -496,6 +497,7 @@ public class CalciteRulesManager
     
rules.add(FilterJoinExcludePushToChildRule.FILTER_ON_JOIN_EXCLUDE_PUSH_TO_CHILD);
     rules.add(SortCollapseRule.instance());
     rules.add(ProjectAggregatePruneUnusedCallRule.instance());
+    rules.add(DruidAggregateRemoveRedundancyRule.instance());
 
     return rules.build();
   }
diff --git 
a/sql/src/main/java/org/apache/druid/sql/calcite/rule/logical/DruidAggregateRemoveRedundancyRule.java
 
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/logical/DruidAggregateRemoveRedundancyRule.java
new file mode 100644
index 00000000000..1ef91dcb6ba
--- /dev/null
+++ 
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/logical/DruidAggregateRemoveRedundancyRule.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.rule.logical;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.Aggregate.Group;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.rules.TransformationRule;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Util;
+import org.apache.calcite.util.mapping.Mappings;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+
+/**
+ * Planner rule that recognizes a {@link Aggregate}
+ * on top of a {@link Project} and if possible
+ * aggregate through the project or removes the project.
+ * <p>
+ * This is updated version of {@link 
org.apache.calcite.rel.rules.AggregateProjectMergeRule}
+ * to be able to handle expressions.
+ */
[email protected]
+public class DruidAggregateRemoveRedundancyRule
+    extends RelOptRule
+    implements TransformationRule
+{
+
+  /**
+   * Creates a DruidAggregateRemoveRedundancyRule.
+   */
+  private static final DruidAggregateRemoveRedundancyRule INSTANCE = new 
DruidAggregateRemoveRedundancyRule();
+
+  private DruidAggregateRemoveRedundancyRule()
+  {
+    super(operand(Aggregate.class, operand(Project.class, any())));
+  }
+
+  public static DruidAggregateRemoveRedundancyRule instance()
+  {
+    return INSTANCE;
+  }
+
+  @Override
+  public void onMatch(RelOptRuleCall call)
+  {
+    final Aggregate aggregate = call.rel(0);
+    final Project project = call.rel(1);
+    RelNode x = apply(call, aggregate, project);
+    if (x != null) {
+      call.transformTo(x);
+      call.getPlanner().prune(aggregate);
+    }
+  }
+
+  public static @Nullable RelNode apply(RelOptRuleCall call, Aggregate 
aggregate, Project project)
+  {
+    final Set<Integer> interestingFields = RelOptUtil.getAllFields(aggregate);
+    if (interestingFields.isEmpty()) {
+      return null;
+    }
+    final Map<Integer, Integer> map = new HashMap<>();
+    final Map<RexNode, Integer> assignedRefForExpr = new HashMap<>();
+    List<RexNode> newRexNodes = new ArrayList<>();
+    for (int source : interestingFields) {
+      final RexNode rex = project.getProjects().get(source);
+      if (!assignedRefForExpr.containsKey(rex)) {
+        RexNode newNode = new RexInputRef(source, rex.getType());
+        assignedRefForExpr.put(rex, newRexNodes.size());
+        newRexNodes.add(newNode);
+      }
+      map.put(source, assignedRefForExpr.get(rex));
+    }
+
+    if (newRexNodes.size() == project.getProjects().size()) {
+      return null;
+    }
+
+    final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
+    ImmutableList<ImmutableBitSet> newGroupingSets = null;
+    if (aggregate.getGroupType() != Group.SIMPLE) {
+      newGroupingSets =
+          ImmutableBitSet.ORDERING.immutableSortedCopy(
+              
Sets.newTreeSet(ImmutableBitSet.permute(aggregate.getGroupSets(), map)));
+    }
+
+    final ImmutableList.Builder<AggregateCall> aggCalls = 
ImmutableList.builder();
+    final int sourceCount = aggregate.getInput().getRowType().getFieldCount();
+    final int targetCount = newRexNodes.size();
+    final Mappings.TargetMapping targetMapping = Mappings.target(map, 
sourceCount, targetCount);
+    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+      aggCalls.add(aggregateCall.transform(targetMapping));
+    }
+
+    final RelBuilder relBuilder = call.builder();
+    relBuilder.push(project);
+    relBuilder.project(newRexNodes);
+
+    final Aggregate newAggregate =
+        aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
+                       newGroupSet, newGroupingSets, aggCalls.build()
+        );
+    relBuilder.push(newAggregate);
+
+    final List<Integer> newKeys =
+        Util.transform(
+            aggregate.getGroupSet().asList(),
+            key -> Objects.requireNonNull(
+                map.get(key),
+                () -> "no value found for key " + key + " in " + map
+            )
+        );
+
+    // Add a project if the group set is not in the same order or
+    // contains duplicates.
+    if (!newKeys.equals(newGroupSet.asList())) {
+      final List<Integer> posList = new ArrayList<>();
+      for (int newKey : newKeys) {
+        posList.add(newGroupSet.indexOf(newKey));
+      }
+      for (int i = newAggregate.getGroupCount();
+           i < newAggregate.getRowType().getFieldCount(); i++) {
+        posList.add(i);
+      }
+      relBuilder.project(relBuilder.fields(posList));
+    }
+
+    return relBuilder.build();
+  }
+}
diff --git 
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java 
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index b5a32f4301c..9b302534ef6 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -8788,8 +8788,8 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
                                             )
                                             .setDimensions(
                                                 dimensions(
-                                                    new 
DefaultDimensionSpec("dim1", "d0", ColumnType.STRING),
-                                                    new 
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+                                                    new 
DefaultDimensionSpec("v0", "d0", ColumnType.LONG),
+                                                    new 
DefaultDimensionSpec("dim1", "d1", ColumnType.STRING)
                                                 )
                                             )
                                             .setAggregatorSpecs(
@@ -8832,9 +8832,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
                                 new FilteredAggregatorFactory(
                                     new CountAggregatorFactory("_a1"),
                                     and(
-                                        notNull("d0"),
+                                        notNull("d1"),
                                         equality("a1", 0L, ColumnType.LONG),
-                                        expressionFilter("\"d1\"")
+                                        expressionFilter("\"d0\"")
                                     )
                                 )
                             )
@@ -12938,8 +12938,7 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
                         .setVirtualColumns(expressionVirtualColumn("v0", "1", 
ColumnType.LONG))
                         .setDimensions(
                             dimensions(
-                                new DefaultDimensionSpec("v0", "d0", 
ColumnType.LONG),
-                                new DefaultDimensionSpec("v0", "d1", 
ColumnType.LONG)
+                                new DefaultDimensionSpec("v0", "d0", 
ColumnType.LONG)
                             )
                         )
                         .setContext(QUERY_CONTEXT_DEFAULT)
@@ -15680,10 +15679,63 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
                               .build()
               )
         ).expectedResults(
-              ResultMatchMode.RELAX_NULLS,
-              ImmutableList.of(
-                      new Object[]{null, null, null}
-              )
-      );
+            NullHandling.sqlCompatible() ? ImmutableList.of(
+                new Object[]{null, null, null}
+            ) : ImmutableList.of(
+                new Object[]{false, false, ""}
+            )
+        ).run();
+  }
+
+  @SqlTestFrameworkConfig.NumMergeBuffers(4)
+  @Test
+  public void testGroupingSetsWithAggrgateCase()
+  {
+    cannotVectorize();
+    msqIncompatible();
+    final Map<String, Object> queryContext = ImmutableMap.of(
+        PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT, false,
+        PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT, true
+    );
+    testBuilder()
+        .sql(
+            "SELECT\n"
+            + "  TIME_FLOOR(\"__time\", 'PT1H') ,\n"
+            + "  COUNT(DISTINCT \"page\") ,\n"
+            + "  COUNT(DISTINCT CASE WHEN \"channel\" = '#it.wikipedia' THEN 
\"user\" END), \n"
+            + "  COUNT(DISTINCT \"user\") FILTER (WHERE \"channel\" = 
'#it.wikipedia'), "
+            + "  COUNT(DISTINCT \"user\") \n"
+            + "FROM \"wikipedia\"\n"
+            + "GROUP BY 1"
+        )
+        .queryContext(queryContext)
+        .expectedResults(
+            ImmutableList.of(
+                new Object[]{1442016000000L, 264L, 5L, 5L, 149L},
+                new Object[]{1442019600000L, 1090L, 14L, 14L, 506L},
+                new Object[]{1442023200000L, 1045L, 10L, 10L, 459L},
+                new Object[]{1442026800000L, 766L, 10L, 10L, 427L},
+                new Object[]{1442030400000L, 781L, 6L, 6L, 427L},
+                new Object[]{1442034000000L, 1223L, 10L, 10L, 448L},
+                new Object[]{1442037600000L, 2092L, 13L, 13L, 498L},
+                new Object[]{1442041200000L, 2181L, 21L, 21L, 574L},
+                new Object[]{1442044800000L, 1552L, 36L, 36L, 707L},
+                new Object[]{1442048400000L, 1624L, 44L, 44L, 770L},
+                new Object[]{1442052000000L, 1710L, 37L, 37L, 785L},
+                new Object[]{1442055600000L, 1532L, 40L, 40L, 799L},
+                new Object[]{1442059200000L, 1633L, 45L, 45L, 855L},
+                new Object[]{1442062800000L, 1958L, 44L, 44L, 905L},
+                new Object[]{1442066400000L, 1779L, 48L, 48L, 886L},
+                new Object[]{1442070000000L, 1868L, 37L, 37L, 949L},
+                new Object[]{1442073600000L, 1846L, 50L, 50L, 969L},
+                new Object[]{1442077200000L, 2168L, 38L, 38L, 941L},
+                new Object[]{1442080800000L, 2043L, 40L, 40L, 925L},
+                new Object[]{1442084400000L, 1924L, 32L, 32L, 930L},
+                new Object[]{1442088000000L, 1736L, 31L, 31L, 882L},
+                new Object[]{1442091600000L, 1672L, 40L, 40L, 861L},
+                new Object[]{1442095200000L, 1504L, 28L, 28L, 716L},
+                new Object[]{1442098800000L, 1407L, 20L, 20L, 631L}
+            )
+        ).run();
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to