This is an automated email from the ASF dual-hosted git repository.
kgyrtkirk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new 1a38434d8dd Restore usage of filtered SUM (#17378)
1a38434d8dd is described below
commit 1a38434d8ddf4532f4e03fa82f6d56df80f80837
Author: Zoltan Haindrich <[email protected]>
AuthorDate: Thu Dec 12 10:30:42 2024 +0100
Restore usage of filtered SUM (#17378)
---
.../java/org/apache/druid/query/QueryContext.java | 9 +
.../java/org/apache/druid/query/QueryContexts.java | 17 +
.../org/apache/druid/query/QueryContextTest.java | 11 +
.../sql/calcite/planner/CalciteRulesManager.java | 3 +-
.../rule/DruidAggregateCaseToFilterRule.java | 349 +++++++++++++++++++++
.../apache/druid/sql/calcite/CalciteQueryTest.java | 132 +++++---
.../filtered_sum.iq | 120 +++++++
7 files changed, 593 insertions(+), 48 deletions(-)
diff --git a/processing/src/main/java/org/apache/druid/query/QueryContext.java
b/processing/src/main/java/org/apache/druid/query/QueryContext.java
index 1a79f524d5c..1aadd5f93b7 100644
--- a/processing/src/main/java/org/apache/druid/query/QueryContext.java
+++ b/processing/src/main/java/org/apache/druid/query/QueryContext.java
@@ -625,6 +625,15 @@ public class QueryContext
);
}
+ public boolean isExtendedFilteredSumRewrite()
+ {
+ return getBoolean(
+ QueryContexts.EXTENDED_FILTERED_SUM_REWRITE_ENABLED,
+ QueryContexts.DEFAULT_EXTENDED_FILTERED_SUM_REWRITE_ENABLED
+ );
+ }
+
+
public QueryResourceId getQueryResourceId()
{
return new QueryResourceId(getString(QueryContexts.QUERY_RESOURCE_ID));
diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java
b/processing/src/main/java/org/apache/druid/query/QueryContexts.java
index 950dea52e90..74b45023d04 100644
--- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java
+++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java
@@ -89,6 +89,22 @@ public class QueryContexts
public static final String UNCOVERED_INTERVALS_LIMIT_KEY =
"uncoveredIntervalsLimit";
public static final String MIN_TOP_N_THRESHOLD = "minTopNThreshold";
public static final String CATALOG_VALIDATION_ENABLED =
"catalogValidationEnabled";
+ /**
+ * Context parameter to enable/disable the extended filtered sum rewrite
logic.
+ *
+ * Controls the rewrite of:
+ * <pre>
+ * SUM(CASE WHEN COND THEN COL1 ELSE 0 END)
+ * to
+ * SUM(COL1) FILTER (COND)
+ * </pre>
+ * managed by {@link DruidAggregateCaseToFilterRule}. Defaults to true for
performance,
+ * but may produce incorrect results when the condition never matches
(expected 0).
+ * This is for testing and can be removed once a correct and
high-performance rewrite
+ * is implemented.
+ */
+ public static final String EXTENDED_FILTERED_SUM_REWRITE_ENABLED =
"extendedFilteredSumRewrite";
+
// projection context keys
public static final String NO_PROJECTIONS = "noProjections";
@@ -139,6 +155,7 @@ public class QueryContexts
public static final boolean DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING = false;
public static final boolean DEFAULT_CATALOG_VALIDATION_ENABLED = true;
public static final boolean DEFAULT_USE_NESTED_FOR_UNKNOWN_TYPE_IN_SUBQUERY
= false;
+ public static final boolean DEFAULT_EXTENDED_FILTERED_SUM_REWRITE_ENABLED =
true;
@SuppressWarnings("unused") // Used by Jackson serialization
diff --git
a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java
b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java
index 0b32c391cc8..b932c6384c6 100644
--- a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java
+++ b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java
@@ -394,6 +394,17 @@ public class QueryContextTest
);
}
+ @Test
+ public void testExtendedFilteredSumRewrite()
+ {
+ assertTrue(QueryContext.empty().isExtendedFilteredSumRewrite());
+ assertFalse(
+ QueryContext
+
.of(ImmutableMap.of(QueryContexts.EXTENDED_FILTERED_SUM_REWRITE_ENABLED, false))
+ .isExtendedFilteredSumRewrite()
+ );
+ }
+
// This test is a bit silly. It is retained because another test uses the
// LegacyContextQuery test.
@Test
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
index d6dd1310e6c..917f0cc204a 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java
@@ -54,6 +54,7 @@ import
org.apache.druid.sql.calcite.external.ExternalTableScanRule;
import org.apache.druid.sql.calcite.rule.AggregatePullUpLookupRule;
import org.apache.druid.sql.calcite.rule.CaseToCoalesceRule;
import org.apache.druid.sql.calcite.rule.CoalesceLookupRule;
+import org.apache.druid.sql.calcite.rule.DruidAggregateCaseToFilterRule;
import org.apache.druid.sql.calcite.rule.DruidLogicalValuesRule;
import org.apache.druid.sql.calcite.rule.DruidRelToDruidRule;
import org.apache.druid.sql.calcite.rule.DruidRules;
@@ -119,7 +120,6 @@ public class CalciteRulesManager
CoreRules.FILTER_PROJECT_TRANSPOSE,
CoreRules.JOIN_PUSH_EXPRESSIONS,
CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT,
- CoreRules.AGGREGATE_CASE_TO_FILTER,
CoreRules.FILTER_AGGREGATE_TRANSPOSE,
CoreRules.PROJECT_WINDOW_TRANSPOSE,
CoreRules.MATCH,
@@ -495,6 +495,7 @@ public class CalciteRulesManager
rules.addAll(BASE_RULES);
rules.addAll(ABSTRACT_RULES);
rules.addAll(ABSTRACT_RELATIONAL_RULES);
+ rules.add(new
DruidAggregateCaseToFilterRule(plannerContext.queryContext().isExtendedFilteredSumRewrite()));
rules.addAll(configurableRuleSet(plannerContext));
if (plannerContext.getJoinAlgorithm().requiresSubquery()) {
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidAggregateCaseToFilterRule.java
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidAggregateCaseToFilterRule.java
new file mode 100644
index 00000000000..7950f62aab3
--- /dev/null
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidAggregateCaseToFilterRule.java
@@ -0,0 +1,349 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.rule;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.rules.AggregateCaseToFilterRule;
+import org.apache.calcite.rel.rules.SubstitutionRule;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlPostfixOperator;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.tools.RelBuilder;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Druid extension of {@link AggregateCaseToFilterRule}.
+ *
+ * Turning on extendedFilteredSumRewrite enables rewrites of:
+ * <pre>
+ * SUM(CASE WHEN COND THEN COL1 ELSE 0 END)
+ * </pre>
+ * to:
+ * <pre>
+ * SUM(COL1) FILTER (WHERE COND)
+ * </pre>
+ * <p>
+ * This rewrite improves performance but introduces a known inconsistency when
+ * the condition never matches, as the expected result (0) is replaced with
`null`.
+ * <p>
+ * Example behavior:
+ * <pre>
+ * +-----------------+--------------+----------+------+--------------+
+ * | input row count | cond matches | valueCol | orig | filtered-SUM |
+ * +-----------------+--------------+----------+------+--------------+
+ * | 0 | * | * | null | null |
+ * | >0 | none | * | 0 | null |
+ * | >0 | all | null | null | null |
+ * | >0 | N>0 | 1 | N | N |
+ * +-----------------+--------------+----------+------+--------------+
+ * </pre>
+ */
+public class DruidAggregateCaseToFilterRule extends RelOptRule implements
SubstitutionRule
+{
+ private boolean extendedFilteredSumRewrite;
+
+ public DruidAggregateCaseToFilterRule(boolean extendedFilteredSumRewrite)
+ {
+ super(operand(Aggregate.class, operand(Project.class, any())));
+ this.extendedFilteredSumRewrite = extendedFilteredSumRewrite;
+ }
+
+ @Override
+ public boolean matches(final RelOptRuleCall call)
+ {
+ final Aggregate aggregate = call.rel(0);
+ final Project project = call.rel(1);
+
+ for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+ final int singleArg = soleArgument(aggregateCall);
+ if (singleArg >= 0
+ && isThreeArgCase(project.getProjects().get(singleArg))) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call)
+ {
+ final Aggregate aggregate = call.rel(0);
+ final Project project = call.rel(1);
+ final List<AggregateCall> newCalls = new
ArrayList<>(aggregate.getAggCallList().size());
+ final List<RexNode> newProjects = new ArrayList<>(project.getProjects());
+
+ for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+ AggregateCall newCall = transform(aggregateCall, project, newProjects);
+
+ if (newCall == null) {
+ newCalls.add(aggregateCall);
+ } else {
+ newCalls.add(newCall);
+ }
+ }
+
+ if (newCalls.equals(aggregate.getAggCallList())) {
+ return;
+ }
+
+ final RelBuilder relBuilder = call.builder()
+ .push(project.getInput())
+ .project(newProjects);
+
+ final RelBuilder.GroupKey groupKey =
relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets());
+
+ relBuilder.aggregate(groupKey, newCalls)
+ .convert(aggregate.getRowType(), false);
+
+ call.transformTo(relBuilder.build());
+ call.getPlanner().prune(aggregate);
+ }
+
+ private @Nullable AggregateCall transform(AggregateCall call,
+ Project project, List<RexNode> newProjects)
+ {
+ final int singleArg = soleArgument(call);
+ if (singleArg < 0) {
+ return null;
+ }
+
+ final RexNode rexNode = project.getProjects().get(singleArg);
+ if (!isThreeArgCase(rexNode)) {
+ return null;
+ }
+
+ final RelOptCluster cluster = project.getCluster();
+ final RexBuilder rexBuilder = cluster.getRexBuilder();
+ final RexCall caseCall = (RexCall) rexNode;
+
+ // If one arg is null and the other is not, reverse them and set "flip",
+ // which negates the filter.
+ final boolean flip = RexLiteral.isNullLiteral(caseCall.operands.get(1))
+ && !RexLiteral.isNullLiteral(caseCall.operands.get(2));
+ final RexNode arg1 = caseCall.operands.get(flip ? 2 : 1);
+ final RexNode arg2 = caseCall.operands.get(flip ? 1 : 2);
+
+ // Operand 1: Filter
+ final SqlPostfixOperator op = flip ? SqlStdOperatorTable.IS_NOT_TRUE :
SqlStdOperatorTable.IS_TRUE;
+ final RexNode filterFromCase = rexBuilder.makeCall(op,
caseCall.operands.get(0));
+
+ // Combine the CASE filter with an honest-to-goodness SQL FILTER, if the
+ // latter is present.
+ final RexNode filter;
+ if (call.filterArg >= 0) {
+ filter = rexBuilder.makeCall(
+ SqlStdOperatorTable.AND,
+ project.getProjects().get(call.filterArg),
+ filterFromCase
+ );
+ } else {
+ filter = filterFromCase;
+ }
+
+ RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
+ final SqlKind kind = call.getAggregation().getKind();
+ if (call.isDistinct()) {
+ // Just one style supported:
+ // COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END)
+ // =>
+ // COUNT(DISTINCT y) FILTER(WHERE x = 'foo')
+
+ if (kind == SqlKind.COUNT
+ && RexLiteral.isNullLiteral(arg2)) {
+ newProjects.add(arg1);
+ newProjects.add(filter);
+ return AggregateCall.create(
+ SqlStdOperatorTable.COUNT,
+ true,
+ false,
+ false,
+ call.rexList,
+ ImmutableList.of(newProjects.size() - 2),
+ newProjects.size() - 1,
+ null,
+ RelCollations.EMPTY,
+ call.getType(),
+ call.getName()
+ );
+ }
+ return null;
+ }
+
+ // Four styles supported:
+ //
+ // A1: AGG(CASE WHEN x = 'foo' THEN expr END)
+ // => AGG(expr) FILTER (x = 'foo')
+ // A2: SUM0(CASE WHEN x = 'foo' THEN cnt ELSE 0 END)
+ // => SUM0(cnt) FILTER (x = 'foo')
+ // B: SUM0(CASE WHEN x = 'foo' THEN 1 ELSE 0 END)
+ // => COUNT() FILTER (x = 'foo')
+ // C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END)
+ // => COUNT() FILTER (x = 'foo')
+
+ if (kind == SqlKind.COUNT // Case C
+ && arg1.isA(SqlKind.LITERAL)
+ && !RexLiteral.isNullLiteral(arg1)
+ && RexLiteral.isNullLiteral(arg2)) {
+ newProjects.add(filter);
+ return AggregateCall.create(
+ SqlStdOperatorTable.COUNT,
+ false,
+ false,
+ false,
+ call.rexList, ImmutableList.of(), newProjects.size() - 1, null,
+ RelCollations.EMPTY, call.getType(),
+ call.getName());
+ } else if (kind == SqlKind.SUM0 // Case B
+ && isIntLiteral(arg1, BigDecimal.ONE)
+ && isIntLiteral(arg2, BigDecimal.ZERO)) {
+
+ newProjects.add(filter);
+ final RelDataType dataType = typeFactory
+
.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT),
false);
+ return AggregateCall.create(
+ SqlStdOperatorTable.COUNT,
+ false,
+ false,
+ false,
+ call.rexList,
+ ImmutableList.of(),
+ newProjects.size() - 1,
+ null,
+ RelCollations.EMPTY,
+ dataType,
+ call.getName()
+ );
+ } else if ((RexLiteral.isNullLiteral(arg2) // Case A1
+ && call.getAggregation().allowsFilter())
+ || (kind == SqlKind.SUM0 // Case A2
+ && isIntLiteral(arg2, BigDecimal.ZERO))) {
+ newProjects.add(arg1);
+ newProjects.add(filter);
+ return AggregateCall.create(
+ call.getAggregation(),
+ false,
+ false,
+ false,
+ call.rexList,
+ ImmutableList.of(newProjects.size() - 2),
+ newProjects.size() - 1,
+ null,
+ RelCollations.EMPTY,
+ call.getType(),
+ call.getName()
+ );
+ }
+
+ // Rewrites
+ // D1: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END)
+ // => SUM0(cnt) FILTER (x = 'foo')
+ // D2: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END)
+ // => COUNT() FILTER (x = 'foo')
+ //
+ // https://issues.apache.org/jira/browse/CALCITE-5953
+ // have restricted this rewrite as in case there are no rows it may not be
equvivalent;
+ // however it may have some performance impact in Druid
+ if (extendedFilteredSumRewrite &&
+ kind == SqlKind.SUM && isIntLiteral(arg2, BigDecimal.ZERO)) {
+ if (isIntLiteral(arg1, BigDecimal.ONE)) { // D2
+ newProjects.add(filter);
+ final RelDataType dataType = typeFactory.createTypeWithNullability(
+ typeFactory.createSqlType(SqlTypeName.BIGINT), false
+ );
+ return AggregateCall.create(
+ SqlStdOperatorTable.COUNT,
+ false,
+ false,
+ false,
+ call.rexList,
+ ImmutableList.of(),
+ newProjects.size() - 1,
+ null,
+ RelCollations.EMPTY,
+ dataType,
+ call.getName()
+ );
+
+ } else { // D1
+ newProjects.add(arg1);
+ newProjects.add(filter);
+
+ RelDataType newType =
typeFactory.createTypeWithNullability(call.getType(), true);
+ return AggregateCall.create(
+ call.getAggregation(),
+ false,
+ false,
+ false,
+ call.rexList,
+ ImmutableList.of(newProjects.size() - 2),
+ newProjects.size() - 1,
+ null,
+ RelCollations.EMPTY,
+ newType,
+ call.getName()
+ );
+ }
+ }
+
+ return null;
+ }
+
+ /**
+ * Returns the argument, if an aggregate call has a single argument,
otherwise
+ * -1.
+ */
+ private static int soleArgument(AggregateCall aggregateCall)
+ {
+ return aggregateCall.getArgList().size() == 1
+ ? aggregateCall.getArgList().get(0)
+ : -1;
+ }
+
+ private static boolean isThreeArgCase(final RexNode rexNode)
+ {
+ return rexNode.getKind() == SqlKind.CASE
+ && ((RexCall) rexNode).operands.size() == 3;
+ }
+
+ private static boolean isIntLiteral(RexNode rexNode, BigDecimal value)
+ {
+ return rexNode instanceof RexLiteral
+ && SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName())
+ && value.equals(((RexLiteral) rexNode).getValueAs(BigDecimal.class));
+ }
+}
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index f2a85b4a4a9..12f17b015c1 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -5188,7 +5188,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testFilteredAggregations()
{
- cannotVectorizeUnlessFallback();
Druids.TimeseriesQueryBuilder builder =
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
@@ -5196,18 +5195,9 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
.granularity(Granularities.ALL)
.context(QUERY_CONTEXT_DEFAULT);
if (NullHandling.sqlCompatible()) {
+ cannotVectorizeUnlessFallback();
builder = builder.virtualColumns(
- expressionVirtualColumn("v0", "substring(\"dim1\",
0, 1)", ColumnType.STRING),
- expressionVirtualColumn(
- "v1",
- "case_searched((\"dim1\" != '1'),1,0)",
- ColumnType.LONG
- ),
- expressionVirtualColumn(
- "v2",
- "case_searched((\"dim1\" != '1'),\"cnt\",0)",
- ColumnType.LONG
- )
+ expressionVirtualColumn("v0", "substring(\"dim1\",
0, 1)", ColumnType.STRING)
)
.aggregators(
aggregators(
@@ -5234,7 +5224,10 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
new CountAggregatorFactory("a4"),
not(equality("dim1", "1",
ColumnType.STRING))
),
- new LongSumAggregatorFactory("a5", "v1"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a5"),
+ not(equality("dim1", "1",
ColumnType.STRING))
+ ),
new FilteredAggregatorFactory(
new LongSumAggregatorFactory("a6", "cnt"),
equality("dim2", "a", ColumnType.STRING)
@@ -5246,7 +5239,10 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
not(equality("dim1", "1",
ColumnType.STRING))
)
),
- new LongSumAggregatorFactory("a8", "v2"),
+ new FilteredAggregatorFactory(
+ new LongSumAggregatorFactory("a8", "cnt"),
+ not(equality("dim1", "1",
ColumnType.STRING))
+ ),
new FilteredAggregatorFactory(
new LongMaxAggregatorFactory("a9", "cnt"),
not(equality("dim1", "1",
ColumnType.STRING))
@@ -5272,16 +5268,7 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
);
} else {
builder = builder.virtualColumns(
- expressionVirtualColumn(
- "v0",
- "case_searched((\"dim1\" != '1'),1,0)",
- ColumnType.LONG
- ),
- expressionVirtualColumn(
- "v1",
- "case_searched((\"dim1\" != '1'),\"cnt\",0)",
- ColumnType.LONG
- ))
+ )
.aggregators(
aggregators(
new FilteredAggregatorFactory(
@@ -5307,7 +5294,10 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
new CountAggregatorFactory("a4"),
not(equality("dim1", "1", ColumnType.STRING))
),
- new LongSumAggregatorFactory("a5", "v0"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a5"),
+ not(equality("dim1", "1", ColumnType.STRING))
+ ),
new FilteredAggregatorFactory(
new LongSumAggregatorFactory("a6", "cnt"),
equality("dim2", "a", ColumnType.STRING)
@@ -5319,7 +5309,10 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
not(equality("dim1", "1", ColumnType.STRING))
)
),
- new LongSumAggregatorFactory("a8", "v1"),
+ new FilteredAggregatorFactory(
+ new LongSumAggregatorFactory("a8", "cnt"),
+ not(equality("dim1", "1", ColumnType.STRING))
+ ),
new FilteredAggregatorFactory(
new LongMaxAggregatorFactory("a9", "cnt"),
not(equality("dim1", "1", ColumnType.STRING))
@@ -5373,7 +5366,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testCaseFilteredAggregationWithGroupBy()
{
- cannotVectorizeUnlessFallback();
testQuery(
"SELECT\n"
+ " cnt,\n"
@@ -5386,15 +5378,11 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(new
DefaultDimensionSpec("cnt", "d0", ColumnType.LONG)))
- .setVirtualColumns(
- expressionVirtualColumn(
- "v0",
- "case_searched((\"dim1\" != '1'),1,0)",
- ColumnType.LONG
- )
- )
.setAggregatorSpecs(aggregators(
- new LongSumAggregatorFactory("a0", "v0"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a0"),
+ not(equality("dim1", "1", ColumnType.STRING))
+ ),
new LongSumAggregatorFactory("a1", "cnt")
))
.setPostAggregatorSpecs(
@@ -5409,6 +5397,52 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
);
}
+ @Test
+ public void testCaseFilteredAggregationWithGroupRewriteToSum()
+ {
+ testBuilder()
+ .sql(
+ "SELECT\n"
+ + " cnt,\n"
+ + " SUM(CASE WHEN dim1 <> '1' THEN 2 ELSE 0 END) + SUM(cnt)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY cnt"
+ )
+ .expectedQueries(
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(dimensions(new DefaultDimensionSpec("cnt",
"d0", ColumnType.LONG)))
+ .setVirtualColumns(
+ expressionVirtualColumn("v0", "2", ColumnType.LONG)
+ )
+ .setAggregatorSpecs(
+ aggregators(
+ new FilteredAggregatorFactory(
+ new LongSumAggregatorFactory("a0", "v0"),
+ not(equality("dim1", "1", ColumnType.STRING))
+ ),
+ new LongSumAggregatorFactory("a1", "cnt")
+ )
+ )
+ .setPostAggregatorSpecs(
+ expressionPostAgg("p0", "(\"a0\" + \"a1\")",
ColumnType.LONG)
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ )
+ )
+ .expectedResults(
+ ImmutableList.of(
+ new Object[] {1L, 16L}
+ )
+ )
+ .run();
+ }
+
+
@Test
public void testFilteredAggregationWithNotIn()
{
@@ -9479,7 +9513,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testQueryWithSelectProjectAndIdentityProjectDoesNotRename()
{
- cannotVectorizeUnlessFallback();
msqIncompatible();
testQuery(
PLANNER_CONFIG_NO_HLL.withOverrides(
@@ -9506,25 +9539,30 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
"v0",
"((\"__time\" >=
947005200000) && (\"__time\" < 1641402000000))",
ColumnType.LONG
- ),
- expressionVirtualColumn(
- "v1",
-
"case_searched(((\"__time\" >= 947005200000) && (\"__time\" <
1641402000000)),1,0)",
- ColumnType.LONG
)
)
.setDimensions(
dimensions(
- new
DefaultDimensionSpec("dim1", "d0", ColumnType.STRING),
- new
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+ new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG),
+ new
DefaultDimensionSpec("dim1", "d1", ColumnType.STRING)
)
)
.setAggregatorSpecs(
aggregators(
- new
LongSumAggregatorFactory("a0", "v1"),
+ new
FilteredAggregatorFactory(
+ new
CountAggregatorFactory("a0"),
+ range(
+ "__time",
+ ColumnType.LONG,
+
timestamp("2000-01-04T17:00:00"),
+
timestamp("2022-01-05T17:00:00"),
+ false,
+ true
+ )
+ ),
new
GroupingAggregatorFactory(
"a1",
-
ImmutableList.of("dim1", "v0")
+ ImmutableList.of("v0",
"dim1")
)
)
)
@@ -9549,9 +9587,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a1"),
and(
- notNull("d0"),
+ notNull("d1"),
equality("a1", 0L, ColumnType.LONG),
- expressionFilter("\"d1\"")
+ expressionFilter("\"d0\"")
)
)
)
diff --git
a/sql/src/test/quidem/org.apache.druid.quidem.SqlQuidemTest/filtered_sum.iq
b/sql/src/test/quidem/org.apache.druid.quidem.SqlQuidemTest/filtered_sum.iq
new file mode 100644
index 00000000000..7da2f4ffd9d
--- /dev/null
+++ b/sql/src/test/quidem/org.apache.druid.quidem.SqlQuidemTest/filtered_sum.iq
@@ -0,0 +1,120 @@
+!use druidtest://?numMergeBuffers=3
+!set outputformat mysql
+
+-- empty input
+SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is
null,
+ SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1)
FROM numfoo where l1 < -1;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+| 0 | 0 | true | | |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+-- 0=-1,0
+SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is
null,
+ SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1)
FROM numfoo where l1 < 3;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+| 0 | 1 | false | | |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+
+-- 0=0,0
+SELECT COUNT(1)FILTER(WHERE l1=0),COUNT(1)FILTER(WHERE l1!=0),MIN(l2) is null,
+ SUM(CASE WHEN l1 = 0 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=0)
FROM numfoo where l1 < 3;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+| 1 | 0 | false | 0 | 0 |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+-- 7=7,null
+SELECT COUNT(1)FILTER(WHERE l1=7),COUNT(1)FILTER(WHERE l1!=7),MIN(l2) is null,
+ SUM(CASE WHEN l1 = 7 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=7)
FROM numfoo where 0 < l1 and l1 < 10;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+| 1 | 0 | true | | |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+LogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$2=[IS NULL($2)], EXPR$3=[$3],
EXPR$4=[$4])
+ LogicalAggregate(group=[{}], EXPR$0=[COUNT() FILTER $0], EXPR$1=[COUNT()
FILTER $1], agg#2=[MIN($2)], EXPR$3=[SUM($3)], EXPR$4=[SUM($2) FILTER $0])
+ LogicalProject($f1=[IS TRUE(=($0, 7))], $f2=[IS TRUE(<>($0, 7))], l2=[$1],
$f4=[CASE(=($0, 7), $1, 0:BIGINT)])
+ LogicalFilter(condition=[SEARCH($0, Sarg[(0..10)])])
+ LogicalProject(l1=[$11], l2=[$12])
+ LogicalTableScan(table=[[druid, numfoo]])
+
+!druidPlan
+
+!set extendedFilteredSumRewrite false
+!use druidtest://?numMergeBuffers=3
+
+
+-- empty input
+SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is
null,
+ SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1)
FROM numfoo where l1 < -1;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+| 0 | 0 | true | | |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+-- 0=-1,0
+SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is
null,
+ SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1)
FROM numfoo where l1 < 3;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+| 0 | 1 | false | 0 | |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+
+-- 0=0,0
+SELECT COUNT(1)FILTER(WHERE l1=0),COUNT(1)FILTER(WHERE l1!=0),MIN(l2) is null,
+ SUM(CASE WHEN l1 = 0 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=0)
FROM numfoo where l1 < 3;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+| 1 | 0 | false | 0 | 0 |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+-- 7=7,null
+SELECT COUNT(1)FILTER(WHERE l1=7),COUNT(1)FILTER(WHERE l1!=7),MIN(l2) is null,
+ SUM(CASE WHEN l1 = 7 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=7)
FROM numfoo where 0 < l1 and l1 < 10;
++--------+--------+--------+--------+--------+
+| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
++--------+--------+--------+--------+--------+
+| 1 | 0 | true | | |
++--------+--------+--------+--------+--------+
+(1 row)
+
+!ok
+
+LogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$2=[IS NULL($2)], EXPR$3=[$3],
EXPR$4=[$4])
+ LogicalAggregate(group=[{}], EXPR$0=[COUNT() FILTER $0], EXPR$1=[COUNT()
FILTER $1], agg#2=[MIN($2)], EXPR$3=[SUM($3)], EXPR$4=[SUM($2) FILTER $0])
+ LogicalProject($f1=[IS TRUE(=($0, 7))], $f2=[IS TRUE(<>($0, 7))], l2=[$1],
$f4=[CASE(=($0, 7), $1, 0:BIGINT)])
+ LogicalFilter(condition=[SEARCH($0, Sarg[(0..10)])])
+ LogicalProject(l1=[$11], l2=[$12])
+ LogicalTableScan(table=[[druid, numfoo]])
+
+!druidPlan
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]