This is an automated email from the ASF dual-hosted git repository.

abstractdog pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hive.git


The following commit(s) were added to refs/heads/master by this push:
     new 2975d7e  HIVE-23880: Bloom filters can be merged in a parallel way in 
VectorUDAFBloomFilterMerge (Laszlo Bodor reviewed by Panagiotis Garefalakis, 
Mustafa Iman, Stamatis Zampetakis, Rajesh Balamohan, David Mollitor)
2975d7e is described below

commit 2975d7e5277f56ede523346eb331a3fe149e932d
Author: Laszlo Bodor <[email protected]>
AuthorDate: Wed Aug 5 09:20:47 2020 +0200

    HIVE-23880: Bloom filters can be merged in a parallel way in 
VectorUDAFBloomFilterMerge (Laszlo Bodor reviewed by Panagiotis Garefalakis, 
Mustafa Iman, Stamatis Zampetakis, Rajesh Balamohan, David Mollitor)
    
    Change-Id: I28947c24cd8bb6a909e7925da38c448885ac6443
---
 .../java/org/apache/hadoop/hive/conf/HiveConf.java |   6 +
 .../hive/ql/exec/vector/VectorGroupByOperator.java | 124 +++++----
 .../aggregates/VectorAggregateExpression.java      |  15 +-
 .../aggregates/VectorUDAFBloomFilterMerge.java     | 279 ++++++++++++++++++---
 .../ql/exec/vector/TestVectorGroupByOperator.java  |  23 ++
 .../aggregates/TestVectorUDAFBloomFilterMerge.java |  98 ++++++++
 .../org/apache/hive/common/util/BloomKFilter.java  |  33 ++-
 .../apache/hive/common/util/TestBloomKFilter.java  |   2 +-
 8 files changed, 495 insertions(+), 85 deletions(-)

diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java 
b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
index de355ad..63578aa 100644
--- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
+++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
@@ -4331,6 +4331,12 @@ public class HiveConf extends Configuration {
             "Bloom filter should be of at max certain size to be effective"),
     TEZ_BLOOM_FILTER_FACTOR("hive.tez.bloom.filter.factor", (float) 1.0,
             "Bloom filter should be a multiple of this factor with nDV"),
+    TEZ_BLOOM_FILTER_MERGE_THREADS("hive.tez.bloom.filter.merge.threads", 1,
+        "How many threads are used for merging bloom filters in addition to 
task's main thread?\n"
+            + "-1: sanity check, it will fail if execution hits bloom filter 
merge codepath\n"
+            + " 0: feature is disabled, use only task's main thread for bloom 
filter merging\n"
+            + " 1: recommended value: there is only 1 merger thread 
(additionally to the task's main thread),"
+            + "according perf tests, this can lead to serious improvement \n"),
     
TEZ_BIGTABLE_MIN_SIZE_SEMIJOIN_REDUCTION("hive.tez.bigtable.minsize.semijoin.reduction",
 100000000L,
             "Big table for runtime filteting should be of atleast this size"),
     
TEZ_DYNAMIC_SEMIJOIN_REDUCTION_THRESHOLD("hive.tez.dynamic.semijoin.reduction.threshold",
 (float) 0.50,
diff --git 
a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java 
b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java
index f6b38d6..e47a6f9 100644
--- 
a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java
+++ 
b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java
@@ -48,6 +48,7 @@ import 
org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
 import 
org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriter;
 import 
org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriterFactory;
 import 
org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
+import 
org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFBloomFilterMerge;
 import org.apache.hadoop.hive.ql.exec.vector.wrapper.VectorHashKeyWrapperBase;
 import org.apache.hadoop.hive.ql.exec.vector.wrapper.VectorHashKeyWrapperBatch;
 import 
org.apache.hadoop.hive.ql.exec.vector.wrapper.VectorHashKeyWrapperGeneral;
@@ -173,6 +174,14 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
    * Base class for all processing modes
    */
   private abstract class ProcessingModeBase implements IProcessingMode {
+    /**
+     * The VectorAggregationBufferRow instance. This field can be shared with 
ProcessingMode
+     * subclasses, where is only 1 VectorAggregationBufferRow instance needed 
at the same time. This
+     * is the case for ProcessingModeGlobalAggregate, 
ProcessingModeReduceMergePartial,
+     * ProcessingModeStreaming, but not for ProcessingModeHashAggregate where 
this field is not
+     * used.
+     */
+    protected VectorAggregationBufferRow aggregationBufferSet;
 
     // Overridden and used in ProcessingModeReduceMergePartial mode.
     @Override
@@ -252,6 +261,18 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
       return bufferSet;
     }
 
+    @Override
+    public void close(boolean aborted) throws HiveException {
+      finishAggregators(aggregationBufferSet, aborted);
+    }
+
+    public void finishAggregators(VectorAggregationBufferRow 
vectorAggregationBufferRow, boolean aborted) {
+      if (vectorAggregationBufferRow != null) {
+        for (int i = 0; i < aggregators.length; ++i) {
+          
aggregators[i].finish(vectorAggregationBufferRow.getAggregationBuffer(i), 
aborted);
+        }
+      }
+    }
   }
 
   /**
@@ -261,14 +282,9 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
    */
   final class ProcessingModeGlobalAggregate extends ProcessingModeBase {
 
-    /**
-     * In global processing mode there is only one set of aggregation buffers
-     */
-    private VectorAggregationBufferRow aggregationBuffers;
-
     @Override
     public void initialize(Configuration hconf) throws HiveException {
-      aggregationBuffers =  allocateAggregationBuffer();
+      aggregationBufferSet =  allocateAggregationBuffer();
       LOG.info("using global aggregation processing mode");
     }
 
@@ -281,14 +297,16 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
     public void doProcessBatch(VectorizedRowBatch batch, boolean 
isFirstGroupingSet,
         boolean[] currentGroupingSetsOverrideIsNulls) throws HiveException {
       for (int i = 0; i < aggregators.length; ++i) {
-        
aggregators[i].aggregateInput(aggregationBuffers.getAggregationBuffer(i), 
batch);
+        
aggregators[i].aggregateInput(aggregationBufferSet.getAggregationBuffer(i), 
batch);
       }
     }
 
     @Override
     public void close(boolean aborted) throws HiveException {
+      super.close(aborted);
+
       if (!aborted) {
-        writeSingleRow(null, aggregationBuffers);
+        writeSingleRow(null, aggregationBufferSet);
       }
     }
   }
@@ -494,6 +512,8 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
 
     @Override
     public void close(boolean aborted) throws HiveException {
+      super.close(aborted);
+
       reusableAggregationBufferRows.clear();
       if (reusableKeyWrapperBuffer != null) {
         reusableKeyWrapperBuffer.clear();
@@ -512,6 +532,7 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
           keyWrappersBatch.setLongValue(kw, pos, val);
         }
         VectorAggregationBufferRow groupAggregators = 
allocateAggregationBuffer();
+        finishAggregators(groupAggregators, false);
         writeSingleRow(kw, groupAggregators);
       }
 
@@ -652,6 +673,7 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
           }
         }
 
+        finishAggregators(bufferRow, false);
         writeSingleRow((VectorHashKeyWrapperBase) keyWrapper, bufferRow);
 
         if (!all) {
@@ -767,11 +789,6 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
   final class ProcessingModeStreaming extends ProcessingModeBase {
 
     /**
-     * The aggregation buffers used in streaming mode
-     */
-    private VectorAggregationBufferRow currentStreamingAggregators;
-
-    /**
      * The current key, used in streaming mode
      */
     private VectorHashKeyWrapperBase streamingKey;
@@ -843,7 +860,7 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
       final VectorHashKeyWrapperBase prevKey = streamingKey;
       if (streamingKey == null) {
         // This is the first batch we process after switching from hash mode
-        currentStreamingAggregators = 
streamAggregationBufferRowPool.getFromPool();
+        aggregationBufferSet = streamAggregationBufferRowPool.getFromPool();
         streamingKey = batchKeys[0];
       }
 
@@ -854,13 +871,13 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
         if (!batchKeys[i].equals(streamingKey)) {
           // We've encountered a new key, must save current one
           // We can't forward yet, the aggregators have not been evaluated
-          rowsToFlush[flushMark] = currentStreamingAggregators;
+          rowsToFlush[flushMark] = aggregationBufferSet;
           keysToFlush[flushMark] = streamingKey;
-          currentStreamingAggregators = 
streamAggregationBufferRowPool.getFromPool();
+          aggregationBufferSet = streamAggregationBufferRowPool.getFromPool();
           streamingKey = batchKeys[i];
           ++flushMark;
         }
-        
aggregationBatchInfo.mapAggregationBufferSet(currentStreamingAggregators, i);
+        aggregationBatchInfo.mapAggregationBufferSet(aggregationBufferSet, i);
       }
 
       // evaluate the aggregators
@@ -868,6 +885,7 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
 
       // Now flush/forward all keys/rows, except the last (current) one
       for (int i = 0; i < flushMark; ++i) {
+        finishAggregators(rowsToFlush[i], false); //finish aggregations before 
flushing
         writeSingleRow(keysToFlush[i], rowsToFlush[i]);
         rowsToFlush[i].reset();
         keysToFlush[i] = null;
@@ -881,8 +899,9 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
 
     @Override
     public void close(boolean aborted) throws HiveException {
+      super.close(aborted);
       if (!aborted && null != streamingKey) {
-        writeSingleRow(streamingKey, currentStreamingAggregators);
+        writeSingleRow(streamingKey, aggregationBufferSet);
       }
     }
   }
@@ -917,11 +936,6 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
     VectorGroupKeyHelper groupKeyHelper;
 
     /**
-     * The group vector aggregation buffers.
-     */
-    private VectorAggregationBufferRow groupAggregators;
-
-    /**
      * Buffer to hold string values.
      */
     private DataOutputBuffer buffer;
@@ -934,7 +948,7 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
       // instead of keyExpressions.length
       groupKeyHelper = new VectorGroupKeyHelper(outputKeyLength);
       groupKeyHelper.init(keyExpressions);
-      groupAggregators = allocateAggregationBuffer();
+      aggregationBufferSet = allocateAggregationBuffer();
       buffer = new DataOutputBuffer();
       LOG.info("using sorted group batch aggregation processing mode");
     }
@@ -966,19 +980,21 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
 
       // Aggregate this batch.
       for (int i = 0; i < aggregators.length; ++i) {
-        
aggregators[i].aggregateInput(groupAggregators.getAggregationBuffer(i), batch);
+        
aggregators[i].aggregateInput(aggregationBufferSet.getAggregationBuffer(i), 
batch);
       }
 
       if (isLastGroupBatch) {
-        writeGroupRow(groupAggregators, buffer);
-        groupAggregators.reset();
+        finishAggregators(aggregationBufferSet, false);
+        writeGroupRow(aggregationBufferSet, buffer);
+        aggregationBufferSet.reset();
       }
     }
 
     @Override
     public void close(boolean aborted) throws HiveException {
+      super.close(aborted);
       if (!aborted && !first && !isLastGroupBatch) {
-        writeGroupRow(groupAggregators, buffer);
+        writeGroupRow(aggregationBufferSet, buffer);
       }
     }
   }
@@ -1116,21 +1132,8 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
 
         Class<? extends VectorAggregateExpression> vecAggrClass = 
vecAggrDesc.getVecAggrClass();
 
-        Constructor<? extends VectorAggregateExpression> ctor = null;
-        try {
-          ctor = vecAggrClass.getConstructor(VectorAggregationDesc.class);
-        } catch (Exception e) {
-          throw new HiveException("Constructor " + 
vecAggrClass.getSimpleName() +
-              "(VectorAggregationDesc) not available");
-        }
-        VectorAggregateExpression vecAggrExpr = null;
-        try {
-          vecAggrExpr = ctor.newInstance(vecAggrDesc);
-        } catch (Exception e) {
-
-           throw new HiveException("Failed to create " + 
vecAggrClass.getSimpleName() +
-               "(VectorAggregationDesc) object ", e);
-        }
+        VectorAggregateExpression vecAggrExpr =
+            instantiateExpression(vecAggrDesc, hconf);
         VectorExpression.doTransientInit(vecAggrExpr.getInputExpression(), 
hconf);
         aggregators[i] = vecAggrExpr;
 
@@ -1194,6 +1197,41 @@ public class VectorGroupByOperator extends 
Operator<GroupByDesc>
     processingMode.initialize(hconf);
   }
 
+  @VisibleForTesting
+  VectorAggregateExpression instantiateExpression(VectorAggregationDesc 
vecAggrDesc,
+      Configuration hconf) throws HiveException {
+    Class<? extends VectorAggregateExpression> vecAggrClass = 
vecAggrDesc.getVecAggrClass();
+
+    Constructor<? extends VectorAggregateExpression> ctor = null;
+    try {
+      if (vecAggrDesc.getVecAggrClass() == VectorUDAFBloomFilterMerge.class) {
+        // VectorUDAFBloomFilterMerge is instantiated with a number of threads 
of parallel processing
+        ctor = vecAggrClass.getConstructor(VectorAggregationDesc.class, 
int.class);
+      } else {
+        ctor = vecAggrClass.getConstructor(VectorAggregationDesc.class);
+      }
+    } catch (Exception e) {
+      throw new HiveException(
+          "Constructor " + vecAggrClass.getSimpleName() + 
"(VectorAggregationDesc) not available", e);
+    }
+    VectorAggregateExpression vecAggrExpr = null;
+    try {
+      if (vecAggrDesc.getVecAggrClass() == VectorUDAFBloomFilterMerge.class) {
+        vecAggrExpr = ctor.newInstance(vecAggrDesc,
+            
hconf.getInt(HiveConf.ConfVars.TEZ_BLOOM_FILTER_MERGE_THREADS.varname,
+                
HiveConf.ConfVars.TEZ_BLOOM_FILTER_MERGE_THREADS.defaultIntVal));
+      } else {
+        vecAggrExpr = ctor.newInstance(vecAggrDesc);
+      }
+    } catch (Exception e) {
+
+      throw new HiveException(
+          "Failed to create " + vecAggrClass.getSimpleName() + 
"(VectorAggregationDesc) object ",
+          e);
+    }
+    return vecAggrExpr;
+  }
+
   /**
    * changes the processing mode to streaming
    * This is done at the request of the hash agg mode, if the number of keys
diff --git 
a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorAggregateExpression.java
 
b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorAggregateExpression.java
index 2499f09..65a0df4 100644
--- 
a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorAggregateExpression.java
+++ 
b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorAggregateExpression.java
@@ -27,9 +27,7 @@ import 
org.apache.hadoop.hive.ql.exec.vector.VectorAggregationDesc;
 import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
 import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.plan.AggregationDesc;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode;
 
@@ -37,7 +35,6 @@ import 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode;
  * Base class for aggregation expressions.
  */
 public abstract class VectorAggregateExpression  implements Serializable {
-
   private static final long serialVersionUID = 1L;
 
   protected final VectorAggregationDesc vecAggrDesc;
@@ -145,5 +142,17 @@ public abstract class VectorAggregateExpression  
implements Serializable {
   public String toString() {
     return vecAggrDesc.toString();
   }
+
+  /**
+   * Optional method to implement in VectorAggregateExpression instances. It's 
called for all
+   * aggregators from ProcessingMode, when VectorGroupByOperator is closed 
(before flush) or a
+   * write/flush happens. Calling this method properly before writing rows is 
the responsibility of
+   * VectorGroupByOperator.
+   *
+   * @param aggregationBuffer
+   * @param aborted
+   */
+  public void finish(AggregationBuffer aggregationBuffer, boolean aborted) {
+  }
 }
 
diff --git 
a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java
 
b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java
index fe5e33a..862b052 100644
--- 
a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java
+++ 
b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java
@@ -18,53 +18,57 @@
 
 package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates;
 
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
 import java.util.Arrays;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector;
 import org.apache.hadoop.hive.ql.exec.vector.ColumnVector;
 import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
 import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationDesc;
 import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
-import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
-import 
org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression.AggregationBuffer;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.plan.AggregationDesc;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBloomFilter.GenericUDAFBloomFilterEvaluator;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.io.IOUtils;
 import org.apache.hive.common.util.BloomKFilter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static 
org.apache.hive.common.util.BloomKFilter.START_OF_SERIALIZED_LONGS;
 
 public class VectorUDAFBloomFilterMerge extends VectorAggregateExpression {
   private static final long serialVersionUID = 1L;
+  private static final Logger LOG = 
LoggerFactory.getLogger(VectorUDAFBloomFilterMerge.class);
 
   private long expectedEntries = -1;
-  transient private int aggBufferSize;
+  private transient int aggBufferSize;
+  private transient int numThreads;
 
   /**
    * class for storing the current aggregate value.
    */
-  private static final class Aggregation implements AggregationBuffer {
+  static final class Aggregation implements AggregationBuffer {
     private static final long serialVersionUID = 1L;
 
     byte[] bfBytes;
+    private ExecutorService executor;
+    private int numThreads;
+    private BloomFilterMergeWorker[] workers;
+    private AtomicBoolean aborted = new AtomicBoolean(false);
 
-    public Aggregation(long expectedEntries) {
-      ByteArrayOutputStream bytesOut = null;
-      try {
-        BloomKFilter bf = new BloomKFilter(expectedEntries);
-        bytesOut = new ByteArrayOutputStream();
-        BloomKFilter.serialize(bytesOut, bf);
-        bfBytes = bytesOut.toByteArray();
-      } catch (Exception err) {
-        throw new IllegalArgumentException("Error creating aggregation 
buffer", err);
-      } finally {
-        IOUtils.closeStream(bytesOut);
+    public Aggregation(long expectedEntries, int numThreads) {
+      bfBytes = BloomKFilter.getInitialBytes(expectedEntries);
+
+      if (numThreads < 0) {
+        throw new RuntimeException(
+            "invalid number of threads for bloom filter merge: " + numThreads);
       }
+
+      this.numThreads = numThreads;
     }
 
     @Override
@@ -77,6 +81,185 @@ public class VectorUDAFBloomFilterMerge extends 
VectorAggregateExpression {
       // Do not change the initial bytes which contain 
NumHashFunctions/NumBits!
       Arrays.fill(bfBytes, BloomKFilter.START_OF_SERIALIZED_LONGS, 
bfBytes.length, (byte) 0);
     }
+
+    public void mergeBloomFilterBytesFromInputColumn(BytesColumnVector 
inputColumn,
+        int batchSize, boolean selectedInUse, int[] selected) {
+      if (executor == null) {
+        initExecutor();
+      }
+
+      // split every bloom filter (represented by a part of a byte[]) across 
workers
+      for (int j = 0; j < batchSize; j++) {
+        if (!selectedInUse) {
+          if (inputColumn.noNulls) {
+            splitVectorAcrossWorkers(workers, inputColumn.vector[j], 
inputColumn.start[j],
+                inputColumn.length[j]);
+          } else if (!inputColumn.isNull[j]) {
+            splitVectorAcrossWorkers(workers, inputColumn.vector[j], 
inputColumn.start[j],
+                inputColumn.length[j]);
+          }
+        } else if (inputColumn.noNulls) {
+          int i = selected[j];
+          splitVectorAcrossWorkers(workers, inputColumn.vector[i], 
inputColumn.start[i],
+              inputColumn.length[i]);
+        } else {
+          int i = selected[j];
+          if (!inputColumn.isNull[i]) {
+            splitVectorAcrossWorkers(workers, inputColumn.vector[i], 
inputColumn.start[i],
+                inputColumn.length[i]);
+          }
+        }
+      }
+    }
+
+    private void initExecutor() {
+      LOG.info("Number of threads used for bloom filter merge: {}", 
numThreads);
+
+      executor = Executors.newFixedThreadPool(numThreads);
+
+      workers = new BloomFilterMergeWorker[numThreads];
+      for (int f = 0; f < numThreads; f++) {
+        workers[f] = new BloomFilterMergeWorker(bfBytes, 0, bfBytes.length, 
aborted);
+        executor.submit(workers[f]);
+      }
+    }
+
+    public int getNumberOfWaitingMergeTasks(){
+      int size = 0;
+      for (BloomFilterMergeWorker w : workers){
+        size += w.queue.size();
+      }
+      return size;
+    }
+
+    private static void splitVectorAcrossWorkers(BloomFilterMergeWorker[] 
workers, byte[] bytes,
+        int start, int length) {
+      if (bytes == null || length == 0) {
+        return;
+      }
+      /*
+       * This will split a byte[] across workers as below:
+       * let's say there are 10 workers for 7813 bytes, in this case
+       * length: 7813, elementPerBatch: 781
+       * bytes assigned to workers: inclusive lower bound, exclusive upper 
bound
+       * 1. worker: 5 -> 786
+       * 2. worker: 786 -> 1567
+       * 3. worker: 1567 -> 2348
+       * 4. worker: 2348 -> 3129
+       * 5. worker: 3129 -> 3910
+       * 6. worker: 3910 -> 4691
+       * 7. worker: 4691 -> 5472
+       * 8. worker: 5472 -> 6253
+       * 9. worker: 6253 -> 7034
+       * 10. worker: 7034 -> 7813 (last element per batch is: 779)
+       *
+       * This way, a particular worker will be given with the same part
+       * of all bloom filters along with the shared base bloom filter,
+       * so the bitwise OR function will not be a subject of threading/sync 
issues.
+       */
+      int elementPerBatch =
+          (int) Math.ceil((double) (length - START_OF_SERIALIZED_LONGS) / 
workers.length);
+
+      for (int w = 0; w < workers.length; w++) {
+        int modifiedStart = START_OF_SERIALIZED_LONGS + w * elementPerBatch;
+        int modifiedLength = (w == workers.length - 1)
+          ? length - (START_OF_SERIALIZED_LONGS + w * elementPerBatch) : 
elementPerBatch;
+
+        ElementWrapper wrapper =
+            new ElementWrapper(bytes, start, length, modifiedStart, 
modifiedLength);
+        workers[w].add(wrapper);
+      }
+    }
+
+    public void shutdownAndWaitForMergeTasks(Aggregation agg, boolean aborted) 
{
+      if (aborted){
+        agg.aborted.set(true);
+      }
+      /**
+       * Executor.shutdownNow() is supposed to send Thread.interrupt to worker 
threads, and they are
+       * supposed to finish their work.
+       */
+      executor.shutdownNow();
+      try {
+        executor.awaitTermination(180, TimeUnit.SECONDS);
+      } catch (InterruptedException e) {
+        LOG.warn("Bloom filter merge is interrupted while waiting to finish, 
this is unexpected",
+            e);
+      }
+    }
+  }
+
+  private static class BloomFilterMergeWorker implements Runnable {
+    private BlockingQueue<ElementWrapper> queue;
+    private byte[] bfAggregation;
+    private int bfAggregationStart;
+    private int bfAggregationLength;
+    private AtomicBoolean aborted;
+
+    public BloomFilterMergeWorker(byte[] bfAggregation, int bfAggregationStart,
+        int bfAggregationLength, AtomicBoolean aborted) {
+      this.bfAggregation = bfAggregation;
+      this.bfAggregationStart = bfAggregationStart;
+      this.bfAggregationLength = bfAggregationLength;
+      this.queue = new LinkedBlockingDeque<>();
+      this.aborted = aborted;
+    }
+
+    public void add(ElementWrapper wrapper) {
+      queue.add(wrapper);
+    }
+
+    @Override
+    public void run() {
+      while (true) {
+        ElementWrapper currentBf = null;
+        try {
+          currentBf = queue.take();
+          // at this point we have a currentBf wrapper which contains the 
whole byte[] of the
+          // serialized bloomfilter, but we only want to merge a modified 
"start -> start+length"
+          // part of it, which is pointed by modifiedStart/modifiedLength 
fields by ElementWrapper
+          merge(currentBf);
+        } catch (InterruptedException e) {// Executor.shutdownNow() is called
+          if (!queue.isEmpty()){
+            LOG.debug(
+                "bloom filter merge was interrupted while processing and queue 
is still not empty"
+                    + ", this is fine in case of shutdownNow");
+          }
+          if (aborted.get()) {
+            LOG.info("bloom filter merge was aborted, won't finish 
merging...");
+            break;
+          }
+          while (!queue.isEmpty()) { // time to finish work if any
+            ElementWrapper lastBloomFilter = queue.poll();
+            merge(lastBloomFilter);
+          }
+          break;
+        }
+      }
+    }
+
+    private void merge(ElementWrapper bloomFilterWrapper) {
+      BloomKFilter.mergeBloomFilterBytes(bfAggregation, bfAggregationStart, 
bfAggregationLength,
+          bloomFilterWrapper.bytes, bloomFilterWrapper.start, 
bloomFilterWrapper.length,
+          bloomFilterWrapper.modifiedStart,
+          bloomFilterWrapper.modifiedStart + 
bloomFilterWrapper.modifiedLength);
+    }
+  }
+
+  public static class ElementWrapper {
+    public byte[] bytes;
+    public int start;
+    public int length;
+    public int modifiedStart;
+    public int modifiedLength;
+
+    public ElementWrapper(byte[] bytes, int start, int length, int 
modifiedStart, int modifiedLength) {
+      this.bytes = bytes;
+      this.start = start;
+      this.length = length;
+      this.modifiedStart = modifiedStart;
+      this.modifiedLength = modifiedLength;
+    }
   }
 
   // This constructor is used to momentarily create the object so match can be 
called.
@@ -89,6 +272,12 @@ public class VectorUDAFBloomFilterMerge extends 
VectorAggregateExpression {
     init();
   }
 
+  public VectorUDAFBloomFilterMerge(VectorAggregationDesc vecAggrDesc, int 
numThreads) {
+    super(vecAggrDesc);
+    this.numThreads = numThreads;
+    init();
+  }
+
   private void init() {
 
     GenericUDAFBloomFilterEvaluator udafBloomFilter =
@@ -99,11 +288,21 @@ public class VectorUDAFBloomFilterMerge extends 
VectorAggregateExpression {
   }
 
   @Override
+  public void finish(AggregationBuffer myagg, boolean aborted) {
+    VectorUDAFBloomFilterMerge.Aggregation agg = 
(VectorUDAFBloomFilterMerge.Aggregation) myagg;
+    if (agg.numThreads > 0) {
+      LOG.info("bloom filter merge: finishing aggregation, waiting tasks: {}",
+          agg.getNumberOfWaitingMergeTasks());
+      agg.shutdownAndWaitForMergeTasks(agg, aborted);
+    }
+  }
+
+  @Override
   public AggregationBuffer getNewAggregationBuffer() throws HiveException {
     if (expectedEntries < 0) {
       throw new IllegalStateException("expectedEntries not initialized");
     }
-    return new Aggregation(expectedEntries);
+    return new Aggregation(expectedEntries, numThreads);
   }
 
   @Override
@@ -129,20 +328,32 @@ public class VectorUDAFBloomFilterMerge extends 
VectorAggregateExpression {
       return;
     }
 
-    if (!batch.selectedInUse && inputColumn.noNulls) {
-      iterateNoSelectionNoNulls(myagg, inputColumn, batchSize);
-    }
-    else if (!batch.selectedInUse) {
-      iterateNoSelectionHasNulls(myagg, inputColumn, batchSize);
-    }
-    else if (inputColumn.noNulls){
-      iterateSelectionNoNulls(myagg, inputColumn, batchSize, batch.selected);
-    }
-    else {
-      iterateSelectionHasNulls(myagg, inputColumn, batchSize, batch.selected);
+    if (myagg.numThreads != 0) {
+      processValues(myagg, inputColumn, batchSize, batch.selectedInUse, 
batch.selected);
+    } else {
+      if (!batch.selectedInUse && inputColumn.noNulls) {
+        iterateNoSelectionNoNulls(myagg, inputColumn, batchSize);
+      } else if (!batch.selectedInUse) {
+        iterateNoSelectionHasNulls(myagg, inputColumn, batchSize);
+      } else if (inputColumn.noNulls) {
+        iterateSelectionNoNulls(myagg, inputColumn, batchSize, batch.selected);
+      } else {
+        iterateSelectionHasNulls(myagg, inputColumn, batchSize, 
batch.selected);
+      }
     }
   }
 
+  private void processValues(
+    Aggregation myagg,
+    ColumnVector inputColumn,
+    int batchSize, boolean selectedInUse, int[] selected){
+
+    VectorUDAFBloomFilterMerge.Aggregation agg = 
(VectorUDAFBloomFilterMerge.Aggregation)myagg;
+
+    agg.mergeBloomFilterBytesFromInputColumn((BytesColumnVector) inputColumn, 
batchSize,
+        selectedInUse, selected);
+  }
+
   private void iterateNoSelectionNoNulls(
       Aggregation myagg,
       ColumnVector inputColumn,
diff --git 
a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java
 
b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java
index b229292..8ec6b40 100644
--- 
a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java
+++ 
b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java
@@ -39,6 +39,7 @@ import java.util.Map;
 import java.util.Set;
 
 import org.apache.calcite.util.Pair;
+import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.hive.common.type.HiveDecimal;
 import org.apache.hadoop.hive.conf.HiveConf;
 import org.apache.hadoop.hive.llap.io.api.LlapProxy;
@@ -47,6 +48,8 @@ import org.apache.hadoop.hive.ql.exec.KeyWrapper;
 import org.apache.hadoop.hive.ql.exec.Operator;
 import org.apache.hadoop.hive.ql.exec.OperatorFactory;
 import 
org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
+import 
org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFBloomFilterMerge;
+import 
org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFCount;
 import 
org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFCountStar;
 import 
org.apache.hadoop.hive.ql.exec.vector.util.FakeCaptureVectorToRowOutputOperator;
 import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromConcat;
@@ -65,6 +68,7 @@ import org.apache.hadoop.hive.ql.plan.OperatorDesc;
 import org.apache.hadoop.hive.ql.plan.VectorGroupByDesc;
 import org.apache.hadoop.hive.ql.plan.VectorGroupByDesc.ProcessingMode;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBloomFilter;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCount;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax;
@@ -2327,6 +2331,25 @@ public class TestVectorGroupByOperator {
         (double)0);
   }
 
+  @Test
+  public void testInstantiateExpression() throws Exception {
+    VectorGroupByOperator op = new VectorGroupByOperator();
+
+    // VectorUDAFBloomFilterMerge with specific constructor
+    VectorAggregationDesc desc = Mockito.mock(VectorAggregationDesc.class);
+    Mockito.when(desc.getVecAggrClass()).thenReturn((Class) 
VectorUDAFBloomFilterMerge.class);
+    Mockito.when(desc.getEvaluator())
+        .thenReturn(new 
GenericUDAFBloomFilter.GenericUDAFBloomFilterEvaluator());
+    VectorAggregateExpression expr = op.instantiateExpression(desc, new 
Configuration());
+    Assert.assertTrue(expr.getClass() == VectorUDAFBloomFilterMerge.class);
+
+    // regular constructor
+    desc = Mockito.mock(VectorAggregationDesc.class);
+    Mockito.when(desc.getVecAggrClass()).thenReturn((Class) 
VectorUDAFCount.class);
+    expr = op.instantiateExpression(desc, new Configuration());
+    Assert.assertTrue(expr.getClass() == VectorUDAFCount.class);
+  }
+
   private void testMultiKey(
       String aggregateName,
       FakeVectorRowBatchFromObjectIterables data,
diff --git 
a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/TestVectorUDAFBloomFilterMerge.java
 
b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/TestVectorUDAFBloomFilterMerge.java
new file mode 100644
index 0000000..d41f3b5
--- /dev/null
+++ 
b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/TestVectorUDAFBloomFilterMerge.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector;
+import org.apache.hive.common.util.BloomKFilter;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestVectorUDAFBloomFilterMerge {
+
+  @Test
+  public void testMergeBloomKFilterBytesParallel() throws Exception {
+    testMergeBloomKFilterBytesParallel(1);
+    testMergeBloomKFilterBytesParallel(2);
+    testMergeBloomKFilterBytesParallel(4);
+    testMergeBloomKFilterBytesParallel(8);
+  }
+
+  private void testMergeBloomKFilterBytesParallel(int threads) throws 
IOException {
+    Configuration conf = new Configuration();
+    conf.setInt(HiveConf.ConfVars.TEZ_BLOOM_FILTER_MERGE_THREADS.varname, 
threads);
+
+    int expectedEntries = 1000000;
+    byte[] bf1Bytes = getBloomKFilterBytesFromStringValues(expectedEntries, 
"bloo", "bloom fil",
+        "bloom filter", "cuckoo filter");
+    byte[] bf2Bytes = getBloomKFilterBytesFromStringValues(expectedEntries, 
"2_bloo", "2_bloom fil",
+        "2_bloom filter", "2_cuckoo filter");
+    byte[] bf3Bytes = getBloomKFilterBytesFromStringValues(expectedEntries, 
"3_bloo", "3_bloom fil",
+        "3_bloom filter", "3_cuckoo filter");
+    byte[] bf4Bytes = getBloomKFilterBytesFromStringValues(expectedEntries, 
"4_bloo", "4_bloom fil",
+        "4_bloom filter", "4_cuckoo filter");
+    byte[] bf5Bytes = getBloomKFilterBytesFromStringValues(expectedEntries, 
"5_bloo", "5_bloom fil",
+        "5_bloom filter", "5_cuckoo filter");
+
+    BytesColumnVector columnVector = new BytesColumnVector();
+    columnVector.reset(); // init buffers
+    columnVector.setVal(0, bf1Bytes);
+    columnVector.setVal(1, bf2Bytes);
+    columnVector.setVal(2, bf3Bytes);
+
+    BytesColumnVector columnVector2 = new BytesColumnVector();
+    columnVector2.reset(); // init buffers
+    columnVector2.setVal(0, bf4Bytes);
+    columnVector2.setVal(1, bf5Bytes);
+
+    VectorUDAFBloomFilterMerge.Aggregation agg =
+        new VectorUDAFBloomFilterMerge.Aggregation(expectedEntries, threads);
+    agg.mergeBloomFilterBytesFromInputColumn(columnVector, 1024, false, null);
+    agg.mergeBloomFilterBytesFromInputColumn(columnVector2, 1024, false, null);
+    new VectorUDAFBloomFilterMerge().finish(agg, false);
+
+    BloomKFilter merged = BloomKFilter.deserialize(new 
ByteArrayInputStream(agg.bfBytes));
+    Assert.assertTrue(merged.testBytes("bloo".getBytes()));
+    Assert.assertTrue(merged.testBytes("cuckoo filter".getBytes()));
+    Assert.assertTrue(merged.testBytes("2_bloo".getBytes()));
+    Assert.assertTrue(merged.testBytes("2_cuckoo filter".getBytes()));
+    Assert.assertTrue(merged.testBytes("3_bloo".getBytes()));
+    Assert.assertTrue(merged.testBytes("3_cuckoo filter".getBytes()));
+    Assert.assertTrue(merged.testBytes("4_bloo".getBytes()));
+    Assert.assertTrue(merged.testBytes("4_cuckoo filter".getBytes()));
+    Assert.assertTrue(merged.testBytes("5_bloo".getBytes()));
+    Assert.assertTrue(merged.testBytes("5_cuckoo filter".getBytes()));
+  }
+
+  private byte[] getBloomKFilterBytesFromStringValues(int expectedEntries, 
String... values)
+      throws IOException {
+    BloomKFilter bf = new BloomKFilter(expectedEntries);
+    for (String val : values) {
+      bf.addString(val);
+    }
+
+    ByteArrayOutputStream bytesOut = new ByteArrayOutputStream();
+    BloomKFilter.serialize(bytesOut, bf);
+    return bytesOut.toByteArray();
+  }
+}
diff --git a/storage-api/src/java/org/apache/hive/common/util/BloomKFilter.java 
b/storage-api/src/java/org/apache/hive/common/util/BloomKFilter.java
index 2386279..9f2f6b3 100644
--- a/storage-api/src/java/org/apache/hive/common/util/BloomKFilter.java
+++ b/storage-api/src/java/org/apache/hive/common/util/BloomKFilter.java
@@ -18,6 +18,7 @@
 
 package org.apache.hive.common.util;
 
+import java.io.ByteArrayOutputStream;
 import java.io.DataInputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
@@ -26,6 +27,8 @@ import java.io.OutputStream;
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 
+import org.apache.hadoop.io.IOUtils;
+
 /**
  * BloomKFilter is variation of {@link BloomFilter}. Unlike BloomFilter, 
BloomKFilter will spread
  * 'k' hash bits within same cache line for better L1 cache performance. The 
way it works is,
@@ -36,7 +39,7 @@ import java.util.Arrays;
  *
  * This implementation has much lesser L1 data cache misses than {@link 
BloomFilter}.
  */
-@SuppressWarnings({ "WeakerAccess", "unused" }) public class BloomKFilter {
+public class BloomKFilter {
   public static final float DEFAULT_FPP = 0.05f;
   private static final int DEFAULT_BLOCK_SIZE = 8;
   private static final int DEFAULT_BLOCK_SIZE_BITS = (int) 
(Math.log(DEFAULT_BLOCK_SIZE) / Math.log(2));
@@ -335,6 +338,13 @@ import java.util.Arrays;
   // NumHashFunctions (1 byte) + bitset array length (4 bytes)
   public static final int START_OF_SERIALIZED_LONGS = 5;
 
+  public static void mergeBloomFilterBytes(
+    byte[] bf1Bytes, int bf1Start, int bf1Length,
+    byte[] bf2Bytes, int bf2Start, int bf2Length) {
+    mergeBloomFilterBytes(bf1Bytes, bf1Start, bf1Length, bf2Bytes, bf2Start, 
bf2Length,
+        START_OF_SERIALIZED_LONGS, bf1Length);
+  }
+
   /**
    * Merges BloomKFilter bf2 into bf1.
    * Assumes 2 BloomKFilters with the same size/hash functions are serialized 
to byte arrays
@@ -348,7 +358,8 @@ import java.util.Arrays;
    */
   public static void mergeBloomFilterBytes(
     byte[] bf1Bytes, int bf1Start, int bf1Length,
-    byte[] bf2Bytes, int bf2Start, int bf2Length) {
+    byte[] bf2Bytes, int bf2Start, int bf2Length, int mergeStart, int 
mergeEnd) {
+
     if (bf1Length != bf2Length) {
       throw new IllegalArgumentException("bf1Length " + bf1Length + " does not 
match bf2Length " + bf2Length);
     }
@@ -362,16 +373,30 @@ import java.util.Arrays;
 
     // Just bitwise-OR the bits together - size/# functions should be the same,
     // rest of the data is serialized long values for the bitset which are 
supposed to be bitwise-ORed.
-    for (int idx = START_OF_SERIALIZED_LONGS; idx < bf1Length; ++idx) {
+    for (int idx = mergeStart; idx < mergeEnd; ++idx) {
       bf1Bytes[bf1Start + idx] |= bf2Bytes[bf2Start + idx];
     }
   }
 
+  public static byte[] getInitialBytes(long expectedEntries) {
+    ByteArrayOutputStream bytesOut = null;
+    try {
+      bytesOut = new ByteArrayOutputStream();
+      BloomKFilter bf = new BloomKFilter(expectedEntries);
+      BloomKFilter.serialize(bytesOut, bf);
+      return bytesOut.toByteArray();
+    } catch (Exception err) {
+      throw new IllegalArgumentException("Error creating aggregation buffer", 
err);
+    } finally {
+      IOUtils.closeStream(bytesOut);
+    }
+  }
+
   /**
    * Bare metal bit set implementation. For performance reasons, this 
implementation does not check
    * for index bounds nor expand the bit set size if the specified index is 
greater than the size.
    */
-  @SuppressWarnings("unused") public static class BitSet {
+  public static class BitSet {
     private final long[] data;
 
     public BitSet(long bits) {
diff --git 
a/storage-api/src/test/org/apache/hive/common/util/TestBloomKFilter.java 
b/storage-api/src/test/org/apache/hive/common/util/TestBloomKFilter.java
index 1b4e210..32f21f8 100644
--- a/storage-api/src/test/org/apache/hive/common/util/TestBloomKFilter.java
+++ b/storage-api/src/test/org/apache/hive/common/util/TestBloomKFilter.java
@@ -508,7 +508,7 @@ public class TestBloomKFilter {
     BloomKFilter.serialize(bytesOut, bf1);
     byte[] bf1Bytes = bytesOut.toByteArray();
     bytesOut.reset();
-    BloomKFilter.serialize(bytesOut, bf1);
+    BloomKFilter.serialize(bytesOut, bf2);
     byte[] bf2Bytes = bytesOut.toByteArray();
 
     // Merge bytes

Reply via email to