You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/04/28 06:23:44 UTC

spark git commit: [SPARK-14961] Build HashedRelation larger than 1G

Repository: spark
Updated Branches:
  refs/heads/master f5da592fc -> ae4e3def5


[SPARK-14961] Build HashedRelation larger than 1G

## What changes were proposed in this pull request?

Currently, LongToUnsafeRowMap use byte array as the underlying page, which can't be larger 1G.

This PR improves LongToUnsafeRowMap  to scale up to 8G bytes by using array of Long instead of array of byte.

## How was this patch tested?

Manually ran a test to confirm that both UnsafeHashedRelation and LongHashedRelation could build a map that larger than 2G.

Author: Davies Liu <da...@databricks.com>

Closes #12740 from davies/larger_broadcast.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ae4e3def
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ae4e3def
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ae4e3def

Branch: refs/heads/master
Commit: ae4e3def5eacb8e383a3535e6c685897fd1aaf4c
Parents: f5da592
Author: Davies Liu <da...@databricks.com>
Authored: Wed Apr 27 21:23:40 2016 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Apr 27 21:23:40 2016 -0700

----------------------------------------------------------------------
 .../sql/execution/joins/HashedRelation.scala    | 134 +++++++++++--------
 .../execution/joins/HashedRelationSuite.scala   |  30 ++++-
 2 files changed, 107 insertions(+), 57 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ae4e3def/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 0427db4..b280c76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -173,8 +173,8 @@ private[joins] class UnsafeHashedRelation(
   override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
     out.writeInt(numFields)
     // TODO: move these into BytesToBytesMap
-    out.writeInt(binaryMap.numKeys())
-    out.writeInt(binaryMap.numValues())
+    out.writeLong(binaryMap.numKeys())
+    out.writeLong(binaryMap.numValues())
 
     var buffer = new Array[Byte](64)
     def write(base: Object, offset: Long, length: Int): Unit = {
@@ -199,8 +199,8 @@ private[joins] class UnsafeHashedRelation(
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
     numFields = in.readInt()
     resultRow = new UnsafeRow(numFields)
-    val nKeys = in.readInt()
-    val nValues = in.readInt()
+    val nKeys = in.readLong()
+    val nValues = in.readLong()
     // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
     // TODO(josh): This needs to be revisited before we merge this patch; making this change now
     // so that tests compile:
@@ -345,16 +345,20 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
 
   // The page to store all bytes of UnsafeRow and the pointer to next rows.
   // [row1][pointer1] [row2][pointer2]
-  private var page: Array[Byte] = null
+  private var page: Array[Long] = null
 
   // Current write cursor in the page.
-  private var cursor = Platform.BYTE_ARRAY_OFFSET
+  private var cursor: Long = Platform.LONG_ARRAY_OFFSET
+
+  // The number of bits for size in address
+  private val SIZE_BITS = 28
+  private val SIZE_MASK = 0xfffffff
 
   // The total number of values of all keys.
-  private var numValues = 0
+  private var numValues = 0L
 
   // The number of unique keys.
-  private var numKeys = 0
+  private var numKeys = 0L
 
   // needed by serializer
   def this() = {
@@ -390,7 +394,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
       acquireMemory(n * 2 * 8 + (1 << 20))
       array = new Array[Long](n * 2)
       mask = n * 2 - 2
-      page = new Array[Byte](1 << 20)  // 1M bytes
+      page = new Array[Long](1 << 17)  // 1M bytes
     }
   }
 
@@ -406,7 +410,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
   /**
    * Returns total memory consumption.
    */
-  def getTotalMemoryConsumption: Long = array.length * 8 + page.length
+  def getTotalMemoryConsumption: Long = array.length * 8L + page.length * 8L
 
   /**
    * Returns the first slot of array that store the keys (sparse mode).
@@ -422,8 +426,8 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
   private def nextSlot(pos: Int): Int = (pos + 2) & mask
 
   private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
-    val offset = address >>> 32
-    val size = address & 0xffffffffL
+    val offset = address >>> SIZE_BITS
+    val size = address & SIZE_MASK
     resultRow.pointTo(page, offset, size.toInt)
     resultRow
   }
@@ -450,15 +454,15 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
   }
 
   /**
-   * Returns an interator of UnsafeRow for multiple linked values.
+   * Returns an iterator of UnsafeRow for multiple linked values.
    */
   private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
     new Iterator[UnsafeRow] {
       var addr = address
       override def hasNext: Boolean = addr != 0
       override def next(): UnsafeRow = {
-        val offset = addr >>> 32
-        val size = addr & 0xffffffffL
+        val offset = addr >>> SIZE_BITS
+        val size = addr & SIZE_MASK
         resultRow.pointTo(page, offset, size.toInt)
         addr = Platform.getLong(page, offset + size)
         resultRow
@@ -491,6 +495,11 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
    * Appends the key and row into this map.
    */
   def append(key: Long, row: UnsafeRow): Unit = {
+    val sizeInBytes = row.getSizeInBytes
+    if (sizeInBytes >= (1 << SIZE_BITS)) {
+      sys.error("Does not support row that is larger than 256M")
+    }
+
     if (key < minKey) {
       minKey = key
     }
@@ -499,16 +508,17 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
     }
 
     // There is 8 bytes for the pointer to next value
-    if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) {
+    if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) {
       val used = page.length
-      if (used * 2L > (1L << 31)) {
-        sys.error("Can't allocate a page that is larger than 2G")
+      if (used >= (1 << 30)) {
+        sys.error("Can not build a HashedRelation that is larger than 8G")
       }
-      acquireMemory(used * 2)
-      val newPage = new Array[Byte](used * 2)
-      System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET)
+      acquireMemory(used * 8L * 2)
+      val newPage = new Array[Long](used * 2)
+      Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
+        cursor - Platform.LONG_ARRAY_OFFSET)
       page = newPage
-      freeMemory(used)
+      freeMemory(used * 8)
     }
 
     // copy the bytes of UnsafeRow
@@ -518,7 +528,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
     Platform.putLong(page, cursor, 0)
     cursor += 8
     numValues += 1
-    updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes)
+    updateIndex(key, (offset.toLong << SIZE_BITS) | row.getSizeInBytes)
   }
 
   /**
@@ -536,11 +546,17 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
       numKeys += 1
       if (numKeys * 4 > array.length) {
         // reach half of the capacity
-        growArray()
+        if (array.length < (1 << 30)) {
+          // Cannot allocate an array with 2G elements
+          growArray()
+        } else if (numKeys > array.length / 2 * 0.75) {
+          // The fill ratio should be less than 0.75
+          sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys")
+        }
       }
     } else {
       // there are some values for this key, put the address in the front of them.
-      val pointer = (address >>> 32) + (address & 0xffffffffL)
+      val pointer = (address >>> SIZE_BITS) + (address & SIZE_MASK)
       Platform.putLong(page, pointer, array(pos + 1))
       array(pos + 1) = address
     }
@@ -550,7 +566,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
     var old_array = array
     val n = array.length
     numKeys = 0
-    acquireMemory(n * 2 * 8)
+    acquireMemory(n * 2 * 8L)
     array = new Array[Long](n * 2)
     mask = n * 2 - 2
     var i = 0
@@ -599,7 +615,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
    */
   def free(): Unit = {
     if (page != null) {
-      freeMemory(page.length)
+      freeMemory(page.length * 8)
       page = null
     }
     if (array != null) {
@@ -608,52 +624,58 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
     }
   }
 
+  private def writeLongArray(out: ObjectOutput, arr: Array[Long], len: Int): Unit = {
+    val buffer = new Array[Byte](4 << 10)
+    var offset: Long = Platform.LONG_ARRAY_OFFSET
+    val end = len * 8L + Platform.LONG_ARRAY_OFFSET
+    while (offset < end) {
+      val size = Math.min(buffer.length, (end - offset).toInt)
+      Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
+      out.write(buffer, 0, size)
+      offset += size
+    }
+  }
+
   override def writeExternal(out: ObjectOutput): Unit = {
     out.writeBoolean(isDense)
     out.writeLong(minKey)
     out.writeLong(maxKey)
-    out.writeInt(numKeys)
-    out.writeInt(numValues)
+    out.writeLong(numKeys)
+    out.writeLong(numValues)
+
+    out.writeLong(array.length)
+    writeLongArray(out, array, array.length)
+    val used = ((cursor - Platform.LONG_ARRAY_OFFSET) / 8).toInt
+    out.writeLong(used)
+    writeLongArray(out, page, used)
+  }
 
-    out.writeInt(array.length)
+  private def readLongArray(in: ObjectInput, length: Int): Array[Long] = {
+    val array = new Array[Long](length)
     val buffer = new Array[Byte](4 << 10)
-    var offset = Platform.LONG_ARRAY_OFFSET
-    val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET
+    var offset: Long = Platform.LONG_ARRAY_OFFSET
+    val end = length * 8L + Platform.LONG_ARRAY_OFFSET
     while (offset < end) {
-      val size = Math.min(buffer.length, end - offset)
-      Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
-      out.write(buffer, 0, size)
+      val size = Math.min(buffer.length, (end - offset).toInt)
+      in.readFully(buffer, 0, size)
+      Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
       offset += size
     }
-
-    val used = cursor - Platform.BYTE_ARRAY_OFFSET
-    out.writeInt(used)
-    out.write(page, 0, used)
+    array
   }
 
   override def readExternal(in: ObjectInput): Unit = {
     isDense = in.readBoolean()
     minKey = in.readLong()
     maxKey = in.readLong()
-    numKeys = in.readInt()
-    numValues = in.readInt()
+    numKeys = in.readLong
+    numValues = in.readLong()
 
-    val length = in.readInt()
-    array = new Array[Long](length)
+    val length = in.readLong().toInt
     mask = length - 2
-    val buffer = new Array[Byte](4 << 10)
-    var offset = Platform.LONG_ARRAY_OFFSET
-    val end = length * 8 + Platform.LONG_ARRAY_OFFSET
-    while (offset < end) {
-      val size = Math.min(buffer.length, end - offset)
-      in.readFully(buffer, 0, size)
-      Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
-      offset += size
-    }
-
-    val numBytes = in.readInt()
-    page = new Array[Byte](numBytes)
-    in.readFully(page)
+    array = readLongArray(in, length)
+    val pageLength = in.readLong().toInt
+    page = readLongArray(in, pageLength)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ae4e3def/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 371a9ed..3ee25c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -24,8 +24,9 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
 import org.apache.spark.unsafe.map.BytesToBytesMap
+import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.collection.CompactBuffer
 
 class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
@@ -149,4 +150,31 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
       assert(rows(1).getInt(1) === i + 1)
     }
   }
+
+  // This test require 4G heap to run, should run it manually
+  ignore("build HashedRelation that is larger than 1G") {
+    val unsafeProj = UnsafeProjection.create(
+      Seq(BoundReference(0, IntegerType, false),
+        BoundReference(1, StringType, true)))
+    val unsafeRow = unsafeProj(InternalRow(0, UTF8String.fromString(" " * 100)))
+    val key = Seq(BoundReference(0, IntegerType, false))
+    val rows = (0 until (1 << 24)).iterator.map { i =>
+      unsafeRow.setInt(0, i % 1000000)
+      unsafeRow.setInt(1, i)
+      unsafeRow
+    }
+
+    val unsafeRelation = UnsafeHashedRelation(rows, key, 1000, mm)
+    assert(unsafeRelation.estimatedSize > (2L << 30))
+    unsafeRelation.close()
+
+    val rows2 = (0 until (1 << 24)).iterator.map { i =>
+      unsafeRow.setInt(0, i % 1000000)
+      unsafeRow.setInt(1, i)
+      unsafeRow
+    }
+    val longRelation = LongHashedRelation(rows2, key, 1000, mm)
+    assert(longRelation.estimatedSize > (2L << 30))
+    longRelation.close()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org