This is an automated email from the ASF dual-hosted git repository.

vitalii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git


The following commit(s) were added to refs/heads/master by this push:
     new 6ad0f9f  DRILL-6453: Fix deadlock caused by reading from left and 
right inputs in HashJoin simultaneously.
6ad0f9f is described below

commit 6ad0f9f1bab8bdda18f3eaaf29445bc94355156e
Author: Timothy Farkas <[email protected]>
AuthorDate: Mon Jul 16 15:33:23 2018 -0700

    DRILL-6453: Fix deadlock caused by reading from left and right inputs in 
HashJoin simultaneously.
    
    closes #1408
---
 .../exec/physical/impl/common/HashPartition.java   |  14 +
 .../physical/impl/join/BatchSizePredictor.java     |  79 ++++
 .../physical/impl/join/BatchSizePredictorImpl.java | 165 +++++++++
 .../exec/physical/impl/join/HashJoinBatch.java     | 358 +++++++++++--------
 .../join/HashJoinMechanicalMemoryCalculator.java   |  22 +-
 .../impl/join/HashJoinMemoryCalculator.java        |  21 +-
 .../impl/join/HashJoinMemoryCalculatorImpl.java    | 396 +++++++++++----------
 .../HashTableSizeCalculatorConservativeImpl.java   |   2 +-
 .../impl/join/HashTableSizeCalculatorLeanImpl.java |   2 +-
 .../exec/record/AbstractBinaryRecordBatch.java     |   5 +-
 .../org/apache/drill/exec/record/RecordBatch.java  |  24 +-
 ...orImpl.java => TestBatchSizePredictorImpl.java} |  79 ++--
 .../impl/join/TestBuildSidePartitioningImpl.java   | 275 +++++++++++---
 ...Impl.java => TestHashJoinMemoryCalculator.java} |  49 +--
 ...estHashTableSizeCalculatorConservativeImpl.java |   8 +-
 .../join/TestHashTableSizeCalculatorLeanImpl.java  |   8 +-
 .../impl/join/TestPostBuildCalculationsImpl.java   | 392 +++++++++++++++++---
 17 files changed, 1377 insertions(+), 522 deletions(-)

diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java
index eaccd33..fbdc4f3 100644
--- 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java
@@ -17,6 +17,7 @@
  */
 package org.apache.drill.exec.physical.impl.common;
 
+import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 import org.apache.drill.common.exceptions.RetryAfterSpillException;
 import org.apache.drill.common.exceptions.UserException;
@@ -122,6 +123,7 @@ public class HashPartition implements 
HashJoinMemoryCalculator.PartitionStat {
   private List<HashJoinMemoryCalculator.BatchStat> inMemoryBatchStats = 
Lists.newArrayList();
   private long partitionInMemorySize;
   private long numInMemoryRecords;
+  private boolean updatedRecordsPerBatch = false;
 
   public HashPartition(FragmentContext context, BufferAllocator allocator, 
ChainedHashTable baseHashTable,
                        RecordBatch buildBatch, RecordBatch probeBatch,
@@ -156,6 +158,18 @@ public class HashPartition implements 
HashJoinMemoryCalculator.PartitionStat {
   }
 
   /**
+   * Configure a different temporary batch size when spilling probe batches.
+   * @param newRecordsPerBatch The new temporary batch size to use.
+   */
+  public void updateProbeRecordsPerBatch(int newRecordsPerBatch) {
+    Preconditions.checkArgument(newRecordsPerBatch > 0);
+    Preconditions.checkState(!updatedRecordsPerBatch); // Only allow updating 
once
+    Preconditions.checkState(processingOuter); // We can only update the 
records per batch when probing.
+
+    recordsPerBatch = newRecordsPerBatch;
+  }
+
+  /**
    * Allocate a new vector container for either right or left record batch
    * Add an additional special vector for the hash values
    * Note: this call may OOM !!
diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictor.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictor.java
new file mode 100644
index 0000000..912e4fe
--- /dev/null
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictor.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.physical.impl.join;
+
+import org.apache.drill.exec.record.RecordBatch;
+
+/**
+ * This class predicts the sizes of batches given an input batch.
+ *
+ * <h4>Invariants</h4>
+ * <ul>
+ *   <li>The {@link BatchSizePredictor} assumes that a {@link RecordBatch} is 
in a state where it can return a valid record count.</li>
+ * </ul>
+ */
+public interface BatchSizePredictor {
+  /**
+   * Gets the batchSize computed in the call to {@link #updateStats()}. 
Returns 0 if {@link #hadDataLastTime()} is false.
+   * @return Gets the batchSize computed in the call to {@link 
#updateStats()}. Returns 0 if {@link #hadDataLastTime()} is false.
+   * @throws IllegalStateException if {@link #updateStats()} was never called.
+   */
+  long getBatchSize();
+
+  /**
+   * Gets the number of records computed in the call to {@link 
#updateStats()}. Returns 0 if {@link #hadDataLastTime()} is false.
+   * @return Gets the number of records computed in the call to {@link 
#updateStats()}. Returns 0 if {@link #hadDataLastTime()} is false.
+   * @throws IllegalStateException if {@link #updateStats()} was never called.
+   */
+  int getNumRecords();
+
+  /**
+   * True if the input batch had records in the last call to {@link 
#updateStats()}. False otherwise.
+   * @return True if the input batch had records in the last call to {@link 
#updateStats()}. False otherwise.
+   */
+  boolean hadDataLastTime();
+
+  /**
+   * This method can be called multiple times to collect stats about the 
latest data in the provided record batch. These
+   * stats are used to predict batch sizes. If the batch currently has no 
data, this method is a noop. This method must be
+   * called at least once before {@link #predictBatchSize(int, boolean)}.
+   */
+  void updateStats();
+
+  /**
+   * Predicts the size of a batch using the current collected stats.
+   * @param desiredNumRecords The number of records contained in the batch 
whose size we want to predict.
+   * @param reserveHash Whether or not to include a column containing hash 
values.
+   * @return The size of the predicted batch.
+   * @throws IllegalStateException if {@link #hadDataLastTime()} is false or 
{@link #updateStats()} was not called.
+   */
+  long predictBatchSize(int desiredNumRecords, boolean reserveHash);
+
+  /**
+   * A factory for creating {@link BatchSizePredictor}s.
+   */
+  interface Factory {
+    /**
+     * Creates a predictor with a batch whose data needs to be used to predict 
other batch sizes.
+     * @param batch The batch whose size needs to be predicted.
+     * @param fragmentationFactor A constant used to predict value vector 
doubling.
+     * @param safetyFactor A constant used to leave padding for unpredictable 
incoming batches.
+     */
+    BatchSizePredictor create(RecordBatch batch, double fragmentationFactor, 
double safetyFactor);
+  }
+}
diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictorImpl.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictorImpl.java
new file mode 100644
index 0000000..bbebd2b
--- /dev/null
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictorImpl.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.physical.impl.join;
+
+import com.google.common.base.Preconditions;
+import org.apache.drill.exec.record.RecordBatch;
+import org.apache.drill.exec.record.RecordBatchSizer;
+import org.apache.drill.exec.vector.IntVector;
+
+import java.util.Map;
+
+public class BatchSizePredictorImpl implements BatchSizePredictor {
+  private RecordBatch batch;
+  private double fragmentationFactor;
+  private double safetyFactor;
+
+  private long batchSize;
+  private int numRecords;
+  private boolean updatedStats;
+  private boolean hasData;
+
+  public BatchSizePredictorImpl(final RecordBatch batch,
+                                final double fragmentationFactor,
+                                final double safetyFactor) {
+    this.batch = Preconditions.checkNotNull(batch);
+    this.fragmentationFactor = fragmentationFactor;
+    this.safetyFactor = safetyFactor;
+  }
+
+  @Override
+  public long getBatchSize() {
+    Preconditions.checkState(updatedStats);
+    return hasData? batchSize: 0;
+  }
+
+  @Override
+  public int getNumRecords() {
+    Preconditions.checkState(updatedStats);
+    return hasData? numRecords: 0;
+  }
+
+  @Override
+  public boolean hadDataLastTime() {
+    return hasData;
+  }
+
+  @Override
+  public void updateStats() {
+    final RecordBatchSizer batchSizer = new RecordBatchSizer(batch);
+    numRecords = batchSizer.rowCount();
+    updatedStats = true;
+    hasData = numRecords > 0;
+
+    if (hasData) {
+      batchSize = getBatchSizeEstimate(batch);
+    }
+  }
+
+  @Override
+  public long predictBatchSize(int desiredNumRecords, boolean reserveHash) {
+    Preconditions.checkState(hasData);
+    // Safety factor can be multiplied at the end since these batches are 
coming from exchange operators, so no excess value vector doubling
+    return computeMaxBatchSize(batchSize,
+      numRecords,
+      desiredNumRecords,
+      fragmentationFactor,
+      safetyFactor,
+      reserveHash);
+  }
+
+  public static long computeValueVectorSize(long numRecords, long byteSize) {
+    long naiveSize = numRecords * byteSize;
+    return roundUpToPowerOf2(naiveSize);
+  }
+
+  public static long computeValueVectorSize(long numRecords, long byteSize, 
double safetyFactor) {
+    long naiveSize = RecordBatchSizer.multiplyByFactor(numRecords * byteSize, 
safetyFactor);
+    return roundUpToPowerOf2(naiveSize);
+  }
+
+  public static long roundUpToPowerOf2(long num) {
+    Preconditions.checkArgument(num >= 1);
+    return num == 1 ? 1 : Long.highestOneBit(num - 1) << 1;
+  }
+
+  public static long computeMaxBatchSizeNoHash(final long incomingBatchSize,
+                                         final int incomingNumRecords,
+                                         final int desiredNumRecords,
+                                         final double fragmentationFactor,
+                                         final double safetyFactor) {
+    long maxBatchSize = computePartitionBatchSize(incomingBatchSize, 
incomingNumRecords, desiredNumRecords);
+    // Multiple by fragmentation factor
+    return RecordBatchSizer.multiplyByFactors(maxBatchSize, 
fragmentationFactor, safetyFactor);
+  }
+
+  public static long computeMaxBatchSize(final long incomingBatchSize,
+                                         final int incomingNumRecords,
+                                         final int desiredNumRecords,
+                                         final double fragmentationFactor,
+                                         final double safetyFactor,
+                                         final boolean reserveHash) {
+    long size = computeMaxBatchSizeNoHash(incomingBatchSize,
+      incomingNumRecords,
+      desiredNumRecords,
+      fragmentationFactor,
+      safetyFactor);
+
+    if (!reserveHash) {
+      return size;
+    }
+
+    long hashSize = desiredNumRecords * ((long) IntVector.VALUE_WIDTH);
+    hashSize = RecordBatchSizer.multiplyByFactors(hashSize, 
fragmentationFactor);
+
+    return size + hashSize;
+  }
+
+  public static long computePartitionBatchSize(final long incomingBatchSize,
+                                               final int incomingNumRecords,
+                                               final int desiredNumRecords) {
+    return (long) Math.ceil((((double) incomingBatchSize) /
+      ((double) incomingNumRecords)) *
+      ((double) desiredNumRecords));
+  }
+
+  public static long getBatchSizeEstimate(final RecordBatch recordBatch) {
+    final RecordBatchSizer sizer = new RecordBatchSizer(recordBatch);
+    long size = 0L;
+
+    for (Map.Entry<String, RecordBatchSizer.ColumnSize> column : 
sizer.columns().entrySet()) {
+      size += computeValueVectorSize(recordBatch.getRecordCount(), 
column.getValue().getStdNetOrNetSizePerEntry());
+    }
+
+    return size;
+  }
+
+  public static class Factory implements BatchSizePredictor.Factory {
+    public static final Factory INSTANCE = new Factory();
+
+    private Factory() {
+    }
+
+    @Override
+    public BatchSizePredictor create(final RecordBatch batch,
+                                     final double fragmentationFactor,
+                                     final double safetyFactor) {
+      return new BatchSizePredictorImpl(batch, fragmentationFactor, 
safetyFactor);
+    }
+  }
+}
diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinBatch.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinBatch.java
index b1ea96f..0bd6fe6 100644
--- 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinBatch.java
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinBatch.java
@@ -27,6 +27,7 @@ import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 
 import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang3.mutable.MutableBoolean;
 import org.apache.drill.common.exceptions.UserException;
 import org.apache.drill.common.expression.FieldReference;
 import org.apache.drill.common.expression.PathSegment;
@@ -68,6 +69,9 @@ import org.apache.drill.exec.vector.ValueVector;
 import org.apache.drill.exec.vector.complex.AbstractContainerVector;
 import org.apache.calcite.rel.core.JoinRelType;
 
+import static org.apache.drill.exec.record.RecordBatch.IterOutcome.EMIT;
+import static 
org.apache.drill.exec.record.RecordBatch.IterOutcome.OK_NEW_SCHEMA;
+
 /**
  *   This class implements the runtime execution for the Hash-Join operator
  *   supporting INNER, LEFT OUTER, RIGHT OUTER, and FULL OUTER joins
@@ -114,6 +118,7 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
 
   // Fields used for partitioning
 
+  private long maxIncomingBatchSize;
   /**
    * The number of {@link HashPartition}s. This is configured via a system 
option and set in {@link #partitionNumTuning(int, 
HashJoinMemoryCalculator.BuildSidePartitioning)}.
    */
@@ -125,7 +130,8 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
    * The master class used to generate {@link HashTable}s.
    */
   private ChainedHashTable baseHashTable;
-  private boolean buildSideIsEmpty = true;
+  private MutableBoolean buildSideIsEmpty = new MutableBoolean(false);
+  private MutableBoolean probeSideIsEmpty = new MutableBoolean(false);
   private boolean canSpill = true;
   private boolean wasKilled; // a kill was received, may need to clean spilled 
partns
 
@@ -138,7 +144,7 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
   private int outputRecords;
 
   // Schema of the build side
-  private BatchSchema rightSchema;
+  private BatchSchema buildSchema;
   // Schema of the probe side
   private BatchSchema probeSchema;
 
@@ -150,9 +156,13 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
   private RecordBatch probeBatch;
 
   /**
-   * Flag indicating whether or not the first data holding batch needs to be 
fetched.
+   * Flag indicating whether or not the first data holding build batch needs 
to be fetched.
+   */
+  private MutableBoolean prefetchedBuild = new MutableBoolean(false);
+  /**
+   * Flag indicating whether or not the first data holding probe batch needs 
to be fetched.
    */
-  private boolean prefetched;
+  private MutableBoolean prefetchedProbe = new MutableBoolean(false);
 
   // For handling spilling
   private SpillSet spillSet;
@@ -220,123 +230,120 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
   protected void buildSchema() throws SchemaChangeException {
     // We must first get the schemas from upstream operators before we can 
build
     // our schema.
-    boolean validSchema = sniffNewSchemas();
+    boolean validSchema = prefetchFirstBatchFromBothSides();
 
     if (validSchema) {
       // We are able to construct a valid schema from the upstream data.
       // Setting the state here makes sure AbstractRecordBatch returns 
OK_NEW_SCHEMA
       state = BatchState.BUILD_SCHEMA;
-    } else {
-      verifyOutcomeToSetBatchState(leftUpstream, rightUpstream);
+
+      if (leftUpstream == OK_NEW_SCHEMA) {
+        probeSchema = left.getSchema();
+      }
+
+      if (rightUpstream == OK_NEW_SCHEMA) {
+        buildSchema = right.getSchema();
+        // position of the new "column" for keeping the hash values (after the 
real columns)
+        rightHVColPosition = right.getContainer().getNumberOfColumns();
+        // We only need the hash tables if we have data on the build side.
+        setupHashTable();
+      }
+
+      try {
+        hashJoinProbe = setupHashJoinProbe();
+      } catch (IOException | ClassTransformationException e) {
+        throw new SchemaChangeException(e);
+      }
     }
 
     // If we have a valid schema, this will build a valid container. If we 
were unable to obtain a valid schema,
-    // we still need to build a dummy schema. These code handles both cases 
for us.
+    // we still need to build a dummy schema. This code handles both cases for 
us.
     setupOutputContainerSchema();
     container.buildSchema(BatchSchema.SelectionVectorMode.NONE);
-
-    // Initialize the hash join helper context
-    if (rightUpstream == IterOutcome.OK_NEW_SCHEMA) {
-      // We only need the hash tables if we have data on the build side.
-      setupHashTable();
-    }
-
-    try {
-      hashJoinProbe = setupHashJoinProbe();
-    } catch (IOException | ClassTransformationException e) {
-      throw new SchemaChangeException(e);
-    }
   }
 
-  @Override
-  protected boolean prefetchFirstBatchFromBothSides() {
-    if (leftUpstream != IterOutcome.NONE) {
-      // We can only get data if there is data available
-      leftUpstream = sniffNonEmptyBatch(leftUpstream, LEFT_INDEX, left);
-    }
-
-    if (rightUpstream != IterOutcome.NONE) {
-      // We can only get data if there is data available
-      rightUpstream = sniffNonEmptyBatch(rightUpstream, RIGHT_INDEX, right);
-    }
-
-    buildSideIsEmpty = rightUpstream == IterOutcome.NONE;
-
-    if (verifyOutcomeToSetBatchState(leftUpstream, rightUpstream)) {
-      // For build side, use aggregate i.e. average row width across batches
-      batchMemoryManager.update(LEFT_INDEX, 0);
-      batchMemoryManager.update(RIGHT_INDEX, 0, true);
-
-      logger.debug("BATCH_STATS, incoming left: {}", 
batchMemoryManager.getRecordBatchSizer(LEFT_INDEX));
-      logger.debug("BATCH_STATS, incoming right: {}", 
batchMemoryManager.getRecordBatchSizer(RIGHT_INDEX));
+  /**
+   * Prefetches the first build side data holding batch.
+   */
+  private void prefetchFirstBuildBatch() {
+    rightUpstream = prefetchFirstBatch(rightUpstream,
+      prefetchedBuild,
+      buildSideIsEmpty,
+      RIGHT_INDEX,
+      right,
+      () -> {
+        batchMemoryManager.update(RIGHT_INDEX, 0, true);
+        logger.debug("BATCH_STATS, incoming right: {}", 
batchMemoryManager.getRecordBatchSizer(RIGHT_INDEX));
+      });
+  }
 
-      // Got our first batche(s)
-      state = BatchState.FIRST;
-      return true;
-    } else {
-      return false;
-    }
+  /**
+   * Prefetches the first build side data holding batch.
+   */
+  private void prefetchFirstProbeBatch() {
+    leftUpstream =  prefetchFirstBatch(leftUpstream,
+      prefetchedProbe,
+      probeSideIsEmpty,
+      LEFT_INDEX,
+      left,
+      () -> {
+        batchMemoryManager.update(LEFT_INDEX, 0);
+        logger.debug("BATCH_STATS, incoming left: {}", 
batchMemoryManager.getRecordBatchSizer(LEFT_INDEX));
+      });
   }
 
   /**
-   * Sniffs all data necessary to construct a schema.
-   * @return True if all the data necessary to construct a schema has been 
retrieved. False otherwise.
+   * Used to fetch the first data holding batch from either the build or probe 
side.
+   * @param outcome The current upstream outcome for either the build or probe 
side.
+   * @param prefetched A flag indicating if we have already done a prefetch of 
the first data holding batch for the probe or build side.
+   * @param isEmpty A flag indicating if the probe or build side is empty.
+   * @param index The upstream index of the probe or build batch.
+   * @param batch The probe or build batch itself.
+   * @param memoryManagerUpdate A lambda function to execute the memory 
manager update for the probe or build batch.
+   * @return The current {@link 
org.apache.drill.exec.record.RecordBatch.IterOutcome}.
    */
-  private boolean sniffNewSchemas() {
-    do {
-      // Ask for data until we get a valid result.
-      leftUpstream = next(LEFT_INDEX, left);
-    } while (leftUpstream == IterOutcome.NOT_YET);
+  private IterOutcome prefetchFirstBatch(IterOutcome outcome,
+                                         final MutableBoolean prefetched,
+                                         final MutableBoolean isEmpty,
+                                         final int index,
+                                         final RecordBatch batch,
+                                         final Runnable memoryManagerUpdate) {
+    if (prefetched.booleanValue()) {
+      // We have already prefetch the first data holding batch
+      return outcome;
+    }
 
-    boolean isValidLeft = false;
+    // If we didn't retrieve our first data holding batch, we need to do it 
now.
+    prefetched.setValue(true);
 
-    switch (leftUpstream) {
-      case OK_NEW_SCHEMA:
-        probeSchema = probeBatch.getSchema();
-      case NONE:
-        isValidLeft = true;
-        break;
-      case OK:
-      case EMIT:
-        throw new IllegalStateException("Unsupported outcome while building 
schema " + leftUpstream);
-      default:
-        // Termination condition
+    if (outcome != IterOutcome.NONE) {
+      // We can only get data if there is data available
+      outcome = sniffNonEmptyBatch(outcome, index, batch);
     }
 
-    do {
-      // Ask for data until we get a valid result.
-      rightUpstream = next(RIGHT_INDEX, right);
-    } while (rightUpstream == IterOutcome.NOT_YET);
-
-    boolean isValidRight = false;
+    isEmpty.setValue(outcome == IterOutcome.NONE); // If we recieved NONE 
there is no data.
 
-    switch (rightUpstream) {
-      case OK_NEW_SCHEMA:
-        // We need to have the schema of the build side even when the build 
side is empty
-        rightSchema = buildBatch.getSchema();
-        // position of the new "column" for keeping the hash values (after the 
real columns)
-        rightHVColPosition = buildBatch.getContainer().getNumberOfColumns();
-      case NONE:
-        isValidRight = true;
-        break;
-      case OK:
-      case EMIT:
-        throw new IllegalStateException("Unsupported outcome while building 
schema " + leftUpstream);
-      default:
-        // Termination condition
+    if (outcome == IterOutcome.OUT_OF_MEMORY) {
+      // We reached a termination state
+      state = BatchState.OUT_OF_MEMORY;
+    } else if (outcome == IterOutcome.STOP) {
+      // We reached a termination state
+      state = BatchState.STOP;
+    } else {
+      // Got our first batch(es)
+      memoryManagerUpdate.run();
+      state = BatchState.FIRST;
     }
 
-    // Left and right sides must return a valid response and both sides cannot 
be NONE.
-    return (isValidLeft && isValidRight) &&
-      (leftUpstream != IterOutcome.NONE && rightUpstream != IterOutcome.NONE);
+    return outcome;
   }
 
   /**
-   * Currently in order to accurately predict memory usage for spilling, the 
first non-empty build side and probe side batches are needed. This method
-   * fetches the first non-empty batch from the left or right side.
+   * Currently in order to accurately predict memory usage for spilling, the 
first non-empty build or probe side batch is needed. This method
+   * fetches the first non-empty batch from the probe or build side.
    * @param curr The current outcome.
-   * @param inputIndex Index specifying whether to work with the left or right 
input.
-   * @param recordBatch The left or right record batch.
+   * @param inputIndex Index specifying whether to work with the prorbe or 
build input.
+   * @param recordBatch The probe or build record batch.
    * @return The {@link org.apache.drill.exec.record.RecordBatch.IterOutcome} 
for the left or right record batch.
    */
   private IterOutcome sniffNonEmptyBatch(IterOutcome curr, int inputIndex, 
RecordBatch recordBatch) {
@@ -354,8 +361,10 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
         case NOT_YET:
           // We need to try again
           break;
+        case EMIT:
+          throw new UnsupportedOperationException("We do not support " + EMIT);
         default:
-          // Other cases termination conditions
+          // Other cases are termination conditions
           return curr;
       }
     }
@@ -381,96 +390,119 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
 
   @Override
   public IterOutcome innerNext() {
-    if (!prefetched) {
-      // If we didn't retrieve our first data hold batch, we need to do it now.
-      prefetched = true;
-      prefetchFirstBatchFromBothSides();
-
-      // Handle emitting the correct outcome for termination conditions
-      // Use the state set by prefetchFirstBatchFromBothSides to emit the 
correct termination outcome.
-      switch (state) {
-        case DONE:
-          return IterOutcome.NONE;
-        case STOP:
-          return IterOutcome.STOP;
-        case OUT_OF_MEMORY:
-          return IterOutcome.OUT_OF_MEMORY;
-        default:
-          // No termination condition so continue processing.
-      }
-    }
-
-    if ( wasKilled ) {
+    if (wasKilled) {
+      // We have recieved a kill signal. We need to stop processing.
       this.cleanup();
       super.close();
       return IterOutcome.NONE;
     }
 
+    prefetchFirstBuildBatch();
+
+    if (rightUpstream.isError()) {
+      // A termination condition was reached while prefetching the first build 
side data holding batch.
+      // We need to terminate.
+      return rightUpstream;
+    }
+
     try {
       /* If we are here for the first time, execute the build phase of the
        * hash join and setup the run time generated class for the probe side
        */
       if (state == BatchState.FIRST) {
         // Build the hash table, using the build side record batches.
-        executeBuildPhase();
+        final IterOutcome buildExecuteTermination = executeBuildPhase();
+
+        if (buildExecuteTermination != null) {
+          // A termination condition was reached while executing the build 
phase.
+          // We need to terminate.
+          return buildExecuteTermination;
+        }
+
         // Update the hash table related stats for the operator
         updateStats();
-        // Initialize various settings for the probe side
-        hashJoinProbe.setupHashJoinProbe(probeBatch, this, joinType, 
leftUpstream, partitions, cycleNum, container, spilledInners, buildSideIsEmpty, 
numPartitions, rightHVColPosition);
       }
 
       // Try to probe and project, or recursively handle a spilled partition
-      if ( ! buildSideIsEmpty ||  // If there are build-side rows
-           joinType != JoinRelType.INNER) {  // or if this is a left/full 
outer join
-
-        // Allocate the memory for the vectors in the output container
-        batchMemoryManager.allocateVectors(container);
-        
hashJoinProbe.setTargetOutputCount(batchMemoryManager.getOutputRowCount());
+      if (!buildSideIsEmpty.booleanValue() ||  // If there are build-side rows
+        joinType != JoinRelType.INNER) {  // or if this is a left/full outer 
join
 
-        outputRecords = hashJoinProbe.probeAndProject();
+        prefetchFirstProbeBatch();
 
-        for (final VectorWrapper<?> v : container) {
-          v.getValueVector().getMutator().setValueCount(outputRecords);
+        if (leftUpstream.isError()) {
+          // A termination condition was reached while prefetching the first 
probe side data holding batch.
+          // We need to terminate.
+          return leftUpstream;
         }
-        container.setRecordCount(outputRecords);
 
-        batchMemoryManager.updateOutgoingStats(outputRecords);
-        if (logger.isDebugEnabled()) {
-          logger.debug("BATCH_STATS, outgoing: {}", new 
RecordBatchSizer(this));
-        }
+        if (!buildSideIsEmpty.booleanValue() || 
!probeSideIsEmpty.booleanValue()) {
+          // Only allocate outgoing vectors and execute probing logic if there 
is data
 
-        /* We are here because of one the following
-         * 1. Completed processing of all the records and we are done
-         * 2. We've filled up the outgoing batch to the maximum and we need to 
return upstream
-         * Either case build the output container's schema and return
-         */
-        if (outputRecords > 0 || state == BatchState.FIRST) {
           if (state == BatchState.FIRST) {
-            state = BatchState.NOT_FIRST;
+            // Initialize various settings for the probe side
+            hashJoinProbe.setupHashJoinProbe(probeBatch,
+              this,
+              joinType,
+              leftUpstream,
+              partitions,
+              cycleNum,
+              container,
+              spilledInners,
+              buildSideIsEmpty.booleanValue(),
+              numPartitions,
+              rightHVColPosition);
+          }
+
+          // Allocate the memory for the vectors in the output container
+          batchMemoryManager.allocateVectors(container);
+
+          
hashJoinProbe.setTargetOutputCount(batchMemoryManager.getOutputRowCount());
+
+          outputRecords = hashJoinProbe.probeAndProject();
+
+          for (final VectorWrapper<?> v : container) {
+            v.getValueVector().getMutator().setValueCount(outputRecords);
+          }
+          container.setRecordCount(outputRecords);
+
+          batchMemoryManager.updateOutgoingStats(outputRecords);
+          if (logger.isDebugEnabled()) {
+            logger.debug("BATCH_STATS, outgoing: {}", new 
RecordBatchSizer(this));
           }
 
-          return IterOutcome.OK;
+          /* We are here because of one the following
+           * 1. Completed processing of all the records and we are done
+           * 2. We've filled up the outgoing batch to the maximum and we need 
to return upstream
+           * Either case build the output container's schema and return
+           */
+          if (outputRecords > 0 || state == BatchState.FIRST) {
+            if (state == BatchState.FIRST) {
+              state = BatchState.NOT_FIRST;
+            }
+
+            return IterOutcome.OK;
+          }
         }
 
         // Free all partitions' in-memory data structures
         // (In case need to start processing spilled partitions)
-        for ( HashPartition partn : partitions ) {
+        for (HashPartition partn : partitions) {
           partn.cleanup(false); // clean, but do not delete the spill files !!
         }
 
         //
         //  (recursively) Handle the spilled partitions, if any
         //
-        if ( !buildSideIsEmpty && !spilledPartitionsList.isEmpty()) {
+        if (!buildSideIsEmpty.booleanValue() && 
!spilledPartitionsList.isEmpty()) {
           // Get the next (previously) spilled partition to handle as incoming
           HJSpilledPartition currSp = spilledPartitionsList.remove(0);
 
           // Create a BUILD-side "incoming" out of the inner spill file of 
that partition
-          buildBatch = new SpilledRecordbatch(currSp.innerSpillFile, 
currSp.innerSpilledBatches, context, rightSchema, oContext, spillSet);
+          buildBatch = new SpilledRecordbatch(currSp.innerSpillFile, 
currSp.innerSpilledBatches, context, buildSchema, oContext, spillSet);
           // The above ctor call also got the first batch; need to update the 
outcome
           rightUpstream = ((SpilledRecordbatch) 
buildBatch).getInitialOutcome();
 
-          if ( currSp.outerSpilledBatches > 0 ) {
+          if (currSp.outerSpilledBatches > 0) {
             // Create a PROBE-side "incoming" out of the outer spill file of 
that partition
             probeBatch = new SpilledRecordbatch(currSp.outerSpillFile, 
currSp.outerSpilledBatches, context, probeSchema, oContext, spillSet);
             // The above ctor call also got the first batch; need to update 
the outcome
@@ -644,13 +676,14 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
         buildBatch,
         probeBatch,
         buildJoinColumns,
+        leftUpstream == IterOutcome.NONE, // probeEmpty
         allocator.getLimit(),
+        maxIncomingBatchSize,
         numPartitions,
         RECORDS_PER_BATCH,
         RECORDS_PER_BATCH,
         maxBatchSize,
         maxBatchSize,
-        batchMemoryManager.getOutputRowCount(),
         batchMemoryManager.getOutputBatchSize(),
         HashTable.DEFAULT_LOAD_FACTOR);
 
@@ -689,12 +722,13 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
    *  Execute the BUILD phase; first read incoming and split rows into 
partitions;
    *  may decide to spill some of the partitions
    *
+   * @return Returns an {@link 
org.apache.drill.exec.record.RecordBatch.IterOutcome} if a termination 
condition is reached. Otherwise returns null.
    * @throws SchemaChangeException
    */
-  public void executeBuildPhase() throws SchemaChangeException {
-    if (rightUpstream == IterOutcome.NONE) {
+  public IterOutcome executeBuildPhase() throws SchemaChangeException {
+    if (buildSideIsEmpty.booleanValue()) {
       // empty right
-      return;
+      return null;
     }
 
     HashJoinMemoryCalculator.BuildSidePartitioning buildCalc;
@@ -716,13 +750,14 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
         buildBatch,
         probeBatch,
         buildJoinColumns,
+        leftUpstream == IterOutcome.NONE, // probeEmpty
         allocator.getLimit(),
+        maxIncomingBatchSize,
         numPartitions,
         RECORDS_PER_BATCH,
         RECORDS_PER_BATCH,
         maxBatchSize,
         maxBatchSize,
-        batchMemoryManager.getOutputRowCount(),
         batchMemoryManager.getOutputBatchSize(),
         HashTable.DEFAULT_LOAD_FACTOR);
 
@@ -754,8 +789,8 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
         continue;
 
       case OK_NEW_SCHEMA:
-        if (!rightSchema.equals(buildBatch.getSchema())) {
-          throw SchemaChangeException.schemaChanged("Hash join does not 
support schema changes in build side.", rightSchema, buildBatch.getSchema());
+        if (!buildSchema.equals(buildBatch.getSchema())) {
+          throw SchemaChangeException.schemaChanged("Hash join does not 
support schema changes in build side.", buildSchema, buildBatch.getSchema());
         }
         for (HashPartition partn : partitions) { partn.updateBatches(); }
         // Fall through
@@ -801,8 +836,16 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
       }
     }
 
+    prefetchFirstProbeBatch();
+
+    if (leftUpstream.isError()) {
+      // A termination condition was reached while prefetching the first build 
side data holding batch.
+      // We need to terminate.
+      return leftUpstream;
+    }
+
     HashJoinMemoryCalculator.PostBuildCalculations postBuildCalc = 
buildCalc.next();
-    postBuildCalc.initialize();
+    postBuildCalc.initialize(probeSideIsEmpty.booleanValue()); // probeEmpty
 
     //
     //  Traverse all the in-memory partitions' incoming batches, and build 
their hash tables
@@ -849,14 +892,18 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
 
         spilledInners[partn.getPartitionNum()] = sp; // for the outer to find 
the SP later
         partn.closeWriter();
+
+        
partn.updateProbeRecordsPerBatch(postBuildCalc.getProbeRecordsPerBatch());
       }
     }
+
+    return null;
   }
 
   private void setupOutputContainerSchema() {
 
-    if (rightSchema != null) {
-      for (final MaterializedField field : rightSchema) {
+    if (buildSchema != null) {
+      for (final MaterializedField field : buildSchema) {
         final MajorType inputType = field.getType();
         final MajorType outputType;
         // If left or full outer join, then the output type must be nullable. 
However, map types are
@@ -938,6 +985,7 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
 
     this.allocator = oContext.getAllocator();
 
+    maxIncomingBatchSize = 
context.getOptions().getLong(ExecConstants.OUTPUT_BATCH_SIZE);
     numPartitions = 
(int)context.getOptions().getOption(ExecConstants.HASHJOIN_NUM_PARTITIONS_VALIDATOR);
     if ( numPartitions == 1 ) { //
       disableSpilling("Spilling is disabled due to configuration setting of 
num_partitions to 1");
@@ -976,7 +1024,7 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
    * spillSet.
    */
   private void cleanup() {
-    if ( buildSideIsEmpty ) { return; } // not set up; nothing to clean
+    if ( buildSideIsEmpty.booleanValue() ) { return; } // not set up; nothing 
to clean
     if ( spillSet.getWriteBytes() > 0 ) {
       stats.setLongStat(Metric.SPILL_MB, // update stats - total MB spilled
         (int) Math.round(spillSet.getWriteBytes() / 1024.0D / 1024.0));
@@ -1027,7 +1075,7 @@ public class HashJoinBatch extends 
AbstractBinaryRecordBatch<HashJoinPOP> {
    * written is updated at close time in {@link #cleanup()}.
    */
   private void updateStats() {
-    if ( buildSideIsEmpty ) { return; } // no stats when the right side is 
empty
+    if ( buildSideIsEmpty.booleanValue() ) { return; } // no stats when the 
right side is empty
     if ( cycleNum > 0 ) { return; } // These stats are only for before 
processing spilled files
 
     final HashTableStats htStats = new HashTableStats();
diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMechanicalMemoryCalculator.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMechanicalMemoryCalculator.java
index fb087a0..af6be8b 100644
--- 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMechanicalMemoryCalculator.java
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMechanicalMemoryCalculator.java
@@ -59,6 +59,7 @@ public class HashJoinMechanicalMemoryCalculator implements 
HashJoinMemoryCalcula
 
     private int initialPartitions;
     private PartitionStatSet partitionStatSet;
+    private int recordsPerPartitionBatchProbe;
 
     public MechanicalBuildSidePartitioning(int maxNumInMemBatches) {
       this.maxNumInMemBatches = maxNumInMemBatches;
@@ -70,16 +71,18 @@ public class HashJoinMechanicalMemoryCalculator implements 
HashJoinMemoryCalcula
                            RecordBatch buildSideBatch,
                            RecordBatch probeSideBatch,
                            Set<String> joinColumns,
+                           boolean probeEmpty,
                            long memoryAvailable,
+                           long maxIncomingBatchSize,
                            int initialPartitions,
                            int recordsPerPartitionBatchBuild,
                            int recordsPerPartitionBatchProbe,
                            int maxBatchNumRecordsBuild,
                            int maxBatchNumRecordsProbe,
-                           int outputBatchNumRecords,
                            int outputBatchSize,
                            double loadFactor) {
       this.initialPartitions = initialPartitions;
+      this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
     }
 
     @Override
@@ -115,7 +118,7 @@ public class HashJoinMechanicalMemoryCalculator implements 
HashJoinMemoryCalcula
     @Nullable
     @Override
     public PostBuildCalculations next() {
-      return new MechanicalPostBuildCalculations(maxNumInMemBatches, 
partitionStatSet);
+      return new MechanicalPostBuildCalculations(maxNumInMemBatches, 
partitionStatSet, recordsPerPartitionBatchProbe);
     }
 
     @Override
@@ -127,16 +130,23 @@ public class HashJoinMechanicalMemoryCalculator 
implements HashJoinMemoryCalcula
   public static class MechanicalPostBuildCalculations implements 
PostBuildCalculations {
     private final int maxNumInMemBatches;
     private final PartitionStatSet partitionStatSet;
+    private final int recordsPerPartitionBatchProbe;
 
-    public MechanicalPostBuildCalculations(int maxNumInMemBatches,
-                                           PartitionStatSet partitionStatSet) {
+    public MechanicalPostBuildCalculations(final int maxNumInMemBatches,
+                                           final PartitionStatSet 
partitionStatSet,
+                                           final int 
recordsPerPartitionBatchProbe) {
       this.maxNumInMemBatches = maxNumInMemBatches;
       this.partitionStatSet = Preconditions.checkNotNull(partitionStatSet);
+      this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
     }
 
     @Override
-    public void initialize() {
-      // Do nothing
+    public void initialize(boolean probeEmty) {
+    }
+
+    @Override
+    public int getProbeRecordsPerBatch() {
+      return recordsPerPartitionBatchProbe;
     }
 
     @Override
diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculator.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculator.java
index 868fbfd..0ccd912 100644
--- 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculator.java
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculator.java
@@ -34,7 +34,7 @@ import java.util.Set;
  * different memory calculations at each phase. The phases of execution have 
been broken down
  * into an explicit state machine diagram below. What ocurrs in each state is 
described in
  * the documentation of the {@link HashJoinState} class below. <b>Note:</b> 
the transition from Probing
- * and Partitioning back to Build Side Partitioning. This happens we had to 
spill probe side
+ * and Partitioning back to Build Side Partitioning. This happens when we had 
to spill probe side
  * partitions and we needed to recursively process spilled partitions. This 
recursion is
  * described in more detail in the example below.
  * </p>
@@ -86,6 +86,14 @@ public interface HashJoinMemoryCalculator extends 
HashJoinStateCalculator<HashJo
   /**
    * The interface representing the {@link HashJoinStateCalculator} 
corresponding to the
    * {@link HashJoinState#BUILD_SIDE_PARTITIONING} state.
+   *
+   * <h4>Invariants</h4>
+   * <ul>
+   *   <li>
+   *     This calculator will only be used when there is build side data. If 
there is no build side data, the caller
+   *     should not invoke this calculator.
+   *   </li>
+   * </ul>
    */
   interface BuildSidePartitioning extends 
HashJoinStateCalculator<PostBuildCalculations> {
     void initialize(boolean autoTune,
@@ -93,13 +101,14 @@ public interface HashJoinMemoryCalculator extends 
HashJoinStateCalculator<HashJo
                     RecordBatch buildSideBatch,
                     RecordBatch probeSideBatch,
                     Set<String> joinColumns,
+                    boolean probeEmpty,
                     long memoryAvailable,
+                    long maxIncomingBatchSize,
                     int initialPartitions,
                     int recordsPerPartitionBatchBuild,
                     int recordsPerPartitionBatchProbe,
                     int maxBatchNumRecordsBuild,
                     int maxBatchNumRecordsProbe,
-                    int outputBatchNumRecords,
                     int outputBatchSize,
                     double loadFactor);
 
@@ -121,7 +130,13 @@ public interface HashJoinMemoryCalculator extends 
HashJoinStateCalculator<HashJo
    * {@link HashJoinState#POST_BUILD_CALCULATIONS} state.
    */
   interface PostBuildCalculations extends 
HashJoinStateCalculator<HashJoinMemoryCalculator> {
-    void initialize();
+    /**
+     * Initializes the calculator with additional information needed.
+     * @param probeEmty True if the probe is empty. False otherwise.
+     */
+    void initialize(boolean probeEmty);
+
+    int getProbeRecordsPerBatch();
 
     boolean shouldSpill();
 
diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculatorImpl.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculatorImpl.java
index 37f3329..a351cbc 100644
--- 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculatorImpl.java
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculatorImpl.java
@@ -73,7 +73,9 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
         throw new IllegalArgumentException("Invalid calc type: " + 
hashTableCalculatorType);
       }
 
-      return new BuildSidePartitioningImpl(hashTableSizeCalculator,
+      return new BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
+        hashTableSizeCalculator,
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
         fragmentationFactor, safetyFactor);
     } else {
@@ -86,65 +88,28 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
     return INITIALIZING;
   }
 
-  public static long computeMaxBatchSizeNoHash(final long incomingBatchSize,
-                                         final int incomingNumRecords,
-                                         final int desiredNumRecords,
-                                         final double fragmentationFactor,
-                                         final double safetyFactor) {
-    long maxBatchSize = HashJoinMemoryCalculatorImpl
-      .computePartitionBatchSize(incomingBatchSize, incomingNumRecords, 
desiredNumRecords);
-    // Multiple by fragmentation factor
-    return RecordBatchSizer.multiplyByFactors(maxBatchSize, 
fragmentationFactor, safetyFactor);
-  }
-
-  public static long computeMaxBatchSize(final long incomingBatchSize,
-                                         final int incomingNumRecords,
-                                         final int desiredNumRecords,
-                                         final double fragmentationFactor,
-                                         final double safetyFactor,
-                                         final boolean reserveHash) {
-    long size = computeMaxBatchSizeNoHash(incomingBatchSize,
-      incomingNumRecords,
-      desiredNumRecords,
-      fragmentationFactor,
-      safetyFactor);
-
-    if (!reserveHash) {
-      return size;
-    }
-
-    long hashSize = desiredNumRecords * ((long) IntVector.VALUE_WIDTH);
-    hashSize = RecordBatchSizer.multiplyByFactors(hashSize, 
fragmentationFactor);
-
-    return size + hashSize;
-  }
-
-  public static long computePartitionBatchSize(final long incomingBatchSize,
-                                               final int incomingNumRecords,
-                                               final int desiredNumRecords) {
-    return (long) Math.ceil((((double) incomingBatchSize) /
-      ((double) incomingNumRecords)) *
-      ((double) desiredNumRecords));
-  }
-
   public static class NoopBuildSidePartitioningImpl implements 
BuildSidePartitioning {
     private int initialPartitions;
+    private int recordsPerPartitionBatchProbe;
 
     @Override
     public void initialize(boolean autoTune,
                            boolean reserveHash,
                            RecordBatch buildSideBatch,
-                           RecordBatch probeSideBatch, Set<String> joinColumns,
+                           RecordBatch probeSideBatch,
+                           Set<String> joinColumns,
+                           boolean probeEmpty,
                            long memoryAvailable,
+                           long maxIncomingBatchSize,
                            int initialPartitions,
                            int recordsPerPartitionBatchBuild,
                            int recordsPerPartitionBatchProbe,
                            int maxBatchNumRecordsBuild,
                            int maxBatchNumRecordsProbe,
-                           int outputBatchNumRecords,
                            int outputBatchSize,
                            double loadFactor) {
       this.initialPartitions = initialPartitions;
+      this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
     }
 
     @Override
@@ -180,7 +145,7 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
     @Nullable
     @Override
     public PostBuildCalculations next() {
-      return new NoopPostBuildCalculationsImpl();
+      return new NoopPostBuildCalculationsImpl(recordsPerPartitionBatchProbe);
     }
 
     @Override
@@ -204,7 +169,7 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
    * <h1>Life Cycle</h1>
    * <p>
    *   <ul>
-   *     <li><b>Step 0:</b> Call {@link #initialize(boolean, boolean, 
RecordBatch, RecordBatch, Set, long, int, int, int, int, int, int, int, 
double)}.
+   *     <li><b>Step 0:</b> Call {@link #initialize(boolean, boolean, 
RecordBatch, RecordBatch, Set, boolean, long, long, int, int, int, int, int, 
int, double)}.
    *     This will initialize the StateCalculate with the additional 
information it needs.</li>
    *     <li><b>Step 1:</b> Call {@link #getNumPartitions()} to see the number 
of partitions that fit in memory.</li>
    *     <li><b>Step 2:</b> Call {@link #shouldSpill()} To determine if 
spilling needs to occurr.</li>
@@ -215,6 +180,7 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
   public static class BuildSidePartitioningImpl implements 
BuildSidePartitioning {
     public static final Logger log = 
LoggerFactory.getLogger(BuildSidePartitioning.class);
 
+    private final BatchSizePredictor.Factory batchSizePredictorFactory;
     private final HashTableSizeCalculator hashTableSizeCalculator;
     private final HashJoinHelperSizeCalculator hashJoinHelperSizeCalculator;
     private final double fragmentationFactor;
@@ -223,10 +189,8 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
     private int maxBatchNumRecordsBuild;
     private int maxBatchNumRecordsProbe;
     private long memoryAvailable;
-    private long buildBatchSize;
-    private long probeBatchSize;
-    private int buildNumRecords;
-    private int probeNumRecords;
+    private boolean probeEmpty;
+    private long maxIncomingBatchSize;
     private long maxBuildBatchSize;
     private long maxProbeBatchSize;
     private long maxOutputBatchSize;
@@ -246,13 +210,17 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
     private long reservedMemory;
     private long maxReservedMemory;
 
+    private BatchSizePredictor buildSizePredictor;
+    private BatchSizePredictor probeSizePredictor;
     private boolean firstInitialized;
     private boolean initialized;
 
-    public BuildSidePartitioningImpl(final HashTableSizeCalculator 
hashTableSizeCalculator,
+    public BuildSidePartitioningImpl(final BatchSizePredictor.Factory 
batchSizePredictorFactory,
+                                     final HashTableSizeCalculator 
hashTableSizeCalculator,
                                      final HashJoinHelperSizeCalculator 
hashJoinHelperSizeCalculator,
                                      final double fragmentationFactor,
                                      final double safetyFactor) {
+      this.batchSizePredictorFactory = 
Preconditions.checkNotNull(batchSizePredictorFactory);
       this.hashTableSizeCalculator = 
Preconditions.checkNotNull(hashTableSizeCalculator);
       this.hashJoinHelperSizeCalculator = 
Preconditions.checkNotNull(hashJoinHelperSizeCalculator);
       this.fragmentationFactor = fragmentationFactor;
@@ -262,35 +230,33 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
     @Override
     public void initialize(boolean autoTune,
                            boolean reserveHash,
-                           RecordBatch buildSideBatch,
-                           RecordBatch probeSideBatch,
+                           RecordBatch buildBatch,
+                           RecordBatch probeBatch,
                            Set<String> joinColumns,
+                           boolean probeEmpty,
                            long memoryAvailable,
+                           long maxIncomingBatchSize,
                            int initialPartitions,
                            int recordsPerPartitionBatchBuild,
                            int recordsPerPartitionBatchProbe,
                            int maxBatchNumRecordsBuild,
                            int maxBatchNumRecordsProbe,
-                           int outputBatchNumRecords,
                            int outputBatchSize,
                            double loadFactor) {
-      Preconditions.checkNotNull(buildSideBatch);
-      Preconditions.checkNotNull(probeSideBatch);
+      Preconditions.checkNotNull(probeBatch);
+      Preconditions.checkNotNull(buildBatch);
       Preconditions.checkNotNull(joinColumns);
 
-      final RecordBatchSizer buildSizer = new RecordBatchSizer(buildSideBatch);
-      final RecordBatchSizer probeSizer = new RecordBatchSizer(probeSideBatch);
+      final BatchSizePredictor buildSizePredictor =
+        batchSizePredictorFactory.create(buildBatch, fragmentationFactor, 
safetyFactor);
+      final BatchSizePredictor probeSizePredictor =
+        batchSizePredictorFactory.create(probeBatch, fragmentationFactor, 
safetyFactor);
 
-      long buildBatchSize = getBatchSizeEstimate(buildSideBatch);
-      long probeBatchSize = getBatchSizeEstimate(probeSideBatch);
+      buildSizePredictor.updateStats();
+      probeSizePredictor.updateStats();
 
-      int buildNumRecords = buildSizer.rowCount();
-      int probeNumRecords = probeSizer.rowCount();
+      final RecordBatchSizer buildSizer = new RecordBatchSizer(buildBatch);
 
-      final CaseInsensitiveMap<Long> buildValueSizes = 
getNotExcludedColumnSizes(
-        joinColumns, buildSizer);
-      final CaseInsensitiveMap<Long> probeValueSizes = 
getNotExcludedColumnSizes(
-        joinColumns, probeSizer);
       final CaseInsensitiveMap<Long> keySizes = 
CaseInsensitiveMap.newHashMap();
 
       for (String joinColumn: joinColumns) {
@@ -302,11 +268,11 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
         reserveHash,
         keySizes,
         memoryAvailable,
+        maxIncomingBatchSize,
         initialPartitions,
-        buildBatchSize,
-        probeBatchSize,
-        buildNumRecords,
-        probeNumRecords,
+        probeEmpty,
+        buildSizePredictor,
+        probeSizePredictor,
         recordsPerPartitionBatchBuild,
         recordsPerPartitionBatchProbe,
         maxBatchNumRecordsBuild,
@@ -316,47 +282,15 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
     }
 
     @VisibleForTesting
-    protected static CaseInsensitiveMap<Long> getNotExcludedColumnSizes(
-        final Set<String> excludedColumns,
-        final RecordBatchSizer batchSizer) {
-      final CaseInsensitiveMap<Long> columnSizes = 
CaseInsensitiveMap.newHashMap();
-      final CaseInsensitiveMap<Boolean> excludedSet = 
CaseInsensitiveMap.newHashMap();
-
-      for (final String excludedColumn: excludedColumns) {
-        excludedSet.put(excludedColumn, true);
-      }
-
-      for (final Map.Entry<String, RecordBatchSizer.ColumnSize> entry: 
batchSizer.columns().entrySet()) {
-        final String columnName = entry.getKey();
-        final RecordBatchSizer.ColumnSize columnSize = entry.getValue();
-
-        columnSizes.put(columnName, (long) 
columnSize.getStdNetOrNetSizePerEntry());
-      }
-
-      return columnSizes;
-    }
-
-    public static long getBatchSizeEstimate(final RecordBatch recordBatch) {
-      final RecordBatchSizer sizer = new RecordBatchSizer(recordBatch);
-      long size = 0L;
-
-      for (Map.Entry<String, RecordBatchSizer.ColumnSize> column: 
sizer.columns().entrySet()) {
-        size += 
PostBuildCalculationsImpl.computeValueVectorSize(recordBatch.getRecordCount(), 
column.getValue().getStdNetOrNetSizePerEntry());
-      }
-
-      return size;
-    }
-
-    @VisibleForTesting
     protected void initialize(boolean autoTune,
                               boolean reserveHash,
                               CaseInsensitiveMap<Long> keySizes,
                               long memoryAvailable,
+                              long maxIncomingBatchSize,
                               int initialPartitions,
-                              long buildBatchSize,
-                              long probeBatchSize,
-                              int buildNumRecords,
-                              int probeNumRecords,
+                              boolean probeEmpty,
+                              BatchSizePredictor buildSizePredictor,
+                              BatchSizePredictor probeSizePredictor,
                               int recordsPerPartitionBatchBuild,
                               int recordsPerPartitionBatchProbe,
                               int maxBatchNumRecordsBuild,
@@ -365,6 +299,9 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
                               double loadFactor) {
       Preconditions.checkState(!firstInitialized);
       Preconditions.checkArgument(initialPartitions >= 1);
+      // If we had probe data before there should still be probe data now.
+      // If we didn't have probe data before we could get some new data now.
+      Preconditions.checkState(!(probeEmpty && 
probeSizePredictor.hadDataLastTime()));
       firstInitialized = true;
 
       this.loadFactor = loadFactor;
@@ -372,10 +309,10 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
       this.reserveHash = reserveHash;
       this.keySizes = Preconditions.checkNotNull(keySizes);
       this.memoryAvailable = memoryAvailable;
-      this.buildBatchSize = buildBatchSize;
-      this.probeBatchSize = probeBatchSize;
-      this.buildNumRecords = buildNumRecords;
-      this.probeNumRecords = probeNumRecords;
+      this.probeEmpty = probeEmpty;
+      this.maxIncomingBatchSize = maxIncomingBatchSize;
+      this.buildSizePredictor = buildSizePredictor;
+      this.probeSizePredictor = probeSizePredictor;
       this.initialPartitions = initialPartitions;
       this.recordsPerPartitionBatchBuild = recordsPerPartitionBatchBuild;
       this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
@@ -420,31 +357,32 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
     private void calculateMemoryUsage()
     {
       // Adjust based on number of records
-      maxBuildBatchSize = computeMaxBatchSizeNoHash(buildBatchSize, 
buildNumRecords,
-        maxBatchNumRecordsBuild, fragmentationFactor, safetyFactor);
-      maxProbeBatchSize = computeMaxBatchSizeNoHash(probeBatchSize, 
probeNumRecords,
-        maxBatchNumRecordsProbe, fragmentationFactor, safetyFactor);
-
-      // Safety factor can be multiplied at the end since these batches are 
coming from exchange operators, so no excess value vector doubling
-      partitionBuildBatchSize = computeMaxBatchSize(buildBatchSize,
-        buildNumRecords,
-        recordsPerPartitionBatchBuild,
-        fragmentationFactor,
-        safetyFactor,
-        reserveHash);
+      maxBuildBatchSize = 
buildSizePredictor.predictBatchSize(maxBatchNumRecordsBuild, false);
 
-      // Safety factor can be multiplied at the end since these batches are 
coming from exchange operators, so no excess value vector doubling
-      partitionProbeBatchSize = computeMaxBatchSize(
-        probeBatchSize,
-        probeNumRecords,
-        recordsPerPartitionBatchProbe,
-        fragmentationFactor,
-        safetyFactor,
-        reserveHash);
+      if (probeSizePredictor.hadDataLastTime()) {
+        // We have probe data and we can compute the max incoming size.
+        maxProbeBatchSize = 
probeSizePredictor.predictBatchSize(maxBatchNumRecordsProbe, false);
+      } else {
+        // We don't have probe data
+        if (probeEmpty) {
+          // We know the probe has no data, so we don't need to reserve any 
space for the incoming probe
+          maxProbeBatchSize = 0;
+        } else {
+          // The probe side may have data, so assume it is the max incoming 
batch size. This assumption
+          // can fail in some cases since the batch sizing project is 
incomplete.
+          maxProbeBatchSize = maxIncomingBatchSize;
+        }
+      }
+
+      partitionBuildBatchSize = 
buildSizePredictor.predictBatchSize(recordsPerPartitionBatchBuild, reserveHash);
+
+      if (probeSizePredictor.hadDataLastTime()) {
+        partitionProbeBatchSize = 
probeSizePredictor.predictBatchSize(recordsPerPartitionBatchProbe, reserveHash);
+      }
 
       maxOutputBatchSize = (long) ((double)outputBatchSize * 
fragmentationFactor * safetyFactor);
 
-      long probeReservedMemory;
+      long probeReservedMemory = 0;
 
       for (partitions = initialPartitions;; partitions /= 2) {
         // The total amount of memory to reserve for incomplete batches across 
all partitions
@@ -455,13 +393,19 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
         // they will have a well defined size.
         reservedMemory = incompletePartitionsBatchSizes + maxBuildBatchSize + 
maxProbeBatchSize;
 
-        probeReservedMemory = 
PostBuildCalculationsImpl.calculateReservedMemory(
-          partitions,
-          maxProbeBatchSize,
-          maxOutputBatchSize,
-          partitionProbeBatchSize);
+        if (probeSizePredictor.hadDataLastTime()) {
+          // If we have probe data, use it in our memory reservation 
calculations.
+          probeReservedMemory = 
PostBuildCalculationsImpl.calculateReservedMemory(
+            partitions,
+            maxProbeBatchSize,
+            maxOutputBatchSize,
+            partitionProbeBatchSize);
 
-        maxReservedMemory = Math.max(reservedMemory, probeReservedMemory);
+          maxReservedMemory = Math.max(reservedMemory, probeReservedMemory);
+        } else {
+          // If we do not have probe data, do our best effort at estimating 
the number of partitions without it.
+          maxReservedMemory = reservedMemory;
+        }
 
         if (!autoTune || maxReservedMemory <= memoryAvailable) {
           // Stop the tuning loop if we are not doing auto tuning, or if we 
are living within our memory limit
@@ -488,19 +432,19 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
           "partitionProbeBatchSize = %d\n" +
           "recordsPerPartitionBatchProbe = %d\n",
           reservedMemory, memoryAvailable, partitions, initialPartitions,
-          buildBatchSize,
-          buildNumRecords,
+          buildSizePredictor.getBatchSize(),
+          buildSizePredictor.getNumRecords(),
           partitionBuildBatchSize,
           recordsPerPartitionBatchBuild,
-          probeBatchSize,
-          probeNumRecords,
+          probeSizePredictor.getBatchSize(),
+          probeSizePredictor.getNumRecords(),
           partitionProbeBatchSize,
           recordsPerPartitionBatchProbe);
 
         String phase = "Probe phase: ";
 
         if (reservedMemory > memoryAvailable) {
-          if (probeReservedMemory > memoryAvailable) {
+          if (probeSizePredictor.hadDataLastTime() && probeReservedMemory > 
memoryAvailable) {
             phase = "Build and Probe phases: ";
           } else {
             phase = "Build phase: ";
@@ -531,10 +475,12 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
     public PostBuildCalculations next() {
       Preconditions.checkState(initialized);
 
-      return new PostBuildCalculationsImpl(memoryAvailable,
-        partitionProbeBatchSize,
-        maxProbeBatchSize,
+      return new PostBuildCalculationsImpl(
+        probeSizePredictor,
+        memoryAvailable,
         maxOutputBatchSize,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         partitionStatsSet,
         keySizes,
         hashTableSizeCalculator,
@@ -572,9 +518,19 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
   }
 
   public static class NoopPostBuildCalculationsImpl implements 
PostBuildCalculations {
+    private final int recordsPerPartitionBatchProbe;
+
+    public NoopPostBuildCalculationsImpl(final int 
recordsPerPartitionBatchProbe) {
+      this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
+    }
+
     @Override
-    public void initialize() {
+    public void initialize(boolean hasProbeData) {
+    }
 
+    @Override
+    public int getProbeRecordsPerBatch() {
+      return recordsPerPartitionBatchProbe;
     }
 
     @Override
@@ -610,7 +566,7 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
    * <h1>Lifecycle</h1>
    * <p>
    *   <ul>
-   *     <li><b>Step 1:</b> Call {@link #initialize()}. This
+   *     <li><b>Step 1:</b> Call {@link #initialize(boolean)}. This
    *     gives the {@link HashJoinStateCalculator} additional information it 
needs to compute memory requirements.</li>
    *     <li><b>Step 2:</b> Call {@link #shouldSpill()}. This tells
    *     you which build side partitions need to be spilled in order to make 
room for probing.</li>
@@ -620,10 +576,15 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
    * </p>
    */
   public static class PostBuildCalculationsImpl implements 
PostBuildCalculations {
+    private static final Logger log = 
LoggerFactory.getLogger(PostBuildCalculationsImpl.class);
+
+    public static final int MIN_RECORDS_PER_PARTITION_BATCH_PROBE = 10;
+
+    private final BatchSizePredictor probeSizePredictor;
     private final long memoryAvailable;
-    private final long partitionProbeBatchSize;
-    private final long maxProbeBatchSize;
     private final long maxOutputBatchSize;
+    private final int maxBatchNumRecordsProbe;
+    private final int recordsPerPartitionBatchProbe;
     private final PartitionStatSet buildPartitionStatSet;
     private final Map<String, Long> keySizes;
     private final HashTableSizeCalculator hashTableSizeCalculator;
@@ -632,26 +593,30 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
     private final double safetyFactor;
     private final double loadFactor;
     private final boolean reserveHash;
-    // private final long maxOutputBatchSize;
 
     private boolean initialized;
     private long consumedMemory;
+    private boolean probeEmpty;
+    private long maxProbeBatchSize;
+    private long partitionProbeBatchSize;
+    private int computedProbeRecordsPerBatch;
 
-    public PostBuildCalculationsImpl(final long memoryAvailable,
-                                     final long partitionProbeBatchSize,
-                                     final long maxProbeBatchSize,
-                                     final long maxOutputBatchSize,
-                                     final PartitionStatSet 
buildPartitionStatSet,
-                                     final Map<String, Long> keySizes,
-                                     final HashTableSizeCalculator 
hashTableSizeCalculator,
-                                     final HashJoinHelperSizeCalculator 
hashJoinHelperSizeCalculator,
-                                     final double fragmentationFactor,
-                                     final double safetyFactor,
-                                     final double loadFactor,
-                                     final boolean reserveHash) {
+    @VisibleForTesting
+    public PostBuildCalculationsImpl(final BatchSizePredictor 
probeSizePredictor,
+                                      final long memoryAvailable,
+                                      final long maxOutputBatchSize,
+                                      final int maxBatchNumRecordsProbe,
+                                      final int recordsPerPartitionBatchProbe,
+                                      final PartitionStatSet 
buildPartitionStatSet,
+                                      final Map<String, Long> keySizes,
+                                      final HashTableSizeCalculator 
hashTableSizeCalculator,
+                                      final HashJoinHelperSizeCalculator 
hashJoinHelperSizeCalculator,
+                                      final double fragmentationFactor,
+                                      final double safetyFactor,
+                                      final double loadFactor,
+                                      final boolean reserveHash) {
+      this.probeSizePredictor = Preconditions.checkNotNull(probeSizePredictor);
       this.memoryAvailable = memoryAvailable;
-      this.partitionProbeBatchSize = partitionProbeBatchSize;
-      this.maxProbeBatchSize = maxProbeBatchSize;
       this.maxOutputBatchSize = maxOutputBatchSize;
       this.buildPartitionStatSet = 
Preconditions.checkNotNull(buildPartitionStatSet);
       this.keySizes = Preconditions.checkNotNull(keySizes);
@@ -661,38 +626,100 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
       this.safetyFactor = safetyFactor;
       this.loadFactor = loadFactor;
       this.reserveHash = reserveHash;
+      this.maxBatchNumRecordsProbe = maxBatchNumRecordsProbe;
+      this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
+      this.computedProbeRecordsPerBatch = recordsPerPartitionBatchProbe;
     }
 
-    // TODO take an incoming Probe RecordBatch
     @Override
-    public void initialize() {
+    public void initialize(boolean probeEmpty) {
       Preconditions.checkState(!initialized);
+      // If we had probe data before there should still be probe data now.
+      // If we didn't have probe data before we could get some new data now.
+      Preconditions.checkState(probeSizePredictor.hadDataLastTime() && 
!probeEmpty || !probeSizePredictor.hadDataLastTime());
       initialized = true;
+      this.probeEmpty = probeEmpty;
+
+      if (probeEmpty) {
+        // We know there is no probe side data, so we don't need to calculate 
anything.
+        return;
+      }
+
+      // We need to compute sizes of probe side data.
+      if (!probeSizePredictor.hadDataLastTime()) {
+        probeSizePredictor.updateStats();
+      }
+
+      maxProbeBatchSize = 
probeSizePredictor.predictBatchSize(maxBatchNumRecordsProbe, false);
+      partitionProbeBatchSize = 
probeSizePredictor.predictBatchSize(recordsPerPartitionBatchProbe, reserveHash);
+
+      long worstCaseProbeMemory = calculateReservedMemory(
+        buildPartitionStatSet.getSize(),
+        maxProbeBatchSize,
+        maxOutputBatchSize,
+        partitionProbeBatchSize);
+
+      if (worstCaseProbeMemory > memoryAvailable) {
+        // We don't have enough memory for the probe data if all the 
partitions are spilled, we need to adjust the records
+        // per probe partition batch in order to make this work.
+
+        computedProbeRecordsPerBatch = 
computeProbeRecordsPerBatch(memoryAvailable,
+          buildPartitionStatSet.getSize(),
+          recordsPerPartitionBatchProbe,
+          MIN_RECORDS_PER_PARTITION_BATCH_PROBE,
+          maxProbeBatchSize,
+          maxOutputBatchSize,
+          partitionProbeBatchSize);
+
+        partitionProbeBatchSize = 
probeSizePredictor.predictBatchSize(computedProbeRecordsPerBatch, reserveHash);
+      }
     }
 
-    public long getConsumedMemory() {
+    @Override
+    public int getProbeRecordsPerBatch() {
       Preconditions.checkState(initialized);
-      return consumedMemory;
+      return computedProbeRecordsPerBatch;
     }
 
-    // TODO move this somewhere else that makes sense
-    public static long computeValueVectorSize(long numRecords, long byteSize)
-    {
-      long naiveSize = numRecords * byteSize;
-      return roundUpToPowerOf2(naiveSize);
+    @VisibleForTesting
+    public long getMaxProbeBatchSize() {
+      return maxProbeBatchSize;
     }
 
-    public static long computeValueVectorSize(long numRecords, long byteSize, 
double safetyFactor)
-    {
-      long naiveSize = RecordBatchSizer.multiplyByFactor(numRecords * 
byteSize, safetyFactor);
-      return roundUpToPowerOf2(naiveSize);
+    @VisibleForTesting
+    public long getPartitionProbeBatchSize() {
+      return partitionProbeBatchSize;
     }
 
-    // TODO move to drill common
-    public static long roundUpToPowerOf2(long num)
-    {
-      Preconditions.checkArgument(num >= 1);
-      return num == 1 ? 1 : Long.highestOneBit(num - 1) << 1;
+    public long getConsumedMemory() {
+      Preconditions.checkState(initialized);
+      return consumedMemory;
+    }
+
+    public static int computeProbeRecordsPerBatch(final long memoryAvailable,
+                                                  final int numPartitions,
+                                                  final int 
defaultProbeRecordsPerBatch,
+                                                  final int 
minProbeRecordsPerBatch,
+                                                  final long maxProbeBatchSize,
+                                                  final long 
maxOutputBatchSize,
+                                                  final long 
defaultPartitionProbeBatchSize) {
+      long memoryForPartitionBatches = memoryAvailable - maxProbeBatchSize - 
maxOutputBatchSize;
+
+      if (memoryForPartitionBatches < 0) {
+        // We just don't have enough memory. We should do our best though by 
using the minimum batch size.
+        log.warn("Not enough memory for probing:\n" +
+          "Memory available: {}\n" +
+          "Max probe batch size: {}\n" +
+          "Max output batch size: {}",
+          memoryAvailable,
+          maxProbeBatchSize,
+          maxOutputBatchSize);
+        return minProbeRecordsPerBatch;
+      }
+
+      long memoryForPartitionBatch = (memoryForPartitionBatches + 
numPartitions - 1) / numPartitions;
+      long scaleFactor = (defaultPartitionProbeBatchSize + 
memoryForPartitionBatch - 1) / memoryForPartitionBatch;
+      return Math.max((int) (defaultProbeRecordsPerBatch / scaleFactor), 
minProbeRecordsPerBatch);
     }
 
     public static long calculateReservedMemory(final int numSpilledPartitions,
@@ -710,6 +737,11 @@ public class HashJoinMemoryCalculatorImpl implements 
HashJoinMemoryCalculator {
     public boolean shouldSpill() {
       Preconditions.checkState(initialized);
 
+      if (probeEmpty) {
+        // If the probe is empty, we should not trigger any spills.
+        return false;
+      }
+
       long reservedMemory = calculateReservedMemory(
         buildPartitionStatSet.getNumSpilledPartitions(),
         maxProbeBatchSize,
diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorConservativeImpl.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorConservativeImpl.java
index 8575021..a366eea 100644
--- 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorConservativeImpl.java
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorConservativeImpl.java
@@ -23,7 +23,7 @@ import org.apache.drill.exec.vector.IntVector;
 
 import java.util.Map;
 
-import static 
org.apache.drill.exec.physical.impl.join.HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize;
+import static 
org.apache.drill.exec.physical.impl.join.BatchSizePredictorImpl.computeValueVectorSize;
 
 public class HashTableSizeCalculatorConservativeImpl implements 
HashTableSizeCalculator {
   public static final String TYPE = "CONSERVATIVE";
diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorLeanImpl.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorLeanImpl.java
index 4f9e585..265b0e3 100644
--- 
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorLeanImpl.java
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorLeanImpl.java
@@ -23,7 +23,7 @@ import org.apache.drill.exec.vector.IntVector;
 
 import java.util.Map;
 
-import static 
org.apache.drill.exec.physical.impl.join.HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize;
+import static 
org.apache.drill.exec.physical.impl.join.BatchSizePredictorImpl.computeValueVectorSize;
 
 public class HashTableSizeCalculatorLeanImpl implements 
HashTableSizeCalculator {
   public static final String TYPE = "LEAN";
diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/record/AbstractBinaryRecordBatch.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/record/AbstractBinaryRecordBatch.java
index d75463b..e7fa4e6 100644
--- 
a/exec/java-exec/src/main/java/org/apache/drill/exec/record/AbstractBinaryRecordBatch.java
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/record/AbstractBinaryRecordBatch.java
@@ -76,12 +76,12 @@ public abstract class AbstractBinaryRecordBatch<T extends 
PhysicalOperator> exte
   }
 
   protected boolean verifyOutcomeToSetBatchState(IterOutcome leftOutcome, 
IterOutcome rightOutcome) {
-    if (leftOutcome == IterOutcome.STOP || rightUpstream == IterOutcome.STOP) {
+    if (leftOutcome == IterOutcome.STOP || rightOutcome == IterOutcome.STOP) {
       state = BatchState.STOP;
       return false;
     }
 
-    if (leftOutcome == IterOutcome.OUT_OF_MEMORY || rightUpstream == 
IterOutcome.OUT_OF_MEMORY) {
+    if (leftOutcome == IterOutcome.OUT_OF_MEMORY || rightOutcome == 
IterOutcome.OUT_OF_MEMORY) {
       state = BatchState.OUT_OF_MEMORY;
       return false;
     }
@@ -97,6 +97,7 @@ public abstract class AbstractBinaryRecordBatch<T extends 
PhysicalOperator> exte
       throw new IllegalStateException("Unexpected IterOutcome.EMIT received 
either from left or right side in " +
         "buildSchema phase");
     }
+
     return true;
   }
 
diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatch.java 
b/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatch.java
index 6954374..f0cab26 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatch.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatch.java
@@ -116,7 +116,7 @@ public interface RecordBatch extends VectorAccessible {
      *   returned at least once (not necessarily <em>immediately</em> after).
      * </p>
      */
-    NONE,
+    NONE(false),
 
     /**
      * Zero or more records with same schema.
@@ -134,7 +134,7 @@ public interface RecordBatch extends VectorAccessible {
      *   returned at least once (not necessarily <em>immediately</em> after).
      * </p>
      */
-    OK,
+    OK(false),
 
     /**
      * New schema, maybe with records.
@@ -147,7 +147,7 @@ public interface RecordBatch extends VectorAccessible {
      *     ({@code next()} should be called again.)
      * </p>
      */
-    OK_NEW_SCHEMA,
+    OK_NEW_SCHEMA(false),
 
     /**
      * Non-completion (abnormal) termination.
@@ -162,7 +162,7 @@ public interface RecordBatch extends VectorAccessible {
      *   of things.
      * </p>
      */
-    STOP,
+    STOP(true),
 
     /**
      * No data yet.
@@ -184,7 +184,7 @@ public interface RecordBatch extends VectorAccessible {
      *   Used by batches that haven't received incoming data yet.
      * </p>
      */
-    NOT_YET,
+    NOT_YET(false),
 
     /**
      * Out of memory (not fatal).
@@ -198,7 +198,7 @@ public interface RecordBatch extends VectorAccessible {
      *     {@code OUT_OF_MEMORY} to its caller) and call {@code next()} again.
      * </p>
      */
-    OUT_OF_MEMORY,
+    OUT_OF_MEMORY(true),
 
     /**
      * Emit record to produce output batches.
@@ -223,7 +223,17 @@ public interface RecordBatch extends VectorAccessible {
      *   input and again start from build side.
      * </p>
      */
-    EMIT,
+    EMIT(false);
+
+    private boolean error;
+
+    IterOutcome(boolean error) {
+      this.error = error;
+    }
+
+    public boolean isError() {
+      return error;
+    }
   }
 
   /**
diff --git 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculatorImpl.java
 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBatchSizePredictorImpl.java
similarity index 57%
copy from 
exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculatorImpl.java
copy to 
exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBatchSizePredictorImpl.java
index 4fe1fa4..e16cdf6 100644
--- 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculatorImpl.java
+++ 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBatchSizePredictorImpl.java
@@ -21,62 +21,81 @@ import org.apache.drill.exec.vector.IntVector;
 import org.junit.Assert;
 import org.junit.Test;
 
-public class TestHashJoinMemoryCalculatorImpl {
+public class TestBatchSizePredictorImpl {
   @Test
-  public void testComputeMaxBatchSizeNoHash() {
-    final long expected = 1200;
-    final long actual = HashJoinMemoryCalculatorImpl.computeMaxBatchSize(
+  public void testComputeMaxBatchSizeHash()
+  {
+    long expected = BatchSizePredictorImpl.computeMaxBatchSizeNoHash(
       100,
       25,
       100,
       2.0,
-      1.5,
-      false);
-    final long actualNoHash = 
HashJoinMemoryCalculatorImpl.computeMaxBatchSizeNoHash(
+      4.0) +
+      100 * IntVector.VALUE_WIDTH * 2;
+
+    final long actual = BatchSizePredictorImpl.computeMaxBatchSize(
       100,
       25,
       100,
       2.0,
-      1.5);
+      4.0,
+      true);
 
     Assert.assertEquals(expected, actual);
-    Assert.assertEquals(expected, actualNoHash);
   }
 
   @Test
-  public void testComputeMaxBatchSizeHash()
-  {
-    long expected = HashJoinMemoryCalculatorImpl.computeMaxBatchSizeNoHash(
+  public void testComputeMaxBatchSizeNoHash() {
+    final long expected = 1200;
+    final long actual = BatchSizePredictorImpl.computeMaxBatchSize(
       100,
       25,
       100,
       2.0,
-      4.0) +
-      100 * IntVector.VALUE_WIDTH * 2;
-
-    final long actual = HashJoinMemoryCalculatorImpl.computeMaxBatchSize(
+      1.5,
+      false);
+    final long actualNoHash = BatchSizePredictorImpl.computeMaxBatchSizeNoHash(
       100,
       25,
       100,
       2.0,
-      4.0,
-      true);
+      1.5);
 
     Assert.assertEquals(expected, actual);
+    Assert.assertEquals(expected, actualNoHash);
   }
 
-  @Test // Make sure no exception is thrown
-  public void testMakeDebugString()
-  {
-    final PartitionStatImpl partitionStat1 = new PartitionStatImpl();
-    final PartitionStatImpl partitionStat2 = new PartitionStatImpl();
-    final PartitionStatImpl partitionStat3 = new PartitionStatImpl();
-    final PartitionStatImpl partitionStat4 = new PartitionStatImpl();
+  @Test
+  public void testRoundUpPowerOf2() {
+    long expected = 32;
+    long actual = BatchSizePredictorImpl.roundUpToPowerOf2(expected);
+
+    Assert.assertEquals(expected, actual);
+  }
+
+  @Test
+  public void testRounUpNonPowerOf2ToPowerOf2() {
+    long expected = 32;
+    long actual = BatchSizePredictorImpl.roundUpToPowerOf2(31);
 
-    final HashJoinMemoryCalculator.PartitionStatSet partitionStatSet =
-      new HashJoinMemoryCalculator.PartitionStatSet(partitionStat1, 
partitionStat2, partitionStat3, partitionStat4);
-    partitionStat1.add(new HashJoinMemoryCalculator.BatchStat(10, 7));
-    partitionStat2.add(new HashJoinMemoryCalculator.BatchStat(11, 20));
-    partitionStat3.spill();
+    Assert.assertEquals(expected, actual);
+  }
+
+  @Test
+  public void testComputeValueVectorSizePowerOf2() {
+    long expected = 4;
+    long actual =
+      BatchSizePredictorImpl.computeValueVectorSize(2, 2);
+
+    Assert.assertEquals(expected, actual);
+  }
+
+  @Test
+  public void testComputeValueVectorSizeNonPowerOf2() {
+    long expected = 16;
+    long actual =
+      BatchSizePredictorImpl.computeValueVectorSize(3, 3);
+
+    Assert.assertEquals(expected, actual);
   }
 }
diff --git 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBuildSidePartitioningImpl.java
 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBuildSidePartitioningImpl.java
index 2a44edb..ceebc81 100644
--- 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBuildSidePartitioningImpl.java
+++ 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBuildSidePartitioningImpl.java
@@ -17,6 +17,7 @@
  */
 package org.apache.drill.exec.physical.impl.join;
 
+import com.google.common.base.Preconditions;
 import org.apache.drill.common.map.CaseInsensitiveMap;
 import org.apache.drill.exec.record.RecordBatch;
 import org.junit.Assert;
@@ -26,26 +27,28 @@ public class TestBuildSidePartitioningImpl {
   @Test
   public void testSimpleReserveMemoryCalculationNoHash() {
     final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
     final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
       new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
         new 
HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
-        2.0,
-        1.5);
+        fragmentationFactor,
+        safetyFactor);
 
-    final CaseInsensitiveMap<Long> buildValueSizes = 
CaseInsensitiveMap.newHashMap();
-    final CaseInsensitiveMap<Long> probeValueSizes = 
CaseInsensitiveMap.newHashMap();
     final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
 
     calc.initialize(true,
       false,
       keySizes,
       200,
+      100,
       2,
-      20,
-      10,
-      20,
-      10,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
       10,
       5,
       maxBatchNumRecords,
@@ -69,26 +72,28 @@ public class TestBuildSidePartitioningImpl {
   @Test
   public void testSimpleReserveMemoryCalculationHash() {
     final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
     final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
       new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
         new 
HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
-        2.0,
-        1.5);
+        fragmentationFactor,
+        safetyFactor);
 
-    final CaseInsensitiveMap<Long> buildValueSizes = 
CaseInsensitiveMap.newHashMap();
-    final CaseInsensitiveMap<Long> probeValueSizes = 
CaseInsensitiveMap.newHashMap();
     final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
 
     calc.initialize(false,
       true,
       keySizes,
       350,
+      100, // Ignored for test
       2,
-      20,
-      10,
-      20,
-      10,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
       10,
       5,
       maxBatchNumRecords,
@@ -112,15 +117,17 @@ public class TestBuildSidePartitioningImpl {
   @Test
   public void testAdjustInitialPartitions() {
     final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
     final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
       new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
         new 
HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
-        2.0,
-        1.5);
+        fragmentationFactor,
+        safetyFactor);
 
-    final CaseInsensitiveMap<Long> buildValueSizes = 
CaseInsensitiveMap.newHashMap();
-    final CaseInsensitiveMap<Long> probeValueSizes = 
CaseInsensitiveMap.newHashMap();
     final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
 
     calc.initialize(
@@ -128,11 +135,11 @@ public class TestBuildSidePartitioningImpl {
       false,
       keySizes,
       200,
+      100, // Ignored for test
       4,
-      20,
-      10,
-      20,
-      10,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
       10,
       5,
       maxBatchNumRecords,
@@ -154,19 +161,148 @@ public class TestBuildSidePartitioningImpl {
     Assert.assertEquals(2, calc.getNumPartitions());
   }
 
+  @Test(expected = IllegalStateException.class)
+  public void testHasDataProbeEmpty() {
+    final int maxIncomingBatchSize = 100;
+    final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
+      new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
+        new 
HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
+        HashJoinHelperSizeCalculatorImpl.INSTANCE,
+        fragmentationFactor,
+        safetyFactor);
+
+    final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
+
+    calc.initialize(
+      true,
+      false,
+      keySizes,
+      240,
+      maxIncomingBatchSize,
+      4,
+      true,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
+      10,
+      5,
+      maxBatchNumRecords,
+      maxBatchNumRecords,
+      16000,
+      .75);
+  }
+
+  @Test
+  public void testNoProbeDataForStats() {
+    final int maxIncomingBatchSize = 100;
+    final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
+      new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
+        new 
HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
+        HashJoinHelperSizeCalculatorImpl.INSTANCE,
+        fragmentationFactor,
+        safetyFactor);
+
+    final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
+
+    calc.initialize(
+      true,
+      false,
+      keySizes,
+      240,
+      maxIncomingBatchSize,
+      4,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(),
+      10,
+      5,
+      maxBatchNumRecords,
+      maxBatchNumRecords,
+      16000,
+      .75);
+
+    final HashJoinMemoryCalculator.PartitionStatSet partitionStatSet =
+      new HashJoinMemoryCalculator.PartitionStatSet(new PartitionStatImpl(), 
new PartitionStatImpl());
+    calc.setPartitionStatSet(partitionStatSet);
+
+    long expectedReservedMemory = 60 // Max incoming batch size
+      + 2 * 30 // build side batch for each spilled partition
+      + maxIncomingBatchSize;
+    long actualReservedMemory = calc.getBuildReservedMemory();
+
+    Assert.assertEquals(expectedReservedMemory, actualReservedMemory);
+    Assert.assertEquals(2, calc.getNumPartitions());
+  }
+
+  @Test
+  public void testProbeEmpty() {
+    final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
+      new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
+        new 
HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
+        HashJoinHelperSizeCalculatorImpl.INSTANCE,
+        fragmentationFactor,
+        safetyFactor);
+
+    final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
+
+    calc.initialize(
+      true,
+      false,
+      keySizes,
+      200,
+      100, // Ignored for test
+      4,
+      true,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(),
+      10,
+      5,
+      maxBatchNumRecords,
+      maxBatchNumRecords,
+      16000,
+      .75);
+
+    final HashJoinMemoryCalculator.PartitionStatSet partitionStatSet =
+      new HashJoinMemoryCalculator.PartitionStatSet(new PartitionStatImpl(), 
new PartitionStatImpl(),
+        new PartitionStatImpl(), new PartitionStatImpl());
+    calc.setPartitionStatSet(partitionStatSet);
+
+    long expectedReservedMemory = 60 // Max incoming batch size
+      + 4 * 30; // build side batch for each spilled partition
+    long actualReservedMemory = calc.getBuildReservedMemory();
+
+    Assert.assertEquals(expectedReservedMemory, actualReservedMemory);
+    Assert.assertEquals(4, calc.getNumPartitions());
+  }
+
   @Test
   public void testNoRoomInMemoryForBatch1() {
     final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
 
     final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
       new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
         new 
HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
-        2.0,
-        1.5);
+        fragmentationFactor,
+        safetyFactor);
 
-    final CaseInsensitiveMap<Long> buildValueSizes = 
CaseInsensitiveMap.newHashMap();
-    final CaseInsensitiveMap<Long> probeValueSizes = 
CaseInsensitiveMap.newHashMap();
     final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
 
     calc.initialize(
@@ -174,11 +310,11 @@ public class TestBuildSidePartitioningImpl {
       false,
       keySizes,
       180,
+      100, // Ignored for test
       2,
-      20,
-      10,
-      20,
-      10,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
       10,
       5,
       maxBatchNumRecords,
@@ -207,15 +343,17 @@ public class TestBuildSidePartitioningImpl {
   @Test
   public void testCompleteLifeCycle() {
     final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
     final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
       new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
         new 
HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
-        2.0,
-        1.5);
+        fragmentationFactor,
+        safetyFactor);
 
-    final CaseInsensitiveMap<Long> buildValueSizes = 
CaseInsensitiveMap.newHashMap();
-    final CaseInsensitiveMap<Long> probeValueSizes = 
CaseInsensitiveMap.newHashMap();
     final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
 
     calc.initialize(
@@ -223,11 +361,11 @@ public class TestBuildSidePartitioningImpl {
       false,
       keySizes,
       210,
+      100, // Ignored for test
       2,
-      20,
-      10,
-      20,
-      10,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
       10,
       5,
       maxBatchNumRecords,
@@ -276,4 +414,61 @@ public class TestBuildSidePartitioningImpl {
 
     Assert.assertNotNull(calc.next());
   }
+
+  public static class MockBatchSizePredictor implements BatchSizePredictor {
+    private final boolean hasData;
+    private final long batchSize;
+    private final int numRecords;
+    private final double fragmentationFactor;
+    private final double safetyFactor;
+
+    public MockBatchSizePredictor() {
+      hasData = false;
+      batchSize = 0;
+      numRecords = 0;
+      fragmentationFactor = 0;
+      safetyFactor = 0;
+    }
+
+    public MockBatchSizePredictor(final long batchSize,
+                                  final int numRecords,
+                                  final double fragmentationFactor,
+                                  final double safetyFactor) {
+      hasData = true;
+      this.batchSize = batchSize;
+      this.numRecords = numRecords;
+      this.fragmentationFactor = fragmentationFactor;
+      this.safetyFactor = safetyFactor;
+    }
+
+    @Override
+    public long getBatchSize() {
+      return batchSize;
+    }
+
+    @Override
+    public int getNumRecords() {
+      return numRecords;
+    }
+
+    @Override
+    public boolean hadDataLastTime() {
+      return hasData;
+    }
+
+    @Override
+    public void updateStats() {
+    }
+
+    @Override
+    public long predictBatchSize(int desiredNumRecords, boolean reserveHash) {
+      Preconditions.checkState(hasData);
+      return BatchSizePredictorImpl.computeMaxBatchSize(batchSize,
+        numRecords,
+        desiredNumRecords,
+        fragmentationFactor,
+        safetyFactor,
+        reserveHash);
+    }
+  }
 }
diff --git 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculatorImpl.java
 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculator.java
similarity index 61%
rename from 
exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculatorImpl.java
rename to 
exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculator.java
index 4fe1fa4..b13829b 100644
--- 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculatorImpl.java
+++ 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculator.java
@@ -17,54 +17,9 @@
  */
 package org.apache.drill.exec.physical.impl.join;
 
-import org.apache.drill.exec.vector.IntVector;
-import org.junit.Assert;
 import org.junit.Test;
 
-public class TestHashJoinMemoryCalculatorImpl {
-  @Test
-  public void testComputeMaxBatchSizeNoHash() {
-    final long expected = 1200;
-    final long actual = HashJoinMemoryCalculatorImpl.computeMaxBatchSize(
-      100,
-      25,
-      100,
-      2.0,
-      1.5,
-      false);
-    final long actualNoHash = 
HashJoinMemoryCalculatorImpl.computeMaxBatchSizeNoHash(
-      100,
-      25,
-      100,
-      2.0,
-      1.5);
-
-    Assert.assertEquals(expected, actual);
-    Assert.assertEquals(expected, actualNoHash);
-  }
-
-  @Test
-  public void testComputeMaxBatchSizeHash()
-  {
-    long expected = HashJoinMemoryCalculatorImpl.computeMaxBatchSizeNoHash(
-      100,
-      25,
-      100,
-      2.0,
-      4.0) +
-      100 * IntVector.VALUE_WIDTH * 2;
-
-    final long actual = HashJoinMemoryCalculatorImpl.computeMaxBatchSize(
-      100,
-      25,
-      100,
-      2.0,
-      4.0,
-      true);
-
-    Assert.assertEquals(expected, actual);
-  }
-
+public class TestHashJoinMemoryCalculator {
   @Test // Make sure no exception is thrown
   public void testMakeDebugString()
   {
@@ -78,5 +33,7 @@ public class TestHashJoinMemoryCalculatorImpl {
     partitionStat1.add(new HashJoinMemoryCalculator.BatchStat(10, 7));
     partitionStat2.add(new HashJoinMemoryCalculator.BatchStat(11, 20));
     partitionStat3.spill();
+
+    partitionStatSet.makeDebugString();
   }
 }
diff --git 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorConservativeImpl.java
 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorConservativeImpl.java
index 3f01bca..813fc35 100644
--- 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorConservativeImpl.java
+++ 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorConservativeImpl.java
@@ -42,14 +42,14 @@ public class TestHashTableSizeCalculatorConservativeImpl {
     long expected = RecordBatchSizer.multiplyByFactor(
       UInt4Vector.VALUE_WIDTH * 128, 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
     // First bucket key value vector sizes
-    expected += 
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(maxNumRecords,
 3L);
-    expected += 
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(maxNumRecords,
 8L);
+    expected += BatchSizePredictorImpl.computeValueVectorSize(maxNumRecords, 
3L);
+    expected += BatchSizePredictorImpl.computeValueVectorSize(maxNumRecords, 
8L);
 
     // Second bucket key value vector sizes
     expected += RecordBatchSizer.multiplyByFactor(
-      
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(20,
 3L), HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
+      BatchSizePredictorImpl.computeValueVectorSize(20, 3L), 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
     expected += RecordBatchSizer.multiplyByFactor(
-      
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(20,
 8L), HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
+      BatchSizePredictorImpl.computeValueVectorSize(20, 8L), 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
 
     // Overhead vectors for links and hash values for each batchHolder
     expected += 2 * UInt4Vector.VALUE_WIDTH // links and hash values */
diff --git 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorLeanImpl.java
 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorLeanImpl.java
index 1bd51fc..3390cea 100644
--- 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorLeanImpl.java
+++ 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorLeanImpl.java
@@ -42,13 +42,13 @@ public class TestHashTableSizeCalculatorLeanImpl {
     long expected = RecordBatchSizer.multiplyByFactor(
       UInt4Vector.VALUE_WIDTH * 128, 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
     // First bucket key value vector sizes
-    expected += 
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(maxNumRecords,
 3L);
-    expected += 
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(maxNumRecords,
 8L);
+    expected += BatchSizePredictorImpl.computeValueVectorSize(maxNumRecords, 
3L);
+    expected += BatchSizePredictorImpl.computeValueVectorSize(maxNumRecords, 
8L);
 
     // Second bucket key value vector sizes
-    expected += 
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(20,
 3L);
+    expected += BatchSizePredictorImpl.computeValueVectorSize(20, 3L);
     expected += RecordBatchSizer.multiplyByFactor(
-      
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(20,
 8L), HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
+      BatchSizePredictorImpl.computeValueVectorSize(20, 8L), 
HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
 
     // Overhead vectors for links and hash values for each batchHolder
     expected += 2 * UInt4Vector.VALUE_WIDTH // links and hash values */
diff --git 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestPostBuildCalculationsImpl.java
 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestPostBuildCalculationsImpl.java
index 5cf7eca..aa7a435 100644
--- 
a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestPostBuildCalculationsImpl.java
+++ 
b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestPostBuildCalculationsImpl.java
@@ -17,44 +17,229 @@
  */
 package org.apache.drill.exec.physical.impl.join;
 
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
 import org.junit.Assert;
 import org.junit.Test;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Map;
 
 public class TestPostBuildCalculationsImpl {
   @Test
-  public void testRoundUpPowerOf2() {
-    long expected = 32;
-    long actual = 
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.roundUpToPowerOf2(expected);
+  public void testProbeTooBig() {
+    final int minProbeRecordsPerBatch = 10;
 
-    Assert.assertEquals(expected, actual);
+    final int computedProbeRecordsPerBatch =
+      
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeProbeRecordsPerBatch(
+        100,
+        2,
+        100,
+        minProbeRecordsPerBatch,
+        70,
+        40,
+        200);
+
+    Assert.assertEquals(minProbeRecordsPerBatch, computedProbeRecordsPerBatch);
+  }
+
+  @Test
+  public void testComputedShouldBeMin() {
+    final int minProbeRecordsPerBatch = 10;
+
+    final int computedProbeRecordsPerBatch =
+      
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeProbeRecordsPerBatch(
+        100,
+        2,
+        100,
+        minProbeRecordsPerBatch,
+        50,
+        40,
+        200);
+
+    Assert.assertEquals(minProbeRecordsPerBatch, computedProbeRecordsPerBatch);
   }
 
   @Test
-  public void testRounUpNonPowerOf2ToPowerOf2() {
-    long expected = 32;
-    long actual = 
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.roundUpToPowerOf2(31);
+  public void testComputedProbeRecordsPerBatch() {
+    final int minProbeRecordsPerBatch = 10;
+
+    final int computedProbeRecordsPerBatch =
+      
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeProbeRecordsPerBatch(
+        200,
+        2,
+        100,
+        minProbeRecordsPerBatch,
+        50,
+        50,
+        200);
+
+    Assert.assertEquals(25, computedProbeRecordsPerBatch);
+  }
+
+  @Test
+  public void testComputedProbeRecordsPerBatchRoundUp() {
+    final int minProbeRecordsPerBatch = 10;
+
+    final int computedProbeRecordsPerBatch =
+      
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeProbeRecordsPerBatch(
+        200,
+        2,
+        100,
+        minProbeRecordsPerBatch,
+        50,
+        51,
+        199);
+
+    Assert.assertEquals(25, computedProbeRecordsPerBatch);
+  }
+
+  @Test(expected = IllegalStateException.class)
+  public void testHasProbeDataButProbeEmpty() {
+    final Map<String, Long> keySizes = 
org.apache.drill.common.map.CaseInsensitiveMap.newHashMap();
+
+    final PartitionStatImpl partition1 = new PartitionStatImpl();
+    final PartitionStatImpl partition2 = new PartitionStatImpl();
+    final HashJoinMemoryCalculator.PartitionStatSet buildPartitionStatSet =
+      new HashJoinMemoryCalculator.PartitionStatSet(partition1, partition2);
+
+    final int recordsPerPartitionBatchBuild = 10;
+
+    addBatches(partition1, recordsPerPartitionBatchBuild,
+      10, 4);
+    addBatches(partition2, recordsPerPartitionBatchBuild,
+      10, 4);
 
-    Assert.assertEquals(expected, actual);
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
+
+    final HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
+      new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, 
recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
+        290, // memoryAvailable
+        20, // maxOutputBatchSize
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
+        buildPartitionStatSet, // buildPartitionStatSet
+        keySizes, // keySizes
+        new MockHashTableSizeCalculator(10), // hashTableSizeCalculator
+        new MockHashJoinHelperSizeCalculator(10), // 
hashJoinHelperSizeCalculator
+        fragmentationFactor, // fragmentationFactor
+        safetyFactor, // safetyFactor
+        .75, // loadFactor
+        false); // reserveHash
+
+    calc.initialize(true);
   }
 
   @Test
-  public void testComputeValueVectorSizePowerOf2() {
-    long expected = 4;
-    long actual =
-      
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(2,
 2);
+  public void testProbeEmpty() {
+    final Map<String, Long> keySizes = 
org.apache.drill.common.map.CaseInsensitiveMap.newHashMap();
+
+    final PartitionStatImpl partition1 = new PartitionStatImpl();
+    final PartitionStatImpl partition2 = new PartitionStatImpl();
+    final HashJoinMemoryCalculator.PartitionStatSet buildPartitionStatSet =
+      new HashJoinMemoryCalculator.PartitionStatSet(partition1, partition2);
+
+    final int recordsPerPartitionBatchBuild = 10;
+
+    addBatches(partition1, recordsPerPartitionBatchBuild,
+      10, 4);
+    addBatches(partition2, recordsPerPartitionBatchBuild,
+      10, 4);
 
-    Assert.assertEquals(expected, actual);
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 40;
+    final long maxProbeBatchSize = 10000;
+
+    final HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
+      new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(),
+        50,
+        1000,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
+        buildPartitionStatSet,
+        keySizes,
+        new MockHashTableSizeCalculator(10),
+        new MockHashJoinHelperSizeCalculator(10),
+        fragmentationFactor,
+        safetyFactor,
+        .75,
+        true);
+
+    calc.initialize(true);
+
+    Assert.assertFalse(calc.shouldSpill());
+    Assert.assertFalse(calc.shouldSpill());
   }
 
   @Test
-  public void testComputeValueVectorSizeNonPowerOf2() {
-    long expected = 16;
-    long actual =
-      
HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(3,
 3);
+  public void testHasNoProbeDataButProbeNonEmpty() {
+    final Map<String, Long> keySizes = 
org.apache.drill.common.map.CaseInsensitiveMap.newHashMap();
 
-    Assert.assertEquals(expected, actual);
+    final PartitionStatImpl partition1 = new PartitionStatImpl();
+    final PartitionStatImpl partition2 = new PartitionStatImpl();
+    final HashJoinMemoryCalculator.PartitionStatSet buildPartitionStatSet =
+      new HashJoinMemoryCalculator.PartitionStatSet(partition1, partition2);
+
+    final int recordsPerPartitionBatchBuild = 10;
+
+    addBatches(partition1, recordsPerPartitionBatchBuild,
+      10, 4);
+    addBatches(partition2, recordsPerPartitionBatchBuild,
+      10, 4);
+
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
+
+    final HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
+      new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, 
recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          false),
+        290,
+        20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
+        buildPartitionStatSet,
+        keySizes,
+        new MockHashTableSizeCalculator(10),
+        new MockHashJoinHelperSizeCalculator(10),
+        fragmentationFactor,
+        safetyFactor,
+        .75,
+        false);
+
+    calc.initialize(false);
+
+    long expected = 60 // maxProbeBatchSize
+      + 160 // in memory partitions
+      + 20 // max output batch size
+      + 2 * 10 // Hash Table
+      + 2 * 10; // Hash join helper
+    Assert.assertFalse(calc.shouldSpill());
+    Assert.assertEquals(expected, calc.getConsumedMemory());
+    Assert.assertNull(calc.next());
   }
 
   @Test
@@ -76,12 +261,21 @@ public class TestPostBuildCalculationsImpl {
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
 
-    HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
+
+    final HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, 
recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         290,
-        15,
-        60,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(10),
@@ -91,7 +285,7 @@ public class TestPostBuildCalculationsImpl {
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
 
     long expected = 60 // maxProbeBatchSize
       + 160 // in memory partitions
@@ -122,12 +316,21 @@ public class TestPostBuildCalculationsImpl {
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
 
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
+
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, 
recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         270,
-        15,
-        60,
         20,
+         maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(10),
@@ -137,7 +340,7 @@ public class TestPostBuildCalculationsImpl {
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
 
     long expected = 60 // maxProbeBatchSize
       + 160 // in memory partitions
@@ -174,12 +377,21 @@ public class TestPostBuildCalculationsImpl {
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
 
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
+
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, 
recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         180,
-        15,
-        60,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(10),
@@ -189,7 +401,7 @@ public class TestPostBuildCalculationsImpl {
         .75,
         true);
 
-    calc.initialize();
+    calc.initialize(false);
 
     long expected = 60 // maxProbeBatchSize
       + 2 * 5 * 3 // partition batches
@@ -215,15 +427,24 @@ public class TestPostBuildCalculationsImpl {
 
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
+
     final long hashTableSize = 10;
     final long hashJoinHelperSize = 10;
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
 
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, 
recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         200,
-        15,
-        60,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(hashTableSize),
@@ -233,7 +454,7 @@ public class TestPostBuildCalculationsImpl {
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
 
     long expected = 60 // maxProbeBatchSize
       + 80 // in memory partition
@@ -269,15 +490,24 @@ public class TestPostBuildCalculationsImpl {
 
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
+
     final long hashTableSize = 10;
     final long hashJoinHelperSize = 10;
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
 
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, 
recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         230,
-        15,
-        60,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(hashTableSize),
@@ -287,7 +517,7 @@ public class TestPostBuildCalculationsImpl {
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
 
     long expected = 60 // maxProbeBatchSize
       + 80 // in memory partition
@@ -317,15 +547,24 @@ public class TestPostBuildCalculationsImpl {
 
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
+
     final long hashTableSize = 10;
     final long hashJoinHelperSize = 10;
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
 
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
-        100,
-        15,
-        60,
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, 
recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
+        110,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(hashTableSize),
@@ -335,7 +574,7 @@ public class TestPostBuildCalculationsImpl {
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
     Assert.assertFalse(calc.shouldSpill());
     Assert.assertEquals(110, calc.getConsumedMemory());
     Assert.assertNotNull(calc.next());
@@ -362,15 +601,24 @@ public class TestPostBuildCalculationsImpl {
 
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
+
     final long hashTableSize = 10;
     final long hashJoinHelperSize = 10;
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
 
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, 
recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         230,
-        15,
-        60,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(hashTableSize),
@@ -380,7 +628,7 @@ public class TestPostBuildCalculationsImpl {
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
   }
 
   private void addBatches(PartitionStatImpl partitionStat,
@@ -431,4 +679,66 @@ public class TestPostBuildCalculationsImpl {
       return size;
     }
   }
+
+  public static class ConditionalMockBatchSizePredictor implements 
BatchSizePredictor {
+    private final List<Integer> recordsPerBatch;
+    private final List<Long> batchSize;
+
+    private boolean hasData;
+    private boolean updateable;
+
+    public ConditionalMockBatchSizePredictor() {
+      recordsPerBatch = new ArrayList<>();
+      batchSize = new ArrayList<>();
+      hasData = false;
+      updateable = true;
+    }
+
+    public ConditionalMockBatchSizePredictor(final List<Integer> 
recordsPerBatch,
+                                             final List<Long> batchSize,
+                                             final boolean hasData) {
+      this.recordsPerBatch = Preconditions.checkNotNull(recordsPerBatch);
+      this.batchSize = Preconditions.checkNotNull(batchSize);
+
+      Preconditions.checkArgument(recordsPerBatch.size() == batchSize.size());
+
+      this.hasData = hasData;
+      updateable = true;
+    }
+
+    @Override
+    public long getBatchSize() {
+      return 0;
+    }
+
+    @Override
+    public int getNumRecords() {
+      return 0;
+    }
+
+    @Override
+    public boolean hadDataLastTime() {
+      return hasData;
+    }
+
+    @Override
+    public void updateStats() {
+      Preconditions.checkState(updateable);
+      updateable = false;
+      hasData = true;
+    }
+
+    @Override
+    public long predictBatchSize(int desiredNumRecords, boolean reserveHash) {
+      Preconditions.checkState(hasData);
+
+      for (int index = 0; index < recordsPerBatch.size(); index++) {
+        if (desiredNumRecords == recordsPerBatch.get(index)) {
+          return batchSize.get(index);
+        }
+      }
+
+      throw new IllegalArgumentException();
+    }
+  }
 }

Reply via email to