This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new d617f3b6c8 Proper null handling in Aggregation functions for SV data
types (#9086)
d617f3b6c8 is described below
commit d617f3b6c89261164163028789d16697b10a7ff0
Author: nizarhejazi <[email protected]>
AuthorDate: Wed Aug 3 16:10:29 2022 -0700
Proper null handling in Aggregation functions for SV data types (#9086)
Proper null handling for SV data types in the following aggregation
functions:
- AVG
- MIN
- MAX
- COUNT
- SUM
- SUMPRECISION
---
.../request/context/RequestContextUtils.java | 10 +-
.../operator/blocks/IntermediateResultsBlock.java | 86 +++++--
.../query/AggregationGroupByOrderByOperator.java | 4 +-
.../core/operator/query/AggregationOperator.java | 6 +-
.../query/DictionaryBasedDistinctOperator.java | 2 +-
.../core/operator/query/DistinctOperator.java | 2 +-
.../operator/query/FastFilteredCountOperator.java | 2 +-
.../query/FilteredAggregationOperator.java | 7 +-
.../query/NonScanBasedAggregationOperator.java | 2 +-
.../pinot/core/plan/AggregationPlanNode.java | 14 +-
.../apache/pinot/core/plan/DistinctPlanNode.java | 2 +-
.../aggregation/ObjectAggregationResultHolder.java | 2 +-
.../function/AggregationFunctionFactory.java | 12 +-
.../function/AvgAggregationFunction.java | 106 +++++++-
.../function/CountAggregationFunction.java | 65 ++++-
.../function/CountMVAggregationFunction.java | 4 +-
.../function/MaxAggregationFunction.java | 150 ++++++++++-
.../function/MinAggregationFunction.java | 150 ++++++++++-
.../function/SumAggregationFunction.java | 142 ++++++++++-
.../function/SumPrecisionAggregationFunction.java | 213 +++++++++++++++-
.../groupby/ObjectGroupByResultHolder.java | 4 +-
.../query/reduce/AggregationDataTableReducer.java | 47 ++--
.../function/AggregationFunctionFactoryTest.java | 7 +-
.../core/startree/v2/CountStarTreeV2Test.java | 4 +-
.../apache/pinot/queries/AllNullQueriesTest.java | 144 +++++++++++
.../pinot/queries/BigDecimalQueriesTest.java | 51 +---
.../queries/BooleanNullEnabledQueriesTest.java | 38 +--
.../pinot/queries/NullEnabledQueriesTest.java | 277 +++++++++++++++++++--
.../pinot/queries/SumPrecisionQueriesTest.java | 4 +-
.../query/runtime/operator/AggregateOperator.java | 4 +-
.../local/aggregator/CountValueAggregator.java | 6 +-
31 files changed, 1382 insertions(+), 185 deletions(-)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
index 6553a405ed..97668db6a9 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
@@ -81,11 +81,6 @@ public class RequestContextUtils {
*/
public static FunctionContext getFunction(Function thriftFunction) {
String functionName = thriftFunction.getOperator();
- if
(functionName.equalsIgnoreCase(AggregationFunctionType.COUNT.getName())) {
- // NOTE: COUNT always take one single argument "*"
- return new FunctionContext(FunctionContext.Type.AGGREGATION,
AggregationFunctionType.COUNT.getName(),
- new
ArrayList<>(Collections.singletonList(ExpressionContext.forIdentifier("*"))));
- }
FunctionContext.Type functionType =
AggregationFunctionType.isAggregationFunction(functionName) ?
FunctionContext.Type.AGGREGATION
: FunctionContext.Type.TRANSFORM;
@@ -95,6 +90,11 @@ public class RequestContextUtils {
for (Expression operand : operands) {
arguments.add(getExpression(operand));
}
+ // TODO(walterddr): a work-around for multi-stage query engine which
might pass COUNT without argument, and
+ // should be removed once that issue is fixed.
+ if (arguments.isEmpty() &&
functionName.equalsIgnoreCase(AggregationFunctionType.COUNT.getName())) {
+ arguments.add(ExpressionContext.forIdentifier("*"));
+ }
return new FunctionContext(functionType, functionName, arguments);
} else {
return new FunctionContext(functionType, functionName,
Collections.emptyList());
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java
index 1658e2d198..93ef4abb80 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java
@@ -98,19 +98,23 @@ public class IntermediateResultsBlock implements Block {
* <p>For aggregation only, the result is a list of values.
* <p>For aggregation group-by, the result is a list of maps from group keys
to aggregation values.
*/
- public IntermediateResultsBlock(AggregationFunction[] aggregationFunctions,
List<Object> aggregationResult) {
+ public IntermediateResultsBlock(AggregationFunction[] aggregationFunctions,
List<Object> aggregationResult,
+ boolean nullHandlingEnabled) {
_aggregationFunctions = aggregationFunctions;
_aggregationResult = aggregationResult;
+ _nullHandlingEnabled = nullHandlingEnabled;
}
/**
* Constructor for aggregation group-by order-by result with {@link
AggregationGroupByResult}.
*/
public IntermediateResultsBlock(AggregationFunction[] aggregationFunctions,
- @Nullable AggregationGroupByResult aggregationGroupByResults, DataSchema
dataSchema) {
+ @Nullable AggregationGroupByResult aggregationGroupByResults, DataSchema
dataSchema,
+ boolean nullHandlingEnabled) {
_aggregationFunctions = aggregationFunctions;
_aggregationGroupByResult = aggregationGroupByResults;
_dataSchema = dataSchema;
+ _nullHandlingEnabled = nullHandlingEnabled;
}
/**
@@ -118,10 +122,11 @@ public class IntermediateResultsBlock implements Block {
* with a collection of intermediate records.
*/
public IntermediateResultsBlock(AggregationFunction[] aggregationFunctions,
- Collection<IntermediateRecord> intermediateRecords, DataSchema
dataSchema) {
+ Collection<IntermediateRecord> intermediateRecords, DataSchema
dataSchema, boolean nullHandlingEnabled) {
_aggregationFunctions = aggregationFunctions;
_dataSchema = dataSchema;
_intermediateRecords = intermediateRecords;
+ _nullHandlingEnabled = nullHandlingEnabled;
}
public IntermediateResultsBlock(Table table, boolean nullHandlingEnabled) {
@@ -458,28 +463,75 @@ public class IntermediateResultsBlock implements Block {
columnNames[i] = aggregationFunction.getColumnName();
columnDataTypes[i] =
aggregationFunction.getIntermediateResultColumnType();
}
+ RoaringBitmap[] nullBitmaps = null;
+ Object[] colDefaultNullValues = null;
+ if (_nullHandlingEnabled) {
+ colDefaultNullValues = new Object[numAggregationFunctions];
+ nullBitmaps = new RoaringBitmap[numAggregationFunctions];
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ if (columnDataTypes[i] != ColumnDataType.OBJECT) {
+ colDefaultNullValues[i] =
NullValueUtils.getDefaultNullValue(columnDataTypes[i].toDataType());
+ }
+ nullBitmaps[i] = new RoaringBitmap();
+ }
+ }
// Build the data table.
DataTableBuilder dataTableBuilder =
DataTableFactory.getDataTableBuilder(new DataSchema(columnNames,
columnDataTypes));
dataTableBuilder.startRow();
- for (int i = 0; i < numAggregationFunctions; i++) {
- switch (columnDataTypes[i]) {
- case LONG:
- dataTableBuilder.setColumn(i, ((Number)
_aggregationResult.get(i)).longValue());
- break;
- case DOUBLE:
- dataTableBuilder.setColumn(i, ((Double)
_aggregationResult.get(i)).doubleValue());
- break;
- case OBJECT:
- dataTableBuilder.setColumn(i, _aggregationResult.get(i));
- break;
- default:
- throw new UnsupportedOperationException(
- "Unsupported aggregation column data type: " +
columnDataTypes[i] + " for column: " + columnNames[i]);
+ if (_nullHandlingEnabled) {
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ Object value = _aggregationResult.get(i);
+ // OBJECT (e.g. DistinctTable) calls toBytes() (e.g.
DistinctTable.toBytes()) which takes care of replacing
+ // nulls with default values, and building presence vector and
serializing both.
+ if (columnDataTypes[i] != ColumnDataType.OBJECT) {
+ if (value == null) {
+ value = colDefaultNullValues[i];
+ nullBitmaps[i].add(0);
+ }
+ }
+
+ switch (columnDataTypes[i]) {
+ case LONG:
+ dataTableBuilder.setColumn(i, ((Number) value).longValue());
+ break;
+ case DOUBLE:
+ dataTableBuilder.setColumn(i, ((Double) value).doubleValue());
+ break;
+ case OBJECT:
+ dataTableBuilder.setColumn(i, value);
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ "Unsupported aggregation column data type: " +
columnDataTypes[i] + " for column: " + columnNames[i]);
+ }
+ }
+ } else {
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ switch (columnDataTypes[i]) {
+ case LONG:
+ dataTableBuilder.setColumn(i, ((Number)
_aggregationResult.get(i)).longValue());
+ break;
+ case DOUBLE:
+ dataTableBuilder.setColumn(i, ((Double)
_aggregationResult.get(i)).doubleValue());
+ break;
+ case OBJECT:
+ dataTableBuilder.setColumn(i, _aggregationResult.get(i));
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ "Unsupported aggregation column data type: " +
columnDataTypes[i] + " for column: " + columnNames[i]);
+ }
}
}
+
dataTableBuilder.finishRow();
+ if (_nullHandlingEnabled) {
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ dataTableBuilder.setNullRowIds(nullBitmaps[i]);
+ }
+ }
DataTable dataTable = dataTableBuilder.build();
return attachMetadataToDataTable(dataTable);
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java
index 0206fc89fe..5f9a33b2dc 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java
@@ -126,14 +126,14 @@ public class AggregationGroupByOrderByOperator extends
BaseOperator<Intermediate
TableResizer tableResizer = new TableResizer(_dataSchema,
_queryContext);
Collection<IntermediateRecord> intermediateRecords =
groupByExecutor.trimGroupByResult(trimSize, tableResizer);
IntermediateResultsBlock resultsBlock = new IntermediateResultsBlock(
- _aggregationFunctions, intermediateRecords, _dataSchema);
+ _aggregationFunctions, intermediateRecords, _dataSchema,
_queryContext.isNullHandlingEnabled());
resultsBlock.setNumGroupsLimitReached(numGroupsLimitReached);
return resultsBlock;
}
}
IntermediateResultsBlock resultsBlock = new IntermediateResultsBlock(
- _aggregationFunctions, groupByExecutor.getResult(), _dataSchema);
+ _aggregationFunctions, groupByExecutor.getResult(), _dataSchema,
_queryContext.isNullHandlingEnabled());
resultsBlock.setNumGroupsLimitReached(numGroupsLimitReached);
return resultsBlock;
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
index c0cd84e2a8..8fa02db6f2 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
@@ -43,15 +43,17 @@ public class AggregationOperator extends
BaseOperator<IntermediateResultsBlock>
private final TransformOperator _transformOperator;
private final long _numTotalDocs;
private final boolean _useStarTree;
+ private final boolean _nullHandlingEnabled;
private int _numDocsScanned = 0;
public AggregationOperator(AggregationFunction[] aggregationFunctions,
TransformOperator transformOperator,
- long numTotalDocs, boolean useStarTree) {
+ long numTotalDocs, boolean useStarTree, boolean nullHandlingEnabled) {
_aggregationFunctions = aggregationFunctions;
_transformOperator = transformOperator;
_numTotalDocs = numTotalDocs;
_useStarTree = useStarTree;
+ _nullHandlingEnabled = nullHandlingEnabled;
}
@Override
@@ -70,7 +72,7 @@ public class AggregationOperator extends
BaseOperator<IntermediateResultsBlock>
}
// Build intermediate result block based on aggregation result from the
executor
- return new IntermediateResultsBlock(_aggregationFunctions,
aggregationExecutor.getResult());
+ return new IntermediateResultsBlock(_aggregationFunctions,
aggregationExecutor.getResult(), _nullHandlingEnabled);
}
@Override
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedDistinctOperator.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedDistinctOperator.java
index 36ef028080..5fef677cbd 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedDistinctOperator.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedDistinctOperator.java
@@ -77,7 +77,7 @@ public class DictionaryBasedDistinctOperator extends
BaseOperator<IntermediateRe
protected IntermediateResultsBlock getNextBlock() {
DistinctTable distinctTable = buildResult();
return new IntermediateResultsBlock(new
AggregationFunction[]{_distinctAggregationFunction},
- Collections.singletonList(distinctTable));
+ Collections.singletonList(distinctTable), _nullHandlingEnabled);
}
/**
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DistinctOperator.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DistinctOperator.java
index e559e50b6f..1e31f8eb13 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DistinctOperator.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DistinctOperator.java
@@ -71,7 +71,7 @@ public class DistinctOperator extends
BaseOperator<IntermediateResultsBlock> {
DistinctTable distinctTable = _distinctExecutor.getResult();
// TODO: Use a separate way to represent DISTINCT instead of aggregation.
return new IntermediateResultsBlock(new
AggregationFunction[]{_distinctAggregationFunction},
- Collections.singletonList(distinctTable));
+ Collections.singletonList(distinctTable),
_queryContext.isNullHandlingEnabled());
}
@Override
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FastFilteredCountOperator.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FastFilteredCountOperator.java
index d6c12d0d6a..01b241c20f 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FastFilteredCountOperator.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FastFilteredCountOperator.java
@@ -65,7 +65,7 @@ public class FastFilteredCountOperator extends
BaseOperator<IntermediateResultsB
List<Object> aggregates = new ArrayList<>(1);
aggregates.add(count);
_docsCounted += count;
- return new IntermediateResultsBlock(_aggregationFunctions, aggregates);
+ return new IntermediateResultsBlock(_aggregationFunctions, aggregates,
false);
}
@Override
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java
index 06aec147c6..c7e075db8b 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java
@@ -47,6 +47,7 @@ public class FilteredAggregationOperator extends
BaseOperator<IntermediateResult
private final AggregationFunction[] _aggregationFunctions;
private final List<Pair<AggregationFunction[], TransformOperator>>
_aggFunctionsWithTransformOperator;
private final long _numTotalDocs;
+ private final boolean _nullHandlingEnabled;
private long _numDocsScanned;
private long _numEntriesScannedInFilter;
@@ -55,10 +56,12 @@ public class FilteredAggregationOperator extends
BaseOperator<IntermediateResult
// We can potentially do away with aggregationFunctions parameter, but its
cleaner to pass it in than to construct
// it from aggFunctionsWithTransformOperator
public FilteredAggregationOperator(AggregationFunction[]
aggregationFunctions,
- List<Pair<AggregationFunction[], TransformOperator>>
aggFunctionsWithTransformOperator, long numTotalDocs) {
+ List<Pair<AggregationFunction[], TransformOperator>>
aggFunctionsWithTransformOperator, long numTotalDocs,
+ boolean nullHandlingEnabled) {
_aggregationFunctions = aggregationFunctions;
_aggFunctionsWithTransformOperator = aggFunctionsWithTransformOperator;
_numTotalDocs = numTotalDocs;
+ _nullHandlingEnabled = nullHandlingEnabled;
}
@Override
@@ -89,7 +92,7 @@ public class FilteredAggregationOperator extends
BaseOperator<IntermediateResult
_numEntriesScannedInFilter +=
transformOperator.getExecutionStatistics().getNumEntriesScannedInFilter();
_numEntriesScannedPostFilter += (long) numDocsScanned *
transformOperator.getNumColumnsProjected();
}
- return new IntermediateResultsBlock(_aggregationFunctions,
Arrays.asList(result));
+ return new IntermediateResultsBlock(_aggregationFunctions,
Arrays.asList(result), _nullHandlingEnabled);
}
@Override
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java
index 138e1e43c6..b045c51a48 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java
@@ -128,7 +128,7 @@ public class NonScanBasedAggregationOperator extends
BaseOperator<IntermediateRe
}
// Build intermediate result block based on aggregation result from the
executor.
- return new IntermediateResultsBlock(_aggregationFunctions,
aggregationResults);
+ return new IntermediateResultsBlock(_aggregationFunctions,
aggregationResults, false);
}
private static Double getMinValue(DataSource dataSource) {
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
index 74ea6af8de..b706d38b54 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
@@ -141,7 +141,8 @@ public class AggregationPlanNode implements PlanNode {
aggToTransformOpList.add(
Pair.of(nonFilteredAggregationFunctions.toArray(new
AggregationFunction[0]), mainTransformOperator));
- return new
FilteredAggregationOperator(_queryContext.getAggregationFunctions(),
aggToTransformOpList, numTotalDocs);
+ return new
FilteredAggregationOperator(_queryContext.getAggregationFunctions(),
aggToTransformOpList, numTotalDocs,
+ _queryContext.isNullHandlingEnabled());
}
/**
@@ -178,11 +179,11 @@ public class AggregationPlanNode implements PlanNode {
FilterPlanNode filterPlanNode = new FilterPlanNode(_indexSegment,
_queryContext);
BaseFilterOperator filterOperator = filterPlanNode.run();
- if (canOptimizeFilteredCount(filterOperator, aggregationFunctions)) {
+ if (canOptimizeFilteredCount(filterOperator, aggregationFunctions) &&
!_queryContext.isNullHandlingEnabled()) {
return new FastFilteredCountOperator(aggregationFunctions,
filterOperator, _indexSegment.getSegmentMetadata());
}
- if (filterOperator.isResultMatchingAll()) {
+ if (filterOperator.isResultMatchingAll() &&
!_queryContext.isNullHandlingEnabled()) {
if (isFitForNonScanBasedPlan(aggregationFunctions, _indexSegment)) {
DataSource[] dataSources = new DataSource[aggregationFunctions.length];
for (int i = 0; i < aggregationFunctions.length; i++) {
@@ -198,7 +199,7 @@ public class AggregationPlanNode implements PlanNode {
// Use star-tree to solve the query if possible
List<StarTreeV2> starTrees = _indexSegment.getStarTrees();
- if (starTrees != null && !_queryContext.isSkipStarTree()) {
+ if (starTrees != null && !_queryContext.isSkipStarTree() &&
!_queryContext.isNullHandlingEnabled()) {
AggregationFunctionColumnPair[] aggregationFunctionColumnPairs =
StarTreeUtils.extractAggregationFunctionPairs(aggregationFunctions);
if (aggregationFunctionColumnPairs != null) {
@@ -212,7 +213,7 @@ public class AggregationPlanNode implements PlanNode {
TransformOperator transformOperator =
new StarTreeTransformPlanNode(_queryContext, starTreeV2,
aggregationFunctionColumnPairs, null,
predicateEvaluatorsMap).run();
- return new AggregationOperator(aggregationFunctions,
transformOperator, numTotalDocs, true);
+ return new AggregationOperator(aggregationFunctions,
transformOperator, numTotalDocs, true, false);
}
}
}
@@ -224,7 +225,8 @@ public class AggregationPlanNode implements PlanNode {
TransformOperator transformOperator =
new TransformPlanNode(_indexSegment, _queryContext,
expressionsToTransform, DocIdSetPlanNode.MAX_DOC_PER_CALL,
filterOperator).run();
- return new AggregationOperator(aggregationFunctions, transformOperator,
numTotalDocs, false);
+ return new AggregationOperator(aggregationFunctions, transformOperator,
numTotalDocs, false,
+ _queryContext.isNullHandlingEnabled());
}
/**
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/plan/DistinctPlanNode.java
b/pinot-core/src/main/java/org/apache/pinot/core/plan/DistinctPlanNode.java
index 8bb67dc0fb..03f6cd7184 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/DistinctPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/DistinctPlanNode.java
@@ -70,7 +70,7 @@ public class DistinctPlanNode implements PlanNode {
// DictionaryBasedDistinctOperator can be reused since it is more
efficient than DistinctOperator for
// dictionary-encoded columns.
NullValueVectorReader nullValueReader =
dataSource.getNullValueVector();
- if (nullValueReader == null ||
nullValueReader.getNullBitmap().getCardinality() == 0) {
+ if (nullValueReader == null ||
nullValueReader.getNullBitmap().isEmpty()) {
return new
DictionaryBasedDistinctOperator(dataSourceMetadata.getDataType(),
distinctAggregationFunction,
dictionary, dataSourceMetadata.getNumDocs(),
_queryContext.isNullHandlingEnabled());
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/ObjectAggregationResultHolder.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/ObjectAggregationResultHolder.java
index d01be84ded..fdca760d13 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/ObjectAggregationResultHolder.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/ObjectAggregationResultHolder.java
@@ -30,7 +30,7 @@ public class ObjectAggregationResultHolder implements
AggregationResultHolder {
*/
@Override
public void setValue(double value) {
- throw new RuntimeException("Method 'setValue' (with double value) not
supported for class " + getClass().getName());
+ _value = value;
}
/**
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index f1d3de353f..7ae11d3857 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -147,17 +147,17 @@ public class AggregationFunctionFactory {
} else {
switch (AggregationFunctionType.valueOf(upperCaseFunctionName)) {
case COUNT:
- return new CountAggregationFunction();
+ return new CountAggregationFunction(firstArgument,
queryContext.isNullHandlingEnabled());
case MIN:
- return new MinAggregationFunction(firstArgument);
+ return new MinAggregationFunction(firstArgument,
queryContext.isNullHandlingEnabled());
case MAX:
- return new MaxAggregationFunction(firstArgument);
+ return new MaxAggregationFunction(firstArgument,
queryContext.isNullHandlingEnabled());
case SUM:
- return new SumAggregationFunction(firstArgument);
+ return new SumAggregationFunction(firstArgument,
queryContext.isNullHandlingEnabled());
case SUMPRECISION:
- return new SumPrecisionAggregationFunction(arguments);
+ return new SumPrecisionAggregationFunction(arguments,
queryContext.isNullHandlingEnabled());
case AVG:
- return new AvgAggregationFunction(firstArgument);
+ return new AvgAggregationFunction(firstArgument,
queryContext.isNullHandlingEnabled());
case MODE:
return new ModeAggregationFunction(arguments);
case FIRSTWITHTIME:
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
index 8a2af50f4f..5226176267 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
@@ -30,13 +30,20 @@ import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
import org.apache.pinot.segment.local.customobject.AvgPair;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.roaringbitmap.RoaringBitmap;
public class AvgAggregationFunction extends
BaseSingleInputAggregationFunction<AvgPair, Double> {
private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
+ private final boolean _nullHandlingEnabled;
public AvgAggregationFunction(ExpressionContext expression) {
+ this(expression, false);
+ }
+
+ public AvgAggregationFunction(ExpressionContext expression, boolean
nullHandlingEnabled) {
super(expression);
+ _nullHandlingEnabled = nullHandlingEnabled;
}
@Override
@@ -58,6 +65,13 @@ public class AvgAggregationFunction extends
BaseSingleInputAggregationFunction<A
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
+ if (_nullHandlingEnabled) {
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ aggregateNullHandlingEnabled(length, aggregationResultHolder,
blockValSet, nullBitmap);
+ return;
+ }
+ }
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
@@ -80,6 +94,41 @@ public class AvgAggregationFunction extends
BaseSingleInputAggregationFunction<A
}
}
+ private void aggregateNullHandlingEnabled(int length,
AggregationResultHolder aggregationResultHolder,
+ BlockValSet blockValSet, RoaringBitmap nullBitmap) {
+ if (blockValSet.getValueType() != DataType.BYTES) {
+ double[] doubleValues = blockValSet.getDoubleValuesSV();
+ if (nullBitmap.getCardinality() < length) {
+ double sum = 0.0;
+ // TODO: need to update the for-loop terminating condition to: i <
length & i < doubleValues.length?
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ sum += doubleValues[i];
+ }
+ }
+ setAggregationResult(aggregationResultHolder, sum, length);
+ }
+ // Note: when all input values re null (nullBitmap.getCardinality() ==
values.length), avg is null. As a result,
+ // we don't call setAggregationResult.
+ } else {
+ // Serialized AvgPair
+ byte[][] bytesValues = blockValSet.getBytesValuesSV();
+ if (nullBitmap.getCardinality() < length) {
+ double sum = 0.0;
+ long count = 0L;
+ // TODO: need to update the for-loop terminating condition to: i <
length & i < bytesValues.length?
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ AvgPair value =
ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]);
+ sum += value.getSum();
+ count += value.getCount();
+ }
+ }
+ setAggregationResult(aggregationResultHolder, sum, count);
+ }
+ }
+ }
+
protected void setAggregationResult(AggregationResultHolder
aggregationResultHolder, double sum, long count) {
AvgPair avgPair = aggregationResultHolder.getResult();
if (avgPair == null) {
@@ -93,6 +142,13 @@ public class AvgAggregationFunction extends
BaseSingleInputAggregationFunction<A
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
+ if (_nullHandlingEnabled) {
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ aggregateGroupBySVNullHandlingEnabled(length, groupKeyArray,
groupByResultHolder, blockValSet, nullBitmap);
+ return;
+ }
+ }
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
@@ -109,6 +165,35 @@ public class AvgAggregationFunction extends
BaseSingleInputAggregationFunction<A
}
}
+ private void aggregateGroupBySVNullHandlingEnabled(int length, int[]
groupKeyArray,
+ GroupByResultHolder groupByResultHolder, BlockValSet blockValSet,
RoaringBitmap nullBitmap) {
+ if (blockValSet.getValueType() != DataType.BYTES) {
+ double[] doubleValues = blockValSet.getDoubleValuesSV();
+ // TODO: need to update the for-loop terminating condition to: i <
length & i < valueArray.length?
+ if (nullBitmap.getCardinality() < length) {
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ int groupKey = groupKeyArray[i];
+ setGroupByResult(groupKey, groupByResultHolder, doubleValues[i],
1L);
+ }
+ }
+ }
+ } else {
+ // Serialized AvgPair
+ byte[][] bytesValues = blockValSet.getBytesValuesSV();
+ // TODO: need to update the for-loop terminating condition to: i <
length & i < valueArray.length?
+ if (nullBitmap.getCardinality() < length) {
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ int groupKey = groupKeyArray[i];
+ AvgPair avgPair =
ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]);
+ setGroupByResult(groupKey, groupByResultHolder, avgPair.getSum(),
avgPair.getCount());
+ }
+ }
+ }
+ }
+ }
+
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
@@ -149,24 +234,30 @@ public class AvgAggregationFunction extends
BaseSingleInputAggregationFunction<A
public AvgPair extractAggregationResult(AggregationResultHolder
aggregationResultHolder) {
AvgPair avgPair = aggregationResultHolder.getResult();
if (avgPair == null) {
- return new AvgPair(0.0, 0L);
- } else {
- return avgPair;
+ return _nullHandlingEnabled ? null : new AvgPair(0.0, 0L);
}
+ return avgPair;
}
@Override
public AvgPair extractGroupByResult(GroupByResultHolder groupByResultHolder,
int groupKey) {
AvgPair avgPair = groupByResultHolder.getResult(groupKey);
if (avgPair == null) {
- return new AvgPair(0.0, 0L);
- } else {
- return avgPair;
+ return _nullHandlingEnabled ? null : new AvgPair(0.0, 0L);
}
+ return avgPair;
}
@Override
public AvgPair merge(AvgPair intermediateResult1, AvgPair
intermediateResult2) {
+ if (_nullHandlingEnabled) {
+ if (intermediateResult1 == null) {
+ return intermediateResult2;
+ }
+ if (intermediateResult2 == null) {
+ return intermediateResult1;
+ }
+ }
intermediateResult1.apply(intermediateResult2);
return intermediateResult1;
}
@@ -183,6 +274,9 @@ public class AvgAggregationFunction extends
BaseSingleInputAggregationFunction<A
@Override
public Double extractFinalResult(AvgPair intermediateResult) {
+ if (intermediateResult == null) {
+ return null;
+ }
long count = intermediateResult.getCount();
if (count == 0L) {
return DEFAULT_FINAL_RESULT;
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
index f81c232695..53f36f4703 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
@@ -30,16 +30,29 @@ import
org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import
org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
+import org.roaringbitmap.RoaringBitmap;
-public class CountAggregationFunction implements AggregationFunction<Long,
Long> {
- private static final String COLUMN_NAME = "count_star";
- private static final String RESULT_COLUMN_NAME = "count(*)";
+public class CountAggregationFunction extends
BaseSingleInputAggregationFunction<Long, Long> {
+ private static final String COUNT_STAR_COLUMN_NAME = "count_star";
+ private static final String COUNT_STAR_RESULT_COLUMN_NAME = "count(*)";
private static final double DEFAULT_INITIAL_VALUE = 0.0;
// Special expression used by star-tree to pass in BlockValSet
private static final ExpressionContext STAR_TREE_COUNT_STAR_EXPRESSION =
ExpressionContext.forIdentifier(AggregationFunctionColumnPair.STAR);
+ private final boolean _nullHandlingEnabled;
+
+ public CountAggregationFunction(ExpressionContext expression) {
+ this(expression, false);
+ }
+
+ public CountAggregationFunction(ExpressionContext expression, boolean
nullHandlingEnabled) {
+ super(expression);
+ // Consider null values only when null handling is enabled and function is
not COUNT(*)
+ _nullHandlingEnabled = nullHandlingEnabled &&
!expression.getIdentifier().equals("*");
+ }
+
@Override
public AggregationFunctionType getType() {
return AggregationFunctionType.COUNT;
@@ -47,17 +60,17 @@ public class CountAggregationFunction implements
AggregationFunction<Long, Long>
@Override
public String getColumnName() {
- return COLUMN_NAME;
+ return _nullHandlingEnabled ? super.getColumnName() :
COUNT_STAR_COLUMN_NAME;
}
@Override
public String getResultColumnName() {
- return RESULT_COLUMN_NAME;
+ return _nullHandlingEnabled ? super.getResultColumnName() :
COUNT_STAR_RESULT_COLUMN_NAME;
}
@Override
public List<ExpressionContext> getInputExpressions() {
- return Collections.emptyList();
+ return _nullHandlingEnabled ? super.getInputExpressions() :
Collections.emptyList();
}
@Override
@@ -75,6 +88,16 @@ public class CountAggregationFunction implements
AggregationFunction<Long, Long>
Map<ExpressionContext, BlockValSet> blockValSetMap) {
if (blockValSetMap.isEmpty()) {
aggregationResultHolder.setValue(aggregationResultHolder.getDoubleResult() +
length);
+ } else if (_nullHandlingEnabled) {
+ assert blockValSetMap.size() == 1;
+ BlockValSet blockValSet = blockValSetMap.values().iterator().next();
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ int numNulls = 0;
+ if (nullBitmap != null) {
+ numNulls = nullBitmap.getCardinality();
+ }
+ assert numNulls <= length;
+
aggregationResultHolder.setValue(aggregationResultHolder.getDoubleResult() +
(length - numNulls));
} else {
// Star-tree pre-aggregated values
long[] valueArray =
blockValSetMap.get(STAR_TREE_COUNT_STAR_EXPRESSION).getLongValuesSV();
@@ -94,6 +117,34 @@ public class CountAggregationFunction implements
AggregationFunction<Long, Long>
int groupKey = groupKeyArray[i];
groupByResultHolder.setValueForKey(groupKey,
groupByResultHolder.getDoubleResult(groupKey) + 1);
}
+ } else if (_nullHandlingEnabled) {
+ // In Presto, null values are not counted:
+ // SELECT count(id) as count, key FROM (VALUES (null, 1), (null, 1),
(null, 2), (1, 3), (null, 3)) AS t(id, key)
+ // GROUP BY key ORDER BY key DESC;
+ // count | key
+ //-------+-----
+ // 1 | 3
+ // 0 | 2
+ // 0 | 1
+ assert blockValSetMap.size() == 1;
+ BlockValSet blockValSet = blockValSetMap.values().iterator().next();
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ if (nullBitmap.getCardinality() == length) {
+ return;
+ }
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ int groupKey = groupKeyArray[i];
+ groupByResultHolder.setValueForKey(groupKey,
groupByResultHolder.getDoubleResult(groupKey) + 1);
+ }
+ }
+ } else {
+ for (int i = 0; i < length; i++) {
+ int groupKey = groupKeyArray[i];
+ groupByResultHolder.setValueForKey(groupKey,
groupByResultHolder.getDoubleResult(groupKey) + 1);
+ }
+ }
} else {
// Star-tree pre-aggregated values
long[] valueArray =
blockValSetMap.get(STAR_TREE_COUNT_STAR_EXPRESSION).getLongValuesSV();
@@ -107,7 +158,7 @@ public class CountAggregationFunction implements
AggregationFunction<Long, Long>
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- if (blockValSetMap.isEmpty()) {
+ if (blockValSetMap.isEmpty() ||
!blockValSetMap.containsKey(STAR_TREE_COUNT_STAR_EXPRESSION)) {
for (int i = 0; i < length; i++) {
for (int groupKey : groupKeysArray[i]) {
groupByResultHolder.setValueForKey(groupKey,
groupByResultHolder.getDoubleResult(groupKey) + 1);
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java
index df94b26054..061111daaf 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java
@@ -29,7 +29,6 @@ import org.apache.pinot.segment.spi.AggregationFunctionType;
public class CountMVAggregationFunction extends CountAggregationFunction {
- private final ExpressionContext _expression;
/**
* Constructor for the class.
@@ -37,7 +36,8 @@ public class CountMVAggregationFunction extends
CountAggregationFunction {
* @param expression Expression to aggregate on.
*/
public CountMVAggregationFunction(ExpressionContext expression) {
- _expression = expression;
+ // TODO(nhejazi): support proper null handling for aggregation functions
on MV columns.
+ super(expression);
}
@Override
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
index f6bc1fffcf..2d0bac2803 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
@@ -25,16 +25,25 @@ import
org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.DoubleAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import
org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.roaringbitmap.RoaringBitmap;
public class MaxAggregationFunction extends
BaseSingleInputAggregationFunction<Double, Double> {
private static final double DEFAULT_INITIAL_VALUE = Double.NEGATIVE_INFINITY;
+ private final boolean _nullHandlingEnabled;
public MaxAggregationFunction(ExpressionContext expression) {
+ this(expression, false);
+ }
+
+ public MaxAggregationFunction(ExpressionContext expression, boolean
nullHandlingEnabled) {
super(expression);
+ _nullHandlingEnabled = nullHandlingEnabled;
}
@Override
@@ -44,11 +53,17 @@ public class MaxAggregationFunction extends
BaseSingleInputAggregationFunction<D
@Override
public AggregationResultHolder createAggregationResultHolder() {
+ if (_nullHandlingEnabled) {
+ return new ObjectAggregationResultHolder();
+ }
return new DoubleAggregationResultHolder(DEFAULT_INITIAL_VALUE);
}
@Override
public GroupByResultHolder createGroupByResultHolder(int initialCapacity,
int maxCapacity) {
+ if (_nullHandlingEnabled) {
+ return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+ }
return new DoubleGroupByResultHolder(initialCapacity, maxCapacity,
DEFAULT_INITIAL_VALUE);
}
@@ -56,6 +71,14 @@ public class MaxAggregationFunction extends
BaseSingleInputAggregationFunction<D
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
+ if (_nullHandlingEnabled) {
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ aggregateNullHandlingEnabled(length, aggregationResultHolder,
blockValSet, nullBitmap);
+ return;
+ }
+ }
+
switch (blockValSet.getValueType().getStoredType()) {
case INT: {
int[] values = blockValSet.getIntValuesSV();
@@ -108,10 +131,111 @@ public class MaxAggregationFunction extends
BaseSingleInputAggregationFunction<D
}
}
+ private void aggregateNullHandlingEnabled(int length,
AggregationResultHolder aggregationResultHolder,
+ BlockValSet blockValSet, RoaringBitmap nullBitmap) {
+ switch (blockValSet.getValueType().getStoredType()) {
+ case INT: {
+ if (nullBitmap.getCardinality() < length) {
+ int[] values = blockValSet.getIntValuesSV();
+ int max = Integer.MIN_VALUE;
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ max = Math.max(values[i], max);
+ }
+ }
+ updateAggregationResultHolder(aggregationResultHolder, max);
+ }
+ // Note: when all input values re null (nullBitmap.getCardinality() ==
values.length), max is null. As a result,
+ // we don't update the value of aggregationResultHolder.
+ break;
+ }
+ case LONG: {
+ if (nullBitmap.getCardinality() < length) {
+ long[] values = blockValSet.getLongValuesSV();
+ long max = Long.MIN_VALUE;
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ max = Math.max(values[i], max);
+ }
+ }
+ updateAggregationResultHolder(aggregationResultHolder, max);
+ }
+ break;
+ }
+ case FLOAT: {
+ if (nullBitmap.getCardinality() < length) {
+ float[] values = blockValSet.getFloatValuesSV();
+ float max = Float.NEGATIVE_INFINITY;
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ max = Math.max(values[i], max);
+ }
+ }
+ updateAggregationResultHolder(aggregationResultHolder, max);
+ }
+ break;
+ }
+ case DOUBLE: {
+ if (nullBitmap.getCardinality() < length) {
+ double[] values = blockValSet.getDoubleValuesSV();
+ double max = Double.NEGATIVE_INFINITY;
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ max = Math.max(values[i], max);
+ }
+ }
+ updateAggregationResultHolder(aggregationResultHolder, max);
+ }
+ break;
+ }
+ case BIG_DECIMAL: {
+ if (nullBitmap.getCardinality() < length) {
+ BigDecimal[] values = blockValSet.getBigDecimalValuesSV();
+ BigDecimal max = null;
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ max = max == null ? values[i] : values[i].max(max);
+ }
+ }
+ // TODO: even though the source data has BIG_DECIMAL type, we still
only support double precision.
+ assert max != null;
+ updateAggregationResultHolder(aggregationResultHolder,
max.doubleValue());
+ }
+ break;
+ }
+ default:
+ throw new IllegalStateException("Cannot compute max for non-numeric
type: " + blockValSet.getValueType());
+ }
+ }
+
+ private void updateAggregationResultHolder(AggregationResultHolder
aggregationResultHolder, double max) {
+ Double otherMax = aggregationResultHolder.getResult();
+ aggregationResultHolder.setValue(otherMax == null ? max : Math.max(max,
otherMax));
+ }
+
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+ if (_nullHandlingEnabled) {
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ if (nullBitmap.getCardinality() < length) {
+ double[] valueArray = blockValSet.getDoubleValuesSV();
+ for (int i = 0; i < length; i++) {
+ double value = valueArray[i];
+ int groupKey = groupKeyArray[i];
+ Double result = groupByResultHolder.getResult(groupKey);
+ if (!nullBitmap.contains(i) && (result == null || value > result))
{
+ groupByResultHolder.setValueForKey(groupKey, value);
+ }
+ }
+ }
+ return;
+ }
+ }
+
+ double[] valueArray = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
double value = valueArray[i];
int groupKey = groupKeyArray[i];
@@ -137,21 +261,35 @@ public class MaxAggregationFunction extends
BaseSingleInputAggregationFunction<D
@Override
public Double extractAggregationResult(AggregationResultHolder
aggregationResultHolder) {
+ if (_nullHandlingEnabled) {
+ return aggregationResultHolder.getResult();
+ }
return aggregationResultHolder.getDoubleResult();
}
@Override
public Double extractGroupByResult(GroupByResultHolder groupByResultHolder,
int groupKey) {
+ if (_nullHandlingEnabled) {
+ return groupByResultHolder.getResult(groupKey);
+ }
return groupByResultHolder.getDoubleResult(groupKey);
}
@Override
- public Double merge(Double intermediateResult1, Double intermediateResult2) {
- if (intermediateResult1 > intermediateResult2) {
- return intermediateResult1;
- } else {
- return intermediateResult2;
+ public Double merge(Double intermediateMaxResult1, Double
intermediateMaxResult2) {
+ if (_nullHandlingEnabled) {
+ if (intermediateMaxResult1 == null) {
+ return intermediateMaxResult2;
+ }
+ if (intermediateMaxResult2 == null) {
+ return intermediateMaxResult1;
+ }
+ }
+
+ if (intermediateMaxResult1 > intermediateMaxResult2) {
+ return intermediateMaxResult1;
}
+ return intermediateMaxResult2;
}
@Override
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
index ab04f2ba80..2e80387cd2 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
@@ -25,16 +25,25 @@ import
org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.DoubleAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import
org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.roaringbitmap.RoaringBitmap;
public class MinAggregationFunction extends
BaseSingleInputAggregationFunction<Double, Double> {
private static final double DEFAULT_VALUE = Double.POSITIVE_INFINITY;
+ private final boolean _nullHandlingEnabled;
public MinAggregationFunction(ExpressionContext expression) {
+ this(expression, false);
+ }
+
+ public MinAggregationFunction(ExpressionContext expression, boolean
nullHandlingEnabled) {
super(expression);
+ _nullHandlingEnabled = nullHandlingEnabled;
}
@Override
@@ -44,11 +53,17 @@ public class MinAggregationFunction extends
BaseSingleInputAggregationFunction<D
@Override
public AggregationResultHolder createAggregationResultHolder() {
+ if (_nullHandlingEnabled) {
+ return new ObjectAggregationResultHolder();
+ }
return new DoubleAggregationResultHolder(DEFAULT_VALUE);
}
@Override
public GroupByResultHolder createGroupByResultHolder(int initialCapacity,
int maxCapacity) {
+ if (_nullHandlingEnabled) {
+ return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+ }
return new DoubleGroupByResultHolder(initialCapacity, maxCapacity,
DEFAULT_VALUE);
}
@@ -56,6 +71,14 @@ public class MinAggregationFunction extends
BaseSingleInputAggregationFunction<D
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
+ if (_nullHandlingEnabled) {
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ aggregateNullHandlingEnabled(length, aggregationResultHolder,
blockValSet, nullBitmap);
+ return;
+ }
+ }
+
switch (blockValSet.getValueType().getStoredType()) {
case INT: {
int[] values = blockValSet.getIntValuesSV();
@@ -108,10 +131,111 @@ public class MinAggregationFunction extends
BaseSingleInputAggregationFunction<D
}
}
+ private void aggregateNullHandlingEnabled(int length,
AggregationResultHolder aggregationResultHolder,
+ BlockValSet blockValSet, RoaringBitmap nullBitmap) {
+ switch (blockValSet.getValueType().getStoredType()) {
+ case INT: {
+ if (nullBitmap.getCardinality() < length) {
+ int[] values = blockValSet.getIntValuesSV();
+ int min = Integer.MAX_VALUE;
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ min = Math.min(values[i], min);
+ }
+ }
+ updateAggregationResultHolder(aggregationResultHolder, min);
+ }
+ // Note: when all input values re null (nullBitmap.getCardinality() ==
values.length), min is null. As a result,
+ // we don't update the value of aggregationResultHolder.
+ break;
+ }
+ case LONG: {
+ if (nullBitmap.getCardinality() < length) {
+ long[] values = blockValSet.getLongValuesSV();
+ long min = Long.MAX_VALUE;
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ min = Math.min(values[i], min);
+ }
+ }
+ updateAggregationResultHolder(aggregationResultHolder, min);
+ }
+ break;
+ }
+ case FLOAT: {
+ if (nullBitmap.getCardinality() < length) {
+ float[] values = blockValSet.getFloatValuesSV();
+ float min = Float.POSITIVE_INFINITY;
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ min = Math.min(values[i], min);
+ }
+ }
+ updateAggregationResultHolder(aggregationResultHolder, min);
+ }
+ break;
+ }
+ case DOUBLE: {
+ if (nullBitmap.getCardinality() < length) {
+ double[] values = blockValSet.getDoubleValuesSV();
+ double min = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ min = Math.min(values[i], min);
+ }
+ }
+ updateAggregationResultHolder(aggregationResultHolder, min);
+ }
+ break;
+ }
+ case BIG_DECIMAL: {
+ if (nullBitmap.getCardinality() < length) {
+ BigDecimal[] values = blockValSet.getBigDecimalValuesSV();
+ BigDecimal min = null;
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ min = min == null ? values[i] : values[i].min(min);
+ }
+ }
+ assert min != null;
+ // TODO: even though the source data has BIG_DECIMAL type, we still
only support double precision.
+ updateAggregationResultHolder(aggregationResultHolder,
min.doubleValue());
+ }
+ break;
+ }
+ default:
+ throw new IllegalStateException("Cannot compute min for non-numeric
type: " + blockValSet.getValueType());
+ }
+ }
+
+ private void updateAggregationResultHolder(AggregationResultHolder
aggregationResultHolder, double min) {
+ Double otherMin = aggregationResultHolder.getResult();
+ aggregationResultHolder.setValue(otherMin == null ? min : Math.min(min,
otherMin));
+ }
+
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+ if (_nullHandlingEnabled) {
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ if (nullBitmap.getCardinality() < length) {
+ double[] valueArray = blockValSet.getDoubleValuesSV();
+ for (int i = 0; i < length; i++) {
+ double value = valueArray[i];
+ int groupKey = groupKeyArray[i];
+ Double result = groupByResultHolder.getResult(groupKey);
+ if (!nullBitmap.contains(i) && (result == null || value < result))
{
+ groupByResultHolder.setValueForKey(groupKey, value);
+ }
+ }
+ }
+ return;
+ }
+ }
+
+ double[] valueArray = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
double value = valueArray[i];
int groupKey = groupKeyArray[i];
@@ -137,21 +261,35 @@ public class MinAggregationFunction extends
BaseSingleInputAggregationFunction<D
@Override
public Double extractAggregationResult(AggregationResultHolder
aggregationResultHolder) {
+ if (_nullHandlingEnabled) {
+ return aggregationResultHolder.getResult();
+ }
return aggregationResultHolder.getDoubleResult();
}
@Override
public Double extractGroupByResult(GroupByResultHolder groupByResultHolder,
int groupKey) {
+ if (_nullHandlingEnabled) {
+ return groupByResultHolder.getResult(groupKey);
+ }
return groupByResultHolder.getDoubleResult(groupKey);
}
@Override
- public Double merge(Double intermediateResult1, Double intermediateResult2) {
- if (intermediateResult1 < intermediateResult2) {
- return intermediateResult1;
- } else {
- return intermediateResult2;
+ public Double merge(Double intermediateMinResult1, Double
intermediateMinResult2) {
+ if (_nullHandlingEnabled) {
+ if (intermediateMinResult1 == null) {
+ return intermediateMinResult2;
+ }
+ if (intermediateMinResult2 == null) {
+ return intermediateMinResult1;
+ }
+ }
+
+ if (intermediateMinResult1 < intermediateMinResult2) {
+ return intermediateMinResult1;
}
+ return intermediateMinResult2;
}
@Override
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
index 0ce91986c4..8456eda436 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
@@ -25,16 +25,25 @@ import
org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.DoubleAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import
org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.roaringbitmap.RoaringBitmap;
public class SumAggregationFunction extends
BaseSingleInputAggregationFunction<Double, Double> {
private static final double DEFAULT_VALUE = 0.0;
+ private final boolean _nullHandlingEnabled;
public SumAggregationFunction(ExpressionContext expression) {
+ this(expression, false);
+ }
+
+ public SumAggregationFunction(ExpressionContext expression, boolean
nullHandlingEnabled) {
super(expression);
+ _nullHandlingEnabled = nullHandlingEnabled;
}
@Override
@@ -44,19 +53,33 @@ public class SumAggregationFunction extends
BaseSingleInputAggregationFunction<D
@Override
public AggregationResultHolder createAggregationResultHolder() {
+ if (_nullHandlingEnabled) {
+ return new ObjectAggregationResultHolder();
+ }
return new DoubleAggregationResultHolder(DEFAULT_VALUE);
}
@Override
public GroupByResultHolder createGroupByResultHolder(int initialCapacity,
int maxCapacity) {
+ if (_nullHandlingEnabled) {
+ return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+ }
return new DoubleGroupByResultHolder(initialCapacity, maxCapacity,
DEFAULT_VALUE);
}
@Override
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- double sum = aggregationResultHolder.getDoubleResult();
BlockValSet blockValSet = blockValSetMap.get(_expression);
+ if (_nullHandlingEnabled) {
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ aggregateNullHandlingEnabled(length, aggregationResultHolder,
blockValSet, nullBitmap);
+ return;
+ }
+ }
+
+ double sum = aggregationResultHolder.getDoubleResult();
switch (blockValSet.getValueType().getStoredType()) {
case INT: {
int[] values = blockValSet.getIntValuesSV();
@@ -102,10 +125,111 @@ public class SumAggregationFunction extends
BaseSingleInputAggregationFunction<D
aggregationResultHolder.setValue(sum);
}
+ private void aggregateNullHandlingEnabled(int length,
AggregationResultHolder aggregationResultHolder,
+ BlockValSet blockValSet, RoaringBitmap nullBitmap) {
+ double sum = 0;
+ switch (blockValSet.getValueType().getStoredType()) {
+ case INT: {
+ if (nullBitmap.getCardinality() < length) {
+ int[] values = blockValSet.getIntValuesSV();
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ sum += values[i];
+ }
+ }
+ setAggregationResultHolder(aggregationResultHolder, sum);
+ }
+ break;
+ }
+ case LONG: {
+ if (nullBitmap.getCardinality() < length) {
+ long[] values = blockValSet.getLongValuesSV();
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ sum += values[i];
+ }
+ }
+ setAggregationResultHolder(aggregationResultHolder, sum);
+ }
+ break;
+ }
+ case FLOAT: {
+ if (nullBitmap.getCardinality() < length) {
+ float[] values = blockValSet.getFloatValuesSV();
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ sum += values[i];
+ }
+ }
+ setAggregationResultHolder(aggregationResultHolder, sum);
+ }
+ break;
+ }
+ case DOUBLE: {
+ if (nullBitmap.getCardinality() < length) {
+ double[] values = blockValSet.getDoubleValuesSV();
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ sum += values[i];
+ }
+ }
+ setAggregationResultHolder(aggregationResultHolder, sum);
+ }
+ break;
+ }
+ case BIG_DECIMAL: {
+ if (nullBitmap.getCardinality() < length) {
+ BigDecimal[] values = blockValSet.getBigDecimalValuesSV();
+ BigDecimal decimalSum = BigDecimal.valueOf(sum);
+ for (int i = 0; i < length & i < values.length; i++) {
+ if (!nullBitmap.contains(i)) {
+ decimalSum = decimalSum.add(values[i]);
+ }
+ }
+ // TODO: even though the source data has BIG_DECIMAL type, we still
only support double precision.
+ setAggregationResultHolder(aggregationResultHolder,
decimalSum.doubleValue());
+ }
+ break;
+ }
+ default:
+ throw new IllegalStateException("Cannot compute sum for non-numeric
type: " + blockValSet.getValueType());
+ }
+ }
+
+ private void setAggregationResultHolder(AggregationResultHolder
aggregationResultHolder, double sum) {
+ Double otherSum = aggregationResultHolder.getResult();
+ aggregationResultHolder.setValue(otherSum == null ? sum : sum + otherSum);
+ }
+
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+ if (_nullHandlingEnabled) {
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ if (nullBitmap.getCardinality() < length) {
+ double[] valueArray = blockValSet.getDoubleValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ int groupKey = groupKeyArray[i];
+ Double result = groupByResultHolder.getResult(groupKey);
+ groupByResultHolder.setValueForKey(groupKey, result == null ?
valueArray[i] : result + valueArray[i]);
+ // In presto:
+ // SELECT sum (cast(id AS DOUBLE)) as sum, min(id) as min,
max(id) as max, key FROM (VALUES (null, 1),
+ // (null, 2)) AS t(id, key) GROUP BY key ORDER BY max DESC;
+ // sum | min | max | key
+ //------+------+------+-----
+ // NULL | NULL | NULL | 2
+ // NULL | NULL | NULL | 1
+ }
+ }
+ }
+ return;
+ }
+ }
+
+ double[] valueArray = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
int groupKey = groupKeyArray[i];
groupByResultHolder.setValueForKey(groupKey,
groupByResultHolder.getDoubleResult(groupKey) + valueArray[i]);
@@ -126,16 +250,30 @@ public class SumAggregationFunction extends
BaseSingleInputAggregationFunction<D
@Override
public Double extractAggregationResult(AggregationResultHolder
aggregationResultHolder) {
+ if (_nullHandlingEnabled) {
+ return aggregationResultHolder.getResult();
+ }
return aggregationResultHolder.getDoubleResult();
}
@Override
public Double extractGroupByResult(GroupByResultHolder groupByResultHolder,
int groupKey) {
+ if (_nullHandlingEnabled) {
+ return groupByResultHolder.getResult(groupKey);
+ }
return groupByResultHolder.getDoubleResult(groupKey);
}
@Override
public Double merge(Double intermediateResult1, Double intermediateResult2) {
+ if (_nullHandlingEnabled) {
+ if (intermediateResult1 == null) {
+ return intermediateResult2;
+ }
+ if (intermediateResult2 == null) {
+ return intermediateResult1;
+ }
+ }
return intermediateResult1 + intermediateResult2;
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
index f6aee3cc82..a460d30d26 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
@@ -33,6 +33,7 @@ import
org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.utils.BigDecimalUtils;
+import org.roaringbitmap.RoaringBitmap;
/**
@@ -48,8 +49,9 @@ import org.apache.pinot.spi.utils.BigDecimalUtils;
public class SumPrecisionAggregationFunction extends
BaseSingleInputAggregationFunction<BigDecimal, BigDecimal> {
private final Integer _precision;
private final Integer _scale;
+ private final boolean _nullHandlingEnabled;
- public SumPrecisionAggregationFunction(List<ExpressionContext> arguments) {
+ public SumPrecisionAggregationFunction(List<ExpressionContext> arguments,
boolean nullHandlingEnabled) {
super(arguments.get(0));
int numArguments = arguments.size();
@@ -65,6 +67,7 @@ public class SumPrecisionAggregationFunction extends
BaseSingleInputAggregationF
_precision = null;
_scale = null;
}
+ _nullHandlingEnabled = nullHandlingEnabled;
}
@Override
@@ -85,8 +88,16 @@ public class SumPrecisionAggregationFunction extends
BaseSingleInputAggregationF
@Override
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
- BigDecimal sum = getDefaultResult(aggregationResultHolder);
BlockValSet blockValSet = blockValSetMap.get(_expression);
+ if (_nullHandlingEnabled) {
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ aggregateNullHandlingEnabled(length, aggregationResultHolder,
blockValSet, nullBitmap);
+ return;
+ }
+ }
+
+ BigDecimal sum = getDefaultResult(aggregationResultHolder);
switch (blockValSet.getValueType().getStoredType()) {
case INT:
int[] intValues = blockValSet.getIntValuesSV();
@@ -126,10 +137,119 @@ public class SumPrecisionAggregationFunction extends
BaseSingleInputAggregationF
aggregationResultHolder.setValue(sum);
}
+ private void aggregateNullHandlingEnabled(int length,
AggregationResultHolder aggregationResultHolder,
+ BlockValSet blockValSet, RoaringBitmap nullBitmap) {
+ BigDecimal sum = BigDecimal.ZERO;
+ switch (blockValSet.getValueType().getStoredType()) {
+ case INT: {
+ if (nullBitmap.getCardinality() < length) {
+ int[] intValues = blockValSet.getIntValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ sum = sum.add(BigDecimal.valueOf(intValues[i]));
+ }
+ }
+ setAggregationResult(aggregationResultHolder, sum);
+ }
+ break;
+ }
+ case LONG: {
+ if (nullBitmap.getCardinality() < length) {
+ long[] longValues = blockValSet.getLongValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ sum = sum.add(BigDecimal.valueOf(longValues[i]));
+ }
+ }
+ setAggregationResult(aggregationResultHolder, sum);
+ }
+ break;
+ }
+ case FLOAT: {
+ if (nullBitmap.getCardinality() < length) {
+ float[] floatValues = blockValSet.getFloatValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ if (Float.isFinite(floatValues[i])) {
+ sum = sum.add(BigDecimal.valueOf(floatValues[i]));
+ }
+ }
+ }
+ setAggregationResult(aggregationResultHolder, sum);
+ }
+ break;
+ }
+ case DOUBLE: {
+ if (nullBitmap.getCardinality() < length) {
+ double[] doubleValues = blockValSet.getDoubleValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ // TODO(nhejazi): throw an exception here instead of ignoring
infinite values?
+ if (Double.isFinite(doubleValues[i])) {
+ sum = sum.add(BigDecimal.valueOf(doubleValues[i]));
+ }
+ }
+ }
+ setAggregationResult(aggregationResultHolder, sum);
+ }
+ break;
+ }
+ case STRING:
+ if (nullBitmap.getCardinality() < length) {
+ String[] stringValues = blockValSet.getStringValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ sum = sum.add(new BigDecimal(stringValues[i]));
+ }
+ }
+ setAggregationResult(aggregationResultHolder, sum);
+ }
+ break;
+ case BIG_DECIMAL: {
+ if (nullBitmap.getCardinality() < length) {
+ BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ sum = sum.add(bigDecimalValues[i]);
+ }
+ }
+ setAggregationResult(aggregationResultHolder, sum);
+ }
+ break;
+ }
+ case BYTES:
+ if (nullBitmap.getCardinality() < length) {
+ byte[][] bytesValues = blockValSet.getBytesValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ sum = sum.add(BigDecimalUtils.deserialize(bytesValues[i]));
+ }
+ }
+ setAggregationResult(aggregationResultHolder, sum);
+ }
+ break;
+ default:
+ throw new IllegalStateException();
+ }
+ }
+
+ protected void setAggregationResult(AggregationResultHolder
aggregationResultHolder, BigDecimal sum) {
+ BigDecimal otherSum = aggregationResultHolder.getResult();
+ aggregationResultHolder.setValue(otherSum == null ? sum :
sum.add(otherSum));
+ }
+
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
+ if (_nullHandlingEnabled) {
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ aggregateGroupBySVNullHandlingEnabled(length, groupKeyArray,
groupByResultHolder, blockValSet, nullBitmap);
+ return;
+ }
+ }
+
switch (blockValSet.getValueType().getStoredType()) {
case INT:
int[] intValues = blockValSet.getIntValuesSV();
@@ -183,6 +303,72 @@ public class SumPrecisionAggregationFunction extends
BaseSingleInputAggregationF
}
}
+ private void aggregateGroupBySVNullHandlingEnabled(int length, int[]
groupKeyArray,
+ GroupByResultHolder groupByResultHolder, BlockValSet blockValSet,
RoaringBitmap nullBitmap) {
+ switch (blockValSet.getValueType().getStoredType()) {
+ case INT:
+ if (nullBitmap.getCardinality() < length) {
+ int[] intValues = blockValSet.getIntValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ setGroupByResult(groupKeyArray[i], groupByResultHolder,
BigDecimal.valueOf(intValues[i]));
+ }
+ }
+ }
+ break;
+ case LONG:
+ if (nullBitmap.getCardinality() < length) {
+ long[] longValues = blockValSet.getLongValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ setGroupByResult(groupKeyArray[i], groupByResultHolder,
BigDecimal.valueOf(longValues[i]));
+ }
+ }
+ }
+ break;
+ case FLOAT:
+ case DOUBLE:
+ case STRING:
+ if (nullBitmap.getCardinality() < length) {
+ String[] stringValues = blockValSet.getStringValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ setGroupByResult(groupKeyArray[i], groupByResultHolder, new
BigDecimal(stringValues[i]));
+ }
+ }
+ }
+ break;
+ case BIG_DECIMAL:
+ if (nullBitmap.getCardinality() < length) {
+ BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ setGroupByResult(groupKeyArray[i], groupByResultHolder,
bigDecimalValues[i]);
+ }
+ }
+ }
+ break;
+ case BYTES:
+ if (nullBitmap.getCardinality() < length) {
+ byte[][] bytesValues = blockValSet.getBytesValuesSV();
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ setGroupByResult(groupKeyArray[i], groupByResultHolder,
BigDecimalUtils.deserialize(bytesValues[i]));
+ }
+ }
+ }
+ break;
+ default:
+ throw new IllegalStateException();
+ }
+ }
+
+ private void setGroupByResult(int groupKey, GroupByResultHolder
groupByResultHolder, BigDecimal value) {
+ BigDecimal sum = groupByResultHolder.getResult(groupKey);
+ sum = sum == null ? value : sum.add(value);
+ groupByResultHolder.setValueForKey(groupKey, sum);
+ }
+
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
@@ -252,16 +438,32 @@ public class SumPrecisionAggregationFunction extends
BaseSingleInputAggregationF
@Override
public BigDecimal extractAggregationResult(AggregationResultHolder
aggregationResultHolder) {
- return getDefaultResult(aggregationResultHolder);
+ BigDecimal result = aggregationResultHolder.getResult();
+ if (result == null) {
+ return _nullHandlingEnabled ? null : BigDecimal.ZERO;
+ }
+ return result;
}
@Override
public BigDecimal extractGroupByResult(GroupByResultHolder
groupByResultHolder, int groupKey) {
- return getDefaultResult(groupByResultHolder, groupKey);
+ BigDecimal result = groupByResultHolder.getResult(groupKey);
+ if (result == null) {
+ return _nullHandlingEnabled ? null : BigDecimal.ZERO;
+ }
+ return result;
}
@Override
public BigDecimal merge(BigDecimal intermediateResult1, BigDecimal
intermediateResult2) {
+ if (_nullHandlingEnabled) {
+ if (intermediateResult1 == null) {
+ return intermediateResult2;
+ }
+ if (intermediateResult2 == null) {
+ return intermediateResult1;
+ }
+ }
return intermediateResult1.add(intermediateResult2);
}
@@ -277,6 +479,9 @@ public class SumPrecisionAggregationFunction extends
BaseSingleInputAggregationF
@Override
public BigDecimal extractFinalResult(BigDecimal intermediateResult) {
+ if (intermediateResult == null) {
+ return null;
+ }
if (_precision == null) {
return intermediateResult;
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/ObjectGroupByResultHolder.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/ObjectGroupByResultHolder.java
index 0c3374ca7f..1f807a9e7c 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/ObjectGroupByResultHolder.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/ObjectGroupByResultHolder.java
@@ -77,7 +77,9 @@ public class ObjectGroupByResultHolder implements
GroupByResultHolder {
@Override
public void setValueForKey(int groupKey, double newValue) {
- throw new UnsupportedOperationException();
+ if (groupKey != GroupKeyGenerator.INVALID_ID) {
+ _resultArray[groupKey] = newValue;
+ }
}
@SuppressWarnings("unchecked")
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
index 5f6c671952..dd29a24c76 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
@@ -29,6 +29,7 @@ import org.apache.pinot.common.utils.DataTable;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.transport.ServerRoutingInstance;
+import org.roaringbitmap.RoaringBitmap;
/**
@@ -65,18 +66,36 @@ public class AggregationDataTableReducer implements
DataTableReducer {
for (int i = 0; i < numAggregationFunctions; i++) {
Object intermediateResultToMerge;
ColumnDataType columnDataType = dataSchema.getColumnDataType(i);
- switch (columnDataType) {
- case LONG:
- intermediateResultToMerge = dataTable.getLong(0, i);
- break;
- case DOUBLE:
- intermediateResultToMerge = dataTable.getDouble(0, i);
- break;
- case OBJECT:
- intermediateResultToMerge = dataTable.getObject(0, i);
- break;
- default:
- throw new IllegalStateException("Illegal column data type in
aggregation results: " + columnDataType);
+ if (_queryContext.isNullHandlingEnabled()) {
+ RoaringBitmap nullBitmap = dataTable.getNullRowIds(i);
+ boolean isNull = nullBitmap != null && nullBitmap.contains(0);
+ switch (columnDataType) {
+ case LONG:
+ intermediateResultToMerge = isNull ? null : dataTable.getLong(0,
i);
+ break;
+ case DOUBLE:
+ intermediateResultToMerge = isNull ? null :
dataTable.getDouble(0, i);
+ break;
+ case OBJECT:
+ intermediateResultToMerge = isNull ? null :
dataTable.getObject(0, i);
+ break;
+ default:
+ throw new IllegalStateException("Illegal column data type in
aggregation results: " + columnDataType);
+ }
+ } else {
+ switch (columnDataType) {
+ case LONG:
+ intermediateResultToMerge = dataTable.getLong(0, i);
+ break;
+ case DOUBLE:
+ intermediateResultToMerge = dataTable.getDouble(0, i);
+ break;
+ case OBJECT:
+ intermediateResultToMerge = dataTable.getObject(0, i);
+ break;
+ default:
+ throw new IllegalStateException("Illegal column data type in
aggregation results: " + columnDataType);
+ }
}
Object mergedIntermediateResult = intermediateResults[i];
if (mergedIntermediateResult == null) {
@@ -89,8 +108,8 @@ public class AggregationDataTableReducer implements
DataTableReducer {
Object[] finalResults = new Object[numAggregationFunctions];
for (int i = 0; i < numAggregationFunctions; i++) {
AggregationFunction aggregationFunction = _aggregationFunctions[i];
- finalResults[i] = aggregationFunction.getFinalResultColumnType()
-
.convert(aggregationFunction.extractFinalResult(intermediateResults[i]));
+ Comparable result =
aggregationFunction.extractFinalResult(intermediateResults[i]);
+ finalResults[i] = result == null ? null :
aggregationFunction.getFinalResultColumnType().convert(result);
}
brokerResponseNative.setResultTable(reduceToResultTable(finalResults));
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
index 58deac76ee..d40f17fce2 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
@@ -31,13 +31,14 @@ import static org.testng.Assert.assertTrue;
@SuppressWarnings("rawtypes")
public class AggregationFunctionFactoryTest {
- private static final String ARGUMENT = "(column)";
+ private static final String ARGUMENT_COLUMN = "(column)";
+ private static final String ARGUMENT_STAR = "(*)";
private static final QueryContext DUMMY_QUERY_CONTEXT =
QueryContextConverterUtils.getQueryContext("SELECT * FROM testTable");
@Test
public void testGetAggregationFunction() {
- FunctionContext function = getFunction("CoUnT");
+ FunctionContext function = getFunction("CoUnT", ARGUMENT_STAR);
AggregationFunction aggregationFunction =
AggregationFunctionFactory.getAggregationFunction(function,
DUMMY_QUERY_CONTEXT);
assertTrue(aggregationFunction instanceof CountAggregationFunction);
@@ -446,7 +447,7 @@ public class AggregationFunctionFactoryTest {
}
private FunctionContext getFunction(String functionName) {
- return getFunction(functionName, ARGUMENT);
+ return getFunction(functionName, ARGUMENT_COLUMN);
}
private FunctionContext getFunction(String functionName, String args) {
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/CountStarTreeV2Test.java
b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/CountStarTreeV2Test.java
index 321c173938..4627e8bd3c 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/CountStarTreeV2Test.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/CountStarTreeV2Test.java
@@ -26,10 +26,10 @@ import org.apache.pinot.spi.data.FieldSpec.DataType;
import static org.testng.Assert.assertEquals;
-public class CountStarTreeV2Test extends BaseStarTreeV2Test<Void, Long> {
+public class CountStarTreeV2Test extends BaseStarTreeV2Test<Object, Long> {
@Override
- ValueAggregator<Void, Long> getValueAggregator() {
+ ValueAggregator<Object, Long> getValueAggregator() {
return new CountValueAggregator();
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/AllNullQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/AllNullQueriesTest.java
index 9fc8124919..c3e957b382 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/AllNullQueriesTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/AllNullQueriesTest.java
@@ -287,6 +287,30 @@ public class AllNullQueriesTest extends BaseQueriesTest {
DataTableFactory.setDataTableVersion(DataTableFactory.VERSION_4);
Map<String, String> queryOptions = new HashMap<>();
queryOptions.put("enableNullHandling", "true");
+ DataType dataType = columnDataType.toDataType();
+ if (columnDataType != ColumnDataType.STRING) {
+ {
+ String query = String.format(
+ "SELECT count(*) as count1, count(%s) as count2, min(%s) as min,
max(%s) as max FROM testTable",
+ COLUMN_NAME, COLUMN_NAME, COLUMN_NAME);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"count1",
"count2", "min", "max"}, new ColumnDataType[]{
+ ColumnDataType.LONG, ColumnDataType.LONG, ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE
+ }));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ Object[] row = rows.get(0);
+ assertEquals(row.length, 4);
+ // Note: count(*) returns total number of docs (nullable and
non-nullable).
+ assertEquals((long) row[0], 1000 * 4);
+ // count(col) returns the count of non-nullable docs.
+ assertEquals((long) row[1], 0);
+ assertNull(row[2]);
+ assertNull(row[3]);
+ }
+ }
{
String query = "SELECT * FROM testTable";
BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
@@ -363,6 +387,28 @@ public class AllNullQueriesTest extends BaseQueriesTest {
List<Object[]> rows = resultTable.getRows();
assertEquals(rows.size(), 1);
}
+ if (columnDataType != ColumnDataType.STRING) {
+ {
+ String query = String.format("SELECT COUNT(%s) AS count, MIN(%s) AS
min, MAX(%s) AS max, AVG(%s) AS avg,"
+ + " SUM(%s) AS sum FROM testTable LIMIT 1000", COLUMN_NAME,
COLUMN_NAME, COLUMN_NAME, COLUMN_NAME,
+ COLUMN_NAME);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"count", "min",
"max", "avg", "sum"},
+ new ColumnDataType[] {
+ ColumnDataType.LONG, ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
+ ColumnDataType.DOUBLE
+ }));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertEquals((long) rows.get(0)[0], 0);
+ assertNull(rows.get(0)[1]);
+ assertNull(rows.get(0)[2]);
+ assertNull(rows.get(0)[3]);
+ assertNull(rows.get(0)[4]);
+ }
+ }
{
String query = String.format("SELECT %s FROM testTable GROUP BY %s ORDER
BY %s DESC", COLUMN_NAME, COLUMN_NAME,
COLUMN_NAME);
@@ -390,6 +436,104 @@ public class AllNullQueriesTest extends BaseQueriesTest {
assertEquals(row[0], 4000L);
assertNull(row[1]);
}
+ {
+ String query = String.format("SELECT SUMPRECISION(%s) AS sum FROM
testTable", COLUMN_NAME);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"sum"}, new
ColumnDataType[]{ColumnDataType.STRING}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertNull(rows.get(0)[0]);
+ }
+ if (columnDataType != ColumnDataType.STRING) {
+ {
+ // Note: in Presto, inequality, equality, and IN comparison with nulls
always returns false:
+ long lowerLimit = 69;
+ String query =
+ String.format("SELECT %s FROM testTable WHERE %s > '%s' LIMIT 50",
COLUMN_NAME, COLUMN_NAME, lowerLimit);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
+ // Pinot loops through the column values from smallest to biggest.
Null comparison always returns false.
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 0);
+ }
+ }
+ {
+ String query = String.format("SELECT %s FROM testTable WHERE %s = '%s'",
COLUMN_NAME, COLUMN_NAME, 68);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema,
+ new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 0);
+ }
+ {
+ String query = String.format("SELECT %s FROM testTable WHERE %s = '%s'",
COLUMN_NAME, COLUMN_NAME, 69);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema,
+ new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 0);
+ }
+ if (columnDataType != ColumnDataType.STRING) {
+ {
+ String query = String.format(
+ "SELECT AVG(%s) AS avg FROM testTable GROUP BY %s ORDER BY avg
LIMIT 20", COLUMN_NAME, COLUMN_NAME);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"avg"}, new
ColumnDataType[]{ColumnDataType.DOUBLE}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ Object[] row = rows.get(0);
+ assertEquals(row.length, 1);
+ assertNull(row[0]);
+ }
+ {
+ // MODE cannot handle BIG_DECIMAL yet.
+ if (dataType != DataType.BIG_DECIMAL) {
+ String query = String.format("SELECT AVG(%s) AS avg, MODE(%s) AS
mode, DISTINCTCOUNT(%s) as distinct_count"
+ + " FROM testTable GROUP BY %s ORDER BY %s LIMIT 200",
COLUMN_NAME, COLUMN_NAME, COLUMN_NAME,
+ COLUMN_NAME, COLUMN_NAME);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"avg", "mode",
"distinct_count"},
+ new ColumnDataType[]{ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE, ColumnDataType.INT}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ Object[] row = rows.get(0);
+ assertEquals(row.length, 3);
+ assertNull(row[0]);
+ // TODO: this should return null instead of default value.
+ if (dataType == DataType.DOUBLE || dataType == DataType.FLOAT) {
+ assertEquals(row[1], Double.NEGATIVE_INFINITY);
+ } else if (dataType == DataType.LONG) {
+ assertEquals(((Double) row[1]).longValue(), Long.MIN_VALUE);
+ }
+ assertEquals(row[2], 1);
+ }
+ }
+ {
+ // If updated limit to include all records, I get back results
unsorted.
+ String query = String.format("SELECT MAX(%s) AS max, %s FROM testTable
GROUP BY %s ORDER BY max LIMIT 501",
+ COLUMN_NAME, COLUMN_NAME, COLUMN_NAME);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"max",
COLUMN_NAME},
+ new ColumnDataType[]{ColumnDataType.DOUBLE, columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertNull(rows.get(0)[0]);
+ }
+ }
DataTableFactory.setDataTableVersion(DataTableFactory.VERSION_3);
_indexSegment.destroy();
FileUtils.deleteDirectory(indexDir);
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/BigDecimalQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/BigDecimalQueriesTest.java
index 1c5eaad62d..a600a4d8e6 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/BigDecimalQueriesTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/BigDecimalQueriesTest.java
@@ -285,7 +285,8 @@ public class BigDecimalQueriesTest extends BaseQueriesTest {
assertEquals(dataSchema, new DataSchema(new String[]{"count"}, new
ColumnDataType[]{ColumnDataType.LONG}));
List<Object[]> rows = resultTable.getRows();
assertEquals(rows.size(), 1);
- assertEquals((long) rows.get(0)[0], 4 * NUM_RECORDS);
+ // A quarter of the data is null and hence the count is 3 * NUM_RECORDS,
not 4 * NUM_RECORDS.
+ assertEquals((long) rows.get(0)[0], 3 * NUM_RECORDS);
}
{
String query = String.format("SELECT %s FROM testTable GROUP BY %s ORDER
BY %s DESC",
@@ -387,54 +388,6 @@ public class BigDecimalQueriesTest extends BaseQueriesTest
{
assertEquals(row[0], BASE_BIG_DECIMAL.add(BigDecimal.valueOf(69)));
}
}
- {
- String query = String.format(
- "SELECT MAX(%s) AS maxValue FROM testTable GROUP BY %s HAVING
maxValue < %s ORDER BY maxValue",
- BIG_DECIMAL_COLUMN, BIG_DECIMAL_COLUMN,
BASE_BIG_DECIMAL.add(BigDecimal.valueOf(5)));
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{"maxValue"}, new
ColumnDataType[]{ColumnDataType.DOUBLE}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 5);
- assertEquals(rows.get(0)[0], 0.0);
- int i = 0;
- for (int index = 1; index < 5; index++) {
- Object[] row = rows.get(index);
- assertEquals(row.length, 1);
- if (i % 4 == 3) {
- // Null values are inserted at: index % 4 == 3.
- i++;
- }
- assertEquals(row[0],
BASE_BIG_DECIMAL.add(BigDecimal.valueOf(i)).doubleValue());
- i++;
- }
- }
- {
- int lowerLimit = 991;
- String query = String.format(
- "SELECT MAX(%s) AS maxValue FROM testTable GROUP BY %s HAVING
maxValue > %s ORDER BY maxValue",
- BIG_DECIMAL_COLUMN, BIG_DECIMAL_COLUMN,
BASE_BIG_DECIMAL.add(BigDecimal.valueOf(lowerLimit)));
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{"maxValue"}, new
ColumnDataType[]{ColumnDataType.DOUBLE}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 6);
- int i = lowerLimit;
- for (int index = 0; index < 6; index++) {
- if (i % 4 == 3) {
- // Null values are inserted at: index % 4 == 3.
- i++;
- }
- Object[] row = rows.get(index);
- assertEquals(row.length, 1);
- assertEquals(row[0],
BASE_BIG_DECIMAL.add(BigDecimal.valueOf(i)).doubleValue());
- i++;
- }
- }
{
// This returns currently 25 rows instead of a single row!
// int limit = 25;
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/BooleanNullEnabledQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/BooleanNullEnabledQueriesTest.java
index cd524d4a8c..8f080ecc92 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/BooleanNullEnabledQueriesTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/BooleanNullEnabledQueriesTest.java
@@ -194,21 +194,6 @@ public class BooleanNullEnabledQueriesTest extends
BaseQueriesTest {
}
}
}
- {
- String query = "SELECT booleanColumn FROM testTable WHERE booleanColumn";
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{"booleanColumn"}, new
ColumnDataType[]{ColumnDataType.BOOLEAN}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 10);
- for (int i = 0; i < 10; i++) {
- Object[] row = rows.get(i);
- assertEquals(row.length, 1);
- assertEquals(row[0], true);
- }
- }
{
String query = "SELECT * FROM testTable ORDER BY booleanColumn DESC
LIMIT 4000";
BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
@@ -275,6 +260,29 @@ public class BooleanNullEnabledQueriesTest extends
BaseQueriesTest {
assertEquals(thirdRow.length, 1);
assertEquals(thirdRow[0], false);
}
+ {
+ String query =
+ "SELECT COUNT(*) AS count, booleanColumn FROM testTable GROUP BY
booleanColumn ORDER BY booleanColumn";
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"count",
"booleanColumn"},
+ new ColumnDataType[]{ColumnDataType.LONG, ColumnDataType.BOOLEAN}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 3);
+ Object[] firstRow = rows.get(0);
+ assertEquals(firstRow.length, 2);
+ assertEquals(firstRow[0], (long) 1716);
+ assertEquals(firstRow[1], false);
+ Object[] secondRow = rows.get(1);
+ assertEquals(secondRow.length, 2);
+ assertEquals(secondRow[0], (long) 1716);
+ assertEquals(secondRow[1], true);
+ Object[] thirdRow = rows.get(2);
+ assertEquals(thirdRow.length, 2);
+ assertEquals(thirdRow[0], (long) 568);
+ assertNull(thirdRow[1]);
+ }
DataTableFactory.setDataTableVersion(DataTableFactory.VERSION_3);
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/NullEnabledQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/NullEnabledQueriesTest.java
index 430f561f1f..f617209fc5 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/NullEnabledQueriesTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/NullEnabledQueriesTest.java
@@ -20,6 +20,7 @@ package org.apache.pinot.queries;
import java.io.File;
import java.io.IOException;
+import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -64,8 +65,13 @@ public class NullEnabledQueriesTest extends BaseQueriesTest {
private static final int NUM_RECORDS = 1000;
private static List<GenericRow> _records;
+ private static BigDecimal _sumPrecision;
+ private static double _sum;
+ private static double _sumKey1;
+ private static double _sumKey2;
private static final String COLUMN_NAME = "column";
+ private static final String KEY_COLUMN = "key";
private IndexSegment _indexSegment;
private List<IndexSegment> _indexSegments;
@@ -89,13 +95,27 @@ public class NullEnabledQueriesTest extends BaseQueriesTest
{
throws Exception {
FileUtils.deleteDirectory(INDEX_DIR);
+ _sumPrecision = BigDecimal.ZERO;
+ _sum = 0;
+ _sumKey1 = 0;
+ _sumKey2 = 0;
_records = new ArrayList<>(NUM_RECORDS);
for (int i = 0; i < NUM_RECORDS; i++) {
GenericRow record = new GenericRow();
double value = baseValue.doubleValue() + i;
if (i % 2 == 0) {
record.putValue(COLUMN_NAME, value);
+ _sumPrecision = _sumPrecision.add(BigDecimal.valueOf(value));
+ _sum += value;
+ if (i < NUM_RECORDS / 2) {
+ record.putValue(KEY_COLUMN, 1);
+ _sumKey1 += value;
+ } else {
+ record.putValue(KEY_COLUMN, 2);
+ _sumKey2 += value;
+ }
} else {
+ // Key column value here is null.
record.putValue(COLUMN_NAME, null);
}
_records.add(record);
@@ -108,9 +128,15 @@ public class NullEnabledQueriesTest extends
BaseQueriesTest {
Schema schema;
if (dataType == DataType.BIG_DECIMAL) {
- schema = new Schema.SchemaBuilder().addMetric(COLUMN_NAME,
dataType).build();
+ schema = new Schema.SchemaBuilder()
+ .addMetric(COLUMN_NAME, dataType)
+ .addMetric(KEY_COLUMN, DataType.INT)
+ .build();
} else {
- schema = new Schema.SchemaBuilder().addSingleValueDimension(COLUMN_NAME,
dataType).build();
+ schema = new Schema.SchemaBuilder()
+ .addSingleValueDimension(COLUMN_NAME, dataType)
+ .addMetric(KEY_COLUMN, DataType.INT)
+ .build();
}
SegmentGeneratorConfig segmentGeneratorConfig = new
SegmentGeneratorConfig(tableConfig, schema);
@@ -190,19 +216,89 @@ public class NullEnabledQueriesTest extends
BaseQueriesTest {
DataTableFactory.setDataTableVersion(DataTableFactory.VERSION_4);
Map<String, String> queryOptions = new HashMap<>();
queryOptions.put("enableNullHandling", "true");
+ {
+ String query = String.format("SELECT SUM(%s) as sum, MIN(%s) AS min,
MAX(%s) AS max, COUNT(%s) AS count, %s "
+ + "FROM testTable GROUP BY %s ORDER BY sum",
+ COLUMN_NAME, COLUMN_NAME, COLUMN_NAME, COLUMN_NAME, KEY_COLUMN,
KEY_COLUMN);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"sum", "min",
"max", "count", "key"}, new ColumnDataType[]{
+ ColumnDataType.DOUBLE, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
ColumnDataType.LONG, ColumnDataType.INT
+ }));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 3);
+ for (int index = 0; index < 3; index++) {
+ Object[] row = rows.get(index);
+ assertEquals(row.length, 5);
+ int keyColumnIdx = 4;
+ if (row[keyColumnIdx] == null) {
+ for (int i = 0; i < 3; i++) {
+ assertNull(row[i]);
+ }
+ // We have 500 nulls and 4 * 500 = 2000. Nevertheless, count should
be 0 similar to Presto.
+ // In Presto:
+ // SELECT count(id) as count, key FROM (VALUES (null, 1), (null, 1),
(null, 2), (1, 3), (null, 3)) AS t(id,
+ // key) GROUP BY key ORDER BY key DESC;
+ // count | key
+ //-------+-----
+ // 1 | 3
+ // 0 | 2
+ // 0 | 1
+ //(3 rows)
+ assertEquals(row[3], 0L);
+ } else if ((int) row[keyColumnIdx] == 1) {
+ assertTrue(Math.abs(((Double) row[0]) - 4 * _sumKey1) < 1e-1);
+ assertTrue(Math.abs(((Double) row[1]) - baseValue.doubleValue()) <
1e-1);
+ assertTrue(Math.abs(((Double) row[2]) - (baseValue.doubleValue() +
Math.ceil(NUM_RECORDS / 2.0) - 2))
+ < 1e-1);
+ assertEquals(row[3], (long) (4 * (Math.ceil(NUM_RECORDS / 2.0) /
2)));
+ } else {
+ assertEquals(row[keyColumnIdx], 2);
+ assertTrue(Math.abs(((Double) row[0]) - 4 * _sumKey2) < 1e-1);
+ assertTrue(Math.abs(((Double) row[1]) - (baseValue.doubleValue() +
Math.ceil(NUM_RECORDS / 2.0))) < 1e-1);
+ assertTrue(Math.abs(((Double) row[2]) - (baseValue.doubleValue() +
NUM_RECORDS - 2)) < 1e-1);
+ assertEquals(row[3], (long) (4 * (Math.ceil(NUM_RECORDS / 2.0) /
2)));
+ }
+ }
+ }
+ {
+ String query = String.format(
+ "SELECT count(*) as count1, count(%s) as count2, min(%s) as min,
max(%s) as max FROM testTable", COLUMN_NAME,
+ COLUMN_NAME, COLUMN_NAME);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"count1", "count2",
"min", "max"}, new ColumnDataType[]{
+ ColumnDataType.LONG, ColumnDataType.LONG, ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ Object[] row = rows.get(0);
+ assertEquals(row.length, 4);
+ // Note: count(*) returns total number of docs (nullable and
non-nullable).
+ assertEquals((long) row[0], 1000 * 4);
+ // count(col) returns the count of non-nullable docs.
+ assertEquals((long) row[1], 500 * 4);
+ assertEquals(row[2], baseValue.doubleValue());
+ assertTrue(Math.abs((Double) row[3] - (baseValue.doubleValue() + 998)) <
1e-1);
+ }
{
String query = "SELECT * FROM testTable";
BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
ResultTable resultTable = brokerResponse.getResultTable();
DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{dataType}));
+ assertEquals(dataSchema, new DataSchema(new String[]{COLUMN_NAME,
KEY_COLUMN},
+ new ColumnDataType[]{dataType, ColumnDataType.INT}));
List<Object[]> rows = resultTable.getRows();
assertEquals(rows.size(), 10);
for (int i = 0; i < 10; i++) {
Object[] row = rows.get(i);
- assertEquals(row.length, 1);
+ assertEquals(row.length, 2);
if (row[0] != null) {
assertTrue(Math.abs(((Number) row[0]).doubleValue() -
(baseValue.doubleValue() + i)) < 1e-1);
+ assertEquals(row[1], 1);
+ } else {
+ assertNull(row[1]);
}
}
}
@@ -214,7 +310,7 @@ public class NullEnabledQueriesTest extends BaseQueriesTest
{
ResultTable resultTable = brokerResponse.getResultTable();
DataSchema dataSchema = resultTable.getDataSchema();
assertEquals(dataSchema,
- new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{dataType}));
+ new DataSchema(new String[]{COLUMN_NAME, KEY_COLUMN}, new
ColumnDataType[]{dataType, ColumnDataType.INT}));
List<Object[]> rows = resultTable.getRows();
assertEquals(rows.size(), 4000);
int k = 0;
@@ -225,7 +321,7 @@ public class NullEnabledQueriesTest extends BaseQueriesTest
{
}
for (int j = 0; j < 4; j++) {
Object[] values = rows.get(i + j);
- assertEquals(values.length, 1);
+ assertEquals(values.length, 2);
assertTrue(Math.abs(((Number) values[0]).doubleValue() -
(baseValue.doubleValue() + (NUM_RECORDS - 1 - k)))
< 1e-1);
}
@@ -236,7 +332,7 @@ public class NullEnabledQueriesTest extends BaseQueriesTest
{
// Note 2: The default null ordering is 'NULLS LAST', regardless of the
ordering direction.
for (int i = 2000; i < 4000; i++) {
Object[] values = rows.get(i);
- assertEquals(values.length, 1);
+ assertEquals(values.length, 2);
assertNull(values[0]);
}
}
@@ -305,6 +401,28 @@ public class NullEnabledQueriesTest extends
BaseQueriesTest {
List<Object[]> rows = resultTable.getRows();
assertEquals(rows.size(), limit);
}
+ {
+ String query = String.format("SELECT COUNT(%s) AS count, MIN(%s) AS min,
MAX(%s) AS max, AVG(%s) AS avg,"
+ + " SUM(%s) AS sum FROM testTable LIMIT 1000", COLUMN_NAME,
COLUMN_NAME, COLUMN_NAME, COLUMN_NAME,
+ COLUMN_NAME);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"count", "min",
"max", "avg", "sum"}, new ColumnDataType[]{
+ ColumnDataType.LONG, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE,
+ ColumnDataType.DOUBLE}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ int count = 4 * 500;
+ assertEquals((long) rows.get(0)[0], count);
+ double min = baseValue.doubleValue();
+ assertTrue(Math.abs((Double) rows.get(0)[1] - min) < 1e-1);
+ double max = baseValue.doubleValue() + 998;
+ assertTrue(Math.abs((Double) rows.get(0)[2] - max) < 1e-1);
+ double avg = _sum / (double) _records.size();
+ assertTrue(Math.abs((Double) rows.get(0)[3] - avg) < 1e-1);
+ assertTrue(Math.abs((Double) rows.get(0)[4] - (4 * _sum)) < 1e-1);
+ }
{
String query = String.format("SELECT %s FROM testTable GROUP BY %s ORDER
BY %s DESC", COLUMN_NAME, COLUMN_NAME,
COLUMN_NAME);
@@ -359,6 +477,72 @@ public class NullEnabledQueriesTest extends
BaseQueriesTest {
assertEquals(row[0], 2000L);
assertNull(row[1]);
}
+ {
+ String query = String.format("SELECT SUMPRECISION(%s) AS sum FROM
testTable", COLUMN_NAME);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"sum"}, new
ColumnDataType[]{ColumnDataType.STRING}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertTrue(Math.abs((new BigDecimal((String)
rows.get(0)[0])).doubleValue()
+ - _sumPrecision.multiply(BigDecimal.valueOf(4)).doubleValue()) <
1e-1);
+ }
+ {
+ // Note: in Presto, inequality, equality, and IN comparison with nulls
always returns false:
+ // Example 1:
+ // SELECT id FROM (VALUES (1), (2), (3), (null), (4), (5), (null), (6),
(null), (7), (8), (null), (9)) AS t (id)
+ // WHERE id > 6;
+ //
+ // Returns:
+ // id
+ //----
+ // 7
+ // 8
+ // 9
+ //
+ // Example 2:
+ // SELECT id FROM (VALUES (1), (2), (3), (null), (4), (5), (null), (6),
(null), (7), (8), (null), (9)) AS t (id)
+ // WHERE id = NULL;
+ // id
+ //----
+ //(0 rows)
+ //
+ // Example 3:
+ // SELECT id FROM (VALUES (1), (2), (3), (null), (4), (5), (null), (6),
(null), (7), (8), (null), (9)) AS t (id)
+ // WHERE id != NULL;
+ // id
+ //----
+ //(0 rows)
+ //
+ // SELECT id FROM (VALUES (1.3), (2.6), (3.6), (null), (4.2), (5.666),
(null), (6.83), (null), (7.66), (8.0),
+ // (null), (9.5)) AS t (id) WHERE id in (9.5, null);
+ // id
+ //-------
+ // 9.500
+ //(1 row)
+ //
+ String query = String.format("SELECT %s FROM testTable WHERE %s > '%s'
LIMIT 50", COLUMN_NAME, COLUMN_NAME,
+ baseValue.doubleValue() + 69);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{dataType}));
+ // Pinot loops through the column values from smallest to biggest. Null
comparison always returns false.
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 50);
+ int i = 0;
+ for (int index = 0; index < 50; index++) {
+ Object[] row = rows.get(index);
+ assertEquals(row.length, 1);
+ if ((69 + i + 1) % 2 == 1) {
+ // Null values are inserted at: index % 2 == 1. However, nulls are
not retuned by an comparison operator.
+ i++;
+ }
+ assertTrue(Math.abs(((Number) row[0]).doubleValue() -
(baseValue.doubleValue() + (69 + i + 1))) < 1e-1);
+ i++;
+ }
+ }
{
String query = String.format("SELECT %s FROM testTable WHERE %s = '%s'",
COLUMN_NAME, COLUMN_NAME,
baseValue.doubleValue() + 68);
@@ -385,28 +569,89 @@ public class NullEnabledQueriesTest extends
BaseQueriesTest {
// 69 % 2 == 1 (and so a null was inserted instead of 69 + BASE_FLOAT).
assertEquals(rows.size(), 0);
}
+ // TODO(nhejazi): uncomment this test after null is handled in inequality
operators (max < %s).
+// {
+// String query = String.format("SELECT COUNT(%s) AS count, MIN(%s) AS
min, MAX(%s) AS max, SUM(%s) AS sum"
+// + " FROM testTable GROUP BY %s HAVING max < %s ORDER BY max",
+// COLUMN_NAME, COLUMN_NAME, COLUMN_NAME, COLUMN_NAME, COLUMN_NAME,
+// baseValue.doubleValue() + 20);
+// BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+// ResultTable resultTable = brokerResponse.getResultTable();
+// DataSchema dataSchema = resultTable.getDataSchema();
+// assertEquals(dataSchema, new DataSchema(new String[]{"count", "min",
"max", "sum"}, new ColumnDataType[]{
+// ColumnDataType.LONG, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE
+// }));
+// List<Object[]> rows = resultTable.getRows();
+// assertEquals(rows.size(), 10);
+// int i = 0;
+// for (int index = 0; index < 10; index++) {
+// if (i % 2 == 1) {
+// // Null values are inserted at: index % 2 == 1.
+// i++;
+// }
+// Object[] row = rows.get(index);
+// assertEquals(row.length, 4);
+// assertEquals(row[0], 4L);
+// System.out.println("min = " + row[1]);
+// assertTrue(Math.abs(((Double) row[1]) - (baseValue.doubleValue() +
i)) < 1e-1);
+// System.out.println("max = " + row[2]);
+// assertTrue(Math.abs((Double) row[2] - (baseValue.doubleValue() + i))
< 1e-1);
+// System.out.println("sum = " + row[3]);
+// assertTrue(Math.abs((Double) row[3] - (4 * (baseValue.doubleValue()
+ i))) < 1e-1);
+// i++;
+// }
+// }
{
- int lowerLimit = 991;
- String query = String.format(
- "SELECT MAX(%s) AS max FROM testTable GROUP BY %s HAVING max > %s
ORDER BY max", COLUMN_NAME, COLUMN_NAME,
- baseValue.doubleValue() + lowerLimit);
+ String query = String.format("SELECT AVG(%s) AS avg, MODE(%s) AS mode,
DISTINCTCOUNT(%s) as distinct_count"
+ + " FROM testTable GROUP BY %s HAVING avg < %s ORDER BY %s LIMIT
200",
+ COLUMN_NAME, COLUMN_NAME, COLUMN_NAME, COLUMN_NAME,
+ baseValue.doubleValue() + 400, COLUMN_NAME);
BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
ResultTable resultTable = brokerResponse.getResultTable();
DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{"max"}, new
ColumnDataType[]{ColumnDataType.DOUBLE}));
+ assertEquals(dataSchema, new DataSchema(new String[]{"avg", "mode",
"distinct_count"},
+ new ColumnDataType[]{ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
ColumnDataType.INT}));
List<Object[]> rows = resultTable.getRows();
- int i = lowerLimit;
- for (Object[] row : rows) {
+ assertEquals(rows.size(), 200);
+ int i = 0;
+ for (int index = 0; index < 200; index++) {
if (i % 2 == 1) {
// Null values are inserted at: index % 2 == 1.
i++;
}
- assertEquals(row.length, 1);
+ Object[] row = rows.get(index);
+ assertEquals(row.length, 3);
assertTrue(Math.abs((Double) row[0] - (baseValue.doubleValue() + i)) <
1e-1);
+ assertTrue(Math.abs((Double) row[1] - (baseValue.doubleValue() + i)) <
1e-1);
+ assertEquals(row[2], 1);
i++;
}
}
+ {
+ // If updated limit to include all records, I get back results unsorted.
+ String query = String.format("SELECT MAX(%s) AS max, %s FROM testTable
GROUP BY %s ORDER BY max LIMIT 501",
+ COLUMN_NAME, COLUMN_NAME, COLUMN_NAME);
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"max", COLUMN_NAME},
+ new ColumnDataType[]{ColumnDataType.DOUBLE, dataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 501);
+ int i = 0;
+ for (int index = 0; index < 500; index++) {
+ if (i % 2 == 1) {
+ // Null values are inserted at: index % 2 == 1.
+ i++;
+ }
+ Object[] row = rows.get(index);
+ assertEquals(row.length, 2);
+ assertTrue(Math.abs((Double) row[0] - (baseValue.doubleValue() + i)) <
1e-1);
+ assertTrue(Math.abs(((Number) row[1]).doubleValue() -
(baseValue.doubleValue() + i)) < 1e-1);
+ i++;
+ }
+ assertNull(rows.get(rows.size() - 1)[0]);
+ }
DataTableFactory.setDataTableVersion(DataTableFactory.VERSION_3);
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/SumPrecisionQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/SumPrecisionQueriesTest.java
index 3b4aa59416..7177537845 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/SumPrecisionQueriesTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/SumPrecisionQueriesTest.java
@@ -118,10 +118,10 @@ public class SumPrecisionQueriesTest extends
BaseQueriesTest {
long longValue = RANDOM.nextLong();
_longSum = _longSum.add(BigDecimal.valueOf(longValue));
float floatValue = RANDOM.nextFloat();
- _floatSum = _floatSum.add(new BigDecimal(Float.toString(floatValue)));
+ _floatSum = _floatSum.add(new BigDecimal(String.valueOf(floatValue)));
double doubleValue = RANDOM.nextDouble();
String stringValue = Double.toString(doubleValue);
- BigDecimal bigDecimalValue = new BigDecimal(stringValue);
+ BigDecimal bigDecimalValue = BigDecimal.valueOf(doubleValue);
_doubleSum = _doubleSum.add(bigDecimalValue);
byte[] bytesValue = BigDecimalUtils.serialize(bigDecimalValue);
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 200336ed77..25168a52ba 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -174,6 +174,7 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
private AggregationFunction toAggregationFunction(RexExpression aggCall, int
aggregationFunctionInputRef) {
Preconditions.checkState(aggCall instanceof RexExpression.FunctionCall);
+ // TODO(Rong Rong): query options are not supported by the new engine at
this moment.
switch (((RexExpression.FunctionCall) aggCall).getFunctionName()) {
case "$SUM":
case "$SUM0":
@@ -182,7 +183,8 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
ExpressionContext.forIdentifier(String.valueOf(aggregationFunctionInputRef)));
case "$COUNT":
case "COUNT":
- return new CountAggregationFunction();
+ return new CountAggregationFunction(
+
ExpressionContext.forIdentifier(String.valueOf(aggregationFunctionInputRef)));
case "$MIN":
case "$MIN0":
case "MIN":
diff --git
a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/CountValueAggregator.java
b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/CountValueAggregator.java
index e74687074b..177578935b 100644
---
a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/CountValueAggregator.java
+++
b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/CountValueAggregator.java
@@ -22,7 +22,7 @@ import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.data.FieldSpec.DataType;
-public class CountValueAggregator implements ValueAggregator<Void, Long> {
+public class CountValueAggregator implements ValueAggregator<Object, Long> {
public static final DataType AGGREGATED_VALUE_TYPE = DataType.LONG;
@Override
@@ -36,12 +36,12 @@ public class CountValueAggregator implements
ValueAggregator<Void, Long> {
}
@Override
- public Long getInitialAggregatedValue(Void rawValue) {
+ public Long getInitialAggregatedValue(Object rawValue) {
return 1L;
}
@Override
- public Long applyRawValue(Long value, Void rawValue) {
+ public Long applyRawValue(Long value, Object rawValue) {
return value + 1;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]