You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/11/30 10:29:19 UTC
spark git commit: [SPARK-22652][SQL] remove set methods in ColumnarRow
Repository: spark
Updated Branches:
refs/heads/master 92cfbeeb5 -> 444a2bbb6
[SPARK-22652][SQL] remove set methods in ColumnarRow
## What changes were proposed in this pull request?
As a step to make `ColumnVector` public, the `ColumnarRow` returned by `ColumnVector#getStruct` should be immutable.
However we do need the mutability of `ColumnaRow` for the fast vectorized hashmap in hash aggregate. To solve this, this PR introduces a `MutableColumnarRow` for this use case.
## How was this patch tested?
existing test.
Author: Wenchen Fan <we...@databricks.com>
Closes #19847 from cloud-fan/mutable-row.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/444a2bbb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/444a2bbb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/444a2bbb
Branch: refs/heads/master
Commit: 444a2bbb67c2548d121152bc922b4c3337ddc8e8
Parents: 92cfbee
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Nov 30 18:28:58 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Nov 30 18:28:58 2017 +0800
----------------------------------------------------------------------
.../sql/execution/vectorized/ColumnarRow.java | 102 +------
.../vectorized/MutableColumnarRow.java | 278 +++++++++++++++++++
.../execution/aggregate/HashAggregateExec.scala | 3 +-
.../aggregate/VectorizedHashMapGenerator.scala | 82 +++---
.../vectorized/ColumnVectorSuite.scala | 12 +
.../vectorized/ColumnarBatchSuite.scala | 23 --
6 files changed, 336 insertions(+), 164 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
index 98a9073..cabb747 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
@@ -16,8 +16,6 @@
*/
package org.apache.spark.sql.execution.vectorized;
-import java.math.BigDecimal;
-
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.util.MapData;
@@ -32,17 +30,10 @@ import org.apache.spark.unsafe.types.UTF8String;
public final class ColumnarRow extends InternalRow {
protected int rowId;
private final ColumnVector[] columns;
- private final WritableColumnVector[] writableColumns;
// Ctor used if this is a struct.
ColumnarRow(ColumnVector[] columns) {
this.columns = columns;
- this.writableColumns = new WritableColumnVector[this.columns.length];
- for (int i = 0; i < this.columns.length; i++) {
- if (this.columns[i] instanceof WritableColumnVector) {
- this.writableColumns[i] = (WritableColumnVector) this.columns[i];
- }
- }
}
public ColumnVector[] columns() { return columns; }
@@ -205,97 +196,8 @@ public final class ColumnarRow extends InternalRow {
}
@Override
- public void update(int ordinal, Object value) {
- if (value == null) {
- setNullAt(ordinal);
- } else {
- DataType dt = columns[ordinal].dataType();
- if (dt instanceof BooleanType) {
- setBoolean(ordinal, (boolean) value);
- } else if (dt instanceof IntegerType) {
- setInt(ordinal, (int) value);
- } else if (dt instanceof ShortType) {
- setShort(ordinal, (short) value);
- } else if (dt instanceof LongType) {
- setLong(ordinal, (long) value);
- } else if (dt instanceof FloatType) {
- setFloat(ordinal, (float) value);
- } else if (dt instanceof DoubleType) {
- setDouble(ordinal, (double) value);
- } else if (dt instanceof DecimalType) {
- DecimalType t = (DecimalType) dt;
- setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()),
- t.precision());
- } else {
- throw new UnsupportedOperationException("Datatype not supported " + dt);
- }
- }
- }
-
- @Override
- public void setNullAt(int ordinal) {
- getWritableColumn(ordinal).putNull(rowId);
- }
-
- @Override
- public void setBoolean(int ordinal, boolean value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putBoolean(rowId, value);
- }
+ public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); }
@Override
- public void setByte(int ordinal, byte value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putByte(rowId, value);
- }
-
- @Override
- public void setShort(int ordinal, short value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putShort(rowId, value);
- }
-
- @Override
- public void setInt(int ordinal, int value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putInt(rowId, value);
- }
-
- @Override
- public void setLong(int ordinal, long value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putLong(rowId, value);
- }
-
- @Override
- public void setFloat(int ordinal, float value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putFloat(rowId, value);
- }
-
- @Override
- public void setDouble(int ordinal, double value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putDouble(rowId, value);
- }
-
- @Override
- public void setDecimal(int ordinal, Decimal value, int precision) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putDecimal(rowId, value, precision);
- }
-
- private WritableColumnVector getWritableColumn(int ordinal) {
- WritableColumnVector column = writableColumns[ordinal];
- assert (!column.isConstant);
- return column;
- }
+ public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
new file mode 100644
index 0000000..f272cc1
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.vectorized;
+
+import java.math.BigDecimal;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.catalyst.util.MapData;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash
+ * aggregate.
+ *
+ * Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to
+ * avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes.
+ */
+public final class MutableColumnarRow extends InternalRow {
+ public int rowId;
+ private final WritableColumnVector[] columns;
+
+ public MutableColumnarRow(WritableColumnVector[] columns) {
+ this.columns = columns;
+ }
+
+ @Override
+ public int numFields() { return columns.length; }
+
+ @Override
+ public InternalRow copy() {
+ GenericInternalRow row = new GenericInternalRow(columns.length);
+ for (int i = 0; i < numFields(); i++) {
+ if (isNullAt(i)) {
+ row.setNullAt(i);
+ } else {
+ DataType dt = columns[i].dataType();
+ if (dt instanceof BooleanType) {
+ row.setBoolean(i, getBoolean(i));
+ } else if (dt instanceof ByteType) {
+ row.setByte(i, getByte(i));
+ } else if (dt instanceof ShortType) {
+ row.setShort(i, getShort(i));
+ } else if (dt instanceof IntegerType) {
+ row.setInt(i, getInt(i));
+ } else if (dt instanceof LongType) {
+ row.setLong(i, getLong(i));
+ } else if (dt instanceof FloatType) {
+ row.setFloat(i, getFloat(i));
+ } else if (dt instanceof DoubleType) {
+ row.setDouble(i, getDouble(i));
+ } else if (dt instanceof StringType) {
+ row.update(i, getUTF8String(i).copy());
+ } else if (dt instanceof BinaryType) {
+ row.update(i, getBinary(i));
+ } else if (dt instanceof DecimalType) {
+ DecimalType t = (DecimalType)dt;
+ row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision());
+ } else if (dt instanceof DateType) {
+ row.setInt(i, getInt(i));
+ } else if (dt instanceof TimestampType) {
+ row.setLong(i, getLong(i));
+ } else {
+ throw new RuntimeException("Not implemented. " + dt);
+ }
+ }
+ }
+ return row;
+ }
+
+ @Override
+ public boolean anyNull() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); }
+
+ @Override
+ public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
+
+ @Override
+ public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
+
+ @Override
+ public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
+
+ @Override
+ public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
+
+ @Override
+ public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
+
+ @Override
+ public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
+
+ @Override
+ public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
+
+ @Override
+ public Decimal getDecimal(int ordinal, int precision, int scale) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
+ return columns[ordinal].getDecimal(rowId, precision, scale);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int ordinal) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
+ return columns[ordinal].getUTF8String(rowId);
+ }
+
+ @Override
+ public byte[] getBinary(int ordinal) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
+ return columns[ordinal].getBinary(rowId);
+ }
+
+ @Override
+ public CalendarInterval getInterval(int ordinal) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
+ final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
+ final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
+ return new CalendarInterval(months, microseconds);
+ }
+
+ @Override
+ public ColumnarRow getStruct(int ordinal, int numFields) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
+ return columns[ordinal].getStruct(rowId);
+ }
+
+ @Override
+ public ColumnarArray getArray(int ordinal) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
+ return columns[ordinal].getArray(rowId);
+ }
+
+ @Override
+ public MapData getMap(int ordinal) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Object get(int ordinal, DataType dataType) {
+ if (dataType instanceof BooleanType) {
+ return getBoolean(ordinal);
+ } else if (dataType instanceof ByteType) {
+ return getByte(ordinal);
+ } else if (dataType instanceof ShortType) {
+ return getShort(ordinal);
+ } else if (dataType instanceof IntegerType) {
+ return getInt(ordinal);
+ } else if (dataType instanceof LongType) {
+ return getLong(ordinal);
+ } else if (dataType instanceof FloatType) {
+ return getFloat(ordinal);
+ } else if (dataType instanceof DoubleType) {
+ return getDouble(ordinal);
+ } else if (dataType instanceof StringType) {
+ return getUTF8String(ordinal);
+ } else if (dataType instanceof BinaryType) {
+ return getBinary(ordinal);
+ } else if (dataType instanceof DecimalType) {
+ DecimalType t = (DecimalType) dataType;
+ return getDecimal(ordinal, t.precision(), t.scale());
+ } else if (dataType instanceof DateType) {
+ return getInt(ordinal);
+ } else if (dataType instanceof TimestampType) {
+ return getLong(ordinal);
+ } else if (dataType instanceof ArrayType) {
+ return getArray(ordinal);
+ } else if (dataType instanceof StructType) {
+ return getStruct(ordinal, ((StructType)dataType).fields().length);
+ } else if (dataType instanceof MapType) {
+ return getMap(ordinal);
+ } else {
+ throw new UnsupportedOperationException("Datatype not supported " + dataType);
+ }
+ }
+
+ @Override
+ public void update(int ordinal, Object value) {
+ if (value == null) {
+ setNullAt(ordinal);
+ } else {
+ DataType dt = columns[ordinal].dataType();
+ if (dt instanceof BooleanType) {
+ setBoolean(ordinal, (boolean) value);
+ } else if (dt instanceof IntegerType) {
+ setInt(ordinal, (int) value);
+ } else if (dt instanceof ShortType) {
+ setShort(ordinal, (short) value);
+ } else if (dt instanceof LongType) {
+ setLong(ordinal, (long) value);
+ } else if (dt instanceof FloatType) {
+ setFloat(ordinal, (float) value);
+ } else if (dt instanceof DoubleType) {
+ setDouble(ordinal, (double) value);
+ } else if (dt instanceof DecimalType) {
+ DecimalType t = (DecimalType) dt;
+ Decimal d = Decimal.apply((BigDecimal) value, t.precision(), t.scale());
+ setDecimal(ordinal, d, t.precision());
+ } else {
+ throw new UnsupportedOperationException("Datatype not supported " + dt);
+ }
+ }
+ }
+
+ @Override
+ public void setNullAt(int ordinal) {
+ columns[ordinal].putNull(rowId);
+ }
+
+ @Override
+ public void setBoolean(int ordinal, boolean value) {
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putBoolean(rowId, value);
+ }
+
+ @Override
+ public void setByte(int ordinal, byte value) {
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putByte(rowId, value);
+ }
+
+ @Override
+ public void setShort(int ordinal, short value) {
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putShort(rowId, value);
+ }
+
+ @Override
+ public void setInt(int ordinal, int value) {
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putInt(rowId, value);
+ }
+
+ @Override
+ public void setLong(int ordinal, long value) {
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putLong(rowId, value);
+ }
+
+ @Override
+ public void setFloat(int ordinal, float value) {
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putFloat(rowId, value);
+ }
+
+ @Override
+ public void setDouble(int ordinal, double value) {
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putDouble(rowId, value);
+ }
+
+ @Override
+ public void setDecimal(int ordinal, Decimal value, int precision) {
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putDecimal(rowId, value, precision);
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index dc8aecf..9139788 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.vectorized.MutableColumnarRow
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.unsafe.KVIterator
import org.apache.spark.util.Utils
@@ -894,7 +895,7 @@ case class HashAggregateExec(
${
if (isVectorizedHashMapEnabled) {
s"""
- | org.apache.spark.sql.execution.vectorized.ColumnarRow $fastRowBuffer = null;
+ | ${classOf[MutableColumnarRow].getName} $fastRowBuffer = null;
""".stripMargin
} else {
s"""
http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index fd783d9..44ba539 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnarRow, MutableColumnarRow, OnHeapColumnVector}
import org.apache.spark.sql.types._
/**
@@ -76,10 +77,9 @@ class VectorizedHashMapGenerator(
}.mkString("\n").concat(";")
s"""
- | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] batchVectors;
- | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] bufferVectors;
- | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
- | private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch;
+ | private ${classOf[OnHeapColumnVector].getName}[] vectors;
+ | private ${classOf[ColumnarBatch].getName} batch;
+ | private ${classOf[MutableColumnarRow].getName} aggBufferRow;
| private int[] buckets;
| private int capacity = 1 << 16;
| private double loadFactor = 0.5;
@@ -91,19 +91,16 @@ class VectorizedHashMapGenerator(
| $generatedAggBufferSchema
|
| public $generatedClassName() {
- | batchVectors = org.apache.spark.sql.execution.vectorized
- | .OnHeapColumnVector.allocateColumns(capacity, schema);
- | batch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch(
- | schema, batchVectors, capacity);
+ | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema);
+ | batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity);
|
- | bufferVectors = new org.apache.spark.sql.execution.vectorized
- | .OnHeapColumnVector[aggregateBufferSchema.fields().length];
+ | // Generates a projection to return the aggregate buffer only.
+ | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors =
+ | new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length];
| for (int i = 0; i < aggregateBufferSchema.fields().length; i++) {
- | bufferVectors[i] = batchVectors[i + ${groupingKeys.length}];
+ | aggBufferVectors[i] = vectors[i + ${groupingKeys.length}];
| }
- | // TODO: Possibly generate this projection in HashAggregate directly
- | aggregateBufferBatch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch(
- | aggregateBufferSchema, bufferVectors, capacity);
+ | aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors);
|
| buckets = new int[numBuckets];
| java.util.Arrays.fill(buckets, -1);
@@ -114,13 +111,13 @@ class VectorizedHashMapGenerator(
/**
* Generates a method that returns true if the group-by keys exist at a given index in the
- * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
- * have 2 long group-by keys, the generated function would be of the form:
+ * associated [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance,
+ * if we have 2 long group-by keys, the generated function would be of the form:
*
* {{{
* private boolean equals(int idx, long agg_key, long agg_key1) {
- * return batchVectors[0].getLong(buckets[idx]) == agg_key &&
- * batchVectors[1].getLong(buckets[idx]) == agg_key1;
+ * return vectors[0].getLong(buckets[idx]) == agg_key &&
+ * vectors[1].getLong(buckets[idx]) == agg_key1;
* }
* }}}
*/
@@ -128,7 +125,7 @@ class VectorizedHashMapGenerator(
def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"batchVectors[$ordinal]", "buckets[idx]",
+ s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]",
key.dataType), key.name)})"""
}.mkString(" && ")
}
@@ -141,29 +138,35 @@ class VectorizedHashMapGenerator(
}
/**
- * Generates a method that returns a mutable
- * [[org.apache.spark.sql.execution.vectorized.ColumnarRow]] which keeps track of the
+ * Generates a method that returns a
+ * [[org.apache.spark.sql.execution.vectorized.MutableColumnarRow]] which keeps track of the
* aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
* generated method adds the corresponding row in the associated
- * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
+ * [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance, if we
* have 2 long group-by keys, the generated function would be of the form:
*
* {{{
- * public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(
- * long agg_key, long agg_key1) {
+ * public MutableColumnarRow findOrInsert(long agg_key, long agg_key1) {
* long h = hash(agg_key, agg_key1);
* int step = 0;
* int idx = (int) h & (numBuckets - 1);
* while (step < maxSteps) {
* // Return bucket index if it's either an empty slot or already contains the key
* if (buckets[idx] == -1) {
- * batchVectors[0].putLong(numRows, agg_key);
- * batchVectors[1].putLong(numRows, agg_key1);
- * batchVectors[2].putLong(numRows, 0);
- * buckets[idx] = numRows++;
- * return batch.getRow(buckets[idx]);
+ * if (numRows < capacity) {
+ * vectors[0].putLong(numRows, agg_key);
+ * vectors[1].putLong(numRows, agg_key1);
+ * vectors[2].putLong(numRows, 0);
+ * buckets[idx] = numRows++;
+ * aggBufferRow.rowId = numRows;
+ * return aggBufferRow;
+ * } else {
+ * // No more space
+ * return null;
+ * }
* } else if (equals(idx, agg_key, agg_key1)) {
- * return batch.getRow(buckets[idx]);
+ * aggBufferRow.rowId = buckets[idx];
+ * return aggBufferRow;
* }
* idx = (idx + 1) & (numBuckets - 1);
* step++;
@@ -177,20 +180,19 @@ class VectorizedHashMapGenerator(
def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- ctx.setValue(s"batchVectors[$ordinal]", "numRows", key.dataType, key.name)
+ ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name)
}
}
def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = {
bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- ctx.updateColumn(s"batchVectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
+ ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
buffVars(ordinal), nullable = true)
}
}
s"""
- |public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(${
- groupingKeySignature}) {
+ |public ${classOf[MutableColumnarRow].getName} findOrInsert($groupingKeySignature) {
| long h = hash(${groupingKeys.map(_.name).mkString(", ")});
| int step = 0;
| int idx = (int) h & (numBuckets - 1);
@@ -208,15 +210,15 @@ class VectorizedHashMapGenerator(
| ${genCodeToSetAggBuffers(bufferValues).mkString("\n")}
|
| buckets[idx] = numRows++;
- | batch.setNumRows(numRows);
- | aggregateBufferBatch.setNumRows(numRows);
- | return aggregateBufferBatch.getRow(buckets[idx]);
+ | aggBufferRow.rowId = buckets[idx];
+ | return aggBufferRow;
| } else {
| // No more space
| return null;
| }
| } else if (equals(idx, ${groupingKeys.map(_.name).mkString(", ")})) {
- | return aggregateBufferBatch.getRow(buckets[idx]);
+ | aggBufferRow.rowId = buckets[idx];
+ | return aggBufferRow;
| }
| idx = (idx + 1) & (numBuckets - 1);
| step++;
@@ -229,8 +231,8 @@ class VectorizedHashMapGenerator(
protected def generateRowIterator(): String = {
s"""
- |public java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarRow>
- | rowIterator() {
+ |public java.util.Iterator<${classOf[ColumnarRow].getName}> rowIterator() {
+ | batch.setNumRows(numRows);
| return batch.rowIterator();
|}
""".stripMargin
http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
index 3c76ca7..e28ab71 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
@@ -163,6 +163,18 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
}
}
+ testVectors("mutable ColumnarRow", 10, IntegerType) { testVector =>
+ val mutableRow = new MutableColumnarRow(Array(testVector))
+ (0 until 10).foreach { i =>
+ mutableRow.rowId = i
+ mutableRow.setInt(0, 10 - i)
+ }
+ (0 until 10).foreach { i =>
+ mutableRow.rowId = i
+ assert(mutableRow.getInt(0) === (10 - i))
+ }
+ }
+
val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true)
testVectors("array", 10, arrayType) { testVector =>
http://git-wip-us.apache.org/repos/asf/spark/blob/444a2bbb/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 80a5086..1b4e2ba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -1129,29 +1129,6 @@ class ColumnarBatchSuite extends SparkFunSuite {
testRandomRows(false, 30)
}
- test("mutable ColumnarBatch rows") {
- val NUM_ITERS = 10
- val types = Array(
- BooleanType, FloatType, DoubleType, IntegerType, LongType, ShortType,
- DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal,
- DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2),
- new DecimalType(12, 2), new DecimalType(30, 10))
- for (i <- 0 to NUM_ITERS) {
- val random = new Random(System.nanoTime())
- val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types)
- val oldRow = RandomDataGenerator.randomRow(random, schema)
- val newRow = RandomDataGenerator.randomRow(random, schema)
-
- (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode =>
- val batch = ColumnVectorUtils.toBatch(schema, memMode, (oldRow :: Nil).iterator.asJava)
- val columnarBatchRow = batch.getRow(0)
- newRow.toSeq.zipWithIndex.foreach(i => columnarBatchRow.update(i._2, i._1))
- compareStruct(schema, columnarBatchRow, newRow, 0)
- batch.close()
- }
- }
- }
-
test("exceeding maximum capacity should throw an error") {
(MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode =>
val column = allocate(1, ByteType, memMode)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org