You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/09/29 04:04:27 UTC

git commit: SPARK-2761 refactor #maybeSpill into Spillable

Repository: spark
Updated Branches:
  refs/heads/master 8e874185e -> 25164a89d


SPARK-2761 refactor #maybeSpill into Spillable

Moved `#maybeSpill` in ExternalSorter and EAOM into `Spillable`.

Author: Jim Lim <ji...@quixey.com>

Closes #2416 from jimjh/SPARK-2761 and squashes the following commits:

cf8be9a [Jim Lim] SPARK-2761 fix documentation, reorder code
f94d522 [Jim Lim] SPARK-2761 refactor Spillable to simplify sig
e75a24e [Jim Lim] SPARK-2761 use protected over protected[this]
7270e0d [Jim Lim] SPARK-2761 refactor #maybeSpill into Spillable


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/25164a89
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/25164a89
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/25164a89

Branch: refs/heads/master
Commit: 25164a89dd32eef58d9b6823ae259439f796e81a
Parents: 8e87418
Author: Jim Lim <ji...@quixey.com>
Authored: Sun Sep 28 19:04:24 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Sun Sep 28 19:04:24 2014 -0700

----------------------------------------------------------------------
 .../util/collection/ExternalAppendOnlyMap.scala |  46 ++------
 .../spark/util/collection/ExternalSorter.scala  |  68 +++---------
 .../spark/util/collection/Spillable.scala       | 111 +++++++++++++++++++
 3 files changed, 133 insertions(+), 92 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/25164a89/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 8a015c1..0c088da 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -66,23 +66,19 @@ class ExternalAppendOnlyMap[K, V, C](
     mergeCombiners: (C, C) => C,
     serializer: Serializer = SparkEnv.get.serializer,
     blockManager: BlockManager = SparkEnv.get.blockManager)
-  extends Iterable[(K, C)] with Serializable with Logging {
+  extends Iterable[(K, C)]
+  with Serializable
+  with Logging
+  with Spillable[SizeTracker] {
 
   private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
   private val spilledMaps = new ArrayBuffer[DiskMapIterator]
   private val sparkConf = SparkEnv.get.conf
   private val diskBlockManager = blockManager.diskBlockManager
-  private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
 
   // Number of pairs inserted since last spill; note that we count them even if a value is merged
   // with a previous key in case we're doing something like groupBy where the result grows
-  private var elementsRead = 0L
-
-  // Number of in-memory pairs inserted before tracking the map's shuffle memory usage
-  private val trackMemoryThreshold = 1000
-
-  // How much of the shared memory pool this collection has claimed
-  private var myMemoryThreshold = 0L
+  protected[this] var elementsRead = 0L
 
   /**
    * Size of object batches when reading/writing from serializers.
@@ -95,11 +91,7 @@ class ExternalAppendOnlyMap[K, V, C](
    */
   private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000)
 
-  // How many times we have spilled so far
-  private var spillCount = 0
-
   // Number of bytes spilled in total
-  private var _memoryBytesSpilled = 0L
   private var _diskBytesSpilled = 0L
 
   private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024
@@ -136,19 +128,8 @@ class ExternalAppendOnlyMap[K, V, C](
 
     while (entries.hasNext) {
       curEntry = entries.next()
-      if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
-          currentMap.estimateSize() >= myMemoryThreshold)
-      {
-        // Claim up to double our current memory from the shuffle memory pool
-        val currentMemory = currentMap.estimateSize()
-        val amountToRequest = 2 * currentMemory - myMemoryThreshold
-        val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
-        myMemoryThreshold += granted
-        if (myMemoryThreshold <= currentMemory) {
-          // We were granted too little memory to grow further (either tryToAcquire returned 0,
-          // or we already had more memory than myMemoryThreshold); spill the current collection
-          spill(currentMemory)  // Will also release memory back to ShuffleMemoryManager
-        }
+      if (maybeSpill(currentMap, currentMap.estimateSize())) {
+        currentMap = new SizeTrackingAppendOnlyMap[K, C]
       }
       currentMap.changeValue(curEntry._1, update)
       elementsRead += 1
@@ -171,11 +152,7 @@ class ExternalAppendOnlyMap[K, V, C](
   /**
    * Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
    */
-  private def spill(mapSize: Long): Unit = {
-    spillCount += 1
-    val threadId = Thread.currentThread().getId
-    logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
-      .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
+  override protected[this] def spill(collection: SizeTracker): Unit = {
     val (blockId, file) = diskBlockManager.createTempBlock()
     curWriteMetrics = new ShuffleWriteMetrics()
     var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
@@ -231,18 +208,11 @@ class ExternalAppendOnlyMap[K, V, C](
       }
     }
 
-    currentMap = new SizeTrackingAppendOnlyMap[K, C]
     spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
 
-    // Release our memory back to the shuffle pool so that other threads can grab it
-    shuffleMemoryManager.release(myMemoryThreshold)
-    myMemoryThreshold = 0L
-
     elementsRead = 0
-    _memoryBytesSpilled += mapSize
   }
 
-  def memoryBytesSpilled: Long = _memoryBytesSpilled
   def diskBytesSpilled: Long = _diskBytesSpilled
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/25164a89/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 782b979..0a152cb 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -79,14 +79,14 @@ private[spark] class ExternalSorter[K, V, C](
     aggregator: Option[Aggregator[K, V, C]] = None,
     partitioner: Option[Partitioner] = None,
     ordering: Option[Ordering[K]] = None,
-    serializer: Option[Serializer] = None) extends Logging {
+    serializer: Option[Serializer] = None)
+  extends Logging with Spillable[SizeTrackingPairCollection[(Int, K), C]] {
 
   private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
   private val shouldPartition = numPartitions > 1
 
   private val blockManager = SparkEnv.get.blockManager
   private val diskBlockManager = blockManager.diskBlockManager
-  private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
   private val ser = Serializer.getSerializer(serializer)
   private val serInstance = ser.newInstance()
 
@@ -115,22 +115,14 @@ private[spark] class ExternalSorter[K, V, C](
 
   // Number of pairs read from input since last spill; note that we count them even if a value is
   // merged with a previous key in case we're doing something like groupBy where the result grows
-  private var elementsRead = 0L
-
-  // What threshold of elementsRead we start estimating map size at.
-  private val trackMemoryThreshold = 1000
+  protected[this] var elementsRead = 0L
 
   // Total spilling statistics
-  private var spillCount = 0
-  private var _memoryBytesSpilled = 0L
   private var _diskBytesSpilled = 0L
 
   // Write metrics for current spill
   private var curWriteMetrics: ShuffleWriteMetrics = _
 
-  // How much of the shared memory pool this collection has claimed
-  private var myMemoryThreshold = 0L
-
   // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need
   // local aggregation and sorting, write numPartitions files directly and just concatenate them
   // at the end. This avoids doing serialization and deserialization twice to merge together the
@@ -209,7 +201,7 @@ private[spark] class ExternalSorter[K, V, C](
         elementsRead += 1
         kv = records.next()
         map.changeValue((getPartition(kv._1), kv._1), update)
-        maybeSpill(usingMap = true)
+        maybeSpillCollection(usingMap = true)
       }
     } else {
       // Stick values into our buffer
@@ -217,7 +209,7 @@ private[spark] class ExternalSorter[K, V, C](
         elementsRead += 1
         val kv = records.next()
         buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
-        maybeSpill(usingMap = false)
+        maybeSpillCollection(usingMap = false)
       }
     }
   }
@@ -227,61 +219,31 @@ private[spark] class ExternalSorter[K, V, C](
    *
    * @param usingMap whether we're using a map or buffer as our current in-memory collection
    */
-  private def maybeSpill(usingMap: Boolean): Unit = {
+  private def maybeSpillCollection(usingMap: Boolean): Unit = {
     if (!spillingEnabled) {
       return
     }
 
-    val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
-
-    // TODO: factor this out of both here and ExternalAppendOnlyMap
-    if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
-        collection.estimateSize() >= myMemoryThreshold)
-    {
-      // Claim up to double our current memory from the shuffle memory pool
-      val currentMemory = collection.estimateSize()
-      val amountToRequest = 2 * currentMemory - myMemoryThreshold
-      val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
-      myMemoryThreshold += granted
-      if (myMemoryThreshold <= currentMemory) {
-        // We were granted too little memory to grow further (either tryToAcquire returned 0,
-        // or we already had more memory than myMemoryThreshold); spill the current collection
-        spill(currentMemory, usingMap)  // Will also release memory back to ShuffleMemoryManager
+    if (usingMap) {
+      if (maybeSpill(map, map.estimateSize())) {
+        map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+      }
+    } else {
+      if (maybeSpill(buffer, buffer.estimateSize())) {
+        buffer = new SizeTrackingPairBuffer[(Int, K), C]
       }
     }
   }
 
   /**
    * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
-   *
-   * @param usingMap whether we're using a map or buffer as our current in-memory collection
    */
-  private def spill(memorySize: Long, usingMap: Boolean): Unit = {
-    val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
-    val memorySize = collection.estimateSize()
-
-    spillCount += 1
-    val threadId = Thread.currentThread().getId
-    logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)"
-      .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
-
+  override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
     if (bypassMergeSort) {
       spillToPartitionFiles(collection)
     } else {
       spillToMergeableFile(collection)
     }
-
-    if (usingMap) {
-      map = new SizeTrackingAppendOnlyMap[(Int, K), C]
-    } else {
-      buffer = new SizeTrackingPairBuffer[(Int, K), C]
-    }
-
-    // Release our memory back to the shuffle pool so that other threads can grab it
-    shuffleMemoryManager.release(myMemoryThreshold)
-    myMemoryThreshold = 0
-
-    _memoryBytesSpilled += memorySize
   }
 
   /**
@@ -804,8 +766,6 @@ private[spark] class ExternalSorter[K, V, C](
     }
   }
 
-  def memoryBytesSpilled: Long = _memoryBytesSpilled
-
   def diskBytesSpilled: Long = _diskBytesSpilled
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/25164a89/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
new file mode 100644
index 0000000..d7dccd4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkEnv
+
+/**
+ * Spills contents of an in-memory collection to disk when the memory threshold
+ * has been exceeded.
+ */
+private[spark] trait Spillable[C] {
+
+  this: Logging =>
+
+  /**
+   * Spills the current in-memory collection to disk, and releases the memory.
+   *
+   * @param collection collection to spill to disk
+   */
+  protected def spill(collection: C): Unit
+
+  // Number of elements read from input since last spill
+  protected var elementsRead: Long
+
+  // Memory manager that can be used to acquire/release memory
+  private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
+
+  // What threshold of elementsRead we start estimating collection size at
+  private[this] val trackMemoryThreshold = 1000
+
+  // How much of the shared memory pool this collection has claimed
+  private[this] var myMemoryThreshold = 0L
+
+  // Number of bytes spilled in total
+  private[this] var _memoryBytesSpilled = 0L
+
+  // Number of spills
+  private[this] var _spillCount = 0
+
+  /**
+   * Spills the current in-memory collection to disk if needed. Attempts to acquire more
+   * memory before spilling.
+   *
+   * @param collection collection to spill to disk
+   * @param currentMemory estimated size of the collection in bytes
+   * @return true if `collection` was spilled to disk; false otherwise
+   */
+  protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
+    if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
+        currentMemory >= myMemoryThreshold) {
+      // Claim up to double our current memory from the shuffle memory pool
+      val amountToRequest = 2 * currentMemory - myMemoryThreshold
+      val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
+      myMemoryThreshold += granted
+      if (myMemoryThreshold <= currentMemory) {
+        // We were granted too little memory to grow further (either tryToAcquire returned 0,
+        // or we already had more memory than myMemoryThreshold); spill the current collection
+        _spillCount += 1
+        logSpillage(currentMemory)
+
+        spill(collection)
+
+        // Keep track of spills, and release memory
+        _memoryBytesSpilled += currentMemory
+        releaseMemoryForThisThread()
+        return true
+      }
+    }
+    false
+  }
+
+  /**
+   * @return number of bytes spilled in total
+   */
+  def memoryBytesSpilled: Long = _memoryBytesSpilled
+
+  /**
+   * Release our memory back to the shuffle pool so that other threads can grab it.
+   */
+  private def releaseMemoryForThisThread(): Unit = {
+    shuffleMemoryManager.release(myMemoryThreshold)
+    myMemoryThreshold = 0L
+  }
+
+  /**
+   * Prints a standard log message detailing spillage.
+   *
+   * @param size number of bytes spilled
+   */
+  @inline private def logSpillage(size: Long) {
+    val threadId = Thread.currentThread().getId
+    logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
+        .format(threadId, size / (1024 * 1024), _spillCount, if (_spillCount > 1) "s" else ""))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org