This is an automated email from the ASF dual-hosted git repository.
cwylie pushed a commit to branch 0.19.0
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/0.19.0 by this push:
new d310682 Fix avg sql aggregator (#10135) (#10162)
d310682 is described below
commit d310682a39b2f69c4216d8a42b1a8d4e45f80423
Author: Clint Wylie <[email protected]>
AuthorDate: Thu Jul 9 14:40:18 2020 -0700
Fix avg sql aggregator (#10135) (#10162)
* new average aggregator
* method to create count aggregator factory
* test everything
* update other usages
* fix style
* fix more tests
* fix datasketches tests
Co-authored-by: Franklyn Dsouza <[email protected]>
---
.../hll/sql/HllSketchSqlAggregatorTest.java | 18 ++-
.../theta/sql/ThetaSketchSqlAggregatorTest.java | 18 ++-
.../aggregation/builtin/AvgSqlAggregator.java | 65 +++++++--
.../aggregation/builtin/CountSqlAggregator.java | 65 ++++++---
.../apache/druid/sql/calcite/CalciteQueryTest.java | 153 ++++++++++++++++-----
5 files changed, 247 insertions(+), 72 deletions(-)
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java
index 6f4c8b6..039801c 100644
---
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java
@@ -388,10 +388,20 @@ public class HllSketchSqlAggregatorTest extends
CalciteTestBase
)
.setInterval(new
MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL)
- .setAggregatorSpecs(Arrays.asList(
- new LongSumAggregatorFactory("_a0:sum",
"a0"),
- new CountAggregatorFactory("_a0:count")
- ))
+ .setAggregatorSpecs(
+ NullHandling.replaceWithDefault()
+ ? Arrays.asList(
+ new LongSumAggregatorFactory("_a0:sum",
"a0"),
+ new CountAggregatorFactory("_a0:count")
+ )
+ : Arrays.asList(
+ new
LongSumAggregatorFactory("_a0:sum", "a0"),
+ new FilteredAggregatorFactory(
+ new
CountAggregatorFactory("_a0:count"),
+
BaseCalciteQueryTest.not(BaseCalciteQueryTest.selector("a0", null, null))
+ )
+ )
+ )
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java
index 201380a..9201c91 100644
---
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java
@@ -385,10 +385,20 @@ public class ThetaSketchSqlAggregatorTest extends
CalciteTestBase
)
.setInterval(new
MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL)
- .setAggregatorSpecs(Arrays.asList(
- new LongSumAggregatorFactory("_a0:sum",
"a0"),
- new CountAggregatorFactory("_a0:count")
- ))
+ .setAggregatorSpecs(
+ NullHandling.replaceWithDefault()
+ ? Arrays.asList(
+ new LongSumAggregatorFactory("_a0:sum",
"a0"),
+ new CountAggregatorFactory("_a0:count")
+ )
+ : Arrays.asList(
+ new
LongSumAggregatorFactory("_a0:sum", "a0"),
+ new FilteredAggregatorFactory(
+ new
CountAggregatorFactory("_a0:count"),
+
BaseCalciteQueryTest.not(BaseCalciteQueryTest.selector("a0", null, null))
+ )
+ )
+ )
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
index 2761d0c..3b97344 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
@@ -20,20 +20,31 @@
package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory;
-import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
+import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
+import org.apache.druid.sql.calcite.aggregation.Aggregations;
+import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
+import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.Calcites;
+import org.apache.druid.sql.calcite.planner.PlannerContext;
+import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
-public class AvgSqlAggregator extends SimpleSqlAggregator
+import javax.annotation.Nullable;
+import java.util.List;
+
+public class AvgSqlAggregator implements SqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
@@ -41,15 +52,46 @@ public class AvgSqlAggregator extends SimpleSqlAggregator
return SqlStdOperatorTable.AVG;
}
+ @Nullable
@Override
- Aggregation getAggregation(
+ public Aggregation toDruidAggregation(
+ final PlannerContext plannerContext,
+ final RowSignature rowSignature,
+ final VirtualColumnRegistry virtualColumnRegistry,
+ final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
- final ExprMacroTable macroTable,
- final String fieldName,
- final String expression
+ final Project project,
+ final List<Aggregation> existingAggregations,
+ final boolean finalizeAggregations
)
{
+
+ final List<DruidExpression> arguments =
Aggregations.getArgumentsForSimpleAggregator(
+ plannerContext,
+ rowSignature,
+ aggregateCall,
+ project
+ );
+
+ if (arguments == null) {
+ return null;
+ }
+
+ final String fieldName;
+ final String expression;
+ final DruidExpression arg = Iterables.getOnlyElement(arguments);
+
+ if (arg.isDirectColumnAccess()) {
+ fieldName = arg.getDirectColumn();
+ expression = null;
+ } else {
+ fieldName = null;
+ expression = arg.getExpression();
+ }
+
+ final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
+
final ValueType sumType;
// Use 64-bit sum regardless of the type of the AVG aggregator.
if
(SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
@@ -67,8 +109,15 @@ public class AvgSqlAggregator extends SimpleSqlAggregator
expression,
macroTable
);
-
- final AggregatorFactory count = new CountAggregatorFactory(countName);
+ final AggregatorFactory count =
CountSqlAggregator.createCountAggregatorFactory(
+ countName,
+ plannerContext,
+ rowSignature,
+ virtualColumnRegistry,
+ rexBuilder,
+ aggregateCall,
+ project
+ );
return Aggregation.create(
ImmutableList.of(sum, count),
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java
index 6bf8b60..f674798 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java
@@ -28,7 +28,9 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
+import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
@@ -52,6 +54,41 @@ public class CountSqlAggregator implements SqlAggregator
return SqlStdOperatorTable.COUNT;
}
+ static AggregatorFactory createCountAggregatorFactory(
+ final String countName,
+ final PlannerContext plannerContext,
+ final RowSignature rowSignature,
+ final VirtualColumnRegistry virtualColumnRegistry,
+ final RexBuilder rexBuilder,
+ final AggregateCall aggregateCall,
+ final Project project
+ )
+ {
+ final RexNode rexNode = Expressions.fromFieldAccess(
+ rowSignature,
+ project,
+ Iterables.getOnlyElement(aggregateCall.getArgList())
+ );
+
+ if (rexNode.getType().isNullable()) {
+ final DimFilter nonNullFilter = Expressions.toFilter(
+ plannerContext,
+ rowSignature,
+ virtualColumnRegistry,
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
ImmutableList.of(rexNode))
+ );
+
+ if (nonNullFilter == null) {
+ // Don't expect this to happen.
+ throw new ISE("Could not create not-null filter for rexNode[%s]",
rexNode);
+ }
+
+ return new FilteredAggregatorFactory(new
CountAggregatorFactory(countName), nonNullFilter);
+ } else {
+ return new CountAggregatorFactory(countName);
+ }
+ }
+
@Nullable
@Override
public Aggregation toDruidAggregation(
@@ -96,32 +133,16 @@ public class CountSqlAggregator implements SqlAggregator
}
} else {
// Not COUNT(*), not distinct
-
// COUNT(x) should count all non-null values of x.
- final RexNode rexNode = Expressions.fromFieldAccess(
- rowSignature,
- project,
- Iterables.getOnlyElement(aggregateCall.getArgList())
- );
-
- if (rexNode.getType().isNullable()) {
- final DimFilter nonNullFilter = Expressions.toFilter(
+ return Aggregation.create(createCountAggregatorFactory(
+ name,
plannerContext,
rowSignature,
virtualColumnRegistry,
- rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
ImmutableList.of(rexNode))
- );
-
- if (nonNullFilter == null) {
- // Don't expect this to happen.
- throw new ISE("Could not create not-null filter for rexNode[%s]",
rexNode);
- }
-
- return Aggregation.create(new CountAggregatorFactory(name))
- .filter(rowSignature, virtualColumnRegistry,
nonNullFilter);
- } else {
- return Aggregation.create(new CountAggregatorFactory(name));
- }
+ rexBuilder,
+ aggregateCall,
+ project
+ ));
}
}
}
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index e51e81d..deffe20 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -244,10 +244,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(new DefaultDimensionSpec("dim2", "d0",
ValueType.STRING))
.setGranularity(Granularities.ALL)
- .setAggregatorSpecs(aggregators(
- new DoubleSumAggregatorFactory("a0:sum", "m2"),
- new CountAggregatorFactory("a0:count")
- )
+ .setAggregatorSpecs(
+ useDefault
+ ? aggregators(
+ new DoubleSumAggregatorFactory("a0:sum", "m2"),
+ new CountAggregatorFactory("a0:count")
+ )
+ : aggregators(
+ new DoubleSumAggregatorFactory("a0:sum", "m2"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a0:count"),
+ not(selector("m2", null, null))
+ )
+ )
)
.setPostAggregatorSpecs(
ImmutableList.of(
@@ -313,10 +322,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(new DefaultDimensionSpec("dim2", "d0",
ValueType.STRING))
.setGranularity(Granularities.ALL)
- .setAggregatorSpecs(aggregators(
- new DoubleSumAggregatorFactory("a0:sum", "m2"),
- new CountAggregatorFactory("a0:count")
- )
+ .setAggregatorSpecs(
+ useDefault
+ ? aggregators(
+ new DoubleSumAggregatorFactory("a0:sum", "m2"),
+ new CountAggregatorFactory("a0:count")
+ )
+ : aggregators(
+ new DoubleSumAggregatorFactory("a0:sum", "m2"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a0:count"),
+ not(selector("m2", null, null))
+ )
+ )
)
.setPostAggregatorSpecs(
ImmutableList.of(
@@ -390,10 +408,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(new DefaultDimensionSpec("dim2", "d0",
ValueType.STRING))
.setGranularity(Granularities.ALL)
- .setAggregatorSpecs(aggregators(
- new DoubleSumAggregatorFactory("a0:sum", "m2"),
- new CountAggregatorFactory("a0:count")
- )
+ .setAggregatorSpecs(
+ useDefault
+ ? aggregators(
+ new DoubleSumAggregatorFactory("a0:sum", "m2"),
+ new CountAggregatorFactory("a0:count")
+ )
+ : aggregators(
+ new DoubleSumAggregatorFactory("a0:sum", "m2"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a0:count"),
+ not(selector("m2", null, null))
+ )
+ )
)
.setPostAggregatorSpecs(
ImmutableList.of(
@@ -4730,11 +4757,11 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
cannotVectorize();
testQuery(
- "SELECT COUNT(*), COUNT(cnt), COUNT(dim1), AVG(cnt), SUM(cnt),
SUM(cnt) + MIN(cnt) + MAX(cnt), COUNT(dim2) FROM druid.foo",
+ "SELECT COUNT(*), COUNT(cnt), COUNT(dim1), AVG(cnt), SUM(cnt),
SUM(cnt) + MIN(cnt) + MAX(cnt), COUNT(dim2), COUNT(d1), AVG(d1) FROM
druid.numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
- .dataSource(CalciteTests.DATASOURCE1)
+ .dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(
@@ -4753,7 +4780,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new FilteredAggregatorFactory(
new CountAggregatorFactory("a6"),
not(selector("dim2", null, null))
- )
+ ),
+ new DoubleSumAggregatorFactory("a7:sum", "d1"),
+ new CountAggregatorFactory("a7:count")
)
: aggregators(
new CountAggregatorFactory("a0"),
@@ -4766,13 +4795,25 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
not(selector("dim1", null, null))
),
new LongSumAggregatorFactory("a3:sum", "cnt"),
- new CountAggregatorFactory("a3:count"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a3:count"),
+ not(selector("cnt", null, null))
+ ),
new LongSumAggregatorFactory("a4", "cnt"),
new LongMinAggregatorFactory("a5", "cnt"),
new LongMaxAggregatorFactory("a6", "cnt"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a7"),
not(selector("dim2", null, null))
+ ),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a8"),
+ not(selector("d1", null, null))
+ ),
+ new DoubleSumAggregatorFactory("a9:sum", "d1"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a9:count"),
+ not(selector("d1", null, null))
)
)
)
@@ -4785,6 +4826,14 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
new FieldAccessPostAggregator(null, useDefault ?
"a2:count" : "a3:count")
)
),
+ new ArithmeticPostAggregator(
+ useDefault ? "a7" : "a9",
+ "quotient",
+ ImmutableList.of(
+ new FieldAccessPostAggregator(null, useDefault ?
"a7:sum" : "a9:sum"),
+ new FieldAccessPostAggregator(null, useDefault ?
"a7:count" : "a9:count")
+ )
+ ),
expressionPostAgg(
"p0",
useDefault ? "((\"a3\" + \"a4\") + \"a5\")" :
"((\"a4\" + \"a5\") + \"a6\")"
@@ -4795,10 +4844,10 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
),
NullHandling.replaceWithDefault() ?
ImmutableList.of(
- new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L}
+ new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)}
) :
ImmutableList.of(
- new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L}
+ new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)}
)
);
}
@@ -6801,14 +6850,28 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
- .setAggregatorSpecs(aggregators(
- new LongMaxAggregatorFactory("_a0", "a0"),
- new LongMinAggregatorFactory("_a1", "a0"),
- new LongSumAggregatorFactory("_a2:sum", "a0"),
- new CountAggregatorFactory("_a2:count"),
- new LongMaxAggregatorFactory("_a3", "d0"),
- new CountAggregatorFactory("_a4")
- ))
+ .setAggregatorSpecs(
+ useDefault
+ ? aggregators(
+ new LongMaxAggregatorFactory("_a0", "a0"),
+ new LongMinAggregatorFactory("_a1", "a0"),
+ new LongSumAggregatorFactory("_a2:sum", "a0"),
+ new CountAggregatorFactory("_a2:count"),
+ new LongMaxAggregatorFactory("_a3", "d0"),
+ new CountAggregatorFactory("_a4")
+ )
+ : aggregators(
+ new LongMaxAggregatorFactory("_a0", "a0"),
+ new LongMinAggregatorFactory("_a1", "a0"),
+ new LongSumAggregatorFactory("_a2:sum", "a0"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("_a2:count"),
+ not(selector("a0", null, null))
+ ),
+ new LongMaxAggregatorFactory("_a3", "d0"),
+ new CountAggregatorFactory("_a4")
+ )
+ )
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(
@@ -6872,10 +6935,20 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
- .setAggregatorSpecs(aggregators(
- new LongSumAggregatorFactory("_a0:sum", "a0"),
- new CountAggregatorFactory("_a0:count")
- ))
+ .setAggregatorSpecs(
+ useDefault
+ ? aggregators(
+ new LongSumAggregatorFactory("_a0:sum", "a0"),
+ new CountAggregatorFactory("_a0:count")
+ )
+ : aggregators(
+ new LongSumAggregatorFactory("_a0:sum", "a0"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("_a0:count"),
+ not(selector("a0", null, null))
+ )
+ )
+ )
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(
@@ -12935,10 +13008,22 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
.dimension(new DefaultDimensionSpec("m1", "d0",
ValueType.FLOAT))
.filters("dim2", "a")
.aggregators(
- new DoubleSumAggregatorFactory("a0:sum", "m2"),
- new CountAggregatorFactory("a0:count"),
- new DoubleSumAggregatorFactory("a1", "m1"),
- new DoubleSumAggregatorFactory("a2", "m2")
+ useDefault
+ ? aggregators(
+ new DoubleSumAggregatorFactory("a0:sum", "m2"),
+ new CountAggregatorFactory("a0:count"),
+ new DoubleSumAggregatorFactory("a1", "m1"),
+ new DoubleSumAggregatorFactory("a2", "m2")
+ )
+ : aggregators(
+ new DoubleSumAggregatorFactory("a0:sum", "m2"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a0:count"),
+ not(selector("m2", null, null))
+ ),
+ new DoubleSumAggregatorFactory("a1", "m1"),
+ new DoubleSumAggregatorFactory("a2", "m2")
+ )
)
.postAggregators(
new ArithmeticPostAggregator(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]