You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2016/01/27 20:15:56 UTC

[3/4] spark git commit: [SPARK-12895][SPARK-12896] Migrate TaskMetrics to accumulators

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index a79ab86..3204e6a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -212,6 +212,8 @@ class HadoopRDD[K, V](
       logInfo("Input split: " + split.inputSplit)
       val jobConf = getJobConf()
 
+      // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD
+
       val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
 
       // Sets the thread local variable for the file's name
@@ -222,14 +224,17 @@ class HadoopRDD[K, V](
 
       // Find a function that will return the FileSystem bytes read by this thread. Do this before
       // creating RecordReader, because RecordReader's constructor might read some bytes
-      val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
-        split.inputSplit.value match {
-          case _: FileSplit | _: CombineFileSplit =>
-            SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
-          case _ => None
+      val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match {
+        case _: FileSplit | _: CombineFileSplit =>
+          SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+        case _ => None
+      }
+
+      def updateBytesRead(): Unit = {
+        getBytesReadCallback.foreach { getBytesRead =>
+          inputMetrics.setBytesRead(getBytesRead())
         }
       }
-      inputMetrics.setBytesReadCallback(bytesReadCallback)
 
       var reader: RecordReader[K, V] = null
       val inputFormat = getInputFormat(jobConf)
@@ -252,6 +257,9 @@ class HadoopRDD[K, V](
         if (!finished) {
           inputMetrics.incRecordsRead(1)
         }
+        if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+          updateBytesRead()
+        }
         (key, value)
       }
 
@@ -272,8 +280,8 @@ class HadoopRDD[K, V](
           } finally {
             reader = null
           }
-          if (bytesReadCallback.isDefined) {
-            inputMetrics.updateBytesRead()
+          if (getBytesReadCallback.isDefined) {
+            updateBytesRead()
           } else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
                      split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
             // If we can't get the bytes read from the FS stats, fall back to the split size,

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 5cc9c81..4d2816e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -133,14 +133,17 @@ class NewHadoopRDD[K, V](
 
       // Find a function that will return the FileSystem bytes read by this thread. Do this before
       // creating RecordReader, because RecordReader's constructor might read some bytes
-      val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
-        split.serializableHadoopSplit.value match {
-          case _: FileSplit | _: CombineFileSplit =>
-            SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
-          case _ => None
+      val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match {
+        case _: FileSplit | _: CombineFileSplit =>
+          SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+        case _ => None
+      }
+
+      def updateBytesRead(): Unit = {
+        getBytesReadCallback.foreach { getBytesRead =>
+          inputMetrics.setBytesRead(getBytesRead())
         }
       }
-      inputMetrics.setBytesReadCallback(bytesReadCallback)
 
       val format = inputFormatClass.newInstance
       format match {
@@ -182,6 +185,9 @@ class NewHadoopRDD[K, V](
         if (!finished) {
           inputMetrics.incRecordsRead(1)
         }
+        if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+          updateBytesRead()
+        }
         (reader.getCurrentKey, reader.getCurrentValue)
       }
 
@@ -201,8 +207,8 @@ class NewHadoopRDD[K, V](
           } finally {
             reader = null
           }
-          if (bytesReadCallback.isDefined) {
-            inputMetrics.updateBytesRead()
+          if (getBytesReadCallback.isDefined) {
+            updateBytesRead()
           } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
                      split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
             // If we can't get the bytes read from the FS stats, fall back to the split size,

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
index 146cfb9..9d45fff 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
@@ -19,47 +19,58 @@ package org.apache.spark.scheduler
 
 import org.apache.spark.annotation.DeveloperApi
 
+
 /**
  * :: DeveloperApi ::
  * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage.
+ *
+ * Note: once this is JSON serialized the types of `update` and `value` will be lost and be
+ * cast to strings. This is because the user can define an accumulator of any type and it will
+ * be difficult to preserve the type in consumers of the event log. This does not apply to
+ * internal accumulators that represent task level metrics.
+ *
+ * @param id accumulator ID
+ * @param name accumulator name
+ * @param update partial value from a task, may be None if used on driver to describe a stage
+ * @param value total accumulated value so far, maybe None if used on executors to describe a task
+ * @param internal whether this accumulator was internal
+ * @param countFailedValues whether to count this accumulator's partial value if the task failed
  */
 @DeveloperApi
-class AccumulableInfo private[spark] (
-    val id: Long,
-    val name: String,
-    val update: Option[String], // represents a partial update within a task
-    val value: String,
-    val internal: Boolean) {
-
-  override def equals(other: Any): Boolean = other match {
-    case acc: AccumulableInfo =>
-      this.id == acc.id && this.name == acc.name &&
-        this.update == acc.update && this.value == acc.value &&
-        this.internal == acc.internal
-    case _ => false
-  }
+case class AccumulableInfo private[spark] (
+    id: Long,
+    name: Option[String],
+    update: Option[Any], // represents a partial update within a task
+    value: Option[Any],
+    private[spark] val internal: Boolean,
+    private[spark] val countFailedValues: Boolean)
 
-  override def hashCode(): Int = {
-    val state = Seq(id, name, update, value, internal)
-    state.map(_.hashCode).reduceLeft(31 * _ + _)
-  }
-}
 
+/**
+ * A collection of deprecated constructors. This will be removed soon.
+ */
 object AccumulableInfo {
+
+  @deprecated("do not create AccumulableInfo", "2.0.0")
   def apply(
       id: Long,
       name: String,
       update: Option[String],
       value: String,
       internal: Boolean): AccumulableInfo = {
-    new AccumulableInfo(id, name, update, value, internal)
+    new AccumulableInfo(
+      id, Option(name), update, Option(value), internal, countFailedValues = false)
   }
 
+  @deprecated("do not create AccumulableInfo", "2.0.0")
   def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = {
-    new AccumulableInfo(id, name, update, value, internal = false)
+    new AccumulableInfo(
+      id, Option(name), update, Option(value), internal = false, countFailedValues = false)
   }
 
+  @deprecated("do not create AccumulableInfo", "2.0.0")
   def apply(id: Long, name: String, value: String): AccumulableInfo = {
-    new AccumulableInfo(id, name, None, value, internal = false)
+    new AccumulableInfo(
+      id, Option(name), None, Option(value), internal = false, countFailedValues = false)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 6b01a10..897479b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -208,11 +208,10 @@ class DAGScheduler(
       task: Task[_],
       reason: TaskEndReason,
       result: Any,
-      accumUpdates: Map[Long, Any],
-      taskInfo: TaskInfo,
-      taskMetrics: TaskMetrics): Unit = {
+      accumUpdates: Seq[AccumulableInfo],
+      taskInfo: TaskInfo): Unit = {
     eventProcessLoop.post(
-      CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
+      CompletionEvent(task, reason, result, accumUpdates, taskInfo))
   }
 
   /**
@@ -222,9 +221,10 @@ class DAGScheduler(
    */
   def executorHeartbeatReceived(
       execId: String,
-      taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics)
+      // (taskId, stageId, stageAttemptId, accumUpdates)
+      accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])],
       blockManagerId: BlockManagerId): Boolean = {
-    listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
+    listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates))
     blockManagerMaster.driverEndpoint.askWithRetry[Boolean](
       BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat"))
   }
@@ -1074,39 +1074,43 @@ class DAGScheduler(
     }
   }
 
-  /** Merge updates from a task to our local accumulator values */
+  /**
+   * Merge local values from a task into the corresponding accumulators previously registered
+   * here on the driver.
+   *
+   * Although accumulators themselves are not thread-safe, this method is called only from one
+   * thread, the one that runs the scheduling loop. This means we only handle one task
+   * completion event at a time so we don't need to worry about locking the accumulators.
+   * This still doesn't stop the caller from updating the accumulator outside the scheduler,
+   * but that's not our problem since there's nothing we can do about that.
+   */
   private def updateAccumulators(event: CompletionEvent): Unit = {
     val task = event.task
     val stage = stageIdToStage(task.stageId)
-    if (event.accumUpdates != null) {
-      try {
-        Accumulators.add(event.accumUpdates)
-
-        event.accumUpdates.foreach { case (id, partialValue) =>
-          // In this instance, although the reference in Accumulators.originals is a WeakRef,
-          // it's guaranteed to exist since the event.accumUpdates Map exists
-
-          val acc = Accumulators.originals(id).get match {
-            case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
-            case None => throw new NullPointerException("Non-existent reference to Accumulator")
-          }
-
-          // To avoid UI cruft, ignore cases where value wasn't updated
-          if (acc.name.isDefined && partialValue != acc.zero) {
-            val name = acc.name.get
-            val value = s"${acc.value}"
-            stage.latestInfo.accumulables(id) =
-              new AccumulableInfo(id, name, None, value, acc.isInternal)
-            event.taskInfo.accumulables +=
-              new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal)
-          }
+    try {
+      event.accumUpdates.foreach { ainfo =>
+        assert(ainfo.update.isDefined, "accumulator from task should have a partial value")
+        val id = ainfo.id
+        val partialValue = ainfo.update.get
+        // Find the corresponding accumulator on the driver and update it
+        val acc: Accumulable[Any, Any] = Accumulators.get(id) match {
+          case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
+          case None =>
+            throw new SparkException(s"attempted to access non-existent accumulator $id")
+        }
+        acc ++= partialValue
+        // To avoid UI cruft, ignore cases where value wasn't updated
+        if (acc.name.isDefined && partialValue != acc.zero) {
+          val name = acc.name
+          stage.latestInfo.accumulables(id) = new AccumulableInfo(
+            id, name, None, Some(acc.value), acc.isInternal, acc.countFailedValues)
+          event.taskInfo.accumulables += new AccumulableInfo(
+            id, name, Some(partialValue), Some(acc.value), acc.isInternal, acc.countFailedValues)
         }
-      } catch {
-        // If we see an exception during accumulator update, just log the
-        // error and move on.
-        case e: Exception =>
-          logError(s"Failed to update accumulators for $task", e)
       }
+    } catch {
+      case NonFatal(e) =>
+        logError(s"Failed to update accumulators for task ${task.partitionId}", e)
     }
   }
 
@@ -1116,6 +1120,7 @@ class DAGScheduler(
    */
   private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
     val task = event.task
+    val taskId = event.taskInfo.id
     val stageId = task.stageId
     val taskType = Utils.getFormattedClassName(task)
 
@@ -1125,12 +1130,26 @@ class DAGScheduler(
       event.taskInfo.attemptNumber, // this is a task attempt number
       event.reason)
 
-    // The success case is dealt with separately below, since we need to compute accumulator
-    // updates before posting.
+    // Reconstruct task metrics. Note: this may be null if the task has failed.
+    val taskMetrics: TaskMetrics =
+      if (event.accumUpdates.nonEmpty) {
+        try {
+          TaskMetrics.fromAccumulatorUpdates(event.accumUpdates)
+        } catch {
+          case NonFatal(e) =>
+            logError(s"Error when attempting to reconstruct metrics for task $taskId", e)
+            null
+        }
+      } else {
+        null
+      }
+
+    // The success case is dealt with separately below.
+    // TODO: Why post it only for failed tasks in cancelled stages? Clarify semantics here.
     if (event.reason != Success) {
       val attemptId = task.stageAttemptId
-      listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason,
-        event.taskInfo, event.taskMetrics))
+      listenerBus.post(SparkListenerTaskEnd(
+        stageId, attemptId, taskType, event.reason, event.taskInfo, taskMetrics))
     }
 
     if (!stageIdToStage.contains(task.stageId)) {
@@ -1142,7 +1161,7 @@ class DAGScheduler(
     event.reason match {
       case Success =>
         listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType,
-          event.reason, event.taskInfo, event.taskMetrics))
+          event.reason, event.taskInfo, taskMetrics))
         stage.pendingPartitions -= task.partitionId
         task match {
           case rt: ResultTask[_, _] =>
@@ -1291,7 +1310,8 @@ class DAGScheduler(
         // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
 
       case exceptionFailure: ExceptionFailure =>
-        // Do nothing here, left up to the TaskScheduler to decide how to handle user failures
+        // Tasks failed with exceptions might still have accumulator updates.
+        updateAccumulators(event)
 
       case TaskResultLost =>
         // Do nothing here; the TaskScheduler handles these failures and resubmits the task.
@@ -1637,7 +1657,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
     case GettingResultEvent(taskInfo) =>
       dagScheduler.handleGetTaskResult(taskInfo)
 
-    case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
+    case completion: CompletionEvent =>
       dagScheduler.handleTaskCompletion(completion)
 
     case TaskSetFailed(taskSet, reason, exception) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index dda3b6c..d5cd2da 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -73,9 +73,8 @@ private[scheduler] case class CompletionEvent(
     task: Task[_],
     reason: TaskEndReason,
     result: Any,
-    accumUpdates: Map[Long, Any],
-    taskInfo: TaskInfo,
-    taskMetrics: TaskMetrics)
+    accumUpdates: Seq[AccumulableInfo],
+    taskInfo: TaskInfo)
   extends DAGSchedulerEvent
 
 private[scheduler] case class ExecutorAdded(execId: String, host: String) extends DAGSchedulerEvent

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 6590cf6..885f70e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD
  * See [[Task]] for more information.
  *
  * @param stageId id of the stage this task belongs to
+ * @param stageAttemptId attempt id of the stage this task belongs to
  * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each
  *                   partition of the given RDD. Once deserialized, the type should be
  *                   (RDD[T], (TaskContext, Iterator[T]) => U).
@@ -37,6 +38,9 @@ import org.apache.spark.rdd.RDD
  * @param locs preferred task execution locations for locality scheduling
  * @param outputId index of the task in this job (a job can launch tasks on only a subset of the
  *                 input RDD's partitions).
+ * @param _initialAccums initial set of accumulators to be used in this task for tracking
+ *                       internal metrics. Other accumulators will be registered later when
+ *                       they are deserialized on the executors.
  */
 private[spark] class ResultTask[T, U](
     stageId: Int,
@@ -45,8 +49,8 @@ private[spark] class ResultTask[T, U](
     partition: Partition,
     locs: Seq[TaskLocation],
     val outputId: Int,
-    internalAccumulators: Seq[Accumulator[Long]])
-  extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators)
+    _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.create())
+  extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums)
   with Serializable {
 
   @transient private[this] val preferredLocs: Seq[TaskLocation] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index ea97ef0..89207dd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -33,10 +33,14 @@ import org.apache.spark.shuffle.ShuffleWriter
  * See [[org.apache.spark.scheduler.Task]] for more information.
  *
  * @param stageId id of the stage this task belongs to
+ * @param stageAttemptId attempt id of the stage this task belongs to
  * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized,
  *                   the type should be (RDD[_], ShuffleDependency[_, _, _]).
  * @param partition partition of the RDD this task is associated with
  * @param locs preferred task execution locations for locality scheduling
+ * @param _initialAccums initial set of accumulators to be used in this task for tracking
+ *                       internal metrics. Other accumulators will be registered later when
+ *                       they are deserialized on the executors.
  */
 private[spark] class ShuffleMapTask(
     stageId: Int,
@@ -44,8 +48,8 @@ private[spark] class ShuffleMapTask(
     taskBinary: Broadcast[Array[Byte]],
     partition: Partition,
     @transient private var locs: Seq[TaskLocation],
-    internalAccumulators: Seq[Accumulator[Long]])
-  extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators)
+    _initialAccums: Seq[Accumulator[_]])
+  extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums)
   with Logging {
 
   /** A constructor used only in test suites. This does not require passing in an RDD. */

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 6c6883d..ed3adbd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.scheduler
 
 import java.util.Properties
+import javax.annotation.Nullable
 
 import scala.collection.Map
 import scala.collection.mutable
@@ -60,7 +61,7 @@ case class SparkListenerTaskEnd(
     taskType: String,
     reason: TaskEndReason,
     taskInfo: TaskInfo,
-    taskMetrics: TaskMetrics)
+    @Nullable taskMetrics: TaskMetrics)
   extends SparkListenerEvent
 
 @DeveloperApi
@@ -111,12 +112,12 @@ case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends
 /**
  * Periodic updates from executors.
  * @param execId executor id
- * @param taskMetrics sequence of (task id, stage id, stage attempt, metrics)
+ * @param accumUpdates sequence of (taskId, stageId, stageAttemptId, accumUpdates)
  */
 @DeveloperApi
 case class SparkListenerExecutorMetricsUpdate(
     execId: String,
-    taskMetrics: Seq[(Long, Int, Int, TaskMetrics)])
+    accumUpdates: Seq[(Long, Int, Int, Seq[AccumulableInfo])])
   extends SparkListenerEvent
 
 @DeveloperApi

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 7ea24a2..c1c8b47 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -74,10 +74,10 @@ private[scheduler] abstract class Stage(
   val name: String = callSite.shortForm
   val details: String = callSite.longForm
 
-  private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty
+  private var _internalAccumulators: Seq[Accumulator[_]] = Seq.empty
 
   /** Internal accumulators shared across all tasks in this stage. */
-  def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators
+  def internalAccumulators: Seq[Accumulator[_]] = _internalAccumulators
 
   /**
    * Re-initialize the internal accumulators associated with this stage.

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index fca5792..a49f371 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.scheduler
 
-import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.io.{DataInputStream, DataOutputStream}
 import java.nio.ByteBuffer
 
 import scala.collection.mutable.HashMap
@@ -41,32 +41,29 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti
  * and divides the task output to multiple buckets (based on the task's partitioner).
  *
  * @param stageId id of the stage this task belongs to
+ * @param stageAttemptId attempt id of the stage this task belongs to
  * @param partitionId index of the number in the RDD
+ * @param initialAccumulators initial set of accumulators to be used in this task for tracking
+ *                            internal metrics. Other accumulators will be registered later when
+ *                            they are deserialized on the executors.
  */
 private[spark] abstract class Task[T](
     val stageId: Int,
     val stageAttemptId: Int,
     val partitionId: Int,
-    internalAccumulators: Seq[Accumulator[Long]]) extends Serializable {
+    val initialAccumulators: Seq[Accumulator[_]]) extends Serializable {
 
   /**
-   * The key of the Map is the accumulator id and the value of the Map is the latest accumulator
-   * local value.
-   */
-  type AccumulatorUpdates = Map[Long, Any]
-
-  /**
-   * Called by [[Executor]] to run this task.
+   * Called by [[org.apache.spark.executor.Executor]] to run this task.
    *
    * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
    * @param attemptNumber how many times this task has been attempted (0 for the first attempt)
    * @return the result of the task along with updates of Accumulators.
    */
   final def run(
-    taskAttemptId: Long,
-    attemptNumber: Int,
-    metricsSystem: MetricsSystem)
-  : (T, AccumulatorUpdates) = {
+      taskAttemptId: Long,
+      attemptNumber: Int,
+      metricsSystem: MetricsSystem): T = {
     context = new TaskContextImpl(
       stageId,
       partitionId,
@@ -74,16 +71,14 @@ private[spark] abstract class Task[T](
       attemptNumber,
       taskMemoryManager,
       metricsSystem,
-      internalAccumulators)
+      initialAccumulators)
     TaskContext.setTaskContext(context)
-    context.taskMetrics.setHostname(Utils.localHostName())
-    context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
     taskThread = Thread.currentThread()
     if (_killed) {
       kill(interruptThread = false)
     }
     try {
-      (runTask(context), context.collectAccumulators())
+      runTask(context)
     } finally {
       context.markTaskCompleted()
       try {
@@ -141,6 +136,18 @@ private[spark] abstract class Task[T](
   def executorDeserializeTime: Long = _executorDeserializeTime
 
   /**
+   * Collect the latest values of accumulators used in this task. If the task failed,
+   * filter out the accumulators whose values should not be included on failures.
+   */
+  def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulableInfo] = {
+    if (context != null) {
+      context.taskMetrics.accumulatorUpdates().filter { a => !taskFailed || a.countFailedValues }
+    } else {
+      Seq.empty[AccumulableInfo]
+    }
+  }
+
+  /**
    * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
    * code and user code to properly handle the flag. This function should be idempotent so it can
    * be called multiple times.

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index b82c7f3..03135e6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -20,11 +20,9 @@ package org.apache.spark.scheduler
 import java.io._
 import java.nio.ByteBuffer
 
-import scala.collection.Map
-import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.SparkEnv
-import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.storage.BlockId
 import org.apache.spark.util.Utils
 
@@ -36,31 +34,24 @@ private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
   extends TaskResult[T] with Serializable
 
 /** A TaskResult that contains the task's return value and accumulator updates. */
-private[spark]
-class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long, Any],
-    var metrics: TaskMetrics)
+private[spark] class DirectTaskResult[T](
+    var valueBytes: ByteBuffer,
+    var accumUpdates: Seq[AccumulableInfo])
   extends TaskResult[T] with Externalizable {
 
   private var valueObjectDeserialized = false
   private var valueObject: T = _
 
-  def this() = this(null.asInstanceOf[ByteBuffer], null, null)
+  def this() = this(null.asInstanceOf[ByteBuffer], null)
 
   override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
-
-    out.writeInt(valueBytes.remaining);
+    out.writeInt(valueBytes.remaining)
     Utils.writeByteBuffer(valueBytes, out)
-
     out.writeInt(accumUpdates.size)
-    for ((key, value) <- accumUpdates) {
-      out.writeLong(key)
-      out.writeObject(value)
-    }
-    out.writeObject(metrics)
+    accumUpdates.foreach(out.writeObject)
   }
 
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
-
     val blen = in.readInt()
     val byteVal = new Array[Byte](blen)
     in.readFully(byteVal)
@@ -70,13 +61,12 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
     if (numUpdates == 0) {
       accumUpdates = null
     } else {
-      val _accumUpdates = mutable.Map[Long, Any]()
+      val _accumUpdates = new ArrayBuffer[AccumulableInfo]
       for (i <- 0 until numUpdates) {
-        _accumUpdates(in.readLong()) = in.readObject()
+        _accumUpdates += in.readObject.asInstanceOf[AccumulableInfo]
       }
       accumUpdates = _accumUpdates
     }
-    metrics = in.readObject().asInstanceOf[TaskMetrics]
     valueObjectDeserialized = false
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index f496599..c94c4f5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.scheduler
 
 import java.nio.ByteBuffer
-import java.util.concurrent.RejectedExecutionException
+import java.util.concurrent.{ExecutorService, RejectedExecutionException}
 
 import scala.language.existentials
 import scala.util.control.NonFatal
@@ -35,9 +35,12 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
   extends Logging {
 
   private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4)
-  private val getTaskResultExecutor = ThreadUtils.newDaemonFixedThreadPool(
-    THREADS, "task-result-getter")
 
+  // Exposed for testing.
+  protected val getTaskResultExecutor: ExecutorService =
+    ThreadUtils.newDaemonFixedThreadPool(THREADS, "task-result-getter")
+
+  // Exposed for testing.
   protected val serializer = new ThreadLocal[SerializerInstance] {
     override def initialValue(): SerializerInstance = {
       sparkEnv.closureSerializer.newInstance()
@@ -45,7 +48,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
   }
 
   def enqueueSuccessfulTask(
-    taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
+      taskSetManager: TaskSetManager,
+      tid: Long,
+      serializedData: ByteBuffer): Unit = {
     getTaskResultExecutor.execute(new Runnable {
       override def run(): Unit = Utils.logUncaughtExceptions {
         try {
@@ -82,7 +87,19 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
               (deserializedResult, size)
           }
 
-          result.metrics.setResultSize(size)
+          // Set the task result size in the accumulator updates received from the executors.
+          // We need to do this here on the driver because if we did this on the executors then
+          // we would have to serialize the result again after updating the size.
+          result.accumUpdates = result.accumUpdates.map { a =>
+            if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
+              assert(a.update == Some(0L),
+                "task result size should not have been set on the executors")
+              a.copy(update = Some(size.toLong))
+            } else {
+              a
+            }
+          }
+
           scheduler.handleSuccessfulTask(taskSetManager, tid, result)
         } catch {
           case cnf: ClassNotFoundException =>

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 7c0b007..fccd6e0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -65,8 +65,10 @@ private[spark] trait TaskScheduler {
    * alive. Return true if the driver knows about the given block manager. Otherwise, return false,
    * indicating that the block manager should re-register.
    */
-  def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)],
-    blockManagerId: BlockManagerId): Boolean
+  def executorHeartbeatReceived(
+      execId: String,
+      accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+      blockManagerId: BlockManagerId): Boolean
 
   /**
    * Get an application ID associated with the job.

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 6e3ef0e..29341df 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -30,7 +30,6 @@ import scala.util.Random
 
 import org.apache.spark._
 import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 import org.apache.spark.scheduler.TaskLocality.TaskLocality
 import org.apache.spark.storage.BlockManagerId
@@ -380,17 +379,17 @@ private[spark] class TaskSchedulerImpl(
    */
   override def executorHeartbeatReceived(
       execId: String,
-      taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
+      accumUpdates: Array[(Long, Seq[AccumulableInfo])],
       blockManagerId: BlockManagerId): Boolean = {
-
-    val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
-      taskMetrics.flatMap { case (id, metrics) =>
+    // (taskId, stageId, stageAttemptId, accumUpdates)
+    val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized {
+      accumUpdates.flatMap { case (id, updates) =>
         taskIdToTaskSetManager.get(id).map { taskSetMgr =>
-          (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
+          (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, updates)
         }
       }
     }
-    dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
+    dagScheduler.executorHeartbeatReceived(execId, accumUpdatesWithTaskIds, blockManagerId)
   }
 
   def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized {

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index aa39b59..cf97877 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -621,8 +621,7 @@ private[spark] class TaskSetManager(
     // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here.
     // Note: "result.value()" only deserializes the value when it's called at the first time, so
     // here "result.value()" just returns the value and won't block other threads.
-    sched.dagScheduler.taskEnded(
-      tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics)
+    sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info)
     if (!successful(index)) {
       tasksSuccessful += 1
       logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format(
@@ -653,8 +652,7 @@ private[spark] class TaskSetManager(
     info.markFailed()
     val index = info.index
     copiesRunning(index) -= 1
-    var taskMetrics : TaskMetrics = null
-
+    var accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]
     val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " +
       reason.asInstanceOf[TaskFailedReason].toErrorString
     val failureException: Option[Throwable] = reason match {
@@ -669,7 +667,8 @@ private[spark] class TaskSetManager(
         None
 
       case ef: ExceptionFailure =>
-        taskMetrics = ef.metrics.orNull
+        // ExceptionFailure's might have accumulator updates
+        accumUpdates = ef.accumUpdates
         if (ef.className == classOf[NotSerializableException].getName) {
           // If the task result wasn't serializable, there's no point in trying to re-execute it.
           logError("Task %s in stage %s (TID %d) had a not serializable result: %s; not retrying"
@@ -721,7 +720,7 @@ private[spark] class TaskSetManager(
     // always add to failed executors
     failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()).
       put(info.executorId, clock.getTimeMillis())
-    sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics)
+    sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info)
     addPendingTask(index)
     if (!isZombie && state != TaskState.KILLED
         && reason.isInstanceOf[TaskFailedReason]
@@ -793,7 +792,8 @@ private[spark] class TaskSetManager(
           addPendingTask(index)
           // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
           // stage finishes when a total of tasks.size tasks finish.
-          sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+          sched.dagScheduler.taskEnded(
+            tasks(index), Resubmitted, null, Seq.empty[AccumulableInfo], info)
         }
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index a57e5b0..acbe160 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -103,8 +103,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
         sorter.insertAll(aggregatedIter)
         context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
         context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
-        context.internalMetricsToAccumulators(
-          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
+        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
         CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
       case None =>
         aggregatedIter

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
index 078718b..9c92a50 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
@@ -237,7 +237,8 @@ private[v1] object AllStagesResource {
   }
 
   def convertAccumulableInfo(acc: InternalAccumulableInfo): AccumulableInfo = {
-    new AccumulableInfo(acc.id, acc.name, acc.update, acc.value)
+    new AccumulableInfo(
+      acc.id, acc.name.orNull, acc.update.map(_.toString), acc.value.map(_.toString).orNull)
   }
 
   def convertUiTaskMetrics(internal: InternalTaskMetrics): TaskMetrics = {

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 4a9f8b3..b2aa8bf 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -325,12 +325,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
   override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized {
     val taskInfo = taskStart.taskInfo
     if (taskInfo != null) {
+      val metrics = new TaskMetrics
       val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), {
         logWarning("Task start for unknown stage " + taskStart.stageId)
         new StageUIData
       })
       stageData.numActiveTasks += 1
-      stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo))
+      stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo, Some(metrics)))
     }
     for (
       activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId);
@@ -387,9 +388,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
             (Some(e.toErrorString), None)
         }
 
-      if (!metrics.isEmpty) {
+      metrics.foreach { m =>
         val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics)
-        updateAggregateMetrics(stageData, info.executorId, metrics.get, oldMetrics)
+        updateAggregateMetrics(stageData, info.executorId, m, oldMetrics)
       }
 
       val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info))
@@ -489,19 +490,18 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
   }
 
   override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) {
-    for ((taskId, sid, sAttempt, taskMetrics) <- executorMetricsUpdate.taskMetrics) {
+    for ((taskId, sid, sAttempt, accumUpdates) <- executorMetricsUpdate.accumUpdates) {
       val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), {
         logWarning("Metrics update for task in unknown stage " + sid)
         new StageUIData
       })
       val taskData = stageData.taskData.get(taskId)
-      taskData.map { t =>
+      val metrics = TaskMetrics.fromAccumulatorUpdates(accumUpdates)
+      taskData.foreach { t =>
         if (!t.taskInfo.finished) {
-          updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics,
-            t.taskMetrics)
-
+          updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.taskMetrics)
           // Overwrite task metrics
-          t.taskMetrics = Some(taskMetrics)
+          t.taskMetrics = Some(metrics)
         }
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 914f618..29c5ff0 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -271,8 +271,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
         }
 
       val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value")
-      def accumulableRow(acc: AccumulableInfo): Elem =
-        <tr><td>{acc.name}</td><td>{acc.value}</td></tr>
+      def accumulableRow(acc: AccumulableInfo): Seq[Node] = {
+        (acc.name, acc.value) match {
+          case (Some(name), Some(value)) => <tr><td>{name}</td><td>{value}</td></tr>
+          case _ => Seq.empty[Node]
+        }
+      }
       val accumulableTable = UIUtils.listingTable(
         accumulableHeaders,
         accumulableRow,
@@ -404,13 +408,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
             </td> +:
             getFormattedTimeQuantiles(gettingResultTimes)
 
-          val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) =>
-            info.accumulables
-              .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
-              .map { acc => acc.update.getOrElse("0").toLong }
-              .getOrElse(0L)
-              .toDouble
-          }
+            val peakExecutionMemory = validTasks.map { case TaskUIData(_, metrics, _) =>
+              metrics.get.peakExecutionMemory.toDouble
+            }
           val peakExecutionMemoryQuantiles = {
             <td>
               <span data-toggle="tooltip"
@@ -891,15 +891,15 @@ private[ui] class TaskDataSource(
     val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
     val gettingResultTime = getGettingResultTime(info, currentTime)
 
-    val (taskInternalAccumulables, taskExternalAccumulables) =
-      info.accumulables.partition(_.internal)
-    val externalAccumulableReadable = taskExternalAccumulables.map { acc =>
-      StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")
-    }
-    val peakExecutionMemoryUsed = taskInternalAccumulables
-      .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
-      .map { acc => acc.update.getOrElse("0").toLong }
-      .getOrElse(0L)
+    val externalAccumulableReadable = info.accumulables
+      .filterNot(_.internal)
+      .flatMap { a =>
+        (a.name, a.update) match {
+          case (Some(name), Some(update)) => Some(StringEscapeUtils.escapeHtml4(s"$name: $update"))
+          case _ => None
+        }
+      }
+    val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L)
 
     val maybeInput = metrics.flatMap(_.inputMetrics)
     val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L)

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index efa22b9..dc8070c 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -233,14 +233,14 @@ private[spark] object JsonProtocol {
 
   def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = {
     val execId = metricsUpdate.execId
-    val taskMetrics = metricsUpdate.taskMetrics
+    val accumUpdates = metricsUpdate.accumUpdates
     ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~
     ("Executor ID" -> execId) ~
-    ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) =>
+      ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) =>
       ("Task ID" -> taskId) ~
       ("Stage ID" -> stageId) ~
       ("Stage Attempt ID" -> stageAttemptId) ~
-      ("Task Metrics" -> taskMetricsToJson(metrics))
+      ("Accumulator Updates" -> JArray(updates.map(accumulableInfoToJson).toList))
     })
   }
 
@@ -265,7 +265,7 @@ private[spark] object JsonProtocol {
     ("Completion Time" -> completionTime) ~
     ("Failure Reason" -> failureReason) ~
     ("Accumulables" -> JArray(
-        stageInfo.accumulables.values.map(accumulableInfoToJson).toList))
+      stageInfo.accumulables.values.map(accumulableInfoToJson).toList))
   }
 
   def taskInfoToJson(taskInfo: TaskInfo): JValue = {
@@ -284,11 +284,44 @@ private[spark] object JsonProtocol {
   }
 
   def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = {
+    val name = accumulableInfo.name
     ("ID" -> accumulableInfo.id) ~
-    ("Name" -> accumulableInfo.name) ~
-    ("Update" -> accumulableInfo.update.map(new JString(_)).getOrElse(JNothing)) ~
-    ("Value" -> accumulableInfo.value) ~
-    ("Internal" -> accumulableInfo.internal)
+    ("Name" -> name) ~
+    ("Update" -> accumulableInfo.update.map { v => accumValueToJson(name, v) }) ~
+    ("Value" -> accumulableInfo.value.map { v => accumValueToJson(name, v) }) ~
+    ("Internal" -> accumulableInfo.internal) ~
+    ("Count Failed Values" -> accumulableInfo.countFailedValues)
+  }
+
+  /**
+   * Serialize the value of an accumulator to JSON.
+   *
+   * For accumulators representing internal task metrics, this looks up the relevant
+   * [[AccumulatorParam]] to serialize the value accordingly. For all other accumulators,
+   * this will simply serialize the value as a string.
+   *
+   * The behavior here must match that of [[accumValueFromJson]]. Exposed for testing.
+   */
+  private[util] def accumValueToJson(name: Option[String], value: Any): JValue = {
+    import AccumulatorParam._
+    if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) {
+      (value, InternalAccumulator.getParam(name.get)) match {
+        case (v: Int, IntAccumulatorParam) => JInt(v)
+        case (v: Long, LongAccumulatorParam) => JInt(v)
+        case (v: String, StringAccumulatorParam) => JString(v)
+        case (v, UpdatedBlockStatusesAccumulatorParam) =>
+          JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) =>
+            ("Block ID" -> id.toString) ~
+            ("Status" -> blockStatusToJson(status))
+          })
+        case (v, p) =>
+          throw new IllegalArgumentException(s"unexpected combination of accumulator value " +
+            s"type (${v.getClass.getName}) and param (${p.getClass.getName}) in '${name.get}'")
+      }
+    } else {
+      // For all external accumulators, just use strings
+      JString(value.toString)
+    }
   }
 
   def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = {
@@ -303,9 +336,9 @@ private[spark] object JsonProtocol {
       }.getOrElse(JNothing)
     val shuffleWriteMetrics: JValue =
       taskMetrics.shuffleWriteMetrics.map { wm =>
-        ("Shuffle Bytes Written" -> wm.shuffleBytesWritten) ~
-        ("Shuffle Write Time" -> wm.shuffleWriteTime) ~
-        ("Shuffle Records Written" -> wm.shuffleRecordsWritten)
+        ("Shuffle Bytes Written" -> wm.bytesWritten) ~
+        ("Shuffle Write Time" -> wm.writeTime) ~
+        ("Shuffle Records Written" -> wm.recordsWritten)
       }.getOrElse(JNothing)
     val inputMetrics: JValue =
       taskMetrics.inputMetrics.map { im =>
@@ -324,7 +357,6 @@ private[spark] object JsonProtocol {
         ("Block ID" -> id.toString) ~
         ("Status" -> blockStatusToJson(status))
       })
-    ("Host Name" -> taskMetrics.hostname) ~
     ("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~
     ("Executor Run Time" -> taskMetrics.executorRunTime) ~
     ("Result Size" -> taskMetrics.resultSize) ~
@@ -352,12 +384,12 @@ private[spark] object JsonProtocol {
         ("Message" -> fetchFailed.message)
       case exceptionFailure: ExceptionFailure =>
         val stackTrace = stackTraceToJson(exceptionFailure.stackTrace)
-        val metrics = exceptionFailure.metrics.map(taskMetricsToJson).getOrElse(JNothing)
+        val accumUpdates = JArray(exceptionFailure.accumUpdates.map(accumulableInfoToJson).toList)
         ("Class Name" -> exceptionFailure.className) ~
         ("Description" -> exceptionFailure.description) ~
         ("Stack Trace" -> stackTrace) ~
         ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~
-        ("Metrics" -> metrics)
+        ("Accumulator Updates" -> accumUpdates)
       case taskCommitDenied: TaskCommitDenied =>
         ("Job ID" -> taskCommitDenied.jobID) ~
         ("Partition ID" -> taskCommitDenied.partitionID) ~
@@ -619,14 +651,15 @@ private[spark] object JsonProtocol {
 
   def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = {
     val execInfo = (json \ "Executor ID").extract[String]
-    val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json =>
+    val accumUpdates = (json \ "Metrics Updated").extract[List[JValue]].map { json =>
       val taskId = (json \ "Task ID").extract[Long]
       val stageId = (json \ "Stage ID").extract[Int]
       val stageAttemptId = (json \ "Stage Attempt ID").extract[Int]
-      val metrics = taskMetricsFromJson(json \ "Task Metrics")
-      (taskId, stageId, stageAttemptId, metrics)
+      val updates =
+        (json \ "Accumulator Updates").extract[List[JValue]].map(accumulableInfoFromJson)
+      (taskId, stageId, stageAttemptId, updates)
     }
-    SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics)
+    SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates)
   }
 
   /** --------------------------------------------------------------------- *
@@ -647,7 +680,7 @@ private[spark] object JsonProtocol {
     val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long])
     val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String])
     val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match {
-      case Some(values) => values.map(accumulableInfoFromJson(_))
+      case Some(values) => values.map(accumulableInfoFromJson)
       case None => Seq[AccumulableInfo]()
     }
 
@@ -675,7 +708,7 @@ private[spark] object JsonProtocol {
     val finishTime = (json \ "Finish Time").extract[Long]
     val failed = (json \ "Failed").extract[Boolean]
     val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match {
-      case Some(values) => values.map(accumulableInfoFromJson(_))
+      case Some(values) => values.map(accumulableInfoFromJson)
       case None => Seq[AccumulableInfo]()
     }
 
@@ -690,11 +723,43 @@ private[spark] object JsonProtocol {
 
   def accumulableInfoFromJson(json: JValue): AccumulableInfo = {
     val id = (json \ "ID").extract[Long]
-    val name = (json \ "Name").extract[String]
-    val update = Utils.jsonOption(json \ "Update").map(_.extract[String])
-    val value = (json \ "Value").extract[String]
+    val name = (json \ "Name").extractOpt[String]
+    val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) }
+    val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) }
     val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false)
-    AccumulableInfo(id, name, update, value, internal)
+    val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false)
+    new AccumulableInfo(id, name, update, value, internal, countFailedValues)
+  }
+
+  /**
+   * Deserialize the value of an accumulator from JSON.
+   *
+   * For accumulators representing internal task metrics, this looks up the relevant
+   * [[AccumulatorParam]] to deserialize the value accordingly. For all other
+   * accumulators, this will simply deserialize the value as a string.
+   *
+   * The behavior here must match that of [[accumValueToJson]]. Exposed for testing.
+   */
+  private[util] def accumValueFromJson(name: Option[String], value: JValue): Any = {
+    import AccumulatorParam._
+    if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) {
+      (value, InternalAccumulator.getParam(name.get)) match {
+        case (JInt(v), IntAccumulatorParam) => v.toInt
+        case (JInt(v), LongAccumulatorParam) => v.toLong
+        case (JString(v), StringAccumulatorParam) => v
+        case (JArray(v), UpdatedBlockStatusesAccumulatorParam) =>
+          v.map { blockJson =>
+            val id = BlockId((blockJson \ "Block ID").extract[String])
+            val status = blockStatusFromJson(blockJson \ "Status")
+            (id, status)
+          }
+        case (v, p) =>
+          throw new IllegalArgumentException(s"unexpected combination of accumulator " +
+            s"value in JSON ($v) and accumulator param (${p.getClass.getName}) in '${name.get}'")
+       }
+     } else {
+       value.extract[String]
+     }
   }
 
   def taskMetricsFromJson(json: JValue): TaskMetrics = {
@@ -702,7 +767,6 @@ private[spark] object JsonProtocol {
       return TaskMetrics.empty
     }
     val metrics = new TaskMetrics
-    metrics.setHostname((json \ "Host Name").extract[String])
     metrics.setExecutorDeserializeTime((json \ "Executor Deserialize Time").extract[Long])
     metrics.setExecutorRunTime((json \ "Executor Run Time").extract[Long])
     metrics.setResultSize((json \ "Result Size").extract[Long])
@@ -787,10 +851,12 @@ private[spark] object JsonProtocol {
         val className = (json \ "Class Name").extract[String]
         val description = (json \ "Description").extract[String]
         val stackTrace = stackTraceFromJson(json \ "Stack Trace")
-        val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace").
-          map(_.extract[String]).orNull
-        val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson)
-        ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None)
+        val fullStackTrace = (json \ "Full Stack Trace").extractOpt[String].orNull
+        // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x
+        val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates")
+          .map(_.extract[List[JValue]].map(accumulableInfoFromJson))
+          .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulatorUpdates())
+        ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates)
       case `taskResultLost` => TaskResultLost
       case `taskKilled` => TaskKilled
       case `taskCommitDenied` =>

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index df9e050..5afd6d6 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -682,8 +682,7 @@ private[spark] class ExternalSorter[K, V, C](
 
     context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
     context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
-    context.internalMetricsToAccumulators(
-      InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes)
+    context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
 
     lengths
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 625fdd5..876c3a2 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -191,8 +191,6 @@ public class UnsafeShuffleWriterSuite {
       });
 
     when(taskContext.taskMetrics()).thenReturn(taskMetrics);
-    when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
-
     when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
     when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 5b84acf..11c97d7 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -17,18 +17,22 @@
 
 package org.apache.spark
 
+import javax.annotation.concurrent.GuardedBy
+
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.ref.WeakReference
+import scala.util.control.NonFatal
 
 import org.scalatest.Matchers
 import org.scalatest.exceptions.TestFailedException
 
 import org.apache.spark.scheduler._
+import org.apache.spark.serializer.JavaSerializer
 
 
 class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
-  import InternalAccumulator._
+  import AccumulatorParam._
 
   implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] =
     new AccumulableParam[mutable.Set[A], A] {
@@ -59,7 +63,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
     longAcc.value should be (210L + maxInt * 20)
   }
 
-  test ("value not assignable from tasks") {
+  test("value not assignable from tasks") {
     sc = new SparkContext("local", "test")
     val acc : Accumulator[Int] = sc.accumulator(0)
 
@@ -84,7 +88,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
     }
   }
 
-  test ("value not readable in tasks") {
+  test("value not readable in tasks") {
     val maxI = 1000
     for (nThreads <- List(1, 10)) { // test single & multi-threaded
       sc = new SparkContext("local[" + nThreads + "]", "test")
@@ -159,193 +163,157 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
     assert(!Accumulators.originals.get(accId).isDefined)
   }
 
-  test("internal accumulators in TaskContext") {
+  test("get accum") {
     sc = new SparkContext("local", "test")
-    val accums = InternalAccumulator.create(sc)
-    val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums)
-    val internalMetricsToAccums = taskContext.internalMetricsToAccumulators
-    val collectedInternalAccums = taskContext.collectInternalAccumulators()
-    val collectedAccums = taskContext.collectAccumulators()
-    assert(internalMetricsToAccums.size > 0)
-    assert(internalMetricsToAccums.values.forall(_.isInternal))
-    assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR))
-    val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR)
-    assert(collectedInternalAccums.size === internalMetricsToAccums.size)
-    assert(collectedInternalAccums.size === collectedAccums.size)
-    assert(collectedInternalAccums.contains(testAccum.id))
-    assert(collectedAccums.contains(testAccum.id))
-  }
+    // Don't register with SparkContext for cleanup
+    var acc = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true, true)
+    val accId = acc.id
+    val ref = WeakReference(acc)
+    assert(ref.get.isDefined)
+    Accumulators.register(ref.get.get)
 
-  test("internal accumulators in a stage") {
-    val listener = new SaveInfoListener
-    val numPartitions = 10
-    sc = new SparkContext("local", "test")
-    sc.addSparkListener(listener)
-    // Have each task add 1 to the internal accumulator
-    val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
-      TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
-      iter
-    }
-    // Register asserts in job completion callback to avoid flakiness
-    listener.registerJobCompletionCallback { _ =>
-      val stageInfos = listener.getCompletedStageInfos
-      val taskInfos = listener.getCompletedTaskInfos
-      assert(stageInfos.size === 1)
-      assert(taskInfos.size === numPartitions)
-      // The accumulator values should be merged in the stage
-      val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
-      assert(stageAccum.value.toLong === numPartitions)
-      // The accumulator should be updated locally on each task
-      val taskAccumValues = taskInfos.map { taskInfo =>
-        val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
-        assert(taskAccum.update.isDefined)
-        assert(taskAccum.update.get.toLong === 1)
-        taskAccum.value.toLong
-      }
-      // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions
-      assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+    // Remove the explicit reference to it and allow weak reference to get garbage collected
+    acc = null
+    System.gc()
+    assert(ref.get.isEmpty)
+
+    // Getting a garbage collected accum should throw error
+    intercept[IllegalAccessError] {
+      Accumulators.get(accId)
     }
-    rdd.count()
+
+    // Getting a normal accumulator. Note: this has to be separate because referencing an
+    // accumulator above in an `assert` would keep it from being garbage collected.
+    val acc2 = new Accumulable[Long, Long](0L, LongAccumulatorParam, None, true, true)
+    Accumulators.register(acc2)
+    assert(Accumulators.get(acc2.id) === Some(acc2))
+
+    // Getting an accumulator that does not exist should return None
+    assert(Accumulators.get(100000).isEmpty)
   }
 
-  test("internal accumulators in multiple stages") {
-    val listener = new SaveInfoListener
-    val numPartitions = 10
-    sc = new SparkContext("local", "test")
-    sc.addSparkListener(listener)
-    // Each stage creates its own set of internal accumulators so the
-    // values for the same metric should not be mixed up across stages
-    val rdd = sc.parallelize(1 to 100, numPartitions)
-      .map { i => (i, i) }
-      .mapPartitions { iter =>
-        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
-        iter
-      }
-      .reduceByKey { case (x, y) => x + y }
-      .mapPartitions { iter =>
-        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10
-        iter
-      }
-      .repartition(numPartitions * 2)
-      .mapPartitions { iter =>
-        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100
-        iter
-      }
-    // Register asserts in job completion callback to avoid flakiness
-    listener.registerJobCompletionCallback { _ =>
-      // We ran 3 stages, and the accumulator values should be distinct
-      val stageInfos = listener.getCompletedStageInfos
-      assert(stageInfos.size === 3)
-      val (firstStageAccum, secondStageAccum, thirdStageAccum) =
-        (findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR),
-        findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR),
-        findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR))
-      assert(firstStageAccum.value.toLong === numPartitions)
-      assert(secondStageAccum.value.toLong === numPartitions * 10)
-      assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100)
-    }
-    rdd.count()
+  test("only external accums are automatically registered") {
+    val accEx = new Accumulator(0, IntAccumulatorParam, Some("external"), internal = false)
+    val accIn = new Accumulator(0, IntAccumulatorParam, Some("internal"), internal = true)
+    assert(!accEx.isInternal)
+    assert(accIn.isInternal)
+    assert(Accumulators.get(accEx.id).isDefined)
+    assert(Accumulators.get(accIn.id).isEmpty)
   }
 
-  test("internal accumulators in fully resubmitted stages") {
-    testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks
+  test("copy") {
+    val acc1 = new Accumulable[Long, Long](456L, LongAccumulatorParam, Some("x"), true, false)
+    val acc2 = acc1.copy()
+    assert(acc1.id === acc2.id)
+    assert(acc1.value === acc2.value)
+    assert(acc1.name === acc2.name)
+    assert(acc1.isInternal === acc2.isInternal)
+    assert(acc1.countFailedValues === acc2.countFailedValues)
+    assert(acc1 !== acc2)
+    // Modifying one does not affect the other
+    acc1.add(44L)
+    assert(acc1.value === 500L)
+    assert(acc2.value === 456L)
+    acc2.add(144L)
+    assert(acc1.value === 500L)
+    assert(acc2.value === 600L)
   }
 
-  test("internal accumulators in partially resubmitted stages") {
-    testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset
+  test("register multiple accums with same ID") {
+    // Make sure these are internal accums so we don't automatically register them already
+    val acc1 = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true, true)
+    val acc2 = acc1.copy()
+    assert(acc1 !== acc2)
+    assert(acc1.id === acc2.id)
+    assert(Accumulators.originals.isEmpty)
+    assert(Accumulators.get(acc1.id).isEmpty)
+    Accumulators.register(acc1)
+    Accumulators.register(acc2)
+    // The second one does not override the first one
+    assert(Accumulators.originals.size === 1)
+    assert(Accumulators.get(acc1.id) === Some(acc1))
   }
 
-  /**
-   * Return the accumulable info that matches the specified name.
-   */
-  private def findAccumulableInfo(
-      accums: Iterable[AccumulableInfo],
-      name: String): AccumulableInfo = {
-    accums.find { a => a.name == name }.getOrElse {
-      throw new TestFailedException(s"internal accumulator '$name' not found", 0)
-    }
+  test("string accumulator param") {
+    val acc = new Accumulator("", StringAccumulatorParam, Some("darkness"))
+    assert(acc.value === "")
+    acc.setValue("feeds")
+    assert(acc.value === "feeds")
+    acc.add("your")
+    assert(acc.value === "your") // value is overwritten, not concatenated
+    acc += "soul"
+    assert(acc.value === "soul")
+    acc ++= "with"
+    assert(acc.value === "with")
+    acc.merge("kindness")
+    assert(acc.value === "kindness")
   }
 
-  /**
-   * Test whether internal accumulators are merged properly if some tasks fail.
-   */
-  private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = {
-    val listener = new SaveInfoListener
-    val numPartitions = 10
-    val numFailedPartitions = (0 until numPartitions).count(failCondition)
-    // This says use 1 core and retry tasks up to 2 times
-    sc = new SparkContext("local[1, 2]", "test")
-    sc.addSparkListener(listener)
-    val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) =>
-      val taskContext = TaskContext.get()
-      taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
-      // Fail the first attempts of a subset of the tasks
-      if (failCondition(i) && taskContext.attemptNumber() == 0) {
-        throw new Exception("Failing a task intentionally.")
-      }
-      iter
-    }
-    // Register asserts in job completion callback to avoid flakiness
-    listener.registerJobCompletionCallback { _ =>
-      val stageInfos = listener.getCompletedStageInfos
-      val taskInfos = listener.getCompletedTaskInfos
-      assert(stageInfos.size === 1)
-      assert(taskInfos.size === numPartitions + numFailedPartitions)
-      val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
-      // We should not double count values in the merged accumulator
-      assert(stageAccum.value.toLong === numPartitions)
-      val taskAccumValues = taskInfos.flatMap { taskInfo =>
-        if (!taskInfo.failed) {
-          // If a task succeeded, its update value should always be 1
-          val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
-          assert(taskAccum.update.isDefined)
-          assert(taskAccum.update.get.toLong === 1)
-          Some(taskAccum.value.toLong)
-        } else {
-          // If a task failed, we should not get its accumulator values
-          assert(taskInfo.accumulables.isEmpty)
-          None
-        }
-      }
-      assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
-    }
-    rdd.count()
+  test("list accumulator param") {
+    val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], Some("numbers"))
+    assert(acc.value === Seq.empty[Int])
+    acc.add(Seq(1, 2))
+    assert(acc.value === Seq(1, 2))
+    acc += Seq(3, 4)
+    assert(acc.value === Seq(1, 2, 3, 4))
+    acc ++= Seq(5, 6)
+    assert(acc.value === Seq(1, 2, 3, 4, 5, 6))
+    acc.merge(Seq(7, 8))
+    assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8))
+    acc.setValue(Seq(9, 10))
+    assert(acc.value === Seq(9, 10))
+  }
+
+  test("value is reset on the executors") {
+    val acc1 = new Accumulator(0, IntAccumulatorParam, Some("thing"), internal = false)
+    val acc2 = new Accumulator(0L, LongAccumulatorParam, Some("thing2"), internal = false)
+    val externalAccums = Seq(acc1, acc2)
+    val internalAccums = InternalAccumulator.create()
+    // Set some values; these should not be observed later on the "executors"
+    acc1.setValue(10)
+    acc2.setValue(20L)
+    internalAccums
+      .find(_.name == Some(InternalAccumulator.TEST_ACCUM))
+      .get.asInstanceOf[Accumulator[Long]]
+      .setValue(30L)
+    // Simulate the task being serialized and sent to the executors.
+    val dummyTask = new DummyTask(internalAccums, externalAccums)
+    val serInstance = new JavaSerializer(new SparkConf).newInstance()
+    val taskSer = Task.serializeWithDependencies(
+      dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance)
+    // Now we're on the executors.
+    // Deserialize the task and assert that its accumulators are zero'ed out.
+    val (_, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
+    val taskDeser = serInstance.deserialize[DummyTask](
+      taskBytes, Thread.currentThread.getContextClassLoader)
+    // Assert that executors see only zeros
+    taskDeser.externalAccums.foreach { a => assert(a.localValue == a.zero) }
+    taskDeser.internalAccums.foreach { a => assert(a.localValue == a.zero) }
   }
 
 }
 
 private[spark] object AccumulatorSuite {
 
+  import InternalAccumulator._
+
   /**
-   * Run one or more Spark jobs and verify that the peak execution memory accumulator
-   * is updated afterwards.
+   * Run one or more Spark jobs and verify that in at least one job the peak execution memory
+   * accumulator is updated afterwards.
    */
   def verifyPeakExecutionMemorySet(
       sc: SparkContext,
       testName: String)(testBody: => Unit): Unit = {
     val listener = new SaveInfoListener
     sc.addSparkListener(listener)
-    // Register asserts in job completion callback to avoid flakiness
-    listener.registerJobCompletionCallback { jobId =>
-      if (jobId == 0) {
-        // The first job is a dummy one to verify that the accumulator does not already exist
-        val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
-        assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY))
-      } else {
-        // In the subsequent jobs, verify that peak execution memory is updated
-        val accum = listener.getCompletedStageInfos
-          .flatMap(_.accumulables.values)
-          .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)
-          .getOrElse {
-          throw new TestFailedException(
-            s"peak execution memory accumulator not set in '$testName'", 0)
-        }
-        assert(accum.value.toLong > 0)
-      }
-    }
-    // Run the jobs
-    sc.parallelize(1 to 10).count()
     testBody
+    val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
+    val isSet = accums.exists { a =>
+      a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L)
+    }
+    if (!isSet) {
+      throw new TestFailedException(s"peak execution memory accumulator not set in '$testName'", 0)
+    }
   }
 }
 
@@ -357,6 +325,10 @@ private class SaveInfoListener extends SparkListener {
   private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo]
   private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID
 
+  // Accesses must be synchronized to ensure failures in `jobCompletionCallback` are propagated
+  @GuardedBy("this")
+  private var exception: Throwable = null
+
   def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq
   def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq
 
@@ -365,9 +337,20 @@ private class SaveInfoListener extends SparkListener {
     jobCompletionCallback = callback
   }
 
-  override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+  /** Throw a stored exception, if any. */
+  def maybeThrowException(): Unit = synchronized {
+    if (exception != null) { throw exception }
+  }
+
+  override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized {
     if (jobCompletionCallback != null) {
-      jobCompletionCallback(jobEnd.jobId)
+      try {
+        jobCompletionCallback(jobEnd.jobId)
+      } catch {
+        // Store any exception thrown here so we can throw them later in the main thread.
+        // Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test.
+        case NonFatal(e) => exception = e
+      }
     }
   }
 
@@ -379,3 +362,14 @@ private class SaveInfoListener extends SparkListener {
     completedTaskInfos += taskEnd.taskInfo
   }
 }
+
+
+/**
+ * A dummy [[Task]] that contains internal and external [[Accumulator]]s.
+ */
+private[spark] class DummyTask(
+    val internalAccums: Seq[Accumulator[_]],
+    val externalAccums: Seq[Accumulator[_]])
+  extends Task[Int](0, 0, 0, internalAccums) {
+  override def runTask(c: TaskContext): Int = 1
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index 4e678fb..80a1de6 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -801,7 +801,7 @@ class ExecutorAllocationManagerSuite
     assert(maxNumExecutorsNeeded(manager) === 1)
 
     // If the task is failed, we expect it to be resubmitted later.
-    val taskEndReason = ExceptionFailure(null, null, null, null, null, None)
+    val taskEndReason = ExceptionFailure(null, null, null, null, None)
     sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null))
     assert(maxNumExecutorsNeeded(manager) === 1)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index c7f629a..3777d77 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -215,14 +215,16 @@ class HeartbeatReceiverSuite
     val metrics = new TaskMetrics
     val blockManagerId = BlockManagerId(executorId, "localhost", 12345)
     val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
-      Heartbeat(executorId, Array(1L -> metrics), blockManagerId))
+      Heartbeat(executorId, Array(1L -> metrics.accumulatorUpdates()), blockManagerId))
     if (executorShouldReregister) {
       assert(response.reregisterBlockManager)
     } else {
       assert(!response.reregisterBlockManager)
       // Additionally verify that the scheduler callback is called with the correct parameters
       verify(scheduler).executorHeartbeatReceived(
-        Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
+        Matchers.eq(executorId),
+        Matchers.eq(Array(1L -> metrics.accumulatorUpdates())),
+        Matchers.eq(blockManagerId))
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org