This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new f1da16473b Add DISTINCT_COUNT_OFF_HEAP aggregate function (#15469)
f1da16473b is described below
commit f1da16473b3f46f2eb1afeb1191e586172e67902
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Mon Apr 7 12:27:24 2025 -0600
Add DISTINCT_COUNT_OFF_HEAP aggregate function (#15469)
---
.../query/NonScanBasedAggregationOperator.java | 5 +
.../pinot/core/plan/AggregationPlanNode.java | 8 +-
.../function/AggregationFunctionFactory.java | 4 +-
.../function/DistinctCountAggregationFunction.java | 2 +-
.../DistinctCountMVAggregationFunction.java | 2 +-
.../DistinctCountOffHeapAggregationFunction.java | 533 +++++++++++++++++++++
.../pinot/queries/DistinctCountQueriesTest.java | 77 +++
.../pinot/segment/spi/AggregationFunctionType.java | 3 +
8 files changed, 627 insertions(+), 7 deletions(-)
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java
index fb6e8d02f6..04f388cd46 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java
@@ -38,6 +38,7 @@ import
org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import
org.apache.pinot.core.query.aggregation.function.DistinctCountHLLAggregationFunction;
import
org.apache.pinot.core.query.aggregation.function.DistinctCountHLLPlusAggregationFunction;
+import
org.apache.pinot.core.query.aggregation.function.DistinctCountOffHeapAggregationFunction;
import
org.apache.pinot.core.query.aggregation.function.DistinctCountRawHLLAggregationFunction;
import
org.apache.pinot.core.query.aggregation.function.DistinctCountRawHLLPlusAggregationFunction;
import
org.apache.pinot.core.query.aggregation.function.DistinctCountSmartHLLAggregationFunction;
@@ -111,6 +112,10 @@ public class NonScanBasedAggregationOperator extends
BaseOperator<AggregationRes
case DISTINCTAVGMV:
result =
getDistinctValueSet(Objects.requireNonNull(dataSource.getDictionary()));
break;
+ case DISTINCTCOUNTOFFHEAP:
+ result = ((DistinctCountOffHeapAggregationFunction)
aggregationFunction).extractAggregationResult(
+ Objects.requireNonNull(dataSource.getDictionary()));
+ break;
case DISTINCTCOUNTHLL:
case DISTINCTCOUNTHLLMV:
result =
getDistinctCountHLLResult(Objects.requireNonNull(dataSource.getDictionary()),
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
index f5157112a8..cca14f2704 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
@@ -49,10 +49,10 @@ import static
org.apache.pinot.segment.spi.AggregationFunctionType.*;
@SuppressWarnings("rawtypes")
public class AggregationPlanNode implements PlanNode {
private static final EnumSet<AggregationFunctionType>
DICTIONARY_BASED_FUNCTIONS =
- EnumSet.of(MIN, MINMV, MAX, MAXMV, MINMAXRANGE, MINMAXRANGEMV,
DISTINCTCOUNT, DISTINCTCOUNTMV, DISTINCTCOUNTHLL,
- DISTINCTCOUNTHLLMV, DISTINCTCOUNTRAWHLL, DISTINCTCOUNTRAWHLLMV,
SEGMENTPARTITIONEDDISTINCTCOUNT,
- DISTINCTCOUNTSMARTHLL, DISTINCTSUM, DISTINCTAVG, DISTINCTSUMMV,
DISTINCTAVGMV, DISTINCTCOUNTHLLPLUS,
- DISTINCTCOUNTHLLPLUSMV, DISTINCTCOUNTRAWHLLPLUS,
DISTINCTCOUNTRAWHLLPLUSMV);
+ EnumSet.of(MIN, MINMV, MAX, MAXMV, MINMAXRANGE, MINMAXRANGEMV,
DISTINCTCOUNT, DISTINCTCOUNTMV, DISTINCTSUM,
+ DISTINCTSUMMV, DISTINCTAVG, DISTINCTAVGMV, DISTINCTCOUNTOFFHEAP,
DISTINCTCOUNTHLL, DISTINCTCOUNTHLLMV,
+ DISTINCTCOUNTRAWHLL, DISTINCTCOUNTRAWHLLMV, DISTINCTCOUNTHLLPLUS,
DISTINCTCOUNTHLLPLUSMV,
+ DISTINCTCOUNTRAWHLLPLUS, DISTINCTCOUNTRAWHLLPLUSMV,
SEGMENTPARTITIONEDDISTINCTCOUNT, DISTINCTCOUNTSMARTHLL);
// DISTINCTCOUNT excluded because consuming segment metadata contains
unknown cardinality when there is no dictionary
private static final EnumSet<AggregationFunctionType>
METADATA_BASED_FUNCTIONS =
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index 205f1ae71a..b15896de7c 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -58,7 +58,7 @@ public class AggregationFunctionFactory {
/**
* Given the function information, returns a new instance of the
corresponding aggregation function.
- * <p>NOTE: Underscores in the function name are ignored in V1.
+ * <p>NOTE: Underscores in the function name are ignored.
*/
public static AggregationFunction getAggregationFunction(FunctionContext
function, boolean nullHandlingEnabled) {
try {
@@ -360,6 +360,8 @@ public class AggregationFunctionFactory {
return new MinMaxRangeAggregationFunction(arguments,
nullHandlingEnabled);
case DISTINCTCOUNT:
return new DistinctCountAggregationFunction(arguments,
nullHandlingEnabled);
+ case DISTINCTCOUNTOFFHEAP:
+ return new DistinctCountOffHeapAggregationFunction(arguments,
nullHandlingEnabled);
case DISTINCTCOUNTBITMAP:
return new DistinctCountBitmapAggregationFunction(arguments);
case SEGMENTPARTITIONEDDISTINCTCOUNT:
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
index 076bc2ccda..7d1b1b5592 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
@@ -30,7 +30,7 @@ import org.apache.pinot.segment.spi.AggregationFunctionType;
/**
- * Aggregation function to compute the average of distinct values for an SV
column
+ * Aggregation function to compute the count of distinct values for an SV
column.
*/
public class DistinctCountAggregationFunction extends
BaseDistinctAggregateAggregationFunction<Integer> {
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
index aa1cd6da66..9940ec3080 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
@@ -30,7 +30,7 @@ import org.apache.pinot.segment.spi.AggregationFunctionType;
/**
- * Aggregation function to compute the average of distinct values for an MV
column
+ * Aggregation function to compute the count of distinct values for an MV
column.
*/
public class DistinctCountMVAggregationFunction extends
BaseDistinctAggregateAggregationFunction<Integer> {
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountOffHeapAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountOffHeapAggregationFunction.java
new file mode 100644
index 0000000000..f0352e79d0
--- /dev/null
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountOffHeapAggregationFunction.java
@@ -0,0 +1,533 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import com.google.common.base.Preconditions;
+import java.util.BitSet;
+import java.util.List;
+import java.util.Map;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.pinot.common.CustomObject;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
+import
org.apache.pinot.core.query.aggregation.function.distinct.BaseOffHeapSet;
+import
org.apache.pinot.core.query.aggregation.function.distinct.OffHeap128BitSet;
+import
org.apache.pinot.core.query.aggregation.function.distinct.OffHeap32BitSet;
+import
org.apache.pinot.core.query.aggregation.function.distinct.OffHeap64BitSet;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.apache.pinot.segment.spi.index.reader.Dictionary;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+
+
+/// Aggregation function to compute the count of distinct values for a column
using off-heap memory.
+public class DistinctCountOffHeapAggregationFunction
+ extends NullableSingleInputAggregationFunction<BaseOffHeapSet, Integer> {
+ // Use empty OffHeap32BitSet as a placeholder for empty result
+ // NOTE: It is okay to close it (multiple times) since we are never adding
values into it
+ private static final OffHeap32BitSet EMPTY_PLACEHOLDER = new
OffHeap32BitSet(0);
+
+ private final int _initialCapacity;
+ private final int _hashBits;
+
+ public DistinctCountOffHeapAggregationFunction(List<ExpressionContext>
arguments, boolean nullHandlingEnabled) {
+ super(arguments.get(0), nullHandlingEnabled);
+ if (arguments.size() > 1) {
+ Parameters parameters = new
Parameters(arguments.get(1).getLiteral().getStringValue());
+ _initialCapacity = parameters._initialCapacity;
+ _hashBits = parameters._hashBits;
+ } else {
+ _initialCapacity = Parameters.DEFAULT_INITIAL_CAPACITY;
+ _hashBits = Parameters.DEFAULT_HASH_BITS;
+ }
+ }
+
+ @Override
+ public AggregationFunctionType getType() {
+ return AggregationFunctionType.DISTINCTCOUNTOFFHEAP;
+ }
+
+ @Override
+ public AggregationResultHolder createAggregationResultHolder() {
+ return new ObjectAggregationResultHolder();
+ }
+
+ @Override
+ public GroupByResultHolder createGroupByResultHolder(int initialCapacity,
int maxCapacity) {
+ throw new UnsupportedOperationException(
+ "DISTINCT_COUNT_OFF_HEAP cannot be applied to group-by queries. Use
DISTINCT_COUNT instead.");
+ }
+
+ @Override
+ public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+ Dictionary dictionary = blockValSet.getDictionary();
+ if (dictionary != null) {
+ // For dictionary-encoded expression, store dictionary ids into the
bitmap
+ if (blockValSet.isSingleValue()) {
+ int[] dictIds = blockValSet.getDictionaryIdsSV();
+ BitSet dictIdBitSet = getDictIdBitSet(aggregationResultHolder,
dictionary);
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ dictIdBitSet.set(dictIds[i]);
+ }
+ });
+ } else {
+ int[][] dictIds = blockValSet.getDictionaryIdsMV();
+ BitSet dictIdBitSet = getDictIdBitSet(aggregationResultHolder,
dictionary);
+ for (int i = 0; i < length; i++) {
+ for (int dictId : dictIds[i]) {
+ dictIdBitSet.set(dictId);
+ }
+ }
+ }
+ } else {
+ // For non-dictionary-encoded expression, add values into the value set
+ BaseOffHeapSet valueSet = aggregationResultHolder.getResult();
+ if (valueSet == null) {
+ valueSet = createValueSet(blockValSet.getValueType().getStoredType());
+ aggregationResultHolder.setValue(valueSet);
+ }
+ if (blockValSet.isSingleValue()) {
+ addToValueSetSV(length, blockValSet, valueSet);
+ } else {
+ addToValueSetMV(length, blockValSet, valueSet);
+ }
+ }
+ }
+
+ private static BitSet getDictIdBitSet(AggregationResultHolder
aggregationResultHolder, Dictionary dictionary) {
+ DictIdsWrapper dictIdsWrapper = aggregationResultHolder.getResult();
+ if (dictIdsWrapper == null) {
+ dictIdsWrapper = new DictIdsWrapper(dictionary);
+ aggregationResultHolder.setValue(dictIdsWrapper);
+ }
+ return dictIdsWrapper._bitSet;
+ }
+
+ private BaseOffHeapSet createValueSet(DataType storedType) {
+ switch (storedType) {
+ case INT:
+ case FLOAT:
+ return new OffHeap32BitSet(_initialCapacity);
+ case LONG:
+ case DOUBLE:
+ return new OffHeap64BitSet(_initialCapacity);
+ default:
+ switch (_hashBits) {
+ case 32:
+ return new OffHeap32BitSet(_initialCapacity);
+ case 64:
+ return new OffHeap64BitSet(_initialCapacity);
+ case 128:
+ return new OffHeap128BitSet(_initialCapacity);
+ default:
+ throw new IllegalStateException();
+ }
+ }
+ }
+
+ private void addToValueSetSV(int length, BlockValSet blockValSet,
BaseOffHeapSet valueSet) {
+ DataType storedType = blockValSet.getValueType().getStoredType();
+ switch (storedType) {
+ case INT:
+ OffHeap32BitSet intSet = (OffHeap32BitSet) valueSet;
+ int[] intValues = blockValSet.getIntValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ intSet.add(intValues[i]);
+ }
+ });
+ break;
+ case LONG:
+ OffHeap64BitSet longSet = (OffHeap64BitSet) valueSet;
+ long[] longValues = blockValSet.getLongValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ longSet.add(longValues[i]);
+ }
+ });
+ break;
+ case FLOAT:
+ OffHeap32BitSet floatSet = (OffHeap32BitSet) valueSet;
+ float[] floatValues = blockValSet.getFloatValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ floatSet.add(Float.floatToRawIntBits(floatValues[i]));
+ }
+ });
+ break;
+ case DOUBLE:
+ OffHeap64BitSet doubleSet = (OffHeap64BitSet) valueSet;
+ double[] doubleValues = blockValSet.getDoubleValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ doubleSet.add(Double.doubleToRawLongBits(doubleValues[i]));
+ }
+ });
+ break;
+ default:
+ switch (_hashBits) {
+ case 32:
+ OffHeap32BitSet valueSet32 = (OffHeap32BitSet) valueSet;
+ int[] hashValues32 = blockValSet.get32BitsMurmur3HashValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueSet32.add(hashValues32[i]);
+ }
+ });
+ break;
+ case 64:
+ OffHeap64BitSet valueSet64 = (OffHeap64BitSet) valueSet;
+ long[] hashValues64 = blockValSet.get64BitsMurmur3HashValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueSet64.add(hashValues64[i]);
+ }
+ });
+ break;
+ case 128:
+ OffHeap128BitSet valueSet128 = (OffHeap128BitSet) valueSet;
+ long[][] hashValues128 =
blockValSet.get128BitsMurmur3HashValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ long[] hashValue = hashValues128[i];
+ valueSet128.add(hashValue[0], hashValue[1]);
+ }
+ });
+ break;
+ default:
+ throw new IllegalStateException();
+ }
+ break;
+ }
+ }
+
+ private void addToValueSetMV(int length, BlockValSet blockValSet,
BaseOffHeapSet valueSet) {
+ DataType storedType = blockValSet.getValueType().getStoredType();
+ switch (storedType) {
+ case INT:
+ OffHeap32BitSet intSet = (OffHeap32BitSet) valueSet;
+ int[][] intValues = blockValSet.getIntValuesMV();
+ for (int i = 0; i < length; i++) {
+ for (int intValue : intValues[i]) {
+ intSet.add(intValue);
+ }
+ }
+ break;
+ case LONG:
+ OffHeap64BitSet longSet = (OffHeap64BitSet) valueSet;
+ long[][] longValues = blockValSet.getLongValuesMV();
+ for (int i = 0; i < length; i++) {
+ for (long longValue : longValues[i]) {
+ longSet.add(longValue);
+ }
+ }
+ break;
+ case FLOAT:
+ OffHeap32BitSet floatSet = (OffHeap32BitSet) valueSet;
+ float[][] floatValues = blockValSet.getFloatValuesMV();
+ for (int i = 0; i < length; i++) {
+ for (float floatValue : floatValues[i]) {
+ floatSet.add(Float.floatToRawIntBits(floatValue));
+ }
+ }
+ break;
+ case DOUBLE:
+ OffHeap64BitSet doubleSet = (OffHeap64BitSet) valueSet;
+ double[][] doubleValues = blockValSet.getDoubleValuesMV();
+ for (int i = 0; i < length; i++) {
+ for (double doubleValue : doubleValues[i]) {
+ doubleSet.add(Double.doubleToRawLongBits(doubleValue));
+ }
+ }
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ "DISTINCT_COUNT_OFF_HEAP does not support MV columns of type: " +
blockValSet.getValueType()
+ + ". Use DISTINCT_COUNT instead.");
+ }
+ }
+
+ @Override
+ public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public BaseOffHeapSet extractAggregationResult(AggregationResultHolder
aggregationResultHolder) {
+ Object result = aggregationResultHolder.getResult();
+ if (result == null) {
+ return EMPTY_PLACEHOLDER;
+ }
+ if (result instanceof DictIdsWrapper) {
+ return extractAggregationResult((DictIdsWrapper) result);
+ } else {
+ return (BaseOffHeapSet) result;
+ }
+ }
+
+ private BaseOffHeapSet extractAggregationResult(DictIdsWrapper
dictIdsWrapper) {
+ BitSet bitSet = dictIdsWrapper._bitSet;
+ int length = bitSet.cardinality();
+ Dictionary dictionary = dictIdsWrapper._dictionary;
+ DataType storedType = dictionary.getValueType();
+ switch (storedType) {
+ case INT:
+ OffHeap32BitSet intSet = new OffHeap32BitSet(length);
+ for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i +
1)) {
+ intSet.add(dictionary.getIntValue(i));
+ }
+ return intSet;
+ case LONG:
+ OffHeap64BitSet longSet = new OffHeap64BitSet(length);
+ for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i +
1)) {
+ longSet.add(dictionary.getLongValue(i));
+ }
+ return longSet;
+ case FLOAT:
+ OffHeap32BitSet floatSet = new OffHeap32BitSet(length);
+ for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i +
1)) {
+ floatSet.add(Float.floatToRawIntBits(dictionary.getFloatValue(i)));
+ }
+ return floatSet;
+ case DOUBLE:
+ OffHeap64BitSet doubleSet = new OffHeap64BitSet(length);
+ for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i +
1)) {
+
doubleSet.add(Double.doubleToRawLongBits(dictionary.getDoubleValue(i)));
+ }
+ return doubleSet;
+ default:
+ switch (_hashBits) {
+ case 32:
+ OffHeap32BitSet valueSet32 = new OffHeap32BitSet(length);
+ for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i
+ 1)) {
+ valueSet32.add(dictionary.get32BitsMurmur3HashValue(i));
+ }
+ return valueSet32;
+ case 64:
+ OffHeap64BitSet valueSet64 = new OffHeap64BitSet(length);
+ for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i
+ 1)) {
+ valueSet64.add(dictionary.get64BitsMurmur3HashValue(i));
+ }
+ return valueSet64;
+ case 128:
+ OffHeap128BitSet valueSet128 = new OffHeap128BitSet(length);
+ for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i
+ 1)) {
+ long[] hashValue = dictionary.get128BitsMurmur3HashValue(i);
+ valueSet128.add(hashValue[0], hashValue[1]);
+ }
+ return valueSet128;
+ default:
+ throw new IllegalStateException();
+ }
+ }
+ }
+
+ /// Extracts the value set from the dictionary.
+ public BaseOffHeapSet extractAggregationResult(Dictionary dictionary) {
+ int length = dictionary.length();
+ DataType storedType = dictionary.getValueType();
+ switch (storedType) {
+ case INT:
+ OffHeap32BitSet intSet = new OffHeap32BitSet(length);
+ for (int i = 0; i < length; i++) {
+ intSet.add(dictionary.getIntValue(i));
+ }
+ return intSet;
+ case LONG:
+ OffHeap64BitSet longSet = new OffHeap64BitSet(length);
+ for (int i = 0; i < length; i++) {
+ longSet.add(dictionary.getLongValue(i));
+ }
+ return longSet;
+ case FLOAT:
+ OffHeap32BitSet floatSet = new OffHeap32BitSet(length);
+ for (int i = 0; i < length; i++) {
+ floatSet.add(Float.floatToRawIntBits(dictionary.getFloatValue(i)));
+ }
+ return floatSet;
+ case DOUBLE:
+ OffHeap64BitSet doubleSet = new OffHeap64BitSet(length);
+ for (int i = 0; i < length; i++) {
+
doubleSet.add(Double.doubleToRawLongBits(dictionary.getDoubleValue(i)));
+ }
+ return doubleSet;
+ default:
+ switch (_hashBits) {
+ case 32:
+ OffHeap32BitSet valueSet32 = new OffHeap32BitSet(length);
+ for (int i = 0; i < length; i++) {
+ valueSet32.add(dictionary.get32BitsMurmur3HashValue(i));
+ }
+ return valueSet32;
+ case 64:
+ OffHeap64BitSet valueSet64 = new OffHeap64BitSet(length);
+ for (int i = 0; i < length; i++) {
+ valueSet64.add(dictionary.get64BitsMurmur3HashValue(i));
+ }
+ return valueSet64;
+ case 128:
+ OffHeap128BitSet valueSet128 = new OffHeap128BitSet(length);
+ for (int i = 0; i < length; i++) {
+ long[] hashValue = dictionary.get128BitsMurmur3HashValue(i);
+ valueSet128.add(hashValue[0], hashValue[1]);
+ }
+ return valueSet128;
+ default:
+ throw new IllegalStateException();
+ }
+ }
+ }
+
+ @Override
+ public BaseOffHeapSet extractGroupByResult(GroupByResultHolder
groupByResultHolder, int groupKey) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public BaseOffHeapSet merge(BaseOffHeapSet intermediateResult1,
BaseOffHeapSet intermediateResult2) {
+ assert intermediateResult1 != null && intermediateResult2 != null;
+ if (intermediateResult1.isEmpty()) {
+ intermediateResult1.close();
+ return intermediateResult2;
+ }
+ intermediateResult1.merge(intermediateResult2);
+ intermediateResult2.close();
+ return intermediateResult1;
+ }
+
+ @Override
+ public ColumnDataType getIntermediateResultColumnType() {
+ return ColumnDataType.OBJECT;
+ }
+
+ @Override
+ public SerializedIntermediateResult
serializeIntermediateResult(BaseOffHeapSet set) {
+ int type;
+ if (set instanceof OffHeap32BitSet) {
+ type = 0;
+ } else if (set instanceof OffHeap64BitSet) {
+ type = 1;
+ } else if (set instanceof OffHeap128BitSet) {
+ type = 2;
+ } else {
+ throw new IllegalStateException();
+ }
+ byte[] bytes = set.serialize();
+ set.close();
+ return new SerializedIntermediateResult(type, bytes);
+ }
+
+ @Override
+ public BaseOffHeapSet deserializeIntermediateResult(CustomObject
customObject) {
+ switch (customObject.getType()) {
+ case 0:
+ return OffHeap32BitSet.deserialize(customObject.getBuffer());
+ case 1:
+ return OffHeap64BitSet.deserialize(customObject.getBuffer());
+ case 2:
+ return OffHeap128BitSet.deserialize(customObject.getBuffer());
+ default:
+ throw new IllegalStateException();
+ }
+ }
+
+ @Override
+ public ColumnDataType getFinalResultColumnType() {
+ return ColumnDataType.INT;
+ }
+
+ @Override
+ public Integer extractFinalResult(BaseOffHeapSet set) {
+ assert set != null;
+ int size = set.size();
+ set.close();
+ return size;
+ }
+
+ @Override
+ public Integer mergeFinalResult(Integer finalResult1, Integer finalResult2) {
+ return finalResult1 + finalResult2;
+ }
+
+ /// Helper class to wrap the dictionary ids.
+ /// Different from the
BaseDistinctAggregateAggregationFunction.DictIdsWrapper, here we use a
pre-allocated BitSet
+ /// instead of RoaringBitmap for better performance on high cardinality
distinct count.
+ private static final class DictIdsWrapper {
+ final Dictionary _dictionary;
+ final BitSet _bitSet;
+
+ DictIdsWrapper(Dictionary dictionary) {
+ _dictionary = dictionary;
+ _bitSet = new BitSet(dictionary.length());
+ }
+ }
+
+ /// Helper class to wrap the parameters.
+ private static class Parameters {
+ static final char PARAMETER_DELIMITER = ';';
+ static final char PARAMETER_KEY_VALUE_SEPARATOR = '=';
+
+ static final String INITIAL_CAPACITY_KEY = "INITIALCAPACITY";
+ static final int DEFAULT_INITIAL_CAPACITY = 10_000;
+
+ static final String HASH_BITS_KEY = "HASHBITS";
+ static final int DEFAULT_HASH_BITS = 64;
+
+ int _initialCapacity = DEFAULT_INITIAL_CAPACITY;
+ int _hashBits = DEFAULT_HASH_BITS;
+
+ Parameters(String parametersString) {
+ StringUtils.deleteWhitespace(parametersString);
+ String[] keyValuePairs = StringUtils.split(parametersString,
PARAMETER_DELIMITER);
+ for (String keyValuePair : keyValuePairs) {
+ String[] keyAndValue = StringUtils.split(keyValuePair,
PARAMETER_KEY_VALUE_SEPARATOR);
+ Preconditions.checkArgument(keyAndValue.length == 2, "Invalid
parameter: %s", keyValuePair);
+ String key = keyAndValue[0];
+ String value = keyAndValue[1];
+ switch (key.toUpperCase()) {
+ case INITIAL_CAPACITY_KEY:
+ _initialCapacity = Integer.parseInt(value);
+ Preconditions.checkArgument(_initialCapacity > 0, "Initial
capacity must be > 0, got: %s",
+ _initialCapacity);
+ break;
+ case HASH_BITS_KEY:
+ _hashBits = Integer.parseInt(value);
+ Preconditions.checkArgument(_hashBits == 32 || _hashBits == 64 ||
_hashBits == 128,
+ "Hash bits must be 32, 64 or 128, got: %s", _hashBits);
+ break;
+ default:
+ throw new IllegalArgumentException("Invalid parameter key: " +
key);
+ }
+ }
+ }
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java
index 598e15debf..e3c6b6240f 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java
@@ -40,6 +40,9 @@ import
org.apache.pinot.core.operator.query.AggregationOperator;
import org.apache.pinot.core.operator.query.GroupByOperator;
import org.apache.pinot.core.operator.query.NonScanBasedAggregationOperator;
import
org.apache.pinot.core.query.aggregation.function.DistinctCountSmartHLLAggregationFunction;
+import
org.apache.pinot.core.query.aggregation.function.distinct.BaseOffHeapSet;
+import
org.apache.pinot.core.query.aggregation.function.distinct.OffHeap128BitSet;
+import
org.apache.pinot.core.query.aggregation.function.distinct.OffHeap64BitSet;
import
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
import org.apache.pinot.core.query.request.context.QueryContext;
@@ -269,6 +272,80 @@ public class DistinctCountQueriesTest extends
BaseQueriesTest {
4 * NUM_RECORDS, expectedRows);
}
+ @Test
+ public void testOffHeap() {
+ // Dictionary based
+ String query = "SELECT "
+ + "DISTINCTCOUNTOFFHEAP(intColumn), "
+ + "DISTINCTCOUNTOFFHEAP(longColumn), "
+ + "DISTINCTCOUNTOFFHEAP(floatColumn), "
+ + "DISTINCTCOUNTOFFHEAP(doubleColumn), "
+ + "DISTINCTCOUNTOFFHEAP(stringColumn), "
+ + "DISTINCTCOUNTOFFHEAP(bytesColumn) "
+ + "FROM testTable";
+
+ // Inner segment
+ for (Object operator : Arrays.asList(getOperator(query),
getOperatorWithFilter(query))) {
+ assertTrue(operator instanceof NonScanBasedAggregationOperator);
+ AggregationResultsBlock resultsBlock =
((NonScanBasedAggregationOperator) operator).nextBlock();
+ QueriesTestUtils.testInnerSegmentExecutionStatistics(((Operator)
operator).getExecutionStatistics(), NUM_RECORDS,
+ 0, 0, NUM_RECORDS);
+ List<Object> aggregationResult = resultsBlock.getResults();
+ assertNotNull(aggregationResult);
+ assertEquals(aggregationResult.size(), 6);
+ for (int i = 0; i < 6; i++) {
+ assertEquals(((BaseOffHeapSet) aggregationResult.get(i)).size(),
_values.size());
+ }
+ }
+
+ // Inter segments
+ Object[] expectedResults = Collections.nCopies(6,
_values.size()).toArray();
+ for (BrokerResponseNative brokerResponse :
Arrays.asList(getBrokerResponse(query),
+ getBrokerResponseWithFilter(query))) {
+ QueriesTestUtils.testInterSegmentsResult(brokerResponse, 4 *
NUM_RECORDS, 0, 0, 4 * NUM_RECORDS, expectedResults);
+ }
+
+ // Regular aggregation
+ query = query + " WHERE intColumn >= 500";
+
+ // Inner segment
+ int expectedResult = 0;
+ for (Integer value : _values) {
+ if (value >= 500) {
+ expectedResult++;
+ }
+ }
+ AggregationOperator aggregationOperator = getOperator(query);
+ List<Object> aggregationResult =
aggregationOperator.nextBlock().getResults();
+ assertNotNull(aggregationResult);
+ assertEquals(aggregationResult.size(), 6);
+ for (int i = 0; i < 6; i++) {
+ assertEquals(((BaseOffHeapSet) aggregationResult.get(i)).size(),
expectedResult);
+ }
+
+ // Inter segment
+ expectedResults = Collections.nCopies(6, expectedResult).toArray();
+ QueriesTestUtils.testInterSegmentsResult(getBrokerResponse(query),
expectedResults);
+
+ // Change parameters
+ query = "SELECT DISTINCTCOUNTOFFHEAP(stringColumn,
'initialcapacity=10;hashbits=128') FROM testTable";
+ NonScanBasedAggregationOperator nonScanOperator = getOperator(query);
+ aggregationResult = nonScanOperator.nextBlock().getResults();
+ assertNotNull(aggregationResult);
+ assertEquals(aggregationResult.size(), 1);
+ assertTrue(aggregationResult.get(0) instanceof OffHeap128BitSet);
+ assertEquals(((OffHeap128BitSet) aggregationResult.get(0)).size(),
_values.size());
+
+ query = "SELECT DISTINCTCOUNTOFFHEAP(bytesColumn, 'initialcapacity=100')
FROM testTable "
+ + "WHERE intColumn >= 500";
+ aggregationOperator = getOperator(query);
+ aggregationResult = aggregationOperator.nextBlock().getResults();
+ assertNotNull(aggregationResult);
+ assertEquals(aggregationResult.size(), 1);
+ assertTrue(aggregationResult.get(0) instanceof OffHeap64BitSet);
+ assertEquals(((OffHeap64BitSet) aggregationResult.get(0)).size(),
expectedResult);
+ }
+
@Test
public void testHLL() {
// Dictionary based
diff --git
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index 623ecaf8ba..cace39ff1a 100644
---
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -70,6 +70,9 @@ public enum AggregationFunctionType {
* (2) count(distinct ...) support multi-argument and will be converted into
DISTINCT + COUNT
*/
DISTINCTCOUNT("distinctCount", ReturnTypes.BIGINT, OperandTypes.ANY,
SqlTypeName.OTHER, SqlTypeName.INTEGER),
+ DISTINCTCOUNTOFFHEAP("distinctCountOffHeap", ReturnTypes.BIGINT,
+ OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER),
i -> i == 1), SqlTypeName.OTHER,
+ SqlTypeName.INTEGER),
DISTINCTSUM("distinctSum", ReturnTypes.AGG_SUM, OperandTypes.NUMERIC,
SqlTypeName.OTHER, SqlTypeName.DOUBLE),
DISTINCTAVG("distinctAvg", ReturnTypes.DOUBLE, OperandTypes.NUMERIC,
SqlTypeName.OTHER),
DISTINCTCOUNTBITMAP("distinctCountBitmap", ReturnTypes.BIGINT,
OperandTypes.ANY, SqlTypeName.OTHER,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]