This is an automated email from the ASF dual-hosted git repository. xiangfu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 70bfd4185c Adding batch api support for WindowFunction (#12993) 70bfd4185c is described below commit 70bfd4185ca18eb64658c5a35813e3578f7c843d Author: Xiang Fu <xiangfu.1...@gmail.com> AuthorDate: Tue May 7 03:44:29 2024 -0700 Adding batch api support for WindowFunction (#12993) --- .../runtime/operator/WindowAggregateOperator.java | 376 ++++----------------- .../runtime/operator/utils/AggregationUtils.java | 22 +- .../operator/window/ValueWindowFunction.java | 54 --- .../runtime/operator/window/WindowFunction.java | 38 ++- .../operator/window/WindowFunctionFactory.java | 60 ++++ .../window/aggregate/AggregateWindowFunction.java | 124 +++++++ .../DenseRankWindowFunction.java} | 35 +- .../operator/window/range/RangeWindowFunction.java | 67 ++++ .../RankWindowFunction.java} | 33 +- .../RowNumberWindowFunction.java} | 21 +- .../{ => value}/FirstValueWindowFunction.java | 18 +- .../window/{ => value}/LagValueWindowFunction.java | 22 +- .../{ => value}/LastValueWindowFunction.java | 18 +- .../{ => value}/LeadValueWindowFunction.java | 22 +- .../operator/window/value/ValueWindowFunction.java | 47 +++ .../operator/WindowAggregateOperatorTest.java | 48 +-- 16 files changed, 496 insertions(+), 509 deletions(-) diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java index c797607660..e2b989b76d 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java @@ -21,7 +21,6 @@ package org.apache.pinot.query.runtime.operator; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.util.ArrayList; import java.util.HashMap; @@ -29,7 +28,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.Function; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.calcite.rel.RelFieldCollation; @@ -45,7 +43,8 @@ import org.apache.pinot.query.runtime.blocks.TransferableBlock; import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils; import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; import org.apache.pinot.query.runtime.operator.utils.TypeUtils; -import org.apache.pinot.query.runtime.operator.window.ValueWindowFunction; +import org.apache.pinot.query.runtime.operator.window.WindowFunction; +import org.apache.pinot.query.runtime.operator.window.WindowFunctionFactory; import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -85,9 +84,9 @@ public class WindowAggregateOperator extends MultiStageOperator { private static final Logger LOGGER = LoggerFactory.getLogger(WindowAggregateOperator.class); // List of window functions which can only be applied as ROWS window frame type - private static final Set<String> ROWS_ONLY_FUNCTION_NAMES = ImmutableSet.of("ROW_NUMBER"); + public static final Set<String> ROWS_ONLY_FUNCTION_NAMES = ImmutableSet.of("ROW_NUMBER"); // List of ranking window functions whose output depends on the ordering of input rows and not on the actual values - private static final Set<String> RANKING_FUNCTION_NAMES = ImmutableSet.of("RANK", "DENSE_RANK"); + public static final Set<String> RANKING_FUNCTION_NAMES = ImmutableSet.of("RANK", "DENSE_RANK"); private final MultiStageOperator _inputOperator; private final List<RexExpression> _groupSet; @@ -96,7 +95,7 @@ public class WindowAggregateOperator extends MultiStageOperator { private final List<RexExpression.FunctionCall> _aggCalls; private final List<RexExpression> _constants; private final DataSchema _resultSchema; - private final WindowAggregateAccumulator[] _windowAccumulators; + private final WindowFunction[] _windowFunctions; private final Map<Key, List<Object[]>> _partitionRows; private final boolean _isPartitionByOnly; @@ -106,22 +105,12 @@ public class WindowAggregateOperator extends MultiStageOperator { private TransferableBlock _eosBlock = null; private final StatMap<StatKey> _statMap = new StatMap<>(StatKey.class); - public WindowAggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator, - List<RexExpression> groupSet, List<RexExpression> orderSet, List<RelFieldCollation.Direction> orderSetDirection, - List<RelFieldCollation.NullDirection> orderSetNullDirection, List<RexExpression> aggCalls, int lowerBound, - int upperBound, WindowNode.WindowFrameType windowFrameType, List<RexExpression> constants, - DataSchema resultSchema, DataSchema inputSchema) { - this(context, inputOperator, groupSet, orderSet, orderSetDirection, orderSetNullDirection, aggCalls, lowerBound, - upperBound, windowFrameType, constants, resultSchema, inputSchema, WindowAggregateAccumulator.WIN_AGG_MERGERS); - } - @VisibleForTesting public WindowAggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator, List<RexExpression> groupSet, List<RexExpression> orderSet, List<RelFieldCollation.Direction> orderSetDirection, List<RelFieldCollation.NullDirection> orderSetNullDirection, List<RexExpression> aggCalls, int lowerBound, int upperBound, WindowNode.WindowFrameType windowFrameType, List<RexExpression> constants, - DataSchema resultSchema, DataSchema inputSchema, - Map<String, Function<ColumnDataType, AggregationUtils.Merger>> mergers) { + DataSchema resultSchema, DataSchema inputSchema) { super(context); _inputOperator = inputOperator; @@ -140,13 +129,13 @@ public class WindowAggregateOperator extends MultiStageOperator { _constants = constants; _resultSchema = resultSchema; - _windowAccumulators = new WindowAggregateAccumulator[_aggCalls.size()]; + _windowFunctions = new WindowFunction[_aggCalls.size()]; int aggCallsSize = _aggCalls.size(); for (int i = 0; i < aggCallsSize; i++) { RexExpression.FunctionCall agg = _aggCalls.get(i); String functionName = agg.getFunctionName(); - validateAggregationCalls(functionName, mergers); - _windowAccumulators[i] = new WindowAggregateAccumulator(agg, mergers, functionName, inputSchema, _orderSetInfo); + validateAggregationCalls(functionName); + _windowFunctions[i] = WindowFunctionFactory.construnctWindowFunction(agg, inputSchema, _orderSetInfo); } _partitionRows = new HashMap<>(); @@ -187,25 +176,10 @@ public class WindowAggregateOperator extends MultiStageOperator { if (_hasReturnedWindowAggregateBlock) { return _eosBlock; } - TransferableBlock finalBlock = consumeInputBlocks(); - if (finalBlock.isErrorBlock()) { - return finalBlock; - } - _eosBlock = updateEosBlock(finalBlock, _statMap); - return produceWindowAggregatedBlock(); + return computeBlocks(); } - private void validateAggregationCalls(String functionName, - Map<String, Function<ColumnDataType, AggregationUtils.Merger>> mergers) { - if (ValueWindowFunction.VALUE_WINDOW_FUNCTION_MAP.containsKey(functionName)) { - Preconditions.checkState(_windowFrame.getWindowFrameType() == WindowNode.WindowFrameType.RANGE, - String.format("Only RANGE type frames are supported at present for VALUE function: %s", functionName)); - return; - } - if (!mergers.containsKey(functionName)) { - throw new IllegalStateException("Unexpected aggregation function name: " + functionName); - } - + private void validateAggregationCalls(String functionName) { if (ROWS_ONLY_FUNCTION_NAMES.contains(functionName)) { Preconditions.checkState( _windowFrame.getWindowFrameType() == WindowNode.WindowFrameType.ROWS && _windowFrame.isUpperBoundCurrentRow(), @@ -236,60 +210,54 @@ public class WindowAggregateOperator extends MultiStageOperator { return partitionByInputRefIndexes.equals(orderByInputRefIndexes); } - private TransferableBlock produceWindowAggregatedBlock() { - Key emptyOrderKey = AggregationUtils.extractEmptyKey(); + /** + * @return the final block, which must be either an end of stream or an error. + */ + private TransferableBlock computeBlocks() { + TransferableBlock block = _inputOperator.nextBlock(); + while (!TransferableBlockUtils.isEndOfStream(block)) { + List<Object[]> container = block.getContainer(); + for (Object[] row : container) { + _numRows++; + // TODO: Revisit null direction handling for all query types + Key key = AggregationUtils.extractRowKey(row, _groupSet); + _partitionRows.computeIfAbsent(key, k -> new ArrayList<>()).add(row); + } + block = _inputOperator.nextBlock(); + } + // Early termination if the block is an error block + if (block.isErrorBlock()) { + return block; + } + _eosBlock = updateEosBlock(block, _statMap); + ColumnDataType[] resultStoredTypes = _resultSchema.getStoredColumnDataTypes(); List<Object[]> rows = new ArrayList<>(_numRows); - if (_windowFrame.getWindowFrameType() == WindowNode.WindowFrameType.RANGE) { - // All aggregation window functions only support RANGE type today (SUM/AVG/MIN/MAX/COUNT/BOOL_AND/BOOL_OR) - // RANK and DENSE_RANK ranking window functions also only support RANGE type today - for (Map.Entry<Key, List<Object[]>> e : _partitionRows.entrySet()) { - Key partitionKey = e.getKey(); - List<Object[]> rowList = e.getValue(); - for (int rowId = 0; rowId < rowList.size(); rowId++) { - Object[] existingRow = rowList.get(rowId); - Object[] row = new Object[existingRow.length + _aggCalls.size()]; - Key orderKey = (_isPartitionByOnly && CollectionUtils.isEmpty(_orderSetInfo.getOrderSet())) ? emptyOrderKey - : AggregationUtils.extractRowKey(existingRow, _orderSetInfo.getOrderSet()); - System.arraycopy(existingRow, 0, row, 0, existingRow.length); - for (int i = 0; i < _windowAccumulators.length; i++) { - if (_windowAccumulators[i]._valueWindowFunction == null) { - row[i + existingRow.length] = _windowAccumulators[i].getRangeResultForKeys(partitionKey, orderKey); - } else { - row[i + existingRow.length] = _windowAccumulators[i].getValueResultForKeys(orderKey, rowId, rowList); - } - } - // Convert the results from Accumulator to the desired type - TypeUtils.convertRow(row, resultStoredTypes); - rows.add(row); - } + for (Map.Entry<Key, List<Object[]>> e : _partitionRows.entrySet()) { + List<Object[]> rowList = e.getValue(); + + // Each window function will return a list of results for each row in the input set + List<List<Object>> windowFunctionResults = new ArrayList<>(); + for (WindowFunction windowFunction : _windowFunctions) { + List<Object> processRows = windowFunction.processRows(rowList); + Preconditions.checkState(processRows.size() == rowList.size(), + "Number of rows in the result set must match the number of rows in the input set"); + windowFunctionResults.add(processRows); } - } else { - // Only ROW_NUMBER() window function is supported as ROWS type today - Key previousPartitionKey = null; - Object[] previousRowValues = new Object[_windowAccumulators.length]; - for (int i = 0; i < _windowAccumulators.length; i++) { - previousRowValues[i] = null; - } - for (Map.Entry<Key, List<Object[]>> e : _partitionRows.entrySet()) { - Key partitionKey = e.getKey(); - List<Object[]> rowList = e.getValue(); - for (Object[] existingRow : rowList) { - Object[] row = new Object[existingRow.length + _aggCalls.size()]; - System.arraycopy(existingRow, 0, row, 0, existingRow.length); - for (int i = 0; i < _windowAccumulators.length; i++) { - row[i + existingRow.length] = - _windowAccumulators[i].computeRowResultForCurrentRow(partitionKey, previousPartitionKey, row, - previousRowValues[i]); - previousRowValues[i] = row[i + existingRow.length]; - } - // Convert the results from Accumulator to the desired type - TypeUtils.convertRow(row, resultStoredTypes); - rows.add(row); - previousPartitionKey = partitionKey; + + for (int rowId = 0; rowId < rowList.size(); rowId++) { + Object[] existingRow = rowList.get(rowId); + Object[] row = new Object[existingRow.length + _aggCalls.size()]; + System.arraycopy(existingRow, 0, row, 0, existingRow.length); + for (int i = 0; i < _windowFunctions.length; i++) { + row[i + existingRow.length] = windowFunctionResults.get(i).get(rowId); } + // Convert the results from WindowFunction to the desired type + TypeUtils.convertRow(row, resultStoredTypes); + rows.add(row); } } + _hasReturnedWindowAggregateBlock = true; if (rows.isEmpty()) { return _eosBlock; @@ -298,60 +266,20 @@ public class WindowAggregateOperator extends MultiStageOperator { } } - /** - * @return the final block, which must be either an end of stream or an error. - */ - private TransferableBlock consumeInputBlocks() { - Key emptyOrderKey = AggregationUtils.extractEmptyKey(); - TransferableBlock block = _inputOperator.nextBlock(); - while (!TransferableBlockUtils.isEndOfStream(block)) { - List<Object[]> container = block.getContainer(); - if (_windowFrame.getWindowFrameType() == WindowNode.WindowFrameType.RANGE) { - // Only need to accumulate the aggregate function values for RANGE type. ROW type can be calculated as - // we output the rows since the aggregation value depends on the neighboring rows. - for (Object[] row : container) { - _numRows++; - // TODO: Revisit null direction handling for all query types - Key key = AggregationUtils.extractRowKey(row, _groupSet); - _partitionRows.computeIfAbsent(key, k -> new ArrayList<>()).add(row); - // Only need to accumulate the aggregate function values for RANGE type. ROW type can be calculated as - // we output the rows since the aggregation value depends on the neighboring rows. - Key orderKey = (_isPartitionByOnly && CollectionUtils.isEmpty(_orderSetInfo.getOrderSet())) ? emptyOrderKey - : AggregationUtils.extractRowKey(row, _orderSetInfo.getOrderSet()); - int aggCallsSize = _aggCalls.size(); - for (int i = 0; i < aggCallsSize; i++) { - if (_windowAccumulators[i]._valueWindowFunction == null) { - _windowAccumulators[i].accumulateRangeResults(key, orderKey, row); - } - } - } - } else { - for (Object[] row : container) { - _numRows++; - // TODO: Revisit null direction handling for all query types - Key key = AggregationUtils.extractRowKey(row, _groupSet); - _partitionRows.computeIfAbsent(key, k -> new ArrayList<>()).add(row); - } - } - block = _inputOperator.nextBlock(); - } - return block; - } - /** * Contains all the ORDER BY key related information such as the keys, direction, and null direction */ - private static class OrderSetInfo { + public static class OrderSetInfo { // List of order keys - final List<RexExpression> _orderSet; + public final List<RexExpression> _orderSet; // List of order direction for each key - final List<RelFieldCollation.Direction> _orderSetDirection; + public final List<RelFieldCollation.Direction> _orderSetDirection; // List of null direction for each key - final List<RelFieldCollation.NullDirection> _orderSetNullDirection; + public final List<RelFieldCollation.NullDirection> _orderSetNullDirection; // Set to 'true' if this is a partition by only query - final boolean _isPartitionByOnly; + public final boolean _isPartitionByOnly; - OrderSetInfo(List<RexExpression> orderSet, List<RelFieldCollation.Direction> orderSetDirection, + public OrderSetInfo(List<RexExpression> orderSet, List<RelFieldCollation.Direction> orderSetDirection, List<RelFieldCollation.NullDirection> orderSetNullDirection, boolean isPartitionByOnly) { _orderSet = orderSet; _orderSetDirection = orderSetDirection; @@ -359,19 +287,19 @@ public class WindowAggregateOperator extends MultiStageOperator { _isPartitionByOnly = isPartitionByOnly; } - List<RexExpression> getOrderSet() { + public List<RexExpression> getOrderSet() { return _orderSet; } - List<RelFieldCollation.Direction> getOrderSetDirection() { + public List<RelFieldCollation.Direction> getOrderSetDirection() { return _orderSetDirection; } - List<RelFieldCollation.NullDirection> getOrderSetNullDirection() { + public List<RelFieldCollation.NullDirection> getOrderSetNullDirection() { return _orderSetNullDirection; } - boolean isPartitionByOnly() { + public boolean isPartitionByOnly() { return _isPartitionByOnly; } } @@ -419,184 +347,6 @@ public class WindowAggregateOperator extends MultiStageOperator { } } - private static class MergeRowNumber implements AggregationUtils.Merger { - - @Override - public Long init(@Nullable Object value, ColumnDataType dataType) { - return 1L; - } - - @Override - public Long merge(Object agg, @Nullable Object value) { - return (long) agg + 1; - } - } - - private static class MergeRank implements AggregationUtils.Merger { - - @Override - public Long init(Object other, ColumnDataType dataType) { - return 1L; - } - - @Override - public Long merge(Object left, Object right) { - // RANK always increase by the number of duplicate entries seen for the given ORDER BY key. - return ((Number) left).longValue() + ((Number) right).longValue(); - } - } - - private static class MergeDenseRank implements AggregationUtils.Merger { - - @Override - public Long init(Object other, ColumnDataType dataType) { - return 1L; - } - - @Override - public Long merge(Object left, Object right) { - long rightValueInLong = ((Number) right).longValue(); - // DENSE_RANK always increase the rank by 1, irrespective of the number of duplicate ORDER BY keys seen - return (rightValueInLong == 0L) ? ((Number) left).longValue() : ((Number) left).longValue() + 1L; - } - } - - private static class WindowAggregateAccumulator extends AggregationUtils.Accumulator { - private static final Map<String, Function<ColumnDataType, AggregationUtils.Merger>> WIN_AGG_MERGERS = - ImmutableMap.<String, Function<ColumnDataType, AggregationUtils.Merger>>builder() - .putAll(AggregationUtils.Accumulator.MERGERS) - .put("ROW_NUMBER", cdt -> new MergeRowNumber()) - .put("RANK", cdt -> new MergeRank()) - .put("DENSE_RANK", cdt -> new MergeDenseRank()) - .build(); - - private final boolean _isPartitionByOnly; - private final boolean _isRankingWindowFunction; - private final ValueWindowFunction _valueWindowFunction; - - // Fields needed only for RANGE frame type queries (ORDER BY) - private final Map<Key, OrderKeyResult> _orderByResults = new HashMap<>(); - - WindowAggregateAccumulator(RexExpression.FunctionCall aggCall, - Map<String, Function<ColumnDataType, AggregationUtils.Merger>> merger, String functionName, - DataSchema inputSchema, OrderSetInfo orderSetInfo) { - super(aggCall, merger, functionName, inputSchema); - _isPartitionByOnly = CollectionUtils.isEmpty(orderSetInfo.getOrderSet()) || orderSetInfo.isPartitionByOnly(); - _isRankingWindowFunction = RANKING_FUNCTION_NAMES.contains(functionName); - _valueWindowFunction = ValueWindowFunction.construnctValueWindowFunction(functionName); - } - - /** - * For ROW type queries the aggregation function value depends on the order of the rows rather than on the actual - * keys. For such queries compute the current row value based on the previous row and previous partition key. - * This should only be called for ROW type queries. - */ - public Object computeRowResultForCurrentRow(Key currentPartitionKey, Key previousPartitionKey, Object[] row, - Object previousRowOutputValue) { - Object value = _inputRef == -1 ? _literal : row[_inputRef]; - if (previousPartitionKey == null || !currentPartitionKey.equals(previousPartitionKey)) { - return _merger.init(currentPartitionKey, _dataType); - } else { - return _merger.merge(previousRowOutputValue, value); - } - } - - /** - * For RANGE type queries, accumulate the function values for each PARTITION BY key and ORDER BY key based on - * the current row. Should only be called for RANGE type queries where the aggregation values are tied to the - * RANGE key and not to the row ordering. This should only be called for RANGE type queries. - */ - public void accumulateRangeResults(Key key, Key orderKey, Object[] row) { - // Ranking functions don't use the row value, thus cannot reuse the AggregationUtils accumulate function for them - if (_isPartitionByOnly && !_isRankingWindowFunction) { - accumulate(key, row); - return; - } - - // TODO: fix that single agg result (original type) has different type from multiple agg results (double). - Key previousOrderKeyIfPresent = - _orderByResults.get(key) == null ? null : _orderByResults.get(key).getPreviousOrderByKey(); - Object currentRes = previousOrderKeyIfPresent == null ? null - : _orderByResults.get(key).getOrderByResults().get(previousOrderKeyIfPresent); - Object value = _inputRef == -1 ? _literal : row[_inputRef]; - - // The ranking functions do not depend on the actual value of the data, but are calculated based on the - // position of the data ordered by the ORDER BY key. Thus they need to be handled differently and require setting - // whether the rank has changed or not and if changed then by how much. - _orderByResults.putIfAbsent(key, new OrderKeyResult()); - if (currentRes == null) { - value = _isRankingWindowFunction ? 0 : value; - _orderByResults.get(key).addOrderByResult(orderKey, _merger.init(value, _dataType)); - } else { - Object mergedResult; - if (orderKey.equals(previousOrderKeyIfPresent)) { - value = _isRankingWindowFunction ? 0 : value; - mergedResult = _merger.merge(currentRes, value); - } else { - Object previousValue = _orderByResults.get(key).getOrderByResults().get(previousOrderKeyIfPresent); - value = _isRankingWindowFunction ? _orderByResults.get(key).getCountOfDuplicateOrderByKeys() : value; - mergedResult = _merger.merge(previousValue, value); - } - _orderByResults.get(key).addOrderByResult(orderKey, mergedResult); - } - } - - public Object getRangeResultForKeys(Key key, Key orderKey) { - if (_isPartitionByOnly && !_isRankingWindowFunction) { - return _results.get(key); - } else { - return _orderByResults.get(key).getOrderByResults().get(orderKey); - } - } - - public Map<Key, OrderKeyResult> getRangeOrderByResults() { - return _orderByResults; - } - - public Object getValueResultForKeys(Key orderKey, int rowId, List<Object[]> partitionRows) { - Object[] row = _valueWindowFunction.processRow(rowId, partitionRows); - if (row == null) { - return null; - } - return _inputRef == -1 ? _literal : row[_inputRef]; - } - - static class OrderKeyResult { - final Map<Key, Object> _orderByResults; - Key _previousOrderByKey; - // Store the counts of duplicate ORDER BY keys seen for this PARTITION BY key for calculating RANK/DENSE_RANK - long _countOfDuplicateOrderByKeys; - - OrderKeyResult() { - _orderByResults = new HashMap<>(); - _previousOrderByKey = null; - _countOfDuplicateOrderByKeys = 0; - } - - public void addOrderByResult(Key orderByKey, Object value) { - // We expect to get the rows in order based on the ORDER BY key so it is safe to blindly assign the - // current key as the previous key - _orderByResults.put(orderByKey, value); - _countOfDuplicateOrderByKeys = - (_previousOrderByKey != null && _previousOrderByKey.equals(orderByKey)) ? _countOfDuplicateOrderByKeys + 1 - : 1; - _previousOrderByKey = orderByKey; - } - - public Map<Key, Object> getOrderByResults() { - return _orderByResults; - } - - public Key getPreviousOrderByKey() { - return _previousOrderByKey; - } - - public long getCountOfDuplicateOrderByKeys() { - return _countOfDuplicateOrderByKeys; - } - } - } - public enum StatKey implements StatMap.Key { EXECUTION_TIME_MS(StatMap.Type.LONG) { @Override diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java index 049da05220..0ea7b5df87 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java @@ -194,23 +194,17 @@ public class AggregationUtils { protected final int _inputRef; protected final Object _literal; protected final Map<Key, Object> _results = new HashMap<>(); - protected final Merger _merger; protected final ColumnDataType _dataType; public Map<Key, Object> getResults() { return _results; } - public Merger getMerger() { - return _merger; - } - public ColumnDataType getDataType() { return _dataType; } - public Accumulator(RexExpression.FunctionCall aggCall, - Map<String, Function<ColumnDataType, AggregationUtils.Merger>> merger, String functionName, + public Accumulator(RexExpression.FunctionCall aggCall, String functionName, DataSchema inputSchema) { // agg function operand should either be a InputRef or a Literal RexExpression rexExpression = toAggregationFunctionOperand(aggCall); @@ -223,20 +217,6 @@ public class AggregationUtils { _literal = ((RexExpression.Literal) rexExpression).getValue(); _dataType = rexExpression.getDataType(); } - _merger = merger.containsKey(functionName) ? merger.get(functionName).apply(_dataType) : null; - } - - public void accumulate(Key key, Object[] row) { - // TODO: fix that single agg result (original type) has different type from multiple agg results (double). - Object currentRes = _results.get(key); - Object value = _inputRef == -1 ? _literal : row[_inputRef]; - - if (currentRes == null) { - _results.put(key, _merger.init(value, _dataType)); - } else { - Object mergedResult = _merger.merge(currentRes, value); - _results.put(key, mergedResult); - } } private RexExpression toAggregationFunctionOperand(RexExpression.FunctionCall rexExpression) { diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/ValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/ValueWindowFunction.java deleted file mode 100644 index c327bcf0ba..0000000000 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/ValueWindowFunction.java +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.query.runtime.operator.window; - -import com.google.common.collect.ImmutableMap; -import java.lang.reflect.InvocationTargetException; -import java.util.List; -import java.util.Map; - - -public abstract class ValueWindowFunction implements WindowFunction { - public static final Map<String, Class<? extends ValueWindowFunction>> VALUE_WINDOW_FUNCTION_MAP = - ImmutableMap.<String, Class<? extends ValueWindowFunction>>builder() - .put("LEAD", LeadValueWindowFunction.class) - .put("LAG", LagValueWindowFunction.class) - .put("FIRST_VALUE", FirstValueWindowFunction.class) - .put("LAST_VALUE", LastValueWindowFunction.class) - .build(); - - /** - * @param rowId Row id to process - * @param partitionedRows List of rows for reference - * @return Row with the window function applied - */ - public abstract Object[] processRow(int rowId, List<Object[]> partitionedRows); - - public static ValueWindowFunction construnctValueWindowFunction(String functionName) { - Class<? extends ValueWindowFunction> valueWindowFunctionClass = VALUE_WINDOW_FUNCTION_MAP.get(functionName); - if (valueWindowFunctionClass == null) { - return null; - } - try { - return valueWindowFunctionClass.getDeclaredConstructor().newInstance(); - } catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { - throw new RuntimeException("Failed to instantiate ValueWindowFunction for function: " + functionName, e); - } - } -} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java index 56d893badf..6221caeae7 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java @@ -19,13 +19,47 @@ package org.apache.pinot.query.runtime.operator.window; import java.util.List; +import org.apache.commons.collections.CollectionUtils; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; +import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; -public interface WindowFunction { +/** + * This class provides the basic structure for window functions. It provides the batch row processing API: + * processRows(List<Object[]> rows) which processes a batch of rows at a time. + * + */ +public abstract class WindowFunction extends AggregationUtils.Accumulator { + protected final String _functionName; + protected final int[] _inputRefs; + protected final boolean _isPartitionByOnly; + protected final List<RexExpression> _orderSet; + + public WindowFunction(RexExpression.FunctionCall aggCall, String functionName, + DataSchema inputSchema, WindowAggregateOperator.OrderSetInfo orderSetInfo) { + super(aggCall, functionName, inputSchema); + _isPartitionByOnly = CollectionUtils.isEmpty(orderSetInfo.getOrderSet()) || orderSetInfo.isPartitionByOnly(); + boolean isRankingWindowFunction = WindowAggregateOperator.RANKING_FUNCTION_NAMES.contains(functionName); + int[] inputRefs = new int[]{_inputRef}; + if (isRankingWindowFunction) { + inputRefs = orderSetInfo._orderSet.stream().map(RexExpression.InputRef.class::cast) + .mapToInt(RexExpression.InputRef::getIndex).toArray(); + } + _functionName = functionName; + _inputRefs = inputRefs; + _orderSet = orderSetInfo._orderSet; + } /** + * Batch processing API for Window functions. + * This method processes a batch of rows at a time. + * Each row generates one object as output. + * Note, the input and output list size should be the same. + * * @param rows List of rows to process * @return List of rows with the window function applied */ - List<Object[]> processRows(List<Object[]> rows); + public abstract List<Object> processRows(List<Object[]> rows); } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunctionFactory.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunctionFactory.java new file mode 100644 index 0000000000..7f2806b757 --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunctionFactory.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.window; + +import com.google.common.collect.ImmutableMap; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.Map; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; +import org.apache.pinot.query.runtime.operator.window.aggregate.AggregateWindowFunction; +import org.apache.pinot.query.runtime.operator.window.range.RangeWindowFunction; +import org.apache.pinot.query.runtime.operator.window.value.ValueWindowFunction; + + +/** + * Factory class to construct WindowFunction instances. + */ +public class WindowFunctionFactory { + private WindowFunctionFactory() { + } + + public static final Map<String, Class<? extends WindowFunction>> WINDOW_FUNCTION_MAP = + ImmutableMap.<String, Class<? extends WindowFunction>>builder() + .putAll(RangeWindowFunction.WINDOW_FUNCTION_MAP) + .putAll(ValueWindowFunction.WINDOW_FUNCTION_MAP) + .build(); + + public static WindowFunction construnctWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, + WindowAggregateOperator.OrderSetInfo orderSetInfo) { + String functionName = aggCall.getFunctionName(); + Class<? extends WindowFunction> windowFunctionClass = + WINDOW_FUNCTION_MAP.getOrDefault(functionName, AggregateWindowFunction.class); + try { + Constructor<? extends WindowFunction> constructor = + windowFunctionClass.getConstructor(RexExpression.FunctionCall.class, String.class, DataSchema.class, + WindowAggregateOperator.OrderSetInfo.class); + return constructor.newInstance(aggCall, functionName, inputSchema, orderSetInfo); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + throw new RuntimeException("Failed to instantiate WindowFunction for function name: " + functionName, e); + } + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java new file mode 100644 index 0000000000..8dd5c791e4 --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java @@ -0,0 +1,124 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.window.aggregate; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.commons.collections.CollectionUtils; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.core.data.table.Key; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; +import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; +import org.apache.pinot.query.runtime.operator.window.WindowFunction; + + +public class AggregateWindowFunction extends WindowFunction { + private final AggregationUtils.Merger _merger; + + public AggregateWindowFunction(RexExpression.FunctionCall aggCall, String functionName, + DataSchema inputSchema, WindowAggregateOperator.OrderSetInfo orderSetInfo) { + super(aggCall, functionName, inputSchema, orderSetInfo); + _merger = AggregationUtils.Accumulator.MERGERS.get(_functionName).apply(_dataType); + } + + @Override + public final List<Object> processRows(List<Object[]> rows) { + if (_isPartitionByOnly) { + return processPartitionOnlyRows(rows); + } else { + return processRowsInternal(rows); + } + } + + protected List<Object> processPartitionOnlyRows(List<Object[]> rows) { + Object mergedResult = null; + for (Object[] row : rows) { + Object value = _inputRef == -1 ? _literal : row[_inputRef]; + if (value == null) { + continue; + } + if (mergedResult == null) { + mergedResult = _merger.init(value, _dataType); + } else { + mergedResult = _merger.merge(mergedResult, value); + } + } + return Collections.nCopies(rows.size(), mergedResult); + } + + protected List<Object> processRowsInternal(List<Object[]> rows) { + Key emptyOrderKey = AggregationUtils.extractEmptyKey(); + OrderKeyResult orderByResult = new OrderKeyResult(); + for (Object[] row : rows) { + // Only need to accumulate the aggregate function values for RANGE type. ROW type can be calculated as + // we output the rows since the aggregation value depends on the neighboring rows. + Key orderKey = (_isPartitionByOnly && CollectionUtils.isEmpty(_orderSet)) ? emptyOrderKey + : AggregationUtils.extractRowKey(row, _orderSet); + + Key previousOrderKeyIfPresent = orderByResult.getPreviousOrderByKey(); + Object currentRes = previousOrderKeyIfPresent == null ? null + : orderByResult.getOrderByResults().get(previousOrderKeyIfPresent); + Object value = _inputRef == -1 ? _literal : row[_inputRef]; + if (currentRes == null) { + orderByResult.addOrderByResult(orderKey, _merger.init(value, _dataType)); + } else { + orderByResult.addOrderByResult(orderKey, _merger.merge(currentRes, value)); + } + } + List<Object> results = new ArrayList<>(rows.size()); + for (Object[] row : rows) { + // Only need to accumulate the aggregate function values for RANGE type. ROW type can be calculated as + // we output the rows since the aggregation value depends on the neighboring rows. + Key orderKey = (_isPartitionByOnly && CollectionUtils.isEmpty(_orderSet)) ? emptyOrderKey + : AggregationUtils.extractRowKey(row, _orderSet); + Object value = orderByResult.getOrderByResults().get(orderKey); + results.add(value); + } + return results; + } + + static class OrderKeyResult { + final Map<Key, Object> _orderByResults; + Key _previousOrderByKey; + + OrderKeyResult() { + _orderByResults = new HashMap<>(); + _previousOrderByKey = null; + } + + public void addOrderByResult(Key orderByKey, Object value) { + // We expect to get the rows in order based on the ORDER BY key so it is safe to blindly assign the + // current key as the previous key + _orderByResults.put(orderByKey, value); + _previousOrderByKey = orderByKey; + } + + public Map<Key, Object> getOrderByResults() { + return _orderByResults; + } + + public Key getPreviousOrderByKey() { + return _previousOrderByKey; + } + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/DenseRankWindowFunction.java similarity index 51% copy from pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java copy to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/DenseRankWindowFunction.java index bd8a50ea48..00f23f851a 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/DenseRankWindowFunction.java @@ -16,32 +16,37 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.query.runtime.operator.window; +package org.apache.pinot.query.runtime.operator.window.range; import java.util.ArrayList; import java.util.List; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; -public class LeadValueWindowFunction extends ValueWindowFunction { +public class DenseRankWindowFunction extends RangeWindowFunction { - @Override - public Object[] processRow(int rowId, List<Object[]> partitionedRows) { - if (rowId == partitionedRows.size() - 1) { - return null; - } else { - return partitionedRows.get(rowId + 1); - } + public DenseRankWindowFunction(RexExpression.FunctionCall aggCall, String functionName, DataSchema inputSchema, + WindowAggregateOperator.OrderSetInfo orderSetInfo) { + super(aggCall, functionName, inputSchema, orderSetInfo); } @Override - public List<Object[]> processRows(List<Object[]> rows) { - List<Object[]> result = new ArrayList<>(); - for (int i = 0; i < rows.size(); i++) { - if (i == rows.size() - 1) { - result.add(null); + public List<Object> processRows(List<Object[]> rows) { + List<Object> result = new ArrayList<>(); + int rank = 1; + Object[] lastRow = null; + for (Object[] row : rows) { + if (lastRow == null) { + result.add(rank); } else { - result.add(rows.get(i + 1)); + if (compareRows(row, lastRow) != 0) { + rank++; + } + result.add(rank); } + lastRow = row; } return result; } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RangeWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RangeWindowFunction.java new file mode 100644 index 0000000000..a4ac37318f --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RangeWindowFunction.java @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.window.range; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; +import org.apache.pinot.query.runtime.operator.window.WindowFunction; + + +public abstract class RangeWindowFunction extends WindowFunction { + public static final Map<String, Class<? extends WindowFunction>> WINDOW_FUNCTION_MAP = + ImmutableMap.<String, Class<? extends WindowFunction>>builder() + // Range window functions + .put("ROW_NUMBER", RowNumberWindowFunction.class) + .put("RANK", RankWindowFunction.class) + .put("DENSE_RANK", DenseRankWindowFunction.class) + .build(); + + public RangeWindowFunction(RexExpression.FunctionCall aggCall, String functionName, + DataSchema inputSchema, WindowAggregateOperator.OrderSetInfo orderSetInfo) { + super(aggCall, functionName, inputSchema, orderSetInfo); + } + + protected int compareRows(Object[] leftRow, Object[] rightRow) { + for (int inputRef : _inputRefs) { + if (inputRef < 0) { + continue; + } + Object leftValue = leftRow[inputRef]; + Object rightValue = rightRow[inputRef]; + if (leftValue == null) { + if (rightValue != null) { + return -1; + } + } else { + if (rightValue == null) { + return 1; + } else { + int result = ((Comparable) leftValue).compareTo(rightValue); + if (result != 0) { + return result; + } + } + } + } + return 0; + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankWindowFunction.java similarity index 52% copy from pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java copy to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankWindowFunction.java index bd8a50ea48..8688f70216 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankWindowFunction.java @@ -16,32 +16,35 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.query.runtime.operator.window; +package org.apache.pinot.query.runtime.operator.window.range; import java.util.ArrayList; import java.util.List; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; -public class LeadValueWindowFunction extends ValueWindowFunction { +public class RankWindowFunction extends RangeWindowFunction { - @Override - public Object[] processRow(int rowId, List<Object[]> partitionedRows) { - if (rowId == partitionedRows.size() - 1) { - return null; - } else { - return partitionedRows.get(rowId + 1); - } + public RankWindowFunction(RexExpression.FunctionCall aggCall, String functionName, DataSchema inputSchema, + WindowAggregateOperator.OrderSetInfo orderSetInfo) { + super(aggCall, functionName, inputSchema, orderSetInfo); } @Override - public List<Object[]> processRows(List<Object[]> rows) { - List<Object[]> result = new ArrayList<>(); + public List<Object> processRows(List<Object[]> rows) { + int rank = 1; + List<Object> result = new ArrayList<>(); for (int i = 0; i < rows.size(); i++) { - if (i == rows.size() - 1) { - result.add(null); - } else { - result.add(rows.get(i + 1)); + if (i > 0) { + Object[] prevRow = rows.get(i - 1); + Object[] currentRow = rows.get(i); + if (compareRows(prevRow, currentRow) != 0) { + rank = i + 1; + } } + result.add(rank); } return result; } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RowNumberWindowFunction.java similarity index 56% copy from pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java copy to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RowNumberWindowFunction.java index cc7db910d2..dd75d17f6c 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RowNumberWindowFunction.java @@ -16,24 +16,27 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.query.runtime.operator.window; +package org.apache.pinot.query.runtime.operator.window.range; import java.util.ArrayList; import java.util.List; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; -public class LastValueWindowFunction extends ValueWindowFunction { +public class RowNumberWindowFunction extends RangeWindowFunction { - @Override - public Object[] processRow(int rowId, List<Object[]> partitionedRows) { - return partitionedRows.get(partitionedRows.size() - 1); + public RowNumberWindowFunction(RexExpression.FunctionCall aggCall, String functionName, DataSchema inputSchema, + WindowAggregateOperator.OrderSetInfo orderSetInfo) { + super(aggCall, functionName, inputSchema, orderSetInfo); } @Override - public List<Object[]> processRows(List<Object[]> rows) { - List<Object[]> result = new ArrayList<>(); - for (int i = 0; i < rows.size(); i++) { - result.add(rows.get(rows.size() - 1)); + public List<Object> processRows(List<Object[]> rows) { + List<Object> result = new ArrayList<>(); + for (long i = 1; i <= rows.size(); i++) { + result.add(i); } return result; } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/FirstValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java similarity index 61% rename from pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/FirstValueWindowFunction.java rename to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java index 5d2ae75950..6894a156d6 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/FirstValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java @@ -16,24 +16,28 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.query.runtime.operator.window; +package org.apache.pinot.query.runtime.operator.window.value; import java.util.ArrayList; import java.util.List; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; public class FirstValueWindowFunction extends ValueWindowFunction { - @Override - public Object[] processRow(int rowId, List<Object[]> partitionedRows) { - return partitionedRows.get(0); + public FirstValueWindowFunction(RexExpression.FunctionCall aggCall, + String functionName, DataSchema inputSchema, + WindowAggregateOperator.OrderSetInfo orderSetInfo) { + super(aggCall, functionName, inputSchema, orderSetInfo); } @Override - public List<Object[]> processRows(List<Object[]> rows) { - List<Object[]> result = new ArrayList<>(); + public List<Object> processRows(List<Object[]> rows) { + List<Object> result = new ArrayList<>(); for (int i = 0; i < rows.size(); i++) { - result.add(rows.get(0)); + result.add(extractValueFromRow(rows.get(0))); } return result; } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LagValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java similarity index 62% rename from pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LagValueWindowFunction.java rename to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java index 9bca8ec930..7e093ed792 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LagValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java @@ -16,31 +16,31 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.query.runtime.operator.window; +package org.apache.pinot.query.runtime.operator.window.value; import java.util.ArrayList; import java.util.List; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; public class LagValueWindowFunction extends ValueWindowFunction { - @Override - public Object[] processRow(int rowId, List<Object[]> partitionedRows) { - if (rowId == 0) { - return null; - } else { - return partitionedRows.get(rowId - 1); - } + public LagValueWindowFunction(RexExpression.FunctionCall aggCall, + String functionName, DataSchema inputSchema, + WindowAggregateOperator.OrderSetInfo orderSetInfo) { + super(aggCall, functionName, inputSchema, orderSetInfo); } @Override - public List<Object[]> processRows(List<Object[]> rows) { - List<Object[]> result = new ArrayList<>(); + public List<Object> processRows(List<Object[]> rows) { + List<Object> result = new ArrayList<>(); for (int i = 0; i < rows.size(); i++) { if (i == 0) { result.add(null); } else { - result.add(rows.get(i - 1)); + result.add(extractValueFromRow(rows.get(i - 1))); } } return result; diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java similarity index 61% rename from pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java rename to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java index cc7db910d2..bccafccf8a 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java @@ -16,24 +16,28 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.query.runtime.operator.window; +package org.apache.pinot.query.runtime.operator.window.value; import java.util.ArrayList; import java.util.List; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; public class LastValueWindowFunction extends ValueWindowFunction { - @Override - public Object[] processRow(int rowId, List<Object[]> partitionedRows) { - return partitionedRows.get(partitionedRows.size() - 1); + public LastValueWindowFunction(RexExpression.FunctionCall aggCall, + String functionName, DataSchema inputSchema, + WindowAggregateOperator.OrderSetInfo orderSetInfo) { + super(aggCall, functionName, inputSchema, orderSetInfo); } @Override - public List<Object[]> processRows(List<Object[]> rows) { - List<Object[]> result = new ArrayList<>(); + public List<Object> processRows(List<Object[]> rows) { + List<Object> result = new ArrayList<>(); for (int i = 0; i < rows.size(); i++) { - result.add(rows.get(rows.size() - 1)); + result.add(extractValueFromRow(rows.get(rows.size() - 1))); } return result; } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java similarity index 63% rename from pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java rename to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java index bd8a50ea48..4cbd917274 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java @@ -16,31 +16,31 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.query.runtime.operator.window; +package org.apache.pinot.query.runtime.operator.window.value; import java.util.ArrayList; import java.util.List; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; public class LeadValueWindowFunction extends ValueWindowFunction { - @Override - public Object[] processRow(int rowId, List<Object[]> partitionedRows) { - if (rowId == partitionedRows.size() - 1) { - return null; - } else { - return partitionedRows.get(rowId + 1); - } + public LeadValueWindowFunction(RexExpression.FunctionCall aggCall, + String functionName, DataSchema inputSchema, + WindowAggregateOperator.OrderSetInfo orderSetInfo) { + super(aggCall, functionName, inputSchema, orderSetInfo); } @Override - public List<Object[]> processRows(List<Object[]> rows) { - List<Object[]> result = new ArrayList<>(); + public List<Object> processRows(List<Object[]> rows) { + List<Object> result = new ArrayList<>(); for (int i = 0; i < rows.size(); i++) { if (i == rows.size() - 1) { result.add(null); } else { - result.add(rows.get(i + 1)); + result.add(extractValueFromRow(rows.get(i + 1))); } } return result; diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/ValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/ValueWindowFunction.java new file mode 100644 index 0000000000..7226a926d4 --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/ValueWindowFunction.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.window.value; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; +import org.apache.pinot.query.runtime.operator.window.WindowFunction; + + +public abstract class ValueWindowFunction extends WindowFunction { + public static final Map<String, Class<? extends WindowFunction>> WINDOW_FUNCTION_MAP = + ImmutableMap.<String, Class<? extends WindowFunction>>builder() + // Value window functions + .put("LEAD", LeadValueWindowFunction.class) + .put("LAG", LagValueWindowFunction.class) + .put("FIRST_VALUE", FirstValueWindowFunction.class) + .put("LAST_VALUE", LastValueWindowFunction.class) + .build(); + + public ValueWindowFunction(RexExpression.FunctionCall aggCall, String functionName, + DataSchema inputSchema, WindowAggregateOperator.OrderSetInfo orderSetInfo) { + super(aggCall, functionName, inputSchema, orderSetInfo); + } + + protected Object extractValueFromRow(Object[] row) { + return _inputRef == -1 ? _literal : (row == null ? null : row[_inputRef]); + } +} diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java index 2bfca7c149..61df71d9ad 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java @@ -19,7 +19,6 @@ package org.apache.pinot.query.runtime.operator; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -36,7 +35,6 @@ import org.apache.pinot.query.routing.VirtualServerAddress; import org.apache.pinot.query.runtime.blocks.TransferableBlock; import org.apache.pinot.query.runtime.blocks.TransferableBlockTestUtils; import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils; -import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; @@ -230,44 +228,6 @@ public class WindowAggregateOperatorTest { Assert.assertTrue(block2.isEndOfStreamBlock(), "Second block is EOS (done processing)"); } - @Test - public void testShouldCallMergerWhenWindowAggregatingMultipleRows() { - // Given: - List<RexExpression> calls = ImmutableList.of(getSum(new RexExpression.InputRef(1))); - List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0)); - - DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT}); - Mockito.when(_input.nextBlock()) - .thenReturn(OperatorTestUtil.block(inSchema, new Object[]{1, 1}, new Object[]{1, 2})) - .thenReturn(OperatorTestUtil.block(inSchema, new Object[]{1, 3})) - .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); - - AggregationUtils.Merger merger = Mockito.mock(AggregationUtils.Merger.class); - Mockito.when(merger.merge(Mockito.any(), Mockito.any())).thenReturn(12d); - Mockito.when(merger.init(Mockito.any(), Mockito.any())).thenReturn(1d); - DataSchema outSchema = new DataSchema(new String[]{"group", "arg", "sum"}, new ColumnDataType[]{INT, INT, DOUBLE}); - WindowAggregateOperator operator = - new WindowAggregateOperator(OperatorTestUtil.getTracingContext(), _input, group, Collections.emptyList(), - Collections.emptyList(), Collections.emptyList(), calls, Integer.MIN_VALUE, Integer.MAX_VALUE, - WindowNode.WindowFrameType.RANGE, Collections.emptyList(), outSchema, inSchema, - ImmutableMap.of("SUM", cdt -> merger)); - - // When: - TransferableBlock resultBlock = operator.nextBlock(); // (output result) - - // Then: - // should call merger twice, one from second row in first block and two from the first row - // in second block - Mockito.verify(merger, Mockito.times(1)).init(Mockito.any(), Mockito.any()); - Mockito.verify(merger, Mockito.times(2)).merge(Mockito.any(), Mockito.any()); - Assert.assertEquals(resultBlock.getContainer().get(0), new Object[]{1, 1, 12d}, - "Expected three columns (original two columns, agg literal value)"); - Assert.assertEquals(resultBlock.getContainer().get(1), new Object[]{1, 2, 12d}, - "Expected three columns (original two columns, agg literal value)"); - Assert.assertEquals(resultBlock.getContainer().get(2), new Object[]{1, 3, 12d}, - "Expected three columns (original two columns, agg literal value)"); - } - @Test public void testPartitionByWindowAggregateWithHashCollision() { MultiStageOperator upstreamOperator = OperatorTestUtil.getOperator(OperatorTestUtil.OP_1); @@ -292,8 +252,8 @@ public class WindowAggregateOperatorTest { Assert.assertEquals(resultRows.get(2), expectedRows.get(2)); } - @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = ".*Unexpected aggregation " - + "function name: AVERAGE.*") + @Test(expectedExceptions = RuntimeException.class, expectedExceptionsMessageRegExp = ".*Failed to instantiate " + + "WindowFunction for function name: AVERAGE.*") public void testShouldThrowOnUnknownAggFunction() { // Given: List<RexExpression> calls = ImmutableList.of( @@ -309,8 +269,8 @@ public class WindowAggregateOperatorTest { WindowNode.WindowFrameType.RANGE, Collections.emptyList(), outSchema, inSchema); } - @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = ".*Unexpected aggregation " - + "function name: NTILE.*") + @Test(expectedExceptions = RuntimeException.class, expectedExceptionsMessageRegExp = ".*Failed to instantiate " + + "WindowFunction for function name: NTILE.*") public void testShouldThrowOnUnknownRankAggFunction() { // TODO: Remove this test when support is added for NTILE function // Given: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org