You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2013/11/05 02:54:44 UTC

git commit: Merge pull request #130 from aarondav/shuffle

Updated Branches:
  refs/heads/branch-0.8 1d11e4336 -> 7e00dee27


Merge pull request #130 from aarondav/shuffle

Memory-optimized shuffle file consolidation

Reduces overhead of each shuffle block for consolidation from >300 bytes to 8 bytes (1 primitive Long). Verified via profiler testing with 1 mil shuffle blocks, net overhead was ~8,400,000 bytes.

Despite the memory-optimized implementation incurring extra CPU overhead, the runtime of the shuffle phase in this test was only around 2% slower, while the reduce phase was 40% faster, when compared to not using any shuffle file consolidation.

This is accomplished by replacing the map from ShuffleBlockId to FileSegment (i.e., block id to where it's located), which had high overhead due to being a gigantic, timestamped, concurrent map with a more space-efficient structure. Namely, the following are introduced (I have omitted the word "Shuffle" from some names for clarity):
**ShuffleFile** - there is one ShuffleFile per consolidated shuffle file on disk. We store an array of offsets into the physical shuffle file for each ShuffleMapTask that wrote into the file. This is sufficient to reconstruct FileSegments for mappers that are in the file.
**FileGroup** - contains a set of ShuffleFiles, one per reducer, that a MapTask can use to write its output. There is one FileGroup created per _concurrent_ MapTask. The FileGroup contains an array of the mapIds that have been written to all files in the group. The positions of elements in this array map directly onto the positions in each ShuffleFile's offsets array.

In order to locate the FileSegment associated with a BlockId, we have another structure which maps each reducer to the set of ShuffleFiles that were created for it. (There will be as many ShuffleFiles per reducer as there are FileGroups.) To lookup a given ShuffleBlockId (shuffleId, reducerId, mapId), we thus search through all ShuffleFiles associated with that reducer.

As a time optimization, we ensure that FileGroups are only reused for MapTasks with monotonically increasing mapIds. This allows us to perform a binary search to locate a mapId inside a group, and also enables potential future optimization (based on the usual monotonic access order).

(cherry picked from commit 7a26104ab7cb492b347ba761ef1f17ca1b9078e4)
Signed-off-by: Reynold Xin <rx...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/7e00dee2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/7e00dee2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/7e00dee2

Branch: refs/heads/branch-0.8
Commit: 7e00dee27fb43530aee888f06d8829ac7f9bce6e
Parents: 1d11e43
Author: Reynold Xin <rx...@apache.org>
Authored: Mon Nov 4 17:54:06 2013 -0800
Committer: Reynold Xin <rx...@apache.org>
Committed: Mon Nov 4 17:54:35 2013 -0800

----------------------------------------------------------------------
 .../apache/spark/scheduler/ShuffleMapTask.scala |  23 +--
 .../org/apache/spark/storage/BlockManager.scala |  10 +-
 .../spark/storage/BlockObjectWriter.scala       |  15 +-
 .../apache/spark/storage/DiskBlockManager.scala |  49 +----
 .../org/apache/spark/storage/DiskStore.scala    |   4 +-
 .../spark/storage/ShuffleBlockManager.scala     | 189 +++++++++++++++----
 .../org/apache/spark/util/MetadataCleaner.scala |   2 +-
 .../collection/PrimitiveKeyOpenHashMap.scala    |   6 +
 .../spark/util/collection/PrimitiveVector.scala |  51 +++++
 .../scala/spark/storage/StoragePerfTester.scala |  10 +-
 .../spark/storage/DiskBlockManagerSuite.scala   |  84 +++++++++
 11 files changed, 333 insertions(+), 110 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7e00dee2/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 24d97da..1dc71a0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -146,26 +146,26 @@ private[spark] class ShuffleMapTask(
     metrics = Some(context.taskMetrics)
 
     val blockManager = SparkEnv.get.blockManager
-    var shuffle: ShuffleBlocks = null
-    var buckets: ShuffleWriterGroup = null
+    val shuffleBlockManager = blockManager.shuffleBlockManager
+    var shuffle: ShuffleWriterGroup = null
+    var success = false
 
     try {
       // Obtain all the block writers for shuffle blocks.
       val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
-      shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
-      buckets = shuffle.acquireWriters(partitionId)
+      shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
 
       // Write the map output to its associated buckets.
       for (elem <- rdd.iterator(split, context)) {
         val pair = elem.asInstanceOf[Product2[Any, Any]]
         val bucketId = dep.partitioner.getPartition(pair._1)
-        buckets.writers(bucketId).write(pair)
+        shuffle.writers(bucketId).write(pair)
       }
 
       // Commit the writes. Get the size of each bucket block (total block size).
       var totalBytes = 0L
       var totalTime = 0L
-      val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+      val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
         writer.commit()
         val size = writer.fileSegment().length
         totalBytes += size
@@ -179,19 +179,20 @@ private[spark] class ShuffleMapTask(
       shuffleMetrics.shuffleWriteTime = totalTime
       metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
 
+      success = true
       new MapStatus(blockManager.blockManagerId, compressedSizes)
     } catch { case e: Exception =>
       // If there is an exception from running the task, revert the partial writes
       // and throw the exception upstream to Spark.
-      if (buckets != null) {
-        buckets.writers.foreach(_.revertPartialWrites())
+      if (shuffle != null) {
+        shuffle.writers.foreach(_.revertPartialWrites())
       }
       throw e
     } finally {
       // Release the writers back to the shuffle block manager.
-      if (shuffle != null && buckets != null) {
-        buckets.writers.foreach(_.close())
-        shuffle.releaseWriters(buckets)
+      if (shuffle != null && shuffle.writers != null) {
+        shuffle.writers.foreach(_.close())
+        shuffle.releaseWriters(success)
       }
       // Execute the callbacks on task completion.
       context.executeOnCompleteCallbacks()

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7e00dee2/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 8f4d69d..ccc05f5 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.storage
 
-import java.io.{InputStream, OutputStream}
+import java.io.{File, InputStream, OutputStream}
 import java.nio.{ByteBuffer, MappedByteBuffer}
 
 import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
@@ -47,7 +47,7 @@ private[spark] class BlockManager(
   extends Logging {
 
   val shuffleBlockManager = new ShuffleBlockManager(this)
-  val diskBlockManager = new DiskBlockManager(
+  val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
     System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
 
   private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -517,15 +517,11 @@ private[spark] class BlockManager(
    * This is currently used for writing shuffle files out. Callers should handle error
    * cases.
    */
-  def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
+  def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int)
     : BlockObjectWriter = {
     val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
-    val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending =  true)
     val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
     writer.registerCloseEventHandler(() => {
-      if (shuffleBlockManager.consolidateShuffleFiles) {
-        diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
-      }
       val myInfo = new ShuffleBlockInfo()
       blockInfo.put(blockId, myInfo)
       myInfo.markReady(writer.fileSegment().length)

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7e00dee2/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 32d2dd0..e49c191 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -78,11 +78,11 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
 
 /** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
 class DiskBlockObjectWriter(
-                             blockId: BlockId,
-                             file: File,
-                             serializer: Serializer,
-                             bufferSize: Int,
-                             compressStream: OutputStream => OutputStream)
+    blockId: BlockId,
+    file: File,
+    serializer: Serializer,
+    bufferSize: Int,
+    compressStream: OutputStream => OutputStream)
   extends BlockObjectWriter(blockId)
   with Logging
 {
@@ -111,8 +111,8 @@ class DiskBlockObjectWriter(
   private var fos: FileOutputStream = null
   private var ts: TimeTrackingOutputStream = null
   private var objOut: SerializationStream = null
-  private var initialPosition = 0L
-  private var lastValidPosition = 0L
+  private val initialPosition = file.length()
+  private var lastValidPosition = initialPosition
   private var initialized = false
   private var _timeWriting = 0L
 
@@ -120,7 +120,6 @@ class DiskBlockObjectWriter(
     fos = new FileOutputStream(file, true)
     ts = new TimeTrackingOutputStream(fos)
     channel = fos.getChannel()
-    initialPosition = channel.position
     lastValidPosition = initialPosition
     bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
     objOut = serializer.newInstance().serializeStream(bs)

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7e00dee2/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index bcb58ad..fcd2e97 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -20,12 +20,11 @@ package org.apache.spark.storage
 import java.io.File
 import java.text.SimpleDateFormat
 import java.util.{Date, Random}
-import java.util.concurrent.ConcurrentHashMap
 
 import org.apache.spark.Logging
 import org.apache.spark.executor.ExecutorExitCode
 import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+import org.apache.spark.util.Utils
 
 /**
  * Creates and maintains the logical mapping between logical blocks and physical on-disk
@@ -35,7 +34,8 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
  *
  * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
  */
-private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging {
+private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String)
+  extends PathResolver with Logging {
 
   private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
   private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
@@ -47,54 +47,23 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
   private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
   private var shuffleSender : ShuffleSender = null
 
-  // Stores only Blocks which have been specifically mapped to segments of files
-  // (rather than the default, which maps a Block to a whole file).
-  // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks. 
-  private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment]
-
-  val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup)
-
   addShutdownHook()
 
   /**
-   * Creates a logical mapping from the given BlockId to a segment of a file.
-   * This will cause any accesses of the logical BlockId to be directed to the specified
-   * physical location.
-   */
-  def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) {
-    blockToFileSegmentMap.put(blockId, fileSegment)
-  }
-
-  /**
    * Returns the phyiscal file segment in which the given BlockId is located.
    * If the BlockId has been mapped to a specific FileSegment, that will be returned.
    * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
    */
   def getBlockLocation(blockId: BlockId): FileSegment = {
-    if (blockToFileSegmentMap.internalMap.containsKey(blockId)) {
-      blockToFileSegmentMap.get(blockId).get
+    if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
+      shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
     } else {
       val file = getFile(blockId.name)
       new FileSegment(file, 0, file.length())
     }
   }
 
-  /**
-   * Simply returns a File to place the given Block into. This does not physically create the file.
-   * If filename is given, that file will be used. Otherwise, we will use the BlockId to get
-   * a unique filename.
-   */
-  def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = {
-    val actualFilename = if (filename == "") blockId.name else filename
-    val file = getFile(actualFilename)
-    if (!allowAppending && file.exists()) {
-      throw new IllegalStateException(
-        "Attempted to create file that already exists: " + actualFilename)
-    }
-    file
-  }
-
-  private def getFile(filename: String): File = {
+  def getFile(filename: String): File = {
     // Figure out which local directory it hashes to, and which subdirectory in that
     val hash = Utils.nonNegativeHash(filename)
     val dirId = hash % localDirs.length
@@ -119,6 +88,8 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
     new File(subDir, filename)
   }
 
+  def getFile(blockId: BlockId): File = getFile(blockId.name)
+
   private def createLocalDirs(): Array[File] = {
     logDebug("Creating local directories at root dirs '" + rootDirs + "'")
     val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
@@ -151,10 +122,6 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
     }
   }
 
-  private def cleanup(cleanupTime: Long) {
-    blockToFileSegmentMap.clearOldValues(cleanupTime)
-  }
-
   private def addShutdownHook() {
     localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
     Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7e00dee2/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index a3c496f..5a1e7b4 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -44,7 +44,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
     val bytes = _bytes.duplicate()
     logDebug("Attempting to put block " + blockId)
     val startTime = System.currentTimeMillis
-    val file = diskManager.createBlockFile(blockId, allowAppending = false)
+    val file = diskManager.getFile(blockId)
     val channel = new FileOutputStream(file).getChannel()
     while (bytes.remaining > 0) {
       channel.write(bytes)
@@ -64,7 +64,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
 
     logDebug("Attempting to write values for block " + blockId)
     val startTime = System.currentTimeMillis
-    val file = diskManager.createBlockFile(blockId, allowAppending = false)
+    val file = diskManager.getFile(blockId)
     val outputStream = new FileOutputStream(file)
     blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
     val length = file.length

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7e00dee2/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 066e45a..2f1b049 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -17,33 +17,45 @@
 
 package org.apache.spark.storage
 
+import java.io.File
 import java.util.concurrent.ConcurrentLinkedQueue
 import java.util.concurrent.atomic.AtomicInteger
 
+import scala.collection.JavaConversions._
+
 import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
 
-private[spark]
-class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter])
+/** A group of writers for a ShuffleMapTask, one writer per reducer. */
+private[spark] trait ShuffleWriterGroup {
+  val writers: Array[BlockObjectWriter]
 
-private[spark]
-trait ShuffleBlocks {
-  def acquireWriters(mapId: Int): ShuffleWriterGroup
-  def releaseWriters(group: ShuffleWriterGroup)
+  /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
+  def releaseWriters(success: Boolean)
 }
 
 /**
- * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer
- * per reducer.
+ * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file
+ * per reducer (this set of files is called a ShuffleFileGroup).
  *
  * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
  * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
- * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files,
- * it releases them for another task.
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle
+ * files, it releases them for another task.
  * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
  *   - shuffleId: The unique id given to the entire shuffle stage.
  *   - bucketId: The id of the output partition (i.e., reducer id)
  *   - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
  *       time owns a particular fileId, and this id is returned to a pool when the task finishes.
+ * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length)
+ * that specifies where in a given file the actual block data is located.
+ *
+ * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping
+ * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for
+ * each block stored in each file. In order to find the location of a shuffle block, we search the
+ * files within a ShuffleFileGroups associated with the block's reducer.
  */
 private[spark]
 class ShuffleBlockManager(blockManager: BlockManager) {
@@ -52,45 +64,152 @@ class ShuffleBlockManager(blockManager: BlockManager) {
   val consolidateShuffleFiles =
     System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
 
-  var nextFileId = new AtomicInteger(0)
-  val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]()
+  private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+
+  /**
+   * Contains all the state related to a particular shuffle. This includes a pool of unused
+   * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle.
+   */
+  private class ShuffleState() {
+    val nextFileId = new AtomicInteger(0)
+    val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+    val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+  }
+
+  type ShuffleId = Int
+  private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
+
+  private
+  val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup)
 
-  def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = {
-    new ShuffleBlocks {
-      // Get a group of writers for a map task.
-      override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
-        val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
-        val fileId = getUnusedFileId()
-        val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+  def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
+    new ShuffleWriterGroup {
+      shuffleStates.putIfAbsent(shuffleId, new ShuffleState())
+      private val shuffleState = shuffleStates(shuffleId)
+      private var fileGroup: ShuffleFileGroup = null
+
+      val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+        fileGroup = getUnusedFileGroup()
+        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
-          if (consolidateShuffleFiles) {
-            val filename = physicalFileName(shuffleId, bucketId, fileId)
-            blockManager.getDiskWriter(blockId, filename, serializer, bufferSize)
-          } else {
-            blockManager.getDiskWriter(blockId, blockId.name, serializer, bufferSize)
+          blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
+        }
+      } else {
+        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+          val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+          val blockFile = blockManager.diskBlockManager.getFile(blockId)
+          blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
+        }
+      }
+
+      override def releaseWriters(success: Boolean) {
+        if (consolidateShuffleFiles) {
+          if (success) {
+            val offsets = writers.map(_.fileSegment().offset)
+            fileGroup.recordMapOutput(mapId, offsets)
           }
+          recycleFileGroup(fileGroup)
         }
-        new ShuffleWriterGroup(mapId, fileId, writers)
       }
 
-      override def releaseWriters(group: ShuffleWriterGroup) {
-        recycleFileId(group.fileId)
+      private def getUnusedFileGroup(): ShuffleFileGroup = {
+        val fileGroup = shuffleState.unusedFileGroups.poll()
+        if (fileGroup != null) fileGroup else newFileGroup()
+      }
+
+      private def newFileGroup(): ShuffleFileGroup = {
+        val fileId = shuffleState.nextFileId.getAndIncrement()
+        val files = Array.tabulate[File](numBuckets) { bucketId =>
+          val filename = physicalFileName(shuffleId, bucketId, fileId)
+          blockManager.diskBlockManager.getFile(filename)
+        }
+        val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files)
+        shuffleState.allFileGroups.add(fileGroup)
+        fileGroup
       }
-    }
-  }
 
-  private def getUnusedFileId(): Int = {
-    val fileId = unusedFileIds.poll()
-    if (fileId == null) nextFileId.getAndIncrement() else fileId
+      private def recycleFileGroup(group: ShuffleFileGroup) {
+        shuffleState.unusedFileGroups.add(group)
+      }
+    }
   }
 
-  private def recycleFileId(fileId: Int) {
-    if (consolidateShuffleFiles) {
-      unusedFileIds.add(fileId)
+  /**
+   * Returns the physical file segment in which the given BlockId is located.
+   * This function should only be called if shuffle file consolidation is enabled, as it is
+   * an error condition if we don't find the expected block.
+   */
+  def getBlockLocation(id: ShuffleBlockId): FileSegment = {
+    // Search all file groups associated with this shuffle.
+    val shuffleState = shuffleStates(id.shuffleId)
+    for (fileGroup <- shuffleState.allFileGroups) {
+      val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId)
+      if (segment.isDefined) { return segment.get }
     }
+    throw new IllegalStateException("Failed to find shuffle block: " + id)
   }
 
   private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
     "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
   }
+
+  private def cleanup(cleanupTime: Long) {
+    shuffleStates.clearOldValues(cleanupTime)
+  }
+}
+
+private[spark]
+object ShuffleBlockManager {
+  /**
+   * A group of shuffle files, one per reducer.
+   * A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
+   */
+  private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) {
+    /**
+     * Stores the absolute index of each mapId in the files of this group. For instance,
+     * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
+     */
+    private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
+
+    /**
+     * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
+     * This ordering allows us to compute block lengths by examining the following block offset.
+     * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every
+     * reducer.
+     */
+    private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) {
+      new PrimitiveVector[Long]()
+    }
+
+    def numBlocks = mapIdToIndex.size
+
+    def apply(bucketId: Int) = files(bucketId)
+
+    def recordMapOutput(mapId: Int, offsets: Array[Long]) {
+      mapIdToIndex(mapId) = numBlocks
+      for (i <- 0 until offsets.length) {
+        blockOffsetsByReducer(i) += offsets(i)
+      }
+    }
+
+    /** Returns the FileSegment associated with the given map task, or None if no entry exists. */
+    def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = {
+      val file = files(reducerId)
+      val blockOffsets = blockOffsetsByReducer(reducerId)
+      val index = mapIdToIndex.getOrElse(mapId, -1)
+      if (index >= 0) {
+        val offset = blockOffsets(index)
+        val length =
+          if (index + 1 < numBlocks) {
+            blockOffsets(index + 1) - offset
+          } else {
+            file.length() - offset
+          }
+        assert(length >= 0)
+        Some(new FileSegment(file, offset, length))
+      } else {
+        None
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7e00dee2/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 3f96372..67a7f87 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -59,7 +59,7 @@ object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext
   "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
 
   val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
-    SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value
+    SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
 
   type MetadataCleanerType = Value
 

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7e00dee2/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
index 4adf9cf..d76143e 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -53,6 +53,12 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
     _values(pos)
   }
 
+  /** Get the value for a given key, or returns elseValue if it doesn't exist. */
+  def getOrElse(k: K, elseValue: V): V = {
+    val pos = _keySet.getPos(k)
+    if (pos >= 0) _values(pos) else elseValue
+  }
+
   /** Set the value for a key */
   def update(k: K, v: V) {
     val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7e00dee2/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
new file mode 100644
index 0000000..369519c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */
+private[spark]
+class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) {
+  private var numElements = 0
+  private var array: Array[V] = _
+
+  // NB: This must be separate from the declaration, otherwise the specialized parent class
+  // will get its own array with the same initial size. TODO: Figure out why...
+  array = new Array[V](initialSize)
+
+  def apply(index: Int): V = {
+    require(index < numElements)
+    array(index)
+  }
+
+  def +=(value: V) {
+    if (numElements == array.length) { resize(array.length * 2) }
+    array(numElements) = value
+    numElements += 1
+  }
+
+  def length = numElements
+
+  def getUnderlyingArray = array
+
+  /** Resizes the array, dropping elements if the total length decreases. */
+  def resize(newLength: Int) {
+    val newArray = new Array[V](newLength)
+    array.copyToArray(newArray)
+    array = newArray
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7e00dee2/core/src/main/scala/spark/storage/StoragePerfTester.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/spark/storage/StoragePerfTester.scala b/core/src/main/scala/spark/storage/StoragePerfTester.scala
index 1b074e5..68893a2 100644
--- a/core/src/main/scala/spark/storage/StoragePerfTester.scala
+++ b/core/src/main/scala/spark/storage/StoragePerfTester.scala
@@ -36,19 +36,19 @@ object StoragePerfTester {
     val blockManager = sc.env.blockManager
 
     def writeOutputBytes(mapId: Int, total: AtomicLong) = {
-      val shuffle = blockManager.shuffleBlockManager.forShuffle(1, numOutputSplits,
+      val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
         new KryoSerializer())
-      val buckets = shuffle.acquireWriters(mapId)
+      val writers = shuffle.writers
       for (i <- 1 to recordsPerMap) {
-        buckets.writers(i % numOutputSplits).write(writeData)
+        writers(i % numOutputSplits).write(writeData)
       }
-      buckets.writers.map {w =>
+      writers.map {w =>
         w.commit()
         total.addAndGet(w.fileSegment().length)
         w.close()
       }
 
-      shuffle.releaseWriters(buckets)
+      shuffle.releaseWriters(true)
     }
 
     val start = System.currentTimeMillis()

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7e00dee2/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
new file mode 100644
index 0000000..0b90563
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -0,0 +1,84 @@
+package org.apache.spark.storage
+
+import java.io.{FileWriter, File}
+
+import scala.collection.mutable
+
+import com.google.common.io.Files
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+
+class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach {
+
+  val rootDir0 = Files.createTempDir()
+  rootDir0.deleteOnExit()
+  val rootDir1 = Files.createTempDir()
+  rootDir1.deleteOnExit()
+  val rootDirs = rootDir0.getName + "," + rootDir1.getName
+  println("Created root dirs: " + rootDirs)
+
+  val shuffleBlockManager = new ShuffleBlockManager(null) {
+    var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]()
+    override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id)
+  }
+
+  var diskBlockManager: DiskBlockManager = _
+
+  override def beforeEach() {
+    diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs)
+    shuffleBlockManager.idToSegmentMap.clear()
+  }
+
+  test("basic block creation") {
+    val blockId = new TestBlockId("test")
+    assertSegmentEquals(blockId, blockId.name, 0, 0)
+
+    val newFile = diskBlockManager.getFile(blockId)
+    writeToFile(newFile, 10)
+    assertSegmentEquals(blockId, blockId.name, 0, 10)
+
+    newFile.delete()
+  }
+
+  test("block appending") {
+    val blockId = new TestBlockId("test")
+    val newFile = diskBlockManager.getFile(blockId)
+    writeToFile(newFile, 15)
+    assertSegmentEquals(blockId, blockId.name, 0, 15)
+    val newFile2 = diskBlockManager.getFile(blockId)
+    assert(newFile === newFile2)
+    writeToFile(newFile2, 12)
+    assertSegmentEquals(blockId, blockId.name, 0, 27)
+    newFile.delete()
+  }
+
+  test("block remapping") {
+    val filename = "test"
+    val blockId0 = new ShuffleBlockId(1, 2, 3)
+    val newFile = diskBlockManager.getFile(filename)
+    writeToFile(newFile, 15)
+    shuffleBlockManager.idToSegmentMap(blockId0) = new FileSegment(newFile, 0, 15)
+    assertSegmentEquals(blockId0, filename, 0, 15)
+
+    val blockId1 = new ShuffleBlockId(1, 2, 4)
+    val newFile2 = diskBlockManager.getFile(filename)
+    writeToFile(newFile2, 12)
+    shuffleBlockManager.idToSegmentMap(blockId1) = new FileSegment(newFile, 15, 12)
+    assertSegmentEquals(blockId1, filename, 15, 12)
+
+    assert(newFile === newFile2)
+    newFile.delete()
+  }
+
+  def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) {
+    val segment = diskBlockManager.getBlockLocation(blockId)
+    assert(segment.file.getName === filename)
+    assert(segment.offset === offset)
+    assert(segment.length === length)
+  }
+
+  def writeToFile(file: File, numBytes: Int) {
+    val writer = new FileWriter(file, true)
+    for (i <- 0 until numBytes) writer.write(i)
+    writer.close()
+  }
+}