This is an automated email from the ASF dual-hosted git repository.

kgyrtkirk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a38434d8dd Restore usage of filtered SUM (#17378)
1a38434d8dd is described below

commit 1a38434d8ddf4532f4e03fa82f6d56df80f80837
Author: Zoltan Haindrich <[email protected]>
AuthorDate: Thu Dec 12 10:30:42 2024 +0100

    Restore usage of filtered SUM (#17378)
---
 .../java/org/apache/druid/query/QueryContext.java  |   9 +
 .../java/org/apache/druid/query/QueryContexts.java |  17 +
 .../org/apache/druid/query/QueryContextTest.java   |  11 +
 .../sql/calcite/planner/CalciteRulesManager.java   |   3 +-
 .../rule/DruidAggregateCaseToFilterRule.java       | 349 +++++++++++++++++++++
 .../apache/druid/sql/calcite/CalciteQueryTest.java | 132 +++++---
 .../filtered_sum.iq                                | 120 +++++++
 7 files changed, 593 insertions(+), 48 deletions(-)

diff --git a/processing/src/main/java/org/apache/druid/query/QueryContext.java 
b/processing/src/main/java/org/apache/druid/query/QueryContext.java
index 1a79f524d5c..1aadd5f93b7 100644
--- a/processing/src/main/java/org/apache/druid/query/QueryContext.java
+++ b/processing/src/main/java/org/apache/druid/query/QueryContext.java
@@ -625,6 +625,15 @@ public class QueryContext
     );
   }
 
+  public boolean isExtendedFilteredSumRewrite()
+  {
+    return getBoolean(
+        QueryContexts.EXTENDED_FILTERED_SUM_REWRITE_ENABLED,
+        QueryContexts.DEFAULT_EXTENDED_FILTERED_SUM_REWRITE_ENABLED
+    );
+  }
+
+
   public QueryResourceId getQueryResourceId()
   {
     return new QueryResourceId(getString(QueryContexts.QUERY_RESOURCE_ID));
diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java 
b/processing/src/main/java/org/apache/druid/query/QueryContexts.java
index 950dea52e90..74b45023d04 100644
--- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java
+++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java
@@ -89,6 +89,22 @@ public class QueryContexts
   public static final String UNCOVERED_INTERVALS_LIMIT_KEY = 
"uncoveredIntervalsLimit";
   public static final String MIN_TOP_N_THRESHOLD = "minTopNThreshold";
   public static final String CATALOG_VALIDATION_ENABLED = 
"catalogValidationEnabled";
+  /**
+   * Context parameter to enable/disable the extended filtered sum rewrite 
logic.
+   *
+   * Controls the rewrite of:
+   * <pre>
+   *    SUM(CASE WHEN COND THEN COL1 ELSE 0 END)
+   * to
+   *    SUM(COL1) FILTER (COND)
+   * </pre>
+   * managed by {@link DruidAggregateCaseToFilterRule}. Defaults to true for 
performance,
+   * but may produce incorrect results when the condition never matches 
(expected 0).
+   * This is for testing and can be removed once a correct and 
high-performance rewrite
+   * is implemented.
+   */
+  public static final String EXTENDED_FILTERED_SUM_REWRITE_ENABLED = 
"extendedFilteredSumRewrite";
+
 
   // projection context keys
   public static final String NO_PROJECTIONS = "noProjections";
@@ -139,6 +155,7 @@ public class QueryContexts
   public static final boolean DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING = false;
   public static final boolean DEFAULT_CATALOG_VALIDATION_ENABLED = true;
   public static final boolean DEFAULT_USE_NESTED_FOR_UNKNOWN_TYPE_IN_SUBQUERY 
= false;
+  public static final boolean DEFAULT_EXTENDED_FILTERED_SUM_REWRITE_ENABLED = 
true;
 
 
   @SuppressWarnings("unused") // Used by Jackson serialization
diff --git 
a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java 
b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java
index 0b32c391cc8..b932c6384c6 100644
--- a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java
+++ b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java
@@ -394,6 +394,17 @@ public class QueryContextTest
     );
   }
 
+  @Test
+  public void testExtendedFilteredSumRewrite()
+  {
+    assertTrue(QueryContext.empty().isExtendedFilteredSumRewrite());
+    assertFalse(
+        QueryContext
+            
.of(ImmutableMap.of(QueryContexts.EXTENDED_FILTERED_SUM_REWRITE_ENABLED, false))
+            .isExtendedFilteredSumRewrite()
+    );
+  }
+
   // This test is a bit silly. It is retained because another test uses the
   // LegacyContextQuery test.
   @Test
diff --git 
a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
 
b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
index d6dd1310e6c..917f0cc204a 100644
--- 
a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
+++ 
b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
@@ -54,6 +54,7 @@ import 
org.apache.druid.sql.calcite.external.ExternalTableScanRule;
 import org.apache.druid.sql.calcite.rule.AggregatePullUpLookupRule;
 import org.apache.druid.sql.calcite.rule.CaseToCoalesceRule;
 import org.apache.druid.sql.calcite.rule.CoalesceLookupRule;
+import org.apache.druid.sql.calcite.rule.DruidAggregateCaseToFilterRule;
 import org.apache.druid.sql.calcite.rule.DruidLogicalValuesRule;
 import org.apache.druid.sql.calcite.rule.DruidRelToDruidRule;
 import org.apache.druid.sql.calcite.rule.DruidRules;
@@ -119,7 +120,6 @@ public class CalciteRulesManager
           CoreRules.FILTER_PROJECT_TRANSPOSE,
           CoreRules.JOIN_PUSH_EXPRESSIONS,
           CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT,
-          CoreRules.AGGREGATE_CASE_TO_FILTER,
           CoreRules.FILTER_AGGREGATE_TRANSPOSE,
           CoreRules.PROJECT_WINDOW_TRANSPOSE,
           CoreRules.MATCH,
@@ -495,6 +495,7 @@ public class CalciteRulesManager
     rules.addAll(BASE_RULES);
     rules.addAll(ABSTRACT_RULES);
     rules.addAll(ABSTRACT_RELATIONAL_RULES);
+    rules.add(new 
DruidAggregateCaseToFilterRule(plannerContext.queryContext().isExtendedFilteredSumRewrite()));
     rules.addAll(configurableRuleSet(plannerContext));
 
     if (plannerContext.getJoinAlgorithm().requiresSubquery()) {
diff --git 
a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidAggregateCaseToFilterRule.java
 
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidAggregateCaseToFilterRule.java
new file mode 100644
index 00000000000..7950f62aab3
--- /dev/null
+++ 
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidAggregateCaseToFilterRule.java
@@ -0,0 +1,349 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.rule;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.rules.AggregateCaseToFilterRule;
+import org.apache.calcite.rel.rules.SubstitutionRule;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlPostfixOperator;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.tools.RelBuilder;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Druid extension of {@link AggregateCaseToFilterRule}.
+ *
+ * Turning on extendedFilteredSumRewrite enables rewrites of:
+ * <pre>
+ *    SUM(CASE WHEN COND THEN COL1 ELSE 0 END)
+ * </pre>
+ * to:
+ * <pre>
+ *    SUM(COL1) FILTER (WHERE COND)
+ * </pre>
+ * <p>
+ * This rewrite improves performance but introduces a known inconsistency when
+ * the condition never matches, as the expected result (0) is replaced with 
`null`.
+ * <p>
+ * Example behavior:
+ * <pre>
+ * +-----------------+--------------+----------+------+--------------+
+ * | input row count | cond matches | valueCol | orig | filtered-SUM |
+ * +-----------------+--------------+----------+------+--------------+
+ * | 0               | *            | *        | null | null         |
+ * | >0              | none         | *        | 0    | null         |
+ * | >0              | all          | null     | null | null         |
+ * | >0              | N>0          | 1        | N    | N            |
+ * +-----------------+--------------+----------+------+--------------+
+ * </pre>
+ */
+public class DruidAggregateCaseToFilterRule extends RelOptRule implements 
SubstitutionRule
+{
+  private boolean extendedFilteredSumRewrite;
+
+  public DruidAggregateCaseToFilterRule(boolean extendedFilteredSumRewrite)
+  {
+    super(operand(Aggregate.class, operand(Project.class, any())));
+    this.extendedFilteredSumRewrite = extendedFilteredSumRewrite;
+  }
+
+  @Override
+  public boolean matches(final RelOptRuleCall call)
+  {
+    final Aggregate aggregate = call.rel(0);
+    final Project project = call.rel(1);
+
+    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+      final int singleArg = soleArgument(aggregateCall);
+      if (singleArg >= 0
+          && isThreeArgCase(project.getProjects().get(singleArg))) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  @Override
+  public void onMatch(RelOptRuleCall call)
+  {
+    final Aggregate aggregate = call.rel(0);
+    final Project project = call.rel(1);
+    final List<AggregateCall> newCalls = new 
ArrayList<>(aggregate.getAggCallList().size());
+    final List<RexNode> newProjects = new ArrayList<>(project.getProjects());
+
+    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+      AggregateCall newCall = transform(aggregateCall, project, newProjects);
+
+      if (newCall == null) {
+        newCalls.add(aggregateCall);
+      } else {
+        newCalls.add(newCall);
+      }
+    }
+
+    if (newCalls.equals(aggregate.getAggCallList())) {
+      return;
+    }
+
+    final RelBuilder relBuilder = call.builder()
+        .push(project.getInput())
+        .project(newProjects);
+
+    final RelBuilder.GroupKey groupKey = 
relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets());
+
+    relBuilder.aggregate(groupKey, newCalls)
+        .convert(aggregate.getRowType(), false);
+
+    call.transformTo(relBuilder.build());
+    call.getPlanner().prune(aggregate);
+  }
+
+  private @Nullable AggregateCall transform(AggregateCall call,
+      Project project, List<RexNode> newProjects)
+  {
+    final int singleArg = soleArgument(call);
+    if (singleArg < 0) {
+      return null;
+    }
+
+    final RexNode rexNode = project.getProjects().get(singleArg);
+    if (!isThreeArgCase(rexNode)) {
+      return null;
+    }
+
+    final RelOptCluster cluster = project.getCluster();
+    final RexBuilder rexBuilder = cluster.getRexBuilder();
+    final RexCall caseCall = (RexCall) rexNode;
+
+    // If one arg is null and the other is not, reverse them and set "flip",
+    // which negates the filter.
+    final boolean flip = RexLiteral.isNullLiteral(caseCall.operands.get(1))
+        && !RexLiteral.isNullLiteral(caseCall.operands.get(2));
+    final RexNode arg1 = caseCall.operands.get(flip ? 2 : 1);
+    final RexNode arg2 = caseCall.operands.get(flip ? 1 : 2);
+
+    // Operand 1: Filter
+    final SqlPostfixOperator op = flip ? SqlStdOperatorTable.IS_NOT_TRUE : 
SqlStdOperatorTable.IS_TRUE;
+    final RexNode filterFromCase = rexBuilder.makeCall(op, 
caseCall.operands.get(0));
+
+    // Combine the CASE filter with an honest-to-goodness SQL FILTER, if the
+    // latter is present.
+    final RexNode filter;
+    if (call.filterArg >= 0) {
+      filter = rexBuilder.makeCall(
+          SqlStdOperatorTable.AND,
+          project.getProjects().get(call.filterArg),
+          filterFromCase
+      );
+    } else {
+      filter = filterFromCase;
+    }
+
+    RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
+    final SqlKind kind = call.getAggregation().getKind();
+    if (call.isDistinct()) {
+      // Just one style supported:
+      //   COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END)
+      // =>
+      //   COUNT(DISTINCT y) FILTER(WHERE x = 'foo')
+
+      if (kind == SqlKind.COUNT
+          && RexLiteral.isNullLiteral(arg2)) {
+        newProjects.add(arg1);
+        newProjects.add(filter);
+        return AggregateCall.create(
+            SqlStdOperatorTable.COUNT,
+            true,
+            false,
+            false,
+            call.rexList,
+            ImmutableList.of(newProjects.size() - 2),
+            newProjects.size() - 1,
+            null,
+            RelCollations.EMPTY,
+            call.getType(),
+            call.getName()
+        );
+      }
+      return null;
+    }
+
+    // Four styles supported:
+    //
+    // A1: AGG(CASE WHEN x = 'foo' THEN expr END)
+    //   => AGG(expr) FILTER (x = 'foo')
+    // A2: SUM0(CASE WHEN x = 'foo' THEN cnt ELSE 0 END)
+    //   => SUM0(cnt) FILTER (x = 'foo')
+    // B: SUM0(CASE WHEN x = 'foo' THEN 1 ELSE 0 END)
+    //   => COUNT() FILTER (x = 'foo')
+    // C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END)
+    //   => COUNT() FILTER (x = 'foo')
+
+    if (kind == SqlKind.COUNT // Case C
+        && arg1.isA(SqlKind.LITERAL)
+        && !RexLiteral.isNullLiteral(arg1)
+        && RexLiteral.isNullLiteral(arg2)) {
+      newProjects.add(filter);
+      return AggregateCall.create(
+          SqlStdOperatorTable.COUNT,
+          false,
+          false,
+          false,
+          call.rexList, ImmutableList.of(), newProjects.size() - 1, null,
+          RelCollations.EMPTY, call.getType(),
+          call.getName());
+    } else if (kind == SqlKind.SUM0 // Case B
+        && isIntLiteral(arg1, BigDecimal.ONE)
+        && isIntLiteral(arg2, BigDecimal.ZERO)) {
+
+      newProjects.add(filter);
+      final RelDataType dataType = typeFactory
+          
.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), 
false);
+      return AggregateCall.create(
+          SqlStdOperatorTable.COUNT,
+          false,
+          false,
+          false,
+          call.rexList,
+          ImmutableList.of(),
+          newProjects.size() - 1,
+          null,
+          RelCollations.EMPTY,
+          dataType,
+          call.getName()
+      );
+    } else if ((RexLiteral.isNullLiteral(arg2) // Case A1
+            && call.getAggregation().allowsFilter())
+        || (kind == SqlKind.SUM0 // Case A2
+            && isIntLiteral(arg2, BigDecimal.ZERO))) {
+      newProjects.add(arg1);
+      newProjects.add(filter);
+      return AggregateCall.create(
+          call.getAggregation(),
+          false,
+          false,
+          false,
+          call.rexList,
+          ImmutableList.of(newProjects.size() - 2),
+          newProjects.size() - 1,
+          null,
+          RelCollations.EMPTY,
+          call.getType(),
+          call.getName()
+      );
+    }
+
+    // Rewrites
+    // D1: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END)
+    //   => SUM0(cnt) FILTER (x = 'foo')
+    // D2: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END)
+    //   => COUNT() FILTER (x = 'foo')
+    //
+    // https://issues.apache.org/jira/browse/CALCITE-5953
+    // have restricted this rewrite as in case there are no rows it may not be 
equvivalent;
+    // however it may have some performance impact in Druid
+    if (extendedFilteredSumRewrite &&
+        kind == SqlKind.SUM && isIntLiteral(arg2, BigDecimal.ZERO)) {
+      if (isIntLiteral(arg1, BigDecimal.ONE)) { // D2
+        newProjects.add(filter);
+        final RelDataType dataType = typeFactory.createTypeWithNullability(
+            typeFactory.createSqlType(SqlTypeName.BIGINT), false
+        );
+        return AggregateCall.create(
+            SqlStdOperatorTable.COUNT,
+            false,
+            false,
+            false,
+            call.rexList,
+            ImmutableList.of(),
+            newProjects.size() - 1,
+            null,
+            RelCollations.EMPTY,
+            dataType,
+            call.getName()
+        );
+
+      } else { // D1
+        newProjects.add(arg1);
+        newProjects.add(filter);
+
+        RelDataType newType = 
typeFactory.createTypeWithNullability(call.getType(), true);
+        return AggregateCall.create(
+            call.getAggregation(),
+            false,
+            false,
+            false,
+            call.rexList,
+            ImmutableList.of(newProjects.size() - 2),
+            newProjects.size() - 1,
+            null,
+            RelCollations.EMPTY,
+            newType,
+            call.getName()
+        );
+      }
+    }
+
+    return null;
+  }
+
+  /**
+   * Returns the argument, if an aggregate call has a single argument, 
otherwise
+   * -1.
+   */
+  private static int soleArgument(AggregateCall aggregateCall)
+  {
+    return aggregateCall.getArgList().size() == 1
+        ? aggregateCall.getArgList().get(0)
+        : -1;
+  }
+
+  private static boolean isThreeArgCase(final RexNode rexNode)
+  {
+    return rexNode.getKind() == SqlKind.CASE
+        && ((RexCall) rexNode).operands.size() == 3;
+  }
+
+  private static boolean isIntLiteral(RexNode rexNode, BigDecimal value)
+  {
+    return rexNode instanceof RexLiteral
+        && SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName())
+        && value.equals(((RexLiteral) rexNode).getValueAs(BigDecimal.class));
+  }
+}
diff --git 
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java 
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index f2a85b4a4a9..12f17b015c1 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -5188,7 +5188,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
   @Test
   public void testFilteredAggregations()
   {
-    cannotVectorizeUnlessFallback();
     Druids.TimeseriesQueryBuilder builder =
         Druids.newTimeseriesQueryBuilder()
               .dataSource(CalciteTests.DATASOURCE1)
@@ -5196,18 +5195,9 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
               .granularity(Granularities.ALL)
               .context(QUERY_CONTEXT_DEFAULT);
     if (NullHandling.sqlCompatible()) {
+      cannotVectorizeUnlessFallback();
       builder = builder.virtualColumns(
-                          expressionVirtualColumn("v0", "substring(\"dim1\", 
0, 1)", ColumnType.STRING),
-                          expressionVirtualColumn(
-                              "v1",
-                              "case_searched((\"dim1\" != '1'),1,0)",
-                              ColumnType.LONG
-                          ),
-                          expressionVirtualColumn(
-                              "v2",
-                              "case_searched((\"dim1\" != '1'),\"cnt\",0)",
-                              ColumnType.LONG
-                          )
+                          expressionVirtualColumn("v0", "substring(\"dim1\", 
0, 1)", ColumnType.STRING)
                         )
                        .aggregators(
                            aggregators(
@@ -5234,7 +5224,10 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
                                    new CountAggregatorFactory("a4"),
                                    not(equality("dim1", "1", 
ColumnType.STRING))
                                ),
-                               new LongSumAggregatorFactory("a5", "v1"),
+                               new FilteredAggregatorFactory(
+                                   new CountAggregatorFactory("a5"),
+                                   not(equality("dim1", "1", 
ColumnType.STRING))
+                               ),
                                new FilteredAggregatorFactory(
                                    new LongSumAggregatorFactory("a6", "cnt"),
                                    equality("dim2", "a", ColumnType.STRING)
@@ -5246,7 +5239,10 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
                                        not(equality("dim1", "1", 
ColumnType.STRING))
                                    )
                                ),
-                               new LongSumAggregatorFactory("a8", "v2"),
+                               new FilteredAggregatorFactory(
+                                   new LongSumAggregatorFactory("a8", "cnt"),
+                                   not(equality("dim1", "1", 
ColumnType.STRING))
+                               ),
                                new FilteredAggregatorFactory(
                                    new LongMaxAggregatorFactory("a9", "cnt"),
                                    not(equality("dim1", "1", 
ColumnType.STRING))
@@ -5272,16 +5268,7 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
                        );
     } else {
       builder = builder.virtualColumns(
-              expressionVirtualColumn(
-              "v0",
-              "case_searched((\"dim1\" != '1'),1,0)",
-              ColumnType.LONG
-          ),
-          expressionVirtualColumn(
-              "v1",
-              "case_searched((\"dim1\" != '1'),\"cnt\",0)",
-              ColumnType.LONG
-          ))
+          )
           .aggregators(
               aggregators(
               new FilteredAggregatorFactory(
@@ -5307,7 +5294,10 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
                   new CountAggregatorFactory("a4"),
                   not(equality("dim1", "1", ColumnType.STRING))
               ),
-              new LongSumAggregatorFactory("a5", "v0"),
+              new FilteredAggregatorFactory(
+                  new CountAggregatorFactory("a5"),
+                  not(equality("dim1", "1", ColumnType.STRING))
+              ),
               new FilteredAggregatorFactory(
                   new LongSumAggregatorFactory("a6", "cnt"),
                   equality("dim2", "a", ColumnType.STRING)
@@ -5319,7 +5309,10 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
                       not(equality("dim1", "1", ColumnType.STRING))
                   )
               ),
-              new LongSumAggregatorFactory("a8", "v1"),
+              new FilteredAggregatorFactory(
+                  new LongSumAggregatorFactory("a8", "cnt"),
+                  not(equality("dim1", "1", ColumnType.STRING))
+              ),
               new FilteredAggregatorFactory(
                   new LongMaxAggregatorFactory("a9", "cnt"),
                   not(equality("dim1", "1", ColumnType.STRING))
@@ -5373,7 +5366,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
   @Test
   public void testCaseFilteredAggregationWithGroupBy()
   {
-    cannotVectorizeUnlessFallback();
     testQuery(
         "SELECT\n"
         + "  cnt,\n"
@@ -5386,15 +5378,11 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
                         .setInterval(querySegmentSpec(Filtration.eternity()))
                         .setGranularity(Granularities.ALL)
                         .setDimensions(dimensions(new 
DefaultDimensionSpec("cnt", "d0", ColumnType.LONG)))
-                        .setVirtualColumns(
-                            expressionVirtualColumn(
-                                "v0",
-                                "case_searched((\"dim1\" != '1'),1,0)",
-                                ColumnType.LONG
-                            )
-                        )
                         .setAggregatorSpecs(aggregators(
-                            new LongSumAggregatorFactory("a0", "v0"),
+                            new FilteredAggregatorFactory(
+                                new CountAggregatorFactory("a0"),
+                                not(equality("dim1", "1", ColumnType.STRING))
+                            ),
                             new LongSumAggregatorFactory("a1", "cnt")
                         ))
                         .setPostAggregatorSpecs(
@@ -5409,6 +5397,52 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
     );
   }
 
+  @Test
+  public void testCaseFilteredAggregationWithGroupRewriteToSum()
+  {
+    testBuilder()
+        .sql(
+            "SELECT\n"
+                + "  cnt,\n"
+                + "  SUM(CASE WHEN dim1 <> '1' THEN 2 ELSE 0 END) + SUM(cnt)\n"
+                + "FROM druid.foo\n"
+                + "GROUP BY cnt"
+        )
+        .expectedQueries(
+            ImmutableList.of(
+                GroupByQuery.builder()
+                    .setDataSource(CalciteTests.DATASOURCE1)
+                    .setInterval(querySegmentSpec(Filtration.eternity()))
+                    .setGranularity(Granularities.ALL)
+                    .setDimensions(dimensions(new DefaultDimensionSpec("cnt", 
"d0", ColumnType.LONG)))
+                    .setVirtualColumns(
+                        expressionVirtualColumn("v0", "2", ColumnType.LONG)
+                    )
+                    .setAggregatorSpecs(
+                        aggregators(
+                            new FilteredAggregatorFactory(
+                                new LongSumAggregatorFactory("a0", "v0"),
+                                not(equality("dim1", "1", ColumnType.STRING))
+                            ),
+                            new LongSumAggregatorFactory("a1", "cnt")
+                        )
+                    )
+                    .setPostAggregatorSpecs(
+                        expressionPostAgg("p0", "(\"a0\" + \"a1\")", 
ColumnType.LONG)
+                    )
+                    .setContext(QUERY_CONTEXT_DEFAULT)
+                    .build()
+            )
+        )
+        .expectedResults(
+            ImmutableList.of(
+                new Object[] {1L, 16L}
+            )
+        )
+        .run();
+  }
+
+
   @Test
   public void testFilteredAggregationWithNotIn()
   {
@@ -9479,7 +9513,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
   @Test
   public void testQueryWithSelectProjectAndIdentityProjectDoesNotRename()
   {
-    cannotVectorizeUnlessFallback();
     msqIncompatible();
     testQuery(
         PLANNER_CONFIG_NO_HLL.withOverrides(
@@ -9506,25 +9539,30 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
                                                     "v0",
                                                     "((\"__time\" >= 
947005200000) && (\"__time\" < 1641402000000))",
                                                     ColumnType.LONG
-                                                ),
-                                                expressionVirtualColumn(
-                                                    "v1",
-                                                    
"case_searched(((\"__time\" >= 947005200000) && (\"__time\" < 
1641402000000)),1,0)",
-                                                    ColumnType.LONG
                                                 )
                                             )
                                             .setDimensions(
                                                 dimensions(
-                                                    new 
DefaultDimensionSpec("dim1", "d0", ColumnType.STRING),
-                                                    new 
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+                                                    new 
DefaultDimensionSpec("v0", "d0", ColumnType.LONG),
+                                                    new 
DefaultDimensionSpec("dim1", "d1", ColumnType.STRING)
                                                 )
                                             )
                                             .setAggregatorSpecs(
                                                 aggregators(
-                                                    new 
LongSumAggregatorFactory("a0", "v1"),
+                                                    new 
FilteredAggregatorFactory(
+                                                        new 
CountAggregatorFactory("a0"),
+                                                        range(
+                                                            "__time",
+                                                            ColumnType.LONG,
+                                                            
timestamp("2000-01-04T17:00:00"),
+                                                            
timestamp("2022-01-05T17:00:00"),
+                                                            false,
+                                                            true
+                                                        )
+                                                    ),
                                                     new 
GroupingAggregatorFactory(
                                                         "a1",
-                                                        
ImmutableList.of("dim1", "v0")
+                                                        ImmutableList.of("v0", 
"dim1")
                                                     )
                                                 )
                                             )
@@ -9549,9 +9587,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
                                 new FilteredAggregatorFactory(
                                     new CountAggregatorFactory("_a1"),
                                     and(
-                                        notNull("d0"),
+                                        notNull("d1"),
                                         equality("a1", 0L, ColumnType.LONG),
-                                        expressionFilter("\"d1\"")
+                                        expressionFilter("\"d0\"")
                                     )
                                 )
                             )
diff --git 
a/sql/src/test/quidem/org.apache.druid.quidem.SqlQuidemTest/filtered_sum.iq 
b/sql/src/test/quidem/org.apache.druid.quidem.SqlQuidemTest/filtered_sum.iq
new file mode 100644
index 00000000000..7da2f4ffd9d
--- /dev/null
+++ b/sql/src/test/quidem/org.apache.druid.quidem.SqlQuidemTest/filtered_sum.iq
@@ -0,0 +1,120 @@
+!use druidtest://?numMergeBuffers=3
+!set outputformat mysql
+
+-- empty input
+SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is 
null,
+       SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) 
FROM numfoo where l1 < -1;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+|      0 |      0 | true   |        |        |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+-- 0=-1,0
+SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is 
null,
+       SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) 
FROM numfoo where l1 < 3;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+|      0 |      1 | false  |        |        |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+
+-- 0=0,0
+SELECT COUNT(1)FILTER(WHERE l1=0),COUNT(1)FILTER(WHERE l1!=0),MIN(l2) is null,
+       SUM(CASE WHEN l1 = 0 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=0) 
FROM numfoo where l1 < 3;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+|      1 |      0 | false  |      0 |      0 |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+-- 7=7,null
+SELECT COUNT(1)FILTER(WHERE l1=7),COUNT(1)FILTER(WHERE l1!=7),MIN(l2) is null,
+       SUM(CASE WHEN l1 = 7 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=7) 
FROM numfoo where 0 < l1 and l1 < 10;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+|      1 |      0 | true   |        |        |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+LogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$2=[IS NULL($2)], EXPR$3=[$3], 
EXPR$4=[$4])
+  LogicalAggregate(group=[{}], EXPR$0=[COUNT() FILTER $0], EXPR$1=[COUNT() 
FILTER $1], agg#2=[MIN($2)], EXPR$3=[SUM($3)], EXPR$4=[SUM($2) FILTER $0])
+    LogicalProject($f1=[IS TRUE(=($0, 7))], $f2=[IS TRUE(<>($0, 7))], l2=[$1], 
$f4=[CASE(=($0, 7), $1, 0:BIGINT)])
+      LogicalFilter(condition=[SEARCH($0, Sarg[(0..10)])])
+        LogicalProject(l1=[$11], l2=[$12])
+          LogicalTableScan(table=[[druid, numfoo]])
+
+!druidPlan
+
+!set extendedFilteredSumRewrite false
+!use druidtest://?numMergeBuffers=3
+
+
+-- empty input
+SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is 
null,
+       SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) 
FROM numfoo where l1 < -1;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+|      0 |      0 | true   |        |        |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+-- 0=-1,0
+SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is 
null,
+       SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) 
FROM numfoo where l1 < 3;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+|      0 |      1 | false  |      0 |        |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+
+-- 0=0,0
+SELECT COUNT(1)FILTER(WHERE l1=0),COUNT(1)FILTER(WHERE l1!=0),MIN(l2) is null,
+       SUM(CASE WHEN l1 = 0 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=0) 
FROM numfoo where l1 < 3;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+|      1 |      0 | false  |      0 |      0 |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+-- 7=7,null
+SELECT COUNT(1)FILTER(WHERE l1=7),COUNT(1)FILTER(WHERE l1!=7),MIN(l2) is null,
+       SUM(CASE WHEN l1 = 7 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=7) 
FROM numfoo where 0 < l1 and l1 < 10;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+|      1 |      0 | true   |        |        |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+LogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$2=[IS NULL($2)], EXPR$3=[$3], 
EXPR$4=[$4])
+  LogicalAggregate(group=[{}], EXPR$0=[COUNT() FILTER $0], EXPR$1=[COUNT() 
FILTER $1], agg#2=[MIN($2)], EXPR$3=[SUM($3)], EXPR$4=[SUM($2) FILTER $0])
+    LogicalProject($f1=[IS TRUE(=($0, 7))], $f2=[IS TRUE(<>($0, 7))], l2=[$1], 
$f4=[CASE(=($0, 7), $1, 0:BIGINT)])
+      LogicalFilter(condition=[SEARCH($0, Sarg[(0..10)])])
+        LogicalProject(l1=[$11], l2=[$12])
+          LogicalTableScan(table=[[druid, numfoo]])
+
+!druidPlan


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to