You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@drill.apache.org by GitBox <gi...@apache.org> on 2018/08/13 11:53:01 UTC

[GitHub] asfgit closed pull request #1408: DRILL-6453: Resolve deadlock when reading from build and probe sides simultaneously in HashJoin

asfgit closed pull request #1408: DRILL-6453: Resolve deadlock when reading from build and probe sides simultaneously in HashJoin
URL: https://github.com/apache/drill/pull/1408
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java
index eaccd335527..fbdc4f3b8a1 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java
@@ -17,6 +17,7 @@
  */
 package org.apache.drill.exec.physical.impl.common;
 
+import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 import org.apache.drill.common.exceptions.RetryAfterSpillException;
 import org.apache.drill.common.exceptions.UserException;
@@ -122,6 +123,7 @@
   private List<HashJoinMemoryCalculator.BatchStat> inMemoryBatchStats = Lists.newArrayList();
   private long partitionInMemorySize;
   private long numInMemoryRecords;
+  private boolean updatedRecordsPerBatch = false;
 
   public HashPartition(FragmentContext context, BufferAllocator allocator, ChainedHashTable baseHashTable,
                        RecordBatch buildBatch, RecordBatch probeBatch,
@@ -155,6 +157,18 @@ public HashPartition(FragmentContext context, BufferAllocator allocator, Chained
     }
   }
 
+  /**
+   * Configure a different temporary batch size when spilling probe batches.
+   * @param newRecordsPerBatch The new temporary batch size to use.
+   */
+  public void updateProbeRecordsPerBatch(int newRecordsPerBatch) {
+    Preconditions.checkArgument(newRecordsPerBatch > 0);
+    Preconditions.checkState(!updatedRecordsPerBatch); // Only allow updating once
+    Preconditions.checkState(processingOuter); // We can only update the records per batch when probing.
+
+    recordsPerBatch = newRecordsPerBatch;
+  }
+
   /**
    * Allocate a new vector container for either right or left record batch
    * Add an additional special vector for the hash values
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictor.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictor.java
new file mode 100644
index 00000000000..912e4feaf3c
--- /dev/null
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictor.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.physical.impl.join;
+
+import org.apache.drill.exec.record.RecordBatch;
+
+/**
+ * This class predicts the sizes of batches given an input batch.
+ *
+ * <h4>Invariants</h4>
+ * <ul>
+ *   <li>The {@link BatchSizePredictor} assumes that a {@link RecordBatch} is in a state where it can return a valid record count.</li>
+ * </ul>
+ */
+public interface BatchSizePredictor {
+  /**
+   * Gets the batchSize computed in the call to {@link #updateStats()}. Returns 0 if {@link #hadDataLastTime()} is false.
+   * @return Gets the batchSize computed in the call to {@link #updateStats()}. Returns 0 if {@link #hadDataLastTime()} is false.
+   * @throws IllegalStateException if {@link #updateStats()} was never called.
+   */
+  long getBatchSize();
+
+  /**
+   * Gets the number of records computed in the call to {@link #updateStats()}. Returns 0 if {@link #hadDataLastTime()} is false.
+   * @return Gets the number of records computed in the call to {@link #updateStats()}. Returns 0 if {@link #hadDataLastTime()} is false.
+   * @throws IllegalStateException if {@link #updateStats()} was never called.
+   */
+  int getNumRecords();
+
+  /**
+   * True if the input batch had records in the last call to {@link #updateStats()}. False otherwise.
+   * @return True if the input batch had records in the last call to {@link #updateStats()}. False otherwise.
+   */
+  boolean hadDataLastTime();
+
+  /**
+   * This method can be called multiple times to collect stats about the latest data in the provided record batch. These
+   * stats are used to predict batch sizes. If the batch currently has no data, this method is a noop. This method must be
+   * called at least once before {@link #predictBatchSize(int, boolean)}.
+   */
+  void updateStats();
+
+  /**
+   * Predicts the size of a batch using the current collected stats.
+   * @param desiredNumRecords The number of records contained in the batch whose size we want to predict.
+   * @param reserveHash Whether or not to include a column containing hash values.
+   * @return The size of the predicted batch.
+   * @throws IllegalStateException if {@link #hadDataLastTime()} is false or {@link #updateStats()} was not called.
+   */
+  long predictBatchSize(int desiredNumRecords, boolean reserveHash);
+
+  /**
+   * A factory for creating {@link BatchSizePredictor}s.
+   */
+  interface Factory {
+    /**
+     * Creates a predictor with a batch whose data needs to be used to predict other batch sizes.
+     * @param batch The batch whose size needs to be predicted.
+     * @param fragmentationFactor A constant used to predict value vector doubling.
+     * @param safetyFactor A constant used to leave padding for unpredictable incoming batches.
+     */
+    BatchSizePredictor create(RecordBatch batch, double fragmentationFactor, double safetyFactor);
+  }
+}
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictorImpl.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictorImpl.java
new file mode 100644
index 00000000000..bbebd2bd7eb
--- /dev/null
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/BatchSizePredictorImpl.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.physical.impl.join;
+
+import com.google.common.base.Preconditions;
+import org.apache.drill.exec.record.RecordBatch;
+import org.apache.drill.exec.record.RecordBatchSizer;
+import org.apache.drill.exec.vector.IntVector;
+
+import java.util.Map;
+
+public class BatchSizePredictorImpl implements BatchSizePredictor {
+  private RecordBatch batch;
+  private double fragmentationFactor;
+  private double safetyFactor;
+
+  private long batchSize;
+  private int numRecords;
+  private boolean updatedStats;
+  private boolean hasData;
+
+  public BatchSizePredictorImpl(final RecordBatch batch,
+                                final double fragmentationFactor,
+                                final double safetyFactor) {
+    this.batch = Preconditions.checkNotNull(batch);
+    this.fragmentationFactor = fragmentationFactor;
+    this.safetyFactor = safetyFactor;
+  }
+
+  @Override
+  public long getBatchSize() {
+    Preconditions.checkState(updatedStats);
+    return hasData? batchSize: 0;
+  }
+
+  @Override
+  public int getNumRecords() {
+    Preconditions.checkState(updatedStats);
+    return hasData? numRecords: 0;
+  }
+
+  @Override
+  public boolean hadDataLastTime() {
+    return hasData;
+  }
+
+  @Override
+  public void updateStats() {
+    final RecordBatchSizer batchSizer = new RecordBatchSizer(batch);
+    numRecords = batchSizer.rowCount();
+    updatedStats = true;
+    hasData = numRecords > 0;
+
+    if (hasData) {
+      batchSize = getBatchSizeEstimate(batch);
+    }
+  }
+
+  @Override
+  public long predictBatchSize(int desiredNumRecords, boolean reserveHash) {
+    Preconditions.checkState(hasData);
+    // Safety factor can be multiplied at the end since these batches are coming from exchange operators, so no excess value vector doubling
+    return computeMaxBatchSize(batchSize,
+      numRecords,
+      desiredNumRecords,
+      fragmentationFactor,
+      safetyFactor,
+      reserveHash);
+  }
+
+  public static long computeValueVectorSize(long numRecords, long byteSize) {
+    long naiveSize = numRecords * byteSize;
+    return roundUpToPowerOf2(naiveSize);
+  }
+
+  public static long computeValueVectorSize(long numRecords, long byteSize, double safetyFactor) {
+    long naiveSize = RecordBatchSizer.multiplyByFactor(numRecords * byteSize, safetyFactor);
+    return roundUpToPowerOf2(naiveSize);
+  }
+
+  public static long roundUpToPowerOf2(long num) {
+    Preconditions.checkArgument(num >= 1);
+    return num == 1 ? 1 : Long.highestOneBit(num - 1) << 1;
+  }
+
+  public static long computeMaxBatchSizeNoHash(final long incomingBatchSize,
+                                         final int incomingNumRecords,
+                                         final int desiredNumRecords,
+                                         final double fragmentationFactor,
+                                         final double safetyFactor) {
+    long maxBatchSize = computePartitionBatchSize(incomingBatchSize, incomingNumRecords, desiredNumRecords);
+    // Multiple by fragmentation factor
+    return RecordBatchSizer.multiplyByFactors(maxBatchSize, fragmentationFactor, safetyFactor);
+  }
+
+  public static long computeMaxBatchSize(final long incomingBatchSize,
+                                         final int incomingNumRecords,
+                                         final int desiredNumRecords,
+                                         final double fragmentationFactor,
+                                         final double safetyFactor,
+                                         final boolean reserveHash) {
+    long size = computeMaxBatchSizeNoHash(incomingBatchSize,
+      incomingNumRecords,
+      desiredNumRecords,
+      fragmentationFactor,
+      safetyFactor);
+
+    if (!reserveHash) {
+      return size;
+    }
+
+    long hashSize = desiredNumRecords * ((long) IntVector.VALUE_WIDTH);
+    hashSize = RecordBatchSizer.multiplyByFactors(hashSize, fragmentationFactor);
+
+    return size + hashSize;
+  }
+
+  public static long computePartitionBatchSize(final long incomingBatchSize,
+                                               final int incomingNumRecords,
+                                               final int desiredNumRecords) {
+    return (long) Math.ceil((((double) incomingBatchSize) /
+      ((double) incomingNumRecords)) *
+      ((double) desiredNumRecords));
+  }
+
+  public static long getBatchSizeEstimate(final RecordBatch recordBatch) {
+    final RecordBatchSizer sizer = new RecordBatchSizer(recordBatch);
+    long size = 0L;
+
+    for (Map.Entry<String, RecordBatchSizer.ColumnSize> column : sizer.columns().entrySet()) {
+      size += computeValueVectorSize(recordBatch.getRecordCount(), column.getValue().getStdNetOrNetSizePerEntry());
+    }
+
+    return size;
+  }
+
+  public static class Factory implements BatchSizePredictor.Factory {
+    public static final Factory INSTANCE = new Factory();
+
+    private Factory() {
+    }
+
+    @Override
+    public BatchSizePredictor create(final RecordBatch batch,
+                                     final double fragmentationFactor,
+                                     final double safetyFactor) {
+      return new BatchSizePredictorImpl(batch, fragmentationFactor, safetyFactor);
+    }
+  }
+}
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinBatch.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinBatch.java
index d4d4f927e3f..0a040c1a9c7 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinBatch.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinBatch.java
@@ -27,6 +27,7 @@
 import com.google.common.collect.Sets;
 
 import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang3.mutable.MutableBoolean;
 import org.apache.drill.common.exceptions.UserException;
 import org.apache.drill.common.expression.FieldReference;
 import org.apache.drill.common.expression.PathSegment;
@@ -68,6 +69,9 @@
 import org.apache.drill.exec.vector.complex.AbstractContainerVector;
 import org.apache.calcite.rel.core.JoinRelType;
 
+import static org.apache.drill.exec.record.RecordBatch.IterOutcome.EMIT;
+import static org.apache.drill.exec.record.RecordBatch.IterOutcome.OK_NEW_SCHEMA;
+
 /**
  *   This class implements the runtime execution for the Hash-Join operator
  *   supporting INNER, LEFT OUTER, RIGHT OUTER, and FULL OUTER joins
@@ -114,6 +118,7 @@
 
   // Fields used for partitioning
 
+  private long maxIncomingBatchSize;
   /**
    * The number of {@link HashPartition}s. This is configured via a system option and set in {@link #partitionNumTuning(int, HashJoinMemoryCalculator.BuildSidePartitioning)}.
    */
@@ -125,7 +130,8 @@
    * The master class used to generate {@link HashTable}s.
    */
   private ChainedHashTable baseHashTable;
-  private boolean buildSideIsEmpty = true;
+  private MutableBoolean buildSideIsEmpty = new MutableBoolean(false);
+  private MutableBoolean probeSideIsEmpty = new MutableBoolean(false);
   private boolean canSpill = true;
   private boolean wasKilled; // a kill was received, may need to clean spilled partns
 
@@ -138,7 +144,7 @@
   private int outputRecords;
 
   // Schema of the build side
-  private BatchSchema rightSchema;
+  private BatchSchema buildSchema;
   // Schema of the probe side
   private BatchSchema probeSchema;
 
@@ -150,9 +156,13 @@
   private RecordBatch probeBatch;
 
   /**
-   * Flag indicating whether or not the first data holding batch needs to be fetched.
+   * Flag indicating whether or not the first data holding build batch needs to be fetched.
+   */
+  private MutableBoolean prefetchedBuild = new MutableBoolean(false);
+  /**
+   * Flag indicating whether or not the first data holding probe batch needs to be fetched.
    */
-  private boolean prefetched;
+  private MutableBoolean prefetchedProbe = new MutableBoolean(false);
 
   // For handling spilling
   private SpillSet spillSet;
@@ -220,123 +230,120 @@ public int getRecordCount() {
   protected void buildSchema() throws SchemaChangeException {
     // We must first get the schemas from upstream operators before we can build
     // our schema.
-    boolean validSchema = sniffNewSchemas();
+    boolean validSchema = prefetchFirstBatchFromBothSides();
 
     if (validSchema) {
       // We are able to construct a valid schema from the upstream data.
       // Setting the state here makes sure AbstractRecordBatch returns OK_NEW_SCHEMA
       state = BatchState.BUILD_SCHEMA;
-    } else {
-      verifyOutcomeToSetBatchState(leftUpstream, rightUpstream);
+
+      if (leftUpstream == OK_NEW_SCHEMA) {
+        probeSchema = left.getSchema();
+      }
+
+      if (rightUpstream == OK_NEW_SCHEMA) {
+        buildSchema = right.getSchema();
+        // position of the new "column" for keeping the hash values (after the real columns)
+        rightHVColPosition = right.getContainer().getNumberOfColumns();
+        // We only need the hash tables if we have data on the build side.
+        setupHashTable();
+      }
+
+      try {
+        hashJoinProbe = setupHashJoinProbe();
+      } catch (IOException | ClassTransformationException e) {
+        throw new SchemaChangeException(e);
+      }
     }
 
     // If we have a valid schema, this will build a valid container. If we were unable to obtain a valid schema,
-    // we still need to build a dummy schema. These code handles both cases for us.
+    // we still need to build a dummy schema. This code handles both cases for us.
     setupOutputContainerSchema();
     container.buildSchema(BatchSchema.SelectionVectorMode.NONE);
-
-    // Initialize the hash join helper context
-    if (rightUpstream == IterOutcome.OK_NEW_SCHEMA) {
-      // We only need the hash tables if we have data on the build side.
-      setupHashTable();
-    }
-
-    try {
-      hashJoinProbe = setupHashJoinProbe();
-    } catch (IOException | ClassTransformationException e) {
-      throw new SchemaChangeException(e);
-    }
   }
 
-  @Override
-  protected boolean prefetchFirstBatchFromBothSides() {
-    if (leftUpstream != IterOutcome.NONE) {
-      // We can only get data if there is data available
-      leftUpstream = sniffNonEmptyBatch(leftUpstream, LEFT_INDEX, left);
-    }
-
-    if (rightUpstream != IterOutcome.NONE) {
-      // We can only get data if there is data available
-      rightUpstream = sniffNonEmptyBatch(rightUpstream, RIGHT_INDEX, right);
-    }
-
-    buildSideIsEmpty = rightUpstream == IterOutcome.NONE;
-
-    if (verifyOutcomeToSetBatchState(leftUpstream, rightUpstream)) {
-      // For build side, use aggregate i.e. average row width across batches
-      batchMemoryManager.update(LEFT_INDEX, 0);
-      batchMemoryManager.update(RIGHT_INDEX, 0, true);
-
-      logger.debug("BATCH_STATS, incoming left: {}", batchMemoryManager.getRecordBatchSizer(LEFT_INDEX));
-      logger.debug("BATCH_STATS, incoming right: {}", batchMemoryManager.getRecordBatchSizer(RIGHT_INDEX));
+  /**
+   * Prefetches the first build side data holding batch.
+   */
+  private void prefetchFirstBuildBatch() {
+    rightUpstream = prefetchFirstBatch(rightUpstream,
+      prefetchedBuild,
+      buildSideIsEmpty,
+      RIGHT_INDEX,
+      right,
+      () -> {
+        batchMemoryManager.update(RIGHT_INDEX, 0, true);
+        logger.debug("BATCH_STATS, incoming right: {}", batchMemoryManager.getRecordBatchSizer(RIGHT_INDEX));
+      });
+  }
 
-      // Got our first batche(s)
-      state = BatchState.FIRST;
-      return true;
-    } else {
-      return false;
-    }
+  /**
+   * Prefetches the first build side data holding batch.
+   */
+  private void prefetchFirstProbeBatch() {
+    leftUpstream =  prefetchFirstBatch(leftUpstream,
+      prefetchedProbe,
+      probeSideIsEmpty,
+      LEFT_INDEX,
+      left,
+      () -> {
+        batchMemoryManager.update(LEFT_INDEX, 0);
+        logger.debug("BATCH_STATS, incoming left: {}", batchMemoryManager.getRecordBatchSizer(LEFT_INDEX));
+      });
   }
 
   /**
-   * Sniffs all data necessary to construct a schema.
-   * @return True if all the data necessary to construct a schema has been retrieved. False otherwise.
+   * Used to fetch the first data holding batch from either the build or probe side.
+   * @param outcome The current upstream outcome for either the build or probe side.
+   * @param prefetched A flag indicating if we have already done a prefetch of the first data holding batch for the probe or build side.
+   * @param isEmpty A flag indicating if the probe or build side is empty.
+   * @param index The upstream index of the probe or build batch.
+   * @param batch The probe or build batch itself.
+   * @param memoryManagerUpdate A lambda function to execute the memory manager update for the probe or build batch.
+   * @return The current {@link org.apache.drill.exec.record.RecordBatch.IterOutcome}.
    */
-  private boolean sniffNewSchemas() {
-    do {
-      // Ask for data until we get a valid result.
-      leftUpstream = next(LEFT_INDEX, left);
-    } while (leftUpstream == IterOutcome.NOT_YET);
+  private IterOutcome prefetchFirstBatch(IterOutcome outcome,
+                                         final MutableBoolean prefetched,
+                                         final MutableBoolean isEmpty,
+                                         final int index,
+                                         final RecordBatch batch,
+                                         final Runnable memoryManagerUpdate) {
+    if (prefetched.booleanValue()) {
+      // We have already prefetch the first data holding batch
+      return outcome;
+    }
 
-    boolean isValidLeft = false;
+    // If we didn't retrieve our first data holding batch, we need to do it now.
+    prefetched.setValue(true);
 
-    switch (leftUpstream) {
-      case OK_NEW_SCHEMA:
-        probeSchema = probeBatch.getSchema();
-      case NONE:
-        isValidLeft = true;
-        break;
-      case OK:
-      case EMIT:
-        throw new IllegalStateException("Unsupported outcome while building schema " + leftUpstream);
-      default:
-        // Termination condition
+    if (outcome != IterOutcome.NONE) {
+      // We can only get data if there is data available
+      outcome = sniffNonEmptyBatch(outcome, index, batch);
     }
 
-    do {
-      // Ask for data until we get a valid result.
-      rightUpstream = next(RIGHT_INDEX, right);
-    } while (rightUpstream == IterOutcome.NOT_YET);
-
-    boolean isValidRight = false;
+    isEmpty.setValue(outcome == IterOutcome.NONE); // If we recieved NONE there is no data.
 
-    switch (rightUpstream) {
-      case OK_NEW_SCHEMA:
-        // We need to have the schema of the build side even when the build side is empty
-        rightSchema = buildBatch.getSchema();
-        // position of the new "column" for keeping the hash values (after the real columns)
-        rightHVColPosition = buildBatch.getContainer().getNumberOfColumns();
-      case NONE:
-        isValidRight = true;
-        break;
-      case OK:
-      case EMIT:
-        throw new IllegalStateException("Unsupported outcome while building schema " + leftUpstream);
-      default:
-        // Termination condition
+    if (outcome == IterOutcome.OUT_OF_MEMORY) {
+      // We reached a termination state
+      state = BatchState.OUT_OF_MEMORY;
+    } else if (outcome == IterOutcome.STOP) {
+      // We reached a termination state
+      state = BatchState.STOP;
+    } else {
+      // Got our first batch(es)
+      memoryManagerUpdate.run();
+      state = BatchState.FIRST;
     }
 
-    // Left and right sides must return a valid response and both sides cannot be NONE.
-    return (isValidLeft && isValidRight) &&
-      (leftUpstream != IterOutcome.NONE && rightUpstream != IterOutcome.NONE);
+    return outcome;
   }
 
   /**
-   * Currently in order to accurately predict memory usage for spilling, the first non-empty build side and probe side batches are needed. This method
-   * fetches the first non-empty batch from the left or right side.
+   * Currently in order to accurately predict memory usage for spilling, the first non-empty build or probe side batch is needed. This method
+   * fetches the first non-empty batch from the probe or build side.
    * @param curr The current outcome.
-   * @param inputIndex Index specifying whether to work with the left or right input.
-   * @param recordBatch The left or right record batch.
+   * @param inputIndex Index specifying whether to work with the prorbe or build input.
+   * @param recordBatch The probe or build record batch.
    * @return The {@link org.apache.drill.exec.record.RecordBatch.IterOutcome} for the left or right record batch.
    */
   private IterOutcome sniffNonEmptyBatch(IterOutcome curr, int inputIndex, RecordBatch recordBatch) {
@@ -354,8 +361,10 @@ private IterOutcome sniffNonEmptyBatch(IterOutcome curr, int inputIndex, RecordB
         case NOT_YET:
           // We need to try again
           break;
+        case EMIT:
+          throw new UnsupportedOperationException("We do not support " + EMIT);
         default:
-          // Other cases termination conditions
+          // Other cases are termination conditions
           return curr;
       }
     }
@@ -381,96 +390,119 @@ public HashJoinMemoryCalculator getCalculatorImpl() {
 
   @Override
   public IterOutcome innerNext() {
-    if (!prefetched) {
-      // If we didn't retrieve our first data hold batch, we need to do it now.
-      prefetched = true;
-      prefetchFirstBatchFromBothSides();
-
-      // Handle emitting the correct outcome for termination conditions
-      // Use the state set by prefetchFirstBatchFromBothSides to emit the correct termination outcome.
-      switch (state) {
-        case DONE:
-          return IterOutcome.NONE;
-        case STOP:
-          return IterOutcome.STOP;
-        case OUT_OF_MEMORY:
-          return IterOutcome.OUT_OF_MEMORY;
-        default:
-          // No termination condition so continue processing.
-      }
-    }
-
-    if ( wasKilled ) {
+    if (wasKilled) {
+      // We have recieved a kill signal. We need to stop processing.
       this.cleanup();
       super.close();
       return IterOutcome.NONE;
     }
 
+    prefetchFirstBuildBatch();
+
+    if (rightUpstream.isError()) {
+      // A termination condition was reached while prefetching the first build side data holding batch.
+      // We need to terminate.
+      return rightUpstream;
+    }
+
     try {
       /* If we are here for the first time, execute the build phase of the
        * hash join and setup the run time generated class for the probe side
        */
       if (state == BatchState.FIRST) {
         // Build the hash table, using the build side record batches.
-        executeBuildPhase();
+        final IterOutcome buildExecuteTermination = executeBuildPhase();
+
+        if (buildExecuteTermination != null) {
+          // A termination condition was reached while executing the build phase.
+          // We need to terminate.
+          return buildExecuteTermination;
+        }
+
         // Update the hash table related stats for the operator
         updateStats();
-        // Initialize various settings for the probe side
-        hashJoinProbe.setupHashJoinProbe(probeBatch, this, joinType, leftUpstream, partitions, cycleNum, container, spilledInners, buildSideIsEmpty, numPartitions, rightHVColPosition);
       }
 
       // Try to probe and project, or recursively handle a spilled partition
-      if ( ! buildSideIsEmpty ||  // If there are build-side rows
-           joinType != JoinRelType.INNER) {  // or if this is a left/full outer join
-
-        // Allocate the memory for the vectors in the output container
-        batchMemoryManager.allocateVectors(container);
-        hashJoinProbe.setTargetOutputCount(batchMemoryManager.getOutputRowCount());
+      if (!buildSideIsEmpty.booleanValue() ||  // If there are build-side rows
+        joinType != JoinRelType.INNER) {  // or if this is a left/full outer join
 
-        outputRecords = hashJoinProbe.probeAndProject();
+        prefetchFirstProbeBatch();
 
-        for (final VectorWrapper<?> v : container) {
-          v.getValueVector().getMutator().setValueCount(outputRecords);
+        if (leftUpstream.isError()) {
+          // A termination condition was reached while prefetching the first probe side data holding batch.
+          // We need to terminate.
+          return leftUpstream;
         }
-        container.setRecordCount(outputRecords);
 
-        batchMemoryManager.updateOutgoingStats(outputRecords);
-        if (logger.isDebugEnabled()) {
-          logger.debug("BATCH_STATS, outgoing: {}", new RecordBatchSizer(this));
-        }
+        if (!buildSideIsEmpty.booleanValue() || !probeSideIsEmpty.booleanValue()) {
+          // Only allocate outgoing vectors and execute probing logic if there is data
 
-        /* We are here because of one the following
-         * 1. Completed processing of all the records and we are done
-         * 2. We've filled up the outgoing batch to the maximum and we need to return upstream
-         * Either case build the output container's schema and return
-         */
-        if (outputRecords > 0 || state == BatchState.FIRST) {
           if (state == BatchState.FIRST) {
-            state = BatchState.NOT_FIRST;
+            // Initialize various settings for the probe side
+            hashJoinProbe.setupHashJoinProbe(probeBatch,
+              this,
+              joinType,
+              leftUpstream,
+              partitions,
+              cycleNum,
+              container,
+              spilledInners,
+              buildSideIsEmpty.booleanValue(),
+              numPartitions,
+              rightHVColPosition);
+          }
+
+          // Allocate the memory for the vectors in the output container
+          batchMemoryManager.allocateVectors(container);
+
+          hashJoinProbe.setTargetOutputCount(batchMemoryManager.getOutputRowCount());
+
+          outputRecords = hashJoinProbe.probeAndProject();
+
+          for (final VectorWrapper<?> v : container) {
+            v.getValueVector().getMutator().setValueCount(outputRecords);
+          }
+          container.setRecordCount(outputRecords);
+
+          batchMemoryManager.updateOutgoingStats(outputRecords);
+          if (logger.isDebugEnabled()) {
+            logger.debug("BATCH_STATS, outgoing: {}", new RecordBatchSizer(this));
           }
 
-          return IterOutcome.OK;
+          /* We are here because of one the following
+           * 1. Completed processing of all the records and we are done
+           * 2. We've filled up the outgoing batch to the maximum and we need to return upstream
+           * Either case build the output container's schema and return
+           */
+          if (outputRecords > 0 || state == BatchState.FIRST) {
+            if (state == BatchState.FIRST) {
+              state = BatchState.NOT_FIRST;
+            }
+
+            return IterOutcome.OK;
+          }
         }
 
         // Free all partitions' in-memory data structures
         // (In case need to start processing spilled partitions)
-        for ( HashPartition partn : partitions ) {
+        for (HashPartition partn : partitions) {
           partn.cleanup(false); // clean, but do not delete the spill files !!
         }
 
         //
         //  (recursively) Handle the spilled partitions, if any
         //
-        if ( !buildSideIsEmpty && !spilledPartitionsList.isEmpty()) {
+        if (!buildSideIsEmpty.booleanValue() && !spilledPartitionsList.isEmpty()) {
           // Get the next (previously) spilled partition to handle as incoming
           HJSpilledPartition currSp = spilledPartitionsList.remove(0);
 
           // Create a BUILD-side "incoming" out of the inner spill file of that partition
-          buildBatch = new SpilledRecordbatch(currSp.innerSpillFile, currSp.innerSpilledBatches, context, rightSchema, oContext, spillSet);
+          buildBatch = new SpilledRecordbatch(currSp.innerSpillFile, currSp.innerSpilledBatches, context, buildSchema, oContext, spillSet);
           // The above ctor call also got the first batch; need to update the outcome
           rightUpstream = ((SpilledRecordbatch) buildBatch).getInitialOutcome();
 
-          if ( currSp.outerSpilledBatches > 0 ) {
+          if (currSp.outerSpilledBatches > 0) {
             // Create a PROBE-side "incoming" out of the outer spill file of that partition
             probeBatch = new SpilledRecordbatch(currSp.outerSpillFile, currSp.outerSpilledBatches, context, probeSchema, oContext, spillSet);
             // The above ctor call also got the first batch; need to update the outcome
@@ -644,13 +676,14 @@ private void initializeBuild() {
         buildBatch,
         probeBatch,
         buildJoinColumns,
+        leftUpstream == IterOutcome.NONE, // probeEmpty
         allocator.getLimit(),
+        maxIncomingBatchSize,
         numPartitions,
         RECORDS_PER_BATCH,
         RECORDS_PER_BATCH,
         maxBatchSize,
         maxBatchSize,
-        batchMemoryManager.getOutputRowCount(),
         batchMemoryManager.getOutputBatchSize(),
         HashTable.DEFAULT_LOAD_FACTOR);
 
@@ -689,12 +722,13 @@ private void disableSpilling(String reason) {
    *  Execute the BUILD phase; first read incoming and split rows into partitions;
    *  may decide to spill some of the partitions
    *
+   * @return Returns an {@link org.apache.drill.exec.record.RecordBatch.IterOutcome} if a termination condition is reached. Otherwise returns null.
    * @throws SchemaChangeException
    */
-  public void executeBuildPhase() throws SchemaChangeException {
-    if (rightUpstream == IterOutcome.NONE) {
+  public IterOutcome executeBuildPhase() throws SchemaChangeException {
+    if (buildSideIsEmpty.booleanValue()) {
       // empty right
-      return;
+      return null;
     }
 
     HashJoinMemoryCalculator.BuildSidePartitioning buildCalc;
@@ -716,13 +750,14 @@ public void executeBuildPhase() throws SchemaChangeException {
         buildBatch,
         probeBatch,
         buildJoinColumns,
+        leftUpstream == IterOutcome.NONE, // probeEmpty
         allocator.getLimit(),
+        maxIncomingBatchSize,
         numPartitions,
         RECORDS_PER_BATCH,
         RECORDS_PER_BATCH,
         maxBatchSize,
         maxBatchSize,
-        batchMemoryManager.getOutputRowCount(),
         batchMemoryManager.getOutputBatchSize(),
         HashTable.DEFAULT_LOAD_FACTOR);
 
@@ -754,8 +789,8 @@ public void executeBuildPhase() throws SchemaChangeException {
         continue;
 
       case OK_NEW_SCHEMA:
-        if (!rightSchema.equals(buildBatch.getSchema())) {
-          throw SchemaChangeException.schemaChanged("Hash join does not support schema changes in build side.", rightSchema, buildBatch.getSchema());
+        if (!buildSchema.equals(buildBatch.getSchema())) {
+          throw SchemaChangeException.schemaChanged("Hash join does not support schema changes in build side.", buildSchema, buildBatch.getSchema());
         }
         for (HashPartition partn : partitions) { partn.updateBatches(); }
         // Fall through
@@ -801,8 +836,16 @@ public void executeBuildPhase() throws SchemaChangeException {
       }
     }
 
+    prefetchFirstProbeBatch();
+
+    if (leftUpstream.isError()) {
+      // A termination condition was reached while prefetching the first build side data holding batch.
+      // We need to terminate.
+      return leftUpstream;
+    }
+
     HashJoinMemoryCalculator.PostBuildCalculations postBuildCalc = buildCalc.next();
-    postBuildCalc.initialize();
+    postBuildCalc.initialize(probeSideIsEmpty.booleanValue()); // probeEmpty
 
     //
     //  Traverse all the in-memory partitions' incoming batches, and build their hash tables
@@ -849,14 +892,18 @@ public void executeBuildPhase() throws SchemaChangeException {
 
         spilledInners[partn.getPartitionNum()] = sp; // for the outer to find the SP later
         partn.closeWriter();
+
+        partn.updateProbeRecordsPerBatch(postBuildCalc.getProbeRecordsPerBatch());
       }
     }
+
+    return null;
   }
 
   private void setupOutputContainerSchema() {
 
-    if (rightSchema != null) {
-      for (final MaterializedField field : rightSchema) {
+    if (buildSchema != null) {
+      for (final MaterializedField field : buildSchema) {
         final MajorType inputType = field.getType();
         final MajorType outputType;
         // If left or full outer join, then the output type must be nullable. However, map types are
@@ -938,6 +985,7 @@ public HashJoinBatch(HashJoinPOP popConfig, FragmentContext context,
 
     this.allocator = oContext.getAllocator();
 
+    maxIncomingBatchSize = context.getOptions().getLong(ExecConstants.OUTPUT_BATCH_SIZE);
     numPartitions = (int)context.getOptions().getOption(ExecConstants.HASHJOIN_NUM_PARTITIONS_VALIDATOR);
     if ( numPartitions == 1 ) { //
       disableSpilling("Spilling is disabled due to configuration setting of num_partitions to 1");
@@ -976,7 +1024,7 @@ public HashJoinBatch(HashJoinPOP popConfig, FragmentContext context,
    * spillSet.
    */
   private void cleanup() {
-    if ( buildSideIsEmpty ) { return; } // not set up; nothing to clean
+    if ( buildSideIsEmpty.booleanValue() ) { return; } // not set up; nothing to clean
     if ( spillSet.getWriteBytes() > 0 ) {
       stats.setLongStat(Metric.SPILL_MB, // update stats - total MB spilled
         (int) Math.round(spillSet.getWriteBytes() / 1024.0D / 1024.0));
@@ -1027,7 +1075,7 @@ public String makeDebugString() {
    * written is updated at close time in {@link #cleanup()}.
    */
   private void updateStats() {
-    if ( buildSideIsEmpty ) { return; } // no stats when the right side is empty
+    if ( buildSideIsEmpty.booleanValue() ) { return; } // no stats when the right side is empty
     if ( cycleNum > 0 ) { return; } // These stats are only for before processing spilled files
 
     final HashTableStats htStats = new HashTableStats();
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMechanicalMemoryCalculator.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMechanicalMemoryCalculator.java
index fb087a0fd0d..af6be8bfe3c 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMechanicalMemoryCalculator.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMechanicalMemoryCalculator.java
@@ -59,6 +59,7 @@ public HashJoinState getState() {
 
     private int initialPartitions;
     private PartitionStatSet partitionStatSet;
+    private int recordsPerPartitionBatchProbe;
 
     public MechanicalBuildSidePartitioning(int maxNumInMemBatches) {
       this.maxNumInMemBatches = maxNumInMemBatches;
@@ -70,16 +71,18 @@ public void initialize(boolean autoTune,
                            RecordBatch buildSideBatch,
                            RecordBatch probeSideBatch,
                            Set<String> joinColumns,
+                           boolean probeEmpty,
                            long memoryAvailable,
+                           long maxIncomingBatchSize,
                            int initialPartitions,
                            int recordsPerPartitionBatchBuild,
                            int recordsPerPartitionBatchProbe,
                            int maxBatchNumRecordsBuild,
                            int maxBatchNumRecordsProbe,
-                           int outputBatchNumRecords,
                            int outputBatchSize,
                            double loadFactor) {
       this.initialPartitions = initialPartitions;
+      this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
     }
 
     @Override
@@ -115,7 +118,7 @@ public String makeDebugString() {
     @Nullable
     @Override
     public PostBuildCalculations next() {
-      return new MechanicalPostBuildCalculations(maxNumInMemBatches, partitionStatSet);
+      return new MechanicalPostBuildCalculations(maxNumInMemBatches, partitionStatSet, recordsPerPartitionBatchProbe);
     }
 
     @Override
@@ -127,16 +130,23 @@ public HashJoinState getState() {
   public static class MechanicalPostBuildCalculations implements PostBuildCalculations {
     private final int maxNumInMemBatches;
     private final PartitionStatSet partitionStatSet;
+    private final int recordsPerPartitionBatchProbe;
 
-    public MechanicalPostBuildCalculations(int maxNumInMemBatches,
-                                           PartitionStatSet partitionStatSet) {
+    public MechanicalPostBuildCalculations(final int maxNumInMemBatches,
+                                           final PartitionStatSet partitionStatSet,
+                                           final int recordsPerPartitionBatchProbe) {
       this.maxNumInMemBatches = maxNumInMemBatches;
       this.partitionStatSet = Preconditions.checkNotNull(partitionStatSet);
+      this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
     }
 
     @Override
-    public void initialize() {
-      // Do nothing
+    public void initialize(boolean probeEmty) {
+    }
+
+    @Override
+    public int getProbeRecordsPerBatch() {
+      return recordsPerPartitionBatchProbe;
     }
 
     @Override
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculator.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculator.java
index 868fbfd10ba..0ccd912d4ba 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculator.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculator.java
@@ -34,7 +34,7 @@
  * different memory calculations at each phase. The phases of execution have been broken down
  * into an explicit state machine diagram below. What ocurrs in each state is described in
  * the documentation of the {@link HashJoinState} class below. <b>Note:</b> the transition from Probing
- * and Partitioning back to Build Side Partitioning. This happens we had to spill probe side
+ * and Partitioning back to Build Side Partitioning. This happens when we had to spill probe side
  * partitions and we needed to recursively process spilled partitions. This recursion is
  * described in more detail in the example below.
  * </p>
@@ -86,6 +86,14 @@
   /**
    * The interface representing the {@link HashJoinStateCalculator} corresponding to the
    * {@link HashJoinState#BUILD_SIDE_PARTITIONING} state.
+   *
+   * <h4>Invariants</h4>
+   * <ul>
+   *   <li>
+   *     This calculator will only be used when there is build side data. If there is no build side data, the caller
+   *     should not invoke this calculator.
+   *   </li>
+   * </ul>
    */
   interface BuildSidePartitioning extends HashJoinStateCalculator<PostBuildCalculations> {
     void initialize(boolean autoTune,
@@ -93,13 +101,14 @@ void initialize(boolean autoTune,
                     RecordBatch buildSideBatch,
                     RecordBatch probeSideBatch,
                     Set<String> joinColumns,
+                    boolean probeEmpty,
                     long memoryAvailable,
+                    long maxIncomingBatchSize,
                     int initialPartitions,
                     int recordsPerPartitionBatchBuild,
                     int recordsPerPartitionBatchProbe,
                     int maxBatchNumRecordsBuild,
                     int maxBatchNumRecordsProbe,
-                    int outputBatchNumRecords,
                     int outputBatchSize,
                     double loadFactor);
 
@@ -121,7 +130,13 @@ void initialize(boolean autoTune,
    * {@link HashJoinState#POST_BUILD_CALCULATIONS} state.
    */
   interface PostBuildCalculations extends HashJoinStateCalculator<HashJoinMemoryCalculator> {
-    void initialize();
+    /**
+     * Initializes the calculator with additional information needed.
+     * @param probeEmty True if the probe is empty. False otherwise.
+     */
+    void initialize(boolean probeEmty);
+
+    int getProbeRecordsPerBatch();
 
     boolean shouldSpill();
 
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculatorImpl.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculatorImpl.java
index 37f33295ee2..a351cbcaf1c 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculatorImpl.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashJoinMemoryCalculatorImpl.java
@@ -73,7 +73,9 @@ public BuildSidePartitioning next() {
         throw new IllegalArgumentException("Invalid calc type: " + hashTableCalculatorType);
       }
 
-      return new BuildSidePartitioningImpl(hashTableSizeCalculator,
+      return new BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
+        hashTableSizeCalculator,
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
         fragmentationFactor, safetyFactor);
     } else {
@@ -86,65 +88,28 @@ public HashJoinState getState() {
     return INITIALIZING;
   }
 
-  public static long computeMaxBatchSizeNoHash(final long incomingBatchSize,
-                                         final int incomingNumRecords,
-                                         final int desiredNumRecords,
-                                         final double fragmentationFactor,
-                                         final double safetyFactor) {
-    long maxBatchSize = HashJoinMemoryCalculatorImpl
-      .computePartitionBatchSize(incomingBatchSize, incomingNumRecords, desiredNumRecords);
-    // Multiple by fragmentation factor
-    return RecordBatchSizer.multiplyByFactors(maxBatchSize, fragmentationFactor, safetyFactor);
-  }
-
-  public static long computeMaxBatchSize(final long incomingBatchSize,
-                                         final int incomingNumRecords,
-                                         final int desiredNumRecords,
-                                         final double fragmentationFactor,
-                                         final double safetyFactor,
-                                         final boolean reserveHash) {
-    long size = computeMaxBatchSizeNoHash(incomingBatchSize,
-      incomingNumRecords,
-      desiredNumRecords,
-      fragmentationFactor,
-      safetyFactor);
-
-    if (!reserveHash) {
-      return size;
-    }
-
-    long hashSize = desiredNumRecords * ((long) IntVector.VALUE_WIDTH);
-    hashSize = RecordBatchSizer.multiplyByFactors(hashSize, fragmentationFactor);
-
-    return size + hashSize;
-  }
-
-  public static long computePartitionBatchSize(final long incomingBatchSize,
-                                               final int incomingNumRecords,
-                                               final int desiredNumRecords) {
-    return (long) Math.ceil((((double) incomingBatchSize) /
-      ((double) incomingNumRecords)) *
-      ((double) desiredNumRecords));
-  }
-
   public static class NoopBuildSidePartitioningImpl implements BuildSidePartitioning {
     private int initialPartitions;
+    private int recordsPerPartitionBatchProbe;
 
     @Override
     public void initialize(boolean autoTune,
                            boolean reserveHash,
                            RecordBatch buildSideBatch,
-                           RecordBatch probeSideBatch, Set<String> joinColumns,
+                           RecordBatch probeSideBatch,
+                           Set<String> joinColumns,
+                           boolean probeEmpty,
                            long memoryAvailable,
+                           long maxIncomingBatchSize,
                            int initialPartitions,
                            int recordsPerPartitionBatchBuild,
                            int recordsPerPartitionBatchProbe,
                            int maxBatchNumRecordsBuild,
                            int maxBatchNumRecordsProbe,
-                           int outputBatchNumRecords,
                            int outputBatchSize,
                            double loadFactor) {
       this.initialPartitions = initialPartitions;
+      this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
     }
 
     @Override
@@ -180,7 +145,7 @@ public String makeDebugString() {
     @Nullable
     @Override
     public PostBuildCalculations next() {
-      return new NoopPostBuildCalculationsImpl();
+      return new NoopPostBuildCalculationsImpl(recordsPerPartitionBatchProbe);
     }
 
     @Override
@@ -204,7 +169,7 @@ public HashJoinState getState() {
    * <h1>Life Cycle</h1>
    * <p>
    *   <ul>
-   *     <li><b>Step 0:</b> Call {@link #initialize(boolean, boolean, RecordBatch, RecordBatch, Set, long, int, int, int, int, int, int, int, double)}.
+   *     <li><b>Step 0:</b> Call {@link #initialize(boolean, boolean, RecordBatch, RecordBatch, Set, boolean, long, long, int, int, int, int, int, int, double)}.
    *     This will initialize the StateCalculate with the additional information it needs.</li>
    *     <li><b>Step 1:</b> Call {@link #getNumPartitions()} to see the number of partitions that fit in memory.</li>
    *     <li><b>Step 2:</b> Call {@link #shouldSpill()} To determine if spilling needs to occurr.</li>
@@ -215,6 +180,7 @@ public HashJoinState getState() {
   public static class BuildSidePartitioningImpl implements BuildSidePartitioning {
     public static final Logger log = LoggerFactory.getLogger(BuildSidePartitioning.class);
 
+    private final BatchSizePredictor.Factory batchSizePredictorFactory;
     private final HashTableSizeCalculator hashTableSizeCalculator;
     private final HashJoinHelperSizeCalculator hashJoinHelperSizeCalculator;
     private final double fragmentationFactor;
@@ -223,10 +189,8 @@ public HashJoinState getState() {
     private int maxBatchNumRecordsBuild;
     private int maxBatchNumRecordsProbe;
     private long memoryAvailable;
-    private long buildBatchSize;
-    private long probeBatchSize;
-    private int buildNumRecords;
-    private int probeNumRecords;
+    private boolean probeEmpty;
+    private long maxIncomingBatchSize;
     private long maxBuildBatchSize;
     private long maxProbeBatchSize;
     private long maxOutputBatchSize;
@@ -246,13 +210,17 @@ public HashJoinState getState() {
     private long reservedMemory;
     private long maxReservedMemory;
 
+    private BatchSizePredictor buildSizePredictor;
+    private BatchSizePredictor probeSizePredictor;
     private boolean firstInitialized;
     private boolean initialized;
 
-    public BuildSidePartitioningImpl(final HashTableSizeCalculator hashTableSizeCalculator,
+    public BuildSidePartitioningImpl(final BatchSizePredictor.Factory batchSizePredictorFactory,
+                                     final HashTableSizeCalculator hashTableSizeCalculator,
                                      final HashJoinHelperSizeCalculator hashJoinHelperSizeCalculator,
                                      final double fragmentationFactor,
                                      final double safetyFactor) {
+      this.batchSizePredictorFactory = Preconditions.checkNotNull(batchSizePredictorFactory);
       this.hashTableSizeCalculator = Preconditions.checkNotNull(hashTableSizeCalculator);
       this.hashJoinHelperSizeCalculator = Preconditions.checkNotNull(hashJoinHelperSizeCalculator);
       this.fragmentationFactor = fragmentationFactor;
@@ -262,35 +230,33 @@ public BuildSidePartitioningImpl(final HashTableSizeCalculator hashTableSizeCalc
     @Override
     public void initialize(boolean autoTune,
                            boolean reserveHash,
-                           RecordBatch buildSideBatch,
-                           RecordBatch probeSideBatch,
+                           RecordBatch buildBatch,
+                           RecordBatch probeBatch,
                            Set<String> joinColumns,
+                           boolean probeEmpty,
                            long memoryAvailable,
+                           long maxIncomingBatchSize,
                            int initialPartitions,
                            int recordsPerPartitionBatchBuild,
                            int recordsPerPartitionBatchProbe,
                            int maxBatchNumRecordsBuild,
                            int maxBatchNumRecordsProbe,
-                           int outputBatchNumRecords,
                            int outputBatchSize,
                            double loadFactor) {
-      Preconditions.checkNotNull(buildSideBatch);
-      Preconditions.checkNotNull(probeSideBatch);
+      Preconditions.checkNotNull(probeBatch);
+      Preconditions.checkNotNull(buildBatch);
       Preconditions.checkNotNull(joinColumns);
 
-      final RecordBatchSizer buildSizer = new RecordBatchSizer(buildSideBatch);
-      final RecordBatchSizer probeSizer = new RecordBatchSizer(probeSideBatch);
+      final BatchSizePredictor buildSizePredictor =
+        batchSizePredictorFactory.create(buildBatch, fragmentationFactor, safetyFactor);
+      final BatchSizePredictor probeSizePredictor =
+        batchSizePredictorFactory.create(probeBatch, fragmentationFactor, safetyFactor);
 
-      long buildBatchSize = getBatchSizeEstimate(buildSideBatch);
-      long probeBatchSize = getBatchSizeEstimate(probeSideBatch);
+      buildSizePredictor.updateStats();
+      probeSizePredictor.updateStats();
 
-      int buildNumRecords = buildSizer.rowCount();
-      int probeNumRecords = probeSizer.rowCount();
+      final RecordBatchSizer buildSizer = new RecordBatchSizer(buildBatch);
 
-      final CaseInsensitiveMap<Long> buildValueSizes = getNotExcludedColumnSizes(
-        joinColumns, buildSizer);
-      final CaseInsensitiveMap<Long> probeValueSizes = getNotExcludedColumnSizes(
-        joinColumns, probeSizer);
       final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
 
       for (String joinColumn: joinColumns) {
@@ -302,11 +268,11 @@ public void initialize(boolean autoTune,
         reserveHash,
         keySizes,
         memoryAvailable,
+        maxIncomingBatchSize,
         initialPartitions,
-        buildBatchSize,
-        probeBatchSize,
-        buildNumRecords,
-        probeNumRecords,
+        probeEmpty,
+        buildSizePredictor,
+        probeSizePredictor,
         recordsPerPartitionBatchBuild,
         recordsPerPartitionBatchProbe,
         maxBatchNumRecordsBuild,
@@ -315,48 +281,16 @@ public void initialize(boolean autoTune,
         loadFactor);
     }
 
-    @VisibleForTesting
-    protected static CaseInsensitiveMap<Long> getNotExcludedColumnSizes(
-        final Set<String> excludedColumns,
-        final RecordBatchSizer batchSizer) {
-      final CaseInsensitiveMap<Long> columnSizes = CaseInsensitiveMap.newHashMap();
-      final CaseInsensitiveMap<Boolean> excludedSet = CaseInsensitiveMap.newHashMap();
-
-      for (final String excludedColumn: excludedColumns) {
-        excludedSet.put(excludedColumn, true);
-      }
-
-      for (final Map.Entry<String, RecordBatchSizer.ColumnSize> entry: batchSizer.columns().entrySet()) {
-        final String columnName = entry.getKey();
-        final RecordBatchSizer.ColumnSize columnSize = entry.getValue();
-
-        columnSizes.put(columnName, (long) columnSize.getStdNetOrNetSizePerEntry());
-      }
-
-      return columnSizes;
-    }
-
-    public static long getBatchSizeEstimate(final RecordBatch recordBatch) {
-      final RecordBatchSizer sizer = new RecordBatchSizer(recordBatch);
-      long size = 0L;
-
-      for (Map.Entry<String, RecordBatchSizer.ColumnSize> column: sizer.columns().entrySet()) {
-        size += PostBuildCalculationsImpl.computeValueVectorSize(recordBatch.getRecordCount(), column.getValue().getStdNetOrNetSizePerEntry());
-      }
-
-      return size;
-    }
-
     @VisibleForTesting
     protected void initialize(boolean autoTune,
                               boolean reserveHash,
                               CaseInsensitiveMap<Long> keySizes,
                               long memoryAvailable,
+                              long maxIncomingBatchSize,
                               int initialPartitions,
-                              long buildBatchSize,
-                              long probeBatchSize,
-                              int buildNumRecords,
-                              int probeNumRecords,
+                              boolean probeEmpty,
+                              BatchSizePredictor buildSizePredictor,
+                              BatchSizePredictor probeSizePredictor,
                               int recordsPerPartitionBatchBuild,
                               int recordsPerPartitionBatchProbe,
                               int maxBatchNumRecordsBuild,
@@ -365,6 +299,9 @@ protected void initialize(boolean autoTune,
                               double loadFactor) {
       Preconditions.checkState(!firstInitialized);
       Preconditions.checkArgument(initialPartitions >= 1);
+      // If we had probe data before there should still be probe data now.
+      // If we didn't have probe data before we could get some new data now.
+      Preconditions.checkState(!(probeEmpty && probeSizePredictor.hadDataLastTime()));
       firstInitialized = true;
 
       this.loadFactor = loadFactor;
@@ -372,10 +309,10 @@ protected void initialize(boolean autoTune,
       this.reserveHash = reserveHash;
       this.keySizes = Preconditions.checkNotNull(keySizes);
       this.memoryAvailable = memoryAvailable;
-      this.buildBatchSize = buildBatchSize;
-      this.probeBatchSize = probeBatchSize;
-      this.buildNumRecords = buildNumRecords;
-      this.probeNumRecords = probeNumRecords;
+      this.probeEmpty = probeEmpty;
+      this.maxIncomingBatchSize = maxIncomingBatchSize;
+      this.buildSizePredictor = buildSizePredictor;
+      this.probeSizePredictor = probeSizePredictor;
       this.initialPartitions = initialPartitions;
       this.recordsPerPartitionBatchBuild = recordsPerPartitionBatchBuild;
       this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
@@ -420,31 +357,32 @@ public long getMaxReservedMemory() {
     private void calculateMemoryUsage()
     {
       // Adjust based on number of records
-      maxBuildBatchSize = computeMaxBatchSizeNoHash(buildBatchSize, buildNumRecords,
-        maxBatchNumRecordsBuild, fragmentationFactor, safetyFactor);
-      maxProbeBatchSize = computeMaxBatchSizeNoHash(probeBatchSize, probeNumRecords,
-        maxBatchNumRecordsProbe, fragmentationFactor, safetyFactor);
-
-      // Safety factor can be multiplied at the end since these batches are coming from exchange operators, so no excess value vector doubling
-      partitionBuildBatchSize = computeMaxBatchSize(buildBatchSize,
-        buildNumRecords,
-        recordsPerPartitionBatchBuild,
-        fragmentationFactor,
-        safetyFactor,
-        reserveHash);
+      maxBuildBatchSize = buildSizePredictor.predictBatchSize(maxBatchNumRecordsBuild, false);
 
-      // Safety factor can be multiplied at the end since these batches are coming from exchange operators, so no excess value vector doubling
-      partitionProbeBatchSize = computeMaxBatchSize(
-        probeBatchSize,
-        probeNumRecords,
-        recordsPerPartitionBatchProbe,
-        fragmentationFactor,
-        safetyFactor,
-        reserveHash);
+      if (probeSizePredictor.hadDataLastTime()) {
+        // We have probe data and we can compute the max incoming size.
+        maxProbeBatchSize = probeSizePredictor.predictBatchSize(maxBatchNumRecordsProbe, false);
+      } else {
+        // We don't have probe data
+        if (probeEmpty) {
+          // We know the probe has no data, so we don't need to reserve any space for the incoming probe
+          maxProbeBatchSize = 0;
+        } else {
+          // The probe side may have data, so assume it is the max incoming batch size. This assumption
+          // can fail in some cases since the batch sizing project is incomplete.
+          maxProbeBatchSize = maxIncomingBatchSize;
+        }
+      }
+
+      partitionBuildBatchSize = buildSizePredictor.predictBatchSize(recordsPerPartitionBatchBuild, reserveHash);
+
+      if (probeSizePredictor.hadDataLastTime()) {
+        partitionProbeBatchSize = probeSizePredictor.predictBatchSize(recordsPerPartitionBatchProbe, reserveHash);
+      }
 
       maxOutputBatchSize = (long) ((double)outputBatchSize * fragmentationFactor * safetyFactor);
 
-      long probeReservedMemory;
+      long probeReservedMemory = 0;
 
       for (partitions = initialPartitions;; partitions /= 2) {
         // The total amount of memory to reserve for incomplete batches across all partitions
@@ -455,13 +393,19 @@ private void calculateMemoryUsage()
         // they will have a well defined size.
         reservedMemory = incompletePartitionsBatchSizes + maxBuildBatchSize + maxProbeBatchSize;
 
-        probeReservedMemory = PostBuildCalculationsImpl.calculateReservedMemory(
-          partitions,
-          maxProbeBatchSize,
-          maxOutputBatchSize,
-          partitionProbeBatchSize);
+        if (probeSizePredictor.hadDataLastTime()) {
+          // If we have probe data, use it in our memory reservation calculations.
+          probeReservedMemory = PostBuildCalculationsImpl.calculateReservedMemory(
+            partitions,
+            maxProbeBatchSize,
+            maxOutputBatchSize,
+            partitionProbeBatchSize);
 
-        maxReservedMemory = Math.max(reservedMemory, probeReservedMemory);
+          maxReservedMemory = Math.max(reservedMemory, probeReservedMemory);
+        } else {
+          // If we do not have probe data, do our best effort at estimating the number of partitions without it.
+          maxReservedMemory = reservedMemory;
+        }
 
         if (!autoTune || maxReservedMemory <= memoryAvailable) {
           // Stop the tuning loop if we are not doing auto tuning, or if we are living within our memory limit
@@ -488,19 +432,19 @@ private void calculateMemoryUsage()
           "partitionProbeBatchSize = %d\n" +
           "recordsPerPartitionBatchProbe = %d\n",
           reservedMemory, memoryAvailable, partitions, initialPartitions,
-          buildBatchSize,
-          buildNumRecords,
+          buildSizePredictor.getBatchSize(),
+          buildSizePredictor.getNumRecords(),
           partitionBuildBatchSize,
           recordsPerPartitionBatchBuild,
-          probeBatchSize,
-          probeNumRecords,
+          probeSizePredictor.getBatchSize(),
+          probeSizePredictor.getNumRecords(),
           partitionProbeBatchSize,
           recordsPerPartitionBatchProbe);
 
         String phase = "Probe phase: ";
 
         if (reservedMemory > memoryAvailable) {
-          if (probeReservedMemory > memoryAvailable) {
+          if (probeSizePredictor.hadDataLastTime() && probeReservedMemory > memoryAvailable) {
             phase = "Build and Probe phases: ";
           } else {
             phase = "Build phase: ";
@@ -531,10 +475,12 @@ public boolean shouldSpill() {
     public PostBuildCalculations next() {
       Preconditions.checkState(initialized);
 
-      return new PostBuildCalculationsImpl(memoryAvailable,
-        partitionProbeBatchSize,
-        maxProbeBatchSize,
+      return new PostBuildCalculationsImpl(
+        probeSizePredictor,
+        memoryAvailable,
         maxOutputBatchSize,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         partitionStatsSet,
         keySizes,
         hashTableSizeCalculator,
@@ -572,9 +518,19 @@ public String makeDebugString() {
   }
 
   public static class NoopPostBuildCalculationsImpl implements PostBuildCalculations {
+    private final int recordsPerPartitionBatchProbe;
+
+    public NoopPostBuildCalculationsImpl(final int recordsPerPartitionBatchProbe) {
+      this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
+    }
+
     @Override
-    public void initialize() {
+    public void initialize(boolean hasProbeData) {
+    }
 
+    @Override
+    public int getProbeRecordsPerBatch() {
+      return recordsPerPartitionBatchProbe;
     }
 
     @Override
@@ -610,7 +566,7 @@ public String makeDebugString() {
    * <h1>Lifecycle</h1>
    * <p>
    *   <ul>
-   *     <li><b>Step 1:</b> Call {@link #initialize()}. This
+   *     <li><b>Step 1:</b> Call {@link #initialize(boolean)}. This
    *     gives the {@link HashJoinStateCalculator} additional information it needs to compute memory requirements.</li>
    *     <li><b>Step 2:</b> Call {@link #shouldSpill()}. This tells
    *     you which build side partitions need to be spilled in order to make room for probing.</li>
@@ -620,10 +576,15 @@ public String makeDebugString() {
    * </p>
    */
   public static class PostBuildCalculationsImpl implements PostBuildCalculations {
+    private static final Logger log = LoggerFactory.getLogger(PostBuildCalculationsImpl.class);
+
+    public static final int MIN_RECORDS_PER_PARTITION_BATCH_PROBE = 10;
+
+    private final BatchSizePredictor probeSizePredictor;
     private final long memoryAvailable;
-    private final long partitionProbeBatchSize;
-    private final long maxProbeBatchSize;
     private final long maxOutputBatchSize;
+    private final int maxBatchNumRecordsProbe;
+    private final int recordsPerPartitionBatchProbe;
     private final PartitionStatSet buildPartitionStatSet;
     private final Map<String, Long> keySizes;
     private final HashTableSizeCalculator hashTableSizeCalculator;
@@ -632,26 +593,30 @@ public String makeDebugString() {
     private final double safetyFactor;
     private final double loadFactor;
     private final boolean reserveHash;
-    // private final long maxOutputBatchSize;
 
     private boolean initialized;
     private long consumedMemory;
+    private boolean probeEmpty;
+    private long maxProbeBatchSize;
+    private long partitionProbeBatchSize;
+    private int computedProbeRecordsPerBatch;
 
-    public PostBuildCalculationsImpl(final long memoryAvailable,
-                                     final long partitionProbeBatchSize,
-                                     final long maxProbeBatchSize,
-                                     final long maxOutputBatchSize,
-                                     final PartitionStatSet buildPartitionStatSet,
-                                     final Map<String, Long> keySizes,
-                                     final HashTableSizeCalculator hashTableSizeCalculator,
-                                     final HashJoinHelperSizeCalculator hashJoinHelperSizeCalculator,
-                                     final double fragmentationFactor,
-                                     final double safetyFactor,
-                                     final double loadFactor,
-                                     final boolean reserveHash) {
+    @VisibleForTesting
+    public PostBuildCalculationsImpl(final BatchSizePredictor probeSizePredictor,
+                                      final long memoryAvailable,
+                                      final long maxOutputBatchSize,
+                                      final int maxBatchNumRecordsProbe,
+                                      final int recordsPerPartitionBatchProbe,
+                                      final PartitionStatSet buildPartitionStatSet,
+                                      final Map<String, Long> keySizes,
+                                      final HashTableSizeCalculator hashTableSizeCalculator,
+                                      final HashJoinHelperSizeCalculator hashJoinHelperSizeCalculator,
+                                      final double fragmentationFactor,
+                                      final double safetyFactor,
+                                      final double loadFactor,
+                                      final boolean reserveHash) {
+      this.probeSizePredictor = Preconditions.checkNotNull(probeSizePredictor);
       this.memoryAvailable = memoryAvailable;
-      this.partitionProbeBatchSize = partitionProbeBatchSize;
-      this.maxProbeBatchSize = maxProbeBatchSize;
       this.maxOutputBatchSize = maxOutputBatchSize;
       this.buildPartitionStatSet = Preconditions.checkNotNull(buildPartitionStatSet);
       this.keySizes = Preconditions.checkNotNull(keySizes);
@@ -661,38 +626,100 @@ public PostBuildCalculationsImpl(final long memoryAvailable,
       this.safetyFactor = safetyFactor;
       this.loadFactor = loadFactor;
       this.reserveHash = reserveHash;
+      this.maxBatchNumRecordsProbe = maxBatchNumRecordsProbe;
+      this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
+      this.computedProbeRecordsPerBatch = recordsPerPartitionBatchProbe;
     }
 
-    // TODO take an incoming Probe RecordBatch
     @Override
-    public void initialize() {
+    public void initialize(boolean probeEmpty) {
       Preconditions.checkState(!initialized);
+      // If we had probe data before there should still be probe data now.
+      // If we didn't have probe data before we could get some new data now.
+      Preconditions.checkState(probeSizePredictor.hadDataLastTime() && !probeEmpty || !probeSizePredictor.hadDataLastTime());
       initialized = true;
+      this.probeEmpty = probeEmpty;
+
+      if (probeEmpty) {
+        // We know there is no probe side data, so we don't need to calculate anything.
+        return;
+      }
+
+      // We need to compute sizes of probe side data.
+      if (!probeSizePredictor.hadDataLastTime()) {
+        probeSizePredictor.updateStats();
+      }
+
+      maxProbeBatchSize = probeSizePredictor.predictBatchSize(maxBatchNumRecordsProbe, false);
+      partitionProbeBatchSize = probeSizePredictor.predictBatchSize(recordsPerPartitionBatchProbe, reserveHash);
+
+      long worstCaseProbeMemory = calculateReservedMemory(
+        buildPartitionStatSet.getSize(),
+        maxProbeBatchSize,
+        maxOutputBatchSize,
+        partitionProbeBatchSize);
+
+      if (worstCaseProbeMemory > memoryAvailable) {
+        // We don't have enough memory for the probe data if all the partitions are spilled, we need to adjust the records
+        // per probe partition batch in order to make this work.
+
+        computedProbeRecordsPerBatch = computeProbeRecordsPerBatch(memoryAvailable,
+          buildPartitionStatSet.getSize(),
+          recordsPerPartitionBatchProbe,
+          MIN_RECORDS_PER_PARTITION_BATCH_PROBE,
+          maxProbeBatchSize,
+          maxOutputBatchSize,
+          partitionProbeBatchSize);
+
+        partitionProbeBatchSize = probeSizePredictor.predictBatchSize(computedProbeRecordsPerBatch, reserveHash);
+      }
     }
 
-    public long getConsumedMemory() {
+    @Override
+    public int getProbeRecordsPerBatch() {
       Preconditions.checkState(initialized);
-      return consumedMemory;
+      return computedProbeRecordsPerBatch;
     }
 
-    // TODO move this somewhere else that makes sense
-    public static long computeValueVectorSize(long numRecords, long byteSize)
-    {
-      long naiveSize = numRecords * byteSize;
-      return roundUpToPowerOf2(naiveSize);
+    @VisibleForTesting
+    public long getMaxProbeBatchSize() {
+      return maxProbeBatchSize;
     }
 
-    public static long computeValueVectorSize(long numRecords, long byteSize, double safetyFactor)
-    {
-      long naiveSize = RecordBatchSizer.multiplyByFactor(numRecords * byteSize, safetyFactor);
-      return roundUpToPowerOf2(naiveSize);
+    @VisibleForTesting
+    public long getPartitionProbeBatchSize() {
+      return partitionProbeBatchSize;
     }
 
-    // TODO move to drill common
-    public static long roundUpToPowerOf2(long num)
-    {
-      Preconditions.checkArgument(num >= 1);
-      return num == 1 ? 1 : Long.highestOneBit(num - 1) << 1;
+    public long getConsumedMemory() {
+      Preconditions.checkState(initialized);
+      return consumedMemory;
+    }
+
+    public static int computeProbeRecordsPerBatch(final long memoryAvailable,
+                                                  final int numPartitions,
+                                                  final int defaultProbeRecordsPerBatch,
+                                                  final int minProbeRecordsPerBatch,
+                                                  final long maxProbeBatchSize,
+                                                  final long maxOutputBatchSize,
+                                                  final long defaultPartitionProbeBatchSize) {
+      long memoryForPartitionBatches = memoryAvailable - maxProbeBatchSize - maxOutputBatchSize;
+
+      if (memoryForPartitionBatches < 0) {
+        // We just don't have enough memory. We should do our best though by using the minimum batch size.
+        log.warn("Not enough memory for probing:\n" +
+          "Memory available: {}\n" +
+          "Max probe batch size: {}\n" +
+          "Max output batch size: {}",
+          memoryAvailable,
+          maxProbeBatchSize,
+          maxOutputBatchSize);
+        return minProbeRecordsPerBatch;
+      }
+
+      long memoryForPartitionBatch = (memoryForPartitionBatches + numPartitions - 1) / numPartitions;
+      long scaleFactor = (defaultPartitionProbeBatchSize + memoryForPartitionBatch - 1) / memoryForPartitionBatch;
+      return Math.max((int) (defaultProbeRecordsPerBatch / scaleFactor), minProbeRecordsPerBatch);
     }
 
     public static long calculateReservedMemory(final int numSpilledPartitions,
@@ -710,6 +737,11 @@ public static long calculateReservedMemory(final int numSpilledPartitions,
     public boolean shouldSpill() {
       Preconditions.checkState(initialized);
 
+      if (probeEmpty) {
+        // If the probe is empty, we should not trigger any spills.
+        return false;
+      }
+
       long reservedMemory = calculateReservedMemory(
         buildPartitionStatSet.getNumSpilledPartitions(),
         maxProbeBatchSize,
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorConservativeImpl.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorConservativeImpl.java
index 85750210ca8..a366eeafcbd 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorConservativeImpl.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorConservativeImpl.java
@@ -23,7 +23,7 @@
 
 import java.util.Map;
 
-import static org.apache.drill.exec.physical.impl.join.HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize;
+import static org.apache.drill.exec.physical.impl.join.BatchSizePredictorImpl.computeValueVectorSize;
 
 public class HashTableSizeCalculatorConservativeImpl implements HashTableSizeCalculator {
   public static final String TYPE = "CONSERVATIVE";
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorLeanImpl.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorLeanImpl.java
index 4f9e5855ed1..265b0e33768 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorLeanImpl.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/join/HashTableSizeCalculatorLeanImpl.java
@@ -23,7 +23,7 @@
 
 import java.util.Map;
 
-import static org.apache.drill.exec.physical.impl.join.HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize;
+import static org.apache.drill.exec.physical.impl.join.BatchSizePredictorImpl.computeValueVectorSize;
 
 public class HashTableSizeCalculatorLeanImpl implements HashTableSizeCalculator {
   public static final String TYPE = "LEAN";
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/record/AbstractBinaryRecordBatch.java b/exec/java-exec/src/main/java/org/apache/drill/exec/record/AbstractBinaryRecordBatch.java
index d75463b0a7a..e7fa4e6b57b 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/record/AbstractBinaryRecordBatch.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/record/AbstractBinaryRecordBatch.java
@@ -76,12 +76,12 @@ protected AbstractBinaryRecordBatch(final T popConfig, final FragmentContext con
   }
 
   protected boolean verifyOutcomeToSetBatchState(IterOutcome leftOutcome, IterOutcome rightOutcome) {
-    if (leftOutcome == IterOutcome.STOP || rightUpstream == IterOutcome.STOP) {
+    if (leftOutcome == IterOutcome.STOP || rightOutcome == IterOutcome.STOP) {
       state = BatchState.STOP;
       return false;
     }
 
-    if (leftOutcome == IterOutcome.OUT_OF_MEMORY || rightUpstream == IterOutcome.OUT_OF_MEMORY) {
+    if (leftOutcome == IterOutcome.OUT_OF_MEMORY || rightOutcome == IterOutcome.OUT_OF_MEMORY) {
       state = BatchState.OUT_OF_MEMORY;
       return false;
     }
@@ -97,6 +97,7 @@ protected boolean verifyOutcomeToSetBatchState(IterOutcome leftOutcome, IterOutc
       throw new IllegalStateException("Unexpected IterOutcome.EMIT received either from left or right side in " +
         "buildSchema phase");
     }
+
     return true;
   }
 
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatch.java b/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatch.java
index 6954374ce75..f0cab26c432 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatch.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatch.java
@@ -116,7 +116,7 @@
      *   returned at least once (not necessarily <em>immediately</em> after).
      * </p>
      */
-    NONE,
+    NONE(false),
 
     /**
      * Zero or more records with same schema.
@@ -134,7 +134,7 @@
      *   returned at least once (not necessarily <em>immediately</em> after).
      * </p>
      */
-    OK,
+    OK(false),
 
     /**
      * New schema, maybe with records.
@@ -147,7 +147,7 @@
      *     ({@code next()} should be called again.)
      * </p>
      */
-    OK_NEW_SCHEMA,
+    OK_NEW_SCHEMA(false),
 
     /**
      * Non-completion (abnormal) termination.
@@ -162,7 +162,7 @@
      *   of things.
      * </p>
      */
-    STOP,
+    STOP(true),
 
     /**
      * No data yet.
@@ -184,7 +184,7 @@
      *   Used by batches that haven't received incoming data yet.
      * </p>
      */
-    NOT_YET,
+    NOT_YET(false),
 
     /**
      * Out of memory (not fatal).
@@ -198,7 +198,7 @@
      *     {@code OUT_OF_MEMORY} to its caller) and call {@code next()} again.
      * </p>
      */
-    OUT_OF_MEMORY,
+    OUT_OF_MEMORY(true),
 
     /**
      * Emit record to produce output batches.
@@ -223,7 +223,17 @@
      *   input and again start from build side.
      * </p>
      */
-    EMIT,
+    EMIT(false);
+
+    private boolean error;
+
+    IterOutcome(boolean error) {
+      this.error = error;
+    }
+
+    public boolean isError() {
+      return error;
+    }
   }
 
   /**
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBatchSizePredictorImpl.java b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBatchSizePredictorImpl.java
new file mode 100644
index 00000000000..e16cdf64c52
--- /dev/null
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBatchSizePredictorImpl.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.physical.impl.join;
+
+import org.apache.drill.exec.vector.IntVector;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestBatchSizePredictorImpl {
+  @Test
+  public void testComputeMaxBatchSizeHash()
+  {
+    long expected = BatchSizePredictorImpl.computeMaxBatchSizeNoHash(
+      100,
+      25,
+      100,
+      2.0,
+      4.0) +
+      100 * IntVector.VALUE_WIDTH * 2;
+
+    final long actual = BatchSizePredictorImpl.computeMaxBatchSize(
+      100,
+      25,
+      100,
+      2.0,
+      4.0,
+      true);
+
+    Assert.assertEquals(expected, actual);
+  }
+
+  @Test
+  public void testComputeMaxBatchSizeNoHash() {
+    final long expected = 1200;
+    final long actual = BatchSizePredictorImpl.computeMaxBatchSize(
+      100,
+      25,
+      100,
+      2.0,
+      1.5,
+      false);
+    final long actualNoHash = BatchSizePredictorImpl.computeMaxBatchSizeNoHash(
+      100,
+      25,
+      100,
+      2.0,
+      1.5);
+
+    Assert.assertEquals(expected, actual);
+    Assert.assertEquals(expected, actualNoHash);
+  }
+
+  @Test
+  public void testRoundUpPowerOf2() {
+    long expected = 32;
+    long actual = BatchSizePredictorImpl.roundUpToPowerOf2(expected);
+
+    Assert.assertEquals(expected, actual);
+  }
+
+  @Test
+  public void testRounUpNonPowerOf2ToPowerOf2() {
+    long expected = 32;
+    long actual = BatchSizePredictorImpl.roundUpToPowerOf2(31);
+
+    Assert.assertEquals(expected, actual);
+  }
+
+  @Test
+  public void testComputeValueVectorSizePowerOf2() {
+    long expected = 4;
+    long actual =
+      BatchSizePredictorImpl.computeValueVectorSize(2, 2);
+
+    Assert.assertEquals(expected, actual);
+  }
+
+  @Test
+  public void testComputeValueVectorSizeNonPowerOf2() {
+    long expected = 16;
+    long actual =
+      BatchSizePredictorImpl.computeValueVectorSize(3, 3);
+
+    Assert.assertEquals(expected, actual);
+  }
+}
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBuildSidePartitioningImpl.java b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBuildSidePartitioningImpl.java
index 2a44edb1795..ceebc811c0f 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBuildSidePartitioningImpl.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestBuildSidePartitioningImpl.java
@@ -17,6 +17,7 @@
  */
 package org.apache.drill.exec.physical.impl.join;
 
+import com.google.common.base.Preconditions;
 import org.apache.drill.common.map.CaseInsensitiveMap;
 import org.apache.drill.exec.record.RecordBatch;
 import org.junit.Assert;
@@ -26,26 +27,28 @@
   @Test
   public void testSimpleReserveMemoryCalculationNoHash() {
     final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
     final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
       new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
         new HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
-        2.0,
-        1.5);
+        fragmentationFactor,
+        safetyFactor);
 
-    final CaseInsensitiveMap<Long> buildValueSizes = CaseInsensitiveMap.newHashMap();
-    final CaseInsensitiveMap<Long> probeValueSizes = CaseInsensitiveMap.newHashMap();
     final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
 
     calc.initialize(true,
       false,
       keySizes,
       200,
+      100,
       2,
-      20,
-      10,
-      20,
-      10,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
       10,
       5,
       maxBatchNumRecords,
@@ -69,26 +72,28 @@ public void testSimpleReserveMemoryCalculationNoHash() {
   @Test
   public void testSimpleReserveMemoryCalculationHash() {
     final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
     final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
       new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
         new HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
-        2.0,
-        1.5);
+        fragmentationFactor,
+        safetyFactor);
 
-    final CaseInsensitiveMap<Long> buildValueSizes = CaseInsensitiveMap.newHashMap();
-    final CaseInsensitiveMap<Long> probeValueSizes = CaseInsensitiveMap.newHashMap();
     final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
 
     calc.initialize(false,
       true,
       keySizes,
       350,
+      100, // Ignored for test
       2,
-      20,
-      10,
-      20,
-      10,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
       10,
       5,
       maxBatchNumRecords,
@@ -112,15 +117,17 @@ public void testSimpleReserveMemoryCalculationHash() {
   @Test
   public void testAdjustInitialPartitions() {
     final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
     final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
       new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
         new HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
-        2.0,
-        1.5);
+        fragmentationFactor,
+        safetyFactor);
 
-    final CaseInsensitiveMap<Long> buildValueSizes = CaseInsensitiveMap.newHashMap();
-    final CaseInsensitiveMap<Long> probeValueSizes = CaseInsensitiveMap.newHashMap();
     final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
 
     calc.initialize(
@@ -128,11 +135,11 @@ public void testAdjustInitialPartitions() {
       false,
       keySizes,
       200,
+      100, // Ignored for test
       4,
-      20,
-      10,
-      20,
-      10,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
       10,
       5,
       maxBatchNumRecords,
@@ -154,19 +161,148 @@ public void testAdjustInitialPartitions() {
     Assert.assertEquals(2, calc.getNumPartitions());
   }
 
+  @Test(expected = IllegalStateException.class)
+  public void testHasDataProbeEmpty() {
+    final int maxIncomingBatchSize = 100;
+    final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
+      new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
+        new HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
+        HashJoinHelperSizeCalculatorImpl.INSTANCE,
+        fragmentationFactor,
+        safetyFactor);
+
+    final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
+
+    calc.initialize(
+      true,
+      false,
+      keySizes,
+      240,
+      maxIncomingBatchSize,
+      4,
+      true,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
+      10,
+      5,
+      maxBatchNumRecords,
+      maxBatchNumRecords,
+      16000,
+      .75);
+  }
+
+  @Test
+  public void testNoProbeDataForStats() {
+    final int maxIncomingBatchSize = 100;
+    final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
+      new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
+        new HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
+        HashJoinHelperSizeCalculatorImpl.INSTANCE,
+        fragmentationFactor,
+        safetyFactor);
+
+    final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
+
+    calc.initialize(
+      true,
+      false,
+      keySizes,
+      240,
+      maxIncomingBatchSize,
+      4,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(),
+      10,
+      5,
+      maxBatchNumRecords,
+      maxBatchNumRecords,
+      16000,
+      .75);
+
+    final HashJoinMemoryCalculator.PartitionStatSet partitionStatSet =
+      new HashJoinMemoryCalculator.PartitionStatSet(new PartitionStatImpl(), new PartitionStatImpl());
+    calc.setPartitionStatSet(partitionStatSet);
+
+    long expectedReservedMemory = 60 // Max incoming batch size
+      + 2 * 30 // build side batch for each spilled partition
+      + maxIncomingBatchSize;
+    long actualReservedMemory = calc.getBuildReservedMemory();
+
+    Assert.assertEquals(expectedReservedMemory, actualReservedMemory);
+    Assert.assertEquals(2, calc.getNumPartitions());
+  }
+
+  @Test
+  public void testProbeEmpty() {
+    final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
+      new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
+        new HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
+        HashJoinHelperSizeCalculatorImpl.INSTANCE,
+        fragmentationFactor,
+        safetyFactor);
+
+    final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
+
+    calc.initialize(
+      true,
+      false,
+      keySizes,
+      200,
+      100, // Ignored for test
+      4,
+      true,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(),
+      10,
+      5,
+      maxBatchNumRecords,
+      maxBatchNumRecords,
+      16000,
+      .75);
+
+    final HashJoinMemoryCalculator.PartitionStatSet partitionStatSet =
+      new HashJoinMemoryCalculator.PartitionStatSet(new PartitionStatImpl(), new PartitionStatImpl(),
+        new PartitionStatImpl(), new PartitionStatImpl());
+    calc.setPartitionStatSet(partitionStatSet);
+
+    long expectedReservedMemory = 60 // Max incoming batch size
+      + 4 * 30; // build side batch for each spilled partition
+    long actualReservedMemory = calc.getBuildReservedMemory();
+
+    Assert.assertEquals(expectedReservedMemory, actualReservedMemory);
+    Assert.assertEquals(4, calc.getNumPartitions());
+  }
+
   @Test
   public void testNoRoomInMemoryForBatch1() {
     final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
 
     final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
       new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
         new HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
-        2.0,
-        1.5);
+        fragmentationFactor,
+        safetyFactor);
 
-    final CaseInsensitiveMap<Long> buildValueSizes = CaseInsensitiveMap.newHashMap();
-    final CaseInsensitiveMap<Long> probeValueSizes = CaseInsensitiveMap.newHashMap();
     final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
 
     calc.initialize(
@@ -174,11 +310,11 @@ public void testNoRoomInMemoryForBatch1() {
       false,
       keySizes,
       180,
+      100, // Ignored for test
       2,
-      20,
-      10,
-      20,
-      10,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
       10,
       5,
       maxBatchNumRecords,
@@ -207,15 +343,17 @@ public void testNoRoomInMemoryForBatch1() {
   @Test
   public void testCompleteLifeCycle() {
     final int maxBatchNumRecords = 20;
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
     final HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl calc =
       new HashJoinMemoryCalculatorImpl.BuildSidePartitioningImpl(
+        BatchSizePredictorImpl.Factory.INSTANCE,
         new HashTableSizeCalculatorConservativeImpl(RecordBatch.MAX_BATCH_SIZE, HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR),
         HashJoinHelperSizeCalculatorImpl.INSTANCE,
-        2.0,
-        1.5);
+        fragmentationFactor,
+        safetyFactor);
 
-    final CaseInsensitiveMap<Long> buildValueSizes = CaseInsensitiveMap.newHashMap();
-    final CaseInsensitiveMap<Long> probeValueSizes = CaseInsensitiveMap.newHashMap();
     final CaseInsensitiveMap<Long> keySizes = CaseInsensitiveMap.newHashMap();
 
     calc.initialize(
@@ -223,11 +361,11 @@ public void testCompleteLifeCycle() {
       false,
       keySizes,
       210,
+      100, // Ignored for test
       2,
-      20,
-      10,
-      20,
-      10,
+      false,
+      new MockBatchSizePredictor(20, 20, fragmentationFactor, safetyFactor),
+      new MockBatchSizePredictor(10, 10, fragmentationFactor, safetyFactor),
       10,
       5,
       maxBatchNumRecords,
@@ -276,4 +414,61 @@ public void testCompleteLifeCycle() {
 
     Assert.assertNotNull(calc.next());
   }
+
+  public static class MockBatchSizePredictor implements BatchSizePredictor {
+    private final boolean hasData;
+    private final long batchSize;
+    private final int numRecords;
+    private final double fragmentationFactor;
+    private final double safetyFactor;
+
+    public MockBatchSizePredictor() {
+      hasData = false;
+      batchSize = 0;
+      numRecords = 0;
+      fragmentationFactor = 0;
+      safetyFactor = 0;
+    }
+
+    public MockBatchSizePredictor(final long batchSize,
+                                  final int numRecords,
+                                  final double fragmentationFactor,
+                                  final double safetyFactor) {
+      hasData = true;
+      this.batchSize = batchSize;
+      this.numRecords = numRecords;
+      this.fragmentationFactor = fragmentationFactor;
+      this.safetyFactor = safetyFactor;
+    }
+
+    @Override
+    public long getBatchSize() {
+      return batchSize;
+    }
+
+    @Override
+    public int getNumRecords() {
+      return numRecords;
+    }
+
+    @Override
+    public boolean hadDataLastTime() {
+      return hasData;
+    }
+
+    @Override
+    public void updateStats() {
+    }
+
+    @Override
+    public long predictBatchSize(int desiredNumRecords, boolean reserveHash) {
+      Preconditions.checkState(hasData);
+      return BatchSizePredictorImpl.computeMaxBatchSize(batchSize,
+        numRecords,
+        desiredNumRecords,
+        fragmentationFactor,
+        safetyFactor,
+        reserveHash);
+    }
+  }
 }
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculatorImpl.java b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculator.java
similarity index 61%
rename from exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculatorImpl.java
rename to exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculator.java
index 4fe1fa4a8c4..b13829baa4d 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculatorImpl.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashJoinMemoryCalculator.java
@@ -17,54 +17,9 @@
  */
 package org.apache.drill.exec.physical.impl.join;
 
-import org.apache.drill.exec.vector.IntVector;
-import org.junit.Assert;
 import org.junit.Test;
 
-public class TestHashJoinMemoryCalculatorImpl {
-  @Test
-  public void testComputeMaxBatchSizeNoHash() {
-    final long expected = 1200;
-    final long actual = HashJoinMemoryCalculatorImpl.computeMaxBatchSize(
-      100,
-      25,
-      100,
-      2.0,
-      1.5,
-      false);
-    final long actualNoHash = HashJoinMemoryCalculatorImpl.computeMaxBatchSizeNoHash(
-      100,
-      25,
-      100,
-      2.0,
-      1.5);
-
-    Assert.assertEquals(expected, actual);
-    Assert.assertEquals(expected, actualNoHash);
-  }
-
-  @Test
-  public void testComputeMaxBatchSizeHash()
-  {
-    long expected = HashJoinMemoryCalculatorImpl.computeMaxBatchSizeNoHash(
-      100,
-      25,
-      100,
-      2.0,
-      4.0) +
-      100 * IntVector.VALUE_WIDTH * 2;
-
-    final long actual = HashJoinMemoryCalculatorImpl.computeMaxBatchSize(
-      100,
-      25,
-      100,
-      2.0,
-      4.0,
-      true);
-
-    Assert.assertEquals(expected, actual);
-  }
-
+public class TestHashJoinMemoryCalculator {
   @Test // Make sure no exception is thrown
   public void testMakeDebugString()
   {
@@ -78,5 +33,7 @@ public void testMakeDebugString()
     partitionStat1.add(new HashJoinMemoryCalculator.BatchStat(10, 7));
     partitionStat2.add(new HashJoinMemoryCalculator.BatchStat(11, 20));
     partitionStat3.spill();
+
+    partitionStatSet.makeDebugString();
   }
 }
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorConservativeImpl.java b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorConservativeImpl.java
index 3f01bca511d..813fc353c41 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorConservativeImpl.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorConservativeImpl.java
@@ -42,14 +42,14 @@ public void testCalculateHashTableSize() {
     long expected = RecordBatchSizer.multiplyByFactor(
       UInt4Vector.VALUE_WIDTH * 128, HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
     // First bucket key value vector sizes
-    expected += HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(maxNumRecords, 3L);
-    expected += HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(maxNumRecords, 8L);
+    expected += BatchSizePredictorImpl.computeValueVectorSize(maxNumRecords, 3L);
+    expected += BatchSizePredictorImpl.computeValueVectorSize(maxNumRecords, 8L);
 
     // Second bucket key value vector sizes
     expected += RecordBatchSizer.multiplyByFactor(
-      HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(20, 3L), HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
+      BatchSizePredictorImpl.computeValueVectorSize(20, 3L), HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
     expected += RecordBatchSizer.multiplyByFactor(
-      HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(20, 8L), HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
+      BatchSizePredictorImpl.computeValueVectorSize(20, 8L), HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
 
     // Overhead vectors for links and hash values for each batchHolder
     expected += 2 * UInt4Vector.VALUE_WIDTH // links and hash values */
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorLeanImpl.java b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorLeanImpl.java
index 1bd51fc0f0c..3390ceaadcf 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorLeanImpl.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestHashTableSizeCalculatorLeanImpl.java
@@ -42,13 +42,13 @@ public void testCalculateHashTableSize() {
     long expected = RecordBatchSizer.multiplyByFactor(
       UInt4Vector.VALUE_WIDTH * 128, HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
     // First bucket key value vector sizes
-    expected += HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(maxNumRecords, 3L);
-    expected += HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(maxNumRecords, 8L);
+    expected += BatchSizePredictorImpl.computeValueVectorSize(maxNumRecords, 3L);
+    expected += BatchSizePredictorImpl.computeValueVectorSize(maxNumRecords, 8L);
 
     // Second bucket key value vector sizes
-    expected += HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(20, 3L);
+    expected += BatchSizePredictorImpl.computeValueVectorSize(20, 3L);
     expected += RecordBatchSizer.multiplyByFactor(
-      HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(20, 8L), HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
+      BatchSizePredictorImpl.computeValueVectorSize(20, 8L), HashTableSizeCalculatorConservativeImpl.HASHTABLE_DOUBLING_FACTOR);
 
     // Overhead vectors for links and hash values for each batchHolder
     expected += 2 * UInt4Vector.VALUE_WIDTH // links and hash values */
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestPostBuildCalculationsImpl.java b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestPostBuildCalculationsImpl.java
index 5cf7eca2b27..aa7a4354ba6 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestPostBuildCalculationsImpl.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/join/TestPostBuildCalculationsImpl.java
@@ -17,44 +17,229 @@
  */
 package org.apache.drill.exec.physical.impl.join;
 
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
 import org.junit.Assert;
 import org.junit.Test;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Map;
 
 public class TestPostBuildCalculationsImpl {
   @Test
-  public void testRoundUpPowerOf2() {
-    long expected = 32;
-    long actual = HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.roundUpToPowerOf2(expected);
+  public void testProbeTooBig() {
+    final int minProbeRecordsPerBatch = 10;
 
-    Assert.assertEquals(expected, actual);
+    final int computedProbeRecordsPerBatch =
+      HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeProbeRecordsPerBatch(
+        100,
+        2,
+        100,
+        minProbeRecordsPerBatch,
+        70,
+        40,
+        200);
+
+    Assert.assertEquals(minProbeRecordsPerBatch, computedProbeRecordsPerBatch);
+  }
+
+  @Test
+  public void testComputedShouldBeMin() {
+    final int minProbeRecordsPerBatch = 10;
+
+    final int computedProbeRecordsPerBatch =
+      HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeProbeRecordsPerBatch(
+        100,
+        2,
+        100,
+        minProbeRecordsPerBatch,
+        50,
+        40,
+        200);
+
+    Assert.assertEquals(minProbeRecordsPerBatch, computedProbeRecordsPerBatch);
   }
 
   @Test
-  public void testRounUpNonPowerOf2ToPowerOf2() {
-    long expected = 32;
-    long actual = HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.roundUpToPowerOf2(31);
+  public void testComputedProbeRecordsPerBatch() {
+    final int minProbeRecordsPerBatch = 10;
+
+    final int computedProbeRecordsPerBatch =
+      HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeProbeRecordsPerBatch(
+        200,
+        2,
+        100,
+        minProbeRecordsPerBatch,
+        50,
+        50,
+        200);
+
+    Assert.assertEquals(25, computedProbeRecordsPerBatch);
+  }
+
+  @Test
+  public void testComputedProbeRecordsPerBatchRoundUp() {
+    final int minProbeRecordsPerBatch = 10;
+
+    final int computedProbeRecordsPerBatch =
+      HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeProbeRecordsPerBatch(
+        200,
+        2,
+        100,
+        minProbeRecordsPerBatch,
+        50,
+        51,
+        199);
+
+    Assert.assertEquals(25, computedProbeRecordsPerBatch);
+  }
+
+  @Test(expected = IllegalStateException.class)
+  public void testHasProbeDataButProbeEmpty() {
+    final Map<String, Long> keySizes = org.apache.drill.common.map.CaseInsensitiveMap.newHashMap();
+
+    final PartitionStatImpl partition1 = new PartitionStatImpl();
+    final PartitionStatImpl partition2 = new PartitionStatImpl();
+    final HashJoinMemoryCalculator.PartitionStatSet buildPartitionStatSet =
+      new HashJoinMemoryCalculator.PartitionStatSet(partition1, partition2);
+
+    final int recordsPerPartitionBatchBuild = 10;
+
+    addBatches(partition1, recordsPerPartitionBatchBuild,
+      10, 4);
+    addBatches(partition2, recordsPerPartitionBatchBuild,
+      10, 4);
 
-    Assert.assertEquals(expected, actual);
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
+
+    final HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
+      new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
+        290, // memoryAvailable
+        20, // maxOutputBatchSize
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
+        buildPartitionStatSet, // buildPartitionStatSet
+        keySizes, // keySizes
+        new MockHashTableSizeCalculator(10), // hashTableSizeCalculator
+        new MockHashJoinHelperSizeCalculator(10), // hashJoinHelperSizeCalculator
+        fragmentationFactor, // fragmentationFactor
+        safetyFactor, // safetyFactor
+        .75, // loadFactor
+        false); // reserveHash
+
+    calc.initialize(true);
   }
 
   @Test
-  public void testComputeValueVectorSizePowerOf2() {
-    long expected = 4;
-    long actual =
-      HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(2, 2);
+  public void testProbeEmpty() {
+    final Map<String, Long> keySizes = org.apache.drill.common.map.CaseInsensitiveMap.newHashMap();
+
+    final PartitionStatImpl partition1 = new PartitionStatImpl();
+    final PartitionStatImpl partition2 = new PartitionStatImpl();
+    final HashJoinMemoryCalculator.PartitionStatSet buildPartitionStatSet =
+      new HashJoinMemoryCalculator.PartitionStatSet(partition1, partition2);
+
+    final int recordsPerPartitionBatchBuild = 10;
+
+    addBatches(partition1, recordsPerPartitionBatchBuild,
+      10, 4);
+    addBatches(partition2, recordsPerPartitionBatchBuild,
+      10, 4);
 
-    Assert.assertEquals(expected, actual);
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 40;
+    final long maxProbeBatchSize = 10000;
+
+    final HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
+      new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(),
+        50,
+        1000,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
+        buildPartitionStatSet,
+        keySizes,
+        new MockHashTableSizeCalculator(10),
+        new MockHashJoinHelperSizeCalculator(10),
+        fragmentationFactor,
+        safetyFactor,
+        .75,
+        true);
+
+    calc.initialize(true);
+
+    Assert.assertFalse(calc.shouldSpill());
+    Assert.assertFalse(calc.shouldSpill());
   }
 
   @Test
-  public void testComputeValueVectorSizeNonPowerOf2() {
-    long expected = 16;
-    long actual =
-      HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl.computeValueVectorSize(3, 3);
+  public void testHasNoProbeDataButProbeNonEmpty() {
+    final Map<String, Long> keySizes = org.apache.drill.common.map.CaseInsensitiveMap.newHashMap();
 
-    Assert.assertEquals(expected, actual);
+    final PartitionStatImpl partition1 = new PartitionStatImpl();
+    final PartitionStatImpl partition2 = new PartitionStatImpl();
+    final HashJoinMemoryCalculator.PartitionStatSet buildPartitionStatSet =
+      new HashJoinMemoryCalculator.PartitionStatSet(partition1, partition2);
+
+    final int recordsPerPartitionBatchBuild = 10;
+
+    addBatches(partition1, recordsPerPartitionBatchBuild,
+      10, 4);
+    addBatches(partition2, recordsPerPartitionBatchBuild,
+      10, 4);
+
+    final double fragmentationFactor = 2.0;
+    final double safetyFactor = 1.5;
+
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
+
+    final HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
+      new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          false),
+        290,
+        20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
+        buildPartitionStatSet,
+        keySizes,
+        new MockHashTableSizeCalculator(10),
+        new MockHashJoinHelperSizeCalculator(10),
+        fragmentationFactor,
+        safetyFactor,
+        .75,
+        false);
+
+    calc.initialize(false);
+
+    long expected = 60 // maxProbeBatchSize
+      + 160 // in memory partitions
+      + 20 // max output batch size
+      + 2 * 10 // Hash Table
+      + 2 * 10; // Hash join helper
+    Assert.assertFalse(calc.shouldSpill());
+    Assert.assertEquals(expected, calc.getConsumedMemory());
+    Assert.assertNull(calc.next());
   }
 
   @Test
@@ -76,12 +261,21 @@ public void testProbingAndPartitioningBuildAllInMemoryNoSpill() {
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
 
-    HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
+
+    final HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         290,
-        15,
-        60,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(10),
@@ -91,7 +285,7 @@ public void testProbingAndPartitioningBuildAllInMemoryNoSpill() {
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
 
     long expected = 60 // maxProbeBatchSize
       + 160 // in memory partitions
@@ -122,12 +316,21 @@ public void testProbingAndPartitioningBuildAllInMemorySpill() {
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
 
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
+
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         270,
-        15,
-        60,
         20,
+         maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(10),
@@ -137,7 +340,7 @@ public void testProbingAndPartitioningBuildAllInMemorySpill() {
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
 
     long expected = 60 // maxProbeBatchSize
       + 160 // in memory partitions
@@ -174,12 +377,21 @@ public void testProbingAndPartitioningBuildAllInMemoryNoSpillWithHash() {
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
 
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
+
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         180,
-        15,
-        60,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(10),
@@ -189,7 +401,7 @@ public void testProbingAndPartitioningBuildAllInMemoryNoSpillWithHash() {
         .75,
         true);
 
-    calc.initialize();
+    calc.initialize(false);
 
     long expected = 60 // maxProbeBatchSize
       + 2 * 5 * 3 // partition batches
@@ -215,15 +427,24 @@ public void testProbingAndPartitioningBuildAllInMemoryWithSpill() {
 
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
+
     final long hashTableSize = 10;
     final long hashJoinHelperSize = 10;
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
 
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         200,
-        15,
-        60,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(hashTableSize),
@@ -233,7 +454,7 @@ public void testProbingAndPartitioningBuildAllInMemoryWithSpill() {
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
 
     long expected = 60 // maxProbeBatchSize
       + 80 // in memory partition
@@ -269,15 +490,24 @@ public void testProbingAndPartitioningBuildSomeInMemory() {
 
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
+
     final long hashTableSize = 10;
     final long hashJoinHelperSize = 10;
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
 
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         230,
-        15,
-        60,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(hashTableSize),
@@ -287,7 +517,7 @@ public void testProbingAndPartitioningBuildSomeInMemory() {
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
 
     long expected = 60 // maxProbeBatchSize
       + 80 // in memory partition
@@ -317,15 +547,24 @@ public void testProbingAndPartitioningBuildNoneInMemory() {
 
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
+
     final long hashTableSize = 10;
     final long hashJoinHelperSize = 10;
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
 
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
-        100,
-        15,
-        60,
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
+        110,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(hashTableSize),
@@ -335,7 +574,7 @@ public void testProbingAndPartitioningBuildNoneInMemory() {
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
     Assert.assertFalse(calc.shouldSpill());
     Assert.assertEquals(110, calc.getConsumedMemory());
     Assert.assertNotNull(calc.next());
@@ -362,15 +601,24 @@ public void testMakeDebugString()
 
     final double fragmentationFactor = 2.0;
     final double safetyFactor = 1.5;
+
     final long hashTableSize = 10;
     final long hashJoinHelperSize = 10;
+    final int maxBatchNumRecordsProbe = 3;
+    final int recordsPerPartitionBatchProbe = 5;
+    final long partitionProbeBatchSize = 15;
+    final long maxProbeBatchSize = 60;
 
     HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl calc =
       new HashJoinMemoryCalculatorImpl.PostBuildCalculationsImpl(
+        new ConditionalMockBatchSizePredictor(
+          Lists.newArrayList(maxBatchNumRecordsProbe, recordsPerPartitionBatchProbe),
+          Lists.newArrayList(maxProbeBatchSize, partitionProbeBatchSize),
+          true),
         230,
-        15,
-        60,
         20,
+        maxBatchNumRecordsProbe,
+        recordsPerPartitionBatchProbe,
         buildPartitionStatSet,
         keySizes,
         new MockHashTableSizeCalculator(hashTableSize),
@@ -380,7 +628,7 @@ public void testMakeDebugString()
         .75,
         false);
 
-    calc.initialize();
+    calc.initialize(false);
   }
 
   private void addBatches(PartitionStatImpl partitionStat,
@@ -431,4 +679,66 @@ public long calculateSize(HashJoinMemoryCalculator.PartitionStat partitionStat,
       return size;
     }
   }
+
+  public static class ConditionalMockBatchSizePredictor implements BatchSizePredictor {
+    private final List<Integer> recordsPerBatch;
+    private final List<Long> batchSize;
+
+    private boolean hasData;
+    private boolean updateable;
+
+    public ConditionalMockBatchSizePredictor() {
+      recordsPerBatch = new ArrayList<>();
+      batchSize = new ArrayList<>();
+      hasData = false;
+      updateable = true;
+    }
+
+    public ConditionalMockBatchSizePredictor(final List<Integer> recordsPerBatch,
+                                             final List<Long> batchSize,
+                                             final boolean hasData) {
+      this.recordsPerBatch = Preconditions.checkNotNull(recordsPerBatch);
+      this.batchSize = Preconditions.checkNotNull(batchSize);
+
+      Preconditions.checkArgument(recordsPerBatch.size() == batchSize.size());
+
+      this.hasData = hasData;
+      updateable = true;
+    }
+
+    @Override
+    public long getBatchSize() {
+      return 0;
+    }
+
+    @Override
+    public int getNumRecords() {
+      return 0;
+    }
+
+    @Override
+    public boolean hadDataLastTime() {
+      return hasData;
+    }
+
+    @Override
+    public void updateStats() {
+      Preconditions.checkState(updateable);
+      updateable = false;
+      hasData = true;
+    }
+
+    @Override
+    public long predictBatchSize(int desiredNumRecords, boolean reserveHash) {
+      Preconditions.checkState(hasData);
+
+      for (int index = 0; index < recordsPerBatch.size(); index++) {
+        if (desiredNumRecords == recordsPerBatch.get(index)) {
+          return batchSize.get(index);
+        }
+      }
+
+      throw new IllegalArgumentException();
+    }
+  }
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services