You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/11/30 10:29:19 UTC

spark git commit: [SPARK-22652][SQL] remove set methods in ColumnarRow

Repository: spark
Updated Branches:
  refs/heads/master 92cfbeeb5 -> 444a2bbb6


[SPARK-22652][SQL] remove set methods in ColumnarRow

## What changes were proposed in this pull request?

As a step to make `ColumnVector` public, the `ColumnarRow` returned by `ColumnVector#getStruct` should be immutable.

However we do need the mutability of `ColumnaRow` for the fast vectorized hashmap in hash aggregate. To solve this, this PR introduces a `MutableColumnarRow` for this use case.

## How was this patch tested?

existing test.

Author: Wenchen Fan <we...@databricks.com>

Closes #19847 from cloud-fan/mutable-row.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/444a2bbb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/444a2bbb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/444a2bbb

Branch: refs/heads/master
Commit: 444a2bbb67c2548d121152bc922b4c3337ddc8e8
Parents: 92cfbee
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Nov 30 18:28:58 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Nov 30 18:28:58 2017 +0800

----------------------------------------------------------------------
 .../sql/execution/vectorized/ColumnarRow.java   | 102 +------
 .../vectorized/MutableColumnarRow.java          | 278 +++++++++++++++++++
 .../execution/aggregate/HashAggregateExec.scala |   3 +-
 .../aggregate/VectorizedHashMapGenerator.scala  |  82 +++---
 .../vectorized/ColumnVectorSuite.scala          |  12 +
 .../vectorized/ColumnarBatchSuite.scala         |  23 --
 6 files changed, 336 insertions(+), 164 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
index 98a9073..cabb747 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
@@ -16,8 +16,6 @@
  */
 package org.apache.spark.sql.execution.vectorized;
 
-import java.math.BigDecimal;
-
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
 import org.apache.spark.sql.catalyst.util.MapData;
@@ -32,17 +30,10 @@ import org.apache.spark.unsafe.types.UTF8String;
 public final class ColumnarRow extends InternalRow {
   protected int rowId;
   private final ColumnVector[] columns;
-  private final WritableColumnVector[] writableColumns;
 
   // Ctor used if this is a struct.
   ColumnarRow(ColumnVector[] columns) {
     this.columns = columns;
-    this.writableColumns = new WritableColumnVector[this.columns.length];
-    for (int i = 0; i < this.columns.length; i++) {
-      if (this.columns[i] instanceof WritableColumnVector) {
-        this.writableColumns[i] = (WritableColumnVector) this.columns[i];
-      }
-    }
   }
 
   public ColumnVector[] columns() { return columns; }
@@ -205,97 +196,8 @@ public final class ColumnarRow extends InternalRow {
   }
 
   @Override
-  public void update(int ordinal, Object value) {
-    if (value == null) {
-      setNullAt(ordinal);
-    } else {
-      DataType dt = columns[ordinal].dataType();
-      if (dt instanceof BooleanType) {
-        setBoolean(ordinal, (boolean) value);
-      } else if (dt instanceof IntegerType) {
-        setInt(ordinal, (int) value);
-      } else if (dt instanceof ShortType) {
-        setShort(ordinal, (short) value);
-      } else if (dt instanceof LongType) {
-        setLong(ordinal, (long) value);
-      } else if (dt instanceof FloatType) {
-        setFloat(ordinal, (float) value);
-      } else if (dt instanceof DoubleType) {
-        setDouble(ordinal, (double) value);
-      } else if (dt instanceof DecimalType) {
-        DecimalType t = (DecimalType) dt;
-        setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()),
-                t.precision());
-      } else {
-        throw new UnsupportedOperationException("Datatype not supported " + dt);
-      }
-    }
-  }
-
-  @Override
-  public void setNullAt(int ordinal) {
-    getWritableColumn(ordinal).putNull(rowId);
-  }
-
-  @Override
-  public void setBoolean(int ordinal, boolean value) {
-    WritableColumnVector column = getWritableColumn(ordinal);
-    column.putNotNull(rowId);
-    column.putBoolean(rowId, value);
-  }
+  public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); }
 
   @Override
-  public void setByte(int ordinal, byte value) {
-    WritableColumnVector column = getWritableColumn(ordinal);
-    column.putNotNull(rowId);
-    column.putByte(rowId, value);
-  }
-
-  @Override
-  public void setShort(int ordinal, short value) {
-    WritableColumnVector column = getWritableColumn(ordinal);
-    column.putNotNull(rowId);
-    column.putShort(rowId, value);
-  }
-
-  @Override
-  public void setInt(int ordinal, int value) {
-    WritableColumnVector column = getWritableColumn(ordinal);
-    column.putNotNull(rowId);
-    column.putInt(rowId, value);
-  }
-
-  @Override
-  public void setLong(int ordinal, long value) {
-    WritableColumnVector column = getWritableColumn(ordinal);
-    column.putNotNull(rowId);
-    column.putLong(rowId, value);
-  }
-
-  @Override
-  public void setFloat(int ordinal, float value) {
-    WritableColumnVector column = getWritableColumn(ordinal);
-    column.putNotNull(rowId);
-    column.putFloat(rowId, value);
-  }
-
-  @Override
-  public void setDouble(int ordinal, double value) {
-    WritableColumnVector column = getWritableColumn(ordinal);
-    column.putNotNull(rowId);
-    column.putDouble(rowId, value);
-  }
-
-  @Override
-  public void setDecimal(int ordinal, Decimal value, int precision) {
-    WritableColumnVector column = getWritableColumn(ordinal);
-    column.putNotNull(rowId);
-    column.putDecimal(rowId, value, precision);
-  }
-
-  private WritableColumnVector getWritableColumn(int ordinal) {
-    WritableColumnVector column = writableColumns[ordinal];
-    assert (!column.isConstant);
-    return column;
-  }
+  public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
new file mode 100644
index 0000000..f272cc1
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.vectorized;
+
+import java.math.BigDecimal;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.catalyst.util.MapData;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash
+ * aggregate.
+ *
+ * Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to
+ * avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes.
+ */
+public final class MutableColumnarRow extends InternalRow {
+  public int rowId;
+  private final WritableColumnVector[] columns;
+
+  public MutableColumnarRow(WritableColumnVector[] columns) {
+    this.columns = columns;
+  }
+
+  @Override
+  public int numFields() { return columns.length; }
+
+  @Override
+  public InternalRow copy() {
+    GenericInternalRow row = new GenericInternalRow(columns.length);
+    for (int i = 0; i < numFields(); i++) {
+      if (isNullAt(i)) {
+        row.setNullAt(i);
+      } else {
+        DataType dt = columns[i].dataType();
+        if (dt instanceof BooleanType) {
+          row.setBoolean(i, getBoolean(i));
+        } else if (dt instanceof ByteType) {
+          row.setByte(i, getByte(i));
+        } else if (dt instanceof ShortType) {
+          row.setShort(i, getShort(i));
+        } else if (dt instanceof IntegerType) {
+          row.setInt(i, getInt(i));
+        } else if (dt instanceof LongType) {
+          row.setLong(i, getLong(i));
+        } else if (dt instanceof FloatType) {
+          row.setFloat(i, getFloat(i));
+        } else if (dt instanceof DoubleType) {
+          row.setDouble(i, getDouble(i));
+        } else if (dt instanceof StringType) {
+          row.update(i, getUTF8String(i).copy());
+        } else if (dt instanceof BinaryType) {
+          row.update(i, getBinary(i));
+        } else if (dt instanceof DecimalType) {
+          DecimalType t = (DecimalType)dt;
+          row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision());
+        } else if (dt instanceof DateType) {
+          row.setInt(i, getInt(i));
+        } else if (dt instanceof TimestampType) {
+          row.setLong(i, getLong(i));
+        } else {
+          throw new RuntimeException("Not implemented. " + dt);
+        }
+      }
+    }
+    return row;
+  }
+
+  @Override
+  public boolean anyNull() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); }
+
+  @Override
+  public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
+
+  @Override
+  public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
+
+  @Override
+  public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
+
+  @Override
+  public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
+
+  @Override
+  public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
+
+  @Override
+  public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
+
+  @Override
+  public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
+
+  @Override
+  public Decimal getDecimal(int ordinal, int precision, int scale) {
+    if (columns[ordinal].isNullAt(rowId)) return null;
+    return columns[ordinal].getDecimal(rowId, precision, scale);
+  }
+
+  @Override
+  public UTF8String getUTF8String(int ordinal) {
+    if (columns[ordinal].isNullAt(rowId)) return null;
+    return columns[ordinal].getUTF8String(rowId);
+  }
+
+  @Override
+  public byte[] getBinary(int ordinal) {
+    if (columns[ordinal].isNullAt(rowId)) return null;
+    return columns[ordinal].getBinary(rowId);
+  }
+
+  @Override
+  public CalendarInterval getInterval(int ordinal) {
+    if (columns[ordinal].isNullAt(rowId)) return null;
+    final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
+    final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
+    return new CalendarInterval(months, microseconds);
+  }
+
+  @Override
+  public ColumnarRow getStruct(int ordinal, int numFields) {
+    if (columns[ordinal].isNullAt(rowId)) return null;
+    return columns[ordinal].getStruct(rowId);
+  }
+
+  @Override
+  public ColumnarArray getArray(int ordinal) {
+    if (columns[ordinal].isNullAt(rowId)) return null;
+    return columns[ordinal].getArray(rowId);
+  }
+
+  @Override
+  public MapData getMap(int ordinal) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public Object get(int ordinal, DataType dataType) {
+    if (dataType instanceof BooleanType) {
+      return getBoolean(ordinal);
+    } else if (dataType instanceof ByteType) {
+      return getByte(ordinal);
+    } else if (dataType instanceof ShortType) {
+      return getShort(ordinal);
+    } else if (dataType instanceof IntegerType) {
+      return getInt(ordinal);
+    } else if (dataType instanceof LongType) {
+      return getLong(ordinal);
+    } else if (dataType instanceof FloatType) {
+      return getFloat(ordinal);
+    } else if (dataType instanceof DoubleType) {
+      return getDouble(ordinal);
+    } else if (dataType instanceof StringType) {
+      return getUTF8String(ordinal);
+    } else if (dataType instanceof BinaryType) {
+      return getBinary(ordinal);
+    } else if (dataType instanceof DecimalType) {
+      DecimalType t = (DecimalType) dataType;
+      return getDecimal(ordinal, t.precision(), t.scale());
+    } else if (dataType instanceof DateType) {
+      return getInt(ordinal);
+    } else if (dataType instanceof TimestampType) {
+      return getLong(ordinal);
+    } else if (dataType instanceof ArrayType) {
+      return getArray(ordinal);
+    } else if (dataType instanceof StructType) {
+      return getStruct(ordinal, ((StructType)dataType).fields().length);
+    } else if (dataType instanceof MapType) {
+      return getMap(ordinal);
+    } else {
+      throw new UnsupportedOperationException("Datatype not supported " + dataType);
+    }
+  }
+
+  @Override
+  public void update(int ordinal, Object value) {
+    if (value == null) {
+      setNullAt(ordinal);
+    } else {
+      DataType dt = columns[ordinal].dataType();
+      if (dt instanceof BooleanType) {
+        setBoolean(ordinal, (boolean) value);
+      } else if (dt instanceof IntegerType) {
+        setInt(ordinal, (int) value);
+      } else if (dt instanceof ShortType) {
+        setShort(ordinal, (short) value);
+      } else if (dt instanceof LongType) {
+        setLong(ordinal, (long) value);
+      } else if (dt instanceof FloatType) {
+        setFloat(ordinal, (float) value);
+      } else if (dt instanceof DoubleType) {
+        setDouble(ordinal, (double) value);
+      } else if (dt instanceof DecimalType) {
+        DecimalType t = (DecimalType) dt;
+        Decimal d = Decimal.apply((BigDecimal) value, t.precision(), t.scale());
+        setDecimal(ordinal, d, t.precision());
+      } else {
+        throw new UnsupportedOperationException("Datatype not supported " + dt);
+      }
+    }
+  }
+
+  @Override
+  public void setNullAt(int ordinal) {
+    columns[ordinal].putNull(rowId);
+  }
+
+  @Override
+  public void setBoolean(int ordinal, boolean value) {
+    columns[ordinal].putNotNull(rowId);
+    columns[ordinal].putBoolean(rowId, value);
+  }
+
+  @Override
+  public void setByte(int ordinal, byte value) {
+    columns[ordinal].putNotNull(rowId);
+    columns[ordinal].putByte(rowId, value);
+  }
+
+  @Override
+  public void setShort(int ordinal, short value) {
+    columns[ordinal].putNotNull(rowId);
+    columns[ordinal].putShort(rowId, value);
+  }
+
+  @Override
+  public void setInt(int ordinal, int value) {
+    columns[ordinal].putNotNull(rowId);
+    columns[ordinal].putInt(rowId, value);
+  }
+
+  @Override
+  public void setLong(int ordinal, long value) {
+    columns[ordinal].putNotNull(rowId);
+    columns[ordinal].putLong(rowId, value);
+  }
+
+  @Override
+  public void setFloat(int ordinal, float value) {
+    columns[ordinal].putNotNull(rowId);
+    columns[ordinal].putFloat(rowId, value);
+  }
+
+  @Override
+  public void setDouble(int ordinal, double value) {
+    columns[ordinal].putNotNull(rowId);
+    columns[ordinal].putDouble(rowId, value);
+  }
+
+  @Override
+  public void setDecimal(int ordinal, Decimal value, int precision) {
+    columns[ordinal].putNotNull(rowId);
+    columns[ordinal].putDecimal(rowId, value, precision);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index dc8aecf..9139788 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.vectorized.MutableColumnarRow
 import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
 import org.apache.spark.unsafe.KVIterator
 import org.apache.spark.util.Utils
@@ -894,7 +895,7 @@ case class HashAggregateExec(
      ${
         if (isVectorizedHashMapEnabled) {
           s"""
-             | org.apache.spark.sql.execution.vectorized.ColumnarRow $fastRowBuffer = null;
+             | ${classOf[MutableColumnarRow].getName} $fastRowBuffer = null;
            """.stripMargin
         } else {
           s"""

http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index fd783d9..44ba539 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate
 
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnarRow, MutableColumnarRow, OnHeapColumnVector}
 import org.apache.spark.sql.types._
 
 /**
@@ -76,10 +77,9 @@ class VectorizedHashMapGenerator(
         }.mkString("\n").concat(";")
 
     s"""
-       |  private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] batchVectors;
-       |  private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] bufferVectors;
-       |  private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
-       |  private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch;
+       |  private ${classOf[OnHeapColumnVector].getName}[] vectors;
+       |  private ${classOf[ColumnarBatch].getName} batch;
+       |  private ${classOf[MutableColumnarRow].getName} aggBufferRow;
        |  private int[] buckets;
        |  private int capacity = 1 << 16;
        |  private double loadFactor = 0.5;
@@ -91,19 +91,16 @@ class VectorizedHashMapGenerator(
        |    $generatedAggBufferSchema
        |
        |  public $generatedClassName() {
-       |    batchVectors = org.apache.spark.sql.execution.vectorized
-       |      .OnHeapColumnVector.allocateColumns(capacity, schema);
-       |    batch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch(
-       |      schema, batchVectors, capacity);
+       |    vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema);
+       |    batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity);
        |
-       |    bufferVectors = new org.apache.spark.sql.execution.vectorized
-       |      .OnHeapColumnVector[aggregateBufferSchema.fields().length];
+       |    // Generates a projection to return the aggregate buffer only.
+       |    ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors =
+       |      new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length];
        |    for (int i = 0; i < aggregateBufferSchema.fields().length; i++) {
-       |      bufferVectors[i] = batchVectors[i + ${groupingKeys.length}];
+       |      aggBufferVectors[i] = vectors[i + ${groupingKeys.length}];
        |    }
-       |    // TODO: Possibly generate this projection in HashAggregate directly
-       |    aggregateBufferBatch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch(
-       |      aggregateBufferSchema, bufferVectors, capacity);
+       |    aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors);
        |
        |    buckets = new int[numBuckets];
        |    java.util.Arrays.fill(buckets, -1);
@@ -114,13 +111,13 @@ class VectorizedHashMapGenerator(
 
   /**
    * Generates a method that returns true if the group-by keys exist at a given index in the
-   * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
-   * have 2 long group-by keys, the generated function would be of the form:
+   * associated [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance,
+   * if we have 2 long group-by keys, the generated function would be of the form:
    *
    * {{{
    * private boolean equals(int idx, long agg_key, long agg_key1) {
-   *   return batchVectors[0].getLong(buckets[idx]) == agg_key &&
-   *     batchVectors[1].getLong(buckets[idx]) == agg_key1;
+   *   return vectors[0].getLong(buckets[idx]) == agg_key &&
+   *     vectors[1].getLong(buckets[idx]) == agg_key1;
    * }
    * }}}
    */
@@ -128,7 +125,7 @@ class VectorizedHashMapGenerator(
 
     def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
       groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
-        s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"batchVectors[$ordinal]", "buckets[idx]",
+        s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]",
           key.dataType), key.name)})"""
       }.mkString(" && ")
     }
@@ -141,29 +138,35 @@ class VectorizedHashMapGenerator(
   }
 
   /**
-   * Generates a method that returns a mutable
-   * [[org.apache.spark.sql.execution.vectorized.ColumnarRow]] which keeps track of the
+   * Generates a method that returns a
+   * [[org.apache.spark.sql.execution.vectorized.MutableColumnarRow]] which keeps track of the
    * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
    * generated method adds the corresponding row in the associated
-   * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
+   * [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance, if we
    * have 2 long group-by keys, the generated function would be of the form:
    *
    * {{{
-   * public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(
-   *     long agg_key, long agg_key1) {
+   * public MutableColumnarRow findOrInsert(long agg_key, long agg_key1) {
    *   long h = hash(agg_key, agg_key1);
    *   int step = 0;
    *   int idx = (int) h & (numBuckets - 1);
    *   while (step < maxSteps) {
    *     // Return bucket index if it's either an empty slot or already contains the key
    *     if (buckets[idx] == -1) {
-   *       batchVectors[0].putLong(numRows, agg_key);
-   *       batchVectors[1].putLong(numRows, agg_key1);
-   *       batchVectors[2].putLong(numRows, 0);
-   *       buckets[idx] = numRows++;
-   *       return batch.getRow(buckets[idx]);
+   *       if (numRows < capacity) {
+   *         vectors[0].putLong(numRows, agg_key);
+   *         vectors[1].putLong(numRows, agg_key1);
+   *         vectors[2].putLong(numRows, 0);
+   *         buckets[idx] = numRows++;
+   *         aggBufferRow.rowId = numRows;
+   *         return aggBufferRow;
+   *       } else {
+   *         // No more space
+   *         return null;
+   *       }
    *     } else if (equals(idx, agg_key, agg_key1)) {
-   *       return batch.getRow(buckets[idx]);
+   *       aggBufferRow.rowId = buckets[idx];
+   *       return aggBufferRow;
    *     }
    *     idx = (idx + 1) & (numBuckets - 1);
    *     step++;
@@ -177,20 +180,19 @@ class VectorizedHashMapGenerator(
 
     def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = {
       groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
-        ctx.setValue(s"batchVectors[$ordinal]", "numRows", key.dataType, key.name)
+        ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name)
       }
     }
 
     def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = {
       bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
-        ctx.updateColumn(s"batchVectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
+        ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
           buffVars(ordinal), nullable = true)
       }
     }
 
     s"""
-       |public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(${
-            groupingKeySignature}) {
+       |public ${classOf[MutableColumnarRow].getName} findOrInsert($groupingKeySignature) {
        |  long h = hash(${groupingKeys.map(_.name).mkString(", ")});
        |  int step = 0;
        |  int idx = (int) h & (numBuckets - 1);
@@ -208,15 +210,15 @@ class VectorizedHashMapGenerator(
        |        ${genCodeToSetAggBuffers(bufferValues).mkString("\n")}
        |
        |        buckets[idx] = numRows++;
-       |        batch.setNumRows(numRows);
-       |        aggregateBufferBatch.setNumRows(numRows);
-       |        return aggregateBufferBatch.getRow(buckets[idx]);
+       |        aggBufferRow.rowId = buckets[idx];
+       |        return aggBufferRow;
        |      } else {
        |        // No more space
        |        return null;
        |      }
        |    } else if (equals(idx, ${groupingKeys.map(_.name).mkString(", ")})) {
-       |      return aggregateBufferBatch.getRow(buckets[idx]);
+       |      aggBufferRow.rowId = buckets[idx];
+       |      return aggBufferRow;
        |    }
        |    idx = (idx + 1) & (numBuckets - 1);
        |    step++;
@@ -229,8 +231,8 @@ class VectorizedHashMapGenerator(
 
   protected def generateRowIterator(): String = {
     s"""
-       |public java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarRow>
-       |    rowIterator() {
+       |public java.util.Iterator<${classOf[ColumnarRow].getName}> rowIterator() {
+       |  batch.setNumRows(numRows);
        |  return batch.rowIterator();
        |}
      """.stripMargin

http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
index 3c76ca7..e28ab71 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
@@ -163,6 +163,18 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
     }
   }
 
+  testVectors("mutable ColumnarRow", 10, IntegerType) { testVector =>
+    val mutableRow = new MutableColumnarRow(Array(testVector))
+    (0 until 10).foreach { i =>
+      mutableRow.rowId = i
+      mutableRow.setInt(0, 10 - i)
+    }
+    (0 until 10).foreach { i =>
+      mutableRow.rowId = i
+      assert(mutableRow.getInt(0) === (10 - i))
+    }
+  }
+
   val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true)
   testVectors("array", 10, arrayType) { testVector =>
 

http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 80a5086..1b4e2ba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -1129,29 +1129,6 @@ class ColumnarBatchSuite extends SparkFunSuite {
     testRandomRows(false, 30)
   }
 
-  test("mutable ColumnarBatch rows") {
-    val NUM_ITERS = 10
-    val types = Array(
-      BooleanType, FloatType, DoubleType, IntegerType, LongType, ShortType,
-      DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal,
-      DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2),
-      new DecimalType(12, 2), new DecimalType(30, 10))
-    for (i <- 0 to NUM_ITERS) {
-      val random = new Random(System.nanoTime())
-      val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types)
-      val oldRow = RandomDataGenerator.randomRow(random, schema)
-      val newRow = RandomDataGenerator.randomRow(random, schema)
-
-      (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode =>
-        val batch = ColumnVectorUtils.toBatch(schema, memMode, (oldRow :: Nil).iterator.asJava)
-        val columnarBatchRow = batch.getRow(0)
-        newRow.toSeq.zipWithIndex.foreach(i => columnarBatchRow.update(i._2, i._1))
-        compareStruct(schema, columnarBatchRow, newRow, 0)
-        batch.close()
-      }
-    }
-  }
-
   test("exceeding maximum capacity should throw an error") {
     (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode =>
       val column = allocate(1, ByteType, memMode)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org