You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/30 18:27:56 UTC

git commit: [SPARK-2521] Broadcast RDD object (instead of sending it along with every task)

Repository: spark
Updated Branches:
  refs/heads/master ee07541e9 -> 774142f55


[SPARK-2521] Broadcast RDD object (instead of sending it along with every task)

This is a resubmission of #1452. It was reverted because it broke the build.

Currently (as of Spark 1.0.1), Spark sends RDD object (which contains closures) using Akka along with the task itself to the executors. This is inefficient because all tasks in the same stage use the same RDD object, but we have to send RDD object multiple times to the executors. This is especially bad when a closure references some variable that is very large. The current design led to users having to explicitly broadcast large variables.

The patch uses broadcast to send RDD objects and the closures to executors, and use Akka to only send a reference to the broadcast RDD/closure along with the partition specific information for the task. For those of you who know more about the internals, Spark already relies on broadcast to send the Hadoop JobConf every time it uses the Hadoop input, because the JobConf is large.

The user-facing impact of the change include:

1. Users won't need to decide what to broadcast anymore, unless they would want to use a large object multiple times in different operations
2. Task size will get smaller, resulting in faster scheduling and higher task dispatch throughput.

In addition, the change will simplify some internals of Spark, eliminating the need to maintain task caches and the complex logic to broadcast JobConf (which also led to a deadlock recently).

A simple way to test this:
```scala
val a = new Array[Byte](1000*1000); scala.util.Random.nextBytes(a);
sc.parallelize(1 to 1000, 1000).map { x => a; x }.groupBy { x => a; x }.count
```

Numbers on 3 r3.8xlarge instances on EC2
```
master branch: 5.648436068 s, 4.715361895 s, 5.360161877 s
with this change: 3.416348793 s, 1.477846558 s, 1.553432156 s
```

Author: Reynold Xin <rx...@apache.org>

Closes #1498 from rxin/broadcast-task and squashes the following commits:

f7364db [Reynold Xin] Code review feedback.
f8535dc [Reynold Xin] Fixed the style violation.
252238d [Reynold Xin] Serialize the final task closure as well as ShuffleDependency in taskBinary.
111007d [Reynold Xin] Fix broadcast tests.
797c247 [Reynold Xin] Properly send SparkListenerStageSubmitted and SparkListenerStageCompleted.
bab1d8b [Reynold Xin] Check for NotSerializableException in submitMissingTasks.
cf38450 [Reynold Xin] Use TorrentBroadcastFactory.
991c002 [Reynold Xin] Use HttpBroadcast.
de779f8 [Reynold Xin] Fix TaskContextSuite.
cc152fc [Reynold Xin] Don't cache the RDD broadcast variable.
d256b45 [Reynold Xin] Fixed unit test failures. One more to go.
cae0af3 [Reynold Xin] [SPARK-2521] Broadcast RDD object (instead of sending it along with every task).


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/774142f5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/774142f5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/774142f5

Branch: refs/heads/master
Commit: 774142f5556ac37fddf03cfa46eb23ca1bde2492
Parents: ee07541
Author: Reynold Xin <rx...@apache.org>
Authored: Wed Jul 30 09:27:43 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Wed Jul 30 09:27:43 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/Dependency.scala     |  28 ++--
 .../scala/org/apache/spark/SparkContext.scala   |   2 -
 .../main/scala/org/apache/spark/rdd/RDD.scala   |  11 +-
 .../apache/spark/rdd/RDDCheckpointData.scala    |   9 +-
 .../apache/spark/scheduler/DAGScheduler.scala   |  87 +++++++++----
 .../org/apache/spark/scheduler/ResultTask.scala | 118 +++--------------
 .../apache/spark/scheduler/ShuffleMapTask.scala | 129 ++++---------------
 .../scala/org/apache/spark/util/Utils.scala     |   2 +-
 .../org/apache/spark/ContextCleanerSuite.scala  |  71 ++++++----
 .../scala/org/apache/spark/rdd/RDDSuite.scala   |   8 +-
 .../spark/scheduler/TaskContextSuite.scala      |  24 ++--
 .../ui/jobs/JobProgressListenerSuite.scala      |  11 +-
 12 files changed, 198 insertions(+), 302 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/main/scala/org/apache/spark/Dependency.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 09a6057..3935c87 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -27,7 +27,9 @@ import org.apache.spark.shuffle.ShuffleHandle
  * Base class for dependencies.
  */
 @DeveloperApi
-abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
+abstract class Dependency[T] extends Serializable {
+  def rdd: RDD[T]
+}
 
 
 /**
@@ -36,20 +38,24 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
  * partition of the child RDD.  Narrow dependencies allow for pipelined execution.
  */
 @DeveloperApi
-abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
+abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
   /**
    * Get the parent partitions for a child partition.
    * @param partitionId a partition of the child RDD
    * @return the partitions of the parent RDD that the child partition depends upon
    */
   def getParents(partitionId: Int): Seq[Int]
+
+  override def rdd: RDD[T] = _rdd
 }
 
 
 /**
  * :: DeveloperApi ::
- * Represents a dependency on the output of a shuffle stage.
- * @param rdd the parent RDD
+ * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
+ * the RDD is transient since we don't need it on the executor side.
+ *
+ * @param _rdd the parent RDD
  * @param partitioner partitioner used to partition the shuffle output
  * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
  *                   the default serializer, as specified by `spark.serializer` config option, will
@@ -57,20 +63,22 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
  */
 @DeveloperApi
 class ShuffleDependency[K, V, C](
-    @transient rdd: RDD[_ <: Product2[K, V]],
+    @transient _rdd: RDD[_ <: Product2[K, V]],
     val partitioner: Partitioner,
     val serializer: Option[Serializer] = None,
     val keyOrdering: Option[Ordering[K]] = None,
     val aggregator: Option[Aggregator[K, V, C]] = None,
     val mapSideCombine: Boolean = false)
-  extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
+  extends Dependency[Product2[K, V]] {
+
+  override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]]
 
-  val shuffleId: Int = rdd.context.newShuffleId()
+  val shuffleId: Int = _rdd.context.newShuffleId()
 
-  val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
-    shuffleId, rdd.partitions.size, this)
+  val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
+    shuffleId, _rdd.partitions.size, this)
 
-  rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
+  _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 3e6adde..fb4c867 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -997,8 +997,6 @@ class SparkContext(config: SparkConf) extends Logging {
       // TODO: Cache.stop()?
       env.stop()
       SparkEnv.set(null)
-      ShuffleMapTask.clearCache()
-      ResultTask.clearCache()
       listenerBus.stop()
       eventLogger.foreach(_.stop())
       logInfo("Successfully stopped SparkContext")

http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index a6abc49..726b3f2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._
 import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.partial.BoundedDouble
 import org.apache.spark.partial.CountEvaluator
 import org.apache.spark.partial.GroupedCountEvaluator
 import org.apache.spark.partial.PartialResult
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils}
+import org.apache.spark.util.{BoundedPriorityQueue, Utils}
 import org.apache.spark.util.collection.OpenHashMap
 import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
 
@@ -1206,16 +1207,12 @@ abstract class RDD[T: ClassTag](
   /**
    * Return whether this RDD has been checkpointed or not
    */
-  def isCheckpointed: Boolean = {
-    checkpointData.map(_.isCheckpointed).getOrElse(false)
-  }
+  def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)
 
   /**
    * Gets the name of the file to which this RDD was checkpointed
    */
-  def getCheckpointFile: Option[String] = {
-    checkpointData.flatMap(_.getCheckpointFile)
-  }
+  def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)
 
   // =======================================================================
   // Other internal methods and fields

http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index c3b2a33..f67e5f1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
       cpRDD = Some(newRDD)
       rdd.markCheckpointed(newRDD)   // Update the RDD's dependencies and partitions
       cpState = Checkpointed
-      RDDCheckpointData.clearTaskCaches()
     }
     logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
   }
@@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
   }
 }
 
-private[spark] object RDDCheckpointData {
-  def clearTaskCaches() {
-    ShuffleMapTask.clearCache()
-    ResultTask.clearCache()
-  }
-}
+// Used for synchronization
+private[spark] object RDDCheckpointData

http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index dc6142a..50186d0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.scheduler
 
-import java.io.{NotSerializableException, PrintWriter, StringWriter}
+import java.io.NotSerializableException
 import java.util.Properties
 import java.util.concurrent.atomic.AtomicInteger
 
@@ -35,6 +35,7 @@ import akka.pattern.ask
 import akka.util.Timeout
 
 import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
 import org.apache.spark.rdd.RDD
@@ -114,6 +115,10 @@ class DAGScheduler(
   private val dagSchedulerActorSupervisor =
     env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))
 
+  // A closure serializer that we reuse.
+  // This is only safe because DAGScheduler runs in a single thread.
+  private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
+
   private[scheduler] var eventProcessActor: ActorRef = _
 
   private def initializeEventProcessActor() {
@@ -361,9 +366,6 @@ class DAGScheduler(
               // data structures based on StageId
               stageIdToStage -= stageId
 
-              ShuffleMapTask.removeStage(stageId)
-              ResultTask.removeStage(stageId)
-
               logDebug("After removal of stage %d, remaining stages = %d"
                 .format(stageId, stageIdToStage.size))
             }
@@ -691,49 +693,83 @@ class DAGScheduler(
     }
   }
 
-
   /** Called when stage's parents are available and we can now do its task. */
   private def submitMissingTasks(stage: Stage, jobId: Int) {
     logDebug("submitMissingTasks(" + stage + ")")
     // Get our pending tasks and remember them in our pendingTasks entry
     stage.pendingTasks.clear()
     var tasks = ArrayBuffer[Task[_]]()
+
+    val properties = if (jobIdToActiveJob.contains(jobId)) {
+      jobIdToActiveJob(stage.jobId).properties
+    } else {
+      // this stage will be assigned to "default" pool
+      null
+    }
+
+    runningStages += stage
+    // SparkListenerStageSubmitted should be posted before testing whether tasks are
+    // serializable. If tasks are not serializable, a SparkListenerStageCompleted event
+    // will be posted, which should always come after a corresponding SparkListenerStageSubmitted
+    // event.
+    listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))
+
+    // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
+    // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
+    // the serialized copy of the RDD and for each task we will deserialize it, which means each
+    // task gets a different copy of the RDD. This provides stronger isolation between tasks that
+    // might modify state of objects referenced in their closures. This is necessary in Hadoop
+    // where the JobConf/Configuration object is not thread-safe.
+    var taskBinary: Broadcast[Array[Byte]] = null
+    try {
+      // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
+      // For ResultTask, serialize and broadcast (rdd, func).
+      val taskBinaryBytes: Array[Byte] =
+        if (stage.isShuffleMap) {
+          closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array()
+        } else {
+          closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array()
+        }
+      taskBinary = sc.broadcast(taskBinaryBytes)
+    } catch {
+      // In the case of a failure during serialization, abort the stage.
+      case e: NotSerializableException =>
+        abortStage(stage, "Task not serializable: " + e.toString)
+        runningStages -= stage
+        return
+      case NonFatal(e) =>
+        abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
+        runningStages -= stage
+        return
+    }
+
     if (stage.isShuffleMap) {
       for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
         val locs = getPreferredLocs(stage.rdd, p)
-        tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
+        val part = stage.rdd.partitions(p)
+        tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs)
       }
     } else {
       // This is a final stage; figure out its job's missing partitions
       val job = stage.resultOfJob.get
       for (id <- 0 until job.numPartitions if !job.finished(id)) {
-        val partition = job.partitions(id)
-        val locs = getPreferredLocs(stage.rdd, partition)
-        tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
+        val p: Int = job.partitions(id)
+        val part = stage.rdd.partitions(p)
+        val locs = getPreferredLocs(stage.rdd, p)
+        tasks += new ResultTask(stage.id, taskBinary, part, locs, id)
       }
     }
 
-    val properties = if (jobIdToActiveJob.contains(jobId)) {
-      jobIdToActiveJob(stage.jobId).properties
-    } else {
-      // this stage will be assigned to "default" pool
-      null
-    }
-
     if (tasks.size > 0) {
-      runningStages += stage
-      // SparkListenerStageSubmitted should be posted before testing whether tasks are
-      // serializable. If tasks are not serializable, a SparkListenerStageCompleted event
-      // will be posted, which should always come after a corresponding SparkListenerStageSubmitted
-      // event.
-      listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))
-
       // Preemptively serialize a task to make sure it can be serialized. We are catching this
       // exception here because it would be fairly hard to catch the non-serializable exception
       // down the road, where we have several different implementations for local scheduler and
       // cluster schedulers.
+      //
+      // We've already serialized RDDs and closures in taskBinary, but here we check for all other
+      // objects such as Partition.
       try {
-        SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
+        closureSerializer.serialize(tasks.head)
       } catch {
         case e: NotSerializableException =>
           abortStage(stage, "Task not serializable: " + e.toString)
@@ -752,6 +788,9 @@ class DAGScheduler(
         new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
       stage.info.submissionTime = Some(clock.getTime())
     } else {
+      // Because we posted SparkListenerStageSubmitted earlier, we should post
+      // SparkListenerStageCompleted here in case there are no tasks to run.
+      listenerBus.post(SparkListenerStageCompleted(stage.info))
       logDebug("Stage " + stage + " is actually done; %b %d %d".format(
         stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
       runningStages -= stage

http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index bbf9f73..d09fd7a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -17,134 +17,56 @@
 
 package org.apache.spark.scheduler
 
-import scala.language.existentials
+import java.nio.ByteBuffer
 
 import java.io._
-import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-
-import scala.collection.mutable.HashMap
 
 import org.apache.spark._
-import org.apache.spark.rdd.{RDD, RDDCheckpointData}
-
-private[spark] object ResultTask {
-
-  // A simple map between the stage id to the serialized byte array of a task.
-  // Served as a cache for task serialization because serialization can be
-  // expensive on the master node if it needs to launch thousands of tasks.
-  private val serializedInfoCache = new HashMap[Int, Array[Byte]]
-
-  def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] =
-  {
-    synchronized {
-      val old = serializedInfoCache.get(stageId).orNull
-      if (old != null) {
-        old
-      } else {
-        val out = new ByteArrayOutputStream
-        val ser = SparkEnv.get.closureSerializer.newInstance()
-        val objOut = ser.serializeStream(new GZIPOutputStream(out))
-        objOut.writeObject(rdd)
-        objOut.writeObject(func)
-        objOut.close()
-        val bytes = out.toByteArray
-        serializedInfoCache.put(stageId, bytes)
-        bytes
-      }
-    }
-  }
-
-  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
-  {
-    val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
-    val ser = SparkEnv.get.closureSerializer.newInstance()
-    val objIn = ser.deserializeStream(in)
-    val rdd = objIn.readObject().asInstanceOf[RDD[_]]
-    val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
-    (rdd, func)
-  }
-
-  def removeStage(stageId: Int) {
-    serializedInfoCache.remove(stageId)
-  }
-
-  def clearCache() {
-    synchronized {
-      serializedInfoCache.clear()
-    }
-  }
-}
-
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
 
 /**
  * A task that sends back the output to the driver application.
  *
- * See [[org.apache.spark.scheduler.Task]] for more information.
+ * See [[Task]] for more information.
  *
  * @param stageId id of the stage this task belongs to
- * @param rdd input to func
- * @param func a function to apply on a partition of the RDD
- * @param _partitionId index of the number in the RDD
+ * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each
+ *                   partition of the given RDD. Once deserialized, the type should be
+ *                   (RDD[T], (TaskContext, Iterator[T]) => U).
+ * @param partition partition of the RDD this task is associated with
  * @param locs preferred task execution locations for locality scheduling
  * @param outputId index of the task in this job (a job can launch tasks on only a subset of the
  *                 input RDD's partitions).
  */
 private[spark] class ResultTask[T, U](
     stageId: Int,
-    var rdd: RDD[T],
-    var func: (TaskContext, Iterator[T]) => U,
-    _partitionId: Int,
+    taskBinary: Broadcast[Array[Byte]],
+    partition: Partition,
     @transient locs: Seq[TaskLocation],
-    var outputId: Int)
-  extends Task[U](stageId, _partitionId) with Externalizable {
+    val outputId: Int)
+  extends Task[U](stageId, partition.index) with Serializable {
 
-  def this() = this(0, null, null, 0, null, 0)
-
-  var split = if (rdd == null) null else rdd.partitions(partitionId)
-
-  @transient private val preferredLocs: Seq[TaskLocation] = {
+  @transient private[this] val preferredLocs: Seq[TaskLocation] = {
     if (locs == null) Nil else locs.toSet.toSeq
   }
 
   override def runTask(context: TaskContext): U = {
+    // Deserialize the RDD and the func using the broadcast variables.
+    val ser = SparkEnv.get.closureSerializer.newInstance()
+    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
+      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
+
     metrics = Some(context.taskMetrics)
     try {
-      func(context, rdd.iterator(split, context))
+      func(context, rdd.iterator(partition, context))
     } finally {
       context.executeOnCompleteCallbacks()
     }
   }
 
+  // This is only callable on the driver side.
   override def preferredLocations: Seq[TaskLocation] = preferredLocs
 
   override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
-
-  override def writeExternal(out: ObjectOutput) {
-    RDDCheckpointData.synchronized {
-      split = rdd.partitions(partitionId)
-      out.writeInt(stageId)
-      val bytes = ResultTask.serializeInfo(
-        stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
-      out.writeInt(bytes.length)
-      out.write(bytes)
-      out.writeInt(partitionId)
-      out.writeInt(outputId)
-      out.writeLong(epoch)
-      out.writeObject(split)
-    }
-  }
-
-  override def readExternal(in: ObjectInput) {
-    val stageId = in.readInt()
-    val numBytes = in.readInt()
-    val bytes = new Array[Byte](numBytes)
-    in.readFully(bytes)
-    val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
-    rdd = rdd_.asInstanceOf[RDD[T]]
-    func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
-    partitionId = in.readInt()
-    outputId = in.readInt()
-    epoch = in.readLong()
-    split = in.readObject().asInstanceOf[Partition]
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index fdaf1de..11255c0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -17,134 +17,55 @@
 
 package org.apache.spark.scheduler
 
-import scala.language.existentials
-
-import java.io._
-import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+import java.nio.ByteBuffer
 
-import scala.collection.mutable.HashMap
+import scala.language.existentials
 
 import org.apache.spark._
-import org.apache.spark.rdd.{RDD, RDDCheckpointData}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
 import org.apache.spark.shuffle.ShuffleWriter
 
-private[spark] object ShuffleMapTask {
-
-  // A simple map between the stage id to the serialized byte array of a task.
-  // Served as a cache for task serialization because serialization can be
-  // expensive on the master node if it needs to launch thousands of tasks.
-  private val serializedInfoCache = new HashMap[Int, Array[Byte]]
-
-  def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = {
-    synchronized {
-      val old = serializedInfoCache.get(stageId).orNull
-      if (old != null) {
-        return old
-      } else {
-        val out = new ByteArrayOutputStream
-        val ser = SparkEnv.get.closureSerializer.newInstance()
-        val objOut = ser.serializeStream(new GZIPOutputStream(out))
-        objOut.writeObject(rdd)
-        objOut.writeObject(dep)
-        objOut.close()
-        val bytes = out.toByteArray
-        serializedInfoCache.put(stageId, bytes)
-        bytes
-      }
-    }
-  }
-
-  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = {
-    val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
-    val ser = SparkEnv.get.closureSerializer.newInstance()
-    val objIn = ser.deserializeStream(in)
-    val rdd = objIn.readObject().asInstanceOf[RDD[_]]
-    val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]]
-    (rdd, dep)
-  }
-
-  // Since both the JarSet and FileSet have the same format this is used for both.
-  def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = {
-    val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
-    val objIn = new ObjectInputStream(in)
-    val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
-    HashMap(set.toSeq: _*)
-  }
-
-  def removeStage(stageId: Int) {
-    serializedInfoCache.remove(stageId)
-  }
-
-  def clearCache() {
-    synchronized {
-      serializedInfoCache.clear()
-    }
-  }
-}
-
 /**
- * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
- * specified in the ShuffleDependency).
- *
- * See [[org.apache.spark.scheduler.Task]] for more information.
- *
+* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
+* specified in the ShuffleDependency).
+*
+* See [[org.apache.spark.scheduler.Task]] for more information.
+*
  * @param stageId id of the stage this task belongs to
- * @param rdd the final RDD in this stage
- * @param dep the ShuffleDependency
- * @param _partitionId index of the number in the RDD
+ * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized,
+ *                   the type should be (RDD[_], ShuffleDependency[_, _, _]).
+ * @param partition partition of the RDD this task is associated with
  * @param locs preferred task execution locations for locality scheduling
  */
 private[spark] class ShuffleMapTask(
     stageId: Int,
-    var rdd: RDD[_],
-    var dep: ShuffleDependency[_, _, _],
-    _partitionId: Int,
+    taskBinary: Broadcast[Array[Byte]],
+    partition: Partition,
     @transient private var locs: Seq[TaskLocation])
-  extends Task[MapStatus](stageId, _partitionId)
-  with Externalizable
-  with Logging {
+  extends Task[MapStatus](stageId, partition.index) with Logging {
 
-  protected def this() = this(0, null, null, 0, null)
+  /** A constructor used only in test suites. This does not require passing in an RDD. */
+  def this(partitionId: Int) {
+    this(0, null, new Partition { override def index = 0 }, null)
+  }
 
   @transient private val preferredLocs: Seq[TaskLocation] = {
     if (locs == null) Nil else locs.toSet.toSeq
   }
 
-  var split = if (rdd == null) null else rdd.partitions(partitionId)
-
-  override def writeExternal(out: ObjectOutput) {
-    RDDCheckpointData.synchronized {
-      split = rdd.partitions(partitionId)
-      out.writeInt(stageId)
-      val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
-      out.writeInt(bytes.length)
-      out.write(bytes)
-      out.writeInt(partitionId)
-      out.writeLong(epoch)
-      out.writeObject(split)
-    }
-  }
-
-  override def readExternal(in: ObjectInput) {
-    val stageId = in.readInt()
-    val numBytes = in.readInt()
-    val bytes = new Array[Byte](numBytes)
-    in.readFully(bytes)
-    val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
-    rdd = rdd_
-    dep = dep_
-    partitionId = in.readInt()
-    epoch = in.readLong()
-    split = in.readObject().asInstanceOf[Partition]
-  }
-
   override def runTask(context: TaskContext): MapStatus = {
+    // Deserialize the RDD using the broadcast variable.
+    val ser = SparkEnv.get.closureSerializer.newInstance()
+    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
+      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
+
     metrics = Some(context.taskMetrics)
     var writer: ShuffleWriter[Any, Any] = null
     try {
       val manager = SparkEnv.get.shuffleManager
       writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
-      writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
+      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
       return writer.stop(success = true).get
     } catch {
       case e: Exception =>

http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 69f65b4..f8fbb3a 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -38,7 +38,7 @@ import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
 import org.json4s._
 import tachyon.client.{TachyonFile,TachyonFS}
 
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+import org.apache.spark._
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.executor.ExecutorUncaughtExceptionHandler
 import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}

http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 13b415c..ad20f9b 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark
 
 import java.lang.ref.WeakReference
 
+import org.apache.spark.broadcast.Broadcast
+
+import scala.collection.mutable
 import scala.collection.mutable.{HashSet, SynchronizedSet}
 import scala.language.existentials
 import scala.language.postfixOps
@@ -52,9 +55,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     }
   }
 
-
   test("cleanup RDD") {
-    val rdd = newRDD.persist()
+    val rdd = newRDD().persist()
     val collected = rdd.collect().toList
     val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
 
@@ -67,7 +69,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("cleanup shuffle") {
-    val (rdd, shuffleDeps) = newRDDWithShuffleDependencies
+    val (rdd, shuffleDeps) = newRDDWithShuffleDependencies()
     val collected = rdd.collect().toList
     val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
 
@@ -80,7 +82,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("cleanup broadcast") {
-    val broadcast = newBroadcast
+    val broadcast = newBroadcast()
     val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
 
     // Explicit cleanup
@@ -89,7 +91,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("automatically cleanup RDD") {
-    var rdd = newRDD.persist()
+    var rdd = newRDD().persist()
     rdd.count()
 
     // Test that GC does not cause RDD cleanup due to a strong reference
@@ -107,7 +109,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("automatically cleanup shuffle") {
-    var rdd = newShuffleRDD
+    var rdd = newShuffleRDD()
     rdd.count()
 
     // Test that GC does not cause shuffle cleanup due to a strong reference
@@ -125,7 +127,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("automatically cleanup broadcast") {
-    var broadcast = newBroadcast
+    var broadcast = newBroadcast()
 
     // Test that GC does not cause broadcast cleanup due to a strong reference
     val preGCTester =  new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
@@ -144,11 +146,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   test("automatically cleanup RDD + shuffle + broadcast") {
     val numRdds = 100
     val numBroadcasts = 4 // Broadcasts are more costly
-    val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
-    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+    val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
+    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
     val rddIds = sc.persistentRdds.keys.toSeq
     val shuffleIds = 0 until sc.newShuffleId
-    val broadcastIds = 0L until numBroadcasts
+    val broadcastIds = broadcastBuffer.map(_.id)
 
     val preGCTester =  new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
     runGC()
@@ -162,6 +164,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     rddBuffer.clear()
     runGC()
     postGCTester.assertCleanup()
+
+    // Make sure the broadcasted task closure no longer exists after GC.
+    val taskClosureBroadcastId = broadcastIds.max + 1
+    assert(sc.env.blockManager.master.getMatchingBlockIds({
+      case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
+      case _ => false
+    }, askSlaves = true).isEmpty)
   }
 
   test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
@@ -175,11 +184,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
 
     val numRdds = 10
     val numBroadcasts = 4 // Broadcasts are more costly
-    val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
-    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+    val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
+    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
     val rddIds = sc.persistentRdds.keys.toSeq
     val shuffleIds = 0 until sc.newShuffleId
-    val broadcastIds = 0L until numBroadcasts
+    val broadcastIds = broadcastBuffer.map(_.id)
 
     val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
     runGC()
@@ -193,21 +202,29 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     rddBuffer.clear()
     runGC()
     postGCTester.assertCleanup()
+
+    // Make sure the broadcasted task closure no longer exists after GC.
+    val taskClosureBroadcastId = broadcastIds.max + 1
+    assert(sc.env.blockManager.master.getMatchingBlockIds({
+      case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
+      case _ => false
+    }, askSlaves = true).isEmpty)
   }
 
   //------ Helper functions ------
 
-  def newRDD = sc.makeRDD(1 to 10)
-  def newPairRDD = newRDD.map(_ -> 1)
-  def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
-  def newBroadcast = sc.broadcast(1 to 100)
-  def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
+  private def newRDD() = sc.makeRDD(1 to 10)
+  private def newPairRDD() = newRDD().map(_ -> 1)
+  private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
+  private def newBroadcast() = sc.broadcast(1 to 100)
+
+  private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
     def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
       rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
         getAllDependencies(dep.rdd)
       }
     }
-    val rdd = newShuffleRDD
+    val rdd = newShuffleRDD()
 
     // Get all the shuffle dependencies
     val shuffleDeps = getAllDependencies(rdd)
@@ -216,34 +233,34 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     (rdd, shuffleDeps)
   }
 
-  def randomRdd = {
+  private def randomRdd() = {
     val rdd: RDD[_] = Random.nextInt(3) match {
-      case 0 => newRDD
-      case 1 => newShuffleRDD
-      case 2 => newPairRDD.join(newPairRDD)
+      case 0 => newRDD()
+      case 1 => newShuffleRDD()
+      case 2 => newPairRDD.join(newPairRDD())
     }
     if (Random.nextBoolean()) rdd.persist()
     rdd.count()
     rdd
   }
 
-  def randomBroadcast = {
+  private def randomBroadcast() = {
     sc.broadcast(Random.nextInt(Int.MaxValue))
   }
 
   /** Run GC and make sure it actually has run */
-  def runGC() {
+  private def runGC() {
     val weakRef = new WeakReference(new Object())
     val startTime = System.currentTimeMillis
     System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
     // Wait until a weak reference object has been GCed
-    while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+    while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
       System.gc()
       Thread.sleep(200)
     }
   }
 
-  def cleaner = sc.cleaner.get
+  private def cleaner = sc.cleaner.get
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index fdc83bc..4953d56 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -155,19 +155,13 @@ class RDDSuite extends FunSuite with SharedSparkContext {
       override def getPartitions: Array[Partition] = Array(onlySplit)
       override val getDependencies = List[Dependency[_]]()
       override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
-        if (shouldFail) {
-          throw new Exception("injected failure")
-        } else {
-          Array(1, 2, 3, 4).iterator
-        }
+        throw new Exception("injected failure")
       }
     }.cache()
     val thrown = intercept[Exception]{
       rdd.collect()
     }
     assert(thrown.getMessage.contains("injected failure"))
-    shouldFail = false
-    assert(rdd.collect().toList === List(1, 2, 3, 4))
   }
 
   test("empty RDD") {

http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 8bb5317..270f7e6 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -20,31 +20,35 @@ package org.apache.spark.scheduler
 import org.scalatest.FunSuite
 import org.scalatest.BeforeAndAfter
 
-import org.apache.spark.LocalSparkContext
-import org.apache.spark.Partition
-import org.apache.spark.SparkContext
-import org.apache.spark.TaskContext
+import org.apache.spark._
 import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
 
 class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
 
   test("Calls executeOnCompleteCallbacks after failure") {
-    var completed = false
+    TaskContextSuite.completed = false
     sc = new SparkContext("local", "test")
     val rdd = new RDD[String](sc, List()) {
       override def getPartitions = Array[Partition](StubPartition(0))
       override def compute(split: Partition, context: TaskContext) = {
-        context.addOnCompleteCallback(() => completed = true)
+        context.addOnCompleteCallback(() => TaskContextSuite.completed = true)
         sys.error("failed")
       }
     }
-    val func = (c: TaskContext, i: Iterator[String]) => i.next
-    val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0)
+    val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
+    val func = (c: TaskContext, i: Iterator[String]) => i.next()
+    val task = new ResultTask[String, String](
+      0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
     intercept[RuntimeException] {
       task.run(0)
     }
-    assert(completed === true)
+    assert(TaskContextSuite.completed === true)
   }
+}
 
-  case class StubPartition(val index: Int) extends Partition
+private object TaskContextSuite {
+  @volatile var completed = false
 }
+
+private case class StubPartition(index: Int) extends Partition

http://git-wip-us.apache.org/repos/asf/spark/blob/774142f5/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index b52f818..86a271e 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.scheduler._
 import org.apache.spark.util.Utils
 
 class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers {
+
   test("test LRU eviction of stages") {
     val conf = new SparkConf()
     conf.set("spark.ui.retainedStages", 5.toString)
@@ -66,7 +67,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
     taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
     var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
     taskInfo.finishTime = 1
-    var task = new ShuffleMapTask(0, null, null, 0, null)
+    var task = new ShuffleMapTask(0)
     val taskType = Utils.getFormattedClassName(task)
     listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
     assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail())
@@ -76,14 +77,14 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
     taskInfo =
       new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true)
     taskInfo.finishTime = 1
-    task = new ShuffleMapTask(0, null, null, 0, null)
+    task = new ShuffleMapTask(0)
     listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
     assert(listener.stageIdToData.size === 1)
 
     // finish this task, should get updated duration
     taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
     taskInfo.finishTime = 1
-    task = new ShuffleMapTask(0, null, null, 0, null)
+    task = new ShuffleMapTask(0)
     listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
     assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail())
       .shuffleRead === 2000)
@@ -91,7 +92,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
     // finish this task, should get updated duration
     taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false)
     taskInfo.finishTime = 1
-    task = new ShuffleMapTask(0, null, null, 0, null)
+    task = new ShuffleMapTask(0)
     listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
     assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-2", fail())
       .shuffleRead === 1000)
@@ -103,7 +104,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
     val metrics = new TaskMetrics()
     val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
     taskInfo.finishTime = 1
-    val task = new ShuffleMapTask(0, null, null, 0, null)
+    val task = new ShuffleMapTask(0)
     val taskType = Utils.getFormattedClassName(task)
 
     // Go through all the failure cases to make sure we are counting them as failures.