This is an automated email from the ASF dual-hosted git repository.
tingchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 17ee2024dd FUNNEL_COUNT Aggregation Function (#10867)
17ee2024dd is described below
commit 17ee2024dd7d49ce2f5ebf00ef9d71b6ca589a00
Author: dario-liberman <[email protected]>
AuthorDate: Wed Jun 21 19:26:57 2023 +0200
FUNNEL_COUNT Aggregation Function (#10867)
* New Funnel Aggregation Function
* Funnel analytics support - FUNNEL_COUNT aggregation function
* Delete FunnelAggregationFunction.java
* Simplify Tests
* Simplify Tests
* Simplify Tests
* within -> across
Fix javadoc
* Update FunnelCountAggregationFunction.java
Address comments
---------
Co-authored-by: Dario Liberman <[email protected]>
---
.../org/apache/pinot/common/utils/DataSchema.java | 3 +
.../blocks/results/AggregationResultsBlock.java | 4 +
.../function/AggregationFunctionFactory.java | 3 +
.../function/FunnelCountAggregationFunction.java | 511 +++++++++++++++++++++
.../pinot/queries/BaseFunnelCountQueriesTest.java | 252 ++++++++++
.../queries/FunnelCountQueriesNonSortedTest.java | 57 +++
.../queries/FunnelCountQueriesSortedTest.java | 65 +++
.../org/apache/pinot/queries/QueriesTestUtils.java | 3 +-
.../pinot/segment/spi/AggregationFunctionType.java | 5 +-
9 files changed, 901 insertions(+), 2 deletions(-)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
index d8fedfc3d7..4020aff70a 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
@@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import com.google.common.collect.Ordering;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
@@ -469,6 +470,8 @@ public class DataSchema {
private static long[] toLongArray(Object value) {
if (value instanceof long[]) {
return (long[]) value;
+ } else if (value instanceof LongArrayList) {
+ return ((LongArrayList) value).elements();
} else {
int[] intValues = (int[]) value;
int length = intValues.length;
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
index b10ebdd3b9..2d816bffca 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
@@ -19,6 +19,7 @@
package org.apache.pinot.core.operator.blocks.results;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.Collections;
@@ -182,6 +183,9 @@ public class AggregationResultsBlock extends
BaseResultsBlock {
case DOUBLE_ARRAY:
dataTableBuilder.setColumn(index, ((DoubleArrayList)
result).elements());
break;
+ case LONG_ARRAY:
+ dataTableBuilder.setColumn(index, ((LongArrayList) result).elements());
+ break;
default:
throw new IllegalStateException("Illegal column data type in final
result: " + columnDataType);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index 06fbb1db66..7f96072c9e 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -358,6 +358,9 @@ public class AggregationFunctionFactory {
case ARGMIN:
throw new IllegalArgumentException(
"Aggregation function: " + function + " is only supported in
selection without alias.");
+ case FUNNELCOUNT:
+ return new FunnelCountAggregationFunction(arguments);
+
default:
throw new IllegalArgumentException();
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FunnelCountAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FunnelCountAggregationFunction.java
new file mode 100644
index 0000000000..4eecad1002
--- /dev/null
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FunnelCountAggregationFunction.java
@@ -0,0 +1,511 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import com.google.common.base.Preconditions;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
+import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.apache.pinot.segment.spi.index.reader.Dictionary;
+import org.roaringbitmap.RoaringBitmap;
+
+
+/**
+ * The {@code FunnelCountAggregationFunction} calculates the number of step
conversions for a given partition column and
+ * a list of boolean expressions.
+ * <p>IMPORTANT: This function relies on the partition column being
partitioned for each segment, where there are no
+ * common values across different segments.
+ * <p>This function calculates the exact number of step matches per partition
key within the segment, then sums up the
+ * results from different segments.
+ *
+ * Example:
+ * SELECT
+ * dateTrunc('day', timestamp) AS ts,
+ * FUNNEL_COUNT(
+ * STEPS(url = '/addToCart', url = '/checkout', url =
'/orderConfirmation'),
+ * CORRELATED_BY(user)
+ * ) as step_counts
+ * FROM user_log
+ * WHERE url in ('/addToCart', '/checkout', '/orderConfirmation')
+ * GROUP BY 1
+ */
+public class FunnelCountAggregationFunction implements
AggregationFunction<List<Long>, LongArrayList> {
+ final List<ExpressionContext> _expressions;
+ final List<ExpressionContext> _stepExpressions;
+ final List<ExpressionContext> _correlateByExpressions;
+ final ExpressionContext _primaryCorrelationCol;
+ final int _numSteps;
+
+ final SegmentAggregationStrategy<?, List<Long>> _sortedAggregationStrategy;
+ final SegmentAggregationStrategy<?, List<Long>> _bitmapAggregationStrategy;
+
+ public FunnelCountAggregationFunction(List<ExpressionContext> expressions) {
+ _expressions = expressions;
+ _correlateByExpressions =
Option.CORRELATE_BY.getInputExpressions(expressions);
+ _primaryCorrelationCol =
Option.CORRELATE_BY.getFirstInputExpression(expressions);
+ _stepExpressions = Option.STEPS.getInputExpressions(expressions);
+ _numSteps = _stepExpressions.size();
+ _sortedAggregationStrategy = new SortedAggregationStrategy();
+ _bitmapAggregationStrategy = new BitmapAggregationStrategy();
+ }
+
+ @Override
+ public String getResultColumnName() {
+ return getType().getName().toLowerCase() + "(" +
_expressions.stream().map(ExpressionContext::toString)
+ .collect(Collectors.joining(",")) + ")";
+ }
+
+ @Override
+ public List<ExpressionContext> getInputExpressions() {
+ final List<ExpressionContext> inputs = new ArrayList<>();
+ inputs.addAll(_correlateByExpressions);
+ inputs.addAll(_stepExpressions);
+ return inputs;
+ }
+
+ @Override
+ public AggregationFunctionType getType() {
+ return AggregationFunctionType.FUNNELCOUNT;
+ }
+
+ @Override
+ public AggregationResultHolder createAggregationResultHolder() {
+ return new ObjectAggregationResultHolder();
+ }
+
+ @Override
+ public GroupByResultHolder createGroupByResultHolder(int initialCapacity,
int maxCapacity) {
+ return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+ }
+
+ @Override
+ public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ getAggregationStrategyByBlockValSetMap(blockValSetMap).aggregate(length,
aggregationResultHolder, blockValSetMap);
+ }
+
+ @Override
+ public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+
getAggregationStrategyByBlockValSetMap(blockValSetMap).aggregateGroupBySV(length,
groupKeyArray,
+ groupByResultHolder, blockValSetMap);
+ }
+
+ @Override
+ public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+
getAggregationStrategyByBlockValSetMap(blockValSetMap).aggregateGroupByMV(length,
groupKeysArray,
+ groupByResultHolder, blockValSetMap);
+ }
+
+ @Override
+ public List<Long> extractAggregationResult(AggregationResultHolder
aggregationResultHolder) {
+ return
getAggregationStrategyByAggregationResult(aggregationResultHolder.getResult()).extractAggregationResult(
+ aggregationResultHolder);
+ }
+
+ @Override
+ public List<Long> extractGroupByResult(GroupByResultHolder
groupByResultHolder, int groupKey) {
+ return
getAggregationStrategyByAggregationResult(groupByResultHolder.getResult(groupKey)).extractGroupByResult(
+ groupByResultHolder, groupKey);
+ }
+
+ @Override
+ public List<Long> merge(List<Long> a, List<Long> b) {
+ int length = a.size();
+ Preconditions.checkState(length == b.size(), "The two operand arrays are
not of the same size! provided %s, %s",
+ length, b.size());
+
+ LongArrayList result = toLongArrayList(a);
+ long[] elements = result.elements();
+ for (int i = 0; i < length; i++) {
+ elements[i] += b.get(i);
+ }
+ return result;
+ }
+
+ @Override
+ public ColumnDataType getIntermediateResultColumnType() {
+ return ColumnDataType.OBJECT;
+ }
+
+ @Override
+ public ColumnDataType getFinalResultColumnType() {
+ return ColumnDataType.LONG_ARRAY;
+ }
+
+ @Override
+ public LongArrayList extractFinalResult(List<Long> result) {
+ return toLongArrayList(result);
+ }
+
+ @Override
+ public String toExplainString() {
+ StringBuilder stringBuilder = new
StringBuilder(getType().getName()).append('(');
+ int numArguments = getInputExpressions().size();
+ if (numArguments > 0) {
+ stringBuilder.append(getInputExpressions().get(0).toString());
+ for (int i = 1; i < numArguments; i++) {
+ stringBuilder.append(",
").append(getInputExpressions().get(i).toString());
+ }
+ }
+ return stringBuilder.append(')').toString();
+ }
+
+ private static LongArrayList toLongArrayList(List<Long> longList) {
+ return longList instanceof LongArrayList ? ((LongArrayList)
longList).clone() : new LongArrayList(longList);
+ }
+
+ private int[] getCorrelationIds(Map<ExpressionContext, BlockValSet>
blockValSetMap) {
+ return blockValSetMap.get(_primaryCorrelationCol).getDictionaryIdsSV();
+ }
+
+ private int[][] getSteps(Map<ExpressionContext, BlockValSet> blockValSetMap)
{
+ final int[][] steps = new int[_numSteps][];
+ for (int n = 0; n < _numSteps; n++) {
+ final BlockValSet stepBlockValSet =
blockValSetMap.get(_stepExpressions.get(n));
+ steps[n] = stepBlockValSet.getIntValuesSV();
+ }
+ return steps;
+ }
+
+ private boolean isSorted(Map<ExpressionContext, BlockValSet> blockValSetMap)
{
+ final Dictionary primaryCorrelationDictionary =
blockValSetMap.get(_primaryCorrelationCol).getDictionary();
+ if (primaryCorrelationDictionary == null) {
+ throw new IllegalArgumentException(
+ "CORRELATE_BY column in FUNNELCOUNT aggregation function not
supported, please use a dictionary encoded "
+ + "column.");
+ }
+ return primaryCorrelationDictionary.isSorted();
+ }
+
+ private SegmentAggregationStrategy<?, List<Long>>
getAggregationStrategyByBlockValSetMap(
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ return isSorted(blockValSetMap) ? _sortedAggregationStrategy :
_bitmapAggregationStrategy;
+ }
+
+ private SegmentAggregationStrategy<?, List<Long>>
getAggregationStrategyByAggregationResult(Object aggResult) {
+ return aggResult instanceof SortedAggregationResult ?
_sortedAggregationStrategy : _bitmapAggregationStrategy;
+ }
+
+ enum Option {
+ STEPS("steps"),
+ CORRELATE_BY("correlateby");
+
+ final String _name;
+
+ Option(String name) {
+ _name = name;
+ }
+
+ boolean matches(ExpressionContext expression) {
+ if (expression.getType() != ExpressionContext.Type.FUNCTION) {
+ return false;
+ }
+ return _name.equals(expression.getFunction().getFunctionName());
+ }
+
+ Optional<ExpressionContext> find(List<ExpressionContext> expressions) {
+ return expressions.stream().filter(this::matches).findFirst();
+ }
+
+ public List<ExpressionContext> getInputExpressions(List<ExpressionContext>
expressions) {
+ return this.find(expressions).map(exp ->
exp.getFunction().getArguments())
+ .orElseThrow(() -> new IllegalStateException("FUNNELCOUNT requires "
+ _name));
+ }
+
+ public ExpressionContext getFirstInputExpression(List<ExpressionContext>
expressions) {
+ return this.find(expressions)
+ .flatMap(exp ->
exp.getFunction().getArguments().stream().findFirst())
+ .orElseThrow(() -> new IllegalStateException("FUNNELCOUNT: " + _name
+ " requires an argument."));
+ }
+ }
+
+ /**
+ * Interface for segment aggregation strategy.
+ *
+ * <p>The implementation should be stateless, and can be shared among
multiple segments in multiple threads. The
+ * result for each segment should be stored and passed in via the result
holder.
+ * There should be no assumptions beyond segment boundaries, different
aggregation strategies may be utilized
+ * across different segments for a given query.
+ *
+ * @param <A> Aggregation result accumulated across blocks within segment,
kept by result holder.
+ * @param <I> Intermediate result at segment level (extracted from
aforementioned aggregation result).
+ */
+ @ThreadSafe
+ static abstract class SegmentAggregationStrategy<A, I> {
+
+ /**
+ * Returns an aggregation result for this aggregation strategy to be kept
in a result holder (aggregation only).
+ */
+ abstract A createAggregationResult();
+
+ public A getAggregationResultGroupBy(GroupByResultHolder
groupByResultHolder, int groupKey) {
+ A aggResult = groupByResultHolder.getResult(groupKey);
+ if (aggResult == null) {
+ aggResult = createAggregationResult();
+ groupByResultHolder.setValueForKey(groupKey, aggResult);
+ }
+ return aggResult;
+ }
+
+ public A getAggregationResult(AggregationResultHolder
aggregationResultHolder) {
+ A aggResult = aggregationResultHolder.getResult();
+ if (aggResult == null) {
+ aggResult = createAggregationResult();
+ aggregationResultHolder.setValue(aggResult);
+ }
+ return aggResult;
+ }
+
+ /**
+ * Performs aggregation on the given block value sets (aggregation only).
+ */
+ abstract void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap);
+
+ /**
+ * Performs aggregation on the given group key array and block value sets
(aggregation group-by on single-value
+ * columns).
+ */
+ abstract void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap);
+
+ /**
+ * Performs aggregation on the given group keys array and block value sets
(aggregation group-by on multi-value
+ * columns).
+ */
+ abstract void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap);
+
+ /**
+ * Extracts the intermediate result from the aggregation result holder
(aggregation only).
+ */
+ public I extractAggregationResult(AggregationResultHolder
aggregationResultHolder) {
+ return extractIntermediateResult(aggregationResultHolder.getResult());
+ }
+
+ /**
+ * Extracts the intermediate result from the group-by result holder for
the given group key (aggregation group-by).
+ */
+ public I extractGroupByResult(GroupByResultHolder groupByResultHolder, int
groupKey) {
+ return
extractIntermediateResult(groupByResultHolder.getResult(groupKey));
+ }
+
+ abstract I extractIntermediateResult(A aggregationResult);
+ }
+
+ /**
+ * Aggregation strategy leveraging roaring bitmap algebra
(unions/intersections).
+ */
+ class BitmapAggregationStrategy extends
SegmentAggregationStrategy<RoaringBitmap[], List<Long>> {
+
+ @Override
+ public RoaringBitmap[] createAggregationResult() {
+ final RoaringBitmap[] stepsBitmaps = new RoaringBitmap[_numSteps];
+ for (int n = 0; n < _numSteps; n++) {
+ stepsBitmaps[n] = new RoaringBitmap();
+ }
+ return stepsBitmaps;
+ }
+
+ @Override
+ public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ final int[] correlationIds = getCorrelationIds(blockValSetMap);
+ final int[][] steps = getSteps(blockValSetMap);
+
+ final RoaringBitmap[] stepsBitmaps =
getAggregationResult(aggregationResultHolder);
+
+ for (int n = 0; n < _numSteps; n++) {
+ for (int i = 0; i < length; i++) {
+ if (steps[n][i] > 0) {
+ stepsBitmaps[n].add(correlationIds[i]);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ final int[] correlationIds = getCorrelationIds(blockValSetMap);
+ final int[][] steps = getSteps(blockValSetMap);
+
+ for (int n = 0; n < _numSteps; n++) {
+ for (int i = 0; i < length; i++) {
+ final int groupKey = groupKeyArray[i];
+ if (steps[n][i] > 0) {
+ getAggregationResultGroupBy(groupByResultHolder,
groupKey)[n].add(correlationIds[i]);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ final int[] correlationIds = getCorrelationIds(blockValSetMap);
+ final int[][] steps = getSteps(blockValSetMap);
+
+ for (int n = 0; n < _numSteps; n++) {
+ for (int i = 0; i < length; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ if (steps[n][i] > 0) {
+ getAggregationResultGroupBy(groupByResultHolder,
groupKey)[n].add(correlationIds[i]);
+ }
+ }
+ }
+ }
+ }
+
+ @Override
+ public List<Long> extractIntermediateResult(RoaringBitmap[] stepsBitmaps) {
+ if (stepsBitmaps == null) {
+ return new LongArrayList(_numSteps);
+ }
+
+ long[] result = new long[_numSteps];
+ result[0] = stepsBitmaps[0].getCardinality();
+ for (int i = 1; i < _numSteps; i++) {
+ // intersect this step with previous step
+ stepsBitmaps[i].and(stepsBitmaps[i - 1]);
+ result[i] = stepsBitmaps[i].getCardinality();
+ }
+ return LongArrayList.wrap(result);
+ }
+ }
+
+ /**
+ * Aggregation strategy for segments sorted by the main correlation column.
+ */
+ class SortedAggregationStrategy extends
SegmentAggregationStrategy<SortedAggregationResult, List<Long>> {
+
+ @Override
+ public SortedAggregationResult createAggregationResult() {
+ return new SortedAggregationResult();
+ }
+
+ @Override
+ public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ final int[] correlationIds = getCorrelationIds(blockValSetMap);
+ final int[][] steps = getSteps(blockValSetMap);
+
+ final SortedAggregationResult agg =
getAggregationResult(aggregationResultHolder);
+
+ for (int i = 0; i < length; i++) {
+ agg.sortedCount(steps, i, correlationIds[i]);
+ }
+ }
+
+ @Override
+ public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ final int[] correlationIds = getCorrelationIds(blockValSetMap);
+ final int[][] steps = getSteps(blockValSetMap);
+
+ for (int i = 0; i < length; i++) {
+ final int groupKey = groupKeyArray[i];
+ final SortedAggregationResult agg =
getAggregationResultGroupBy(groupByResultHolder, groupKey);
+
+ agg.sortedCount(steps, i, correlationIds[i]);
+ }
+ }
+
+ @Override
+ public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ final int[] correlationIds = getCorrelationIds(blockValSetMap);
+ final int[][] steps = getSteps(blockValSetMap);
+
+ for (int i = 0; i < length; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ final SortedAggregationResult agg =
getAggregationResultGroupBy(groupByResultHolder, groupKey);
+
+ agg.sortedCount(steps, i, correlationIds[i]);
+ }
+ }
+ }
+
+ @Override
+ public List<Long> extractIntermediateResult(SortedAggregationResult agg) {
+ if (agg == null) {
+ return new LongArrayList(_numSteps);
+ }
+
+ return LongArrayList.wrap(agg.extractResult());
+ }
+ }
+
+ /**
+ * Aggregation result data structure leveraged by sorted aggregation
strategy.
+ */
+ class SortedAggregationResult {
+ public long[] _stepCounters = new long[_numSteps];
+ public int _lastCorrelationId = Integer.MIN_VALUE;
+ public boolean[] _correlatedSteps = new boolean[_numSteps];
+
+ public void sortedCount(int[][] steps, int i, int correlationId) {
+ if (correlationId == _lastCorrelationId) {
+ // same correlation as before, keep accumulating.
+ for (int n = 0; n < _numSteps; n++) {
+ _correlatedSteps[n] |= steps[n][i] > 0;
+ }
+ } else {
+ // End of correlation group, calculate funnel conversion counts
+ incrStepCounters();
+
+ // initialize next correlation group
+ for (int n = 0; n < _numSteps; n++) {
+ _correlatedSteps[n] = steps[n][i] > 0;
+ }
+ _lastCorrelationId = correlationId;
+ }
+ }
+
+ void incrStepCounters() {
+ for (int n = 0; n < _numSteps; n++) {
+ if (!_correlatedSteps[n]) {
+ break;
+ }
+ _stepCounters[n]++;
+ }
+ }
+
+ public long[] extractResult() {
+ // count last correlation id left open
+ incrStepCounters();
+
+ return _stepCounters;
+ }
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/BaseFunnelCountQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/BaseFunnelCountQueriesTest.java
new file mode 100644
index 0000000000..ef5c7d596f
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/BaseFunnelCountQueriesTest.java
@@ -0,0 +1,252 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import org.apache.commons.io.FileUtils;
+import org.apache.pinot.common.utils.HashUtil;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
+import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
+import org.apache.pinot.core.operator.query.AggregationOperator;
+import org.apache.pinot.core.operator.query.GroupByOperator;
+import
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
+import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
+import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+
+
+/**
+ * Base queries test for FUNNEL_COUNT queries.
+ * Each strategy gets its own test.
+ */
+@SuppressWarnings("rawtypes")
+abstract public class BaseFunnelCountQueriesTest extends BaseQueriesTest {
+ protected static final File INDEX_DIR =
+ new File(FileUtils.getTempDirectory(), "FunnelCountQueriesTest");
+ protected static final String RAW_TABLE_NAME = "testTable";
+ protected static final String SEGMENT_NAME = "testSegment";
+ protected static final Random RANDOM = new Random();
+
+ protected static final int NUM_RECORDS = 2000;
+ protected static final int MAX_VALUE = 1000;
+ protected static final int NUM_GROUPS = 100;
+ protected static final int FILTER_LIMIT = 50;
+ protected static final String ID_COLUMN = "idColumn";
+ protected static final String STEP_COLUMN = "stepColumn";
+ protected static final String[] STEPS = {"A", "B"};
+ protected static final Schema SCHEMA = new Schema.SchemaBuilder()
+ .addSingleValueDimension(ID_COLUMN, DataType.INT)
+ .addSingleValueDimension(STEP_COLUMN, DataType.STRING)
+ .build();
+ protected static final TableConfigBuilder TABLE_CONFIG_BUILDER =
+ new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME);
+
+ private Set<Integer>[] _values = new Set[2];
+ private List<Integer> _all = new ArrayList<>();
+ private IndexSegment _indexSegment;
+ private List<IndexSegment> _indexSegments;
+
+ protected abstract int getExpectedNumEntriesScannedInFilter();
+ protected abstract TableConfig getTableConfig();
+ protected abstract IndexSegment buildSegment(List<GenericRow> records)
throws Exception;
+
+ @Override
+ protected String getFilter() {
+ return String.format(" WHERE idColumn >= %s", FILTER_LIMIT);
+ }
+
+ @Override
+ protected IndexSegment getIndexSegment() {
+ return _indexSegment;
+ }
+
+ @Override
+ protected List<IndexSegment> getIndexSegments() {
+ return _indexSegments;
+ }
+
+ @BeforeClass
+ public void setUp()
+ throws Exception {
+ FileUtils.deleteDirectory(INDEX_DIR);
+
+ List<GenericRow> records = genereateRows();
+ _indexSegment = buildSegment(records);
+ _indexSegments = Arrays.asList(_indexSegment, _indexSegment);
+ }
+
+ private List<GenericRow> genereateRows() {
+ List<GenericRow> records = new ArrayList<>(NUM_RECORDS);
+ int hashMapCapacity = HashUtil.getHashMapCapacity(MAX_VALUE);
+ _values[0] = new HashSet<>(hashMapCapacity);
+ _values[1] = new HashSet<>(hashMapCapacity);
+ for (int i = 0; i < NUM_RECORDS; i++) {
+ int value = RANDOM.nextInt(MAX_VALUE);
+ GenericRow record = new GenericRow();
+ record.putValue(ID_COLUMN, value);
+ record.putValue(STEP_COLUMN, STEPS[i % 2]);
+ records.add(record);
+ _all.add(Integer.hashCode(value));
+ _values[i % 2].add(Integer.hashCode(value));
+ }
+ return records;
+ }
+
+ @Test
+ public void testAggregationOnly() {
+ String query = String.format("SELECT "
+ + "FUNNEL_COUNT("
+ + " STEPS(stepColumn = 'A', stepColumn = 'B'),"
+ + " CORRELATE_BY(idColumn)"
+ + ") FROM testTable");
+
+ // Inner segment
+ Predicate<Integer> filter = id -> id >= FILTER_LIMIT;
+ long expectedFilteredNumDocs = _all.stream().filter(filter).count();
+ Set<Integer> filteredA =
_values[0].stream().filter(filter).collect(Collectors.toSet());
+ Set<Integer> filteredB =
_values[1].stream().filter(filter).collect(Collectors.toSet());
+ Set<Integer> intersection = new HashSet<>(filteredA);
+ intersection.retainAll(filteredB);
+ long[] expectedResult = { filteredA.size(), intersection.size() };
+
+ Operator operator = getOperatorWithFilter(query);
+ assertTrue(operator instanceof AggregationOperator);
+ AggregationResultsBlock resultsBlock = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
+ expectedFilteredNumDocs, getExpectedNumEntriesScannedInFilter(), 2 *
expectedFilteredNumDocs, NUM_RECORDS);
+ List<Object> aggregationResult = resultsBlock.getResults();
+ assertNotNull(aggregationResult);
+ assertEquals(aggregationResult.size(), 1);
+ for (int i = 0; i < 2; i++) {
+ assertEquals(((LongArrayList) aggregationResult.get(0)).getLong(i),
expectedResult[i]);
+ }
+
+ // Inter segments (expect 4 * inner segment result)
+ for (int i = 0; i < 2; i++) {
+ expectedResult[i] = 4 * expectedResult[i];
+ }
+ Object[] expectedResults = { expectedResult };
+
+
QueriesTestUtils.testInterSegmentsResult(getBrokerResponseWithFilter(query),
+ 4 * expectedFilteredNumDocs, 4 *
getExpectedNumEntriesScannedInFilter(), 4 * 2 * expectedFilteredNumDocs,
+ 4 * NUM_RECORDS, expectedResults);
+ }
+
+ @Test
+ public void testAggregationGroupBy() {
+ String query = String.format("SELECT "
+ + "MOD(idColumn, %s), "
+ + "FUNNEL_COUNT("
+ + " STEPS(stepColumn = 'A', stepColumn = 'B'),"
+ + " CORRELATE_BY(idColumn)"
+ + ") FROM testTable "
+ + "WHERE idColumn >= %s "
+ + "GROUP BY 1 ORDER BY 1 LIMIT %s", NUM_GROUPS, FILTER_LIMIT,
NUM_GROUPS);
+
+ // Inner segment
+ Set<Integer>[] filteredA = new Set[NUM_GROUPS];
+ Set<Integer>[] filteredB = new Set[NUM_GROUPS];
+ Set<Integer>[] intersection = new Set[NUM_GROUPS];
+ long[][] expectedResult = new long[NUM_GROUPS][];
+
+ long expectedFilteredNumDocs = _all.stream().filter(id -> id >=
FILTER_LIMIT).count();
+
+ int expectedNumGroups = 0;
+ for (int i = 0; i < NUM_GROUPS; i++) {
+ final int group = i;
+ Predicate<Integer> filter = id -> id >= FILTER_LIMIT && id % NUM_GROUPS
== group;
+ filteredA[group] =
_values[0].stream().filter(filter).collect(Collectors.toSet());
+ filteredB[group] =
_values[1].stream().filter(filter).collect(Collectors.toSet());
+ intersection[group] = new HashSet<>(filteredA[group]);
+ intersection[group].retainAll(filteredB[group]);
+ if (!filteredA[i].isEmpty() || !filteredB[i].isEmpty()) {
+ expectedNumGroups++;
+ expectedResult[group] = new long[] { filteredA[group].size(),
intersection[group].size() };
+ }
+ }
+
+ // Inner segment
+ GroupByOperator groupByOperator = getOperator(query);
+ GroupByResultsBlock resultsBlock = groupByOperator.nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(groupByOperator.getExecutionStatistics(),
+ expectedFilteredNumDocs, getExpectedNumEntriesScannedInFilter(), 2 *
expectedFilteredNumDocs, NUM_RECORDS);
+
+ AggregationGroupByResult aggregationGroupByResult =
resultsBlock.getAggregationGroupByResult();
+ assertNotNull(aggregationGroupByResult);
+ int numGroups = 0;
+ Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator =
aggregationGroupByResult.getGroupKeyIterator();
+ while (groupKeyIterator.hasNext()) {
+ numGroups++;
+ GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
+ int key = ((Double) groupKey._keys[0]).intValue();
+ assertEquals(aggregationGroupByResult.getResultForGroupId(0,
groupKey._groupId),
+ new LongArrayList(expectedResult[key]));
+ }
+ assertEquals(numGroups, expectedNumGroups);
+
+ // Inter segments (expect 4 * inner segment result)
+ List<Object[]> expectedRows = new ArrayList<>();
+ for (int i = 0; i < NUM_GROUPS; i++) {
+ if (expectedResult[i] == null) {
+ continue;
+ }
+ for (int step = 0; step < 2; step++) {
+ expectedResult[i][step] = 4 * expectedResult[i][step];
+ }
+ Object[] expectedRow = { Double.valueOf(i), expectedResult[i] };
+ expectedRows.add(expectedRow);
+ }
+
+ // Inter segments (expect 4 * inner segment result)
+ QueriesTestUtils.testInterSegmentsResult(getBrokerResponse(query),
+ 4 * expectedFilteredNumDocs, 4 *
getExpectedNumEntriesScannedInFilter(), 4 * 2 * expectedFilteredNumDocs,
+ 4 * NUM_RECORDS, expectedRows);
+ }
+
+ @AfterClass
+ public void tearDown()
+ throws IOException {
+ _indexSegment.destroy();
+ FileUtils.deleteDirectory(INDEX_DIR);
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesNonSortedTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesNonSortedTest.java
new file mode 100644
index 0000000000..c89a5d74c9
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesNonSortedTest.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.util.Collections;
+import java.util.List;
+import
org.apache.pinot.segment.local.indexsegment.mutable.MutableSegmentImplTestUtils;
+import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.segment.spi.MutableSegment;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.data.readers.GenericRow;
+
+
+/**
+ * Queries test for FUNNEL_COUNT queries.
+ */
+@SuppressWarnings("rawtypes")
+public class FunnelCountQueriesNonSortedTest extends
BaseFunnelCountQueriesTest {
+
+ @Override
+ protected int getExpectedNumEntriesScannedInFilter() {
+ return NUM_RECORDS;
+ }
+
+ @Override
+ protected TableConfig getTableConfig() {
+ return TABLE_CONFIG_BUILDER.build();
+ }
+
+ @Override
+ protected IndexSegment buildSegment(List<GenericRow> records)
+ throws Exception {
+ MutableSegment mutableSegment = MutableSegmentImplTestUtils
+ .createMutableSegmentImpl(SCHEMA, Collections.emptySet(),
Collections.emptySet(), Collections.emptySet(),
+ false);
+ for (GenericRow record : records) {
+ mutableSegment.index(record, null);
+ }
+ return mutableSegment;
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesSortedTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesSortedTest.java
new file mode 100644
index 0000000000..f06fe26637
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/FunnelCountQueriesSortedTest.java
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.io.File;
+import java.util.Comparator;
+import java.util.List;
+import
org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader;
+import
org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader;
+import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.ReadMode;
+
+
+/**
+ * Queries test for FUNNEL_COUNT queries using sorted strategy.
+ */
+@SuppressWarnings("rawtypes")
+public class FunnelCountQueriesSortedTest extends BaseFunnelCountQueriesTest {
+
+ @Override
+ protected int getExpectedNumEntriesScannedInFilter() {
+ return 0;
+ }
+
+ @Override
+ protected TableConfig getTableConfig() {
+ return TABLE_CONFIG_BUILDER.setSortedColumn(ID_COLUMN).build();
+ }
+
+ @Override
+ protected IndexSegment buildSegment(List<GenericRow> records)
+ throws Exception {
+ // Simulate PinotSegmentSorter
+ records.sort(Comparator.comparingInt(rec -> (Integer)
rec.getValue(ID_COLUMN)));
+
+ SegmentGeneratorConfig segmentGeneratorConfig = new
SegmentGeneratorConfig(getTableConfig(), SCHEMA);
+ segmentGeneratorConfig.setTableName(RAW_TABLE_NAME);
+ segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
+ segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath());
+ SegmentIndexCreationDriverImpl driver = new
SegmentIndexCreationDriverImpl();
+ driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records));
+ driver.build();
+ return ImmutableSegmentLoader.load(new File(INDEX_DIR, SEGMENT_NAME),
ReadMode.mmap);
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java
b/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java
index c5ba9cdb8b..c5c9eea9a3 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java
@@ -133,7 +133,8 @@ public class QueriesTestUtils {
private static void validateRows(List<Object[]> actual, List<Object[]>
expected) {
assertEquals(actual.size(), expected.size());
for (int i = 0; i < actual.size(); i++) {
- assertEquals(actual.get(i), expected.get(i));
+ // Generic assertEquals delegates to assertArrayEquals, which can test
for equality of array values in rows.
+ assertEquals((Object) actual.get(i), (Object) expected.get(i));
}
}
diff --git
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index 7b2a02d666..5201fdcd2b 100644
---
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -116,7 +116,10 @@ public enum AggregationFunctionType {
PARENTARGMIN(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX +
ARGMIN.getName()),
PARENTARGMAX(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX +
ARGMAX.getName()),
CHILDARGMIN(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX
+ ARGMIN.getName()),
- CHILDARGMAX(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX
+ ARGMAX.getName());
+ CHILDARGMAX(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX
+ ARGMAX.getName()),
+
+ // funnel aggregate functions
+ FUNNELCOUNT("funnelCount");
private static final Set<String> NAMES =
Arrays.stream(values()).flatMap(func -> Stream.of(func.name(),
func.getName(),
func.getName().toLowerCase())).collect(Collectors.toSet());
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]