You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/20 01:56:29 UTC

git commit: Revert "[SPARK-2521] Broadcast RDD object (instead of sending it along with every task)."

Repository: spark
Updated Branches:
  refs/heads/master 2a732110d -> 1efb3698b


Revert "[SPARK-2521] Broadcast RDD object (instead of sending it along with every task)."

This reverts commit 7b8cd175254d42c8e82f0aa8eb4b7f3508d8fde2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1efb3698
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1efb3698
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1efb3698

Branch: refs/heads/master
Commit: 1efb3698b6cf39a80683b37124d2736ebf3c9d9a
Parents: 2a73211
Author: Reynold Xin <rx...@apache.org>
Authored: Sat Jul 19 16:56:22 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Sat Jul 19 16:56:22 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/Dependency.scala     |  28 ++--
 .../scala/org/apache/spark/SparkContext.scala   |   2 +
 .../main/scala/org/apache/spark/rdd/RDD.scala   |  30 +----
 .../apache/spark/rdd/RDDCheckpointData.scala    |   9 +-
 .../apache/spark/scheduler/DAGScheduler.scala   |   4 +
 .../org/apache/spark/scheduler/ResultTask.scala | 128 ++++++++++++++-----
 .../apache/spark/scheduler/ShuffleMapTask.scala | 125 ++++++++++++++----
 .../org/apache/spark/ContextCleanerSuite.scala  |  62 ++++-----
 8 files changed, 251 insertions(+), 137 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1efb3698/core/src/main/scala/org/apache/spark/Dependency.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 3935c87..09a6057 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -27,9 +27,7 @@ import org.apache.spark.shuffle.ShuffleHandle
  * Base class for dependencies.
  */
 @DeveloperApi
-abstract class Dependency[T] extends Serializable {
-  def rdd: RDD[T]
-}
+abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
 
 
 /**
@@ -38,24 +36,20 @@ abstract class Dependency[T] extends Serializable {
  * partition of the child RDD.  Narrow dependencies allow for pipelined execution.
  */
 @DeveloperApi
-abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
+abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
   /**
    * Get the parent partitions for a child partition.
    * @param partitionId a partition of the child RDD
    * @return the partitions of the parent RDD that the child partition depends upon
    */
   def getParents(partitionId: Int): Seq[Int]
-
-  override def rdd: RDD[T] = _rdd
 }
 
 
 /**
  * :: DeveloperApi ::
- * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
- * the RDD is transient since we don't need it on the executor side.
- *
- * @param _rdd the parent RDD
+ * Represents a dependency on the output of a shuffle stage.
+ * @param rdd the parent RDD
  * @param partitioner partitioner used to partition the shuffle output
  * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
  *                   the default serializer, as specified by `spark.serializer` config option, will
@@ -63,22 +57,20 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
  */
 @DeveloperApi
 class ShuffleDependency[K, V, C](
-    @transient _rdd: RDD[_ <: Product2[K, V]],
+    @transient rdd: RDD[_ <: Product2[K, V]],
     val partitioner: Partitioner,
     val serializer: Option[Serializer] = None,
     val keyOrdering: Option[Ordering[K]] = None,
     val aggregator: Option[Aggregator[K, V, C]] = None,
     val mapSideCombine: Boolean = false)
-  extends Dependency[Product2[K, V]] {
-
-  override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]]
+  extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
 
-  val shuffleId: Int = _rdd.context.newShuffleId()
+  val shuffleId: Int = rdd.context.newShuffleId()
 
-  val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
-    shuffleId, _rdd.partitions.size, this)
+  val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
+    shuffleId, rdd.partitions.size, this)
 
-  _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
+  rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1efb3698/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 48a0965..8052499 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -997,6 +997,8 @@ class SparkContext(config: SparkConf) extends Logging {
       // TODO: Cache.stop()?
       env.stop()
       SparkEnv.set(null)
+      ShuffleMapTask.clearCache()
+      ResultTask.clearCache()
       listenerBus.stop()
       eventLogger.foreach(_.stop())
       logInfo("Successfully stopped SparkContext")

http://git-wip-us.apache.org/repos/asf/spark/blob/1efb3698/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 2ee9a8f..88a918a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -35,13 +35,12 @@ import org.apache.spark.Partitioner._
 import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.partial.BoundedDouble
 import org.apache.spark.partial.CountEvaluator
 import org.apache.spark.partial.GroupedCountEvaluator
 import org.apache.spark.partial.PartialResult
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{BoundedPriorityQueue, Utils}
+import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils}
 import org.apache.spark.util.collection.OpenHashMap
 import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
 
@@ -1196,36 +1195,21 @@ abstract class RDD[T: ClassTag](
   /**
    * Return whether this RDD has been checkpointed or not
    */
-  def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)
+  def isCheckpointed: Boolean = {
+    checkpointData.map(_.isCheckpointed).getOrElse(false)
+  }
 
   /**
    * Gets the name of the file to which this RDD was checkpointed
    */
-  def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)
+  def getCheckpointFile: Option[String] = {
+    checkpointData.flatMap(_.getCheckpointFile)
+  }
 
   // =======================================================================
   // Other internal methods and fields
   // =======================================================================
 
-  /**
-   * Broadcasted copy of this RDD, used to dispatch tasks to executors. Note that we broadcast
-   * the serialized copy of the RDD and for each task we will deserialize it, which means each
-   * task gets a different copy of the RDD. This provides stronger isolation between tasks that
-   * might modify state of objects referenced in their closures. This is necessary in Hadoop
-   * where the JobConf/Configuration object is not thread-safe.
-   */
-  @transient private[spark] lazy val broadcasted: Broadcast[Array[Byte]] = {
-    val ser = SparkEnv.get.closureSerializer.newInstance()
-    val bytes = ser.serialize(this).array()
-    val size = Utils.bytesToString(bytes.length)
-    if (bytes.length > (1L << 20)) {
-      logWarning(s"Broadcasting RDD $id ($size), which contains large objects")
-    } else {
-      logDebug(s"Broadcasting RDD $id ($size)")
-    }
-    sc.broadcast(bytes)
-  }
-
   private var storageLevel: StorageLevel = StorageLevel.NONE
 
   /** User code that created this RDD (e.g. `textFile`, `parallelize`). */

http://git-wip-us.apache.org/repos/asf/spark/blob/1efb3698/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index f67e5f1..c3b2a33 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -106,6 +106,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
       cpRDD = Some(newRDD)
       rdd.markCheckpointed(newRDD)   // Update the RDD's dependencies and partitions
       cpState = Checkpointed
+      RDDCheckpointData.clearTaskCaches()
     }
     logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
   }
@@ -130,5 +131,9 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
   }
 }
 
-// Used for synchronization
-private[spark] object RDDCheckpointData
+private[spark] object RDDCheckpointData {
+  def clearTaskCaches() {
+    ShuffleMapTask.clearCache()
+    ResultTask.clearCache()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1efb3698/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 88cb5fe..ede3c7d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -376,6 +376,9 @@ class DAGScheduler(
               stageIdToStage -= stageId
               stageIdToJobIds -= stageId
 
+              ShuffleMapTask.removeStage(stageId)
+              ResultTask.removeStage(stageId)
+
               logDebug("After removal of stage %d, remaining stages = %d"
                 .format(stageId, stageIdToStage.size))
             }
@@ -720,6 +723,7 @@ class DAGScheduler(
     }
   }
 
+
   /** Called when stage's parents are available and we can now do its task. */
   private def submitMissingTasks(stage: Stage, jobId: Int) {
     logDebug("submitMissingTasks(" + stage + ")")

http://git-wip-us.apache.org/repos/asf/spark/blob/1efb3698/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 62beb0d..bbf9f73 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -17,68 +17,134 @@
 
 package org.apache.spark.scheduler
 
-import java.nio.ByteBuffer
+import scala.language.existentials
 
 import java.io._
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+import scala.collection.mutable.HashMap
 
 import org.apache.spark._
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{RDD, RDDCheckpointData}
+
+private[spark] object ResultTask {
+
+  // A simple map between the stage id to the serialized byte array of a task.
+  // Served as a cache for task serialization because serialization can be
+  // expensive on the master node if it needs to launch thousands of tasks.
+  private val serializedInfoCache = new HashMap[Int, Array[Byte]]
+
+  def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] =
+  {
+    synchronized {
+      val old = serializedInfoCache.get(stageId).orNull
+      if (old != null) {
+        old
+      } else {
+        val out = new ByteArrayOutputStream
+        val ser = SparkEnv.get.closureSerializer.newInstance()
+        val objOut = ser.serializeStream(new GZIPOutputStream(out))
+        objOut.writeObject(rdd)
+        objOut.writeObject(func)
+        objOut.close()
+        val bytes = out.toByteArray
+        serializedInfoCache.put(stageId, bytes)
+        bytes
+      }
+    }
+  }
+
+  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
+  {
+    val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+    val ser = SparkEnv.get.closureSerializer.newInstance()
+    val objIn = ser.deserializeStream(in)
+    val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+    val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
+    (rdd, func)
+  }
+
+  def removeStage(stageId: Int) {
+    serializedInfoCache.remove(stageId)
+  }
+
+  def clearCache() {
+    synchronized {
+      serializedInfoCache.clear()
+    }
+  }
+}
+
 
 /**
  * A task that sends back the output to the driver application.
  *
- * See [[Task]] for more information.
+ * See [[org.apache.spark.scheduler.Task]] for more information.
  *
  * @param stageId id of the stage this task belongs to
- * @param rddBinary broadcast version of of the serialized RDD
+ * @param rdd input to func
  * @param func a function to apply on a partition of the RDD
- * @param partition partition of the RDD this task is associated with
+ * @param _partitionId index of the number in the RDD
  * @param locs preferred task execution locations for locality scheduling
  * @param outputId index of the task in this job (a job can launch tasks on only a subset of the
  *                 input RDD's partitions).
  */
 private[spark] class ResultTask[T, U](
     stageId: Int,
-    val rddBinary: Broadcast[Array[Byte]],
-    val func: (TaskContext, Iterator[T]) => U,
-    val partition: Partition,
+    var rdd: RDD[T],
+    var func: (TaskContext, Iterator[T]) => U,
+    _partitionId: Int,
     @transient locs: Seq[TaskLocation],
-    val outputId: Int)
-  extends Task[U](stageId, partition.index) with Serializable {
-
-  // TODO: Should we also broadcast func? For that we would need a place to
-  // keep a reference to it (perhaps in DAGScheduler's job object).
-
-  def this(
-      stageId: Int,
-      rdd: RDD[T],
-      func: (TaskContext, Iterator[T]) => U,
-      partitionId: Int,
-      locs: Seq[TaskLocation],
-      outputId: Int) = {
-    this(stageId, rdd.broadcasted, func, rdd.partitions(partitionId), locs, outputId)
-  }
+    var outputId: Int)
+  extends Task[U](stageId, _partitionId) with Externalizable {
+
+  def this() = this(0, null, null, 0, null, 0)
+
+  var split = if (rdd == null) null else rdd.partitions(partitionId)
 
-  @transient private[this] val preferredLocs: Seq[TaskLocation] = {
+  @transient private val preferredLocs: Seq[TaskLocation] = {
     if (locs == null) Nil else locs.toSet.toSeq
   }
 
   override def runTask(context: TaskContext): U = {
-    // Deserialize the RDD using the broadcast variable.
-    val ser = SparkEnv.get.closureSerializer.newInstance()
-    val rdd = ser.deserialize[RDD[T]](ByteBuffer.wrap(rddBinary.value),
-      Thread.currentThread.getContextClassLoader)
     metrics = Some(context.taskMetrics)
     try {
-      func(context, rdd.iterator(partition, context))
+      func(context, rdd.iterator(split, context))
     } finally {
       context.executeOnCompleteCallbacks()
     }
   }
 
-  // This is only callable on the driver side.
   override def preferredLocations: Seq[TaskLocation] = preferredLocs
 
   override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
+
+  override def writeExternal(out: ObjectOutput) {
+    RDDCheckpointData.synchronized {
+      split = rdd.partitions(partitionId)
+      out.writeInt(stageId)
+      val bytes = ResultTask.serializeInfo(
+        stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
+      out.writeInt(bytes.length)
+      out.write(bytes)
+      out.writeInt(partitionId)
+      out.writeInt(outputId)
+      out.writeLong(epoch)
+      out.writeObject(split)
+    }
+  }
+
+  override def readExternal(in: ObjectInput) {
+    val stageId = in.readInt()
+    val numBytes = in.readInt()
+    val bytes = new Array[Byte](numBytes)
+    in.readFully(bytes)
+    val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
+    rdd = rdd_.asInstanceOf[RDD[T]]
+    func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
+    partitionId = in.readInt()
+    outputId = in.readInt()
+    epoch = in.readLong()
+    split = in.readObject().asInstanceOf[Partition]
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1efb3698/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 033c6e5..fdaf1de 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -17,13 +17,71 @@
 
 package org.apache.spark.scheduler
 
-import java.nio.ByteBuffer
+import scala.language.existentials
+
+import java.io._
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+import scala.collection.mutable.HashMap
 
 import org.apache.spark._
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{RDD, RDDCheckpointData}
 import org.apache.spark.shuffle.ShuffleWriter
 
+private[spark] object ShuffleMapTask {
+
+  // A simple map between the stage id to the serialized byte array of a task.
+  // Served as a cache for task serialization because serialization can be
+  // expensive on the master node if it needs to launch thousands of tasks.
+  private val serializedInfoCache = new HashMap[Int, Array[Byte]]
+
+  def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = {
+    synchronized {
+      val old = serializedInfoCache.get(stageId).orNull
+      if (old != null) {
+        return old
+      } else {
+        val out = new ByteArrayOutputStream
+        val ser = SparkEnv.get.closureSerializer.newInstance()
+        val objOut = ser.serializeStream(new GZIPOutputStream(out))
+        objOut.writeObject(rdd)
+        objOut.writeObject(dep)
+        objOut.close()
+        val bytes = out.toByteArray
+        serializedInfoCache.put(stageId, bytes)
+        bytes
+      }
+    }
+  }
+
+  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = {
+    val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+    val ser = SparkEnv.get.closureSerializer.newInstance()
+    val objIn = ser.deserializeStream(in)
+    val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+    val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]]
+    (rdd, dep)
+  }
+
+  // Since both the JarSet and FileSet have the same format this is used for both.
+  def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = {
+    val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+    val objIn = new ObjectInputStream(in)
+    val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
+    HashMap(set.toSeq: _*)
+  }
+
+  def removeStage(stageId: Int) {
+    serializedInfoCache.remove(stageId)
+  }
+
+  def clearCache() {
+    synchronized {
+      serializedInfoCache.clear()
+    }
+  }
+}
+
 /**
  * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
  * specified in the ShuffleDependency).
@@ -31,47 +89,62 @@ import org.apache.spark.shuffle.ShuffleWriter
  * See [[org.apache.spark.scheduler.Task]] for more information.
  *
  * @param stageId id of the stage this task belongs to
- * @param rddBinary broadcast version of of the serialized RDD
+ * @param rdd the final RDD in this stage
  * @param dep the ShuffleDependency
- * @param partition partition of the RDD this task is associated with
+ * @param _partitionId index of the number in the RDD
  * @param locs preferred task execution locations for locality scheduling
  */
 private[spark] class ShuffleMapTask(
     stageId: Int,
-    var rddBinary: Broadcast[Array[Byte]],
+    var rdd: RDD[_],
     var dep: ShuffleDependency[_, _, _],
-    partition: Partition,
+    _partitionId: Int,
     @transient private var locs: Seq[TaskLocation])
-  extends Task[MapStatus](stageId, partition.index) with Logging {
-
-  // TODO: Should we also broadcast the ShuffleDependency? For that we would need a place to
-  // keep a reference to it (perhaps in Stage).
-
-  def this(
-      stageId: Int,
-      rdd: RDD[_],
-      dep: ShuffleDependency[_, _, _],
-      partitionId: Int,
-      locs: Seq[TaskLocation]) = {
-    this(stageId, rdd.broadcasted, dep, rdd.partitions(partitionId), locs)
-  }
+  extends Task[MapStatus](stageId, _partitionId)
+  with Externalizable
+  with Logging {
+
+  protected def this() = this(0, null, null, 0, null)
 
   @transient private val preferredLocs: Seq[TaskLocation] = {
     if (locs == null) Nil else locs.toSet.toSeq
   }
 
-  override def runTask(context: TaskContext): MapStatus = {
-    // Deserialize the RDD using the broadcast variable.
-    val ser = SparkEnv.get.closureSerializer.newInstance()
-    val rdd = ser.deserialize[RDD[_]](ByteBuffer.wrap(rddBinary.value),
-      Thread.currentThread.getContextClassLoader)
+  var split = if (rdd == null) null else rdd.partitions(partitionId)
+
+  override def writeExternal(out: ObjectOutput) {
+    RDDCheckpointData.synchronized {
+      split = rdd.partitions(partitionId)
+      out.writeInt(stageId)
+      val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+      out.writeInt(bytes.length)
+      out.write(bytes)
+      out.writeInt(partitionId)
+      out.writeLong(epoch)
+      out.writeObject(split)
+    }
+  }
 
+  override def readExternal(in: ObjectInput) {
+    val stageId = in.readInt()
+    val numBytes = in.readInt()
+    val bytes = new Array[Byte](numBytes)
+    in.readFully(bytes)
+    val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
+    rdd = rdd_
+    dep = dep_
+    partitionId = in.readInt()
+    epoch = in.readLong()
+    split = in.readObject().asInstanceOf[Partition]
+  }
+
+  override def runTask(context: TaskContext): MapStatus = {
     metrics = Some(context.taskMetrics)
     var writer: ShuffleWriter[Any, Any] = null
     try {
       val manager = SparkEnv.get.shuffleManager
       writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
-      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
+      writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
       return writer.stop(success = true).get
     } catch {
       case e: Exception =>

http://git-wip-us.apache.org/repos/asf/spark/blob/1efb3698/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 871f831..13b415c 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -52,8 +52,9 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     }
   }
 
+
   test("cleanup RDD") {
-    val rdd = newRDD().persist()
+    val rdd = newRDD.persist()
     val collected = rdd.collect().toList
     val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
 
@@ -66,7 +67,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("cleanup shuffle") {
-    val (rdd, shuffleDeps) = newRDDWithShuffleDependencies()
+    val (rdd, shuffleDeps) = newRDDWithShuffleDependencies
     val collected = rdd.collect().toList
     val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
 
@@ -79,7 +80,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("cleanup broadcast") {
-    val broadcast = newBroadcast()
+    val broadcast = newBroadcast
     val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
 
     // Explicit cleanup
@@ -88,7 +89,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("automatically cleanup RDD") {
-    var rdd = newRDD().persist()
+    var rdd = newRDD.persist()
     rdd.count()
 
     // Test that GC does not cause RDD cleanup due to a strong reference
@@ -106,7 +107,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("automatically cleanup shuffle") {
-    var rdd = newShuffleRDD()
+    var rdd = newShuffleRDD
     rdd.count()
 
     // Test that GC does not cause shuffle cleanup due to a strong reference
@@ -124,7 +125,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("automatically cleanup broadcast") {
-    var broadcast = newBroadcast()
+    var broadcast = newBroadcast
 
     // Test that GC does not cause broadcast cleanup due to a strong reference
     val preGCTester =  new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
@@ -140,23 +141,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     postGCTester.assertCleanup()
   }
 
-  test("automatically cleanup broadcast data for task dispatching") {
-    var rdd = newRDDWithShuffleDependencies()._1
-    rdd.count()  // This triggers an action that broadcasts the RDDs.
-
-    // Test that GC causes broadcast task data cleanup after dereferencing the RDD.
-    val postGCTester = new CleanerTester(sc,
-      broadcastIds = Seq(rdd.broadcasted.id, rdd.firstParent.broadcasted.id))
-    rdd = null
-    runGC()
-    postGCTester.assertCleanup()
-  }
-
   test("automatically cleanup RDD + shuffle + broadcast") {
     val numRdds = 100
     val numBroadcasts = 4 // Broadcasts are more costly
-    val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
-    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
+    val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
     val rddIds = sc.persistentRdds.keys.toSeq
     val shuffleIds = 0 until sc.newShuffleId
     val broadcastIds = 0L until numBroadcasts
@@ -186,8 +175,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
 
     val numRdds = 10
     val numBroadcasts = 4 // Broadcasts are more costly
-    val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
-    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
+    val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
     val rddIds = sc.persistentRdds.keys.toSeq
     val shuffleIds = 0 until sc.newShuffleId
     val broadcastIds = 0L until numBroadcasts
@@ -208,18 +197,17 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
 
   //------ Helper functions ------
 
-  private def newRDD() = sc.makeRDD(1 to 10)
-  private def newPairRDD() = newRDD().map(_ -> 1)
-  private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
-  private def newBroadcast() = sc.broadcast(1 to 100)
-
-  private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
+  def newRDD = sc.makeRDD(1 to 10)
+  def newPairRDD = newRDD.map(_ -> 1)
+  def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
+  def newBroadcast = sc.broadcast(1 to 100)
+  def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
     def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
       rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
         getAllDependencies(dep.rdd)
       }
     }
-    val rdd = newShuffleRDD()
+    val rdd = newShuffleRDD
 
     // Get all the shuffle dependencies
     val shuffleDeps = getAllDependencies(rdd)
@@ -228,34 +216,34 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     (rdd, shuffleDeps)
   }
 
-  private def randomRdd() = {
+  def randomRdd = {
     val rdd: RDD[_] = Random.nextInt(3) match {
-      case 0 => newRDD()
-      case 1 => newShuffleRDD()
-      case 2 => newPairRDD.join(newPairRDD())
+      case 0 => newRDD
+      case 1 => newShuffleRDD
+      case 2 => newPairRDD.join(newPairRDD)
     }
     if (Random.nextBoolean()) rdd.persist()
     rdd.count()
     rdd
   }
 
-  private def randomBroadcast() = {
+  def randomBroadcast = {
     sc.broadcast(Random.nextInt(Int.MaxValue))
   }
 
   /** Run GC and make sure it actually has run */
-  private def runGC() {
+  def runGC() {
     val weakRef = new WeakReference(new Object())
     val startTime = System.currentTimeMillis
     System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
     // Wait until a weak reference object has been GCed
-    while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+    while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
       System.gc()
       Thread.sleep(200)
     }
   }
 
-  private def cleaner = sc.cleaner.get
+  def cleaner = sc.cleaner.get
 }