This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 657e245 Add SegmentPartitionedDistinctCount aggregation function
(#5786)
657e245 is described below
commit 657e2452176c45edfe71a3c69d60ce3b7cec6982
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Mon Aug 3 15:24:16 2020 -0700
Add SegmentPartitionedDistinctCount aggregation function (#5786)
Add a new `SegmentPartitionedDistinctCountAggregationFunction` to calculate
the number of distinct values when values are partitioned for each segment.
This function calculates the exact number of distinct values (using raw
value instead of hash code) within the segment, then simply sums up the results
from different segments to get the final result.
---
.../common/function/AggregationFunctionType.java | 1 +
.../query/DictionaryBasedAggregationOperator.java | 3 +
.../core/plan/maker/InstancePlanMakerImplV2.java | 8 +-
.../function/AggregationFunctionFactory.java | 2 +
.../function/AggregationFunctionUtils.java | 14 +-
.../function/AggregationFunctionVisitorBase.java | 4 +-
...artitionedDistinctCountAggregationFunction.java | 425 +++++++++++++++++++++
...SegmentPartitionedDistinctCountQueriesTest.java | 253 ++++++++++++
8 files changed, 699 insertions(+), 11 deletions(-)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
index cf48ea6..62704db 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
@@ -28,6 +28,7 @@ public enum AggregationFunctionType {
MINMAXRANGE("minMaxRange"),
DISTINCTCOUNT("distinctCount"),
DISTINCTCOUNTBITMAP("distinctCountBitmap"),
+ SEGMENTPARTITIONEDDISTINCTCOUNT("segmentPartitionedDistinctCount"),
DISTINCTCOUNTHLL("distinctCountHLL"),
DISTINCTCOUNTRAWHLL("distinctCountRawHLL"),
FASTHLL("fastHLL"),
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
index 35abe86..7fa6798 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
@@ -114,6 +114,9 @@ public class DictionaryBasedAggregationOperator extends
BaseOperator<Intermediat
}
aggregationResults.add(set);
break;
+ case SEGMENTPARTITIONEDDISTINCTCOUNT:
+ aggregationResults.add((long) dictionarySize);
+ break;
default:
throw new IllegalStateException(
"Dictionary based aggregation operator does not support function
type: " + aggregationFunction.getType());
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
b/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
index d7f8bad..caae551 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
@@ -158,7 +158,7 @@ public class InstancePlanMakerImplV2 implements PlanMaker {
/**
* Returns {@code true} if the given aggregation-only without filter
QueryContext can be solved with dictionary,
* {@code false} otherwise.
- * <p>Aggregations supported: MIN, MAX, MINMAXRANGE, DISTINCTCOUNT
+ * <p>Aggregations supported: MIN, MAX, MIN_MAX_RANGE, DISTINCT_COUNT,
SEGMENT_PARTITIONED_DISTINCT_COUNT
*/
@VisibleForTesting
static boolean isFitForDictionaryBasedPlan(QueryContext queryContext,
IndexSegment indexSegment) {
@@ -179,8 +179,10 @@ public class InstancePlanMakerImplV2 implements PlanMaker {
if (dictionary == null) {
return false;
}
- // NOTE: DISTINCTCOUNT does not require sorted dictionary
- if (!dictionary.isSorted() &&
!functionName.equalsIgnoreCase(AggregationFunctionType.DISTINCTCOUNT.name())) {
+ // TODO: Remove this check because MutableDictionary maintains min/max
value
+ // NOTE: DISTINCT_COUNT and SEGMENT_PARTITIONED_DISTINCT_COUNT does not
require sorted dictionary
+ if (!dictionary.isSorted() &&
!functionName.equalsIgnoreCase(AggregationFunctionType.DISTINCTCOUNT.name())
+ &&
!functionName.equalsIgnoreCase(AggregationFunctionType.SEGMENTPARTITIONEDDISTINCTCOUNT.name()))
{
return false;
}
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index 3a1bb01..08eed30 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -123,6 +123,8 @@ public class AggregationFunctionFactory {
return new DistinctCountAggregationFunction(firstArgument);
case DISTINCTCOUNTBITMAP:
return new DistinctCountBitmapAggregationFunction(firstArgument);
+ case SEGMENTPARTITIONEDDISTINCTCOUNT:
+ return new
SegmentPartitionedDistinctCountAggregationFunction(firstArgument);
case DISTINCTCOUNTHLL:
return new DistinctCountHLLAggregationFunction(arguments);
case DISTINCTCOUNTRAWHLL:
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
index 3822373..b5784fb 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
@@ -209,12 +209,12 @@ public class AggregationFunctionUtils {
}
public static boolean isFitForDictionaryBasedComputation(String
functionName) {
- if (functionName.equalsIgnoreCase(AggregationFunctionType.MIN.name()) ||
//
- functionName.equalsIgnoreCase(AggregationFunctionType.MAX.name()) || //
-
functionName.equalsIgnoreCase(AggregationFunctionType.MINMAXRANGE.name()) || //
-
functionName.equalsIgnoreCase(AggregationFunctionType.DISTINCTCOUNT.name())) {
- return true;
- }
- return false;
+ //@formatter:off
+ return functionName.equalsIgnoreCase(AggregationFunctionType.MIN.name())
+ || functionName.equalsIgnoreCase(AggregationFunctionType.MAX.name())
+ ||
functionName.equalsIgnoreCase(AggregationFunctionType.MINMAXRANGE.name())
+ ||
functionName.equalsIgnoreCase(AggregationFunctionType.DISTINCTCOUNT.name())
+ ||
functionName.equalsIgnoreCase(AggregationFunctionType.SEGMENTPARTITIONEDDISTINCTCOUNT.name());
+ //@formatter:on
}
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionVisitorBase.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionVisitorBase.java
index 72c8d4c..2710ba7 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionVisitorBase.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionVisitorBase.java
@@ -51,6 +51,9 @@ public class AggregationFunctionVisitorBase {
public void visit(DistinctCountBitmapMVAggregationFunction function) {
}
+ public void visit(SegmentPartitionedDistinctCountAggregationFunction
function) {
+ }
+
public void visit(DistinctCountHLLAggregationFunction function) {
}
@@ -106,7 +109,6 @@ public class AggregationFunctionVisitorBase {
}
public void visit(StUnionAggregationFunction function) {
-
}
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java
new file mode 100644
index 0000000..221969d
--- /dev/null
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java
@@ -0,0 +1,425 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet;
+import it.unimi.dsi.fastutil.floats.FloatOpenHashSet;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
+import java.util.Collection;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
+import org.apache.pinot.core.query.request.context.ExpressionContext;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.utils.ByteArray;
+import org.roaringbitmap.RoaringBitmap;
+
+
+/**
+ * The {@code SegmentPartitionedDistinctCountAggregationFunction} calculates
the number of distinct values for a given
+ * single-value expression.
+ * <p>IMPORTANT: This function relies on the expression values being
partitioned for each segment, where there is no
+ * common values within different segments.
+ * <p>This function calculates the exact number of distinct values within the
segment, then simply sums up the results
+ * from different segments to get the final result.
+ */
+public class SegmentPartitionedDistinctCountAggregationFunction extends
BaseSingleInputAggregationFunction<Long, Long> {
+
+ public SegmentPartitionedDistinctCountAggregationFunction(ExpressionContext
expression) {
+ super(expression);
+ }
+
+ @Override
+ public AggregationFunctionType getType() {
+ return AggregationFunctionType.SEGMENTPARTITIONEDDISTINCTCOUNT;
+ }
+
+ @Override
+ public void accept(AggregationFunctionVisitorBase visitor) {
+ visitor.visit(this);
+ }
+
+ @Override
+ public AggregationResultHolder createAggregationResultHolder() {
+ return new ObjectAggregationResultHolder();
+ }
+
+ @Override
+ public GroupByResultHolder createGroupByResultHolder(int initialCapacity,
int maxCapacity) {
+ return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+ }
+
+ @Override
+ public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+
+ // For dictionary-encoded expression, store dictionary ids into a
RoaringBitmap
+ if (blockValSet.getDictionary() != null) {
+ int[] dictIds = blockValSet.getDictionaryIdsSV();
+ RoaringBitmap bitmap = aggregationResultHolder.getResult();
+ if (bitmap == null) {
+ bitmap = new RoaringBitmap();
+ aggregationResultHolder.setValue(bitmap);
+ }
+ bitmap.addN(dictIds, 0, length);
+ return;
+ }
+
+ // For non-dictionary-encoded expression, store INT values into a
RoaringBitmap, other types into an OpenHashSet
+ DataType valueType = blockValSet.getValueType();
+ switch (valueType) {
+ case INT:
+ int[] intValues = blockValSet.getIntValuesSV();
+ RoaringBitmap bitmap = aggregationResultHolder.getResult();
+ if (bitmap == null) {
+ bitmap = new RoaringBitmap();
+ aggregationResultHolder.setValue(bitmap);
+ }
+ bitmap.addN(intValues, 0, length);
+ break;
+ case LONG:
+ long[] longValues = blockValSet.getLongValuesSV();
+ LongOpenHashSet longSet = aggregationResultHolder.getResult();
+ if (longSet == null) {
+ longSet = new LongOpenHashSet();
+ aggregationResultHolder.setValue(longSet);
+ }
+ for (int i = 0; i < length; i++) {
+ longSet.add(longValues[i]);
+ }
+ break;
+ case FLOAT:
+ float[] floatValues = blockValSet.getFloatValuesSV();
+ FloatOpenHashSet floatSet = aggregationResultHolder.getResult();
+ if (floatSet == null) {
+ floatSet = new FloatOpenHashSet();
+ aggregationResultHolder.setValue(floatSet);
+ }
+ for (int i = 0; i < length; i++) {
+ floatSet.add(floatValues[i]);
+ }
+ break;
+ case DOUBLE:
+ double[] doubleValues = blockValSet.getDoubleValuesSV();
+ DoubleOpenHashSet doubleSet = aggregationResultHolder.getResult();
+ if (doubleSet == null) {
+ doubleSet = new DoubleOpenHashSet();
+ aggregationResultHolder.setValue(doubleSet);
+ }
+ for (int i = 0; i < length; i++) {
+ doubleSet.add(doubleValues[i]);
+ }
+ break;
+ case STRING:
+ String[] stringValues = blockValSet.getStringValuesSV();
+ ObjectOpenHashSet<String> stringSet =
aggregationResultHolder.getResult();
+ if (stringSet == null) {
+ stringSet = new ObjectOpenHashSet<>();
+ aggregationResultHolder.setValue(stringSet);
+ }
+ //noinspection ManualArrayToCollectionCopy
+ for (int i = 0; i < length; i++) {
+ stringSet.add(stringValues[i]);
+ }
+ break;
+ case BYTES:
+ byte[][] bytesValues = blockValSet.getBytesValuesSV();
+ ObjectOpenHashSet<ByteArray> bytesSet =
aggregationResultHolder.getResult();
+ if (bytesSet == null) {
+ bytesSet = new ObjectOpenHashSet<>();
+ aggregationResultHolder.setValue(bytesSet);
+ }
+ for (int i = 0; i < length; i++) {
+ bytesSet.add(new ByteArray(bytesValues[i]));
+ }
+ break;
+ default:
+ throw new IllegalStateException(
+ "Illegal data type for PARTITIONED_DISTINCT_COUNT aggregation
function: " + valueType);
+ }
+ }
+
+ @Override
+ public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+
+ // For dictionary-encoded expression, store dictionary ids into a
RoaringBitmap
+ if (blockValSet.getDictionary() != null) {
+ int[] dictIds = blockValSet.getDictionaryIdsSV();
+ for (int i = 0; i < length; i++) {
+ setIntValueForGroup(groupByResultHolder, groupKeyArray[i], dictIds[i]);
+ }
+ return;
+ }
+
+ // For non-dictionary-encoded expression, store INT values into a
RoaringBitmap, other types into an OpenHashSet
+ DataType valueType = blockValSet.getValueType();
+ switch (valueType) {
+ case INT:
+ int[] intValues = blockValSet.getIntValuesSV();
+ for (int i = 0; i < length; i++) {
+ setIntValueForGroup(groupByResultHolder, groupKeyArray[i],
intValues[i]);
+ }
+ break;
+ case LONG:
+ long[] longValues = blockValSet.getLongValuesSV();
+ for (int i = 0; i < length; i++) {
+ setLongValueForGroup(groupByResultHolder, groupKeyArray[i],
longValues[i]);
+ }
+ break;
+ case FLOAT:
+ float[] floatValues = blockValSet.getFloatValuesSV();
+ for (int i = 0; i < length; i++) {
+ setFloatValueForGroup(groupByResultHolder, groupKeyArray[i],
floatValues[i]);
+ }
+ break;
+ case DOUBLE:
+ double[] doubleValues = blockValSet.getDoubleValuesSV();
+ for (int i = 0; i < length; i++) {
+ setDoubleValueForGroup(groupByResultHolder, groupKeyArray[i],
doubleValues[i]);
+ }
+ break;
+ case STRING:
+ String[] stringValues = blockValSet.getStringValuesSV();
+ for (int i = 0; i < length; i++) {
+ setStringValueForGroup(groupByResultHolder, groupKeyArray[i],
stringValues[i]);
+ }
+ break;
+ case BYTES:
+ byte[][] bytesValues = blockValSet.getBytesValuesSV();
+ for (int i = 0; i < length; i++) {
+ setBytesValueForGroup(groupByResultHolder, groupKeyArray[i], new
ByteArray(bytesValues[i]));
+ }
+ break;
+ default:
+ throw new IllegalStateException(
+ "Illegal data type for PARTITIONED_DISTINCT_COUNT aggregation
function: " + valueType);
+ }
+ }
+
+ @Override
+ public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+
+ // For dictionary-encoded expression, store dictionary ids into a
RoaringBitmap
+ if (blockValSet.getDictionary() != null) {
+ int[] dictIds = blockValSet.getDictionaryIdsSV();
+ for (int i = 0; i < length; i++) {
+ int dictId = dictIds[i];
+ for (int groupKey : groupKeysArray[i]) {
+ setIntValueForGroup(groupByResultHolder, groupKey, dictId);
+ }
+ }
+ return;
+ }
+
+ // For non-dictionary-encoded expression, store INT values into a
RoaringBitmap, other types into an OpenHashSet
+ DataType valueType = blockValSet.getValueType();
+ switch (valueType) {
+ case INT:
+ int[] intValues = blockValSet.getIntValuesSV();
+ for (int i = 0; i < length; i++) {
+ int value = intValues[i];
+ for (int groupKey : groupKeysArray[i]) {
+ setIntValueForGroup(groupByResultHolder, groupKey, value);
+ }
+ }
+ break;
+ case LONG:
+ long[] longValues = blockValSet.getLongValuesSV();
+ for (int i = 0; i < length; i++) {
+ long value = longValues[i];
+ for (int groupKey : groupKeysArray[i]) {
+ setLongValueForGroup(groupByResultHolder, groupKey, value);
+ }
+ }
+ break;
+ case FLOAT:
+ float[] floatValues = blockValSet.getFloatValuesSV();
+ for (int i = 0; i < length; i++) {
+ float value = floatValues[i];
+ for (int groupKey : groupKeysArray[i]) {
+ setFloatValueForGroup(groupByResultHolder, groupKey, value);
+ }
+ }
+ break;
+ case DOUBLE:
+ double[] doubleValues = blockValSet.getDoubleValuesSV();
+ for (int i = 0; i < length; i++) {
+ double value = doubleValues[i];
+ for (int groupKey : groupKeysArray[i]) {
+ setDoubleValueForGroup(groupByResultHolder, groupKey, value);
+ }
+ }
+ break;
+ case STRING:
+ String[] stringValues = blockValSet.getStringValuesSV();
+ for (int i = 0; i < length; i++) {
+ String value = stringValues[i];
+ for (int groupKey : groupKeysArray[i]) {
+ setStringValueForGroup(groupByResultHolder, groupKey, value);
+ }
+ }
+ break;
+ case BYTES:
+ byte[][] bytesValues = blockValSet.getBytesValuesSV();
+ for (int i = 0; i < length; i++) {
+ ByteArray value = new ByteArray(bytesValues[i]);
+ for (int groupKey : groupKeysArray[i]) {
+ setBytesValueForGroup(groupByResultHolder, groupKey, value);
+ }
+ }
+ break;
+ default:
+ throw new IllegalStateException(
+ "Illegal data type for PARTITIONED_DISTINCT_COUNT aggregation
function: " + valueType);
+ }
+ }
+
+ @Override
+ public Long extractAggregationResult(AggregationResultHolder
aggregationResultHolder) {
+ return extractIntermediateResult(aggregationResultHolder.getResult());
+ }
+
+ @Override
+ public Long extractGroupByResult(GroupByResultHolder groupByResultHolder,
int groupKey) {
+ return extractIntermediateResult(groupByResultHolder.getResult(groupKey));
+ }
+
+ @Override
+ public Long merge(Long intermediateResult1, Long intermediateResult2) {
+ return intermediateResult1 + intermediateResult2;
+ }
+
+ @Override
+ public boolean isIntermediateResultComparable() {
+ return true;
+ }
+
+ @Override
+ public ColumnDataType getIntermediateResultColumnType() {
+ return ColumnDataType.LONG;
+ }
+
+ @Override
+ public ColumnDataType getFinalResultColumnType() {
+ return ColumnDataType.LONG;
+ }
+
+ @Override
+ public Long extractFinalResult(Long intermediateResult) {
+ return intermediateResult;
+ }
+
+ /**
+ * Helper method to set an INT value for the given group key into the result
holder.
+ */
+ private static void setIntValueForGroup(GroupByResultHolder
groupByResultHolder, int groupKey, int value) {
+ RoaringBitmap bitmap = groupByResultHolder.getResult(groupKey);
+ if (bitmap == null) {
+ bitmap = new RoaringBitmap();
+ groupByResultHolder.setValueForKey(groupKey, bitmap);
+ }
+ bitmap.add(value);
+ }
+
+ /**
+ * Helper method to set an LONG value for the given group key into the
result holder.
+ */
+ private static void setLongValueForGroup(GroupByResultHolder
groupByResultHolder, int groupKey, long value) {
+ LongOpenHashSet longSet = groupByResultHolder.getResult(groupKey);
+ if (longSet == null) {
+ longSet = new LongOpenHashSet();
+ groupByResultHolder.setValueForKey(groupKey, longSet);
+ }
+ longSet.add(value);
+ }
+
+ /**
+ * Helper method to set an FLOAT value for the given group key into the
result holder.
+ */
+ private static void setFloatValueForGroup(GroupByResultHolder
groupByResultHolder, int groupKey, float value) {
+ FloatOpenHashSet floatSet = groupByResultHolder.getResult(groupKey);
+ if (floatSet == null) {
+ floatSet = new FloatOpenHashSet();
+ groupByResultHolder.setValueForKey(groupKey, floatSet);
+ }
+ floatSet.add(value);
+ }
+
+ /**
+ * Helper method to set an DOUBLE value for the given group key into the
result holder.
+ */
+ private static void setDoubleValueForGroup(GroupByResultHolder
groupByResultHolder, int groupKey, double value) {
+ DoubleOpenHashSet doubleSet = groupByResultHolder.getResult(groupKey);
+ if (doubleSet == null) {
+ doubleSet = new DoubleOpenHashSet();
+ groupByResultHolder.setValueForKey(groupKey, doubleSet);
+ }
+ doubleSet.add(value);
+ }
+
+ /**
+ * Helper method to set an STRING value for the given group key into the
result holder.
+ */
+ private static void setStringValueForGroup(GroupByResultHolder
groupByResultHolder, int groupKey, String value) {
+ ObjectOpenHashSet<String> stringSet =
groupByResultHolder.getResult(groupKey);
+ if (stringSet == null) {
+ stringSet = new ObjectOpenHashSet<>();
+ groupByResultHolder.setValueForKey(groupKey, stringSet);
+ }
+ stringSet.add(value);
+ }
+
+ /**
+ * Helper method to set an BYTES value for the given group key into the
result holder.
+ */
+ private static void setBytesValueForGroup(GroupByResultHolder
groupByResultHolder, int groupKey, ByteArray value) {
+ ObjectOpenHashSet<ByteArray> bytesSet =
groupByResultHolder.getResult(groupKey);
+ if (bytesSet == null) {
+ bytesSet = new ObjectOpenHashSet<>();
+ groupByResultHolder.setValueForKey(groupKey, bytesSet);
+ }
+ bytesSet.add(value);
+ }
+
+ /**
+ * Helper method to extract segment level intermediate result from the inner
segment result.
+ */
+ private static long extractIntermediateResult(@Nullable Object result) {
+ if (result == null) {
+ return 0L;
+ }
+ if (result instanceof RoaringBitmap) {
+ return ((RoaringBitmap) result).getLongCardinality();
+ }
+ assert result instanceof Collection;
+ return ((Collection<?>) result).size();
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/SegmentPartitionedDistinctCountQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/SegmentPartitionedDistinctCountQueriesTest.java
new file mode 100644
index 0000000..bef3e57
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/SegmentPartitionedDistinctCountQueriesTest.java
@@ -0,0 +1,253 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.pinot.common.response.broker.AggregationResult;
+import org.apache.pinot.common.response.broker.BrokerResponseNative;
+import org.apache.pinot.common.response.broker.GroupByResult;
+import org.apache.pinot.common.segment.ReadMode;
+import org.apache.pinot.common.utils.HashUtil;
+import org.apache.pinot.common.utils.StringUtil;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.data.readers.GenericRowRecordReader;
+import org.apache.pinot.core.indexsegment.IndexSegment;
+import org.apache.pinot.core.indexsegment.generator.SegmentGeneratorConfig;
+import org.apache.pinot.core.indexsegment.immutable.ImmutableSegment;
+import org.apache.pinot.core.indexsegment.immutable.ImmutableSegmentLoader;
+import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
+import org.apache.pinot.core.operator.query.AggregationGroupByOperator;
+import org.apache.pinot.core.operator.query.AggregationOperator;
+import org.apache.pinot.core.operator.query.DictionaryBasedAggregationOperator;
+import
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
+import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
+import
org.apache.pinot.core.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.Assert;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+
+
+/**
+ * Queries test for SEGMENT_PARTITIONED_DISTINCT_COUNT queries.
+ */
+@SuppressWarnings("rawtypes")
+public class SegmentPartitionedDistinctCountQueriesTest extends
BaseQueriesTest {
+ private static final File INDEX_DIR =
+ new File(FileUtils.getTempDirectory(),
"SegmentPartitionedDistinctCountQueriesTest");
+ private static final String RAW_TABLE_NAME = "testTable";
+ private static final String SEGMENT_NAME = "testSegment";
+ private static final Random RANDOM = new Random();
+
+ private static final int NUM_RECORDS = 2000;
+ private static final int MAX_VALUE = 1000;
+
+ private static final String INT_COLUMN = "intColumn";
+ private static final String LONG_COLUMN = "longColumn";
+ private static final String FLOAT_COLUMN = "floatColumn";
+ private static final String DOUBLE_COLUMN = "doubleColumn";
+ private static final String STRING_COLUMN = "stringColumn";
+ private static final String BYTES_COLUMN = "bytesColumn";
+ private static final Schema SCHEMA = new
Schema.SchemaBuilder().addSingleValueDimension(INT_COLUMN, DataType.INT)
+ .addSingleValueDimension(LONG_COLUMN,
DataType.LONG).addSingleValueDimension(FLOAT_COLUMN, DataType.FLOAT)
+ .addSingleValueDimension(DOUBLE_COLUMN,
DataType.DOUBLE).addSingleValueDimension(STRING_COLUMN, DataType.STRING)
+ .addSingleValueDimension(BYTES_COLUMN, DataType.BYTES).build();
+ private static final TableConfig TABLE_CONFIG =
+ new
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+
+ private Set<Integer> _values;
+ private long _expectedResult;
+ private IndexSegment _indexSegment;
+ private List<IndexSegment> _indexSegments;
+
+ @Override
+ protected String getFilter() {
+ // NOTE: Use a match all filter to switch between
DictionaryBasedAggregationOperator and AggregationOperator
+ return " WHERE intColumn >= 0";
+ }
+
+ @Override
+ protected IndexSegment getIndexSegment() {
+ return _indexSegment;
+ }
+
+ @Override
+ protected List<IndexSegment> getIndexSegments() {
+ return _indexSegments;
+ }
+
+ @BeforeClass
+ public void setUp()
+ throws Exception {
+ FileUtils.deleteDirectory(INDEX_DIR);
+
+ List<GenericRow> records = new ArrayList<>(NUM_RECORDS);
+ int hashMapCapacity = HashUtil.getHashMapCapacity(MAX_VALUE);
+ _values = new HashSet<>(hashMapCapacity);
+ for (int i = 0; i < NUM_RECORDS; i++) {
+ int value = RANDOM.nextInt(MAX_VALUE);
+ GenericRow record = new GenericRow();
+ record.putValue(INT_COLUMN, value);
+ _values.add(Integer.hashCode(value));
+ record.putValue(LONG_COLUMN, (long) value);
+ record.putValue(FLOAT_COLUMN, (float) value);
+ record.putValue(DOUBLE_COLUMN, (double) value);
+ String stringValue = Integer.toString(value);
+ record.putValue(STRING_COLUMN, stringValue);
+ // NOTE: Create fixed-length bytes so that dictionary can be generated
+ byte[] bytesValue =
StringUtil.encodeUtf8(StringUtils.leftPad(stringValue, 3, '0'));
+ record.putValue(BYTES_COLUMN, bytesValue);
+ records.add(record);
+ }
+ _expectedResult = _values.size();
+
+ SegmentGeneratorConfig segmentGeneratorConfig = new
SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA);
+ segmentGeneratorConfig.setTableName(RAW_TABLE_NAME);
+ segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
+ segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath());
+
+ SegmentIndexCreationDriverImpl driver = new
SegmentIndexCreationDriverImpl();
+ driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records));
+ driver.build();
+
+ ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new
File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap);
+ _indexSegment = immutableSegment;
+ _indexSegments = Arrays.asList(immutableSegment, immutableSegment);
+ }
+
+ @Test
+ public void testAggregationOnly() {
+ String query =
+ "SELECT SEGMENTPARTITIONEDDISTINCTCOUNT(intColumn),
SEGMENTPARTITIONEDDISTINCTCOUNT(longColumn),
SEGMENTPARTITIONEDDISTINCTCOUNT(floatColumn),
SEGMENTPARTITIONEDDISTINCTCOUNT(doubleColumn),
SEGMENTPARTITIONEDDISTINCTCOUNT(stringColumn),
SEGMENTPARTITIONEDDISTINCTCOUNT(bytesColumn) FROM testTable";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof DictionaryBasedAggregationOperator);
+ IntermediateResultsBlock resultsBlock =
((DictionaryBasedAggregationOperator) operator).nextBlock();
+ QueriesTestUtils
+
.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0, 0, NUM_RECORDS);
+ List<Object> aggregationResult = resultsBlock.getAggregationResult();
+
+ operator = getOperatorForPqlQueryWithFilter(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator)
operator).nextBlock();
+ QueriesTestUtils
+
.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0, 6 * NUM_RECORDS,
+ NUM_RECORDS);
+ List<Object> aggregationResultWithFilter =
resultsBlockWithFilter.getAggregationResult();
+
+ assertNotNull(aggregationResult);
+ assertNotNull(aggregationResultWithFilter);
+ assertEquals(aggregationResult, aggregationResultWithFilter);
+ for (int i = 0; i < 6; i++) {
+ assertEquals((long) aggregationResult.get(i), _expectedResult);
+ }
+
+ // Inter segments (expect 4 * inner segment result)
+ String[] expectedResults = new String[6];
+ for (int i = 0; i < 6; i++) {
+ expectedResults[i] = Long.toString(4 * _expectedResult);
+ }
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+ QueriesTestUtils
+ .testInterSegmentAggregationResult(brokerResponse, 4 * NUM_RECORDS, 0,
0, 4 * NUM_RECORDS, expectedResults);
+ brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
+ QueriesTestUtils
+ .testInterSegmentAggregationResult(brokerResponse, 4 * NUM_RECORDS, 0,
4 * 6 * NUM_RECORDS, 4 * NUM_RECORDS,
+ expectedResults);
+ }
+
+ @Test
+ public void testAggregationGroupBy() {
+ String query =
+ "SELECT SEGMENTPARTITIONEDDISTINCTCOUNT(intColumn),
SEGMENTPARTITIONEDDISTINCTCOUNT(longColumn),
SEGMENTPARTITIONEDDISTINCTCOUNT(floatColumn),
SEGMENTPARTITIONEDDISTINCTCOUNT(doubleColumn),
SEGMENTPARTITIONEDDISTINCTCOUNT(stringColumn),
SEGMENTPARTITIONEDDISTINCTCOUNT(bytesColumn) FROM testTable GROUP BY intColumn";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationGroupByOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator)
operator).nextBlock();
+ QueriesTestUtils
+
.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0, 6 * NUM_RECORDS,
+ NUM_RECORDS);
+ AggregationGroupByResult aggregationGroupByResult =
resultsBlock.getAggregationGroupByResult();
+ assertNotNull(aggregationGroupByResult);
+ int numGroups = 0;
+ Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator =
aggregationGroupByResult.getGroupKeyIterator();
+ while (groupKeyIterator.hasNext()) {
+ numGroups++;
+ GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
+ assertTrue(_values.contains(Integer.parseInt(groupKey._stringKey)));
+ for (int i = 0; i < 6; i++) {
+ assertEquals((long) aggregationGroupByResult.getResultForKey(groupKey,
i), 1);
+ }
+ }
+ assertEquals(numGroups, _values.size());
+
+ // Inter segments (expect 4 * inner segment result)
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+ Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+ Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 6
* NUM_RECORDS);
+ Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+ // size of this array will be equal to number of aggregation functions
since
+ // we return each aggregation function separately
+ List<AggregationResult> aggregationResults =
brokerResponse.getAggregationResults();
+ int numAggregationColumns = aggregationResults.size();
+ Assert.assertEquals(numAggregationColumns, 6);
+ for (AggregationResult aggregationResult : aggregationResults) {
+ Assert.assertNull(aggregationResult.getValue());
+ List<GroupByResult> groupByResults =
aggregationResult.getGroupByResult();
+ numGroups = groupByResults.size();
+ for (int i = 0; i < numGroups; i++) {
+ GroupByResult groupByResult = groupByResults.get(i);
+ List<String> group = groupByResult.getGroup();
+ assertEquals(group.size(), 1);
+ assertTrue(_values.contains(Integer.parseInt(group.get(0))));
+ assertEquals(groupByResult.getValue(), Long.toString(4));
+ }
+ }
+ }
+
+ @AfterClass
+ public void tearDown()
+ throws IOException {
+ _indexSegment.destroy();
+ FileUtils.deleteDirectory(INDEX_DIR);
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]