This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 22b345526b [multistage] [bugfix] Fix hash collision (#9729)
22b345526b is described below

commit 22b345526b2abada6cea219484696b1b78669e08
Author: Yao Liu <[email protected]>
AuthorDate: Sat Nov 5 08:30:35 2022 -0700

    [multistage] [bugfix] Fix hash collision (#9729)
    
    * fix agg function hashkey collision bug
    * fix hash join hashkey collision bug
---
 .../query/runtime/blocks/TransferableBlock.java    |   8 --
 .../query/runtime/operator/AggregateOperator.java  |  86 +++++--------
 .../query/runtime/operator/HashJoinOperator.java   |  24 ++--
 .../query/runtime/plan/PhysicalPlanVisitor.java    |   5 +-
 .../runtime/operator/AggregateOperatorTest.java    |  65 ++++++++++
 .../runtime/operator/HashJoinOperatorTest.java     | 139 +++++++++++++++++++++
 .../query/runtime/operator/OperatorTestUtil.java   |  50 ++++++++
 7 files changed, 301 insertions(+), 76 deletions(-)

diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
index 46dd9dc967..93508ffac6 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
@@ -161,14 +161,6 @@ public class TransferableBlock implements Block {
     return _isErrorBlock;
   }
 
-  boolean isContainerBlock() {
-    return _container != null;
-  }
-
-  boolean isDataBlock() {
-    return _dataBlock != null;
-  }
-
   @Override
   public BlockValSet getBlockValueSet() {
     throw new UnsupportedOperationException();
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 0c261d26e3..56f2b9087b 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -26,15 +26,10 @@ import java.util.List;
 import java.util.Map;
 import javax.annotation.Nullable;
 import org.apache.pinot.common.datablock.BaseDataBlock;
-import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.core.data.table.Key;
 import org.apache.pinot.core.operator.BaseOperator;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.MaxAggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.MinAggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction;
 import org.apache.pinot.core.query.selection.SelectionOperatorUtils;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
@@ -44,39 +39,49 @@ import org.apache.pinot.spi.data.FieldSpec;
 
 /**
  *
+ * AggregateOperator is used to aggregate values over a set of group by keys.
+ * Output data will be in the format of [group by key, aggregate result1, ... 
aggregate resultN]
+ * Currently, we only support SUM/COUNT/MIN/MAX aggregation.
+ *
+ * When the list of aggregation calls is empty, this class is used to 
calculate distinct result based on group by keys.
+ * In this case, the input can be any type.
+ *
+ * If the list of aggregation calls is not empty, the input of aggregation has 
to be a number.
+ *
+ * Note: This class performs aggregation over the double value of input.
+ * If the input is single value, the output type will be input type. 
Otherwise, the output type will be double.
  */
 public class AggregateOperator extends BaseOperator<TransferableBlock> {
   private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR";
 
   private Operator<TransferableBlock> _inputOperator;
+  // TODO: Deal with the case where _aggCalls is empty but we have groupSet 
setup, which means this is a Distinct call.
   private List<RexExpression> _aggCalls;
   private List<RexExpression> _groupSet;
 
-  private final AggregationFunction[] _aggregationFunctions;
   private final int[] _aggregationFunctionInputRefs;
   private final Object[] _aggregationFunctionLiterals;
   private final DataSchema _resultSchema;
-  private final Map<Integer, Object>[] _groupByResultHolders;
-  private final Map<Integer, Object[]> _groupByKeyHolder;
-
-  private DataSchema _upstreamDataSchema;
+  private final Map<Key, Object>[] _groupByResultHolders;
+  private final Map<Key, Object[]> _groupByKeyHolder;
   private TransferableBlock _upstreamErrorBlock;
   private boolean _isCumulativeBlockConstructed;
 
   // TODO: refactor Pinot Reducer code to support the intermediate stage agg 
operator.
+  // aggCalls has to be a list of FunctionCall and cannot be null
+  // groupSet has to be a list of InputRef and cannot be null
+  // TODO: Add these two checks when we confirm we can handle error in 
upstream ctor call.
   public AggregateOperator(Operator<TransferableBlock> inputOperator, 
DataSchema dataSchema,
-      List<RexExpression> aggCalls, List<RexExpression> groupSet, DataSchema 
upstreamDataSchema) {
+      List<RexExpression> aggCalls, List<RexExpression> groupSet) {
     _inputOperator = inputOperator;
     _aggCalls = aggCalls;
     _groupSet = groupSet;
-    _upstreamDataSchema = upstreamDataSchema;
     _upstreamErrorBlock = null;
 
-    _aggregationFunctions = new AggregationFunction[_aggCalls.size()];
     _aggregationFunctionInputRefs = new int[_aggCalls.size()];
     _aggregationFunctionLiterals = new Object[_aggCalls.size()];
     _groupByResultHolders = new Map[_aggCalls.size()];
-    _groupByKeyHolder = new HashMap<Integer, Object[]>();
+    _groupByKeyHolder = new HashMap<Key, Object[]>();
     for (int i = 0; i < aggCalls.size(); i++) {
       // agg function operand should either be a InputRef or a Literal
       RexExpression rexExpression = 
toAggregationFunctionOperand(aggCalls.get(i));
@@ -86,8 +91,7 @@ public class AggregateOperator extends 
BaseOperator<TransferableBlock> {
         _aggregationFunctionInputRefs[i] = -1;
         _aggregationFunctionLiterals[i] = ((RexExpression.Literal) 
rexExpression).getValue();
       }
-      _aggregationFunctions[i] = toAggregationFunction(aggCalls.get(i), 
_aggregationFunctionInputRefs[i]);
-      _groupByResultHolders[i] = new HashMap<Integer, Object>();
+      _groupByResultHolders[i] = new HashMap<Key, Object>();
     }
     _resultSchema = dataSchema;
 
@@ -129,8 +133,8 @@ public class AggregateOperator extends 
BaseOperator<TransferableBlock> {
     }
     if (!_isCumulativeBlockConstructed) {
       List<Object[]> rows = new ArrayList<>(_groupByKeyHolder.size());
-      for (Map.Entry<Integer, Object[]> e : _groupByKeyHolder.entrySet()) {
-        Object[] row = new Object[_aggregationFunctions.length + 
_groupSet.size()];
+      for (Map.Entry<Key, Object[]> e : _groupByKeyHolder.entrySet()) {
+        Object[] row = new Object[_aggCalls.size() + _groupSet.size()];
         Object[] keyElements = e.getValue();
         for (int i = 0; i < keyElements.length; i++) {
           row[i] = keyElements[i];
@@ -160,17 +164,17 @@ public class AggregateOperator extends 
BaseOperator<TransferableBlock> {
         for (int rowId = 0; rowId < numRows; rowId++) {
           Object[] row = 
SelectionOperatorUtils.extractRowFromDataTable(dataBlock, rowId);
           Key key = extraRowKey(row, _groupSet);
-          int keyHashCode = key.hashCode();
-          _groupByKeyHolder.put(keyHashCode, key.getValues());
-          for (int i = 0; i < _aggregationFunctions.length; i++) {
-            Object currentRes = _groupByResultHolders[i].get(keyHashCode);
+          _groupByKeyHolder.put(key, key.getValues());
+          for (int i = 0; i < _aggCalls.size(); i++) {
+            Object currentRes = _groupByResultHolders[i].get(key);
+            // TODO: fix that single agg result (original type) has different 
type from multiple agg results (double).
             if (currentRes == null) {
-              _groupByResultHolders[i].put(keyHashCode, 
_aggregationFunctionInputRefs[i] == -1
-                  ? _aggregationFunctionLiterals[i] : 
row[_aggregationFunctionInputRefs[i]]);
+              _groupByResultHolders[i].put(key, 
_aggregationFunctionInputRefs[i] == -1 ? _aggregationFunctionLiterals[i]
+                  : row[_aggregationFunctionInputRefs[i]]);
             } else {
-              _groupByResultHolders[i].put(keyHashCode,
-                  merge(_aggCalls.get(i), currentRes, 
_aggregationFunctionInputRefs[i] == -1
-                      ? _aggregationFunctionLiterals[i] : 
row[_aggregationFunctionInputRefs[i]]));
+              _groupByResultHolders[i].put(key, merge(_aggCalls.get(i), 
currentRes,
+                  _aggregationFunctionInputRefs[i] == -1 ? 
_aggregationFunctionLiterals[i]
+                      : row[_aggregationFunctionInputRefs[i]]));
             }
           }
         }
@@ -183,34 +187,6 @@ public class AggregateOperator extends 
BaseOperator<TransferableBlock> {
     }
   }
 
-  private AggregationFunction toAggregationFunction(RexExpression aggCall, int 
aggregationFunctionInputRef) {
-    Preconditions.checkState(aggCall instanceof RexExpression.FunctionCall);
-    // TODO(Rong Rong): query options are not supported by the new engine at 
this moment.
-    switch (((RexExpression.FunctionCall) aggCall).getFunctionName()) {
-      case "$SUM":
-      case "$SUM0":
-      case "SUM":
-        return new SumAggregationFunction(
-            
ExpressionContext.forIdentifier(String.valueOf(aggregationFunctionInputRef)));
-      case "$MIN":
-      case "$MIN0":
-      case "MIN":
-        return new MinAggregationFunction(
-            
ExpressionContext.forIdentifier(String.valueOf(aggregationFunctionInputRef)));
-      case "$MAX":
-      case "$MAX0":
-      case "MAX":
-        return new MaxAggregationFunction(
-            
ExpressionContext.forIdentifier(String.valueOf(aggregationFunctionInputRef)));
-      // COUNT(*) is rewritten to SUM(1)
-      case "COUNT":
-        return new 
SumAggregationFunction(ExpressionContext.forLiteralContext(FieldSpec.DataType.INT,
 1));
-      default:
-        throw new IllegalStateException(
-            "Unexpected value: " + ((RexExpression.FunctionCall) 
aggCall).getFunctionName());
-    }
-  }
-
   private Object merge(RexExpression aggCall, Object left, Object right) {
     Preconditions.checkState(aggCall instanceof RexExpression.FunctionCall);
     switch (((RexExpression.FunctionCall) aggCall).getFunctionName()) {
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
index bcf6807dd9..61cd18ff5a 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
@@ -28,6 +28,7 @@ import org.apache.pinot.common.datablock.BaseDataBlock;
 import org.apache.pinot.common.datablock.DataBlockUtils;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.data.table.Key;
 import org.apache.pinot.core.operator.BaseOperator;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.planner.partitioning.KeySelector;
@@ -44,17 +45,18 @@ import 
org.apache.pinot.query.runtime.operator.operands.FilterOperand;
  * it looks up for the corresponding row(s) from the hash table and create a 
joint row.
  *
  * <p>For each of the data block received from the left table, it will 
generate a joint data block.
+ *
+ * We currently support left join, inner join and semi join.
+ * The output is in the format of [left_row, right_row]
  */
 public class HashJoinOperator extends BaseOperator<TransferableBlock> {
   private static final String EXPLAIN_NAME = "BROADCAST_JOIN";
 
-  private final HashMap<Integer, List<Object[]>> _broadcastHashTable;
+  private final HashMap<Key, List<Object[]>> _broadcastHashTable;
   private final Operator<TransferableBlock> _leftTableOperator;
   private final Operator<TransferableBlock> _rightTableOperator;
   private final JoinRelType _joinType;
   private final DataSchema _resultSchema;
-  private final DataSchema _leftTableSchema;
-  private final DataSchema _rightTableSchema;
   private final int _resultRowSize;
   private final List<FilterOperand> _joinClauseEvaluators;
   private boolean _isHashTableBuilt;
@@ -62,16 +64,16 @@ public class HashJoinOperator extends 
BaseOperator<TransferableBlock> {
   private KeySelector<Object[], Object[]> _leftKeySelector;
   private KeySelector<Object[], Object[]> _rightKeySelector;
 
-  public HashJoinOperator(Operator<TransferableBlock> leftTableOperator, 
DataSchema leftSchema,
-      Operator<TransferableBlock> rightTableOperator, DataSchema rightSchema, 
DataSchema outputSchema,
-      JoinNode.JoinKeys joinKeys, List<RexExpression> joinClauses, JoinRelType 
joinType) {
+  // TODO: Fix inequi join bug. (https://github.com/apache/pinot/issues/9728)
+  // TODO: Double check semi join logic.
+  public HashJoinOperator(Operator<TransferableBlock> leftTableOperator, 
Operator<TransferableBlock> rightTableOperator,
+      DataSchema outputSchema, JoinNode.JoinKeys joinKeys, List<RexExpression> 
joinClauses, JoinRelType joinType) {
+    // TODO: Handle the case where _leftKeySelector and _rightKeySelector 
could be null.
     _leftKeySelector = joinKeys.getLeftJoinKeySelector();
     _rightKeySelector = joinKeys.getRightJoinKeySelector();
     _leftTableOperator = leftTableOperator;
     _rightTableOperator = rightTableOperator;
     _resultSchema = outputSchema;
-    _leftTableSchema = leftSchema;
-    _rightTableSchema = rightSchema;
     _joinClauseEvaluators = new ArrayList<>(joinClauses.size());
     for (RexExpression joinClause : joinClauses) {
       _joinClauseEvaluators.add(FilterOperand.toFilterOperand(joinClause, 
_resultSchema));
@@ -118,7 +120,7 @@ public class HashJoinOperator extends 
BaseOperator<TransferableBlock> {
         // put all the rows into corresponding hash collections keyed by the 
key selector function.
         for (Object[] row : container) {
           List<Object[]> hashCollection =
-              
_broadcastHashTable.computeIfAbsent(_rightKeySelector.computeHash(row), k -> 
new ArrayList<>());
+              _broadcastHashTable.computeIfAbsent(new 
Key(_rightKeySelector.getKey(row)), k -> new ArrayList<>());
           hashCollection.add(row);
         }
         rightBlock = _rightTableOperator.nextBlock();
@@ -137,10 +139,12 @@ public class HashJoinOperator extends 
BaseOperator<TransferableBlock> {
       List<Object[]> container = leftBlock.getContainer();
       for (Object[] leftRow : container) {
         List<Object[]> hashCollection = _broadcastHashTable.getOrDefault(
-            _leftKeySelector.computeHash(leftRow), Collections.emptyList());
+            new Key(_leftKeySelector.getKey(leftRow)), 
Collections.emptyList());
+        // If it is a left join and right table is empty, we return left rows.
         if (hashCollection.isEmpty() && _joinType == JoinRelType.LEFT) {
           rows.add(joinRow(leftRow, null));
         } else {
+          // If it is other type of join.
           for (Object[] rightRow : hashCollection) {
             Object[] resultRow = joinRow(leftRow, rightRow);
             if (_joinClauseEvaluators.isEmpty() || 
_joinClauseEvaluators.stream().allMatch(
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
index 6dab5865ba..6ba6d6b4c1 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
@@ -83,7 +83,7 @@ public class PhysicalPlanVisitor implements 
StageNodeVisitor<Operator<Transferab
   public Operator<TransferableBlock> visitAggregate(AggregateNode node, 
PlanRequestContext context) {
     Operator<TransferableBlock> nextOperator = 
node.getInputs().get(0).visit(this, context);
     return new AggregateOperator(nextOperator, node.getDataSchema(), 
node.getAggCalls(),
-        node.getGroupSet(), node.getInputs().get(0).getDataSchema());
+        node.getGroupSet());
   }
 
   @Override
@@ -100,8 +100,7 @@ public class PhysicalPlanVisitor implements 
StageNodeVisitor<Operator<Transferab
     Operator<TransferableBlock> leftOperator = left.visit(this, context);
     Operator<TransferableBlock> rightOperator = right.visit(this, context);
 
-    return new HashJoinOperator(leftOperator, left.getDataSchema(), 
rightOperator,
-        right.getDataSchema(), node.getDataSchema(), node.getJoinKeys(),
+    return new HashJoinOperator(leftOperator, rightOperator, 
node.getDataSchema(), node.getJoinKeys(),
         node.getJoinClauses(), node.getJoinRelType());
   }
 
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
new file mode 100644
index 0000000000..97e27e10ab
--- /dev/null
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator;
+
+import java.util.Arrays;
+import java.util.List;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.testng.Assert;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+import static org.mockito.Mockito.when;
+
+
+public class AggregateOperatorTest {
+  @Mock
+  Operator<TransferableBlock> _upstreamOperator;
+
+  @BeforeMethod
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+  }
+
+  @Test
+  public void testGroupByAggregateWithHashCollision() {
+    // "Aa" and "BB" have same hash code in java.
+    List<Object[]> rows = Arrays.asList(new Object[]{1, "Aa"}, new Object[]{2, 
"BB"}, new Object[]{3, "BB"});
+    
when(_upstreamOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(rows))
+        .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+    // Create an aggregation call with sum for first column and group by 
second column.
+    RexExpression.FunctionCall agg = new 
RexExpression.FunctionCall(SqlKind.SUM, FieldSpec.DataType.INT, "SUM",
+        Arrays.asList(new RexExpression.InputRef(0)));
+    AggregateOperator sum0GroupBy1 =
+        new AggregateOperator(_upstreamOperator, 
OperatorTestUtil.TEST_DATA_SCHEMA, Arrays.asList(agg),
+            Arrays.asList(new RexExpression.InputRef(1)));
+    TransferableBlock result = sum0GroupBy1.getNextBlock();
+    List<Object[]> resultRows = result.getContainer();
+    List<Object[]> expectedRows = Arrays.asList(new Object[]{"Aa", 1}, new 
Object[]{"BB", 5.0});
+    Assert.assertEquals(resultRows.size(), expectedRows.size());
+    Assert.assertEquals(resultRows.get(0), expectedRows.get(0));
+    Assert.assertEquals(resultRows.get(1), expectedRows.get(1));
+  }
+}
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
new file mode 100644
index 0000000000..768371636d
--- /dev/null
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
@@ -0,0 +1,139 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
+import org.apache.pinot.query.planner.stage.JoinNode;
+import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.testng.Assert;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+import static org.mockito.Mockito.when;
+
+
+public class HashJoinOperatorTest {
+  private static JoinNode.JoinKeys getJoinKeys(List<Integer> leftIdx, 
List<Integer> rightIdx) {
+    FieldSelectionKeySelector leftSelect = new 
FieldSelectionKeySelector(leftIdx);
+    FieldSelectionKeySelector rightSelect = new 
FieldSelectionKeySelector(rightIdx);
+    return new JoinNode.JoinKeys(leftSelect, rightSelect);
+  }
+  @Mock
+  Operator<TransferableBlock> _leftOperator;
+
+  @Mock
+  Operator<TransferableBlock> _rightOperator;
+
+  @BeforeMethod
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+  }
+
+  @Test
+  public void testHashJoinKeyCollisionInnerJoin() {
+    // "Aa" and "BB" have same hash code in java.
+    List<Object[]> rows = Arrays.asList(new Object[]{1, "Aa"}, new Object[]{2, 
"BB"}, new Object[]{3, "BB"});
+    
when(_leftOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(rows))
+        .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+    
when(_rightOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(rows))
+        .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+
+    List<RexExpression> joinClauses = new ArrayList<>();
+    DataSchema resultSchema = new DataSchema(new String[]{"foo", "bar", "foo", 
"bar"}, new DataSchema.ColumnDataType[]{
+        DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, 
DataSchema.ColumnDataType.INT,
+        DataSchema.ColumnDataType.STRING
+    });
+    HashJoinOperator join = new HashJoinOperator(_leftOperator, 
_rightOperator, resultSchema,
+        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses, 
JoinRelType.INNER);
+
+    TransferableBlock result = join.getNextBlock();
+    List<Object[]> resultRows = result.getContainer();
+    List<Object[]> expectedRows =
+        Arrays.asList(new Object[]{1, "Aa", 1, "Aa"}, new Object[]{2, "BB", 2, 
"BB"}, new Object[]{2, "BB", 3, "BB"},
+            new Object[]{3, "BB", 2, "BB"}, new Object[]{3, "BB", 3, "BB"});
+    Assert.assertEquals(expectedRows.size(), resultRows.size());
+    Assert.assertEquals(expectedRows.get(0), resultRows.get(0));
+    Assert.assertEquals(expectedRows.get(1), resultRows.get(1));
+    Assert.assertEquals(expectedRows.get(2), resultRows.get(2));
+    Assert.assertEquals(expectedRows.get(3), resultRows.get(3));
+    Assert.assertEquals(expectedRows.get(4), resultRows.get(4));
+  }
+
+  @Test
+  public void testInnerJoin() {
+    List<Object[]> leftRows = Arrays.asList(new Object[]{1, "Aa"}, new 
Object[]{2, "BB"}, new Object[]{3, "BB"});
+    
when(_leftOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(leftRows))
+        .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+    List<Object[]> rightRows = Arrays.asList(new Object[]{1, "AA"}, new 
Object[]{2, "Aa"});
+    
when(_rightOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(rightRows))
+        .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+
+    List<RexExpression> joinClauses = new ArrayList<>();
+    DataSchema resultSchema = new DataSchema(new String[]{"foo", "bar", "foo", 
"bar"}, new DataSchema.ColumnDataType[]{
+        DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, 
DataSchema.ColumnDataType.INT,
+        DataSchema.ColumnDataType.STRING
+    });
+    HashJoinOperator join = new HashJoinOperator(_leftOperator, 
_rightOperator, resultSchema,
+        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses, 
JoinRelType.INNER);
+
+    TransferableBlock result = join.getNextBlock();
+    List<Object[]> resultRows = result.getContainer();
+    Object[] expRow = new Object[]{1, "Aa", 2, "Aa"};
+    List<Object[]> expectedRows = new ArrayList<>();
+    expectedRows.add(expRow);
+    Assert.assertEquals(expectedRows.size(), resultRows.size());
+    Assert.assertEquals(expectedRows.get(0), resultRows.get(0));
+  }
+
+  @Test
+  public void testLeftJoin() {
+    List<Object[]> leftRows = Arrays.asList(new Object[]{1, "Aa"}, new 
Object[]{2, "BB"}, new Object[]{3, "BB"});
+    
when(_leftOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(leftRows))
+        .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+    List<Object[]> rightRows = Arrays.asList(new Object[]{1, "AA"}, new 
Object[]{2, "Aa"});
+    
when(_rightOperator.nextBlock()).thenReturn(OperatorTestUtil.getRowDataBlock(rightRows))
+        .thenReturn(OperatorTestUtil.getEndOfStreamRowBlock());
+
+    List<RexExpression> joinClauses = new ArrayList<>();
+    DataSchema resultSchema = new DataSchema(new String[]{"foo", "bar", "foo", 
"bar"}, new DataSchema.ColumnDataType[]{
+        DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, 
DataSchema.ColumnDataType.INT,
+        DataSchema.ColumnDataType.STRING
+    });
+    HashJoinOperator join = new HashJoinOperator(_leftOperator, 
_rightOperator, resultSchema,
+        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses, 
JoinRelType.LEFT);
+
+    TransferableBlock result = join.getNextBlock();
+    List<Object[]> resultRows = result.getContainer();
+    List<Object[]> expectedRows = Arrays.asList(new Object[]{1, "Aa", 2, 
"Aa"}, new Object[]{2, "BB", null, null},
+        new Object[]{3, "BB", null, null});
+    Assert.assertEquals(expectedRows.size(), resultRows.size());
+    Assert.assertEquals(expectedRows.get(0), resultRows.get(0));
+    Assert.assertEquals(expectedRows.get(1), resultRows.get(1));
+    Assert.assertEquals(expectedRows.get(2), resultRows.get(2));
+  }
+}
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
new file mode 100644
index 0000000000..da8a9cc599
--- /dev/null
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator;
+
+import java.util.List;
+import org.apache.pinot.common.datablock.BaseDataBlock;
+import org.apache.pinot.common.datablock.DataBlockUtils;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+
+
+public class OperatorTestUtil {
+  private OperatorTestUtil() {
+  }
+
+  public static final DataSchema TEST_DATA_SCHEMA = new DataSchema(new 
String[]{"foo", "bar"},
+      new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, 
DataSchema.ColumnDataType.STRING});
+
+  public static TransferableBlock getEndOfStreamRowBlock() {
+    return getEndOfStreamRowBlockWithSchema(TEST_DATA_SCHEMA);
+  }
+
+  public static TransferableBlock getEndOfStreamRowBlockWithSchema(DataSchema 
schema) {
+    return new 
TransferableBlock(DataBlockUtils.getEndOfStreamDataBlock(schema));
+  }
+
+  public static TransferableBlock getRowDataBlock(List<Object[]> rows) {
+    return getRowDataBlockWithSchema(rows, TEST_DATA_SCHEMA);
+  }
+
+  public static TransferableBlock getRowDataBlockWithSchema(List<Object[]> 
rows, DataSchema schema) {
+    return new TransferableBlock(rows, schema, BaseDataBlock.Type.ROW);
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to