This is an automated email from the ASF dual-hosted git repository.
gortiz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 01b0f8c272 [null-aggr] Add null handling support in `mode` aggregation
(#12227)
01b0f8c272 is described below
commit 01b0f8c272472e37fdaebd3288d6cd924088d8f5
Author: Gonzalo Ortiz Jaureguizar <[email protected]>
AuthorDate: Mon Mar 4 10:54:57 2024 +0100
[null-aggr] Add null handling support in `mode` aggregation (#12227)
Support null handling in `mode` aggregation function. When null handling is
enabled, null values are ignored when the mode is calculated
---
.../function/AggregationFunctionFactory.java | 4 +-
.../BaseSingleInputAggregationFunction.java | 21 +
.../function/ModeAggregationFunction.java | 147 +++--
.../NullableSingleInputAggregationFunction.java | 141 ++++
.../function/ModeAggregationFunctionTest.java | 273 ++++++++
.../apache/pinot/queries/AllNullQueriesTest.java | 729 ++++++++++++---------
.../pinot/queries/NullEnabledQueriesTest.java | 6 +-
.../perf/AbstractAggregationFunctionBenchmark.java | 217 ++++++
.../pinot/perf/BenchmarkModeAggregation.java | 175 +++++
.../apache/pinot/perf/SyntheticBlockValSets.java | 256 ++++++++
.../pinot/perf/SyntheticNullBitmapFactories.java | 89 +++
11 files changed, 1675 insertions(+), 383 deletions(-)
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index f6d069a564..8db0d730d7 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -49,7 +49,7 @@ public class AggregationFunctionFactory {
/**
* Given the function information, returns a new instance of the
corresponding aggregation function.
- * <p>NOTE: Underscores in the function name are ignored.
+ * <p>NOTE: Underscores in the function name are ignored in V1.
*/
public static AggregationFunction getAggregationFunction(FunctionContext
function, boolean nullHandlingEnabled) {
try {
@@ -208,7 +208,7 @@ public class AggregationFunctionFactory {
case AVG:
return new AvgAggregationFunction(arguments, nullHandlingEnabled);
case MODE:
- return new ModeAggregationFunction(arguments);
+ return new ModeAggregationFunction(arguments, nullHandlingEnabled);
case FIRSTWITHTIME: {
Preconditions.checkArgument(numArguments == 3,
"FIRST_WITH_TIME expects 3 arguments, got: %s. The function
can be used as "
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseSingleInputAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseSingleInputAggregationFunction.java
index 6c15d06e08..e11ad8492a 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseSingleInputAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseSingleInputAggregationFunction.java
@@ -21,7 +21,10 @@ package org.apache.pinot.core.query.aggregation.function;
import com.google.common.base.Preconditions;
import java.util.Collections;
import java.util.List;
+import java.util.function.Supplier;
import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
/**
@@ -67,4 +70,22 @@ public abstract class BaseSingleInputAggregationFunction<I,
F extends Comparable
arguments.size());
return arguments.get(0);
}
+
+ protected static <E> E getValue(AggregationResultHolder
aggregationResultHolder, Supplier<E> orCreate) {
+ E result = aggregationResultHolder.getResult();
+ if (result == null) {
+ result = orCreate.get();
+ aggregationResultHolder.setValue(result);
+ }
+ return result;
+ }
+
+ protected static <E> E getValue(GroupByResultHolder groupByResultHolder, int
groupKey, Supplier<E> orCreate) {
+ E result = groupByResultHolder.getResult(groupKey);
+ if (result == null) {
+ result = orCreate.get();
+ groupByResultHolder.setValueForKey(groupKey, result);
+ }
+ return result;
+ }
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunction.java
index 882167385d..dbef6d8dad 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunction.java
@@ -56,14 +56,15 @@ import org.apache.pinot.spi.data.FieldSpec.DataType;
* </ul>
*/
@SuppressWarnings({"rawtypes", "unchecked"})
-public class ModeAggregationFunction extends
BaseSingleInputAggregationFunction<Map<? extends Number, Long>, Double> {
+public class ModeAggregationFunction
+ extends NullableSingleInputAggregationFunction<Map<? extends Number,
Long>, Double> {
private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
private final MultiModeReducerType _multiModeReducerType;
- public ModeAggregationFunction(List<ExpressionContext> arguments) {
- super(arguments.get(0));
+ public ModeAggregationFunction(List<ExpressionContext> arguments, boolean
nullHandlingEnabled) {
+ super(arguments.get(0), nullHandlingEnabled);
int numArguments = arguments.size();
Preconditions.checkArgument(numArguments <= 2, "Mode expects at most 2
arguments, got: %s", numArguments);
@@ -263,11 +264,14 @@ public class ModeAggregationFunction extends
BaseSingleInputAggregationFunction<
// For dictionary-encoded expression, store dictionary ids into the dictId
map
Dictionary dictionary = blockValSet.getDictionary();
if (dictionary != null) {
- int[] dictIds = blockValSet.getDictionaryIdsSV();
+
Int2IntOpenHashMap dictIdValueMap =
getDictIdCountMap(aggregationResultHolder, dictionary);
- for (int i = 0; i < length; i++) {
- dictIdValueMap.merge(dictIds[i], 1, Integer::sum);
- }
+ int[] dictIds = blockValSet.getDictionaryIdsSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ dictIdValueMap.merge(dictIds[i], 1, Integer::sum);
+ }
+ });
return;
}
@@ -278,30 +282,38 @@ public class ModeAggregationFunction extends
BaseSingleInputAggregationFunction<
case INT:
Int2LongOpenHashMap intMap = (Int2LongOpenHashMap) valueMap;
int[] intValues = blockValSet.getIntValuesSV();
- for (int i = 0; i < length; i++) {
- intMap.merge(intValues[i], 1, Long::sum);
- }
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ intMap.merge(intValues[i], 1, Long::sum);
+ }
+ });
break;
case LONG:
Long2LongOpenHashMap longMap = (Long2LongOpenHashMap) valueMap;
long[] longValues = blockValSet.getLongValuesSV();
- for (int i = 0; i < length; i++) {
- longMap.merge(longValues[i], 1, Long::sum);
- }
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ longMap.merge(longValues[i], 1, Long::sum);
+ }
+ });
break;
case FLOAT:
Float2LongOpenHashMap floatMap = (Float2LongOpenHashMap) valueMap;
float[] floatValues = blockValSet.getFloatValuesSV();
- for (int i = 0; i < length; i++) {
- floatMap.merge(floatValues[i], 1, Long::sum);
- }
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ floatMap.merge(floatValues[i], 1, Long::sum);
+ }
+ });
break;
case DOUBLE:
Double2LongOpenHashMap doubleMap = (Double2LongOpenHashMap) valueMap;
double[] doubleValues = blockValSet.getDoubleValuesSV();
- for (int i = 0; i < length; i++) {
- doubleMap.merge(doubleValues[i], 1, Long::sum);
- }
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ doubleMap.merge(doubleValues[i], 1, Long::sum);
+ }
+ });
break;
default:
throw new IllegalStateException("Illegal data type for MODE
aggregation function: " + storedType);
@@ -317,9 +329,12 @@ public class ModeAggregationFunction extends
BaseSingleInputAggregationFunction<
Dictionary dictionary = blockValSet.getDictionary();
if (dictionary != null) {
int[] dictIds = blockValSet.getDictionaryIdsSV();
- for (int i = 0; i < length; i++) {
- getDictIdCountMap(groupByResultHolder, groupKeyArray[i],
dictionary).merge(dictIds[i], 1, Integer::sum);
- }
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ Int2IntOpenHashMap dictIdCountMap =
getDictIdCountMap(groupByResultHolder, groupKeyArray[i], dictionary);
+ dictIdCountMap.merge(dictIds[i], 1, Integer::sum);
+ }
+ });
return;
}
@@ -328,27 +343,35 @@ public class ModeAggregationFunction extends
BaseSingleInputAggregationFunction<
switch (storedType) {
case INT:
int[] intValues = blockValSet.getIntValuesSV();
- for (int i = 0; i < length; i++) {
- setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
intValues[i]);
- }
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
intValues[i]);
+ }
+ });
break;
case LONG:
long[] longValues = blockValSet.getLongValuesSV();
- for (int i = 0; i < length; i++) {
- setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
longValues[i]);
- }
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
longValues[i]);
+ }
+ });
break;
case FLOAT:
float[] floatValues = blockValSet.getFloatValuesSV();
- for (int i = 0; i < length; i++) {
- setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
floatValues[i]);
- }
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
floatValues[i]);
+ }
+ });
break;
case DOUBLE:
double[] doubleValues = blockValSet.getDoubleValuesSV();
- for (int i = 0; i < length; i++) {
- setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
doubleValues[i]);
- }
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ setValueForGroupKeys(groupByResultHolder, groupKeyArray[i],
doubleValues[i]);
+ }
+ });
break;
default:
throw new IllegalStateException("Illegal data type for MODE
aggregation function: " + storedType);
@@ -364,11 +387,13 @@ public class ModeAggregationFunction extends
BaseSingleInputAggregationFunction<
Dictionary dictionary = blockValSet.getDictionary();
if (dictionary != null) {
int[] dictIds = blockValSet.getDictionaryIdsSV();
- for (int i = 0; i < length; i++) {
- for (int groupKey : groupKeysArray[i]) {
- getDictIdCountMap(groupByResultHolder, groupKey,
dictionary).merge(dictIds[i], 1, Integer::sum);
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ getDictIdCountMap(groupByResultHolder, groupKey,
dictionary).merge(dictIds[i], 1, Integer::sum);
+ }
}
- }
+ });
return;
}
@@ -377,35 +402,43 @@ public class ModeAggregationFunction extends
BaseSingleInputAggregationFunction<
switch (storedType) {
case INT:
int[] intValues = blockValSet.getIntValuesSV();
- for (int i = 0; i < length; i++) {
- for (int groupKey : groupKeysArray[i]) {
- setValueForGroupKeys(groupByResultHolder, groupKey, intValues[i]);
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ setValueForGroupKeys(groupByResultHolder, groupKey,
intValues[i]);
+ }
}
- }
+ });
break;
case LONG:
long[] longValues = blockValSet.getLongValuesSV();
- for (int i = 0; i < length; i++) {
- for (int groupKey : groupKeysArray[i]) {
- setValueForGroupKeys(groupByResultHolder, groupKey, longValues[i]);
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ setValueForGroupKeys(groupByResultHolder, groupKey,
longValues[i]);
+ }
}
- }
+ });
break;
case FLOAT:
float[] floatValues = blockValSet.getFloatValuesSV();
- for (int i = 0; i < length; i++) {
- for (int groupKey : groupKeysArray[i]) {
- setValueForGroupKeys(groupByResultHolder, groupKey,
floatValues[i]);
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ setValueForGroupKeys(groupByResultHolder, groupKey,
floatValues[i]);
+ }
}
- }
+ });
break;
case DOUBLE:
double[] doubleValues = blockValSet.getDoubleValuesSV();
- for (int i = 0; i < length; i++) {
- for (int groupKey : groupKeysArray[i]) {
- setValueForGroupKeys(groupByResultHolder, groupKey,
doubleValues[i]);
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ setValueForGroupKeys(groupByResultHolder, groupKey,
doubleValues[i]);
+ }
}
- }
+ });
break;
default:
throw new IllegalStateException("Illegal data type for MODE
aggregation function: " + storedType);
@@ -467,7 +500,11 @@ public class ModeAggregationFunction extends
BaseSingleInputAggregationFunction<
@Override
public Double extractFinalResult(Map<? extends Number, Long>
intermediateResult) {
if (intermediateResult.isEmpty()) {
- return DEFAULT_FINAL_RESULT;
+ if (_nullHandlingEnabled) {
+ return null;
+ } else {
+ return DEFAULT_FINAL_RESULT;
+ }
} else if (intermediateResult instanceof Int2LongOpenHashMap) {
return extractFinalResult((Int2LongOpenHashMap) intermediateResult);
} else if (intermediateResult instanceof Long2LongOpenHashMap) {
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java
new file mode 100644
index 0000000000..0a42db7442
--- /dev/null
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java
@@ -0,0 +1,141 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import javax.annotation.Nullable;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.common.BlockValSet;
+import org.roaringbitmap.IntIterator;
+import org.roaringbitmap.RoaringBitmap;
+
+
+public abstract class NullableSingleInputAggregationFunction<I, F extends
Comparable>
+ extends BaseSingleInputAggregationFunction<I, F> {
+ protected final boolean _nullHandlingEnabled;
+
+ public NullableSingleInputAggregationFunction(ExpressionContext expression,
boolean nullHandlingEnabled) {
+ super(expression);
+ _nullHandlingEnabled = nullHandlingEnabled;
+ }
+
+ /**
+ * A consumer that is being used to consume batch of indexes.
+ */
+ @FunctionalInterface
+ public interface BatchConsumer {
+ /**
+ * Consumes a batch of indexes.
+ * @param fromInclusive the start index (inclusive)
+ * @param toExclusive the end index (exclusive)
+ */
+ void consume(int fromInclusive, int toExclusive);
+ }
+
+ /**
+ * A reducer that is being used to fold over consecutive indexes.
+ * @param <A>
+ */
+ @FunctionalInterface
+ public interface Reducer<A> {
+ /**
+ * Applies the reducer to the range of indexes.
+ * @param acum the initial value of the accumulator
+ * @param fromInclusive the start index (inclusive)
+ * @param toExclusive the end index (exclusive)
+ * @return the next value of the accumulator (maybe the same as the input)
+ */
+ A apply(A acum, int fromInclusive, int toExclusive);
+ }
+
+ /**
+ * Iterates over the non-null ranges of the blockValSet and calls the
consumer for each range.
+ * @param blockValSet the blockValSet to iterate over
+ * @param consumer the consumer to call for each non-null range
+ */
+ public void forEachNotNull(int length, BlockValSet blockValSet,
BatchConsumer consumer) {
+ if (!_nullHandlingEnabled) {
+ consumer.consume(0, length);
+ return;
+ }
+
+ RoaringBitmap roaringBitmap = blockValSet.getNullBitmap();
+ if (roaringBitmap == null) {
+ consumer.consume(0, length);
+ return;
+ }
+
+ forEachNotNull(length, roaringBitmap.getIntIterator(), consumer);
+ }
+
+ /**
+ * Iterates over the non-null ranges of the nullIndexIterator and calls the
consumer for each range.
+ * @param nullIndexIterator an int iterator that returns values in ascending
order whose min value is 0.
+ * Rows are considered null if and only if their
index is emitted.
+ */
+ public void forEachNotNull(int length, IntIterator nullIndexIterator,
BatchConsumer consumer) {
+ int prev = 0;
+ while (nullIndexIterator.hasNext() && prev < length) {
+ int nextNull = Math.min(nullIndexIterator.next(), length);
+ if (nextNull > prev) {
+ consumer.consume(prev, nextNull);
+ }
+ prev = nextNull + 1;
+ }
+ if (prev < length) {
+ consumer.consume(prev, length);
+ }
+ }
+
+ /**
+ * Folds over the non-null ranges of the blockValSet using the reducer.
+ * @param initialAcum the initial value of the accumulator
+ * @param <A> The type of the accumulator
+ */
+ public <A> A foldNotNull(int length, @Nullable RoaringBitmap roaringBitmap,
A initialAcum, Reducer<A> reducer) {
+ IntIterator intIterator = roaringBitmap == null ? null :
roaringBitmap.getIntIterator();
+ return foldNotNull(length, intIterator, initialAcum, reducer);
+ }
+
+ /**
+ * Folds over the non-null ranges of the nullIndexIterator using the reducer.
+ * @param nullIndexIterator an int iterator that returns values in ascending
order whose min value is 0.
+ * Rows are considered null if and only if their
index is emitted.
+ * @param initialAcum the initial value of the accumulator
+ * @param <A> The type of the accumulator
+ */
+ public <A> A foldNotNull(int length, @Nullable IntIterator
nullIndexIterator, A initialAcum, Reducer<A> reducer) {
+ A acum = initialAcum;
+ if (!_nullHandlingEnabled || nullIndexIterator == null ||
!nullIndexIterator.hasNext()) {
+ return reducer.apply(initialAcum, 0, length);
+ }
+
+ int prev = 0;
+ while (nullIndexIterator.hasNext() && prev < length) {
+ int nextNull = Math.min(nullIndexIterator.next(), length);
+ if (nextNull > prev) {
+ acum = reducer.apply(acum, prev, nextNull);
+ }
+ prev = nextNull + 1;
+ }
+ if (prev < length) {
+ acum = reducer.apply(acum, prev, length);
+ }
+ return acum;
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunctionTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunctionTest.java
new file mode 100644
index 0000000000..fb637bedc1
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunctionTest.java
@@ -0,0 +1,273 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.pinot.core.query.aggregation.function;
+
+import org.apache.pinot.common.utils.PinotDataType;
+import org.apache.pinot.queries.FluentQueryTest;
+import org.apache.pinot.spi.config.table.FieldConfig;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+
+
+public class ModeAggregationFunctionTest extends
AbstractAggregationFunctionTest {
+
+ @DataProvider(name = "scenarios")
+ Object[] scenarios() {
+ return new Object[] {
+ new Scenario(FieldSpec.DataType.INT, true),
+
+ new Scenario(FieldSpec.DataType.INT, false),
+ new Scenario(FieldSpec.DataType.LONG, false),
+ new Scenario(FieldSpec.DataType.FLOAT, false),
+ new Scenario(FieldSpec.DataType.DOUBLE, false),
+ };
+ }
+
+ public class Scenario {
+ private final FieldSpec.DataType _dataType;
+ private final boolean _dictionary;
+
+ public Scenario(FieldSpec.DataType dataType, boolean dictionary) {
+ _dataType = dataType;
+ _dictionary = dictionary;
+ }
+
+ public FluentQueryTest.DeclaringTable getDeclaringTable(boolean
nullHandlingEnabled) {
+ FieldConfig.EncodingType encodingType =
+ _dictionary ? FieldConfig.EncodingType.DICTIONARY :
FieldConfig.EncodingType.RAW;
+ return givenSingleNullableFieldTable(_dataType, nullHandlingEnabled,
builder -> {
+ builder.withEncodingType(encodingType);
+
builder.withCompressionCodec(FieldConfig.CompressionCodec.PASS_THROUGH);
+ });
+ }
+
+ @Override
+ public String toString() {
+ return "Scenario{" + "dt=" + _dataType + ", dict=" + _dictionary + '}';
+ }
+ }
+
+ @Test(dataProvider = "scenarios")
+ void aggrWithoutNullAndEmptySegments(Scenario scenario) {
+ scenario.getDeclaringTable(false)
+ .onFirstInstance("myField",
+ "null",
+ "null"
+ ).andOnSecondInstance("myField",
+ "null",
+ "null"
+ ).whenQuery("select mode(myField) as mode from testTable")
+ .thenResultIs("DOUBLE", aggrWithoutNullResult(scenario._dataType));
+ }
+
+ @Test(dataProvider = "scenarios")
+ void aggrWithNullAndEmptySegments(Scenario scenario) {
+ scenario.getDeclaringTable(true)
+ .onFirstInstance("myField",
+ "null",
+ "null"
+ ).andOnSecondInstance("myField",
+ "null",
+ "null"
+ ).whenQuery("select mode(myField) as mode from testTable")
+ .thenResultIs("DOUBLE", "null");
+ }
+
+ String aggrWithoutNullResult(FieldSpec.DataType dt) {
+ switch (dt) {
+ case INT: return "-2.147483648E9";
+ case LONG: return "-9.223372036854776E18";
+ case FLOAT: return "-Infinity";
+ case DOUBLE: return "-Infinity";
+ default: throw new IllegalArgumentException(dt.toString());
+ }
+ }
+
+ @Test(dataProvider = "scenarios")
+ void aggrWithoutNull(Scenario scenario) {
+ scenario.getDeclaringTable(false)
+ .onFirstInstance("myField",
+ "null",
+ "1",
+ "null"
+ ).andOnSecondInstance("myField",
+ "null",
+ "1",
+ "null"
+ )
+ .whenQuery("select mode(myField) as mode from testTable")
+ .thenResultIs("DOUBLE", aggrWithoutNullResult(scenario._dataType));
+ }
+
+ @Test(dataProvider = "scenarios")
+ void aggrWithNull(Scenario scenario) {
+ scenario.getDeclaringTable(true)
+ .onFirstInstance("myField",
+ "null",
+ "1",
+ "null"
+ ).andOnSecondInstance("myField",
+ "null",
+ "1",
+ "null"
+ ).whenQuery("select mode(myField) as mode from testTable")
+ .thenResultIs("DOUBLE", "1");
+ }
+
+ String aggrSvWithoutNullResult(FieldSpec.DataType dt) {
+ switch (dt) {
+ case INT: return "-2.147483648E9";
+ case LONG: return "-9.223372036854776E18";
+ case FLOAT: return "-Infinity";
+ case DOUBLE: return "-Infinity";
+ default: throw new IllegalArgumentException(dt.toString());
+ }
+ }
+
+ @Test(dataProvider = "scenarios")
+ void aggrSvWithoutNull(Scenario scenario) {
+ scenario.getDeclaringTable(false)
+ .onFirstInstance("myField",
+ "null",
+ "1",
+ "null"
+ ).andOnSecondInstance("myField",
+ "null",
+ "1",
+ "null"
+ ).whenQuery("select 'cte', mode(myField) as mode from testTable group
by 'cte'")
+ .thenResultIs("STRING | DOUBLE", "cte | " +
aggrSvWithoutNullResult(scenario._dataType));
+ }
+
+ @Test(dataProvider = "scenarios")
+ void aggrSvWithNull(Scenario scenario) {
+ scenario.getDeclaringTable(true)
+ .onFirstInstance("myField",
+ "null",
+ "1",
+ "null"
+ ).andOnSecondInstance("myField",
+ "null",
+ "1",
+ "null"
+ ).whenQuery("select 'cte', mode(myField) as mode from testTable group
by 'cte'")
+ .thenResultIs("STRING | DOUBLE", "cte | 1");
+ }
+
+ @Test(dataProvider = "scenarios")
+ void aggrSvSelfWithoutNull(Scenario scenario) {
+ PinotDataType pinotDataType = scenario._dataType == FieldSpec.DataType.INT
+ ? PinotDataType.INTEGER :
PinotDataType.valueOf(scenario._dataType.name());
+
+ Object defaultNullValue;
+ switch (scenario._dataType) {
+ case INT:
+ defaultNullValue = Integer.MIN_VALUE;
+ break;
+ case LONG:
+ defaultNullValue = Long.MIN_VALUE;
+ break;
+ case FLOAT:
+ defaultNullValue = Float.NEGATIVE_INFINITY;
+ break;
+ case DOUBLE:
+ defaultNullValue = Double.NEGATIVE_INFINITY;
+ break;
+ default:
+ throw new IllegalArgumentException("Unexpected scenario data type " +
scenario._dataType);
+ }
+
+ scenario.getDeclaringTable(false)
+ .onFirstInstance("myField",
+ "null",
+ "1",
+ "2"
+ ).andOnSecondInstance("myField",
+ "null",
+ "1",
+ "2"
+ ).whenQuery("select myField, mode(myField) as mode from testTable
group by myField order by myField")
+ .thenResultIs(pinotDataType + " | DOUBLE",
+ defaultNullValue + " | " +
aggrSvWithoutNullResult(scenario._dataType),
+ "1 | 1",
+ "2 | 2");
+ }
+
+ @Test(dataProvider = "scenarios")
+ void aggrSvSelfWithNull(Scenario scenario) {
+ PinotDataType pinotDataType = scenario._dataType == FieldSpec.DataType.INT
+ ? PinotDataType.INTEGER :
PinotDataType.valueOf(scenario._dataType.name());
+
+ scenario.getDeclaringTable(true)
+ .onFirstInstance("myField",
+ "null",
+ "1",
+ "2"
+ ).andOnSecondInstance("myField",
+ "null",
+ "1",
+ "2"
+ ).whenQuery("select myField, mode(myField) as mode from testTable
group by myField order by myField")
+ .thenResultIs(pinotDataType + " | DOUBLE", "1 | 1", "2 | 2", "null |
null");
+ }
+
+ String aggrMvWithoutNullResult(FieldSpec.DataType dt) {
+ switch (dt) {
+ case INT: return "-2.147483648E9";
+ case LONG: return "-9.223372036854776E18";
+ case FLOAT: return "-Infinity";
+ case DOUBLE: return "-Infinity";
+ default: throw new IllegalArgumentException(dt.toString());
+ }
+ }
+
+ @Test(dataProvider = "scenarios")
+ void aggrMvWithoutNull(Scenario scenario) {
+ // TODO: This test is not actually exercising aggregateGroupByMV
+ scenario.getDeclaringTable(false)
+ .onFirstInstance("myField",
+ "null",
+ "1",
+ "null"
+ ).andOnSecondInstance("myField",
+ "null",
+ "1",
+ "null"
+ ).whenQuery("select 'cte1' as cte1, 'cte2' as cte2, mode(myField) as
mode from testTable group by cte1, cte2")
+ .thenResultIs("STRING | STRING | DOUBLE", "cte1 | cte2 | " +
aggrMvWithoutNullResult(scenario._dataType));
+ }
+
+ @Test(dataProvider = "scenarios")
+ void aggrMvWithNull(Scenario scenario) {
+ // TODO: This test is not actually exercising aggregateGroupByMV
+ scenario.getDeclaringTable(true)
+ .onFirstInstance("myField",
+ "null",
+ "1",
+ "null"
+ ).andOnSecondInstance("myField",
+ "null",
+ "1",
+ "null"
+ ).whenQuery("select 'cte1' as cte1, 'cte2' as cte2, mode(myField) as
mode from testTable group by cte1, cte2")
+ .thenResultIs("STRING | STRING | DOUBLE", "cte1 | cte2 | 1");
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/AllNullQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/AllNullQueriesTest.java
index ed8a5a716d..a2a720edfc 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/AllNullQueriesTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/AllNullQueriesTest.java
@@ -45,6 +45,7 @@ import org.apache.pinot.spi.data.readers.GenericRow;
import org.apache.pinot.spi.utils.ReadMode;
import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
import org.testng.annotations.BeforeClass;
+import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import static org.testng.Assert.assertEquals;
@@ -120,8 +121,8 @@ public class AllNullQueriesTest extends BaseQueriesTest {
_indexSegments = Arrays.asList(immutableSegment, immutableSegment);
}
- @Test
- public void testQueriesWithDictLongColumn()
+ @Test(dataProvider = "queries")
+ public void testQueriesWithDictLongColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.LONG;
TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE)
@@ -129,11 +130,11 @@ public class AllNullQueriesTest extends BaseQueriesTest {
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithDictLongColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- @Test(priority = 1)
- public void testQueriesWithNoDictLongColumn()
+ @Test(priority = 1, dataProvider = "queries")
+ public void testQueriesWithNoDictLongColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.LONG;
List<String> noDictionaryColumns = new ArrayList<String>();
@@ -144,11 +145,11 @@ public class AllNullQueriesTest extends BaseQueriesTest {
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithNoDictLongColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- @Test(priority = 2)
- public void testQueriesWithDictFloatColumn()
+ @Test(priority = 2, dataProvider = "queries")
+ public void testQueriesWithDictFloatColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.FLOAT;
TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE)
@@ -156,11 +157,11 @@ public class AllNullQueriesTest extends BaseQueriesTest {
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithDictFloatColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- @Test(priority = 3)
- public void testQueriesWithNoDictFloatColumn()
+ @Test(priority = 3, dataProvider = "queries")
+ public void testQueriesWithNoDictFloatColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.FLOAT;
List<String> noDictionaryColumns = new ArrayList<String>();
@@ -171,11 +172,11 @@ public class AllNullQueriesTest extends BaseQueriesTest {
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithNoDictFloatColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- @Test(priority = 4)
- public void testQueriesWithDictDoubleColumn()
+ @Test(priority = 4, dataProvider = "queries")
+ public void testQueriesWithDictDoubleColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.DOUBLE;
TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE)
@@ -183,11 +184,11 @@ public class AllNullQueriesTest extends BaseQueriesTest {
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithDictDoubleColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- @Test(priority = 5)
- public void testQueriesWithNoDictDoubleColumn()
+ @Test(priority = 5, dataProvider = "queries")
+ public void testQueriesWithNoDictDoubleColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.DOUBLE;
List<String> noDictionaryColumns = new ArrayList<String>();
@@ -198,11 +199,11 @@ public class AllNullQueriesTest extends BaseQueriesTest {
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithNoDictDoubleColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- @Test(priority = 6)
- public void testQueriesWithDictIntColumn()
+ @Test(priority = 6, dataProvider = "queries")
+ public void testQueriesWithDictIntColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.INT;
TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE)
@@ -210,11 +211,11 @@ public class AllNullQueriesTest extends BaseQueriesTest {
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithDictIntColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- @Test(priority = 7)
- public void testQueriesWithNoDictIntColumn()
+ @Test(priority = 7, dataProvider = "queries")
+ public void testQueriesWithNoDictIntColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.INT;
List<String> noDictionaryColumns = new ArrayList<String>();
@@ -225,11 +226,11 @@ public class AllNullQueriesTest extends BaseQueriesTest {
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithNoDictIntColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- @Test(priority = 8)
- public void testQueriesWithDictBigDecimalColumn()
+ @Test(priority = 8, dataProvider = "queries")
+ public void testQueriesWithDictBigDecimalColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.BIG_DECIMAL;
TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE)
@@ -237,11 +238,11 @@ public class AllNullQueriesTest extends BaseQueriesTest {
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithDictBigDecimalColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- @Test(priority = 9)
- public void testQueriesWithNoDictBigDecimalColumn()
+ @Test(priority = 9, dataProvider = "queries")
+ public void testQueriesWithNoDictBigDecimalColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.BIG_DECIMAL;
List<String> noDictionaryColumns = new ArrayList<String>();
@@ -252,11 +253,11 @@ public class AllNullQueriesTest extends BaseQueriesTest {
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithNoDictBigDecimalColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- @Test(priority = 10)
- public void testQueriesWithDictStringColumn()
+ @Test(priority = 10, dataProvider = "queries")
+ public void testQueriesWithDictStringColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.STRING;
TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE)
@@ -264,11 +265,11 @@ public class AllNullQueriesTest extends BaseQueriesTest {
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithDictStringColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- @Test(priority = 11)
- public void testQueriesWithNoDictStringColumn()
+ @Test(priority = 11, dataProvider = "queries")
+ public void testQueriesWithNoDictStringColumn(Query query)
throws Exception {
ColumnDataType columnDataType = ColumnDataType.STRING;
List<String> noDictionaryColumns = new ArrayList<String>();
@@ -279,304 +280,384 @@ public class AllNullQueriesTest extends BaseQueriesTest
{
.build();
File indexDir = new File(FileUtils.getTempDirectory(),
"AllNullWithNoDictStringColumnQueriesTest");
setUp(tableConfig, columnDataType.toDataType(), indexDir);
- testQueries(columnDataType, indexDir);
+ testQueries(columnDataType, indexDir, query);
}
- public void testQueries(ColumnDataType columnDataType, File indexDir)
- throws IOException {
- Map<String, String> queryOptions = new HashMap<>();
- queryOptions.put("enableNullHandling", "true");
- DataType dataType = columnDataType.toDataType();
- {
- String query = String.format("SELECT %s FROM testTable WHERE %s is null
limit 5000", COLUMN_NAME, COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 4000);
- for (Object[] row : rows) {
- assertNull(row[0]);
- }
- }
- {
- String query = String.format("SELECT %s FROM testTable WHERE %s is not
null limit 5000",
- COLUMN_NAME, COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 0);
- }
- if (columnDataType != ColumnDataType.STRING) {
- {
- String query = String.format(
- "SELECT count(*) as count1, count(%s) as count2, min(%s) as min,
max(%s) as max FROM testTable",
- COLUMN_NAME, COLUMN_NAME, COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{"count1",
"count2", "min", "max"}, new ColumnDataType[]{
- ColumnDataType.LONG, ColumnDataType.LONG, ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE
- }));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 1);
- Object[] row = rows.get(0);
- assertEquals(row.length, 4);
- // Note: count(*) returns total number of docs (nullable and
non-nullable).
- assertEquals((long) row[0], 1000 * 4);
- // count(col) returns the count of non-nullable docs.
- assertEquals((long) row[1], 0);
- assertNull(row[2]);
- assertNull(row[3]);
- }
- }
- {
- String query = "SELECT * FROM testTable";
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 10);
- for (int i = 0; i < 10; i++) {
- Object[] row = rows.get(i);
- assertEquals(row.length, 1);
- if (row[0] != null) {
- assertEquals(row[0], i);
- }
- }
- }
- {
- String query = String.format("SELECT * FROM testTable ORDER BY %s DESC
LIMIT 4000", COLUMN_NAME);
- // getBrokerResponseForSqlQuery(query) runs SQL query on multiple index
segments. The result should be equivalent
- // to querying 4 identical index segments.
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 4000);
- for (int i = 0; i < 4000; i += 4) {
- for (int j = 0; j < 4; j++) {
- Object[] values = rows.get(i + j);
- assertEquals(values.length, 1);
- assertNull(values[0]);
- }
- }
- }
- {
- String query = String.format("SELECT DISTINCT %s FROM testTable ORDER BY
%s", COLUMN_NAME, COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 1);
- for (Object[] row : rows) {
- assertEquals(row.length, 1);
- assertNull(row[0]);
- }
- }
- {
- int limit = 40;
- String query = String.format("SELECT DISTINCT %s FROM testTable ORDER BY
%s LIMIT %d", COLUMN_NAME, COLUMN_NAME,
- limit);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- for (Object[] row : rows) {
- assertEquals(row.length, 1);
- assertNull(row[0]);
- }
- }
- {
- // This test case was added to validate path-code for distinct w/o order
by.
- int limit = 40;
- String query = String.format("SELECT DISTINCT %s FROM testTable LIMIT
%d", COLUMN_NAME, limit);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 1);
- }
- if (columnDataType != ColumnDataType.STRING) {
- {
- String query = String.format("SELECT COUNT(%s) AS count, MIN(%s) AS
min, MAX(%s) AS max, AVG(%s) AS avg,"
- + " SUM(%s) AS sum FROM testTable LIMIT 1000", COLUMN_NAME,
COLUMN_NAME, COLUMN_NAME, COLUMN_NAME,
- COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{"count", "min",
"max", "avg", "sum"},
- new ColumnDataType[] {
- ColumnDataType.LONG, ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
- ColumnDataType.DOUBLE
- }));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 1);
- assertEquals((long) rows.get(0)[0], 0);
- assertNull(rows.get(0)[1]);
- assertNull(rows.get(0)[2]);
- assertNull(rows.get(0)[3]);
- assertNull(rows.get(0)[4]);
- }
- }
- {
- String query = String.format("SELECT %s FROM testTable GROUP BY %s ORDER
BY %s DESC", COLUMN_NAME, COLUMN_NAME,
- COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 1);
- assertEquals(rows.get(0).length, 1);
- assertNull(rows.get(0)[0]);
- }
- {
- String query = String.format(
- "SELECT COUNT(*) AS count, %s FROM testTable GROUP BY %s ORDER BY %s
DESC LIMIT 1000", COLUMN_NAME,
- COLUMN_NAME, COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{"count",
COLUMN_NAME},
- new ColumnDataType[]{ColumnDataType.LONG, columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 1);
- Object[] row = rows.get(0);
- assertEquals(row[0], 4000L);
- assertNull(row[1]);
- }
- {
- String query = String.format("SELECT SUMPRECISION(%s) AS sum FROM
testTable", COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{"sum"}, new
ColumnDataType[]{ColumnDataType.STRING}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 1);
- assertNull(rows.get(0)[0]);
+ public static abstract class Query {
+ private final String _query;
+
+ public Query(String query) {
+ _query = query;
}
- if (columnDataType != ColumnDataType.STRING) {
- {
- // Note: in Presto, inequality, equality, and IN comparison with nulls
always returns false:
- long lowerLimit = 69;
- String query =
- String.format("SELECT %s FROM testTable WHERE %s > '%s' LIMIT 50",
COLUMN_NAME, COLUMN_NAME, lowerLimit);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
- // Pinot loops through the column values from smallest to biggest.
Null comparison always returns false.
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 0);
- }
+
+ @Override
+ public String toString() {
+ return "Query{" + _query + '}';
}
- {
- String query = String.format("SELECT %s FROM testTable WHERE %s = '%s'",
COLUMN_NAME, COLUMN_NAME, 68);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 0);
+
+ public String getQuery() {
+ return _query;
}
- {
- String query = String.format("SELECT %s FROM testTable WHERE %s = '%s'",
COLUMN_NAME, COLUMN_NAME, 69);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema,
- new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 0);
+
+ public boolean skip(ColumnDataType columnDataType) {
+ return false;
}
- if (columnDataType != ColumnDataType.STRING) {
- {
- String query = String.format("SELECT COUNT(%s) AS count, MIN(%s) AS
min, MAX(%s) AS max, SUM(%s) AS sum" + " "
- + "FROM testTable GROUP BY %s ORDER BY max", COLUMN_NAME,
COLUMN_NAME, COLUMN_NAME, COLUMN_NAME,
- COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{"count", "min",
"max", "sum"}, new ColumnDataType[]{
- ColumnDataType.LONG, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE
- }));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 1);
- Object[] row = rows.get(0);
- assertEquals(row.length, 4);
- // Count(column) return 0 if all values are nulls.
- assertEquals(row[0], 0L);
- assertNull(row[1]);
- assertNull(row[2]);
- assertNull(row[3]);
- }
- {
- String query = String.format(
- "SELECT AVG(%s) AS avg FROM testTable GROUP BY %s ORDER BY avg
LIMIT 20", COLUMN_NAME, COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{"avg"}, new
ColumnDataType[]{ColumnDataType.DOUBLE}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 1);
- Object[] row = rows.get(0);
- assertEquals(row.length, 1);
- assertNull(row[0]);
- }
- {
- // MODE cannot handle BIG_DECIMAL yet.
- if (dataType != DataType.BIG_DECIMAL) {
- String query = String.format("SELECT AVG(%s) AS avg, MODE(%s) AS
mode, DISTINCTCOUNT(%s) as distinct_count"
- + " FROM testTable GROUP BY %s ORDER BY %s LIMIT 200",
COLUMN_NAME, COLUMN_NAME, COLUMN_NAME,
- COLUMN_NAME, COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{"avg", "mode",
"distinct_count"},
- new ColumnDataType[]{ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE, ColumnDataType.INT}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 1);
- Object[] row = rows.get(0);
- assertEquals(row.length, 3);
- assertNull(row[0]);
- // TODO: this should return null instead of default value.
- if (dataType == DataType.DOUBLE || dataType == DataType.FLOAT) {
- assertEquals(row[1], Double.NEGATIVE_INFINITY);
- } else if (dataType == DataType.LONG) {
- assertEquals(((Double) row[1]).longValue(), Long.MIN_VALUE);
+
+ public abstract void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse);
+ }
+
+ @DataProvider(name = "queries")
+ public static Query[] queries() {
+ return new Query[] {
+ new Query(String.format("SELECT %s FROM testTable WHERE %s is null
limit 5000", COLUMN_NAME, COLUMN_NAME)) {
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema,
+ new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 4000);
+ for (Object[] row : rows) {
+ assertNull(row[0]);
+ }
+ }
+ },
+ new Query(String.format("SELECT %s FROM testTable WHERE %s is not null
limit 5000", COLUMN_NAME, COLUMN_NAME)) {
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema,
+ new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 0);
+ }
+ },
+ new Query(String.format(
+ "SELECT count(*) as count1, count(%s) as count2, min(%s) as min,
max(%s) as max FROM testTable",
+ COLUMN_NAME, COLUMN_NAME, COLUMN_NAME)) {
+ @Override
+ public boolean skip(ColumnDataType columnDataType) {
+ return columnDataType != ColumnDataType.STRING;
+ }
+
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"count1",
"count2", "min", "max"},
+ new ColumnDataType[]{
+ ColumnDataType.LONG, ColumnDataType.LONG,
ColumnDataType.DOUBLE, ColumnDataType.DOUBLE
+ }));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ Object[] row = rows.get(0);
+ assertEquals(row.length, 4);
+ // Note: count(*) returns total number of docs (nullable and
non-nullable).
+ assertEquals((long) row[0], 1000 * 4);
+ // count(col) returns the count of non-nullable docs.
+ assertEquals((long) row[1], 0);
+ assertNull(row[2]);
+ assertNull(row[3]);
+ }
+ },
+ new Query("SELECT * FROM testTable") {
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{COLUMN_NAME},
new ColumnDataType[]{columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 10);
+ for (int i = 0; i < 10; i++) {
+ Object[] row = rows.get(i);
+ assertEquals(row.length, 1);
+ if (row[0] != null) {
+ assertEquals(row[0], i);
+ }
+ }
+ }
+ },
+ new Query(String.format("SELECT * FROM testTable ORDER BY %s DESC
LIMIT 4000", COLUMN_NAME)) {
+ // getBrokerResponseForSqlQuery(query) runs SQL query on multiple
index segments. The result should
+ // be equivalent to querying 4 identical index segments.
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema,
+ new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 4000);
+ for (int i = 0; i < 4000; i += 4) {
+ for (int j = 0; j < 4; j++) {
+ Object[] values = rows.get(i + j);
+ assertEquals(values.length, 1);
+ assertNull(values[0]);
+ }
+ }
+ }
+ },
+ new Query(String.format("SELECT DISTINCT %s FROM testTable ORDER BY
%s", COLUMN_NAME, COLUMN_NAME)) {
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema,
+ new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ for (Object[] row : rows) {
+ assertEquals(row.length, 1);
+ assertNull(row[0]);
+ }
+ }
+ },
+ new Query(String.format("SELECT DISTINCT %s FROM testTable ORDER BY %s
LIMIT %d",
+ COLUMN_NAME, COLUMN_NAME, 40)) {
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema,
+ new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ for (Object[] row : rows) {
+ assertEquals(row.length, 1);
+ assertNull(row[0]);
+ }
+ }
+ },
+ new Query(String.format("SELECT DISTINCT %s FROM testTable LIMIT %d",
COLUMN_NAME, 40)) {
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema,
+ new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ }
+ },
+ new Query(
+ String.format("SELECT COUNT(%s) AS count, MIN(%s) AS min, MAX(%s)
AS max, AVG(%s) AS avg,"
+ + " SUM(%s) AS sum FROM testTable LIMIT 1000",
COLUMN_NAME, COLUMN_NAME, COLUMN_NAME, COLUMN_NAME,
+ COLUMN_NAME)
+ ) {
+ @Override
+ public boolean skip(ColumnDataType columnDataType) {
+ return columnDataType != ColumnDataType.STRING;
+ }
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"count",
"min", "max", "avg", "sum"},
+ new ColumnDataType[] {
+ ColumnDataType.LONG, ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
+ ColumnDataType.DOUBLE
+ }));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertEquals((long) rows.get(0)[0], 0);
+ assertNull(rows.get(0)[1]);
+ assertNull(rows.get(0)[2]);
+ assertNull(rows.get(0)[3]);
+ assertNull(rows.get(0)[4]);
+ }
+ },
+ new Query(String.format("SELECT %s FROM testTable GROUP BY %s ORDER BY
%s DESC", COLUMN_NAME, COLUMN_NAME,
+ COLUMN_NAME)) {
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{COLUMN_NAME},
new ColumnDataType[]{columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertEquals(rows.get(0).length, 1);
+ assertNull(rows.get(0)[0]);
+ }
+ },
+ new Query(String.format(
+ "SELECT COUNT(*) AS count, %s FROM testTable GROUP BY %s ORDER BY
%s DESC LIMIT 1000", COLUMN_NAME,
+ COLUMN_NAME, COLUMN_NAME)) {
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"count",
COLUMN_NAME},
+ new ColumnDataType[]{ColumnDataType.LONG, columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ Object[] row = rows.get(0);
+ assertEquals(row[0], 4000L);
+ assertNull(row[1]);
+ }
+ },
+ new Query(String.format("SELECT SUMPRECISION(%s) AS sum FROM
testTable", COLUMN_NAME)) {
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"sum"}, new
ColumnDataType[]{ColumnDataType.STRING}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertNull(rows.get(0)[0]);
+ }
+ },
+ new Query(String.format("SELECT %s FROM testTable WHERE %s > '%s'
LIMIT 50", COLUMN_NAME, COLUMN_NAME, 69)) {
+ // Note: in Presto, inequality, equality, and IN comparison with
nulls always returns false:
+ @Override
+ public boolean skip(ColumnDataType columnDataType) {
+ return columnDataType != ColumnDataType.STRING;
+ }
+
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{COLUMN_NAME},
new ColumnDataType[]{columnDataType}));
+ // Pinot loops through the column values from smallest to biggest.
Null comparison always returns false.
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 0);
+ }
+ },
+ new Query(String.format("SELECT %s FROM testTable WHERE %s = '%s'",
COLUMN_NAME, COLUMN_NAME, 68)) {
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema,
+ new DataSchema(new String[]{COLUMN_NAME}, new
ColumnDataType[]{columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 0);
+ }
+ },
+ new Query(String.format("SELECT %s FROM testTable WHERE %s = '%s'",
COLUMN_NAME, COLUMN_NAME, 69)) {
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"count",
"min", "max", "sum"}, new ColumnDataType[]{
+ ColumnDataType.LONG, ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE, ColumnDataType.DOUBLE
+ }));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ Object[] row = rows.get(0);
+ assertEquals(row.length, 4);
+ // Count(column) return 0 if all values are nulls.
+ assertEquals(row[0], 0L);
+ assertNull(row[1]);
+ assertNull(row[2]);
+ assertNull(row[3]);
+ }
+ },
+ new Query(
+ String.format("SELECT COUNT(%s) AS count, MIN(%s) AS min, MAX(%s)
AS max, SUM(%s) AS sum" + " "
+ + "FROM testTable GROUP BY %s ORDER BY max", COLUMN_NAME,
COLUMN_NAME, COLUMN_NAME, COLUMN_NAME,
+ COLUMN_NAME)
+ ) {
+ @Override
+ public boolean skip(ColumnDataType columnDataType) {
+ return columnDataType != ColumnDataType.STRING;
+ }
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"count",
"min", "max", "sum"}, new ColumnDataType[]{
+ ColumnDataType.LONG, ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE, ColumnDataType.DOUBLE
+ }));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ Object[] row = rows.get(0);
+ assertEquals(row.length, 4);
+ // Count(column) return 0 if all values are nulls.
+ assertEquals(row[0], 0L);
+ assertNull(row[1]);
+ assertNull(row[2]);
+ assertNull(row[3]);
+ }
+ },
+ new Query(String.format(
+ "SELECT AVG(%s) AS avg FROM testTable GROUP BY %s ORDER BY avg
LIMIT 20", COLUMN_NAME, COLUMN_NAME)) {
+
+ @Override
+ public boolean skip(ColumnDataType columnDataType) {
+ return columnDataType != ColumnDataType.STRING;
+ }
+
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"avg"}, new
ColumnDataType[]{ColumnDataType.DOUBLE}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ Object[] row = rows.get(0);
+ assertEquals(row.length, 1);
+ assertNull(row[0]);
+ }
+ },
+ new Query(String.format("SELECT AVG(%s) AS avg, MODE(%s) AS mode,
DISTINCTCOUNT(%s) as distinct_count"
+ + " FROM testTable GROUP BY %s ORDER BY %s LIMIT 200",
COLUMN_NAME, COLUMN_NAME, COLUMN_NAME,
+ COLUMN_NAME, COLUMN_NAME)) {
+ @Override
+ public boolean skip(ColumnDataType columnDataType) {
+ return columnDataType != ColumnDataType.STRING && columnDataType
!= ColumnDataType.BIG_DECIMAL;
+ }
+
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"avg",
"mode", "distinct_count"},
+ new ColumnDataType[]{ColumnDataType.DOUBLE,
ColumnDataType.DOUBLE, ColumnDataType.INT}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ Object[] row = rows.get(0);
+ assertEquals(row.length, 3);
+ assertNull(row[0]);
+ assertNull(row[1]);
+ assertEquals(row[2], 0);
+ }
+ },
+ new Query(String.format("SELECT MAX(%s) AS max, %s FROM testTable
GROUP BY %s ORDER BY max LIMIT 501",
+ COLUMN_NAME, COLUMN_NAME, COLUMN_NAME)) {
+
+ @Override
+ public boolean skip(ColumnDataType columnDataType) {
+ return columnDataType != ColumnDataType.STRING;
+ }
+ @Override
+ public void verify(ColumnDataType columnDataType,
BrokerResponseNative brokerResponse) {
+ ResultTable resultTable = brokerResponse.getResultTable();
+ DataSchema dataSchema = resultTable.getDataSchema();
+ assertEquals(dataSchema, new DataSchema(new String[]{"max",
COLUMN_NAME},
+ new ColumnDataType[]{ColumnDataType.DOUBLE, columnDataType}));
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertNull(rows.get(0)[0]);
}
- assertEquals(row[2], 0);
}
- }
- {
- // If updated limit to include all records, I get back results
unsorted.
- String query = String.format("SELECT MAX(%s) AS max, %s FROM testTable
GROUP BY %s ORDER BY max LIMIT 501",
- COLUMN_NAME, COLUMN_NAME, COLUMN_NAME);
- BrokerResponseNative brokerResponse = getBrokerResponse(query,
queryOptions);
- ResultTable resultTable = brokerResponse.getResultTable();
- DataSchema dataSchema = resultTable.getDataSchema();
- assertEquals(dataSchema, new DataSchema(new String[]{"max",
COLUMN_NAME},
- new ColumnDataType[]{ColumnDataType.DOUBLE, columnDataType}));
- List<Object[]> rows = resultTable.getRows();
- assertEquals(rows.size(), 1);
- assertNull(rows.get(0)[0]);
- }
+ };
+ }
+
+ public void testQueries(ColumnDataType columnDataType, File indexDir, Query
query)
+ throws IOException {
+ Map<String, String> queryOptions = new HashMap<>();
+ queryOptions.put("enableNullHandling", "true");
+
+ if (!query.skip(columnDataType)) {
+ return;
}
+ String queryStr = query.getQuery();
+ BrokerResponseNative brokerResponse = getBrokerResponse(queryStr,
queryOptions);
+
+ query.verify(columnDataType, brokerResponse);
+
DataTableBuilderFactory.setDataTableVersion(DataTableBuilderFactory.DEFAULT_VERSION);
_indexSegment.destroy();
FileUtils.deleteDirectory(indexDir);
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/NullEnabledQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/NullEnabledQueriesTest.java
index 5a57945ca7..41a664c5f1 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/NullEnabledQueriesTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/NullEnabledQueriesTest.java
@@ -697,8 +697,10 @@ public class NullEnabledQueriesTest extends
BaseQueriesTest {
}
Object[] row = rows.get(index);
assertEquals(row.length, 3);
- assertTrue(Math.abs((Double) row[0] - (baseValue.doubleValue() + i)) <
1e-1);
- assertTrue(Math.abs((Double) row[1] - (baseValue.doubleValue() + i)) <
1e-1);
+
+ double expected = baseValue.doubleValue() + i;
+ assertTrue(Math.abs((Double) row[0] - expected) < 1e-1, "Col 0:
Expected " + expected + " found " + row[0]);
+ assertTrue(Math.abs((Double) row[1] - expected) < 1e-1, "Col 1:
Expected " + expected + " found " + row[1]);
assertEquals(row[2], 1);
i++;
}
diff --git
a/pinot-perf/src/main/java/org/apache/pinot/perf/AbstractAggregationFunctionBenchmark.java
b/pinot-perf/src/main/java/org/apache/pinot/perf/AbstractAggregationFunctionBenchmark.java
new file mode 100644
index 0000000000..c23b4ce89b
--- /dev/null
+++
b/pinot-perf/src/main/java/org/apache/pinot/perf/AbstractAggregationFunctionBenchmark.java
@@ -0,0 +1,217 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.perf;
+
+import java.util.Map;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.Level;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.infra.Blackhole;
+import org.testng.Assert;
+
+
+public abstract class AbstractAggregationFunctionBenchmark {
+
+ /**
+ * Returns the aggregation function to benchmark.
+ *
+ * This method will be called in the benchmark method, so it must be fast.
+ */
+ protected abstract AggregationFunction<?, ?> getAggregationFunction();
+
+ /**
+ * Returns the result holder to use for the aggregation function.
+ *
+ * This method will be called in the benchmark method, so it must be fast.
+ * @return
+ */
+ protected abstract AggregationResultHolder getResultHolder();
+
+ /**
+ * Resets the result holder to prepare for the next aggregation.
+ *
+ * This method will be called in the benchmark method, so it must be fast.
+ */
+ protected abstract void resetResultHolder(AggregationResultHolder
resultHolder);
+
+ /**
+ * Returns the expected result of the aggregation function.
+ *
+ * This method will be called in the benchmark method, so it must be fast.
+ */
+ protected abstract Object getExpectedResult();
+
+ /**
+ * Returns the block value set map to use for the aggregation function.
+ *
+ * This method will be called in the benchmark method, so it must be fast.
+ */
+ protected abstract Map<ExpressionContext, BlockValSet> getBlockValSetMap();
+
+ /**
+ * Verifies the final result of the aggregation function.
+ *
+ * This method will be called in the benchmark method, so it must be fast.
+ */
+ protected void verifyResult(Blackhole bh, Comparable finalResult, Object
expectedResult) {
+ Assert.assertEquals(finalResult, expectedResult);
+ bh.consume(finalResult);
+ }
+
+ /**
+ * Base class for benchmarks that are stable on each {@link Level#Trial} or
{@link Level#Iteration}.
+ */
+ public static abstract class Stable extends
AbstractAggregationFunctionBenchmark {
+ protected AggregationFunction<?, ?> _aggregationFunction;
+ protected AggregationResultHolder _resultHolder;
+ protected Object _expectedResult;
+ protected Map<ExpressionContext, BlockValSet> _blockValSetMap;
+
+ /**
+ * Returns the level at which the aggregation function should be created.
+ *
+ * By default, the aggregation function is created at the {@link
Level#Trial} level.
+ */
+ protected Level getAggregationFunctionLevel() {
+ return Level.Trial;
+ }
+
+ /**
+ * Creates the aggregation function to benchmark.
+ *
+ * This method will be called at the level returned by {@link
#getAggregationFunctionLevel()}.
+ * Therefore, time spent here is not counted towards the benchmark time.
+ */
+ protected abstract AggregationFunction<?, ?> createAggregationFunction();
+
+ /**
+ * Returns the level at which the result holder should be created.
+ *
+ * By default, the result holder is created at the {@link Level#Trial}
level.
+ */
+ protected Level getResultHolderLevel() {
+ return Level.Trial;
+ }
+
+ /**
+ * Creates the result holder to use for the aggregation function.
+ *
+ * This method will be called at the level returned by {@link
#getResultHolderLevel()}.
+ * Therefore, time spent here is not counted towards the benchmark time.
+ */
+ protected abstract AggregationResultHolder createResultHolder();
+
+ /**
+ * Returns the level at which the block value set map should be created.
+ *
+ * By default, the block value set map is created at the {@link
Level#Trial} level.
+ */
+ protected Level getBlockValSetMapLevel() {
+ return Level.Trial;
+ }
+
+ /**
+ * Creates the block value set map to use for the aggregation function.
+ *
+ * This method will be called at the level returned by {@link
#getBlockValSetMapLevel()}.
+ * Therefore, time spent here is not counted towards the benchmark time.
+ */
+ protected abstract Map<ExpressionContext, BlockValSet>
createBlockValSetMap();
+
+ /**
+ * Returns the level at which the expected result should be created.
+ *
+ * By default, the expected result is created at the {@link Level#Trial}
level.
+ */
+ protected Level getExpectedResultLevel() {
+ return Level.Trial;
+ }
+
+ /**
+ * Creates the expected result of the aggregation function.
+ *
+ * This method will be called at the level returned by {@link
#getExpectedResultLevel()}.
+ * Therefore, time spent here is not counted towards the benchmark time.
+ */
+ protected abstract Object createExpectedResult(Map<ExpressionContext,
BlockValSet> map);
+
+ @Override
+ protected AggregationFunction<?, ?> getAggregationFunction() {
+ return _aggregationFunction;
+ }
+
+ @Override
+ protected Object getExpectedResult() {
+ return _expectedResult;
+ }
+
+ @Override
+ protected AggregationResultHolder getResultHolder() {
+ return _resultHolder;
+ }
+
+ @Override
+ public Map<ExpressionContext, BlockValSet> getBlockValSetMap() {
+ return _blockValSetMap;
+ }
+
+ @Setup(Level.Trial)
+ public void setupTrial() {
+ onSetupLevel(Level.Trial);
+ }
+
+ @Setup(Level.Iteration)
+ public void setupIteration() {
+ onSetupLevel(Level.Iteration);
+ }
+
+ private void onSetupLevel(Level level) {
+ if (getAggregationFunctionLevel() == level) {
+ _aggregationFunction = createAggregationFunction();
+ }
+ if (getResultHolderLevel() == level) {
+ _resultHolder = createResultHolder();
+ }
+ if (getBlockValSetMapLevel() == level) {
+ _blockValSetMap = createBlockValSetMap();
+ }
+ if (getExpectedResultLevel() == level) {
+ _expectedResult = createExpectedResult(_blockValSetMap);
+ }
+ }
+ }
+
+ @Benchmark
+ public void test(Blackhole bh) {
+ AggregationResultHolder resultHolder = getResultHolder();
+ resetResultHolder(resultHolder);
+ Map<ExpressionContext, BlockValSet> blockValSetMap = getBlockValSetMap();
+
+ getAggregationFunction().aggregate(DocIdSetPlanNode.MAX_DOC_PER_CALL,
resultHolder, blockValSetMap);
+
+ Comparable finalResult =
getAggregationFunction().extractFinalResult(resultHolder.getResult());
+
+ verifyResult(bh, finalResult, getExpectedResult());
+ }
+}
diff --git
a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkModeAggregation.java
b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkModeAggregation.java
new file mode 100644
index 0000000000..b1051a1476
--- /dev/null
+++
b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkModeAggregation.java
@@ -0,0 +1,175 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.perf;
+
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.TimeUnit;
+import java.util.function.LongSupplier;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import
org.apache.pinot.core.query.aggregation.function.ModeAggregationFunction;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Level;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.runner.Runner;
+import org.openjdk.jmh.runner.RunnerException;
+import org.openjdk.jmh.runner.options.Options;
+import org.openjdk.jmh.runner.options.OptionsBuilder;
+import org.roaringbitmap.RoaringBitmap;
+
+@Fork(1)
+@BenchmarkMode(Mode.Throughput)
+@Warmup(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS)
+@Measurement(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS)
+@OutputTimeUnit(TimeUnit.MILLISECONDS)
+@State(Scope.Benchmark)
+public class BenchmarkModeAggregation extends
AbstractAggregationFunctionBenchmark.Stable {
+ private static final ExpressionContext EXPR =
ExpressionContext.forIdentifier("col");
+ @Param({"100", "50", "0"})
+ public int _nullHandlingEnabledPerCent;
+ private boolean _nullHandlingEnabled;
+ @Param({"2", "4", "8", "16", "32", "64", "128"})
+ protected int _nullPeriod;
+ private final Random _segmentNullRandomGenerator = new Random(42);
+ private double _modeIgnoringNull;
+ private double _modeNullAware;
+ private final int _numDocs = DocIdSetPlanNode.MAX_DOC_PER_CALL;
+
+ public static void main(String[] args)
+ throws RunnerException {
+ Options opt = new
OptionsBuilder().include(BenchmarkModeAggregation.class.getSimpleName())
+ // .addProfiler(LinuxPerfAsmProfiler.class)
+ .build();
+
+ new Runner(opt).run();
+ }
+
+ @Override
+ protected Map<ExpressionContext, BlockValSet> createBlockValSetMap() {
+ Random valueRandom = new Random(42);
+ int tries = 3;
+ LongSupplier longSupplier = () -> {
+ // just a simple distribution that will have some kind of normal but
limited distribution
+ int sum = 0;
+ for (int i = 0; i < tries; i++) {
+ sum += valueRandom.nextInt(_numDocs);
+ }
+ return sum / tries;
+ };
+ RoaringBitmap nullBitmap =
SyntheticNullBitmapFactories.Periodic.randomInPeriod(_numDocs, _nullPeriod);
+
+ BlockValSet block = SyntheticBlockValSets.Long.create(_numDocs,
nullBitmap, longSupplier);
+ return Map.of(EXPR, block);
+ }
+
+ @Override
+ public void setupTrial() {
+ super.setupTrial();
+
+ HashMap<Long, Integer> ignoringNullDistribution = new HashMap<>();
+ HashMap<Long, Integer> nullAwareDistribution = new HashMap<>();
+
+ BlockValSet blockValSet = getBlockValSetMap().get(EXPR);
+ long[] longValuesSV = blockValSet.getLongValuesSV();
+ RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
+ if (nullBitmap != null) {
+ for (int i = 0; i < _numDocs; i++) {
+ long value = longValuesSV[i];
+ ignoringNullDistribution.merge(value, 1, Integer::sum);
+ if (!nullBitmap.contains(i)) {
+ nullAwareDistribution.merge(value, 1, Integer::sum);
+ }
+ }
+ } else {
+ for (int i = 0; i < _numDocs; i++) {
+ long value = longValuesSV[i];
+ ignoringNullDistribution.merge(value, 1, Integer::sum);
+ nullAwareDistribution.merge(value, 1, Integer::sum);
+ }
+ }
+ _modeIgnoringNull = ignoringNullDistribution.entrySet().stream()
+ .max(Comparator.comparingInt(Map.Entry::getValue))
+ .get()
+ .getKey()
+ .doubleValue();
+ _modeNullAware = nullAwareDistribution.entrySet().stream()
+ .max(Comparator.comparingInt(Map.Entry::getValue))
+ .get()
+ .getKey()
+ .doubleValue();
+ }
+
+ @Override
+ public void setupIteration() {
+ _nullHandlingEnabled = _segmentNullRandomGenerator.nextInt(100) <
_nullHandlingEnabledPerCent;
+ super.setupIteration();
+ }
+
+ @Override
+ protected Level getAggregationFunctionLevel() {
+ return Level.Iteration;
+ }
+
+ @Override
+ protected AggregationFunction<?, ?> createAggregationFunction() {
+ return new ModeAggregationFunction(Collections.singletonList(EXPR),
_nullHandlingEnabled);
+ }
+
+ @Override
+ protected Level getResultHolderLevel() {
+ return Level.Iteration;
+ }
+
+ @Override
+ protected AggregationResultHolder createResultHolder() {
+ return getAggregationFunction().createAggregationResultHolder();
+ }
+
+ @Override
+ protected void resetResultHolder(AggregationResultHolder resultHolder) {
+ Map<? extends Number, Long> result = resultHolder.getResult();
+ if (result != null) {
+ result.clear();
+ }
+ }
+
+ @Override
+ protected Level getExpectedResultLevel() {
+ return Level.Iteration;
+ }
+
+ @Override
+ protected Object createExpectedResult(Map<ExpressionContext, BlockValSet>
map) {
+ return _nullHandlingEnabled ? _modeNullAware : _modeIgnoringNull;
+ }
+}
diff --git
a/pinot-perf/src/main/java/org/apache/pinot/perf/SyntheticBlockValSets.java
b/pinot-perf/src/main/java/org/apache/pinot/perf/SyntheticBlockValSets.java
new file mode 100644
index 0000000000..0214a82c2f
--- /dev/null
+++ b/pinot-perf/src/main/java/org/apache/pinot/perf/SyntheticBlockValSets.java
@@ -0,0 +1,256 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.perf;
+
+import com.google.common.base.Preconditions;
+import java.math.BigDecimal;
+import java.util.function.DoubleSupplier;
+import java.util.function.LongSupplier;
+import javax.annotation.Nullable;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.segment.spi.index.reader.Dictionary;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.roaringbitmap.RoaringBitmap;
+
+/**
+ * Synthetic {@link BlockValSet} for testing and benchmarking.
+ */
+public class SyntheticBlockValSets {
+ private SyntheticBlockValSets() {
+ }
+
+ /**
+ * Base class for synthetic {@link BlockValSet}.
+ *
+ * Most of its methods throw {@link UnsupportedOperationException} and
should be overridden by subclasses if they
+ * need to be used.
+ */
+ public static abstract class Base implements BlockValSet {
+ @Nullable
+ @Override
+ public RoaringBitmap getNullBitmap() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Nullable
+ @Override
+ public Dictionary getDictionary() {
+ return null;
+ }
+
+ @Override
+ public int[] getDictionaryIdsSV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int[] getIntValuesSV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long[] getLongValuesSV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float[] getFloatValuesSV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double[] getDoubleValuesSV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public BigDecimal[] getBigDecimalValuesSV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String[] getStringValuesSV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public byte[][] getBytesValuesSV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int[][] getDictionaryIdsMV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int[][] getIntValuesMV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long[][] getLongValuesMV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float[][] getFloatValuesMV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double[][] getDoubleValuesMV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String[][] getStringValuesMV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public byte[][][] getBytesValuesMV() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int[] getNumMVEntries() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ /**
+ * A simple {@link BlockValSet} for nullable, not dictionary-encoded long
values.
+ */
+ public static class Long extends Base {
+
+ @Nullable
+ final RoaringBitmap _nullBitmap;
+ final long[] _values;
+
+ private Long(@Nullable RoaringBitmap nullBitmap, long[] values) {
+ _nullBitmap = nullBitmap;
+ _values = values;
+ }
+
+ public static Long create() {
+ return create(Distribution.createLongSupplier(42, "EXP(0.5)"));
+ }
+
+ public static Long create(LongSupplier supplier) {
+ return create(DocIdSetPlanNode.MAX_DOC_PER_CALL, null, supplier);
+ }
+
+ public static Long create(@Nullable RoaringBitmap nullBitmap, LongSupplier
supplier) {
+ return create(DocIdSetPlanNode.MAX_DOC_PER_CALL, nullBitmap, supplier);
+ }
+
+ public static Long create(int numDocs, @Nullable RoaringBitmap nullBitmap,
LongSupplier supplier) {
+ Preconditions.checkArgument(nullBitmap == null || nullBitmap.last() <
numDocs,
+ "null bitmap larger than numDocs");
+ long[] values = new long[numDocs];
+ for (int i = 0; i < numDocs; i++) {
+ values[i] = supplier.getAsLong();
+ }
+
+ return new Long(nullBitmap, values);
+ }
+
+ @Nullable
+ @Override
+ public RoaringBitmap getNullBitmap() {
+ return _nullBitmap;
+ }
+
+ @Override
+ public FieldSpec.DataType getValueType() {
+ return FieldSpec.DataType.LONG;
+ }
+
+ @Override
+ public boolean isSingleValue() {
+ return true;
+ }
+
+ @Override
+ public long[] getLongValuesSV() {
+ return _values;
+ }
+ }
+
+ /**
+ * A simple {@link BlockValSet} for nullable, not dictionary-encoded double
values.
+ */
+ public static class Double extends Base {
+
+ @Nullable
+ final RoaringBitmap _nullBitmap;
+ final double[] _values;
+
+ private Double(@Nullable RoaringBitmap nullBitmap, double[] values) {
+ _nullBitmap = nullBitmap;
+ _values = values;
+ }
+
+ public static Double create() {
+ return create(Distribution.createDoubleSupplier(42, "EXP(0.5)"));
+ }
+
+ public static Double create(DoubleSupplier supplier) {
+ return create(DocIdSetPlanNode.MAX_DOC_PER_CALL, null, supplier);
+ }
+
+ public static Double create(@Nullable RoaringBitmap nullBitmap,
DoubleSupplier supplier) {
+ return create(DocIdSetPlanNode.MAX_DOC_PER_CALL, nullBitmap, supplier);
+ }
+
+ public static Double create(int numDocs, @Nullable RoaringBitmap
nullBitmap, DoubleSupplier supplier) {
+ Preconditions.checkArgument(nullBitmap == null || nullBitmap.last() <
numDocs,
+ "null bitmap larger than numDocs");
+ double[] values = new double[numDocs];
+ for (int i = 0; i < numDocs; i++) {
+ values[i] = supplier.getAsDouble();
+ }
+
+ return new Double(nullBitmap, values);
+ }
+
+ @Nullable
+ @Override
+ public RoaringBitmap getNullBitmap() {
+ return _nullBitmap;
+ }
+
+ @Override
+ public FieldSpec.DataType getValueType() {
+ return FieldSpec.DataType.LONG;
+ }
+
+ @Override
+ public boolean isSingleValue() {
+ return true;
+ }
+
+ @Override
+ public double[] getDoubleValuesSV() {
+ return _values;
+ }
+ }
+}
diff --git
a/pinot-perf/src/main/java/org/apache/pinot/perf/SyntheticNullBitmapFactories.java
b/pinot-perf/src/main/java/org/apache/pinot/perf/SyntheticNullBitmapFactories.java
new file mode 100644
index 0000000000..6f67c2a962
--- /dev/null
+++
b/pinot-perf/src/main/java/org/apache/pinot/perf/SyntheticNullBitmapFactories.java
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.perf;
+
+import java.util.Random;
+import java.util.function.IntSupplier;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.roaringbitmap.RoaringBitmap;
+
+
+/**
+ * Synthetic null bitmap suppliers for testing and benchmarking.
+ */
+public class SyntheticNullBitmapFactories {
+ private SyntheticNullBitmapFactories() {
+ }
+
+ /**
+ * Null bitmap factories that generate null bitmaps with periodic patterns.
+ */
+ public static class Periodic {
+ private Periodic() {
+ }
+
+ /**
+ * Returns a null bitmap with the first doc in each period set as null.
+ */
+ public static RoaringBitmap firstInPeriod(int numDocs, int period) {
+ return periodic(numDocs, period, () -> 0);
+ }
+
+ /**
+ * Returns a null bitmap with the last doc in each period set as null.
+ */
+ public static RoaringBitmap lastInPeriod(int numDocs, int period) {
+ return periodic(numDocs, period, () -> period - 1);
+ }
+
+ /**
+ * Returns a null bitmap with a random doc in each period set as null.
+ */
+ public static RoaringBitmap randomInPeriod(int numDocs, int period) {
+ Random random = new Random(42);
+ return randomInPeriod(numDocs, period, random);
+ }
+
+ /**
+ * Returns a null bitmap with a random doc in each period set as null.
+ */
+ public static RoaringBitmap randomInPeriod(int numDocs, int period, Random
random) {
+ return periodic(numDocs, period, () -> random.nextInt(period));
+ }
+
+ /**
+ * Returns a null bitmap with a doc in each period set as null, with the
doc position in the period determined by
+ * the given supplier.
+ *
+ * @param inIntervalSupplier Supplier for the position of the doc in the
period.
+ * The supplier should return a value in the
range {@code [0, period)}.
+ */
+ public static RoaringBitmap periodic(int numDocs, int period, IntSupplier
inIntervalSupplier) {
+ RoaringBitmap nullBitmap = new RoaringBitmap();
+ for (int i = 0; i < DocIdSetPlanNode.MAX_DOC_PER_CALL; i += period) {
+ int pos = i + inIntervalSupplier.getAsInt();
+ if (pos < numDocs) {
+ nullBitmap.add(pos);
+ }
+ }
+ nullBitmap.runOptimize();
+ return nullBitmap;
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]