This is an automated email from the ASF dual-hosted git repository.
rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 22b345526b [multistage] [bugfix] Fix hash collision (#9729)
22b345526b is described below
commit 22b345526b2abada6cea219484696b1b78669e08
Author: Yao Liu <[email protected]>
AuthorDate: Sat Nov 5 08:30:35 2022 -0700
[multistage] [bugfix] Fix hash collision (#9729)
* fix agg function hashkey collision bug
* fix hash join hashkey collision bug
---
.../query/runtime/blocks/TransferableBlock.java | 8 --
.../query/runtime/operator/AggregateOperator.java | 86 +++++--------
.../query/runtime/operator/HashJoinOperator.java | 24 ++--
.../query/runtime/plan/PhysicalPlanVisitor.java | 5 +-
.../runtime/operator/AggregateOperatorTest.java | 65 ++++++++++
.../runtime/operator/HashJoinOperatorTest.java | 139 +++++++++++++++++++++
.../query/runtime/operator/OperatorTestUtil.java | 50 ++++++++
7 files changed, 301 insertions(+), 76 deletions(-)
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
index 46dd9dc967..93508ffac6 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
@@ -161,14 +161,6 @@ public class TransferableBlock implements Block {
return _isErrorBlock;
}
- boolean isContainerBlock() {
- return _container != null;
- }
-
- boolean isDataBlock() {
- return _dataBlock != null;
- }
-
@Override
public BlockValSet getBlockValueSet() {
throw new UnsupportedOperationException();
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 0c261d26e3..56f2b9087b 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -26,15 +26,10 @@ import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.pinot.common.datablock.BaseDataBlock;
-import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.data.table.Key;
import org.apache.pinot.core.operator.BaseOperator;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.MaxAggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.MinAggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction;
import org.apache.pinot.core.query.selection.SelectionOperatorUtils;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
@@ -44,39 +39,49 @@ import org.apache.pinot.spi.data.FieldSpec;
/**
*
+ * AggregateOperator is used to aggregate values over a set of group by keys.
+ * Output data will be in the format of [group by key, aggregate result1, ...
aggregate resultN]
+ * Currently, we only support SUM/COUNT/MIN/MAX aggregation.
+ *
+ * When the list of aggregation calls is empty, this class is used to
calculate distinct result based on group by keys.
+ * In this case, the input can be any type.
+ *
+ * If the list of aggregation calls is not empty, the input of aggregation has
to be a number.
+ *
+ * Note: This class performs aggregation over the double value of input.
+ * If the input is single value, the output type will be input type.
Otherwise, the output type will be double.
*/
public class AggregateOperator extends BaseOperator<TransferableBlock> {
private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR";
private Operator<TransferableBlock> _inputOperator;
+ // TODO: Deal with the case where _aggCalls is empty but we have groupSet
setup, which means this is a Distinct call.
private List<RexExpression> _aggCalls;
private List<RexExpression> _groupSet;
- private final AggregationFunction[] _aggregationFunctions;
private final int[] _aggregationFunctionInputRefs;
private final Object[] _aggregationFunctionLiterals;
private final DataSchema _resultSchema;
- private final Map<Integer, Object>[] _groupByResultHolders;
- private final Map<Integer, Object[]> _groupByKeyHolder;
-
- private DataSchema _upstreamDataSchema;
+ private final Map<Key, Object>[] _groupByResultHolders;
+ private final Map<Key, Object[]> _groupByKeyHolder;
private TransferableBlock _upstreamErrorBlock;
private boolean _isCumulativeBlockConstructed;
// TODO: refactor Pinot Reducer code to support the intermediate stage agg
operator.
+ // aggCalls has to be a list of FunctionCall and cannot be null
+ // groupSet has to be a list of InputRef and cannot be null
+ // TODO: Add these two checks when we confirm we can handle error in
upstream ctor call.
public AggregateOperator(Operator<TransferableBlock> inputOperator,
DataSchema dataSchema,
- List<RexExpression> aggCalls, List<RexExpression> groupSet, DataSchema
upstreamDataSchema) {
+ List<RexExpression> aggCalls, List<RexExpression> groupSet) {
_inputOperator = inputOperator;
_aggCalls = aggCalls;
_groupSet = groupSet;
- _upstreamDataSchema = upstreamDataSchema;
_upstreamErrorBlock = null;
- _aggregationFunctions = new AggregationFunction[_aggCalls.size()];
_aggregationFunctionInputRefs = new int[_aggCalls.size()];
_aggregationFunctionLiterals = new Object[_aggCalls.size()];
_groupByResultHolders = new Map[_aggCalls.size()];
- _groupByKeyHolder = new HashMap<Integer, Object[]>();
+ _groupByKeyHolder = new HashMap<Key, Object[]>();
for (int i = 0; i < aggCalls.size(); i++) {
// agg function operand should either be a InputRef or a Literal
RexExpression rexExpression =
toAggregationFunctionOperand(aggCalls.get(i));
@@ -86,8 +91,7 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
_aggregationFunctionInputRefs[i] = -1;
_aggregationFunctionLiterals[i] = ((RexExpression.Literal)
rexExpression).getValue();
}
- _aggregationFunctions[i] = toAggregationFunction(aggCalls.get(i),
_aggregationFunctionInputRefs[i]);
- _groupByResultHolders[i] = new HashMap<Integer, Object>();
+ _groupByResultHolders[i] = new HashMap<Key, Object>();
}
_resultSchema = dataSchema;
@@ -129,8 +133,8 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
}
if (!_isCumulativeBlockConstructed) {
List<Object[]> rows = new ArrayList<>(_groupByKeyHolder.size());
- for (Map.Entry<Integer, Object[]> e : _groupByKeyHolder.entrySet()) {
- Object[] row = new Object[_aggregationFunctions.length +
_groupSet.size()];
+ for (Map.Entry<Key, Object[]> e : _groupByKeyHolder.entrySet()) {
+ Object[] row = new Object[_aggCalls.size() + _groupSet.size()];
Object[] keyElements = e.getValue();
for (int i = 0; i < keyElements.length; i++) {
row[i] = keyElements[i];
@@ -160,17 +164,17 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
for (int rowId = 0; rowId < numRows; rowId++) {
Object[] row =
SelectionOperatorUtils.extractRowFromDataTable(dataBlock, rowId);
Key key = extraRowKey(row, _groupSet);
- int keyHashCode = key.hashCode();
- _groupByKeyHolder.put(keyHashCode, key.getValues());
- for (int i = 0; i < _aggregationFunctions.length; i++) {
- Object currentRes = _groupByResultHolders[i].get(keyHashCode);
+ _groupByKeyHolder.put(key, key.getValues());
+ for (int i = 0; i < _aggCalls.size(); i++) {
+ Object currentRes = _groupByResultHolders[i].get(key);
+ // TODO: fix that single agg result (original type) has different
type from multiple agg results (double).
if (currentRes == null) {
- _groupByResultHolders[i].put(keyHashCode,
_aggregationFunctionInputRefs[i] == -1
- ? _aggregationFunctionLiterals[i] :
row[_aggregationFunctionInputRefs[i]]);
+ _groupByResultHolders[i].put(key,
_aggregationFunctionInputRefs[i] == -1 ? _aggregationFunctionLiterals[i]
+ : row[_aggregationFunctionInputRefs[i]]);
} else {
- _groupByResultHolders[i].put(keyHashCode,
- merge(_aggCalls.get(i), currentRes,
_aggregationFunctionInputRefs[i] == -1
- ? _aggregationFunctionLiterals[i] :
row[_aggregationFunctionInputRefs[i]]));
+ _groupByResultHolders[i].put(key, merge(_aggCalls.get(i),
currentRes,
+ _aggregationFunctionInputRefs[i] == -1 ?
_aggregationFunctionLiterals[i]
+ : row[_aggregationFunctionInputRefs[i]]));
}
}
}
@@ -183,34 +187,6 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
}
}
- private AggregationFunction toAggregationFunction(RexExpression aggCall, int
aggregationFunctionInputRef) {
- Preconditions.checkState(aggCall instanceof RexExpression.FunctionCall);
- // TODO(Rong Rong): query options are not supported by the new engine at
this moment.
- switch (((RexExpression.FunctionCall) aggCall).getFunctionName()) {
- case "$SUM":
- case "$SUM0":
- case "SUM":
- return new SumAggregationFunction(
-
ExpressionContext.forIdentifier(String.valueOf(aggregationFunctionInputRef)));
- case "$MIN":
- case "$MIN0":
- case "MIN":
- return new MinAggregationFunction(
-
ExpressionContext.forIdentifier(String.valueOf(aggregationFunctionInputRef)));
- case "$MAX":
- case "$MAX0":
- case "MAX":
- return new MaxAggregationFunction(
-
ExpressionContext.forIdentifier(String.valueOf(aggregationFunctionInputRef)));
- // COUNT(*) is rewritten to SUM(1)
- case "COUNT":
- return new
SumAggregationFunction(ExpressionContext.forLiteralContext(FieldSpec.DataType.INT,
1));
- default:
- throw new IllegalStateException(
- "Unexpected value: " + ((RexExpression.FunctionCall)
aggCall).getFunctionName());
- }
- }
-
private Object merge(RexExpression aggCall, Object left, Object right) {
Preconditions.checkState(aggCall instanceof RexExpression.FunctionCall);
switch (((RexExpression.FunctionCall) aggCall).getFunctionName()) {
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
index bcf6807dd9..61cd18ff5a 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
@@ -28,6 +28,7 @@ import org.apache.pinot.common.datablock.BaseDataBlock;
import org.apache.pinot.common.datablock.DataBlockUtils;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.data.table.Key;
import org.apache.pinot.core.operator.BaseOperator;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.partitioning.KeySelector;
@@ -44,17 +45,18 @@ import
org.apache.pinot.query.runtime.operator.operands.FilterOperand;
* it looks up for the corresponding row(s) from the hash table and create a
joint row.
*
* <p>For each of the data block received from the left table, it will
generate a joint data block.
+ *
+ * We currently support left join, inner join and semi join.
+ * The output is in the format of [left_row, right_row]
*/
public class HashJoinOperator extends BaseOperator<TransferableBlock> {
private static final String EXPLAIN_NAME = "BROADCAST_JOIN";
- private final HashMap<Integer, List<Object[]>> _broadcastHashTable;
+ private final HashMap<Key, List<Object[]>> _broadcastHashTable;
private final Operator<TransferableBlock> _leftTableOperator;
private final Operator<TransferableBlock> _rightTableOperator;
private final JoinRelType _joinType;
private final DataSchema _resultSchema;
- private final DataSchema _leftTableSchema;
- private final DataSchema _rightTableSchema;
private final int _resultRowSize;
private final List<FilterOperand> _joinClauseEvaluators;
private boolean _isHashTableBuilt;
@@ -62,16 +64,16 @@ public class HashJoinOperator extends
BaseOperator<TransferableBlock> {
private KeySelector<Object[], Object[]> _leftKeySelector;
private KeySelector<Object[], Object[]> _rightKeySelector;
- public HashJoinOperator(Operator<TransferableBlock> leftTableOperator,
DataSchema leftSchema,
- Operator<TransferableBlock> rightTableOperator, DataSchema rightSchema,
DataSchema outputSchema,
- JoinNode.JoinKeys joinKeys, List<RexExpression> joinClauses, JoinRelType
joinType) {
+ // TODO: Fix inequi join bug. (https://github.com/apache/pinot/issues/9728)
+ // TODO: Double check semi join logic.
+ public HashJoinOperator(Operator<TransferableBlock> leftTableOperator,
Operator<TransferableBlock> rightTableOperator,
+ DataSchema outputSchema, JoinNode.JoinKeys joinKeys, List<RexExpression>
joinClauses, JoinRelType joinType) {
+ // TODO: Handle the case where _leftKeySelector and _rightKeySelector
could be null.
_leftKeySelector = joinKeys.getLeftJoinKeySelector();
_rightKeySelector = joinKeys.getRightJoinKeySelector();
_leftTableOperator = leftTableOperator;
_rightTableOperator = rightTableOperator;
_resultSchema = outputSchema;
- _leftTableSchema = leftSchema;
- _rightTableSchema = rightSchema;
_joinClauseEvaluators = new ArrayList<>(joinClauses.size());
for (RexExpression joinClause : joinClauses) {
_joinClauseEvaluators.add(FilterOperand.toFilterOperand(joinClause,
_resultSchema));
@@ -118,7 +120,7 @@ public class HashJoinOperator extends
BaseOperator<TransferableBlock> {
// put all the rows into corresponding hash collections keyed by the
key selector function.
for (Object[] row : container) {
List<Object[]> hashCollection =
-
_broadcastHashTable.computeIfAbsent(_rightKeySelector.computeHash(row), k ->
new ArrayList<>());
+ _broadcastHashTable.computeIfAbsent(new
Key(_rightKeySelector.getKey(row)), k -> new ArrayList<>());
hashCollection.add(row);
}
rightBlock = _rightTableOperator.nextBlock();
@@ -137,10 +139,12 @@ public class HashJoinOperator extends
BaseOperator<TransferableBlock> {
List<Object[]> container = leftBlock.getContainer();
for (Object[] leftRow : container) {
List<Object[]> hashCollection = _broadcastHashTable.getOrDefault(
- _leftKeySelector.computeHash(leftRow), Collections.emptyList());
+ new Key(_leftKeySelector.getKey(leftRow)),
Collections.emptyList());
+ // If it is a left join and right table is empty, we return left rows.
if (hashCollection.isEmpty() && _joinType == JoinRelType.LEFT) {
rows.add(joinRow(leftRow, null));
} else {
+ // If it is other type of join.
for (Object[] rightRow : hashCollection) {
Object[] resultRow = joinRow(leftRow, rightRow);
if (_joinClauseEvaluators.isEmpty() ||
_joinClauseEvaluators.stream().allMatch(
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
index 6dab5865ba..6ba6d6b4c1 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
@@ -83,7 +83,7 @@ public class PhysicalPlanVisitor implements
StageNodeVisitor<Operator<Transferab
public Operator<TransferableBlock> visitAggregate(AggregateNode node,
PlanRequestContext context) {
Operator<TransferableBlock> nextOperator =
node.getInputs().get(0).visit(this, context);
return new AggregateOperator(nextOperator, node.getDataSchema(),
node.getAggCalls(),
- node.getGroupSet(), node.getInputs().get(0).getDataSchema());
+ node.getGroupSet());
}
@Override
@@ -100,8 +100,7 @@ public class PhysicalPlanVisitor implements
StageNodeVisitor<Operator<Transferab
Operator<TransferableBlock> leftOperator = left.visit(this, context);
Operator<TransferableBlock> rightOperator = right.visit(this, context);
- return new HashJoinOperator(leftOperator, left.getDataSchema(),
rightOperator,
- right.getDataSchema(), node.getDataSchema(), node.getJoinKeys(),
+ return new HashJoinOperator(leftOperator, rightOperator,
node.getDataSchema(), node.getJoinKeys(),
node.getJoinClauses(), node.getJoinRelType());
}
diff --git
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
new file mode 100644
index 0000000000..97e27e10ab
--- /dev/null
+++
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator;
+
+import java.util.Arrays;
+import java.util.List;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.testng.Assert;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+import static org.mockito.Mockito.when;
+
+
+public class AggregateOperatorTest {
+ @Mock
+ Operator<TransferableBlock> _upstreamOperator;
+
+ @BeforeMethod
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ }
+
+ @Test
+ public void testGroupByAggregateWithHashCollision() {
+ // "Aa" and "BB" have same hash code in java.
+ List<Object[]> rows = Arrays.asList(new Object[]{1, "Aa"}, new Object[]{2,
"BB"}, new Object[]{3, "BB"});
+
when(_upstreamOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(rows))
+ .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+ // Create an aggregation call with sum for first column and group by
second column.
+ RexExpression.FunctionCall agg = new
RexExpression.FunctionCall(SqlKind.SUM, FieldSpec.DataType.INT, "SUM",
+ Arrays.asList(new RexExpression.InputRef(0)));
+ AggregateOperator sum0GroupBy1 =
+ new AggregateOperator(_upstreamOperator,
OperatorTestUtil.TEST_DATA_SCHEMA, Arrays.asList(agg),
+ Arrays.asList(new RexExpression.InputRef(1)));
+ TransferableBlock result = sum0GroupBy1.getNextBlock();
+ List<Object[]> resultRows = result.getContainer();
+ List<Object[]> expectedRows = Arrays.asList(new Object[]{"Aa", 1}, new
Object[]{"BB", 5.0});
+ Assert.assertEquals(resultRows.size(), expectedRows.size());
+ Assert.assertEquals(resultRows.get(0), expectedRows.get(0));
+ Assert.assertEquals(resultRows.get(1), expectedRows.get(1));
+ }
+}
diff --git
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
new file mode 100644
index 0000000000..768371636d
--- /dev/null
+++
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
@@ -0,0 +1,139 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
+import org.apache.pinot.query.planner.stage.JoinNode;
+import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.testng.Assert;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+import static org.mockito.Mockito.when;
+
+
+public class HashJoinOperatorTest {
+ private static JoinNode.JoinKeys getJoinKeys(List<Integer> leftIdx,
List<Integer> rightIdx) {
+ FieldSelectionKeySelector leftSelect = new
FieldSelectionKeySelector(leftIdx);
+ FieldSelectionKeySelector rightSelect = new
FieldSelectionKeySelector(rightIdx);
+ return new JoinNode.JoinKeys(leftSelect, rightSelect);
+ }
+ @Mock
+ Operator<TransferableBlock> _leftOperator;
+
+ @Mock
+ Operator<TransferableBlock> _rightOperator;
+
+ @BeforeMethod
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ }
+
+ @Test
+ public void testHashJoinKeyCollisionInnerJoin() {
+ // "Aa" and "BB" have same hash code in java.
+ List<Object[]> rows = Arrays.asList(new Object[]{1, "Aa"}, new Object[]{2,
"BB"}, new Object[]{3, "BB"});
+
when(_leftOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(rows))
+ .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+
when(_rightOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(rows))
+ .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+
+ List<RexExpression> joinClauses = new ArrayList<>();
+ DataSchema resultSchema = new DataSchema(new String[]{"foo", "bar", "foo",
"bar"}, new DataSchema.ColumnDataType[]{
+ DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING,
DataSchema.ColumnDataType.INT,
+ DataSchema.ColumnDataType.STRING
+ });
+ HashJoinOperator join = new HashJoinOperator(_leftOperator,
_rightOperator, resultSchema,
+ getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses,
JoinRelType.INNER);
+
+ TransferableBlock result = join.getNextBlock();
+ List<Object[]> resultRows = result.getContainer();
+ List<Object[]> expectedRows =
+ Arrays.asList(new Object[]{1, "Aa", 1, "Aa"}, new Object[]{2, "BB", 2,
"BB"}, new Object[]{2, "BB", 3, "BB"},
+ new Object[]{3, "BB", 2, "BB"}, new Object[]{3, "BB", 3, "BB"});
+ Assert.assertEquals(expectedRows.size(), resultRows.size());
+ Assert.assertEquals(expectedRows.get(0), resultRows.get(0));
+ Assert.assertEquals(expectedRows.get(1), resultRows.get(1));
+ Assert.assertEquals(expectedRows.get(2), resultRows.get(2));
+ Assert.assertEquals(expectedRows.get(3), resultRows.get(3));
+ Assert.assertEquals(expectedRows.get(4), resultRows.get(4));
+ }
+
+ @Test
+ public void testInnerJoin() {
+ List<Object[]> leftRows = Arrays.asList(new Object[]{1, "Aa"}, new
Object[]{2, "BB"}, new Object[]{3, "BB"});
+
when(_leftOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(leftRows))
+ .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+ List<Object[]> rightRows = Arrays.asList(new Object[]{1, "AA"}, new
Object[]{2, "Aa"});
+
when(_rightOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(rightRows))
+ .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+
+ List<RexExpression> joinClauses = new ArrayList<>();
+ DataSchema resultSchema = new DataSchema(new String[]{"foo", "bar", "foo",
"bar"}, new DataSchema.ColumnDataType[]{
+ DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING,
DataSchema.ColumnDataType.INT,
+ DataSchema.ColumnDataType.STRING
+ });
+ HashJoinOperator join = new HashJoinOperator(_leftOperator,
_rightOperator, resultSchema,
+ getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses,
JoinRelType.INNER);
+
+ TransferableBlock result = join.getNextBlock();
+ List<Object[]> resultRows = result.getContainer();
+ Object[] expRow = new Object[]{1, "Aa", 2, "Aa"};
+ List<Object[]> expectedRows = new ArrayList<>();
+ expectedRows.add(expRow);
+ Assert.assertEquals(expectedRows.size(), resultRows.size());
+ Assert.assertEquals(expectedRows.get(0), resultRows.get(0));
+ }
+
+ @Test
+ public void testLeftJoin() {
+ List<Object[]> leftRows = Arrays.asList(new Object[]{1, "Aa"}, new
Object[]{2, "BB"}, new Object[]{3, "BB"});
+
when(_leftOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(leftRows))
+ .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+ List<Object[]> rightRows = Arrays.asList(new Object[]{1, "AA"}, new
Object[]{2, "Aa"});
+
when(_rightOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(rightRows))
+ .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+
+ List<RexExpression> joinClauses = new ArrayList<>();
+ DataSchema resultSchema = new DataSchema(new String[]{"foo", "bar", "foo",
"bar"}, new DataSchema.ColumnDataType[]{
+ DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING,
DataSchema.ColumnDataType.INT,
+ DataSchema.ColumnDataType.STRING
+ });
+ HashJoinOperator join = new HashJoinOperator(_leftOperator,
_rightOperator, resultSchema,
+ getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses,
JoinRelType.LEFT);
+
+ TransferableBlock result = join.getNextBlock();
+ List<Object[]> resultRows = result.getContainer();
+ List<Object[]> expectedRows = Arrays.asList(new Object[]{1, "Aa", 2,
"Aa"}, new Object[]{2, "BB", null, null},
+ new Object[]{3, "BB", null, null});
+ Assert.assertEquals(expectedRows.size(), resultRows.size());
+ Assert.assertEquals(expectedRows.get(0), resultRows.get(0));
+ Assert.assertEquals(expectedRows.get(1), resultRows.get(1));
+ Assert.assertEquals(expectedRows.get(2), resultRows.get(2));
+ }
+}
diff --git
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
new file mode 100644
index 0000000000..da8a9cc599
--- /dev/null
+++
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator;
+
+import java.util.List;
+import org.apache.pinot.common.datablock.BaseDataBlock;
+import org.apache.pinot.common.datablock.DataBlockUtils;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+
+
+public class OperatorTestUtil {
+ private OperatorTestUtil() {
+ }
+
+ public static final DataSchema TEST_DATA_SCHEMA = new DataSchema(new
String[]{"foo", "bar"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT,
DataSchema.ColumnDataType.STRING});
+
+ public static TransferableBlock getEndOfStreamRowBlock() {
+ return getEndOfStreamRowBlockWithSchema(TEST_DATA_SCHEMA);
+ }
+
+ public static TransferableBlock getEndOfStreamRowBlockWithSchema(DataSchema
schema) {
+ return new
TransferableBlock(DataBlockUtils.getEndOfStreamDataBlock(schema));
+ }
+
+ public static TransferableBlock getRowDataBlock(List<Object[]> rows) {
+ return getRowDataBlockWithSchema(rows, TEST_DATA_SCHEMA);
+ }
+
+ public static TransferableBlock getRowDataBlockWithSchema(List<Object[]>
rows, DataSchema schema) {
+ return new TransferableBlock(rows, schema, BaseDataBlock.Type.ROW);
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]