This is an automated email from the ASF dual-hosted git repository.
cwylie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new ee41cc770f fix issue with SQL sum aggregator due to bug with
DruidTypeSystem and AggregateRemoveRule (#12880)
ee41cc770f is described below
commit ee41cc770f3a2b666e380d25ffbc5469a42a1d36
Author: Clint Wylie <[email protected]>
AuthorDate: Tue Aug 9 15:17:45 2022 -0700
fix issue with SQL sum aggregator due to bug with DruidTypeSystem and
AggregateRemoveRule (#12880)
* fix issue with SQL sum aggregator due to bug with DruidTypeSystem and
AggregateRemoveRule
* fix style
* add comment about using custom sum function
---
.../aggregation/builtin/SumSqlAggregator.java | 87 +++++++++++++++++++++-
.../apache/druid/sql/calcite/CalciteQueryTest.java | 46 ++++++++++++
2 files changed, 131 insertions(+), 2 deletions(-)
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SumSqlAggregator.java
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SumSqlAggregator.java
index cd7a13d935..f4dcad3ed5 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SumSqlAggregator.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SumSqlAggregator.java
@@ -20,8 +20,17 @@
package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.SqlFunctionCategory;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlSplittableAggFunction;
+import org.apache.calcite.sql.type.OperandTypes;
+import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.util.Optionality;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
@@ -35,10 +44,18 @@ import
org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException;
public class SumSqlAggregator extends SimpleSqlAggregator
{
+ /**
+ * We are using a custom SUM function instead of {@link
org.apache.calcite.sql.fun.SqlStdOperatorTable#SUM} to
+ * work around the issue described in
https://issues.apache.org/jira/browse/CALCITE-4609. Once we upgrade Calcite
+ * to 1.27.0+ we can return to using the built-in SUM function, and {@link
DruidSumAggFunction and
+ * {@link DruidSumSplitter} can be removed.
+ */
+ private static final SqlAggFunction DRUID_SUM = new DruidSumAggFunction();
+
@Override
public SqlAggFunction calciteFunction()
{
- return SqlStdOperatorTable.SUM;
+ return DRUID_SUM;
}
@Override
@@ -74,4 +91,70 @@ public class SumSqlAggregator extends SimpleSqlAggregator
throw new UnsupportedSQLQueryException("Sum aggregation is not
supported for '%s' type", aggregationType);
}
}
+
+ /**
+ * Customized verison of {@link
org.apache.calcite.sql.fun.SqlSumAggFunction} with a customized
+ * implementation of {@link #unwrap(Class)} to provide a customized {@link
SqlSplittableAggFunction} that correctly
+ * honors Druid's type system. The default sum implementation of {@link
SqlSplittableAggFunction} assumes that it can
+ * reduce its output to its input in the case of a single row, which means
that it doesn't necessarily reflect the
+ * output type as if it were run through the SUM function (e.g. INTEGER ->
BIGINT)
+ */
+ private static class DruidSumAggFunction extends SqlAggFunction
+ {
+ public DruidSumAggFunction()
+ {
+ super(
+ "SUM",
+ null,
+ SqlKind.SUM,
+ ReturnTypes.AGG_SUM,
+ null,
+ OperandTypes.NUMERIC,
+ SqlFunctionCategory.NUMERIC,
+ false,
+ false,
+ Optionality.FORBIDDEN
+ );
+ }
+
+ @Override
+ public <T> T unwrap(Class<T> clazz)
+ {
+ if (clazz == SqlSplittableAggFunction.class) {
+ return clazz.cast(DruidSumSplitter.INSTANCE);
+ }
+ return super.unwrap(clazz);
+ }
+ }
+
+ /**
+ * The default sum implementation of {@link SqlSplittableAggFunction}
assumes that it can reduce its output to its
+ * input in the case of a single row for the {@link #singleton(RexBuilder,
RelDataType, AggregateCall)} method, which
+ * is fine for the default type system where the output type of SUM is the
same numeric type as the inputs, but
+ * Druid SUM always produces DOUBLE or BIGINT, so this is incorrect for
+ * {@link org.apache.druid.sql.calcite.planner.DruidTypeSystem}.
+ */
+ private static class DruidSumSplitter extends
SqlSplittableAggFunction.AbstractSumSplitter
+ {
+ public static DruidSumSplitter INSTANCE = new DruidSumSplitter();
+
+ @Override
+ public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType,
AggregateCall aggregateCall)
+ {
+ final int arg = aggregateCall.getArgList().get(0);
+ final RelDataTypeField field = inputRowType.getFieldList().get(arg);
+ final RexNode inputRef = rexBuilder.makeInputRef(field.getType(), arg);
+ // if input and output do not aggree, we must cast the input to the
output type
+ if (!aggregateCall.getType().equals(field.getType())) {
+ return rexBuilder.makeCast(aggregateCall.getType(), inputRef);
+ }
+ return inputRef;
+ }
+
+ @Override
+ protected SqlAggFunction getMergeAggFunctionOfTopSplit()
+ {
+ return DRUID_SUM;
+ }
+ }
}
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index 610590237a..39e5dcd799 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -14044,4 +14044,50 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
)
);
}
+
+ @Test
+ public void testSubqueryTypeMismatchWithLiterals() throws Exception
+ {
+ testQuery(
+ "SELECT \n"
+ + " dim1,\n"
+ + " SUM(CASE WHEN sum_l1 = 0 THEN 1 ELSE 0 END) AS outer_l1\n"
+ + "from (\n"
+ + " select \n"
+ + " dim1,\n"
+ + " SUM(l1) as sum_l1\n"
+ + " from numfoo\n"
+ + " group by dim1\n"
+ + ")\n"
+ + "group by 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE3)
+ .setInterval(querySegmentSpec(Intervals.ETERNITY))
+ .setGranularity(Granularities.ALL)
+ .addDimension(new DefaultDimensionSpec("dim1", "_d0",
ColumnType.STRING))
+ .addAggregator(new LongSumAggregatorFactory("a0",
"l1"))
+ .setPostAggregatorSpecs(ImmutableList.of(
+ expressionPostAgg("p0", "case_searched((\"a0\" ==
0),1,0)")
+ ))
+ .build()
+ ),
+ useDefault ? ImmutableList.of(
+ new Object[]{"", 0L},
+ new Object[]{"1", 1L},
+ new Object[]{"10.1", 0L},
+ new Object[]{"2", 1L},
+ new Object[]{"abc", 1L},
+ new Object[]{"def", 1L}
+ ) : ImmutableList.of(
+ // in sql compatible mode, null does not equal 0 so the values
which were 1 previously are not in this mode
+ new Object[]{"", 0L},
+ new Object[]{"1", 0L},
+ new Object[]{"10.1", 0L},
+ new Object[]{"2", 1L},
+ new Object[]{"abc", 0L},
+ new Object[]{"def", 0L}
+ )
+ );
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]