You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/09 05:28:12 UTC

spark git commit: [SPARK-8932] Support copy() for UnsafeRows that do not use ObjectPools

Repository: spark
Updated Branches:
  refs/heads/master a29081487 -> b55499a44


[SPARK-8932] Support copy() for UnsafeRows that do not use ObjectPools

We call Row.copy() in many places throughout SQL but UnsafeRow currently throws UnsupportedOperationException when copy() is called.

Supporting copying when ObjectPool is used may be difficult, since we may need to handle deep-copying of objects in the pool. In addition, this copy() method needs to produce a self-contained row object which may be passed around / buffered by downstream code which does not understand the UnsafeRow format.

In the long run, we'll need to figure out how to handle the ObjectPool corner cases, but this may be unnecessary if other changes are made. Therefore, in order to unblock my sort patch (#6444) I propose that we support copy() for the cases where UnsafeRow does not use an ObjectPool and continue to throw UnsupportedOperationException when an ObjectPool is used.

This patch accomplishes this by modifying UnsafeRow so that it knows the size of the row's backing data in order to be able to copy it into a byte array.

Author: Josh Rosen <jo...@databricks.com>

Closes #7306 from JoshRosen/SPARK-8932 and squashes the following commits:

338e6bf [Josh Rosen] Support copy for UnsafeRows that do not use ObjectPools.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b55499a4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b55499a4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b55499a4

Branch: refs/heads/master
Commit: b55499a44ab74e33378211fb0d6940905d7c6318
Parents: a290814
Author: Josh Rosen <jo...@databricks.com>
Authored: Wed Jul 8 20:28:05 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Jul 8 20:28:05 2015 -0700

----------------------------------------------------------------------
 .../UnsafeFixedWidthAggregationMap.java         | 12 +++--
 .../sql/catalyst/expressions/UnsafeRow.java     | 32 +++++++++++-
 .../expressions/UnsafeRowConverter.scala        | 10 +++-
 .../expressions/UnsafeRowConverterSuite.scala   | 52 +++++++++++++++-----
 4 files changed, 87 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b55499a4/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 1e79f4b..79d55b3 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -120,9 +120,11 @@ public final class UnsafeFixedWidthAggregationMap {
     this.bufferPool = new ObjectPool(initialCapacity);
 
     InternalRow initRow = initProjection.apply(emptyRow);
-    this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)];
+    int emptyBufferSize = bufferConverter.getSizeRequirement(initRow);
+    this.emptyBuffer = new byte[emptyBufferSize];
     int writtenLength = bufferConverter.writeRow(
-      initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool);
+      initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize,
+      bufferPool);
     assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!";
     // re-use the empty buffer only when there is no object saved in pool.
     reuseEmptyBuffer = bufferPool.size() == 0;
@@ -142,6 +144,7 @@ public final class UnsafeFixedWidthAggregationMap {
       groupingKey,
       groupingKeyConversionScratchSpace,
       PlatformDependent.BYTE_ARRAY_OFFSET,
+      groupingKeySize,
       keyPool);
     assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
 
@@ -157,7 +160,7 @@ public final class UnsafeFixedWidthAggregationMap {
         // There is some objects referenced by emptyBuffer, so generate a new one
         InternalRow initRow = initProjection.apply(emptyRow);
         bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET,
-          bufferPool);
+          groupingKeySize, bufferPool);
       }
       loc.putNewKey(
         groupingKeyConversionScratchSpace,
@@ -175,6 +178,7 @@ public final class UnsafeFixedWidthAggregationMap {
       address.getBaseObject(),
       address.getBaseOffset(),
       bufferConverter.numFields(),
+      loc.getValueLength(),
       bufferPool
     );
     return currentBuffer;
@@ -214,12 +218,14 @@ public final class UnsafeFixedWidthAggregationMap {
           keyAddress.getBaseObject(),
           keyAddress.getBaseOffset(),
           keyConverter.numFields(),
+          loc.getKeyLength(),
           keyPool
         );
         entry.value.pointTo(
           valueAddress.getBaseObject(),
           valueAddress.getBaseOffset(),
           bufferConverter.numFields(),
+          loc.getValueLength(),
           bufferPool
         );
         return entry;

http://git-wip-us.apache.org/repos/asf/spark/blob/b55499a4/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index aeb64b0..edb7202 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -68,6 +68,9 @@ public final class UnsafeRow extends MutableRow {
   /** The number of fields in this row, used for calculating the bitset width (and in assertions) */
   private int numFields;
 
+  /** The size of this row's backing data, in bytes) */
+  private int sizeInBytes;
+
   public int length() { return numFields; }
 
   /** The width of the null tracking bit set, in bytes */
@@ -95,14 +98,17 @@ public final class UnsafeRow extends MutableRow {
    * @param baseObject the base object
    * @param baseOffset the offset within the base object
    * @param numFields the number of fields in this row
+   * @param sizeInBytes the size of this row's backing data, in bytes
    * @param pool the object pool to hold arbitrary objects
    */
-  public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) {
+  public void pointTo(
+      Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) {
     assert numFields >= 0 : "numFields should >= 0";
     this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
     this.baseObject = baseObject;
     this.baseOffset = baseOffset;
     this.numFields = numFields;
+    this.sizeInBytes = sizeInBytes;
     this.pool = pool;
   }
 
@@ -336,9 +342,31 @@ public final class UnsafeRow extends MutableRow {
     }
   }
 
+  /**
+   * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal
+   * byte array rather than referencing data stored in a data page.
+   * <p>
+   * This method is only supported on UnsafeRows that do not use ObjectPools.
+   */
   @Override
   public InternalRow copy() {
-    throw new UnsupportedOperationException();
+    if (pool != null) {
+      throw new UnsupportedOperationException(
+        "Copy is not supported for UnsafeRows that use object pools");
+    } else {
+      UnsafeRow rowCopy = new UnsafeRow();
+      final byte[] rowDataCopy = new byte[sizeInBytes];
+      PlatformDependent.copyMemory(
+        baseObject,
+        baseOffset,
+        rowDataCopy,
+        PlatformDependent.BYTE_ARRAY_OFFSET,
+        sizeInBytes
+      );
+      rowCopy.pointTo(
+        rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null);
+      return rowCopy;
+    }
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/spark/blob/b55499a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
index 1f39549..6af5e62 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -70,10 +70,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
    * @param row the row to convert
    * @param baseObject the base object of the destination address
    * @param baseOffset the base offset of the destination address
+   * @param rowLengthInBytes the length calculated by `getSizeRequirement(row)`
    * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
    */
-  def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = {
-    unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool)
+  def writeRow(
+      row: InternalRow,
+      baseObject: Object,
+      baseOffset: Long,
+      rowLengthInBytes: Int,
+      pool: ObjectPool): Int = {
+    unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool)
 
     if (writers.length > 0) {
       // zero-out the bitset

http://git-wip-us.apache.org/repos/asf/spark/blob/b55499a4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 96d4e64..d00aeb4d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -44,19 +44,32 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     val sizeRequired: Int = converter.getSizeRequirement(row)
     assert(sizeRequired === 8 + (3 * 8))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
-    val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+    val numBytesWritten =
+      converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
     assert(numBytesWritten === sizeRequired)
 
     val unsafeRow = new UnsafeRow()
-    unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+    unsafeRow.pointTo(
+      buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null)
     assert(unsafeRow.getLong(0) === 0)
     assert(unsafeRow.getLong(1) === 1)
     assert(unsafeRow.getInt(2) === 2)
 
+    // We can copy UnsafeRows as long as they don't reference ObjectPools
+    val unsafeRowCopy = unsafeRow.copy()
+    assert(unsafeRowCopy.getLong(0) === 0)
+    assert(unsafeRowCopy.getLong(1) === 1)
+    assert(unsafeRowCopy.getInt(2) === 2)
+
     unsafeRow.setLong(1, 3)
     assert(unsafeRow.getLong(1) === 3)
     unsafeRow.setInt(2, 4)
     assert(unsafeRow.getInt(2) === 4)
+
+    // Mutating the original row should not have changed the copy
+    assert(unsafeRowCopy.getLong(0) === 0)
+    assert(unsafeRowCopy.getLong(1) === 1)
+    assert(unsafeRowCopy.getInt(2) === 2)
   }
 
   test("basic conversion with primitive, string and binary types") {
@@ -73,12 +86,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
       ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
-    val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+    val numBytesWritten = converter.writeRow(
+      row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
     assert(numBytesWritten === sizeRequired)
 
     val unsafeRow = new UnsafeRow()
     val pool = new ObjectPool(10)
-    unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
+    unsafeRow.pointTo(
+      buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool)
     assert(unsafeRow.getLong(0) === 0)
     assert(unsafeRow.getString(1) === "Hello")
     assert(unsafeRow.get(2) === "World".getBytes)
@@ -96,6 +111,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     unsafeRow.update(2, "Hello World".getBytes)
     assert(unsafeRow.get(2) === "Hello World".getBytes)
     assert(pool.size === 2)
+
+    // We do not support copy() for UnsafeRows that reference ObjectPools
+    intercept[UnsupportedOperationException] {
+      unsafeRow.copy()
+    }
   }
 
   test("basic conversion with primitive, decimal and array") {
@@ -111,12 +131,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     val sizeRequired: Int = converter.getSizeRequirement(row)
     assert(sizeRequired === 8 + (8 * 3))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
-    val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
+    val numBytesWritten =
+      converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool)
     assert(numBytesWritten === sizeRequired)
     assert(pool.size === 2)
 
     val unsafeRow = new UnsafeRow()
-    unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
+    unsafeRow.pointTo(
+      buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool)
     assert(unsafeRow.getLong(0) === 0)
     assert(unsafeRow.get(1) === Decimal(1))
     assert(unsafeRow.get(2) === Array(2))
@@ -142,11 +164,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(sizeRequired === 8 + (8 * 4) +
       ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
-    val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+    val numBytesWritten =
+      converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
     assert(numBytesWritten === sizeRequired)
 
     val unsafeRow = new UnsafeRow()
-    unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+    unsafeRow.pointTo(
+      buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null)
     assert(unsafeRow.getLong(0) === 0)
     assert(unsafeRow.getString(1) === "Hello")
     // Date is represented as Int in unsafeRow
@@ -190,12 +214,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
     val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
     val numBytesWritten = converter.writeRow(
-      rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+      rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
+      sizeRequired, null)
     assert(numBytesWritten === sizeRequired)
 
     val createdFromNull = new UnsafeRow()
     createdFromNull.pointTo(
-      createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+      createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length,
+      sizeRequired, null)
     for (i <- 0 to fieldTypes.length - 1) {
       assert(createdFromNull.isNullAt(i))
     }
@@ -233,10 +259,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     val pool = new ObjectPool(1)
     val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2)
     converter.writeRow(
-      rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
+      rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
+      sizeRequired, pool)
     val setToNullAfterCreation = new UnsafeRow()
     setToNullAfterCreation.pointTo(
-      setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
+      setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length,
+      sizeRequired, pool)
 
     assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
     assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org