You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@drill.apache.org by GitBox <gi...@apache.org> on 2018/07/02 07:45:46 UTC

[GitHub] asfgit closed pull request #1324: DRILL-6310: limit batch size for hash aggregate

asfgit closed pull request #1324: DRILL-6310: limit batch size for hash aggregate
URL: https://github.com/apache/drill/pull/1324
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggBatch.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggBatch.java
index 57e9bd7d0c..d37631be45 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggBatch.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggBatch.java
@@ -19,15 +19,19 @@
 
 import java.io.IOException;
 import java.util.List;
+import java.util.Map;
 
 import com.google.common.collect.Lists;
 import org.apache.drill.common.exceptions.UserException;
 import org.apache.drill.common.expression.ErrorCollector;
 import org.apache.drill.common.expression.ErrorCollectorImpl;
+import org.apache.drill.common.expression.FunctionCall;
 import org.apache.drill.common.expression.FunctionHolderExpression;
 import org.apache.drill.common.expression.IfExpression;
 import org.apache.drill.common.expression.LogicalExpression;
+import org.apache.drill.common.expression.SchemaPath;
 import org.apache.drill.common.logical.data.NamedExpression;
+import org.apache.drill.common.map.CaseInsensitiveMap;
 import org.apache.drill.exec.ExecConstants;
 import org.apache.drill.exec.compile.sig.GeneratorMapping;
 import org.apache.drill.exec.compile.sig.MappingSet;
@@ -49,11 +53,14 @@
 import org.apache.drill.exec.record.BatchSchema.SelectionVectorMode;
 import org.apache.drill.exec.record.MaterializedField;
 import org.apache.drill.exec.record.RecordBatch;
+import org.apache.drill.exec.record.RecordBatchMemoryManager;
+import org.apache.drill.exec.record.RecordBatchSizer;
 import org.apache.drill.exec.record.TypedFieldId;
 import org.apache.drill.exec.record.VectorWrapper;
 import org.apache.drill.exec.record.selection.SelectionVector2;
 import org.apache.drill.exec.record.selection.SelectionVector4;
 import org.apache.drill.exec.vector.AllocationHelper;
+import org.apache.drill.exec.vector.FixedWidthVector;
 import org.apache.drill.exec.vector.ValueVector;
 
 import com.sun.codemodel.JExpr;
@@ -71,6 +78,12 @@
   private BatchSchema incomingSchema;
   private boolean wasKilled;
 
+  private int numGroupByExprs, numAggrExprs;
+
+  // This map saves the mapping between outgoing column and incoming column.
+  private Map<String, String> columnMapping;
+  private final HashAggMemoryManager hashAggMemoryManager;
+
   private final GeneratorMapping UPDATE_AGGR_INSIDE =
       GeneratorMapping.create("setupInterior" /* setup method */, "updateAggrValuesInternal" /* eval method */,
           "resetValues" /* reset */, "cleanup" /* cleanup */);
@@ -84,6 +97,67 @@
           "htRowIdx" /* workspace index */, "incoming" /* read container */, "outgoing" /* write container */,
           "aggrValuesContainer" /* workspace container */, UPDATE_AGGR_INSIDE, UPDATE_AGGR_OUTSIDE, UPDATE_AGGR_INSIDE);
 
+  public int getOutputRowCount() {
+    return hashAggMemoryManager.getOutputRowCount();
+  }
+
+  public RecordBatchMemoryManager getRecordBatchMemoryManager() {
+    return hashAggMemoryManager;
+  }
+
+  private class HashAggMemoryManager extends RecordBatchMemoryManager {
+    private int valuesRowWidth = 0;
+
+    HashAggMemoryManager(int outputBatchSize) {
+      super(outputBatchSize);
+    }
+
+    @Override
+    public void update() {
+      // Get sizing information for the batch.
+      setRecordBatchSizer(new RecordBatchSizer(incoming));
+
+      int fieldId = 0;
+      int newOutgoingRowWidth = 0;
+      for (VectorWrapper<?> w : container) {
+        if (w.getValueVector() instanceof FixedWidthVector) {
+          newOutgoingRowWidth += ((FixedWidthVector) w.getValueVector()).getValueWidth();
+          if (fieldId >= numGroupByExprs) {
+            valuesRowWidth += ((FixedWidthVector) w.getValueVector()).getValueWidth();
+          }
+        } else {
+          int columnWidth;
+          if (columnMapping.get(w.getValueVector().getField().getName()) == null) {
+             columnWidth = TypeHelper.getSize(w.getField().getType());
+          } else {
+            RecordBatchSizer.ColumnSize columnSize = getRecordBatchSizer().getColumn(columnMapping.get(w.getValueVector().getField().getName()));
+            if (columnSize == null) {
+              columnWidth = TypeHelper.getSize(w.getField().getType());
+            } else {
+              columnWidth = columnSize.getAllocSizePerEntry();
+            }
+          }
+          newOutgoingRowWidth += columnWidth;
+          if (fieldId >= numGroupByExprs) {
+            valuesRowWidth += columnWidth;
+          }
+        }
+        fieldId++;
+      }
+
+      if (updateIfNeeded(newOutgoingRowWidth)) {
+        // There is an update to outgoing row width.
+        // un comment this if we want to adjust the batch row count of in flight batches.
+        // To keep things simple, we are not doing this adjustment for now.
+        // aggregator.adjustOutputCount(getOutputBatchSize(), getOutgoingRowWidth(), newOutgoingRowWidth);
+      }
+
+      updateIncomingStats();
+      if (logger.isDebugEnabled()) {
+        logger.debug("BATCH_STATS, incoming: {}", getRecordBatchSizer());
+      }
+    }
+  }
 
   public HashAggBatch(HashAggregate popConfig, RecordBatch incoming, FragmentContext context) {
     super(popConfig, context);
@@ -103,6 +177,13 @@ public HashAggBatch(HashAggregate popConfig, RecordBatch incoming, FragmentConte
 
     boolean allowed = oContext.getAllocator().setLenient();
     logger.debug("Config: Is allocator lenient? {}", allowed);
+
+    // get the output batch size from config.
+    int configuredBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
+    hashAggMemoryManager = new HashAggMemoryManager(configuredBatchSize);
+    logger.debug("BATCH_STATS, configured output batch size: {}", configuredBatchSize);
+
+    columnMapping = CaseInsensitiveMap.newHashMap();
   }
 
   @Override
@@ -136,6 +217,9 @@ public void buildSchema() throws SchemaChangeException {
     for (VectorWrapper<?> w : container) {
       AllocationHelper.allocatePrecomputedChildCount(w.getValueVector(), 0, 0, 0);
     }
+    if (incoming.getRecordCount() > 0) {
+      hashAggMemoryManager.update();
+    }
   }
 
   @Override
@@ -239,8 +323,8 @@ private HashAggregator createAggregatorInternal() throws SchemaChangeException,
     // top.saveCodeForDebugging(true);
     container.clear();
 
-    int numGroupByExprs = (popConfig.getGroupByExprs() != null) ? popConfig.getGroupByExprs().size() : 0;
-    int numAggrExprs = (popConfig.getAggrExprs() != null) ? popConfig.getAggrExprs().size() : 0;
+    numGroupByExprs = (popConfig.getGroupByExprs() != null) ? popConfig.getGroupByExprs().size() : 0;
+    numAggrExprs = (popConfig.getAggrExprs() != null) ? popConfig.getAggrExprs().size() : 0;
     aggrExprs = new LogicalExpression[numAggrExprs];
     groupByOutFieldIds = new TypedFieldId[numGroupByExprs];
     aggrOutFieldIds = new TypedFieldId[numAggrExprs];
@@ -263,13 +347,13 @@ private HashAggregator createAggregatorInternal() throws SchemaChangeException,
 
       // add this group-by vector to the output container
       groupByOutFieldIds[i] = container.add(vv);
+      columnMapping.put(outputField.getName(), ne.getExpr().toString().replace('`',' ').trim());
     }
 
     int extraNonNullColumns = 0; // each of SUM, MAX and MIN gets an extra bigint column
     for (i = 0; i < numAggrExprs; i++) {
       NamedExpression ne = popConfig.getAggrExprs().get(i);
-      final LogicalExpression expr =
-          ExpressionTreeMaterializer.materialize(ne.getExpr(), incoming, collector, context.getFunctionRegistry());
+      final LogicalExpression expr = ExpressionTreeMaterializer.materialize(ne.getExpr(), incoming, collector, context.getFunctionRegistry());
 
       if (expr instanceof IfExpression) {
         throw UserException.unsupportedError(new UnsupportedOperationException("Union type not supported in aggregate functions")).build(logger);
@@ -283,16 +367,28 @@ private HashAggregator createAggregatorInternal() throws SchemaChangeException,
         continue;
       }
 
-      if ( expr instanceof FunctionHolderExpression ) {
-         String funcName = ((FunctionHolderExpression) expr).getName();
-         if ( funcName.equals("sum") || funcName.equals("max") || funcName.equals("min") ) {extraNonNullColumns++;}
-      }
       final MaterializedField outputField = MaterializedField.create(ne.getRef().getAsNamePart().getName(), expr.getMajorType());
-      @SuppressWarnings("resource")
-      ValueVector vv = TypeHelper.getNewVector(outputField, oContext.getAllocator());
+      @SuppressWarnings("resource") ValueVector vv = TypeHelper.getNewVector(outputField, oContext.getAllocator());
       aggrOutFieldIds[i] = container.add(vv);
 
       aggrExprs[i] = new ValueVectorWriteExpression(aggrOutFieldIds[i], expr, true);
+
+      if (expr instanceof FunctionHolderExpression) {
+        String funcName = ((FunctionHolderExpression) expr).getName();
+        if (funcName.equals("sum") || funcName.equals("max") || funcName.equals("min")) {
+          extraNonNullColumns++;
+        }
+        if (((FunctionCall) ne.getExpr()).args.get(0) instanceof SchemaPath) {
+          columnMapping.put(outputField.getName(), ((SchemaPath) ((FunctionCall) ne.getExpr()).args.get(0)).getAsNamePart().getName());
+        }  else if (((FunctionCall) ne.getExpr()).args.get(0) instanceof FunctionCall) {
+          FunctionCall functionCall = (FunctionCall) ((FunctionCall) ne.getExpr()).args.get(0);
+          if (functionCall.args.get(0) instanceof SchemaPath) {
+            columnMapping.put(outputField.getName(), ((SchemaPath) functionCall.args.get(0)).getAsNamePart().getName());
+          }
+        }
+      } else {
+        columnMapping.put(outputField.getName(), ne.getRef().getAsNamePart().getName());
+      }
     }
 
     setupUpdateAggrValues(cgInner);
@@ -345,11 +441,32 @@ private void setupGetIndex(ClassGenerator<HashAggregator> cg) {
     }
   }
 
+  private void updateStats() {
+    stats.setLongStat(HashAggTemplate.Metric.INPUT_BATCH_COUNT, hashAggMemoryManager.getNumIncomingBatches());
+    stats.setLongStat(HashAggTemplate.Metric.AVG_INPUT_BATCH_BYTES, hashAggMemoryManager.getAvgInputBatchSize());
+    stats.setLongStat(HashAggTemplate.Metric.AVG_INPUT_ROW_BYTES, hashAggMemoryManager.getAvgInputRowWidth());
+    stats.setLongStat(HashAggTemplate.Metric.INPUT_RECORD_COUNT, hashAggMemoryManager.getTotalInputRecords());
+    stats.setLongStat(HashAggTemplate.Metric.OUTPUT_BATCH_COUNT, hashAggMemoryManager.getNumOutgoingBatches());
+    stats.setLongStat(HashAggTemplate.Metric.AVG_OUTPUT_BATCH_BYTES, hashAggMemoryManager.getAvgOutputBatchSize());
+    stats.setLongStat(HashAggTemplate.Metric.AVG_OUTPUT_ROW_BYTES, hashAggMemoryManager.getAvgOutputRowWidth());
+    stats.setLongStat(HashAggTemplate.Metric.OUTPUT_RECORD_COUNT, hashAggMemoryManager.getTotalOutputRecords());
+
+    if (logger.isDebugEnabled()) {
+      logger.debug("BATCH_STATS, incoming aggregate: count : {}, avg bytes : {},  avg row bytes : {}, record count : {}",
+        hashAggMemoryManager.getNumIncomingBatches(), hashAggMemoryManager.getAvgInputBatchSize(),
+        hashAggMemoryManager.getAvgInputRowWidth(), hashAggMemoryManager.getTotalInputRecords());
+
+      logger.debug("BATCH_STATS, outgoing aggregate: count : {}, avg bytes : {},  avg row bytes : {}, record count : {}",
+        hashAggMemoryManager.getNumOutgoingBatches(), hashAggMemoryManager.getAvgOutputBatchSize(),
+        hashAggMemoryManager.getAvgOutputRowWidth(), hashAggMemoryManager.getTotalOutputRecords());
+    }
+  }
   @Override
   public void close() {
     if (aggregator != null) {
       aggregator.cleanup();
     }
+    updateStats();
     super.close();
   }
 
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggTemplate.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggTemplate.java
index 3b50471db1..2f3bc23da3 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggTemplate.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggTemplate.java
@@ -79,6 +79,7 @@
 
 import org.apache.drill.exec.vector.VariableWidthVector;
 
+import static org.apache.drill.exec.physical.impl.common.HashTable.BATCH_MASK;
 import static org.apache.drill.exec.record.RecordBatch.MAX_BATCH_SIZE;
 
 public abstract class HashAggTemplate implements HashAggregator {
@@ -159,8 +160,6 @@
   private int operatorId; // for the spill file name
 
   private IndexPointer htIdxHolder; // holder for the Hashtable's internal index returned by put()
-  private IndexPointer outStartIdxHolder;
-  private IndexPointer outNumRecordsHolder;
   private int numGroupByOutFields = 0; // Note: this should be <= number of group-by fields
   private TypedFieldId[] groupByOutFieldIds;
 
@@ -185,7 +184,15 @@
                       // then later re-read. So, disk I/O is twice this amount.
                       // For first phase aggr -- this is an estimate of the amount of data
                       // returned early (analogous to a spill in the 2nd phase).
-    SPILL_CYCLE       // 0 - no spill, 1 - spill, 2 - SECONDARY, 3 - TERTIARY
+    SPILL_CYCLE,       // 0 - no spill, 1 - spill, 2 - SECONDARY, 3 - TERTIARY
+    INPUT_BATCH_COUNT,
+    AVG_INPUT_BATCH_BYTES,
+    AVG_INPUT_ROW_BYTES,
+    INPUT_RECORD_COUNT,
+    OUTPUT_BATCH_COUNT,
+    AVG_OUTPUT_BATCH_BYTES,
+    AVG_OUTPUT_ROW_BYTES,
+    OUTPUT_RECORD_COUNT;
     ;
 
     @Override
@@ -195,16 +202,29 @@ public int metricId() {
   }
 
   public class BatchHolder {
-
     private VectorContainer aggrValuesContainer; // container for aggr values (workspace variables)
     private int maxOccupiedIdx = -1;
-    private int batchOutputCount = 0;
+    private int targetBatchRowCount = 0;
+
+    public int getTargetBatchRowCount() {
+      return targetBatchRowCount;
+    }
+
+    public void setTargetBatchRowCount(int batchRowCount) {
+      this.targetBatchRowCount = batchRowCount;
+    }
+
+    public int getCurrentRowCount() {
+      return (maxOccupiedIdx + 1);
+    }
 
     @SuppressWarnings("resource")
-    public BatchHolder() {
+    public BatchHolder(int batchRowCount) {
 
       aggrValuesContainer = new VectorContainer();
       boolean success = false;
+      this.targetBatchRowCount = batchRowCount;
+
       try {
         ValueVector vector;
 
@@ -220,12 +240,12 @@ public BatchHolder() {
           // BatchHolder in HashTable, causing the HashTable to be space inefficient. So it is better to allocate space
           // to fit as close to as BATCH_SIZE records.
           if (vector instanceof FixedWidthVector) {
-            ((FixedWidthVector) vector).allocateNew(HashTable.BATCH_SIZE);
+            ((FixedWidthVector) vector).allocateNew(batchRowCount);
           } else if (vector instanceof VariableWidthVector) {
             // This case is never used .... a varchar falls under ObjectVector which is allocated on the heap !
-            ((VariableWidthVector) vector).allocateNew(maxColumnWidth, HashTable.BATCH_SIZE);
+            ((VariableWidthVector) vector).allocateNew(maxColumnWidth, batchRowCount);
           } else if (vector instanceof ObjectVector) {
-            ((ObjectVector) vector).allocateNew(HashTable.BATCH_SIZE);
+            ((ObjectVector) vector).allocateNew(batchRowCount);
           } else {
             vector.allocateNew();
           }
@@ -252,17 +272,12 @@ private void setup() {
       catch (SchemaChangeException sc) { throw new UnsupportedOperationException(sc);}
     }
 
-    private void outputValues(IndexPointer outStartIdxHolder, IndexPointer outNumRecordsHolder) {
-      outStartIdxHolder.value = batchOutputCount;
-      outNumRecordsHolder.value = 0;
-      for (int i = batchOutputCount; i <= maxOccupiedIdx; i++) {
-        try { outputRecordValues(i, batchOutputCount); }
-        catch (SchemaChangeException sc) { throw new UnsupportedOperationException(sc);}
-        if (EXTRA_DEBUG_2) {
-          logger.debug("Outputting values to output index: {}", batchOutputCount);
+    private void outputValues() {
+      for (int i = 0; i <= maxOccupiedIdx; i++) {
+        try {
+          outputRecordValues(i, i);
         }
-        batchOutputCount++;
-        outNumRecordsHolder.value++;
+        catch (SchemaChangeException sc) { throw new UnsupportedOperationException(sc);}
       }
     }
 
@@ -275,7 +290,7 @@ private int getNumGroups() {
     }
 
     private int getNumPendingOutput() {
-      return getNumGroups() - batchOutputCount;
+      return getNumGroups();
     }
 
     // Code-generated methods (implemented in HashAggBatch)
@@ -349,9 +364,6 @@ public void setup(HashAggregate hashAggrConfig, HashTableConfig htConfig, Fragme
     }
 
     this.htIdxHolder = new IndexPointer();
-    this.outStartIdxHolder = new IndexPointer();
-    this.outNumRecordsHolder = new IndexPointer();
-
     materializedValueFields = new MaterializedField[valueFieldIds.size()];
 
     if (valueFieldIds.size() > 0) {
@@ -513,7 +525,7 @@ private void initializeSetup(RecordBatch newIncoming) throws SchemaChangeExcepti
   private void updateEstMaxBatchSize(RecordBatch incoming) {
     if ( estMaxBatchSize > 0 ) { return; }  // no handling of a schema (or varchar) change
     // Use the sizer to get the input row width and the length of the longest varchar column
-    RecordBatchSizer sizer = new RecordBatchSizer(incoming);
+    RecordBatchSizer sizer = outgoing.getRecordBatchMemoryManager().getRecordBatchSizer();
     logger.trace("Incoming sizer: {}",sizer);
     // An empty batch only has the schema, can not tell actual length of varchars
     // else use the actual varchars length, each capped at 50 (to match the space allocation)
@@ -654,6 +666,8 @@ public AggOutcome doWork() {
           // remember EMIT, but continue like handling OK
 
         case OK:
+          outgoing.getRecordBatchMemoryManager().update();
+
           currentBatchRecordCount = incoming.getRecordCount(); // size of next batch
 
           resetIndex(); // initialize index (a new batch needs to be processed)
@@ -789,6 +803,22 @@ public int getOutputCount() {
     return lastBatchOutputCount;
   }
 
+  @Override
+  public void adjustOutputCount(int outputBatchSize, int oldRowWidth, int newRowWidth) {
+    for (int i = 0; i < numPartitions; i++ ) {
+      if (batchHolders[i] == null || batchHolders[i].size() == 0) {
+        continue;
+      }
+      BatchHolder bh = batchHolders[i].get(batchHolders[i].size()-1);
+      // Divide remaining memory by new row width.
+      final int remainingRows = RecordBatchSizer.safeDivide(Math.max((outputBatchSize - (bh.getCurrentRowCount() * oldRowWidth)), 0), newRowWidth);
+      // Do not go beyond the current target row count as this might cause reallocs for fixed width vectors.
+      final int newRowCount = Math.min(bh.getTargetBatchRowCount(), bh.getCurrentRowCount() + remainingRows);
+      bh.setTargetBatchRowCount(newRowCount);
+      htables[i].setTargetBatchRowCount(newRowCount);
+    }
+  }
+
   @Override
   public void cleanup() {
     if ( schema == null ) { return; } // not set up; nothing to clean
@@ -836,8 +866,6 @@ public void cleanup() {
     spillSet.close(); // delete the spill directory(ies)
     htIdxHolder = null;
     materializedValueFields = null;
-    outStartIdxHolder = null;
-    outNumRecordsHolder = null;
   }
 
   // First free the memory used by the given (spilled) partition (i.e., hash table plus batches)
@@ -853,6 +881,7 @@ private void reinitPartition(int part) /* throws SchemaChangeException /*, IOExc
     }
     batchHolders[part] = new ArrayList<BatchHolder>(); // First BatchHolder is created when the first put request is received.
 
+    outBatchIndex[part] = 0;
     // in case the reserve memory was used, try to restore
     restoreReservedMemory();
   }
@@ -962,17 +991,14 @@ private void spillAPartition(int part) {
     for (int currOutBatchIndex = 0; currOutBatchIndex < currPartition.size(); currOutBatchIndex++ ) {
 
       // get the number of records in the batch holder that are pending output
-      int numPendingOutput = currPartition.get(currOutBatchIndex).getNumPendingOutput();
+      int numOutputRecords = currPartition.get(currOutBatchIndex).getNumPendingOutput();
 
-      rowsInPartition += numPendingOutput;  // for logging
-      rowsSpilled += numPendingOutput;
+      rowsInPartition += numOutputRecords;  // for logging
+      rowsSpilled += numOutputRecords;
 
-      allocateOutgoing(numPendingOutput);
-
-      currPartition.get(currOutBatchIndex).outputValues(outStartIdxHolder, outNumRecordsHolder);
-      int numOutputRecords = outNumRecordsHolder.value;
-
-      this.htables[part].outputKeys(currOutBatchIndex, this.outContainer, outStartIdxHolder.value, outNumRecordsHolder.value, numPendingOutput);
+      allocateOutgoing(numOutputRecords);
+      currPartition.get(currOutBatchIndex).outputValues();
+      this.htables[part].outputKeys(currOutBatchIndex, this.outContainer, numOutputRecords);
 
       // set the value count for outgoing batch value vectors
       /* int i = 0; */
@@ -992,8 +1018,8 @@ private void spillAPartition(int part) {
         */
       }
 
-      outContainer.setRecordCount(numPendingOutput);
-      WritableBatch batch = WritableBatch.getBatchNoHVWrap(numPendingOutput, outContainer, false);
+      outContainer.setRecordCount(numOutputRecords);
+      WritableBatch batch = WritableBatch.getBatchNoHVWrap(numOutputRecords, outContainer, false);
       try {
         writers[part].write(batch, null);
       } catch (IOException ioe) {
@@ -1004,7 +1030,7 @@ private void spillAPartition(int part) {
         batch.clear();
       }
       outContainer.zeroVectors();
-      logger.trace("HASH AGG: Took {} us to spill {} records", writers[part].time(TimeUnit.MICROSECONDS), numPendingOutput);
+      logger.trace("HASH AGG: Took {} us to spill {} records", writers[part].time(TimeUnit.MICROSECONDS), numOutputRecords);
     }
 
     spilledBatchesCount[part] += currPartition.size(); // update count of spilled batches
@@ -1012,9 +1038,9 @@ private void spillAPartition(int part) {
     logger.trace("HASH AGG: Spilled {} rows from {} batches of partition {}", rowsInPartition, currPartition.size(), part);
   }
 
-  private void addBatchHolder(int part) {
+  private void addBatchHolder(int part, int batchRowCount) {
 
-    BatchHolder bh = newBatchHolder();
+    BatchHolder bh = newBatchHolder(batchRowCount);
     batchHolders[part].add(bh);
     if (EXTRA_DEBUG_1) {
       logger.debug("HashAggregate: Added new batch; num batches = {}.", batchHolders[part].size());
@@ -1024,8 +1050,8 @@ private void addBatchHolder(int part) {
   }
 
   // These methods are overridden in the generated class when created as plain Java code.
-  protected BatchHolder newBatchHolder() {
-    return new BatchHolder();
+  protected BatchHolder newBatchHolder(int batchRowCount) {
+    return new BatchHolder(batchRowCount);
   }
 
   /**
@@ -1161,20 +1187,21 @@ public AggIterOutcome outputCurrentBatch() {
 
     allocateOutgoing(numPendingOutput);
 
-    currPartition.get(currOutBatchIndex).outputValues(outStartIdxHolder, outNumRecordsHolder);
-    int numOutputRecords = outNumRecordsHolder.value;
-
-    if (EXTRA_DEBUG_1) {
-      logger.debug("After output values: outStartIdx = {}, outNumRecords = {}", outStartIdxHolder.value, outNumRecordsHolder.value);
-    }
-
-    this.htables[partitionToReturn].outputKeys(currOutBatchIndex, this.outContainer, outStartIdxHolder.value, outNumRecordsHolder.value, numPendingOutput);
+    currPartition.get(currOutBatchIndex).outputValues();
+    int numOutputRecords = numPendingOutput;
+    this.htables[partitionToReturn].outputKeys(currOutBatchIndex, this.outContainer, numPendingOutput);
 
     // set the value count for outgoing batch value vectors
     for (VectorWrapper<?> v : outgoing) {
       v.getValueVector().getMutator().setValueCount(numOutputRecords);
     }
 
+    outgoing.getRecordBatchMemoryManager().updateOutgoingStats(numOutputRecords);
+
+    if (logger.isDebugEnabled()) {
+      logger.debug("BATCH_STATS, outgoing: {}", new RecordBatchSizer(outgoing));
+    }
+
     this.outcome = IterOutcome.OK;
 
     if ( EXTRA_DEBUG_SPILL && is2ndPhase ) {
@@ -1271,6 +1298,10 @@ private String getOOMErrorMsg(String prefix) {
     return errmsg;
   }
 
+  private int getTargetBatchCount() {
+    return outgoing.getOutputRowCount();
+  }
+
   // Check if a group is present in the hash table; if not, insert it in the hash table.
   // The htIdxHolder contains the index of the group in the hash table container; this same
   // index is also used for the aggregation values maintained by the hash aggregate.
@@ -1338,7 +1369,7 @@ private void checkGroupAndAggrValues(int incomingRowIdx) {
     // ==========================================
     try {
 
-      putStatus = htables[currentPartition].put(incomingRowIdx, htIdxHolder, hashCode);
+      putStatus = htables[currentPartition].put(incomingRowIdx, htIdxHolder, hashCode, getTargetBatchCount());
 
     } catch (RetryAfterSpillException re) {
       if ( ! canSpill ) { throw new OutOfMemoryException(getOOMErrorMsg("Can not spill")); }
@@ -1372,7 +1403,7 @@ private void checkGroupAndAggrValues(int incomingRowIdx) {
 
         useReservedValuesMemory(); // try to preempt an OOM by using the reserve
 
-        addBatchHolder(currentPartition);  // allocate a new (internal) values batch
+        addBatchHolder(currentPartition, getTargetBatchCount());  // allocate a new (internal) values batch
 
         restoreReservedMemory(); // restore the reserve, if possible
         // A reason to check for a spill - In case restore-reserve failed
@@ -1408,8 +1439,8 @@ private void checkGroupAndAggrValues(int incomingRowIdx) {
     // Locate the matching aggregate columns and perform the aggregation
     // =================================================================
     int currentIdx = htIdxHolder.value;
-    BatchHolder bh = batchHolders[currentPartition].get((currentIdx >>> 16) & HashTable.BATCH_MASK);
-    int idxWithinBatch = currentIdx & HashTable.BATCH_MASK;
+    BatchHolder bh = batchHolders[currentPartition].get((currentIdx >>> 16) & BATCH_MASK);
+    int idxWithinBatch = currentIdx & BATCH_MASK;
     if (bh.updateAggrValues(incomingRowIdx, idxWithinBatch)) {
       numGroupedRecords++;
     }
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggregator.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggregator.java
index 35e6d538aa..f58be89291 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggregator.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/aggregate/HashAggregator.java
@@ -68,4 +68,6 @@ void setup(HashAggregate hashAggrConfig, HashTableConfig htConfig, FragmentConte
   boolean earlyOutput();
 
   RecordBatch getNewIncoming();
+
+  void adjustOutputCount(int outputBatchSize, int oldRowWidth, int newRowWidth);
 }
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java
index d80237c1e4..eaccd33552 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashPartition.java
@@ -53,6 +53,8 @@
 import java.util.List;
 import java.util.concurrent.TimeUnit;
 
+import static org.apache.drill.exec.physical.impl.common.HashTable.BATCH_SIZE;
+
 /**
  * <h2>Overview</h2>
  * <p>
@@ -498,7 +500,7 @@ public void buildContainersHashTableAndHelper() throws SchemaChangeException {
       for (int recInd = 0; recInd < currentRecordCount; recInd++) {
         int hashCode = HV_vector.getAccessor().get(recInd);
         try {
-          hashTable.put(recInd, htIndex, hashCode);
+          hashTable.put(recInd, htIndex, hashCode, BATCH_SIZE);
         } catch (RetryAfterSpillException RE) {
           throw new OutOfMemoryException("HT put");
         } // Hash Join does not retry
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTable.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTable.java
index 194c865ff6..3bf4b86ecf 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTable.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTable.java
@@ -82,7 +82,7 @@ void setup(HashTableConfig htConfig, BufferAllocator allocator, VectorContainer
    */
   int getProbeHashCode(int incomingRowIdx) throws SchemaChangeException;
 
-  PutStatus put(int incomingRowIdx, IndexPointer htIdxHolder, int hashCode) throws SchemaChangeException, RetryAfterSpillException;
+  PutStatus put(int incomingRowIdx, IndexPointer htIdxHolder, int hashCode, int batchSize) throws SchemaChangeException, RetryAfterSpillException;
 
   /**
    * @param incomingRowIdx The index of the key in the probe batch.
@@ -130,12 +130,10 @@ void setup(HashTableConfig htConfig, BufferAllocator allocator, VectorContainer
    * Retrieves the key columns and transfers them to the output container. Note this operation removes the key columns from the {@link HashTable}.
    * @param batchIdx The index of a {@link HashTableTemplate.BatchHolder} in the HashTable.
    * @param outContainer The destination container for the key columns.
-   * @param outStartIndex The start index of the key records to transfer.
    * @param numRecords The number of key recorts to transfer.
-   * @param numExpectedRecords
    * @return
    */
-  boolean outputKeys(int batchIdx, VectorContainer outContainer, int outStartIndex, int numRecords, int numExpectedRecords);
+  boolean outputKeys(int batchIdx, VectorContainer outContainer, int numRecords);
 
   /**
    * Returns a message containing memory usage statistics. Intended to be used for printing debugging or error messages.
@@ -148,6 +146,10 @@ void setup(HashTableConfig htConfig, BufferAllocator allocator, VectorContainer
    * @return
    */
   long getActualSize();
+
+  void setTargetBatchRowCount(int batchRowCount);
+
+  int getTargetBatchRowCount();
 }
 
 
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTableAllocationTracker.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTableAllocationTracker.java
index d72278d81c..7f38ee624f 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTableAllocationTracker.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTableAllocationTracker.java
@@ -32,42 +32,35 @@
   }
 
   private final HashTableConfig config;
-  private final int maxBatchHolderSize;
-
   private State state = State.NO_ALLOCATION_IN_PROGRESS;
   private int remainingCapacity;
 
-  protected HashTableAllocationTracker(final HashTableConfig config,
-                                       final int maxBatchHolderSize)
+  protected HashTableAllocationTracker(final HashTableConfig config)
   {
     this.config = Preconditions.checkNotNull(config);
-    this.maxBatchHolderSize = maxBatchHolderSize;
-
     remainingCapacity = config.getInitialCapacity();
   }
 
-  public int getNextBatchHolderSize() {
+  public int getNextBatchHolderSize(int batchSize) {
     state = State.ALLOCATION_IN_PROGRESS;
 
     if (!config.getInitialSizeIsFinal()) {
-      // We don't know the final size of the hash table, so return the default max batch holder size
-      return maxBatchHolderSize;
+      // We don't know the final size of the hash table, so just return the batch size.
+      return batchSize;
     } else {
       // We know the final size of the hash table so we need to compute the next batch holder size.
-
       Preconditions.checkState(remainingCapacity > 0);
-      return computeNextBatchHolderSize();
+      return computeNextBatchHolderSize(batchSize);
     }
   }
 
-  private int computeNextBatchHolderSize() {
-    return Math.min(remainingCapacity, maxBatchHolderSize);
+  private int computeNextBatchHolderSize(int batchSize) {
+    return Math.min(batchSize, remainingCapacity);
   }
 
-  public void commit() {
+  public void commit(int batchSize) {
     Preconditions.checkState(state.equals(State.ALLOCATION_IN_PROGRESS));
-
-    remainingCapacity -= computeNextBatchHolderSize();
+    remainingCapacity -= batchSize;
     state = State.NO_ALLOCATION_IN_PROGRESS;
   }
 }
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTableTemplate.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTableTemplate.java
index da916f3ba7..756b3f3a20 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTableTemplate.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/common/HashTableTemplate.java
@@ -65,7 +65,9 @@
   // Array of batch holders..each batch holder can hold up to BATCH_SIZE entries
   private ArrayList<BatchHolder> batchHolders;
 
-  private int totalBatchHoldersSize; // the size of all batchHolders
+  private int totalIndexSize; // index size of all batchHolders including current batch
+  private int prevIndexSize; // index size of all batchHolders not including current batch
+  private int currentIndexSize; // prevIndexSize + current batch count.
 
   // Current size of the hash table in terms of number of buckets
   private int tableSize = 0;
@@ -127,13 +129,21 @@
     private IntVector hashValues;
 
     private int maxOccupiedIdx = -1;
-//    private int batchOutputCount = 0;
-
+    private int targetBatchRowCount;
     private int batchIndex = 0;
 
+    public void setTargetBatchRowCount(int targetBatchRowCount) {
+      this.targetBatchRowCount = targetBatchRowCount;
+    }
+
+    public int getTargetBatchRowCount() {
+      return targetBatchRowCount;
+    }
+
     public BatchHolder(int idx, int newBatchHolderSize) {
 
       this.batchIndex = idx;
+      this.targetBatchRowCount = newBatchHolderSize;
 
       htContainer = new VectorContainer();
       boolean success = false;
@@ -152,7 +162,7 @@ public BatchHolder(int idx, int newBatchHolderSize) {
           } else if (vv instanceof VariableWidthVector) {
             long beforeMem = allocator.getAllocatedMemory();
             ((VariableWidthVector) vv).allocateNew(MAX_VARCHAR_SIZE * newBatchHolderSize, newBatchHolderSize);
-            logger.trace("HT allocated {} for varchar of max width {}",allocator.getAllocatedMemory() - beforeMem, MAX_VARCHAR_SIZE);
+            logger.trace("HT allocated {} for varchar of max width {}", allocator.getAllocatedMemory() - beforeMem, MAX_VARCHAR_SIZE);
           } else {
             vv.allocateNew();
           }
@@ -164,7 +174,9 @@ public BatchHolder(int idx, int newBatchHolderSize) {
       } finally {
         if (!success) {
           htContainer.clear();
-          if (links != null) { links.clear();}
+          if (links != null) {
+            links.clear();
+          }
         }
       }
     }
@@ -190,15 +202,14 @@ protected void setup() throws SchemaChangeException {
     private boolean isKeyMatch(int incomingRowIdx,
         IndexPointer currentIdxHolder,
         boolean isProbe) throws SchemaChangeException {
-
       int currentIdxWithinBatch = currentIdxHolder.value & BATCH_MASK;
       boolean match;
 
-      if (currentIdxWithinBatch >= HashTable.BATCH_SIZE) {
-        logger.debug("Batch size = {}, incomingRowIdx = {}, currentIdxWithinBatch = {}.", HashTable.BATCH_SIZE,
-            incomingRowIdx, currentIdxWithinBatch);
+      if (currentIdxWithinBatch >= batchHolders.get((currentIdxHolder.value >>> 16) & BATCH_MASK).getTargetBatchRowCount()) {
+        logger.debug("Batch size = {}, incomingRowIdx = {}, currentIdxWithinBatch = {}.",
+          batchHolders.get((currentIdxHolder.value >>> 16) & BATCH_MASK).getTargetBatchRowCount(), incomingRowIdx, currentIdxWithinBatch);
       }
-      assert (currentIdxWithinBatch < HashTable.BATCH_SIZE);
+      assert (currentIdxWithinBatch < batchHolders.get((currentIdxHolder.value >>> 16) & BATCH_MASK).getTargetBatchRowCount());
       assert (incomingRowIdx < HashTable.BATCH_SIZE);
 
       if (isProbe) {
@@ -217,7 +228,6 @@ private boolean isKeyMatch(int incomingRowIdx,
     // container at the specified index
     private void insertEntry(int incomingRowIdx, int currentIdx, int hashValue, BatchHolder lastEntryBatch, int lastEntryIdxWithinBatch) throws SchemaChangeException {
       int currentIdxWithinBatch = currentIdx & BATCH_MASK;
-
       setValue(incomingRowIdx, currentIdxWithinBatch);
       // setValue may OOM when doubling of one of the VarChar Key Value Vectors
       // This would be caught and retried later (setValue() is idempotent)
@@ -280,8 +290,7 @@ private void rehash(int numbuckets, IntVector newStartIndices, int batchStartIdx
           while (true) {
             if (idx != EMPTY_SLOT) {
               idxWithinBatch = idx & BATCH_MASK;
-              int batchIdx = ((idx >>> 16) & BATCH_MASK);
-              bh = batchHolders.get(batchIdx);
+              bh = batchHolders.get((idx >>> 16) & BATCH_MASK);
             }
 
             if (bh == this && newLinks.getAccessor().get(idxWithinBatch) == EMPTY_SLOT) {
@@ -332,7 +341,7 @@ private void rehash(int numbuckets, IntVector newStartIndices, int batchStartIdx
       hashValues = newHashValues;
     }
 
-    private boolean outputKeys(VectorContainer outContainer, int outStartIndex, int numRecords, int numExpectedRecords) {
+    private boolean outputKeys(VectorContainer outContainer, int numRecords) {
       // set the value count for htContainer's value vectors before the transfer ..
       setValueCount();
 
@@ -344,18 +353,9 @@ private boolean outputKeys(VectorContainer outContainer, int outStartIndex, int
         @SuppressWarnings("resource")
         ValueVector targetVV = outgoingIter.next().getValueVector();
         TransferPair tp = sourceVV.makeTransferPair(targetVV);
-        if ( outStartIndex == 0 && numRecords == numExpectedRecords ) {
-          // The normal case: The whole column key(s) are transfered as is
-          tp.transfer();
-        } else {
-          // Transfer just the required section (does this ever happen ?)
-          // Requires an expensive allocation and copy
-          logger.debug("Performing partial output of keys, from index {}, num {} (out of {})",
-              outStartIndex,numRecords,numExpectedRecords);
-          tp.splitAndTransfer(outStartIndex, numRecords);
-        }
+        // The normal case: The whole column key(s) are transfered as is
+        tp.transfer();
       }
-
       return true;
     }
 
@@ -469,7 +469,7 @@ public void setup(HashTableConfig htConfig, BufferAllocator allocator, VectorCon
     this.incomingProbe = incomingProbe;
     this.outgoing = outgoing;
     this.htContainerOrig = htContainerOrig;
-    this.allocationTracker = new HashTableAllocationTracker(htConfig, BATCH_SIZE);
+    this.allocationTracker = new HashTableAllocationTracker(htConfig);
 
     // round up the initial capacity to nearest highest power of 2
     tableSize = roundUpToPowerOf2(initialCap);
@@ -486,9 +486,12 @@ public void setup(HashTableConfig htConfig, BufferAllocator allocator, VectorCon
 
     // Create the first batch holder
     batchHolders = new ArrayList<BatchHolder>();
-    totalBatchHoldersSize = 0;
     // First BatchHolder is created when the first put request is received.
 
+    prevIndexSize = 0;
+    currentIndexSize = 0;
+    totalIndexSize = 0;
+
     try {
       doSetup(incomingBuild, incomingProbe);
     } catch (SchemaChangeException e) {
@@ -501,7 +504,7 @@ public void setup(HashTableConfig htConfig, BufferAllocator allocator, VectorCon
   @Override
   public void updateInitialCapacity(int initialCapacity) {
     htConfig = htConfig.withInitialCapacity(initialCapacity);
-    allocationTracker = new HashTableAllocationTracker(htConfig, BATCH_SIZE);
+    allocationTracker = new HashTableAllocationTracker(htConfig);
     enlargeEmptyHashTableIfNeeded(initialCapacity);
   }
 
@@ -548,7 +551,9 @@ public void clear() {
       }
       batchHolders.clear();
       batchHolders = null;
-      totalBatchHoldersSize = 0;
+      prevIndexSize = 0;
+      currentIndexSize = 0;
+      totalIndexSize = 0;
     }
     startIndices.clear();
     // currentIdxHolder = null; // keep IndexPointer in case HT is reused
@@ -574,10 +579,15 @@ private void retryAfterOOM(boolean batchAdded) throws RetryAfterSpillException {
     if ( batchAdded ) {
       logger.trace("OOM - Removing index {} from the batch holders list",batchHolders.size() - 1);
       BatchHolder bh = batchHolders.remove(batchHolders.size() - 1);
-      totalBatchHoldersSize -= BATCH_SIZE;
+      prevIndexSize = batchHolders.size() > 1 ? (batchHolders.size()-1) * BATCH_SIZE : 0;
+      currentIndexSize = prevIndexSize + (batchHolders.size() > 0 ? batchHolders.get(batchHolders.size()-1).getTargetBatchRowCount() : 0);
+      totalIndexSize = batchHolders.size() * BATCH_SIZE;
+      // update freeIndex to point to end of last batch + 1
+      freeIndex = totalIndexSize + 1;
       bh.clear();
+    } else {
+      freeIndex--;
     }
-    freeIndex--;
     throw new RetryAfterSpillException();
   }
 
@@ -619,7 +629,7 @@ public int getProbeHashCode(int incomingRowIdx) throws SchemaChangeException {
    * @return Status - the key(s) was ADDED or was already PRESENT
    */
   @Override
-  public PutStatus put(int incomingRowIdx, IndexPointer htIdxHolder, int hashCode) throws SchemaChangeException, RetryAfterSpillException {
+  public PutStatus put(int incomingRowIdx, IndexPointer htIdxHolder, int hashCode, int targetBatchRowCount) throws SchemaChangeException, RetryAfterSpillException {
 
     int bucketIndex = getBucketIndex(hashCode, numBuckets());
     int startIdx = startIndices.getAccessor().get(bucketIndex);
@@ -634,7 +644,7 @@ public PutStatus put(int incomingRowIdx, IndexPointer htIdxHolder, int hashCode)
           /* isKeyMatch() below also advances the currentIdxHolder to the next link */) {
 
       // remember the current link, which would be the last when the next link is empty
-      lastEntryBatch = batchHolders.get((currentIdxHolder.value >>> 16) & HashTable.BATCH_MASK);
+      lastEntryBatch = batchHolders.get((currentIdxHolder.value >>> 16) & BATCH_MASK);
       lastEntryIdxWithinBatch = currentIdxHolder.value & BATCH_MASK;
 
       if (lastEntryBatch.isKeyMatch(incomingRowIdx, currentIdxHolder, false)) {
@@ -647,14 +657,18 @@ public PutStatus put(int incomingRowIdx, IndexPointer htIdxHolder, int hashCode)
     currentIdx = freeIndex++;
     boolean addedBatch = false;
     try {  // ADD A BATCH
-      addedBatch = addBatchIfNeeded(currentIdx);
+      addedBatch = addBatchIfNeeded(currentIdx, targetBatchRowCount);
+      if (addedBatch) {
+        // If we just added the batch, update the current index to point to beginning of new batch.
+        currentIdx = (batchHolders.size() - 1) * BATCH_SIZE;
+        freeIndex = currentIdx + 1;
+      }
     } catch (OutOfMemoryException OOME) {
-      retryAfterOOM( currentIdx < batchHolders.size() * BATCH_SIZE );
+      retryAfterOOM( currentIdx < totalIndexSize);
     }
 
     try { // INSERT ENTRY
       BatchHolder bh = batchHolders.get((currentIdx >>> 16) & BATCH_MASK);
-
       bh.insertEntry(incomingRowIdx, currentIdx, hashCode, lastEntryBatch, lastEntryIdxWithinBatch);
       numEntries++;
     } catch (OutOfMemoryException OOME) { retryAfterOOM( addedBatch ); }
@@ -684,7 +698,7 @@ public PutStatus put(int incomingRowIdx, IndexPointer htIdxHolder, int hashCode)
     }
     htIdxHolder.value = currentIdx;
     return  addedBatch ? PutStatus.NEW_BATCH_ADDED :
-        ( freeIndex + 1 > totalBatchHoldersSize /* batchHolders.size() * BATCH_SIZE */ ) ?
+        (freeIndex + 1 > currentIndexSize) ?
         PutStatus.KEY_ADDED_LAST : // the last key in the batch
         PutStatus.KEY_ADDED;     // otherwise
   }
@@ -716,20 +730,22 @@ public int probeForKey(int incomingRowIdx, int hashCode) throws SchemaChangeExce
   // Add a new BatchHolder to the list of batch holders if needed. This is based on the supplied
   // currentIdx; since each BatchHolder can hold up to BATCH_SIZE entries, if the currentIdx exceeds
   // the capacity, we will add a new BatchHolder. Return true if a new batch was added.
-  private boolean addBatchIfNeeded(int currentIdx) throws SchemaChangeException {
-    // int totalBatchSize = batchHolders.size() * BATCH_SIZE;
-
-    if (currentIdx >= totalBatchHoldersSize) {
-      BatchHolder bh = newBatchHolder(batchHolders.size(), allocationTracker.getNextBatchHolderSize());
+  private boolean addBatchIfNeeded(int currentIdx, int batchRowCount) throws SchemaChangeException {
+     // Add a new batch if this is the first batch or
+     // index is greater than current batch target count i.e. we reached the limit of current batch.
+     if (batchHolders.size() == 0 || (currentIdx >= currentIndexSize)) {
+      final int allocationSize = allocationTracker.getNextBatchHolderSize(batchRowCount);
+      final BatchHolder bh = newBatchHolder(batchHolders.size(), allocationSize);
       batchHolders.add(bh);
+      prevIndexSize = batchHolders.size() > 1 ? (batchHolders.size()-1)*BATCH_SIZE : 0;
+      currentIndexSize = prevIndexSize + batchHolders.get(batchHolders.size()-1).getTargetBatchRowCount();
+      totalIndexSize = batchHolders.size() * BATCH_SIZE;
       bh.setup();
       if (EXTRA_DEBUG) {
         logger.debug("HashTable: Added new batch. Num batches = {}.", batchHolders.size());
       }
 
-      allocationTracker.commit();
-
-      totalBatchHoldersSize += BATCH_SIZE; // total increased by 1 batch
+      allocationTracker.commit(allocationSize);
       return true;
     }
     return false;
@@ -782,10 +798,12 @@ private void resizeAndRehashIfNeeded() {
 
     IntVector newStartIndices = allocMetadataVector(tableSize, EMPTY_SLOT);
 
+    int idx = 0;
     for (int i = 0; i < batchHolders.size(); i++) {
       BatchHolder bh = batchHolders.get(i);
-      int batchStartIdx = i * BATCH_SIZE;
+      int batchStartIdx = idx;
       bh.rehash(tableSize, newStartIndices, batchStartIdx);
+      idx += bh.getTargetBatchRowCount();
     }
 
     startIndices.clear();
@@ -796,8 +814,8 @@ private void resizeAndRehashIfNeeded() {
       logger.debug("Number of buckets = {}.", startIndices.getAccessor().getValueCount());
       for (int i = 0; i < startIndices.getAccessor().getValueCount(); i++) {
         logger.debug("Bucket: {}, startIdx[ {} ] = {}.", i, i, startIndices.getAccessor().get(i));
-        int idx = startIndices.getAccessor().get(i);
-        BatchHolder bh = batchHolders.get((idx >>> 16) & BATCH_MASK);
+        int startIdx = startIndices.getAccessor().get(i);
+        BatchHolder bh = batchHolders.get((startIdx >>> 16) & BATCH_MASK);
         bh.dump(idx);
       }
     }
@@ -831,7 +849,9 @@ public void reset() {
     freeIndex = 0; // all batch holders are gone
     // reallocate batch holders, and the hash table to the original size
     batchHolders = new ArrayList<BatchHolder>();
-    totalBatchHoldersSize = 0;
+    prevIndexSize = 0;
+    currentIndexSize = 0;
+    totalIndexSize = 0;
     startIndices = allocMetadataVector(originalTableSize, EMPTY_SLOT);
   }
   public void updateIncoming(VectorContainer newIncoming, RecordBatch newIncomingProbe) {
@@ -846,9 +866,9 @@ public void updateIncoming(VectorContainer newIncoming, RecordBatch newIncomingP
   }
 
   @Override
-  public boolean outputKeys(int batchIdx, VectorContainer outContainer, int outStartIndex, int numRecords, int numExpectedRecords) {
+  public boolean outputKeys(int batchIdx, VectorContainer outContainer, int numRecords) {
     assert batchIdx < batchHolders.size();
-    return batchHolders.get(batchIdx).outputKeys(outContainer, outStartIndex, numRecords, numExpectedRecords);
+    return batchHolders.get(batchIdx).outputKeys(outContainer, numRecords);
   }
 
   private IntVector allocMetadataVector(int size, int initialValue) {
@@ -891,4 +911,14 @@ public String makeDebugString() {
     return String.format("[numBuckets = %d, numEntries = %d, numBatchHolders = %d, actualSize = %s]",
       numBuckets(), numEntries, batchHolders.size(), HashJoinMemoryCalculator.PartitionStatSet.prettyPrintBytes(getActualSize()));
   }
+
+  @Override
+  public void setTargetBatchRowCount(int batchRowCount) {
+    batchHolders.get(batchHolders.size()-1).targetBatchRowCount = batchRowCount;
+  }
+
+  @Override
+  public int getTargetBatchRowCount() {
+    return batchHolders.get(batchHolders.size()-1).targetBatchRowCount;
+  }
 }
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatchMemoryManager.java b/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatchMemoryManager.java
index a270ced48c..79b28db243 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatchMemoryManager.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/record/RecordBatchMemoryManager.java
@@ -94,19 +94,19 @@ public long getAvgOutputRowWidth() {
   }
 
   public long getNumIncomingBatches() {
-    return inputBatchStats[DEFAULT_INPUT_INDEX].getNumBatches();
+    return inputBatchStats[DEFAULT_INPUT_INDEX] == null ? 0 : inputBatchStats[DEFAULT_INPUT_INDEX].getNumBatches();
   }
 
   public long getAvgInputBatchSize() {
-    return inputBatchStats[DEFAULT_INPUT_INDEX].getAvgBatchSize();
+    return inputBatchStats[DEFAULT_INPUT_INDEX] == null ? 0 : inputBatchStats[DEFAULT_INPUT_INDEX].getAvgBatchSize();
   }
 
   public long getAvgInputRowWidth() {
-    return inputBatchStats[DEFAULT_INPUT_INDEX].getAvgRowWidth();
+    return inputBatchStats[DEFAULT_INPUT_INDEX] == null ? 0 : inputBatchStats[DEFAULT_INPUT_INDEX].getAvgRowWidth();
   }
 
   public long getTotalInputRecords() {
-    return inputBatchStats[DEFAULT_INPUT_INDEX].getTotalRecords();
+    return inputBatchStats[DEFAULT_INPUT_INDEX] == null ? 0 : inputBatchStats[DEFAULT_INPUT_INDEX].getTotalRecords();
   }
 
   public long getNumIncomingBatches(int index) {
@@ -176,6 +176,22 @@ public int update(RecordBatch batch, int inputIndex, int outputPosition, boolean
     return getOutputRowCount();
   }
 
+  public boolean updateIfNeeded(int newOutgoingRowWidth) {
+    // We do not want to keep adjusting batch holders target row count
+    // for small variations in row width.
+    // If row width changes, calculate actual adjusted row count i.e. row count
+    // rounded down to nearest power of two and do nothing if that does not change.
+    if (newOutgoingRowWidth == outgoingRowWidth ||
+      computeOutputRowCount(outputBatchSize, newOutgoingRowWidth) == computeOutputRowCount(outputBatchSize, outgoingRowWidth)) {
+      return false;
+    }
+
+    // Set number of rows in outgoing batch. This number will be used for new batch creation.
+    setOutputRowCount(outputBatchSize, newOutgoingRowWidth);
+    setOutgoingRowWidth(newOutgoingRowWidth);
+    return true;
+  }
+
   public int getOutputRowCount() {
     return outputRowCount;
   }
@@ -201,7 +217,7 @@ public static int adjustOutputRowCount(int rowCount) {
     return (Math.min(MAX_NUM_ROWS, Math.max(Integer.highestOneBit(rowCount) - 1, MIN_NUM_ROWS)));
   }
 
-  public static int computeRowCount(int batchSize, int rowWidth) {
+  public static int computeOutputRowCount(int batchSize, int rowWidth) {
     return adjustOutputRowCount(RecordBatchSizer.safeDivide(batchSize, rowWidth));
   }
 
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/common/HashTableAllocationTrackerTest.java b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/common/HashTableAllocationTrackerTest.java
index 131d82f911..0140943341 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/common/HashTableAllocationTrackerTest.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/common/HashTableAllocationTrackerTest.java
@@ -21,40 +21,42 @@
 import org.junit.Assert;
 import org.junit.Test;
 
+import static org.apache.drill.exec.physical.impl.common.HashTable.BATCH_SIZE;
+
 public class HashTableAllocationTrackerTest {
   @Test
   public void testDoubleGetNextCall() {
     final HashTableConfig config = new HashTableConfig(100, true, .5f, Lists.newArrayList(), Lists.newArrayList(), Lists.newArrayList());
-    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config, 30);
+    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config);
 
     for (int counter = 0; counter < 100; counter++) {
-      Assert.assertEquals(30, tracker.getNextBatchHolderSize());
+      Assert.assertEquals(100, tracker.getNextBatchHolderSize(BATCH_SIZE));
     }
   }
 
   @Test(expected = IllegalStateException.class)
   public void testPrematureCommit() {
     final HashTableConfig config = new HashTableConfig(100, .5f, Lists.newArrayList(), Lists.newArrayList(), Lists.newArrayList());
-    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config, 30);
+    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config);
 
-    tracker.commit();
+    tracker.commit(30);
   }
 
   @Test(expected = IllegalStateException.class)
   public void testDoubleCommit() {
     final HashTableConfig config = new HashTableConfig(100, .5f, Lists.newArrayList(), Lists.newArrayList(), Lists.newArrayList());
-    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config, 30);
+    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config);
 
-    tracker.commit();
-    tracker.commit();
+    tracker.commit(30);
+    tracker.commit(30);
   }
 
   @Test
   public void testOverAsking() {
     final HashTableConfig config = new HashTableConfig(100, .5f, Lists.newArrayList(), Lists.newArrayList(), Lists.newArrayList());
-    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config, 30);
+    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config);
 
-    tracker.getNextBatchHolderSize();
+    tracker.getNextBatchHolderSize(30);
   }
 
   /**
@@ -63,11 +65,11 @@ public void testOverAsking() {
   @Test
   public void testLifecycle1() {
     final HashTableConfig config = new HashTableConfig(100, .5f, Lists.newArrayList(), Lists.newArrayList(), Lists.newArrayList());
-    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config, 30);
+    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config);
 
     for (int counter = 0; counter < 100; counter++) {
-      Assert.assertEquals(30, tracker.getNextBatchHolderSize());
-      tracker.commit();
+      Assert.assertEquals(30, tracker.getNextBatchHolderSize(30));
+      tracker.commit(30);
     }
   }
 
@@ -77,21 +79,21 @@ public void testLifecycle1() {
   @Test
   public void testLifecycle() {
     final HashTableConfig config = new HashTableConfig(100, true, .5f, Lists.newArrayList(), Lists.newArrayList(), Lists.newArrayList());
-    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config, 30);
+    final HashTableAllocationTracker tracker = new HashTableAllocationTracker(config);
 
-    Assert.assertEquals(30, tracker.getNextBatchHolderSize());
-    tracker.commit();
-    Assert.assertEquals(30, tracker.getNextBatchHolderSize());
-    tracker.commit();
-    Assert.assertEquals(30, tracker.getNextBatchHolderSize());
-    tracker.commit();
-    Assert.assertEquals(10, tracker.getNextBatchHolderSize());
-    tracker.commit();
+    Assert.assertEquals(30, tracker.getNextBatchHolderSize(30));
+    tracker.commit(30);
+    Assert.assertEquals(30, tracker.getNextBatchHolderSize(30));
+    tracker.commit(30);
+    Assert.assertEquals(30, tracker.getNextBatchHolderSize(30));
+    tracker.commit(30);
+    Assert.assertEquals(10, tracker.getNextBatchHolderSize(30));
+    tracker.commit(30);
 
     boolean caughtException = false;
 
     try {
-      tracker.getNextBatchHolderSize();
+      tracker.getNextBatchHolderSize(BATCH_SIZE);
     } catch (IllegalStateException ex) {
       caughtException = true;
     }
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/unit/TestOutputBatchSize.java b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/unit/TestOutputBatchSize.java
index fd0b49485a..471f1b85db 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/unit/TestOutputBatchSize.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/unit/TestOutputBatchSize.java
@@ -26,11 +26,13 @@
 import org.apache.drill.exec.physical.base.AbstractBase;
 import org.apache.drill.exec.physical.base.PhysicalOperator;
 import org.apache.drill.exec.physical.config.FlattenPOP;
+import org.apache.drill.exec.physical.config.HashAggregate;
 import org.apache.drill.exec.physical.config.HashJoinPOP;
 import org.apache.drill.exec.physical.config.MergeJoinPOP;
 import org.apache.drill.exec.physical.config.Project;
 import org.apache.drill.exec.physical.config.UnionAll;
 import org.apache.drill.exec.physical.impl.ScanBatch;
+import org.apache.drill.exec.planner.physical.AggPrelBase;
 import org.apache.drill.exec.record.RecordBatchSizer;
 import org.apache.drill.exec.record.RecordBatch;
 import org.apache.drill.exec.record.VectorAccessible;
@@ -2088,6 +2090,203 @@ public void testLeftOuterHashJoin() throws Exception {
 
   }
 
+  @Test
+  public void testSimpleHashAgg() {
+    HashAggregate aggConf = new HashAggregate(null, AggPrelBase.OperatorPhase.PHASE_1of1, parseExprs("a", "a"), parseExprs("sum(b)", "b_sum"), 1.0f);
+    List<String> inputJsonBatches = Lists.newArrayList(
+       "[{\"a\": 5, \"b\" : 1 }]",
+         "[{\"a\": 5, \"b\" : 5},{\"a\": 3, \"b\" : 8}]");
+
+    opTestBuilder()
+      .physicalOperator(aggConf)
+      .inputDataStreamJson(inputJsonBatches)
+      .baselineColumns("b_sum", "a")
+      .baselineValues(6l, 5l)
+      .baselineValues(8l, 3l)
+      .go();
+  }
+
+  @Test
+  public void testHashAggSum() throws ExecutionSetupException {
+    HashAggregate hashAgg = new HashAggregate(null, AggPrelBase.OperatorPhase.PHASE_1of1, parseExprs("a", "a"), parseExprs("sum(b)", "b_sum"), 1.0f);
+
+    // create input rows like this.
+    // "a" : 1, "b" : 1
+    // "a" : 1, "b" : 1
+    // "a" : 1, "b" : 1
+    List<String> inputJsonBatches = Lists.newArrayList();
+    StringBuilder batchString = new StringBuilder();
+    batchString.append("[");
+    for (int i = 0; i < numRows; i++) {
+        batchString.append("{\"a\": " + i + ", \"b\": " + i + "},");
+        batchString.append("{\"a\": " + i + ", \"b\": " + i + "},");
+        batchString.append("{\"a\": " + i + ", \"b\": " + i + "},");
+    }
+    batchString.append("{\"a\": " + numRows + ", \"b\": " + numRows + "}," );
+    batchString.append("{\"a\": " + numRows + ", \"b\": " + numRows + "}," );
+    batchString.append("{\"a\": " + numRows + ", \"b\": " + numRows + "}" );
+
+    batchString.append("]");
+    inputJsonBatches.add(batchString.toString());
+
+    // Figure out what will be approximate total output size out of hash agg for input above
+    // We will use this sizing information to set output batch size so we can produce desired
+    // number of batches that can be verified.
+
+    // output rows will be like this.
+    // "a" : 1, "b" : 3
+    List<String> expectedJsonBatches = Lists.newArrayList();
+    StringBuilder expectedBatchString = new StringBuilder();
+    expectedBatchString.append("[");
+
+    for (int i = 0; i < numRows; i++) {
+      expectedBatchString.append("{\"a\": " + i + ", \"b\": " + (3*i) + "},");
+    }
+    expectedBatchString.append("{\"a\": " + numRows + ", \"b\": " + numRows + "}" );
+    expectedBatchString.append("]");
+    expectedJsonBatches.add(expectedBatchString.toString());
+
+    long totalSize = getExpectedSize(expectedJsonBatches);
+
+    // set the output batch size to 1/2 of total size expected.
+    // We will get approximately get 2 batches and max of 4.
+    fragContext.getOptions().setLocalOption("drill.exec.memory.operator.output_batch_size", totalSize / 2);
+
+    OperatorTestBuilder opTestBuilder = opTestBuilder()
+      .physicalOperator(hashAgg)
+      .inputDataStreamJson(inputJsonBatches)
+      .baselineColumns("a", "b_sum")
+      .expectedNumBatches(4)  // verify number of batches
+      .expectedBatchSize(totalSize/2); // verify batch size.
+
+
+    for (int i = 0; i < numRows + 1; i++) {
+      opTestBuilder.baselineValues((long)i, (long)3*i);
+    }
+
+    opTestBuilder.go();
+  }
+
+  @Test
+  public void testHashAggAvg() throws ExecutionSetupException {
+    HashAggregate hashAgg = new HashAggregate(null, AggPrelBase.OperatorPhase.PHASE_1of1, parseExprs("a", "a"), parseExprs("avg(b)", "b_avg"), 1.0f);
+
+    // create input rows like this.
+    // "a" : 1, "b" : 1
+    // "a" : 1, "b" : 1
+    // "a" : 1, "b" : 1
+    List<String> inputJsonBatches = Lists.newArrayList();
+    StringBuilder batchString = new StringBuilder();
+    batchString.append("[");
+    for (int i = 0; i < numRows; i++) {
+      batchString.append("{\"a\": " + i + ", \"b\": " + i + "},");
+      batchString.append("{\"a\": " + i + ", \"b\": " + i + "},");
+      batchString.append("{\"a\": " + i + ", \"b\": " + i + "},");
+    }
+    batchString.append("{\"a\": " + numRows + ", \"b\": " + numRows + "}," );
+    batchString.append("{\"a\": " + numRows + ", \"b\": " + numRows + "}," );
+    batchString.append("{\"a\": " + numRows + ", \"b\": " + numRows + "}" );
+
+    batchString.append("]");
+    inputJsonBatches.add(batchString.toString());
+
+    // Figure out what will be approximate total output size out of hash agg for input above
+    // We will use this sizing information to set output batch size so we can produce desired
+    // number of batches that can be verified.
+
+    // output rows will be like this.
+    // "a" : 1, "b" : 3
+    List<String> expectedJsonBatches = Lists.newArrayList();
+    StringBuilder expectedBatchString = new StringBuilder();
+    expectedBatchString.append("[");
+
+    for (int i = 0; i < numRows; i++) {
+      expectedBatchString.append("{\"a\": " + i + ", \"b\": " + (3*i) + "},");
+    }
+    expectedBatchString.append("{\"a\": " + numRows + ", \"b\": " + numRows + "}" );
+    expectedBatchString.append("]");
+    expectedJsonBatches.add(expectedBatchString.toString());
+
+    long totalSize = getExpectedSize(expectedJsonBatches);
+
+    // set the output batch size to 1/2 of total size expected.
+    // We will get approximately get 2 batches and max of 4.
+    fragContext.getOptions().setLocalOption("drill.exec.memory.operator.output_batch_size", totalSize / 2);
+
+    OperatorTestBuilder opTestBuilder = opTestBuilder()
+      .physicalOperator(hashAgg)
+      .inputDataStreamJson(inputJsonBatches)
+      .baselineColumns("a", "b_avg")
+      .expectedNumBatches(4)  // verify number of batches
+      .expectedBatchSize(totalSize/2); // verify batch size.
+
+    for (int i = 0; i < numRows + 1; i++) {
+      opTestBuilder.baselineValues((long)i, (double)i);
+    }
+
+    opTestBuilder.go();
+  }
+
+  @Test
+  public void testHashAggMax() throws ExecutionSetupException {
+    HashAggregate hashAgg = new HashAggregate(null, AggPrelBase.OperatorPhase.PHASE_1of1, parseExprs("a", "a"), parseExprs("max(b)", "b_max"), 1.0f);
+
+    // create input rows like this.
+    // "a" : 1, "b" : "a"
+    // "a" : 2, "b" : "aa"
+    // "a" : 3, "b" : "aaa"
+    List<String> inputJsonBatches = Lists.newArrayList();
+    StringBuilder batchString = new StringBuilder();
+    batchString.append("[");
+    for (int i = 0; i < numRows; i++) {
+      batchString.append("{\"a\": " + i + ", \"b\": " + "\"a\"" + "},");
+      batchString.append("{\"a\": " + i + ", \"b\": " + "\"aa\"" + "},");
+      batchString.append("{\"a\": " + i + ", \"b\": " + "\"aaa\"" + "},");
+    }
+    batchString.append("{\"a\": " + numRows + ", \"b\": " + "\"a\"" + "}," );
+    batchString.append("{\"a\": " + numRows + ", \"b\": " + "\"aa\"" + "}," );
+    batchString.append("{\"a\": " + numRows + ", \"b\": " + "\"aaa\"" + "}" );
+
+    batchString.append("]");
+    inputJsonBatches.add(batchString.toString());
+
+    // Figure out what will be approximate total output size out of hash agg for input above
+    // We will use this sizing information to set output batch size so we can produce desired
+    // number of batches that can be verified.
+
+    // output rows will be like this.
+    // "a" : 1, "b" : "aaa"
+    List<String> expectedJsonBatches = Lists.newArrayList();
+    StringBuilder expectedBatchString = new StringBuilder();
+    expectedBatchString.append("[");
+
+    for (int i = 0; i < numRows; i++) {
+      expectedBatchString.append("{\"a\": " + i + ", \"b\": " + "\"aaa\"" + "},");
+    }
+    expectedBatchString.append("{\"a\": " + numRows + ", \"b\": " + "\"aaa\"" + "}" );
+    expectedBatchString.append("]");
+    expectedJsonBatches.add(expectedBatchString.toString());
+
+    long totalSize = getExpectedSize(expectedJsonBatches);
+
+    // set the output batch size to 1/2 of total size expected.
+    // We will get approximately get 2 batches and max of 4.
+    fragContext.getOptions().setLocalOption("drill.exec.memory.operator.output_batch_size", totalSize / 2);
+
+    OperatorTestBuilder opTestBuilder = opTestBuilder()
+      .physicalOperator(hashAgg)
+      .inputDataStreamJson(inputJsonBatches)
+      .baselineColumns("a", "b_max")
+      .expectedNumBatches(2)  // verify number of batches
+      .expectedBatchSize(totalSize); // verify batch size.
+
+    for (int i = 0; i < numRows + 1; i++) {
+      opTestBuilder.baselineValues((long)i, "aaa");
+    }
+
+    opTestBuilder.go();
+  }
+
   @Test
   public void testSizerRepeatedList() throws Exception {
     List<String> inputJsonBatches = Lists.newArrayList();


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services