You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/06/12 05:45:31 UTC

git commit: [SPARK-2044] Pluggable interface for shuffles

Repository: spark
Updated Branches:
  refs/heads/master d9203350b -> 508fd371d


[SPARK-2044] Pluggable interface for shuffles

This is a first cut at moving shuffle logic behind a pluggable interface, as described at https://issues.apache.org/jira/browse/SPARK-2044, to let us more easily experiment with new shuffle implementations. It moves the existing shuffle code to a class HashShuffleManager behind a general ShuffleManager interface.

Two things are still missing to make this complete:
* MapOutputTracker needs to be hidden behind the ShuffleManager interface; this will also require adding methods to ShuffleManager that will let the DAGScheduler interact with it as it does with the MapOutputTracker today
* The code to do map-sides and reduce-side combine in ShuffledRDD, PairRDDFunctions, etc needs to be moved into the ShuffleManager's readers and writers

However, some of these may also be done later after we merge the current interface.

Author: Matei Zaharia <ma...@databricks.com>

Closes #1009 from mateiz/pluggable-shuffle and squashes the following commits:

7a09862 [Matei Zaharia] review comments
be33d3f [Matei Zaharia] review comments
1513d4e [Matei Zaharia] Add ASF header
ac56831 [Matei Zaharia] Bug fix and better error message
4f681ba [Matei Zaharia] Move write part of ShuffleMapTask to ShuffleManager
f6f011d [Matei Zaharia] Move hash shuffle reader behind ShuffleManager interface
55c7717 [Matei Zaharia] Changed RDD code to use ShuffleReader
75cc044 [Matei Zaharia] Partial work to move hash shuffle in


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/508fd371
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/508fd371
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/508fd371

Branch: refs/heads/master
Commit: 508fd371d6dbb826fd8a00787d347235b549e189
Parents: d920335
Author: Matei Zaharia <ma...@databricks.com>
Authored: Wed Jun 11 20:45:29 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Wed Jun 11 20:45:29 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/BlockStoreShuffleFetcher.scala |  92 ---------------
 .../scala/org/apache/spark/ContextCleaner.scala |   2 +-
 .../scala/org/apache/spark/Dependency.scala     |  12 +-
 .../scala/org/apache/spark/ShuffleFetcher.scala |  36 ------
 .../main/scala/org/apache/spark/SparkEnv.scala  |  28 +++--
 .../org/apache/spark/rdd/CoGroupedRDD.scala     |  22 ++--
 .../org/apache/spark/rdd/ShuffledRDD.scala      |  12 +-
 .../org/apache/spark/rdd/SubtractedRDD.scala    |  17 +--
 .../apache/spark/scheduler/DAGScheduler.scala   |  12 +-
 .../apache/spark/scheduler/ShuffleMapTask.scala |  73 +++---------
 .../org/apache/spark/scheduler/Stage.scala      |   2 +-
 .../spark/scheduler/TaskResultGetter.scala      |   3 +-
 .../apache/spark/serializer/Serializer.scala    |   4 +
 .../spark/shuffle/BaseShuffleHandle.scala       |  30 +++++
 .../apache/spark/shuffle/ShuffleHandle.scala    |  25 +++++
 .../apache/spark/shuffle/ShuffleManager.scala   |  57 ++++++++++
 .../apache/spark/shuffle/ShuffleReader.scala    |  29 +++++
 .../apache/spark/shuffle/ShuffleWriter.scala    |  31 ++++++
 .../shuffle/hash/BlockStoreShuffleFetcher.scala |  91 +++++++++++++++
 .../spark/shuffle/hash/HashShuffleManager.scala |  60 ++++++++++
 .../spark/shuffle/hash/HashShuffleReader.scala  |  42 +++++++
 .../spark/shuffle/hash/HashShuffleWriter.scala  | 111 +++++++++++++++++++
 .../org/apache/spark/ContextCleanerSuite.scala  |   6 +-
 .../scala/org/apache/spark/ShuffleSuite.scala   |   6 +-
 24 files changed, 566 insertions(+), 237 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
deleted file mode 100644
index a673924..0000000
--- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import org.apache.spark.executor.ShuffleReadMetrics
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
-import org.apache.spark.util.CompletionIterator
-
-private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
-
-  override def fetch[T](
-      shuffleId: Int,
-      reduceId: Int,
-      context: TaskContext,
-      serializer: Serializer)
-    : Iterator[T] =
-  {
-
-    logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
-    val blockManager = SparkEnv.get.blockManager
-
-    val startTime = System.currentTimeMillis
-    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
-    logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
-      shuffleId, reduceId, System.currentTimeMillis - startTime))
-
-    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
-    for (((address, size), index) <- statuses.zipWithIndex) {
-      splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
-    }
-
-    val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
-      case (address, splits) =>
-        (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
-    }
-
-    def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
-      val blockId = blockPair._1
-      val blockOption = blockPair._2
-      blockOption match {
-        case Some(block) => {
-          block.asInstanceOf[Iterator[T]]
-        }
-        case None => {
-          blockId match {
-            case ShuffleBlockId(shufId, mapId, _) =>
-              val address = statuses(mapId.toInt)._1
-              throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
-            case _ =>
-              throw new SparkException(
-                "Failed to get block " + blockId + ", which is not a shuffle block")
-          }
-        }
-      }
-    }
-
-    val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
-    val itr = blockFetcherItr.flatMap(unpackBlock)
-
-    val completionIter = CompletionIterator[T, Iterator[T]](itr, {
-      val shuffleMetrics = new ShuffleReadMetrics
-      shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
-      shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
-      shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
-      shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
-      shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
-      shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
-      context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
-    })
-
-    new InterruptibleIterator[T](context, completionIter)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/ContextCleaner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index e2d2250..bf3c3a6 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -96,7 +96,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
   }
 
   /** Register a ShuffleDependency for cleanup when it is garbage collected. */
-  def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
+  def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) {
     registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/Dependency.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 2c31cc2..c8c194a 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -20,6 +20,7 @@ package org.apache.spark
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleHandle
 
 /**
  * :: DeveloperApi ::
@@ -50,19 +51,24 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
  * Represents a dependency on the output of a shuffle stage.
  * @param rdd the parent RDD
  * @param partitioner partitioner used to partition the shuffle output
- * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null,
+ * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
  *                   the default serializer, as specified by `spark.serializer` config option, will
  *                   be used.
  */
 @DeveloperApi
-class ShuffleDependency[K, V](
+class ShuffleDependency[K, V, C](
     @transient rdd: RDD[_ <: Product2[K, V]],
     val partitioner: Partitioner,
-    val serializer: Serializer = null)
+    val serializer: Option[Serializer] = None,
+    val keyOrdering: Option[Ordering[K]] = None,
+    val aggregator: Option[Aggregator[K, V, C]] = None)
   extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
 
   val shuffleId: Int = rdd.context.newShuffleId()
 
+  val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
+    shuffleId, rdd.partitions.size, this)
+
   rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
deleted file mode 100644
index a4f69b6..0000000
--- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import org.apache.spark.serializer.Serializer
-
-private[spark] abstract class ShuffleFetcher {
-
-  /**
-   * Fetch the shuffle outputs for a given ShuffleDependency.
-   * @return An iterator over the elements of the fetched shuffle outputs.
-   */
-  def fetch[T](
-      shuffleId: Int,
-      reduceId: Int,
-      context: TaskContext,
-      serializer: Serializer = SparkEnv.get.serializer): Iterator[T]
-
-  /** Stop the fetcher */
-  def stop() {}
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 720151a..8dfa8cc 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem
 import org.apache.spark.network.ConnectionManager
 import org.apache.spark.scheduler.LiveListenerBus
 import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleManager
 import org.apache.spark.storage._
 import org.apache.spark.util.{AkkaUtils, Utils}
 
@@ -56,7 +57,7 @@ class SparkEnv (
     val closureSerializer: Serializer,
     val cacheManager: CacheManager,
     val mapOutputTracker: MapOutputTracker,
-    val shuffleFetcher: ShuffleFetcher,
+    val shuffleManager: ShuffleManager,
     val broadcastManager: BroadcastManager,
     val blockManager: BlockManager,
     val connectionManager: ConnectionManager,
@@ -80,7 +81,7 @@ class SparkEnv (
     pythonWorkers.foreach { case(key, worker) => worker.stop() }
     httpFileServer.stop()
     mapOutputTracker.stop()
-    shuffleFetcher.stop()
+    shuffleManager.stop()
     broadcastManager.stop()
     blockManager.stop()
     blockManager.master.stop()
@@ -163,13 +164,20 @@ object SparkEnv extends Logging {
     def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
       val name = conf.get(propertyName,  defaultClassName)
       val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader)
-      // First try with the constructor that takes SparkConf. If we can't find one,
-      // use a no-arg constructor instead.
+      // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
+      // SparkConf, then one taking no arguments
       try {
-        cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+        cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
+          .newInstance(conf, new java.lang.Boolean(isDriver))
+          .asInstanceOf[T]
       } catch {
         case _: NoSuchMethodException =>
-            cls.getConstructor().newInstance().asInstanceOf[T]
+          try {
+            cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+          } catch {
+            case _: NoSuchMethodException =>
+              cls.getConstructor().newInstance().asInstanceOf[T]
+          }
       }
     }
 
@@ -219,9 +227,6 @@ object SparkEnv extends Logging {
 
     val cacheManager = new CacheManager(blockManager)
 
-    val shuffleFetcher = instantiateClass[ShuffleFetcher](
-      "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
-
     val httpFileServer = new HttpFileServer(securityManager)
     httpFileServer.initialize()
     conf.set("spark.fileserver.uri",  httpFileServer.serverUri)
@@ -242,6 +247,9 @@ object SparkEnv extends Logging {
       "."
     }
 
+    val shuffleManager = instantiateClass[ShuffleManager](
+      "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
+
     // Warn about deprecated spark.cache.class property
     if (conf.contains("spark.cache.class")) {
       logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@@ -255,7 +263,7 @@ object SparkEnv extends Logging {
       closureSerializer,
       cacheManager,
       mapOutputTracker,
-      shuffleFetcher,
+      shuffleManager,
       broadcastManager,
       blockManager,
       connectionManager,

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 9ff7689..5951865 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
 import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleHandle
 
 private[spark] sealed trait CoGroupSplitDep extends Serializable
 
@@ -44,7 +45,7 @@ private[spark] case class NarrowCoGroupSplitDep(
   }
 }
 
-private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
+private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep
 
 private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
   extends Partition with Serializable {
@@ -74,10 +75,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
   private type CoGroupValue = (Any, Int)  // Int is dependency number
   private type CoGroupCombiner = Seq[CoGroup]
 
-  private var serializer: Serializer = null
+  private var serializer: Option[Serializer] = None
 
+  /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
   def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
-    this.serializer = serializer
+    this.serializer = Option(serializer)
     this
   }
 
@@ -88,7 +90,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
         new OneToOneDependency(rdd)
       } else {
         logDebug("Adding shuffle dependency with " + rdd)
-        new ShuffleDependency[Any, Any](rdd, part, serializer)
+        new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer)
       }
     }
   }
@@ -100,8 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
       array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) =>
         // Assume each RDD contributed a single dependency, and get it
         dependencies(j) match {
-          case s: ShuffleDependency[_, _] =>
-            new ShuffleCoGroupSplitDep(s.shuffleId)
+          case s: ShuffleDependency[_, _, _] =>
+            new ShuffleCoGroupSplitDep(s.shuffleHandle)
           case _ =>
             new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
         }
@@ -126,11 +128,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
         val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
         rddIterators += ((it, depNum))
 
-      case ShuffleCoGroupSplitDep(shuffleId) =>
+      case ShuffleCoGroupSplitDep(handle) =>
         // Read map outputs of shuffle
-        val fetcher = SparkEnv.get.shuffleFetcher
-        val ser = Serializer.getSerializer(serializer)
-        val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
+        val it = SparkEnv.get.shuffleManager
+          .getReader(handle, split.index, split.index + 1, context)
+          .read()
         rddIterators += ((it, depNum))
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 802b0bd..bb108ef 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -42,10 +42,11 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
     part: Partitioner)
   extends RDD[P](prev.context, Nil) {
 
-  private var serializer: Serializer = null
+  private var serializer: Option[Serializer] = None
 
+  /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
   def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
-    this.serializer = serializer
+    this.serializer = Option(serializer)
     this
   }
 
@@ -60,9 +61,10 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
   }
 
   override def compute(split: Partition, context: TaskContext): Iterator[P] = {
-    val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
-    val ser = Serializer.getSerializer(serializer)
-    SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
+    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]]
+    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
+      .read()
+      .asInstanceOf[Iterator[P]]
   }
 
   override def clearDependencies() {

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 9a09c05..ed24ea2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -54,10 +54,11 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
     part: Partitioner)
   extends RDD[(K, V)](rdd1.context, Nil) {
 
-  private var serializer: Serializer = null
+  private var serializer: Option[Serializer] = None
 
+  /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
   def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
-    this.serializer = serializer
+    this.serializer = Option(serializer)
     this
   }
 
@@ -79,8 +80,8 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
       // Each CoGroupPartition will depend on rdd1 and rdd2
       array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
         dependencies(j) match {
-          case s: ShuffleDependency[_, _] =>
-            new ShuffleCoGroupSplitDep(s.shuffleId)
+          case s: ShuffleDependency[_, _, _] =>
+            new ShuffleCoGroupSplitDep(s.shuffleHandle)
           case _ =>
             new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
         }
@@ -93,7 +94,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
 
   override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
     val partition = p.asInstanceOf[CoGroupPartition]
-    val ser = Serializer.getSerializer(serializer)
     val map = new JHashMap[K, ArrayBuffer[V]]
     def getSeq(k: K): ArrayBuffer[V] = {
       val seq = map.get(k)
@@ -109,9 +109,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
       case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
         rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
 
-      case ShuffleCoGroupSplitDep(shuffleId) =>
-        val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
-          context, ser)
+      case ShuffleCoGroupSplitDep(handle) =>
+        val iter = SparkEnv.get.shuffleManager
+          .getReader(handle, partition.index, partition.index + 1, context)
+          .read()
         iter.foreach(op)
     }
     // the first dep is rdd1; add all values to the map

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index e09a422..3c85b5a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -190,7 +190,7 @@ class DAGScheduler(
    * The jobId value passed in will be used if the stage doesn't already exist with
    * a lower jobId (jobId always increases across jobs.)
    */
-  private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = {
+  private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = {
     shuffleToMapStage.get(shuffleDep.shuffleId) match {
       case Some(stage) => stage
       case None =>
@@ -210,7 +210,7 @@ class DAGScheduler(
   private def newStage(
       rdd: RDD[_],
       numTasks: Int,
-      shuffleDep: Option[ShuffleDependency[_,_]],
+      shuffleDep: Option[ShuffleDependency[_, _, _]],
       jobId: Int,
       callSite: Option[String] = None)
     : Stage =
@@ -233,7 +233,7 @@ class DAGScheduler(
   private def newOrUsedStage(
       rdd: RDD[_],
       numTasks: Int,
-      shuffleDep: ShuffleDependency[_,_],
+      shuffleDep: ShuffleDependency[_, _, _],
       jobId: Int,
       callSite: Option[String] = None)
     : Stage =
@@ -269,7 +269,7 @@ class DAGScheduler(
         // we can't do it in its constructor because # of partitions is unknown
         for (dep <- r.dependencies) {
           dep match {
-            case shufDep: ShuffleDependency[_,_] =>
+            case shufDep: ShuffleDependency[_, _, _] =>
               parents += getShuffleMapStage(shufDep, jobId)
             case _ =>
               visit(dep.rdd)
@@ -290,7 +290,7 @@ class DAGScheduler(
         if (getCacheLocs(rdd).contains(Nil)) {
           for (dep <- rdd.dependencies) {
             dep match {
-              case shufDep: ShuffleDependency[_,_] =>
+              case shufDep: ShuffleDependency[_, _, _] =>
                 val mapStage = getShuffleMapStage(shufDep, stage.jobId)
                 if (!mapStage.isAvailable) {
                   missing += mapStage
@@ -1088,7 +1088,7 @@ class DAGScheduler(
         visitedRdds += rdd
         for (dep <- rdd.dependencies) {
           dep match {
-            case shufDep: ShuffleDependency[_,_] =>
+            case shufDep: ShuffleDependency[_, _, _] =>
               val mapStage = getShuffleMapStage(shufDep, stage.jobId)
               if (!mapStage.isAvailable) {
                 visitedStages += mapStage

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index ed0f56f..0098b5a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -29,6 +29,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.rdd.{RDD, RDDCheckpointData}
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.storage._
+import org.apache.spark.shuffle.ShuffleWriter
 
 private[spark] object ShuffleMapTask {
 
@@ -37,7 +38,7 @@ private[spark] object ShuffleMapTask {
   // expensive on the master node if it needs to launch thousands of tasks.
   private val serializedInfoCache = new HashMap[Int, Array[Byte]]
 
-  def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
+  def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = {
     synchronized {
       val old = serializedInfoCache.get(stageId).orNull
       if (old != null) {
@@ -56,12 +57,12 @@ private[spark] object ShuffleMapTask {
     }
   }
 
-  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = {
+  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = {
     val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
     val ser = SparkEnv.get.closureSerializer.newInstance()
     val objIn = ser.deserializeStream(in)
     val rdd = objIn.readObject().asInstanceOf[RDD[_]]
-    val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
+    val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]]
     (rdd, dep)
   }
 
@@ -99,7 +100,7 @@ private[spark] object ShuffleMapTask {
 private[spark] class ShuffleMapTask(
     stageId: Int,
     var rdd: RDD[_],
-    var dep: ShuffleDependency[_,_],
+    var dep: ShuffleDependency[_, _, _],
     _partitionId: Int,
     @transient private var locs: Seq[TaskLocation])
   extends Task[MapStatus](stageId, _partitionId)
@@ -141,66 +142,22 @@ private[spark] class ShuffleMapTask(
   }
 
   override def runTask(context: TaskContext): MapStatus = {
-    val numOutputSplits = dep.partitioner.numPartitions
     metrics = Some(context.taskMetrics)
-
-    val blockManager = SparkEnv.get.blockManager
-    val shuffleBlockManager = blockManager.shuffleBlockManager
-    var shuffle: ShuffleWriterGroup = null
-    var success = false
-
+    var writer: ShuffleWriter[Any, Any] = null
     try {
-      // Obtain all the block writers for shuffle blocks.
-      val ser = Serializer.getSerializer(dep.serializer)
-      shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
-
-      // Write the map output to its associated buckets.
+      val manager = SparkEnv.get.shuffleManager
+      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
       for (elem <- rdd.iterator(split, context)) {
-        val pair = elem.asInstanceOf[Product2[Any, Any]]
-        val bucketId = dep.partitioner.getPartition(pair._1)
-        shuffle.writers(bucketId).write(pair)
-      }
-
-      // Commit the writes. Get the size of each bucket block (total block size).
-      var totalBytes = 0L
-      var totalTime = 0L
-      val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
-        writer.commit()
-        writer.close()
-        val size = writer.fileSegment().length
-        totalBytes += size
-        totalTime += writer.timeWriting()
-        MapOutputTracker.compressSize(size)
+        writer.write(elem.asInstanceOf[Product2[Any, Any]])
       }
-
-      // Update shuffle metrics.
-      val shuffleMetrics = new ShuffleWriteMetrics
-      shuffleMetrics.shuffleBytesWritten = totalBytes
-      shuffleMetrics.shuffleWriteTime = totalTime
-      metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
-
-      success = true
-      new MapStatus(blockManager.blockManagerId, compressedSizes)
-    } catch { case e: Exception =>
-      // If there is an exception from running the task, revert the partial writes
-      // and throw the exception upstream to Spark.
-      if (shuffle != null && shuffle.writers != null) {
-        for (writer <- shuffle.writers) {
-          writer.revertPartialWrites()
-          writer.close()
+      return writer.stop(success = true).get
+    } catch {
+      case e: Exception =>
+        if (writer != null) {
+          writer.stop(success = false)
         }
-      }
-      throw e
+        throw e
     } finally {
-      // Release the writers back to the shuffle block manager.
-      if (shuffle != null && shuffle.writers != null) {
-        try {
-          shuffle.releaseWriters(success)
-        } catch {
-          case e: Exception => logError("Failed to release shuffle writers", e)
-        }
-      }
-      // Execute the callbacks on task completion.
       context.executeOnCompleteCallbacks()
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 5c1fc30..3bf9713 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -40,7 +40,7 @@ private[spark] class Stage(
     val id: Int,
     val rdd: RDD[_],
     val numTasks: Int,
-    val shuffleDep: Option[ShuffleDependency[_,_]],  // Output shuffle if stage is a map stage
+    val shuffleDep: Option[ShuffleDependency[_, _, _]],  // Output shuffle if stage is a map stage
     val parents: List[Stage],
     val jobId: Int,
     callSite: Option[String])

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 99d305b..df59f44 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -71,7 +71,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
             val loader = Thread.currentThread.getContextClassLoader
             taskSetManager.abort("ClassNotFound with classloader: " + loader)
           case ex: Exception =>
-            taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
+            logError("Exception while getting task result", ex)
+            taskSetManager.abort("Exception while getting task result: %s".format(ex))
         }
       }
     })

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index ee26970..f2f5cea 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -52,6 +52,10 @@ object Serializer {
   def getSerializer(serializer: Serializer): Serializer = {
     if (serializer == null) SparkEnv.get.serializer else serializer
   }
+
+  def getSerializer(serializer: Option[Serializer]): Serializer = {
+    serializer.getOrElse(SparkEnv.get.serializer)
+  }
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
new file mode 100644
index 0000000..b36c457
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner}
+import org.apache.spark.serializer.Serializer
+
+/**
+ * A basic ShuffleHandle implementation that just captures registerShuffle's parameters.
+ */
+private[spark] class BaseShuffleHandle[K, V, C](
+    shuffleId: Int,
+    val numMaps: Int,
+    val dependency: ShuffleDependency[K, V, C])
+  extends ShuffleHandle(shuffleId)

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala
new file mode 100644
index 0000000..13c7115
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+/**
+ * An opaque handle to a shuffle, used by a ShuffleManager to pass information about it to tasks.
+ *
+ * @param shuffleId ID of the shuffle
+ */
+private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {}

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
new file mode 100644
index 0000000..9c859b8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.{TaskContext, ShuffleDependency}
+
+/**
+ * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the
+ * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles
+ * with it, and executors (or tasks running locally in the driver) can ask to read and write data.
+ *
+ * NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and
+ * boolean isDriver as parameters.
+ */
+private[spark] trait ShuffleManager {
+  /**
+   * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+   */
+  def registerShuffle[K, V, C](
+      shuffleId: Int,
+      numMaps: Int,
+      dependency: ShuffleDependency[K, V, C]): ShuffleHandle
+
+  /** Get a writer for a given partition. Called on executors by map tasks. */
+  def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V]
+
+  /**
+   * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+   * Called on executors by reduce tasks.
+   */
+  def getReader[K, C](
+      handle: ShuffleHandle,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext): ShuffleReader[K, C]
+
+  /** Remove a shuffle's metadata from the ShuffleManager. */
+  def unregisterShuffle(shuffleId: Int)
+
+  /** Shut down this ShuffleManager. */
+  def stop(): Unit
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
new file mode 100644
index 0000000..b30e366
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+/**
+ * Obtained inside a reduce task to read combined records from the mappers.
+ */
+private[spark] trait ShuffleReader[K, C] {
+  /** Read the combined key-values for this reduce task */
+  def read(): Iterator[Product2[K, C]]
+
+  /** Close this reader */
+  def stop(): Unit
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
new file mode 100644
index 0000000..ead3ebd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.scheduler.MapStatus
+
+/**
+ * Obtained inside a map task to write out records to the shuffle system.
+ */
+private[spark] trait ShuffleWriter[K, V] {
+  /** Write a record to this task's output */
+  def write(record: Product2[K, V]): Unit
+
+  /** Close this writer, passing along whether the map completed */
+  def stop(success: Boolean): Option[MapStatus]
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
new file mode 100644
index 0000000..b05b6ea
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import org.apache.spark.executor.ShuffleReadMetrics
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
+import org.apache.spark.util.CompletionIterator
+import org.apache.spark._
+
+private[hash] object BlockStoreShuffleFetcher extends Logging {
+  def fetch[T](
+      shuffleId: Int,
+      reduceId: Int,
+      context: TaskContext,
+      serializer: Serializer)
+    : Iterator[T] =
+  {
+    logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
+    val blockManager = SparkEnv.get.blockManager
+
+    val startTime = System.currentTimeMillis
+    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
+    logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
+      shuffleId, reduceId, System.currentTimeMillis - startTime))
+
+    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
+    for (((address, size), index) <- statuses.zipWithIndex) {
+      splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
+    }
+
+    val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
+      case (address, splits) =>
+        (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
+    }
+
+    def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
+      val blockId = blockPair._1
+      val blockOption = blockPair._2
+      blockOption match {
+        case Some(block) => {
+          block.asInstanceOf[Iterator[T]]
+        }
+        case None => {
+          blockId match {
+            case ShuffleBlockId(shufId, mapId, _) =>
+              val address = statuses(mapId.toInt)._1
+              throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
+            case _ =>
+              throw new SparkException(
+                "Failed to get block " + blockId + ", which is not a shuffle block")
+          }
+        }
+      }
+    }
+
+    val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
+    val itr = blockFetcherItr.flatMap(unpackBlock)
+
+    val completionIter = CompletionIterator[T, Iterator[T]](itr, {
+      val shuffleMetrics = new ShuffleReadMetrics
+      shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
+      shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
+      shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
+      shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
+      shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
+      shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
+      context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
+    })
+
+    new InterruptibleIterator[T](context, completionIter)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
new file mode 100644
index 0000000..5b0940e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark._
+import org.apache.spark.shuffle._
+
+/**
+ * A ShuffleManager using hashing, that creates one output file per reduce partition on each
+ * mapper (possibly reusing these across waves of tasks).
+ */
+class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
+  /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
+  override def registerShuffle[K, V, C](
+      shuffleId: Int,
+      numMaps: Int,
+      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+    new BaseShuffleHandle(shuffleId, numMaps, dependency)
+  }
+
+  /**
+   * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+   * Called on executors by reduce tasks.
+   */
+  override def getReader[K, C](
+      handle: ShuffleHandle,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext): ShuffleReader[K, C] = {
+    new HashShuffleReader(
+      handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
+  }
+
+  /** Get a writer for a given partition. Called on executors by map tasks. */
+  override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
+      : ShuffleWriter[K, V] = {
+    new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
+  }
+
+  /** Remove a shuffle's metadata from the ShuffleManager. */
+  override def unregisterShuffle(shuffleId: Int): Unit = {}
+
+  /** Shut down this ShuffleManager. */
+  override def stop(): Unit = {}
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
new file mode 100644
index 0000000..f6a7903
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
+import org.apache.spark.TaskContext
+
+class HashShuffleReader[K, C](
+    handle: BaseShuffleHandle[K, _, C],
+    startPartition: Int,
+    endPartition: Int,
+    context: TaskContext)
+  extends ShuffleReader[K, C]
+{
+  require(endPartition == startPartition + 1,
+    "Hash shuffle currently only supports fetching one partition")
+
+  /** Read the combined key-values for this reduce task */
+  override def read(): Iterator[Product2[K, C]] = {
+    BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
+      Serializer.getSerializer(handle.dependency.serializer))
+  }
+
+  /** Close this reader */
+  override def stop(): Unit = ???
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
new file mode 100644
index 0000000..4c67490
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
+import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.storage.{BlockObjectWriter}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.scheduler.MapStatus
+
+class HashShuffleWriter[K, V](
+    handle: BaseShuffleHandle[K, V, _],
+    mapId: Int,
+    context: TaskContext)
+  extends ShuffleWriter[K, V] with Logging {
+
+  private val dep = handle.dependency
+  private val numOutputSplits = dep.partitioner.numPartitions
+  private val metrics = context.taskMetrics
+  private var stopping = false
+
+  private val blockManager = SparkEnv.get.blockManager
+  private val shuffleBlockManager = blockManager.shuffleBlockManager
+  private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
+  private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser)
+
+  /** Write a record to this task's output */
+  override def write(record: Product2[K, V]): Unit = {
+    val pair = record.asInstanceOf[Product2[Any, Any]]
+    val bucketId = dep.partitioner.getPartition(pair._1)
+    shuffle.writers(bucketId).write(pair)
+  }
+
+  /** Close this writer, passing along whether the map completed */
+  override def stop(success: Boolean): Option[MapStatus] = {
+    try {
+      if (stopping) {
+        return None
+      }
+      stopping = true
+      if (success) {
+        try {
+          return Some(commitWritesAndBuildStatus())
+        } catch {
+          case e: Exception =>
+            revertWrites()
+            throw e
+        }
+      } else {
+        revertWrites()
+        return None
+      }
+    } finally {
+      // Release the writers back to the shuffle block manager.
+      if (shuffle != null && shuffle.writers != null) {
+        try {
+          shuffle.releaseWriters(success)
+        } catch {
+          case e: Exception => logError("Failed to release shuffle writers", e)
+        }
+      }
+    }
+  }
+
+  private def commitWritesAndBuildStatus(): MapStatus = {
+    // Commit the writes. Get the size of each bucket block (total block size).
+    var totalBytes = 0L
+    var totalTime = 0L
+    val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
+      writer.commit()
+      writer.close()
+      val size = writer.fileSegment().length
+      totalBytes += size
+      totalTime += writer.timeWriting()
+      MapOutputTracker.compressSize(size)
+    }
+
+    // Update shuffle metrics.
+    val shuffleMetrics = new ShuffleWriteMetrics
+    shuffleMetrics.shuffleBytesWritten = totalBytes
+    shuffleMetrics.shuffleWriteTime = totalTime
+    metrics.shuffleWriteMetrics = Some(shuffleMetrics)
+
+    new MapStatus(blockManager.blockManagerId, compressedSizes)
+  }
+
+  private def revertWrites(): Unit = {
+    if (shuffle != null && shuffle.writers != null) {
+      for (writer <- shuffle.writers) {
+        writer.revertPartialWrites()
+        writer.close()
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index dc2db66..13b415c 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -201,7 +201,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   def newPairRDD = newRDD.map(_ -> 1)
   def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
   def newBroadcast = sc.broadcast(1 to 100)
-  def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = {
+  def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
     def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
       rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
         getAllDependencies(dep.rdd)
@@ -211,8 +211,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
 
     // Get all the shuffle dependencies
     val shuffleDeps = getAllDependencies(rdd)
-      .filter(_.isInstanceOf[ShuffleDependency[_, _]])
-      .map(_.asInstanceOf[ShuffleDependency[_, _]])
+      .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
+      .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
     (rdd, shuffleDeps)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/508fd371/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 7b0607d..47112ce 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -58,7 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     // default Java serializer cannot handle the non serializable class.
     val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
       b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf))
-    val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+    val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
 
     assert(c.count === 10)
 
@@ -97,7 +97,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
       .setSerializer(new KryoSerializer(conf))
 
-    val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+    val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
     assert(c.count === 4)
 
     val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
@@ -122,7 +122,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     // NOTE: The default Java serializer should create zero-sized blocks
     val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
 
-    val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+    val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
     assert(c.count === 4)
 
     val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>