You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2018/04/02 19:48:49 UTC
spark git commit: [SPARK-23713][SQL] Cleanup UnsafeWriter and
BufferHolder classes
Repository: spark
Updated Branches:
refs/heads/master fe2b7a456 -> a7c19d9c2
[SPARK-23713][SQL] Cleanup UnsafeWriter and BufferHolder classes
## What changes were proposed in this pull request?
This PR implemented the following cleanups related to `UnsafeWriter` class:
- Remove code duplication between `UnsafeRowWriter` and `UnsafeArrayWriter`
- Make `BufferHolder` class internal by delegating its accessor methods to `UnsafeWriter`
- Replace `UnsafeRow.setTotalSize(...)` with `UnsafeRowWriter.setTotalSize()`
## How was this patch tested?
Tested by existing UTs
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Closes #20850 from kiszk/SPARK-23713.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a7c19d9c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a7c19d9c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a7c19d9c
Branch: refs/heads/master
Commit: a7c19d9c21d59fd0109a7078c80b33d3da03fafd
Parents: fe2b7a4
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Mon Apr 2 21:48:44 2018 +0200
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Mon Apr 2 21:48:44 2018 +0200
----------------------------------------------------------------------
.../sql/kafka010/KafkaContinuousReader.scala | 3 -
.../KafkaRecordToUnsafeRowConverter.scala | 11 +-
.../expressions/codegen/BufferHolder.java | 32 ++--
.../expressions/codegen/UnsafeArrayWriter.java | 133 +++----------
.../expressions/codegen/UnsafeRowWriter.java | 189 +++++++------------
.../expressions/codegen/UnsafeWriter.java | 157 ++++++++++++++-
.../InterpretedUnsafeProjection.scala | 90 ++++-----
.../codegen/GenerateUnsafeProjection.scala | 124 ++++++------
.../expressions/RowBasedKeyValueBatchSuite.java | 28 +--
.../aggregate/RowBasedHashMapGenerator.scala | 12 +-
.../columnar/GenerateColumnAccessor.scala | 9 +-
.../datasources/text/TextFileFormat.scala | 11 +-
12 files changed, 391 insertions(+), 408 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
index e7e2787..f26c134 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
@@ -27,13 +27,10 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
-import org.apache.spark.unsafe.types.UTF8String
/**
* A [[ContinuousReader]] for data from kafka.
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala
index 1acdd56..f35a143 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala
@@ -20,18 +20,16 @@ package org.apache.spark.sql.kafka010
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.unsafe.types.UTF8String
/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */
private[kafka010] class KafkaRecordToUnsafeRowConverter {
- private val sharedRow = new UnsafeRow(7)
- private val bufferHolder = new BufferHolder(sharedRow)
- private val rowWriter = new UnsafeRowWriter(bufferHolder, 7)
+ private val rowWriter = new UnsafeRowWriter(7)
def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = {
- bufferHolder.reset()
+ rowWriter.reset()
if (record.key == null) {
rowWriter.setNullAt(0)
@@ -46,7 +44,6 @@ private[kafka010] class KafkaRecordToUnsafeRowConverter {
5,
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp)))
rowWriter.write(6, record.timestampType.id)
- sharedRow.setTotalSize(bufferHolder.totalSize)
- sharedRow
+ rowWriter.getRow()
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
index 2599761..537ef24 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
@@ -30,25 +30,21 @@ import org.apache.spark.unsafe.array.ByteArrayMethods;
* this class per writing program, so that the memory segment/data buffer can be reused. Note that
* for each incoming record, we should call `reset` of BufferHolder instance before write the record
* and reuse the data buffer.
- *
- * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update
- * the size of the result row, after writing a record to the buffer. However, we can skip this step
- * if the fields of row are all fixed-length, as the size of result row is also fixed.
*/
-public class BufferHolder {
+final class BufferHolder {
private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH;
- public byte[] buffer;
- public int cursor = Platform.BYTE_ARRAY_OFFSET;
+ private byte[] buffer;
+ private int cursor = Platform.BYTE_ARRAY_OFFSET;
private final UnsafeRow row;
private final int fixedSize;
- public BufferHolder(UnsafeRow row) {
+ BufferHolder(UnsafeRow row) {
this(row, 64);
}
- public BufferHolder(UnsafeRow row, int initialSize) {
+ BufferHolder(UnsafeRow row, int initialSize) {
int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields());
if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) {
throw new UnsupportedOperationException(
@@ -64,7 +60,7 @@ public class BufferHolder {
/**
* Grows the buffer by at least neededSize and points the row to the buffer.
*/
- public void grow(int neededSize) {
+ void grow(int neededSize) {
if (neededSize > ARRAY_MAX - totalSize()) {
throw new UnsupportedOperationException(
"Cannot grow BufferHolder by size " + neededSize + " because the size after growing " +
@@ -86,11 +82,23 @@ public class BufferHolder {
}
}
- public void reset() {
+ byte[] getBuffer() {
+ return buffer;
+ }
+
+ int getCursor() {
+ return cursor;
+ }
+
+ void increaseCursor(int val) {
+ cursor += val;
+ }
+
+ void reset() {
cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize;
}
- public int totalSize() {
+ int totalSize() {
return cursor - Platform.BYTE_ARRAY_OFFSET;
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index 82cd1b2..a78dd97 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -21,8 +21,6 @@ import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
-import org.apache.spark.unsafe.types.CalendarInterval;
-import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes;
@@ -32,14 +30,12 @@ import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculat
*/
public final class UnsafeArrayWriter extends UnsafeWriter {
- private BufferHolder holder;
-
- // The offset of the global buffer where we start to write this array.
- private int startingOffset;
-
// The number of elements in this array
private int numElements;
+ // The element size in this array
+ private int elementSize;
+
private int headerInBytes;
private void assertIndexIsValid(int index) {
@@ -47,13 +43,17 @@ public final class UnsafeArrayWriter extends UnsafeWriter {
assert index < numElements : "index (" + index + ") should < " + numElements;
}
- public void initialize(BufferHolder holder, int numElements, int elementSize) {
+ public UnsafeArrayWriter(UnsafeWriter writer, int elementSize) {
+ super(writer.getBufferHolder());
+ this.elementSize = elementSize;
+ }
+
+ public void initialize(int numElements) {
// We need 8 bytes to store numElements in header
this.numElements = numElements;
this.headerInBytes = calculateHeaderPortionInBytes(numElements);
- this.holder = holder;
- this.startingOffset = holder.cursor;
+ this.startingOffset = cursor();
// Grows the global buffer ahead for header and fixed size data.
int fixedPartInBytes =
@@ -61,112 +61,92 @@ public final class UnsafeArrayWriter extends UnsafeWriter {
holder.grow(headerInBytes + fixedPartInBytes);
// Write numElements and clear out null bits to header
- Platform.putLong(holder.buffer, startingOffset, numElements);
+ Platform.putLong(getBuffer(), startingOffset, numElements);
for (int i = 8; i < headerInBytes; i += 8) {
- Platform.putLong(holder.buffer, startingOffset + i, 0L);
+ Platform.putLong(getBuffer(), startingOffset + i, 0L);
}
// fill 0 into reminder part of 8-bytes alignment in unsafe array
for (int i = elementSize * numElements; i < fixedPartInBytes; i++) {
- Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0);
+ Platform.putByte(getBuffer(), startingOffset + headerInBytes + i, (byte) 0);
}
- holder.cursor += (headerInBytes + fixedPartInBytes);
+ increaseCursor(headerInBytes + fixedPartInBytes);
}
- private void zeroOutPaddingBytes(int numBytes) {
- if ((numBytes & 0x07) > 0) {
- Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
- }
- }
-
- private long getElementOffset(int ordinal, int elementSize) {
+ private long getElementOffset(int ordinal) {
return startingOffset + headerInBytes + ordinal * elementSize;
}
- public void setOffsetAndSize(int ordinal, int currentCursor, int size) {
- assertIndexIsValid(ordinal);
- final long relativeOffset = currentCursor - startingOffset;
- final long offsetAndSize = (relativeOffset << 32) | (long)size;
-
- write(ordinal, offsetAndSize);
- }
-
private void setNullBit(int ordinal) {
assertIndexIsValid(ordinal);
- BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal);
+ BitSetMethods.set(getBuffer(), startingOffset + 8, ordinal);
}
public void setNull1Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0);
+ writeByte(getElementOffset(ordinal), (byte)0);
}
public void setNull2Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0);
+ writeShort(getElementOffset(ordinal), (short)0);
}
public void setNull4Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0);
+ writeInt(getElementOffset(ordinal), 0);
}
public void setNull8Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0);
+ writeLong(getElementOffset(ordinal), 0);
}
public void setNull(int ordinal) { setNull8Bytes(ordinal); }
public void write(int ordinal, boolean value) {
assertIndexIsValid(ordinal);
- Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value);
+ writeBoolean(getElementOffset(ordinal), value);
}
public void write(int ordinal, byte value) {
assertIndexIsValid(ordinal);
- Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value);
+ writeByte(getElementOffset(ordinal), value);
}
public void write(int ordinal, short value) {
assertIndexIsValid(ordinal);
- Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value);
+ writeShort(getElementOffset(ordinal), value);
}
public void write(int ordinal, int value) {
assertIndexIsValid(ordinal);
- Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value);
+ writeInt(getElementOffset(ordinal), value);
}
public void write(int ordinal, long value) {
assertIndexIsValid(ordinal);
- Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value);
+ writeLong(getElementOffset(ordinal), value);
}
public void write(int ordinal, float value) {
- if (Float.isNaN(value)) {
- value = Float.NaN;
- }
assertIndexIsValid(ordinal);
- Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value);
+ writeFloat(getElementOffset(ordinal), value);
}
public void write(int ordinal, double value) {
- if (Double.isNaN(value)) {
- value = Double.NaN;
- }
assertIndexIsValid(ordinal);
- Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value);
+ writeDouble(getElementOffset(ordinal), value);
}
public void write(int ordinal, Decimal input, int precision, int scale) {
// make sure Decimal object has the same scale as DecimalType
assertIndexIsValid(ordinal);
- if (input.changePrecision(precision, scale)) {
+ if (input != null && input.changePrecision(precision, scale)) {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
write(ordinal, input.toUnscaledLong());
} else {
@@ -180,65 +160,14 @@ public final class UnsafeArrayWriter extends UnsafeWriter {
// Write the bytes to the variable length portion.
Platform.copyMemory(
- bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
- setOffsetAndSize(ordinal, holder.cursor, numBytes);
+ bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes);
+ setOffsetAndSize(ordinal, numBytes);
// move the cursor forward with 8-bytes boundary
- holder.cursor += roundedSize;
+ increaseCursor(roundedSize);
}
} else {
setNull(ordinal);
}
}
-
- public void write(int ordinal, UTF8String input) {
- final int numBytes = input.numBytes();
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- input.writeToMemory(holder.buffer, holder.cursor);
-
- setOffsetAndSize(ordinal, holder.cursor, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, byte[] input) {
- final int numBytes = input.length;
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- Platform.copyMemory(
- input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
-
- setOffsetAndSize(ordinal, holder.cursor, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, CalendarInterval input) {
- // grow the global buffer before writing data.
- holder.grow(16);
-
- // Write the months and microseconds fields of Interval to the variable length portion.
- Platform.putLong(holder.buffer, holder.cursor, input.months);
- Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
-
- setOffsetAndSize(ordinal, holder.cursor, 16);
-
- // move the cursor forward.
- holder.cursor += 16;
- }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index 2620bbc..71c49d8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -20,10 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
-import org.apache.spark.unsafe.types.CalendarInterval;
-import org.apache.spark.unsafe.types.UTF8String;
/**
* A helper class to write data into global row buffer using `UnsafeRow` format.
@@ -31,7 +28,7 @@ import org.apache.spark.unsafe.types.UTF8String;
* It will remember the offset of row buffer which it starts to write, and move the cursor of row
* buffer while writing. If new data(can be the input record if this is the outermost writer, or
* nested struct if this is an inner writer) comes, the starting cursor of row buffer may be
- * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the
+ * changed, so we need to call `UnsafeRowWriter.resetRowWriter` before writing, to update the
* `startingOffset` and clear out null bits.
*
* Note that if this is the outermost writer, which means we will always write from the very
@@ -40,29 +37,58 @@ import org.apache.spark.unsafe.types.UTF8String;
*/
public final class UnsafeRowWriter extends UnsafeWriter {
- private final BufferHolder holder;
- // The offset of the global buffer where we start to write this row.
- private int startingOffset;
+ private final UnsafeRow row;
+
private final int nullBitsSize;
private final int fixedSize;
- public UnsafeRowWriter(BufferHolder holder, int numFields) {
- this.holder = holder;
+ public UnsafeRowWriter(int numFields) {
+ this(new UnsafeRow(numFields));
+ }
+
+ public UnsafeRowWriter(int numFields, int initialBufferSize) {
+ this(new UnsafeRow(numFields), initialBufferSize);
+ }
+
+ public UnsafeRowWriter(UnsafeWriter writer, int numFields) {
+ this(null, writer.getBufferHolder(), numFields);
+ }
+
+ private UnsafeRowWriter(UnsafeRow row) {
+ this(row, new BufferHolder(row), row.numFields());
+ }
+
+ private UnsafeRowWriter(UnsafeRow row, int initialBufferSize) {
+ this(row, new BufferHolder(row, initialBufferSize), row.numFields());
+ }
+
+ private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) {
+ super(holder);
+ this.row = row;
this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields);
this.fixedSize = nullBitsSize + 8 * numFields;
- this.startingOffset = holder.cursor;
+ this.startingOffset = cursor();
+ }
+
+ /**
+ * Updates total size of the UnsafeRow using the size collected by BufferHolder, and returns
+ * the UnsafeRow created at a constructor
+ */
+ public UnsafeRow getRow() {
+ row.setTotalSize(totalSize());
+ return row;
}
/**
* Resets the `startingOffset` according to the current cursor of row buffer, and clear out null
* bits. This should be called before we write a new nested struct to the row buffer.
*/
- public void reset() {
- this.startingOffset = holder.cursor;
+ public void resetRowWriter() {
+ this.startingOffset = cursor();
// grow the global buffer to make sure it has enough space to write fixed-length data.
- holder.grow(fixedSize);
- holder.cursor += fixedSize;
+ grow(fixedSize);
+ increaseCursor(fixedSize);
zeroOutNullBytes();
}
@@ -72,25 +98,17 @@ public final class UnsafeRowWriter extends UnsafeWriter {
*/
public void zeroOutNullBytes() {
for (int i = 0; i < nullBitsSize; i += 8) {
- Platform.putLong(holder.buffer, startingOffset + i, 0L);
- }
- }
-
- private void zeroOutPaddingBytes(int numBytes) {
- if ((numBytes & 0x07) > 0) {
- Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
+ Platform.putLong(getBuffer(), startingOffset + i, 0L);
}
}
- public BufferHolder holder() { return holder; }
-
public boolean isNullAt(int ordinal) {
- return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal);
+ return BitSetMethods.isSet(getBuffer(), startingOffset, ordinal);
}
public void setNullAt(int ordinal) {
- BitSetMethods.set(holder.buffer, startingOffset, ordinal);
- Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L);
+ BitSetMethods.set(getBuffer(), startingOffset, ordinal);
+ write(ordinal, 0L);
}
@Override
@@ -117,67 +135,49 @@ public final class UnsafeRowWriter extends UnsafeWriter {
return startingOffset + nullBitsSize + 8 * ordinal;
}
- public void setOffsetAndSize(int ordinal, int size) {
- setOffsetAndSize(ordinal, holder.cursor, size);
- }
-
- public void setOffsetAndSize(int ordinal, int currentCursor, int size) {
- final long relativeOffset = currentCursor - startingOffset;
- final long fieldOffset = getFieldOffset(ordinal);
- final long offsetAndSize = (relativeOffset << 32) | (long) size;
-
- Platform.putLong(holder.buffer, fieldOffset, offsetAndSize);
- }
-
public void write(int ordinal, boolean value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putBoolean(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeBoolean(offset, value);
}
public void write(int ordinal, byte value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putByte(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeByte(offset, value);
}
public void write(int ordinal, short value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putShort(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeShort(offset, value);
}
public void write(int ordinal, int value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putInt(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeInt(offset, value);
}
public void write(int ordinal, long value) {
- Platform.putLong(holder.buffer, getFieldOffset(ordinal), value);
+ writeLong(getFieldOffset(ordinal), value);
}
public void write(int ordinal, float value) {
- if (Float.isNaN(value)) {
- value = Float.NaN;
- }
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putFloat(holder.buffer, offset, value);
+ writeLong(offset, 0);
+ writeFloat(offset, value);
}
public void write(int ordinal, double value) {
- if (Double.isNaN(value)) {
- value = Double.NaN;
- }
- Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value);
+ writeDouble(getFieldOffset(ordinal), value);
}
public void write(int ordinal, Decimal input, int precision, int scale) {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
// make sure Decimal object has the same scale as DecimalType
- if (input.changePrecision(precision, scale)) {
- Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong());
+ if (input != null && input.changePrecision(precision, scale)) {
+ write(ordinal, input.toUnscaledLong());
} else {
setNullAt(ordinal);
}
@@ -185,82 +185,31 @@ public final class UnsafeRowWriter extends UnsafeWriter {
// grow the global buffer before writing data.
holder.grow(16);
- // zero-out the bytes
- Platform.putLong(holder.buffer, holder.cursor, 0L);
- Platform.putLong(holder.buffer, holder.cursor + 8, 0L);
-
// Make sure Decimal object has the same scale as DecimalType.
// Note that we may pass in null Decimal object to set null for it.
if (input == null || !input.changePrecision(precision, scale)) {
- BitSetMethods.set(holder.buffer, startingOffset, ordinal);
+ // zero-out the bytes
+ Platform.putLong(getBuffer(), cursor(), 0L);
+ Platform.putLong(getBuffer(), cursor() + 8, 0L);
+
+ BitSetMethods.set(getBuffer(), startingOffset, ordinal);
// keep the offset for future update
setOffsetAndSize(ordinal, 0);
} else {
final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
- assert bytes.length <= 16;
+ final int numBytes = bytes.length;
+ assert numBytes <= 16;
+
+ zeroOutPaddingBytes(numBytes);
// Write the bytes to the variable length portion.
Platform.copyMemory(
- bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
+ bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes);
setOffsetAndSize(ordinal, bytes.length);
}
// move the cursor forward.
- holder.cursor += 16;
+ increaseCursor(16);
}
}
-
- public void write(int ordinal, UTF8String input) {
- final int numBytes = input.numBytes();
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- input.writeToMemory(holder.buffer, holder.cursor);
-
- setOffsetAndSize(ordinal, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, byte[] input) {
- write(ordinal, input, 0, input.length);
- }
-
- public void write(int ordinal, byte[] input, int offset, int numBytes) {
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset,
- holder.buffer, holder.cursor, numBytes);
-
- setOffsetAndSize(ordinal, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, CalendarInterval input) {
- // grow the global buffer before writing data.
- holder.grow(16);
-
- // Write the months and microseconds fields of Interval to the variable length portion.
- Platform.putLong(holder.buffer, holder.cursor, input.months);
- Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
-
- setOffsetAndSize(ordinal, 16);
-
- // move the cursor forward.
- holder.cursor += 16;
- }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index c94b5c7..de0eb6d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions.codegen;
import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -24,10 +26,73 @@ import org.apache.spark.unsafe.types.UTF8String;
* Base class for writing Unsafe* structures.
*/
public abstract class UnsafeWriter {
+ // Keep internal buffer holder
+ protected final BufferHolder holder;
+
+ // The offset of the global buffer where we start to write this structure.
+ protected int startingOffset;
+
+ protected UnsafeWriter(BufferHolder holder) {
+ this.holder = holder;
+ }
+
+ /**
+ * Accessor methods are delegated from BufferHolder class
+ */
+ public final BufferHolder getBufferHolder() {
+ return holder;
+ }
+
+ public final byte[] getBuffer() {
+ return holder.getBuffer();
+ }
+
+ public final void reset() {
+ holder.reset();
+ }
+
+ public final int totalSize() {
+ return holder.totalSize();
+ }
+
+ public final void grow(int neededSize) {
+ holder.grow(neededSize);
+ }
+
+ public final int cursor() {
+ return holder.getCursor();
+ }
+
+ public final void increaseCursor(int val) {
+ holder.increaseCursor(val);
+ }
+
+ public final void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) {
+ setOffsetAndSize(ordinal, previousCursor, cursor() - previousCursor);
+ }
+
+ protected void setOffsetAndSize(int ordinal, int size) {
+ setOffsetAndSize(ordinal, cursor(), size);
+ }
+
+ protected void setOffsetAndSize(int ordinal, int currentCursor, int size) {
+ final long relativeOffset = currentCursor - startingOffset;
+ final long offsetAndSize = (relativeOffset << 32) | (long)size;
+
+ write(ordinal, offsetAndSize);
+ }
+
+ protected final void zeroOutPaddingBytes(int numBytes) {
+ if ((numBytes & 0x07) > 0) {
+ Platform.putLong(getBuffer(), cursor() + ((numBytes >> 3) << 3), 0L);
+ }
+ }
+
public abstract void setNull1Bytes(int ordinal);
public abstract void setNull2Bytes(int ordinal);
public abstract void setNull4Bytes(int ordinal);
public abstract void setNull8Bytes(int ordinal);
+
public abstract void write(int ordinal, boolean value);
public abstract void write(int ordinal, byte value);
public abstract void write(int ordinal, short value);
@@ -36,8 +101,92 @@ public abstract class UnsafeWriter {
public abstract void write(int ordinal, float value);
public abstract void write(int ordinal, double value);
public abstract void write(int ordinal, Decimal input, int precision, int scale);
- public abstract void write(int ordinal, UTF8String input);
- public abstract void write(int ordinal, byte[] input);
- public abstract void write(int ordinal, CalendarInterval input);
- public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size);
+
+ public final void write(int ordinal, UTF8String input) {
+ final int numBytes = input.numBytes();
+ final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+
+ // grow the global buffer before writing data.
+ grow(roundedSize);
+
+ zeroOutPaddingBytes(numBytes);
+
+ // Write the bytes to the variable length portion.
+ input.writeToMemory(getBuffer(), cursor());
+
+ setOffsetAndSize(ordinal, numBytes);
+
+ // move the cursor forward.
+ increaseCursor(roundedSize);
+ }
+
+ public final void write(int ordinal, byte[] input) {
+ write(ordinal, input, 0, input.length);
+ }
+
+ public final void write(int ordinal, byte[] input, int offset, int numBytes) {
+ final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
+
+ // grow the global buffer before writing data.
+ grow(roundedSize);
+
+ zeroOutPaddingBytes(numBytes);
+
+ // Write the bytes to the variable length portion.
+ Platform.copyMemory(
+ input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes);
+
+ setOffsetAndSize(ordinal, numBytes);
+
+ // move the cursor forward.
+ increaseCursor(roundedSize);
+ }
+
+ public final void write(int ordinal, CalendarInterval input) {
+ // grow the global buffer before writing data.
+ grow(16);
+
+ // Write the months and microseconds fields of Interval to the variable length portion.
+ Platform.putLong(getBuffer(), cursor(), input.months);
+ Platform.putLong(getBuffer(), cursor() + 8, input.microseconds);
+
+ setOffsetAndSize(ordinal, 16);
+
+ // move the cursor forward.
+ increaseCursor(16);
+ }
+
+ protected final void writeBoolean(long offset, boolean value) {
+ Platform.putBoolean(getBuffer(), offset, value);
+ }
+
+ protected final void writeByte(long offset, byte value) {
+ Platform.putByte(getBuffer(), offset, value);
+ }
+
+ protected final void writeShort(long offset, short value) {
+ Platform.putShort(getBuffer(), offset, value);
+ }
+
+ protected final void writeInt(long offset, int value) {
+ Platform.putInt(getBuffer(), offset, value);
+ }
+
+ protected final void writeLong(long offset, long value) {
+ Platform.putLong(getBuffer(), offset, value);
+ }
+
+ protected final void writeFloat(long offset, float value) {
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
+ }
+ Platform.putFloat(getBuffer(), offset, value);
+ }
+
+ protected final void writeDouble(long offset, double value) {
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
+ }
+ Platform.putDouble(getBuffer(), offset, value);
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index 0da5ece..b31466f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
+import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{UserDefinedType, _}
import org.apache.spark.unsafe.Platform
@@ -42,17 +42,12 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe
/** The row representing the expression results. */
private[this] val intermediate = new GenericInternalRow(values)
- /** The row returned by the projection. */
- private[this] val result = new UnsafeRow(numFields)
-
- /** The buffer which holds the resulting row's backing data. */
- private[this] val holder = new BufferHolder(result, numFields * 32)
+ /* The row writer for UnsafeRow result */
+ private[this] val rowWriter = new UnsafeRowWriter(numFields, numFields * 32)
/** The writer that writes the intermediate result to the result row. */
private[this] val writer: InternalRow => Unit = {
- val rowWriter = new UnsafeRowWriter(holder, numFields)
val baseWriter = generateStructWriter(
- holder,
rowWriter,
expressions.map(e => StructField("", e.dataType, e.nullable)))
if (!expressions.exists(_.nullable)) {
@@ -83,10 +78,9 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe
}
// Write the intermediate row to an unsafe row.
- holder.reset()
+ rowWriter.reset()
writer(intermediate)
- result.setTotalSize(holder.totalSize())
- result
+ rowWriter.getRow()
}
}
@@ -111,14 +105,13 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
* given buffer using the given [[UnsafeRowWriter]].
*/
private def generateStructWriter(
- bufferHolder: BufferHolder,
rowWriter: UnsafeRowWriter,
fields: Array[StructField]): InternalRow => Unit = {
val numFields = fields.length
// Create field writers.
val fieldWriters = fields.map { field =>
- generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable)
+ generateFieldWriter(rowWriter, field.dataType, field.nullable)
}
// Create basic writer.
row => {
@@ -136,7 +129,6 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
* or array) to the given buffer using the given [[UnsafeWriter]].
*/
private def generateFieldWriter(
- bufferHolder: BufferHolder,
writer: UnsafeWriter,
dt: DataType,
nullable: Boolean): (SpecializedGetters, Int) => Unit = {
@@ -178,81 +170,79 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
case StructType(fields) =>
val numFields = fields.length
- val rowWriter = new UnsafeRowWriter(bufferHolder, numFields)
- val structWriter = generateStructWriter(bufferHolder, rowWriter, fields)
+ val rowWriter = new UnsafeRowWriter(writer, numFields)
+ val structWriter = generateStructWriter(rowWriter, fields)
(v, i) => {
- val tmpCursor = bufferHolder.cursor
+ val previousCursor = writer.cursor()
v.getStruct(i, fields.length) match {
case row: UnsafeRow =>
writeUnsafeData(
- bufferHolder,
+ rowWriter,
row.getBaseObject,
row.getBaseOffset,
row.getSizeInBytes)
case row =>
// Nested struct. We don't know where this will start because a row can be
// variable length, so we need to update the offsets and zero out the bit mask.
- rowWriter.reset()
+ rowWriter.resetRowWriter()
structWriter.apply(row)
}
- writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+ writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
}
case ArrayType(elementType, containsNull) =>
- val arrayWriter = new UnsafeArrayWriter
- val elementSize = getElementSize(elementType)
+ val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType))
val elementWriter = generateFieldWriter(
- bufferHolder,
arrayWriter,
elementType,
containsNull)
(v, i) => {
- val tmpCursor = bufferHolder.cursor
- writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize)
- writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+ val previousCursor = writer.cursor()
+ writeArray(arrayWriter, elementWriter, v.getArray(i))
+ writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
}
case MapType(keyType, valueType, valueContainsNull) =>
- val keyArrayWriter = new UnsafeArrayWriter
- val keySize = getElementSize(keyType)
+ val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType))
val keyWriter = generateFieldWriter(
- bufferHolder,
keyArrayWriter,
keyType,
nullable = false)
- val valueArrayWriter = new UnsafeArrayWriter
- val valueSize = getElementSize(valueType)
+ val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType))
val valueWriter = generateFieldWriter(
- bufferHolder,
valueArrayWriter,
valueType,
valueContainsNull)
(v, i) => {
- val tmpCursor = bufferHolder.cursor
+ val previousCursor = writer.cursor()
v.getMap(i) match {
case map: UnsafeMapData =>
writeUnsafeData(
- bufferHolder,
+ valueArrayWriter,
map.getBaseObject,
map.getBaseOffset,
map.getSizeInBytes)
case map =>
// preserve 8 bytes to write the key array numBytes later.
- bufferHolder.grow(8)
- bufferHolder.cursor += 8
+ valueArrayWriter.grow(8)
+ valueArrayWriter.increaseCursor(8)
// Write the keys and write the numBytes of key array into the first 8 bytes.
- writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize)
- Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8)
+ writeArray(keyArrayWriter, keyWriter, map.keyArray())
+ Platform.putLong(
+ valueArrayWriter.getBuffer,
+ previousCursor,
+ valueArrayWriter.cursor - previousCursor - 8
+ )
// Write the values.
- writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize)
+ writeArray(valueArrayWriter, valueWriter, map.valueArray())
}
- writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+ writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
}
case udt: UserDefinedType[_] =>
- generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable)
+ generateFieldWriter(writer, udt.sqlType, nullable)
case NullType =>
(_, _) => {}
@@ -324,20 +314,18 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
* copy.
*/
private def writeArray(
- bufferHolder: BufferHolder,
arrayWriter: UnsafeArrayWriter,
elementWriter: (SpecializedGetters, Int) => Unit,
- array: ArrayData,
- elementSize: Int): Unit = array match {
+ array: ArrayData): Unit = array match {
case unsafe: UnsafeArrayData =>
writeUnsafeData(
- bufferHolder,
+ arrayWriter,
unsafe.getBaseObject,
unsafe.getBaseOffset,
unsafe.getSizeInBytes)
case _ =>
val numElements = array.numElements()
- arrayWriter.initialize(bufferHolder, numElements, elementSize)
+ arrayWriter.initialize(numElements)
var i = 0
while (i < numElements) {
elementWriter.apply(array, i)
@@ -350,17 +338,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
* [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects.
*/
private def writeUnsafeData(
- bufferHolder: BufferHolder,
+ writer: UnsafeWriter,
baseObject: AnyRef,
baseOffset: Long,
sizeInBytes: Int) : Unit = {
- bufferHolder.grow(sizeInBytes)
+ writer.grow(sizeInBytes)
Platform.copyMemory(
baseObject,
baseOffset,
- bufferHolder.buffer,
- bufferHolder.cursor,
+ writer.getBuffer,
+ writer.cursor,
sizeInBytes)
- bufferHolder.cursor += sizeInBytes
+ writer.increaseCursor(sizeInBytes)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 6682ba5..ab2254c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -48,19 +48,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodegenContext,
input: String,
fieldTypes: Seq[DataType],
- bufferHolder: String): String = {
+ rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString))
}
+ val rowWriterClass = classOf[UnsafeRowWriter].getName
+ val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
+ v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});")
+
s"""
final InternalRow $tmpInput = $input;
if ($tmpInput instanceof UnsafeRow) {
- ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)}
+ ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)}
} else {
- ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)}
+ ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)}
}
"""
}
@@ -70,12 +74,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
row: String,
inputs: Seq[ExprCode],
inputTypes: Seq[DataType],
- bufferHolder: String,
+ rowWriter: String,
isTopLevel: Boolean = false): String = {
- val rowWriterClass = classOf[UnsafeRowWriter].getName
- val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
- v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});")
-
val resetWriter = if (isTopLevel) {
// For top level row writer, it always writes to the beginning of the global buffer holder,
// which means its fixed-size region always in the same position, so we don't need to call
@@ -88,7 +88,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"$rowWriter.zeroOutNullBytes();"
}
} else {
- s"$rowWriter.reset();"
+ s"$rowWriter.resetRowWriter();"
}
val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
@@ -97,7 +97,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case udt: UserDefinedType[_] => udt.sqlType
case other => other
}
- val tmpCursor = ctx.freshName("tmpCursor")
val setNull = dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
@@ -105,33 +104,34 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});"
case _ => s"$rowWriter.setNullAt($index);"
}
+ val previousCursor = ctx.freshName("previousCursor")
val writeField = dt match {
case t: StructType =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)}
- $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $rowWriter.cursor();
+ ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)}
+ $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case a @ ArrayType(et, _) =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)}
- $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $rowWriter.cursor();
+ ${writeArrayToBuffer(ctx, input.value, et, rowWriter)}
+ $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case m @ MapType(kt, vt, _) =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)}
- $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $rowWriter.cursor();
+ ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)}
+ $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case t: DecimalType =>
@@ -181,12 +181,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodegenContext,
input: String,
elementType: DataType,
- bufferHolder: String): String = {
+ rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
- val arrayWriterClass = classOf[UnsafeArrayWriter].getName
- val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter",
- v => s"$v = new $arrayWriterClass();")
val numElements = ctx.freshName("numElements")
val index = ctx.freshName("index")
@@ -203,28 +200,32 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => 8 // we need 8 bytes to store offset and length
}
- val tmpCursor = ctx.freshName("tmpCursor")
+ val arrayWriterClass = classOf[UnsafeArrayWriter].getName
+ val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter",
+ v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);")
+ val previousCursor = ctx.freshName("previousCursor")
+
val element = CodeGenerator.getValue(tmpInput, et, index)
val writeElement = et match {
case t: StructType =>
s"""
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)}
- $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $arrayWriter.cursor();
+ ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)}
+ $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case a @ ArrayType(et, _) =>
s"""
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeArrayToBuffer(ctx, element, et, bufferHolder)}
- $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $arrayWriter.cursor();
+ ${writeArrayToBuffer(ctx, element, et, arrayWriter)}
+ $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case m @ MapType(kt, vt, _) =>
s"""
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
- $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $arrayWriter.cursor();
+ ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)}
+ $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case t: DecimalType =>
@@ -240,10 +241,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
final ArrayData $tmpInput = $input;
if ($tmpInput instanceof UnsafeArrayData) {
- ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)}
+ ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)}
} else {
final int $numElements = $tmpInput.numElements();
- $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize);
+ $arrayWriter.initialize($numElements);
for (int $index = 0; $index < $numElements; $index++) {
if ($tmpInput.isNullAt($index)) {
@@ -262,7 +263,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
input: String,
keyType: DataType,
valueType: DataType,
- bufferHolder: String): String = {
+ rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val tmpCursor = ctx.freshName("tmpCursor")
@@ -271,20 +272,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
final MapData $tmpInput = $input;
if ($tmpInput instanceof UnsafeMapData) {
- ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)}
+ ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)}
} else {
// preserve 8 bytes to write the key array numBytes later.
- $bufferHolder.grow(8);
- $bufferHolder.cursor += 8;
+ $rowWriter.grow(8);
+ $rowWriter.increaseCursor(8);
// Remember the current cursor so that we can write numBytes of key array later.
- final int $tmpCursor = $bufferHolder.cursor;
+ final int $tmpCursor = $rowWriter.cursor();
- ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)}
+ ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)}
// Write the numBytes of key array into the first 8 bytes.
- Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor);
+ Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor);
- ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)}
+ ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)}
}
"""
}
@@ -293,14 +294,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
* If the input is already in unsafe format, we don't need to go through all elements/fields,
* we can directly write it.
*/
- private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = {
+ private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = {
val sizeInBytes = ctx.freshName("sizeInBytes")
s"""
final int $sizeInBytes = $input.getSizeInBytes();
// grow the global buffer before writing data.
- $bufferHolder.grow($sizeInBytes);
- $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor);
- $bufferHolder.cursor += $sizeInBytes;
+ $rowWriter.grow($sizeInBytes);
+ $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor());
+ $rowWriter.increaseCursor($sizeInBytes);
"""
}
@@ -317,38 +318,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => true
}
- val result = ctx.addMutableState("UnsafeRow", "result",
- v => s"$v = new UnsafeRow(${expressions.length});")
-
- val holderClass = classOf[BufferHolder].getName
- val holder = ctx.addMutableState(holderClass, "holder",
- v => s"$v = new $holderClass($result, ${numVarLenFields * 32});")
-
- val resetBufferHolder = if (numVarLenFields == 0) {
- ""
- } else {
- s"$holder.reset();"
- }
- val updateRowSize = if (numVarLenFields == 0) {
- ""
- } else {
- s"$result.setTotalSize($holder.totalSize());"
- }
+ val rowWriterClass = classOf[UnsafeRowWriter].getName
+ val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
+ v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});")
// Evaluate all the subexpression.
val evalSubexpr = ctx.subexprFunctions.mkString("\n")
- val writeExpressions =
- writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true)
+ val writeExpressions = writeExpressionsToBuffer(
+ ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true)
val code =
s"""
- $resetBufferHolder
+ $rowWriter.reset();
$evalSubexpr
$writeExpressions
- $updateRowSize
"""
- ExprCode(code, "false", result)
+ ExprCode(code, "false", s"$rowWriter.getRow()")
}
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
index fb3dbe8..2da8711 100644
--- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
@@ -27,7 +27,6 @@ import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
import org.apache.spark.unsafe.types.UTF8String;
@@ -55,36 +54,27 @@ public class RowBasedKeyValueBatchSuite {
}
private UnsafeRow makeKeyRow(long k1, String k2) {
- UnsafeRow row = new UnsafeRow(2);
- BufferHolder holder = new BufferHolder(row, 32);
- UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
- holder.reset();
+ UnsafeRowWriter writer = new UnsafeRowWriter(2);
+ writer.reset();
writer.write(0, k1);
writer.write(1, UTF8String.fromString(k2));
- row.setTotalSize(holder.totalSize());
- return row;
+ return writer.getRow();
}
private UnsafeRow makeKeyRow(long k1, long k2) {
- UnsafeRow row = new UnsafeRow(2);
- BufferHolder holder = new BufferHolder(row, 0);
- UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
- holder.reset();
+ UnsafeRowWriter writer = new UnsafeRowWriter(2);
+ writer.reset();
writer.write(0, k1);
writer.write(1, k2);
- row.setTotalSize(holder.totalSize());
- return row;
+ return writer.getRow();
}
private UnsafeRow makeValueRow(long v1, long v2) {
- UnsafeRow row = new UnsafeRow(2);
- BufferHolder holder = new BufferHolder(row, 0);
- UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
- holder.reset();
+ UnsafeRowWriter writer = new UnsafeRowWriter(2);
+ writer.reset();
writer.write(0, v1);
writer.write(1, v2);
- row.setTotalSize(holder.totalSize());
- return row;
+ return writer.getRow();
}
private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) {
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
index 8617be8..d550827 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
@@ -165,18 +165,14 @@ class RowBasedHashMapGenerator(
| if (buckets[idx] == -1) {
| if (numRows < capacity && !isBatchFull) {
| // creating the unsafe for new entry
- | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length});
- | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder
- | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result,
- | ${numVarLenFields * 32});
| org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
| = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
- | agg_holder,
- | ${groupingKeySchema.length});
- | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed
+ | ${groupingKeySchema.length}, ${numVarLenFields * 32});
+ | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed
| agg_rowWriter.zeroOutNullBytes();
| ${createUnsafeRowForKey};
- | agg_result.setTotalSize(agg_holder.totalSize());
+ | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result
+ | = agg_rowWriter.getRow();
| Object kbase = agg_result.getBaseObject();
| long koff = agg_result.getBaseOffset();
| int klen = agg_result.getSizeInBytes();
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index 3b5655b..2d699e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -165,9 +165,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
private ByteOrder nativeOrder = null;
private byte[][] buffers = null;
- private UnsafeRow unsafeRow = new UnsafeRow($numFields);
- private BufferHolder bufferHolder = new BufferHolder(unsafeRow);
- private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields);
+ private UnsafeRowWriter rowWriter = new UnsafeRowWriter($numFields);
private MutableUnsafeRow mutableRow = null;
private int currentRow = 0;
@@ -212,11 +210,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
public InternalRow next() {
currentRow += 1;
- bufferHolder.reset();
+ rowWriter.reset();
rowWriter.zeroOutNullBytes();
${extractorCalls}
- unsafeRow.setTotalSize(bufferHolder.totalSize());
- return unsafeRow;
+ return rowWriter.getRow();
}
${ctx.declareAddedFunctions()}
http://git-wip-us.apache.org/repos/asf/spark/blob/a7c19d9c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index 9647f09..e93908d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -26,7 +26,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
@@ -130,16 +130,13 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
val emptyUnsafeRow = new UnsafeRow(0)
reader.map(_ => emptyUnsafeRow)
} else {
- val unsafeRow = new UnsafeRow(1)
- val bufferHolder = new BufferHolder(unsafeRow)
- val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1)
+ val unsafeRowWriter = new UnsafeRowWriter(1)
reader.map { line =>
// Writes to an UnsafeRow directly
- bufferHolder.reset()
+ unsafeRowWriter.reset()
unsafeRowWriter.write(0, line.getBytes, 0, line.getLength)
- unsafeRow.setTotalSize(bufferHolder.totalSize())
- unsafeRow
+ unsafeRowWriter.getRow()
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org