You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2016/01/27 20:15:55 UTC

[2/4] spark git commit: [SPARK-12895][SPARK-12896] Migrate TaskMetrics to accumulators

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
new file mode 100644
index 0000000..630b46f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.storage.{BlockId, BlockStatus}
+
+
+class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
+  import InternalAccumulator._
+  import AccumulatorParam._
+
+  test("get param") {
+    assert(getParam(EXECUTOR_DESERIALIZE_TIME) === LongAccumulatorParam)
+    assert(getParam(EXECUTOR_RUN_TIME) === LongAccumulatorParam)
+    assert(getParam(RESULT_SIZE) === LongAccumulatorParam)
+    assert(getParam(JVM_GC_TIME) === LongAccumulatorParam)
+    assert(getParam(RESULT_SERIALIZATION_TIME) === LongAccumulatorParam)
+    assert(getParam(MEMORY_BYTES_SPILLED) === LongAccumulatorParam)
+    assert(getParam(DISK_BYTES_SPILLED) === LongAccumulatorParam)
+    assert(getParam(PEAK_EXECUTION_MEMORY) === LongAccumulatorParam)
+    assert(getParam(UPDATED_BLOCK_STATUSES) === UpdatedBlockStatusesAccumulatorParam)
+    assert(getParam(TEST_ACCUM) === LongAccumulatorParam)
+    // shuffle read
+    assert(getParam(shuffleRead.REMOTE_BLOCKS_FETCHED) === IntAccumulatorParam)
+    assert(getParam(shuffleRead.LOCAL_BLOCKS_FETCHED) === IntAccumulatorParam)
+    assert(getParam(shuffleRead.REMOTE_BYTES_READ) === LongAccumulatorParam)
+    assert(getParam(shuffleRead.LOCAL_BYTES_READ) === LongAccumulatorParam)
+    assert(getParam(shuffleRead.FETCH_WAIT_TIME) === LongAccumulatorParam)
+    assert(getParam(shuffleRead.RECORDS_READ) === LongAccumulatorParam)
+    // shuffle write
+    assert(getParam(shuffleWrite.BYTES_WRITTEN) === LongAccumulatorParam)
+    assert(getParam(shuffleWrite.RECORDS_WRITTEN) === LongAccumulatorParam)
+    assert(getParam(shuffleWrite.WRITE_TIME) === LongAccumulatorParam)
+    // input
+    assert(getParam(input.READ_METHOD) === StringAccumulatorParam)
+    assert(getParam(input.RECORDS_READ) === LongAccumulatorParam)
+    assert(getParam(input.BYTES_READ) === LongAccumulatorParam)
+    // output
+    assert(getParam(output.WRITE_METHOD) === StringAccumulatorParam)
+    assert(getParam(output.RECORDS_WRITTEN) === LongAccumulatorParam)
+    assert(getParam(output.BYTES_WRITTEN) === LongAccumulatorParam)
+    // default to Long
+    assert(getParam(METRICS_PREFIX + "anything") === LongAccumulatorParam)
+    intercept[IllegalArgumentException] {
+      getParam("something that does not start with the right prefix")
+    }
+  }
+
+  test("create by name") {
+    val executorRunTime = create(EXECUTOR_RUN_TIME)
+    val updatedBlockStatuses = create(UPDATED_BLOCK_STATUSES)
+    val shuffleRemoteBlocksRead = create(shuffleRead.REMOTE_BLOCKS_FETCHED)
+    val inputReadMethod = create(input.READ_METHOD)
+    assert(executorRunTime.name === Some(EXECUTOR_RUN_TIME))
+    assert(updatedBlockStatuses.name === Some(UPDATED_BLOCK_STATUSES))
+    assert(shuffleRemoteBlocksRead.name === Some(shuffleRead.REMOTE_BLOCKS_FETCHED))
+    assert(inputReadMethod.name === Some(input.READ_METHOD))
+    assert(executorRunTime.value.isInstanceOf[Long])
+    assert(updatedBlockStatuses.value.isInstanceOf[Seq[_]])
+    // We cannot assert the type of the value directly since the type parameter is erased.
+    // Instead, try casting a `Seq` of expected type and see if it fails in run time.
+    updatedBlockStatuses.setValueAny(Seq.empty[(BlockId, BlockStatus)])
+    assert(shuffleRemoteBlocksRead.value.isInstanceOf[Int])
+    assert(inputReadMethod.value.isInstanceOf[String])
+    // default to Long
+    val anything = create(METRICS_PREFIX + "anything")
+    assert(anything.value.isInstanceOf[Long])
+  }
+
+  test("create") {
+    val accums = create()
+    val shuffleReadAccums = createShuffleReadAccums()
+    val shuffleWriteAccums = createShuffleWriteAccums()
+    val inputAccums = createInputAccums()
+    val outputAccums = createOutputAccums()
+    // assert they're all internal
+    assert(accums.forall(_.isInternal))
+    assert(shuffleReadAccums.forall(_.isInternal))
+    assert(shuffleWriteAccums.forall(_.isInternal))
+    assert(inputAccums.forall(_.isInternal))
+    assert(outputAccums.forall(_.isInternal))
+    // assert they all count on failures
+    assert(accums.forall(_.countFailedValues))
+    assert(shuffleReadAccums.forall(_.countFailedValues))
+    assert(shuffleWriteAccums.forall(_.countFailedValues))
+    assert(inputAccums.forall(_.countFailedValues))
+    assert(outputAccums.forall(_.countFailedValues))
+    // assert they all have names
+    assert(accums.forall(_.name.isDefined))
+    assert(shuffleReadAccums.forall(_.name.isDefined))
+    assert(shuffleWriteAccums.forall(_.name.isDefined))
+    assert(inputAccums.forall(_.name.isDefined))
+    assert(outputAccums.forall(_.name.isDefined))
+    // assert `accums` is a strict superset of the others
+    val accumNames = accums.map(_.name.get).toSet
+    val shuffleReadAccumNames = shuffleReadAccums.map(_.name.get).toSet
+    val shuffleWriteAccumNames = shuffleWriteAccums.map(_.name.get).toSet
+    val inputAccumNames = inputAccums.map(_.name.get).toSet
+    val outputAccumNames = outputAccums.map(_.name.get).toSet
+    assert(shuffleReadAccumNames.subsetOf(accumNames))
+    assert(shuffleWriteAccumNames.subsetOf(accumNames))
+    assert(inputAccumNames.subsetOf(accumNames))
+    assert(outputAccumNames.subsetOf(accumNames))
+  }
+
+  test("naming") {
+    val accums = create()
+    val shuffleReadAccums = createShuffleReadAccums()
+    val shuffleWriteAccums = createShuffleWriteAccums()
+    val inputAccums = createInputAccums()
+    val outputAccums = createOutputAccums()
+    // assert that prefixes are properly namespaced
+    assert(SHUFFLE_READ_METRICS_PREFIX.startsWith(METRICS_PREFIX))
+    assert(SHUFFLE_WRITE_METRICS_PREFIX.startsWith(METRICS_PREFIX))
+    assert(INPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX))
+    assert(OUTPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX))
+    assert(accums.forall(_.name.get.startsWith(METRICS_PREFIX)))
+    // assert they all start with the expected prefixes
+    assert(shuffleReadAccums.forall(_.name.get.startsWith(SHUFFLE_READ_METRICS_PREFIX)))
+    assert(shuffleWriteAccums.forall(_.name.get.startsWith(SHUFFLE_WRITE_METRICS_PREFIX)))
+    assert(inputAccums.forall(_.name.get.startsWith(INPUT_METRICS_PREFIX)))
+    assert(outputAccums.forall(_.name.get.startsWith(OUTPUT_METRICS_PREFIX)))
+  }
+
+  test("internal accumulators in TaskContext") {
+    val taskContext = TaskContext.empty()
+    val accumUpdates = taskContext.taskMetrics.accumulatorUpdates()
+    assert(accumUpdates.size > 0)
+    assert(accumUpdates.forall(_.internal))
+    val testAccum = taskContext.taskMetrics.getAccum(TEST_ACCUM)
+    assert(accumUpdates.exists(_.id == testAccum.id))
+  }
+
+  test("internal accumulators in a stage") {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    sc = new SparkContext("local", "test")
+    sc.addSparkListener(listener)
+    // Have each task add 1 to the internal accumulator
+    val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
+      TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1
+      iter
+    }
+    // Register asserts in job completion callback to avoid flakiness
+    listener.registerJobCompletionCallback { _ =>
+      val stageInfos = listener.getCompletedStageInfos
+      val taskInfos = listener.getCompletedTaskInfos
+      assert(stageInfos.size === 1)
+      assert(taskInfos.size === numPartitions)
+      // The accumulator values should be merged in the stage
+      val stageAccum = findTestAccum(stageInfos.head.accumulables.values)
+      assert(stageAccum.value.get.asInstanceOf[Long] === numPartitions)
+      // The accumulator should be updated locally on each task
+      val taskAccumValues = taskInfos.map { taskInfo =>
+        val taskAccum = findTestAccum(taskInfo.accumulables)
+        assert(taskAccum.update.isDefined)
+        assert(taskAccum.update.get.asInstanceOf[Long] === 1L)
+        taskAccum.value.get.asInstanceOf[Long]
+      }
+      // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions
+      assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+    }
+    rdd.count()
+  }
+
+  test("internal accumulators in multiple stages") {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    sc = new SparkContext("local", "test")
+    sc.addSparkListener(listener)
+    // Each stage creates its own set of internal accumulators so the
+    // values for the same metric should not be mixed up across stages
+    val rdd = sc.parallelize(1 to 100, numPartitions)
+      .map { i => (i, i) }
+      .mapPartitions { iter =>
+        TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1
+        iter
+      }
+      .reduceByKey { case (x, y) => x + y }
+      .mapPartitions { iter =>
+        TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 10
+        iter
+      }
+      .repartition(numPartitions * 2)
+      .mapPartitions { iter =>
+        TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 100
+        iter
+      }
+    // Register asserts in job completion callback to avoid flakiness
+    listener.registerJobCompletionCallback { _ =>
+    // We ran 3 stages, and the accumulator values should be distinct
+      val stageInfos = listener.getCompletedStageInfos
+      assert(stageInfos.size === 3)
+      val (firstStageAccum, secondStageAccum, thirdStageAccum) =
+        (findTestAccum(stageInfos(0).accumulables.values),
+        findTestAccum(stageInfos(1).accumulables.values),
+        findTestAccum(stageInfos(2).accumulables.values))
+      assert(firstStageAccum.value.get.asInstanceOf[Long] === numPartitions)
+      assert(secondStageAccum.value.get.asInstanceOf[Long] === numPartitions * 10)
+      assert(thirdStageAccum.value.get.asInstanceOf[Long] === numPartitions * 2 * 100)
+    }
+    rdd.count()
+  }
+
+  // TODO: these two tests are incorrect; they don't actually trigger stage retries.
+  ignore("internal accumulators in fully resubmitted stages") {
+    testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks
+  }
+
+  ignore("internal accumulators in partially resubmitted stages") {
+    testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset
+  }
+
+  test("internal accumulators are registered for cleanups") {
+    sc = new SparkContext("local", "test") {
+      private val myCleaner = new SaveAccumContextCleaner(this)
+      override def cleaner: Option[ContextCleaner] = Some(myCleaner)
+    }
+    assert(Accumulators.originals.isEmpty)
+    sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count()
+    val internalAccums = InternalAccumulator.create()
+    // We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage
+    assert(Accumulators.originals.size === internalAccums.size * 2)
+    val accumsRegistered = sc.cleaner match {
+      case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup
+      case _ => Seq.empty[Long]
+    }
+    // Make sure the same set of accumulators is registered for cleanup
+    assert(accumsRegistered.size === internalAccums.size * 2)
+    assert(accumsRegistered.toSet === Accumulators.originals.keys.toSet)
+  }
+
+  /**
+   * Return the accumulable info that matches the specified name.
+   */
+  private def findTestAccum(accums: Iterable[AccumulableInfo]): AccumulableInfo = {
+    accums.find { a => a.name == Some(TEST_ACCUM) }.getOrElse {
+      fail(s"unable to find internal accumulator called $TEST_ACCUM")
+    }
+  }
+
+  /**
+   * Test whether internal accumulators are merged properly if some tasks fail.
+   * TODO: make this actually retry the stage.
+   */
+  private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    val numFailedPartitions = (0 until numPartitions).count(failCondition)
+    // This says use 1 core and retry tasks up to 2 times
+    sc = new SparkContext("local[1, 2]", "test")
+    sc.addSparkListener(listener)
+    val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) =>
+      val taskContext = TaskContext.get()
+      taskContext.taskMetrics().getAccum(TEST_ACCUM) += 1
+      // Fail the first attempts of a subset of the tasks
+      if (failCondition(i) && taskContext.attemptNumber() == 0) {
+        throw new Exception("Failing a task intentionally.")
+      }
+      iter
+    }
+    // Register asserts in job completion callback to avoid flakiness
+    listener.registerJobCompletionCallback { _ =>
+      val stageInfos = listener.getCompletedStageInfos
+      val taskInfos = listener.getCompletedTaskInfos
+      assert(stageInfos.size === 1)
+      assert(taskInfos.size === numPartitions + numFailedPartitions)
+      val stageAccum = findTestAccum(stageInfos.head.accumulables.values)
+      // If all partitions failed, then we would resubmit the whole stage again and create a
+      // fresh set of internal accumulators. Otherwise, these internal accumulators do count
+      // failed values, so we must include the failed values.
+      val expectedAccumValue =
+        if (numPartitions == numFailedPartitions) {
+          numPartitions
+        } else {
+          numPartitions + numFailedPartitions
+        }
+      assert(stageAccum.value.get.asInstanceOf[Long] === expectedAccumValue)
+      val taskAccumValues = taskInfos.flatMap { taskInfo =>
+        if (!taskInfo.failed) {
+          // If a task succeeded, its update value should always be 1
+          val taskAccum = findTestAccum(taskInfo.accumulables)
+          assert(taskAccum.update.isDefined)
+          assert(taskAccum.update.get.asInstanceOf[Long] === 1L)
+          assert(taskAccum.value.isDefined)
+          Some(taskAccum.value.get.asInstanceOf[Long])
+        } else {
+          // If a task failed, we should not get its accumulator values
+          assert(taskInfo.accumulables.isEmpty)
+          None
+        }
+      }
+      assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+    }
+    rdd.count()
+    listener.maybeThrowException()
+  }
+
+  /**
+   * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup.
+   */
+  private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) {
+    private val accumsRegistered = new ArrayBuffer[Long]
+
+    override def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
+      accumsRegistered += a.id
+      super.registerAccumulatorForCleanup(a)
+    }
+
+    def accumsRegisteredForCleanup: Seq[Long] = accumsRegistered.toArray
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index 9be9db0..d3359c7 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -42,6 +42,8 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging {
       test()
     } finally {
       logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n")
+      // Avoid leaking map entries in tests that use accumulators without SparkContext
+      Accumulators.clear()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
index e5ec2aa..15be0b1 100644
--- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
@@ -17,12 +17,542 @@
 
 package org.apache.spark.executor
 
-import org.apache.spark.SparkFunSuite
+import org.scalatest.Assertions
+
+import org.apache.spark._
+import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel, TestBlockId}
+
 
 class TaskMetricsSuite extends SparkFunSuite {
-  test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") {
-    val taskMetrics = new TaskMetrics()
-    taskMetrics.mergeShuffleReadMetrics()
-    assert(taskMetrics.shuffleReadMetrics.isEmpty)
+  import AccumulatorParam._
+  import InternalAccumulator._
+  import StorageLevel._
+  import TaskMetricsSuite._
+
+  test("create") {
+    val internalAccums = InternalAccumulator.create()
+    val tm1 = new TaskMetrics
+    val tm2 = new TaskMetrics(internalAccums)
+    assert(tm1.accumulatorUpdates().size === internalAccums.size)
+    assert(tm1.shuffleReadMetrics.isEmpty)
+    assert(tm1.shuffleWriteMetrics.isEmpty)
+    assert(tm1.inputMetrics.isEmpty)
+    assert(tm1.outputMetrics.isEmpty)
+    assert(tm2.accumulatorUpdates().size === internalAccums.size)
+    assert(tm2.shuffleReadMetrics.isEmpty)
+    assert(tm2.shuffleWriteMetrics.isEmpty)
+    assert(tm2.inputMetrics.isEmpty)
+    assert(tm2.outputMetrics.isEmpty)
+    // TaskMetrics constructor expects minimal set of initial accumulators
+    intercept[IllegalArgumentException] { new TaskMetrics(Seq.empty[Accumulator[_]]) }
+  }
+
+  test("create with unnamed accum") {
+    intercept[IllegalArgumentException] {
+      new TaskMetrics(
+        InternalAccumulator.create() ++ Seq(
+          new Accumulator(0, IntAccumulatorParam, None, internal = true)))
+    }
+  }
+
+  test("create with duplicate name accum") {
+    intercept[IllegalArgumentException] {
+      new TaskMetrics(
+        InternalAccumulator.create() ++ Seq(
+          new Accumulator(0, IntAccumulatorParam, Some(RESULT_SIZE), internal = true)))
+    }
+  }
+
+  test("create with external accum") {
+    intercept[IllegalArgumentException] {
+      new TaskMetrics(
+        InternalAccumulator.create() ++ Seq(
+          new Accumulator(0, IntAccumulatorParam, Some("x"))))
+    }
+  }
+
+  test("create shuffle read metrics") {
+    import shuffleRead._
+    val accums = InternalAccumulator.createShuffleReadAccums()
+      .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]
+    accums(REMOTE_BLOCKS_FETCHED).setValueAny(1)
+    accums(LOCAL_BLOCKS_FETCHED).setValueAny(2)
+    accums(REMOTE_BYTES_READ).setValueAny(3L)
+    accums(LOCAL_BYTES_READ).setValueAny(4L)
+    accums(FETCH_WAIT_TIME).setValueAny(5L)
+    accums(RECORDS_READ).setValueAny(6L)
+    val sr = new ShuffleReadMetrics(accums)
+    assert(sr.remoteBlocksFetched === 1)
+    assert(sr.localBlocksFetched === 2)
+    assert(sr.remoteBytesRead === 3L)
+    assert(sr.localBytesRead === 4L)
+    assert(sr.fetchWaitTime === 5L)
+    assert(sr.recordsRead === 6L)
+  }
+
+  test("create shuffle write metrics") {
+    import shuffleWrite._
+    val accums = InternalAccumulator.createShuffleWriteAccums()
+      .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]
+    accums(BYTES_WRITTEN).setValueAny(1L)
+    accums(RECORDS_WRITTEN).setValueAny(2L)
+    accums(WRITE_TIME).setValueAny(3L)
+    val sw = new ShuffleWriteMetrics(accums)
+    assert(sw.bytesWritten === 1L)
+    assert(sw.recordsWritten === 2L)
+    assert(sw.writeTime === 3L)
+  }
+
+  test("create input metrics") {
+    import input._
+    val accums = InternalAccumulator.createInputAccums()
+      .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]
+    accums(BYTES_READ).setValueAny(1L)
+    accums(RECORDS_READ).setValueAny(2L)
+    accums(READ_METHOD).setValueAny(DataReadMethod.Hadoop.toString)
+    val im = new InputMetrics(accums)
+    assert(im.bytesRead === 1L)
+    assert(im.recordsRead === 2L)
+    assert(im.readMethod === DataReadMethod.Hadoop)
+  }
+
+  test("create output metrics") {
+    import output._
+    val accums = InternalAccumulator.createOutputAccums()
+      .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]
+    accums(BYTES_WRITTEN).setValueAny(1L)
+    accums(RECORDS_WRITTEN).setValueAny(2L)
+    accums(WRITE_METHOD).setValueAny(DataWriteMethod.Hadoop.toString)
+    val om = new OutputMetrics(accums)
+    assert(om.bytesWritten === 1L)
+    assert(om.recordsWritten === 2L)
+    assert(om.writeMethod === DataWriteMethod.Hadoop)
+  }
+
+  test("mutating values") {
+    val accums = InternalAccumulator.create()
+    val tm = new TaskMetrics(accums)
+    // initial values
+    assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 0L)
+    assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 0L)
+    assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 0L)
+    assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 0L)
+    assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 0L)
+    assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 0L)
+    assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 0L)
+    assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 0L)
+    assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES,
+      Seq.empty[(BlockId, BlockStatus)])
+    // set or increment values
+    tm.setExecutorDeserializeTime(100L)
+    tm.setExecutorDeserializeTime(1L) // overwrite
+    tm.setExecutorRunTime(200L)
+    tm.setExecutorRunTime(2L)
+    tm.setResultSize(300L)
+    tm.setResultSize(3L)
+    tm.setJvmGCTime(400L)
+    tm.setJvmGCTime(4L)
+    tm.setResultSerializationTime(500L)
+    tm.setResultSerializationTime(5L)
+    tm.incMemoryBytesSpilled(600L)
+    tm.incMemoryBytesSpilled(6L) // add
+    tm.incDiskBytesSpilled(700L)
+    tm.incDiskBytesSpilled(7L)
+    tm.incPeakExecutionMemory(800L)
+    tm.incPeakExecutionMemory(8L)
+    val block1 = (TestBlockId("a"), BlockStatus(MEMORY_ONLY, 1L, 2L))
+    val block2 = (TestBlockId("b"), BlockStatus(MEMORY_ONLY, 3L, 4L))
+    tm.incUpdatedBlockStatuses(Seq(block1))
+    tm.incUpdatedBlockStatuses(Seq(block2))
+    // assert new values exist
+    assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 1L)
+    assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 2L)
+    assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 3L)
+    assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 4L)
+    assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 5L)
+    assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 606L)
+    assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 707L)
+    assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 808L)
+    assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES,
+      Seq(block1, block2))
+  }
+
+  test("mutating shuffle read metrics values") {
+    import shuffleRead._
+    val accums = InternalAccumulator.create()
+    val tm = new TaskMetrics(accums)
+    def assertValEquals[T](tmValue: ShuffleReadMetrics => T, name: String, value: T): Unit = {
+      assertValueEquals(tm, tm => tmValue(tm.shuffleReadMetrics.get), accums, name, value)
+    }
+    // create shuffle read metrics
+    assert(tm.shuffleReadMetrics.isEmpty)
+    tm.registerTempShuffleReadMetrics()
+    tm.mergeShuffleReadMetrics()
+    assert(tm.shuffleReadMetrics.isDefined)
+    val sr = tm.shuffleReadMetrics.get
+    // initial values
+    assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 0)
+    assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 0)
+    assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 0L)
+    assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 0L)
+    assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 0L)
+    assertValEquals(_.recordsRead, RECORDS_READ, 0L)
+    // set and increment values
+    sr.setRemoteBlocksFetched(100)
+    sr.setRemoteBlocksFetched(10)
+    sr.incRemoteBlocksFetched(1) // 10 + 1
+    sr.incRemoteBlocksFetched(1) // 10 + 1 + 1
+    sr.setLocalBlocksFetched(200)
+    sr.setLocalBlocksFetched(20)
+    sr.incLocalBlocksFetched(2)
+    sr.incLocalBlocksFetched(2)
+    sr.setRemoteBytesRead(300L)
+    sr.setRemoteBytesRead(30L)
+    sr.incRemoteBytesRead(3L)
+    sr.incRemoteBytesRead(3L)
+    sr.setLocalBytesRead(400L)
+    sr.setLocalBytesRead(40L)
+    sr.incLocalBytesRead(4L)
+    sr.incLocalBytesRead(4L)
+    sr.setFetchWaitTime(500L)
+    sr.setFetchWaitTime(50L)
+    sr.incFetchWaitTime(5L)
+    sr.incFetchWaitTime(5L)
+    sr.setRecordsRead(600L)
+    sr.setRecordsRead(60L)
+    sr.incRecordsRead(6L)
+    sr.incRecordsRead(6L)
+    // assert new values exist
+    assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 12)
+    assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 24)
+    assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 36L)
+    assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 48L)
+    assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 60L)
+    assertValEquals(_.recordsRead, RECORDS_READ, 72L)
+  }
+
+  test("mutating shuffle write metrics values") {
+    import shuffleWrite._
+    val accums = InternalAccumulator.create()
+    val tm = new TaskMetrics(accums)
+    def assertValEquals[T](tmValue: ShuffleWriteMetrics => T, name: String, value: T): Unit = {
+      assertValueEquals(tm, tm => tmValue(tm.shuffleWriteMetrics.get), accums, name, value)
+    }
+    // create shuffle write metrics
+    assert(tm.shuffleWriteMetrics.isEmpty)
+    tm.registerShuffleWriteMetrics()
+    assert(tm.shuffleWriteMetrics.isDefined)
+    val sw = tm.shuffleWriteMetrics.get
+    // initial values
+    assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L)
+    assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L)
+    assertValEquals(_.writeTime, WRITE_TIME, 0L)
+    // increment and decrement values
+    sw.incBytesWritten(100L)
+    sw.incBytesWritten(10L) // 100 + 10
+    sw.decBytesWritten(1L) // 100 + 10 - 1
+    sw.decBytesWritten(1L) // 100 + 10 - 1 - 1
+    sw.incRecordsWritten(200L)
+    sw.incRecordsWritten(20L)
+    sw.decRecordsWritten(2L)
+    sw.decRecordsWritten(2L)
+    sw.incWriteTime(300L)
+    sw.incWriteTime(30L)
+    // assert new values exist
+    assertValEquals(_.bytesWritten, BYTES_WRITTEN, 108L)
+    assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 216L)
+    assertValEquals(_.writeTime, WRITE_TIME, 330L)
+  }
+
+  test("mutating input metrics values") {
+    import input._
+    val accums = InternalAccumulator.create()
+    val tm = new TaskMetrics(accums)
+    def assertValEquals(tmValue: InputMetrics => Any, name: String, value: Any): Unit = {
+      assertValueEquals(tm, tm => tmValue(tm.inputMetrics.get), accums, name, value,
+        (x: Any, y: Any) => assert(x.toString === y.toString))
+    }
+    // create input metrics
+    assert(tm.inputMetrics.isEmpty)
+    tm.registerInputMetrics(DataReadMethod.Memory)
+    assert(tm.inputMetrics.isDefined)
+    val in = tm.inputMetrics.get
+    // initial values
+    assertValEquals(_.bytesRead, BYTES_READ, 0L)
+    assertValEquals(_.recordsRead, RECORDS_READ, 0L)
+    assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Memory)
+    // set and increment values
+    in.setBytesRead(1L)
+    in.setBytesRead(2L)
+    in.incRecordsRead(1L)
+    in.incRecordsRead(2L)
+    in.setReadMethod(DataReadMethod.Disk)
+    // assert new values exist
+    assertValEquals(_.bytesRead, BYTES_READ, 2L)
+    assertValEquals(_.recordsRead, RECORDS_READ, 3L)
+    assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Disk)
+  }
+
+  test("mutating output metrics values") {
+    import output._
+    val accums = InternalAccumulator.create()
+    val tm = new TaskMetrics(accums)
+    def assertValEquals(tmValue: OutputMetrics => Any, name: String, value: Any): Unit = {
+      assertValueEquals(tm, tm => tmValue(tm.outputMetrics.get), accums, name, value,
+        (x: Any, y: Any) => assert(x.toString === y.toString))
+    }
+    // create input metrics
+    assert(tm.outputMetrics.isEmpty)
+    tm.registerOutputMetrics(DataWriteMethod.Hadoop)
+    assert(tm.outputMetrics.isDefined)
+    val out = tm.outputMetrics.get
+    // initial values
+    assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L)
+    assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L)
+    assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop)
+    // set values
+    out.setBytesWritten(1L)
+    out.setBytesWritten(2L)
+    out.setRecordsWritten(3L)
+    out.setRecordsWritten(4L)
+    out.setWriteMethod(DataWriteMethod.Hadoop)
+    // assert new values exist
+    assertValEquals(_.bytesWritten, BYTES_WRITTEN, 2L)
+    assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 4L)
+    // Note: this doesn't actually test anything, but there's only one DataWriteMethod
+    // so we can't set it to anything else
+    assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop)
+  }
+
+  test("merging multiple shuffle read metrics") {
+    val tm = new TaskMetrics
+    assert(tm.shuffleReadMetrics.isEmpty)
+    val sr1 = tm.registerTempShuffleReadMetrics()
+    val sr2 = tm.registerTempShuffleReadMetrics()
+    val sr3 = tm.registerTempShuffleReadMetrics()
+    assert(tm.shuffleReadMetrics.isEmpty)
+    sr1.setRecordsRead(10L)
+    sr2.setRecordsRead(10L)
+    sr1.setFetchWaitTime(1L)
+    sr2.setFetchWaitTime(2L)
+    sr3.setFetchWaitTime(3L)
+    tm.mergeShuffleReadMetrics()
+    assert(tm.shuffleReadMetrics.isDefined)
+    val sr = tm.shuffleReadMetrics.get
+    assert(sr.remoteBlocksFetched === 0L)
+    assert(sr.recordsRead === 20L)
+    assert(sr.fetchWaitTime === 6L)
+
+    // SPARK-5701: calling merge without any shuffle deps does nothing
+    val tm2 = new TaskMetrics
+    tm2.mergeShuffleReadMetrics()
+    assert(tm2.shuffleReadMetrics.isEmpty)
+  }
+
+  test("register multiple shuffle write metrics") {
+    val tm = new TaskMetrics
+    val sw1 = tm.registerShuffleWriteMetrics()
+    val sw2 = tm.registerShuffleWriteMetrics()
+    assert(sw1 === sw2)
+    assert(tm.shuffleWriteMetrics === Some(sw1))
+  }
+
+  test("register multiple input metrics") {
+    val tm = new TaskMetrics
+    val im1 = tm.registerInputMetrics(DataReadMethod.Memory)
+    val im2 = tm.registerInputMetrics(DataReadMethod.Memory)
+    // input metrics with a different read method than the one already registered are ignored
+    val im3 = tm.registerInputMetrics(DataReadMethod.Hadoop)
+    assert(im1 === im2)
+    assert(im1 !== im3)
+    assert(tm.inputMetrics === Some(im1))
+    im2.setBytesRead(50L)
+    im3.setBytesRead(100L)
+    assert(tm.inputMetrics.get.bytesRead === 50L)
+  }
+
+  test("register multiple output metrics") {
+    val tm = new TaskMetrics
+    val om1 = tm.registerOutputMetrics(DataWriteMethod.Hadoop)
+    val om2 = tm.registerOutputMetrics(DataWriteMethod.Hadoop)
+    assert(om1 === om2)
+    assert(tm.outputMetrics === Some(om1))
+  }
+
+  test("additional accumulables") {
+    val internalAccums = InternalAccumulator.create()
+    val tm = new TaskMetrics(internalAccums)
+    assert(tm.accumulatorUpdates().size === internalAccums.size)
+    val acc1 = new Accumulator(0, IntAccumulatorParam, Some("a"))
+    val acc2 = new Accumulator(0, IntAccumulatorParam, Some("b"))
+    val acc3 = new Accumulator(0, IntAccumulatorParam, Some("c"))
+    val acc4 = new Accumulator(0, IntAccumulatorParam, Some("d"),
+      internal = true, countFailedValues = true)
+    tm.registerAccumulator(acc1)
+    tm.registerAccumulator(acc2)
+    tm.registerAccumulator(acc3)
+    tm.registerAccumulator(acc4)
+    acc1 += 1
+    acc2 += 2
+    val newUpdates = tm.accumulatorUpdates().map { a => (a.id, a) }.toMap
+    assert(newUpdates.contains(acc1.id))
+    assert(newUpdates.contains(acc2.id))
+    assert(newUpdates.contains(acc3.id))
+    assert(newUpdates.contains(acc4.id))
+    assert(newUpdates(acc1.id).name === Some("a"))
+    assert(newUpdates(acc2.id).name === Some("b"))
+    assert(newUpdates(acc3.id).name === Some("c"))
+    assert(newUpdates(acc4.id).name === Some("d"))
+    assert(newUpdates(acc1.id).update === Some(1))
+    assert(newUpdates(acc2.id).update === Some(2))
+    assert(newUpdates(acc3.id).update === Some(0))
+    assert(newUpdates(acc4.id).update === Some(0))
+    assert(!newUpdates(acc3.id).internal)
+    assert(!newUpdates(acc3.id).countFailedValues)
+    assert(newUpdates(acc4.id).internal)
+    assert(newUpdates(acc4.id).countFailedValues)
+    assert(newUpdates.values.map(_.update).forall(_.isDefined))
+    assert(newUpdates.values.map(_.value).forall(_.isEmpty))
+    assert(newUpdates.size === internalAccums.size + 4)
+  }
+
+  test("existing values in shuffle read accums") {
+    // set shuffle read accum before passing it into TaskMetrics
+    val accums = InternalAccumulator.create()
+    val srAccum = accums.find(_.name === Some(shuffleRead.FETCH_WAIT_TIME))
+    assert(srAccum.isDefined)
+    srAccum.get.asInstanceOf[Accumulator[Long]] += 10L
+    val tm = new TaskMetrics(accums)
+    assert(tm.shuffleReadMetrics.isDefined)
+    assert(tm.shuffleWriteMetrics.isEmpty)
+    assert(tm.inputMetrics.isEmpty)
+    assert(tm.outputMetrics.isEmpty)
+  }
+
+  test("existing values in shuffle write accums") {
+    // set shuffle write accum before passing it into TaskMetrics
+    val accums = InternalAccumulator.create()
+    val swAccum = accums.find(_.name === Some(shuffleWrite.RECORDS_WRITTEN))
+    assert(swAccum.isDefined)
+    swAccum.get.asInstanceOf[Accumulator[Long]] += 10L
+    val tm = new TaskMetrics(accums)
+    assert(tm.shuffleReadMetrics.isEmpty)
+    assert(tm.shuffleWriteMetrics.isDefined)
+    assert(tm.inputMetrics.isEmpty)
+    assert(tm.outputMetrics.isEmpty)
+  }
+
+  test("existing values in input accums") {
+    // set input accum before passing it into TaskMetrics
+    val accums = InternalAccumulator.create()
+    val inAccum = accums.find(_.name === Some(input.RECORDS_READ))
+    assert(inAccum.isDefined)
+    inAccum.get.asInstanceOf[Accumulator[Long]] += 10L
+    val tm = new TaskMetrics(accums)
+    assert(tm.shuffleReadMetrics.isEmpty)
+    assert(tm.shuffleWriteMetrics.isEmpty)
+    assert(tm.inputMetrics.isDefined)
+    assert(tm.outputMetrics.isEmpty)
   }
+
+  test("existing values in output accums") {
+    // set output accum before passing it into TaskMetrics
+    val accums = InternalAccumulator.create()
+    val outAccum = accums.find(_.name === Some(output.RECORDS_WRITTEN))
+    assert(outAccum.isDefined)
+    outAccum.get.asInstanceOf[Accumulator[Long]] += 10L
+    val tm4 = new TaskMetrics(accums)
+    assert(tm4.shuffleReadMetrics.isEmpty)
+    assert(tm4.shuffleWriteMetrics.isEmpty)
+    assert(tm4.inputMetrics.isEmpty)
+    assert(tm4.outputMetrics.isDefined)
+  }
+
+  test("from accumulator updates") {
+    val accumUpdates1 = InternalAccumulator.create().map { a =>
+      AccumulableInfo(a.id, a.name, Some(3L), None, a.isInternal, a.countFailedValues)
+    }
+    val metrics1 = TaskMetrics.fromAccumulatorUpdates(accumUpdates1)
+    assertUpdatesEquals(metrics1.accumulatorUpdates(), accumUpdates1)
+    // Test this with additional accumulators. Only the ones registered with `Accumulators`
+    // will show up in the reconstructed TaskMetrics. In practice, all accumulators created
+    // on the driver, internal or not, should be registered with `Accumulators` at some point.
+    // Here we show that reconstruction will succeed even if there are unregistered accumulators.
+    val param = IntAccumulatorParam
+    val registeredAccums = Seq(
+      new Accumulator(0, param, Some("a"), internal = true, countFailedValues = true),
+      new Accumulator(0, param, Some("b"), internal = true, countFailedValues = false),
+      new Accumulator(0, param, Some("c"), internal = false, countFailedValues = true),
+      new Accumulator(0, param, Some("d"), internal = false, countFailedValues = false))
+    val unregisteredAccums = Seq(
+      new Accumulator(0, param, Some("e"), internal = true, countFailedValues = true),
+      new Accumulator(0, param, Some("f"), internal = true, countFailedValues = false))
+    registeredAccums.foreach(Accumulators.register)
+    registeredAccums.foreach { a => assert(Accumulators.originals.contains(a.id)) }
+    unregisteredAccums.foreach { a => assert(!Accumulators.originals.contains(a.id)) }
+    // set some values in these accums
+    registeredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) }
+    unregisteredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) }
+    val registeredAccumInfos = registeredAccums.map(makeInfo)
+    val unregisteredAccumInfos = unregisteredAccums.map(makeInfo)
+    val accumUpdates2 = accumUpdates1 ++ registeredAccumInfos ++ unregisteredAccumInfos
+    val metrics2 = TaskMetrics.fromAccumulatorUpdates(accumUpdates2)
+    // accumulators that were not registered with `Accumulators` will not show up
+    assertUpdatesEquals(metrics2.accumulatorUpdates(), accumUpdates1 ++ registeredAccumInfos)
+  }
+}
+
+
+private[spark] object TaskMetricsSuite extends Assertions {
+
+  /**
+   * Assert that the following three things are equal to `value`:
+   *   (1) TaskMetrics value
+   *   (2) TaskMetrics accumulator update value
+   *   (3) Original accumulator value
+   */
+  def assertValueEquals(
+      tm: TaskMetrics,
+      tmValue: TaskMetrics => Any,
+      accums: Seq[Accumulator[_]],
+      metricName: String,
+      value: Any,
+      assertEquals: (Any, Any) => Unit = (x: Any, y: Any) => assert(x === y)): Unit = {
+    assertEquals(tmValue(tm), value)
+    val accum = accums.find(_.name == Some(metricName))
+    assert(accum.isDefined)
+    assertEquals(accum.get.value, value)
+    val accumUpdate = tm.accumulatorUpdates().find(_.name == Some(metricName))
+    assert(accumUpdate.isDefined)
+    assert(accumUpdate.get.value === None)
+    assertEquals(accumUpdate.get.update, Some(value))
+  }
+
+  /**
+   * Assert that two lists of accumulator updates are equal.
+   * Note: this does NOT check accumulator ID equality.
+   */
+  def assertUpdatesEquals(
+      updates1: Seq[AccumulableInfo],
+      updates2: Seq[AccumulableInfo]): Unit = {
+    assert(updates1.size === updates2.size)
+    updates1.zip(updates2).foreach { case (info1, info2) =>
+      // do not assert ID equals here
+      assert(info1.name === info2.name)
+      assert(info1.update === info2.update)
+      assert(info1.value === info2.value)
+      assert(info1.internal === info2.internal)
+      assert(info1.countFailedValues === info2.countFailedValues)
+    }
+  }
+
+  /**
+   * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the
+   * info as an accumulator update.
+   */
+  def makeInfo(a: Accumulable[_, _]): AccumulableInfo = {
+    new AccumulableInfo(a.id, a.name, Some(a.value), None, a.isInternal, a.countFailedValues)
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
index 0e60cc8..2b5e4b8 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -31,7 +31,6 @@ object MemoryTestingUtils {
       taskAttemptId = 0,
       attemptNumber = 0,
       taskMemoryManager = taskMemoryManager,
-      metricsSystem = env.metricsSystem,
-      internalAccumulators = Seq.empty)
+      metricsSystem = env.metricsSystem)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 370a284..d9c71ec 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -23,7 +23,6 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
 import scala.language.reflectiveCalls
 import scala.util.control.NonFatal
 
-import org.scalatest.BeforeAndAfter
 import org.scalatest.concurrent.Timeouts
 import org.scalatest.time.SpanSugar._
 
@@ -96,8 +95,7 @@ class MyRDD(
 
 class DAGSchedulerSuiteDummyException extends Exception
 
-class DAGSchedulerSuite
-  extends SparkFunSuite with BeforeAndAfter with LocalSparkContext with Timeouts {
+class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeouts {
 
   val conf = new SparkConf
   /** Set of TaskSets the DAGScheduler has requested executed. */
@@ -111,8 +109,10 @@ class DAGSchedulerSuite
     override def schedulingMode: SchedulingMode = SchedulingMode.NONE
     override def start() = {}
     override def stop() = {}
-    override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)],
-      blockManagerId: BlockManagerId): Boolean = true
+    override def executorHeartbeatReceived(
+        execId: String,
+        accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+        blockManagerId: BlockManagerId): Boolean = true
     override def submitTasks(taskSet: TaskSet) = {
       // normally done by TaskSetManager
       taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
@@ -189,7 +189,8 @@ class DAGSchedulerSuite
     override def jobFailed(exception: Exception): Unit = { failure = exception }
   }
 
-  before {
+  override def beforeEach(): Unit = {
+    super.beforeEach()
     sc = new SparkContext("local", "DAGSchedulerSuite")
     sparkListener.submittedStageInfos.clear()
     sparkListener.successfulStages.clear()
@@ -202,17 +203,21 @@ class DAGSchedulerSuite
     results.clear()
     mapOutputTracker = new MapOutputTrackerMaster(conf)
     scheduler = new DAGScheduler(
-        sc,
-        taskScheduler,
-        sc.listenerBus,
-        mapOutputTracker,
-        blockManagerMaster,
-        sc.env)
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env)
     dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
   }
 
-  after {
-    scheduler.stop()
+  override def afterEach(): Unit = {
+    try {
+      scheduler.stop()
+    } finally {
+      super.afterEach()
+    }
   }
 
   override def afterAll() {
@@ -242,26 +247,31 @@ class DAGSchedulerSuite
    * directly through CompletionEvents.
    */
   private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) =>
-     it.next.asInstanceOf[Tuple2[_, _]]._1
+    it.next.asInstanceOf[Tuple2[_, _]]._1
 
   /** Send the given CompletionEvent messages for the tasks in the TaskSet. */
   private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
     assert(taskSet.tasks.size >= results.size)
     for ((result, i) <- results.zipWithIndex) {
       if (i < taskSet.tasks.size) {
-        runEvent(CompletionEvent(
-          taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null))
+        runEvent(makeCompletionEvent(taskSet.tasks(i), result._1, result._2))
       }
     }
   }
 
-  private def completeWithAccumulator(accumId: Long, taskSet: TaskSet,
-                                      results: Seq[(TaskEndReason, Any)]) {
+  private def completeWithAccumulator(
+      accumId: Long,
+      taskSet: TaskSet,
+      results: Seq[(TaskEndReason, Any)]) {
     assert(taskSet.tasks.size >= results.size)
     for ((result, i) <- results.zipWithIndex) {
       if (i < taskSet.tasks.size) {
-        runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2,
-          Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null))
+        runEvent(makeCompletionEvent(
+          taskSet.tasks(i),
+          result._1,
+          result._2,
+          Seq(new AccumulableInfo(
+            accumId, Some(""), Some(1), None, internal = false, countFailedValues = false))))
       }
     }
   }
@@ -338,9 +348,12 @@ class DAGSchedulerSuite
   }
 
   test("equals and hashCode AccumulableInfo") {
-    val accInfo1 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, true)
-    val accInfo2 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false)
-    val accInfo3 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false)
+    val accInfo1 = new AccumulableInfo(
+      1, Some("a1"), Some("delta1"), Some("val1"), internal = true, countFailedValues = false)
+    val accInfo2 = new AccumulableInfo(
+      1, Some("a1"), Some("delta1"), Some("val1"), internal = false, countFailedValues = false)
+    val accInfo3 = new AccumulableInfo(
+      1, Some("a1"), Some("delta1"), Some("val1"), internal = false, countFailedValues = false)
     assert(accInfo1 !== accInfo2)
     assert(accInfo2 === accInfo3)
     assert(accInfo2.hashCode() === accInfo3.hashCode())
@@ -464,7 +477,7 @@ class DAGSchedulerSuite
       override def defaultParallelism(): Int = 2
       override def executorHeartbeatReceived(
           execId: String,
-          taskMetrics: Array[(Long, TaskMetrics)],
+          accumUpdates: Array[(Long, Seq[AccumulableInfo])],
           blockManagerId: BlockManagerId): Boolean = true
       override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
       override def applicationAttemptId(): Option[String] = None
@@ -499,8 +512,8 @@ class DAGSchedulerSuite
     val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0))
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostB", 1))))
+      (Success, makeMapStatus("hostA", 1)),
+      (Success, makeMapStatus("hostB", 1))))
     assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
       HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
     complete(taskSets(1), Seq((Success, 42)))
@@ -515,12 +528,12 @@ class DAGSchedulerSuite
     val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0, 1))
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", reduceRdd.partitions.length)),
-        (Success, makeMapStatus("hostB", reduceRdd.partitions.length))))
+      (Success, makeMapStatus("hostA", reduceRdd.partitions.length)),
+      (Success, makeMapStatus("hostB", reduceRdd.partitions.length))))
     // the 2nd ResultTask failed
     complete(taskSets(1), Seq(
-        (Success, 42),
-        (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null)))
+      (Success, 42),
+      (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null)))
     // this will get called
     // blockManagerMaster.removeExecutor("exec-hostA")
     // ask the scheduler to try it again
@@ -829,23 +842,17 @@ class DAGSchedulerSuite
       HashSet("hostA", "hostB"))
 
     // The first result task fails, with a fetch failure for the output from the first mapper.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(0),
       FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
     sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
     assert(sparkListener.failedStages.contains(1))
 
     // The second ResultTask fails, with a fetch failure for the output from the second mapper.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(0),
       FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
     // The SparkListener should not receive redundant failure events.
     sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
@@ -882,12 +889,9 @@ class DAGSchedulerSuite
       HashSet("hostA", "hostB"))
 
     // The first result task fails, with a fetch failure for the output from the first mapper.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(0),
       FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
     sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
     assert(sparkListener.failedStages.contains(1))
@@ -900,12 +904,9 @@ class DAGSchedulerSuite
     assert(countSubmittedMapStageAttempts() === 2)
 
     // The second ResultTask fails, with a fetch failure for the output from the second mapper.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(1),
       FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
 
     // Another ResubmitFailedStages event should not result in another attempt for the map
@@ -920,11 +921,11 @@ class DAGSchedulerSuite
   }
 
   /**
-    * This tests the case where a late FetchFailed comes in after the map stage has finished getting
-    * retried and a new reduce stage starts running.
-    */
+   * This tests the case where a late FetchFailed comes in after the map stage has finished getting
+   * retried and a new reduce stage starts running.
+   */
   test("extremely late fetch failures don't cause multiple concurrent attempts for " +
-      "the same stage") {
+    "the same stage") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val shuffleId = shuffleDep.shuffleId
@@ -952,12 +953,9 @@ class DAGSchedulerSuite
     assert(countSubmittedReduceStageAttempts() === 1)
 
     // The first result task fails, with a fetch failure for the output from the first mapper.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(0),
       FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
 
     // Trigger resubmission of the failed map stage and finish the re-started map task.
@@ -971,12 +969,9 @@ class DAGSchedulerSuite
     assert(countSubmittedReduceStageAttempts() === 2)
 
     // A late FetchFailed arrives from the second task in the original reduce stage.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(1),
       FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
 
     // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because
@@ -1007,48 +1002,36 @@ class DAGSchedulerSuite
     assert(shuffleStage.numAvailableOutputs === 0)
 
     // should be ignored for being too old
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSet.tasks(0),
       Success,
-      makeMapStatus("hostA", reduceRdd.partitions.size),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostA", reduceRdd.partitions.size)))
     assert(shuffleStage.numAvailableOutputs === 0)
 
     // should work because it's a non-failed host (so the available map outputs will increase)
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSet.tasks(0),
       Success,
-      makeMapStatus("hostB", reduceRdd.partitions.size),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostB", reduceRdd.partitions.size)))
     assert(shuffleStage.numAvailableOutputs === 1)
 
     // should be ignored for being too old
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSet.tasks(0),
       Success,
-      makeMapStatus("hostA", reduceRdd.partitions.size),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostA", reduceRdd.partitions.size)))
     assert(shuffleStage.numAvailableOutputs === 1)
 
     // should work because it's a new epoch, which will increase the number of available map
     // outputs, and also finish the stage
     taskSet.tasks(1).epoch = newEpoch
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSet.tasks(1),
       Success,
-      makeMapStatus("hostA", reduceRdd.partitions.size),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostA", reduceRdd.partitions.size)))
     assert(shuffleStage.numAvailableOutputs === 2)
     assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
-           HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+      HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
 
     // finish the next stage normally, which completes the job
     complete(taskSets(1), Seq((Success, 42), (Success, 43)))
@@ -1140,12 +1123,9 @@ class DAGSchedulerSuite
 
     // then one executor dies, and a task fails in stage 1
     runEvent(ExecutorLost("exec-hostA"))
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(0),
       FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"),
-      null,
-      null,
-      createFakeTaskInfo(),
       null))
 
     // so we resubmit stage 0, which completes happily
@@ -1155,13 +1135,10 @@ class DAGSchedulerSuite
     assert(stage0Resubmit.stageAttemptId === 1)
     val task = stage0Resubmit.tasks(0)
     assert(task.partitionId === 2)
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       task,
       Success,
-      makeMapStatus("hostC", shuffleMapRdd.partitions.length),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostC", shuffleMapRdd.partitions.length)))
 
     // now here is where things get tricky : we will now have a task set representing
     // the second attempt for stage 1, but we *also* have some tasks for the first attempt for
@@ -1174,28 +1151,19 @@ class DAGSchedulerSuite
     // we'll have some tasks finish from the first attempt, and some finish from the second attempt,
     // so that we actually have all stage outputs, though no attempt has completed all its
     // tasks
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(3).tasks(0),
       Success,
-      makeMapStatus("hostC", reduceRdd.partitions.length),
-      null,
-      createFakeTaskInfo(),
-      null))
-    runEvent(CompletionEvent(
+      makeMapStatus("hostC", reduceRdd.partitions.length)))
+    runEvent(makeCompletionEvent(
       taskSets(3).tasks(1),
       Success,
-      makeMapStatus("hostC", reduceRdd.partitions.length),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostC", reduceRdd.partitions.length)))
     // late task finish from the first attempt
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(2),
       Success,
-      makeMapStatus("hostB", reduceRdd.partitions.length),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostB", reduceRdd.partitions.length)))
 
     // What should happen now is that we submit stage 2.  However, we might not see an error
     // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them).  But
@@ -1242,21 +1210,21 @@ class DAGSchedulerSuite
     submit(reduceRdd, Array(0))
 
     // complete some of the tasks from the first stage, on one host
-    runEvent(CompletionEvent(
-      taskSets(0).tasks(0), Success,
-      makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null))
-    runEvent(CompletionEvent(
-      taskSets(0).tasks(1), Success,
-      makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(0),
+      Success,
+      makeMapStatus("hostA", reduceRdd.partitions.length)))
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(1),
+      Success,
+      makeMapStatus("hostA", reduceRdd.partitions.length)))
 
     // now that host goes down
     runEvent(ExecutorLost("exec-hostA"))
 
     // so we resubmit those tasks
-    runEvent(CompletionEvent(
-      taskSets(0).tasks(0), Resubmitted, null, null, createFakeTaskInfo(), null))
-    runEvent(CompletionEvent(
-      taskSets(0).tasks(1), Resubmitted, null, null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(taskSets(0).tasks(0), Resubmitted, null))
+    runEvent(makeCompletionEvent(taskSets(0).tasks(1), Resubmitted, null))
 
     // now complete everything on a different host
     complete(taskSets(0), Seq(
@@ -1449,12 +1417,12 @@ class DAGSchedulerSuite
     // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
     // rather than marking it is as failed and waiting.
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", 1)),
-       (Success, makeMapStatus("hostB", 1))))
+      (Success, makeMapStatus("hostA", 1)),
+      (Success, makeMapStatus("hostB", 1))))
     // have hostC complete the resubmitted task
     complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
     assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
-           HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+      HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
     complete(taskSets(2), Seq((Success, 42)))
     assert(results === Map(0 -> 42))
     assertDataStructuresEmpty()
@@ -1469,15 +1437,15 @@ class DAGSchedulerSuite
     submit(finalRdd, Array(0))
     // have the first stage complete normally
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", 2)),
-        (Success, makeMapStatus("hostB", 2))))
+      (Success, makeMapStatus("hostA", 2)),
+      (Success, makeMapStatus("hostB", 2))))
     // have the second stage complete normally
     complete(taskSets(1), Seq(
-        (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostC", 1))))
+      (Success, makeMapStatus("hostA", 1)),
+      (Success, makeMapStatus("hostC", 1))))
     // fail the third stage because hostA went down
     complete(taskSets(2), Seq(
-        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
+      (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
     // TODO assert this:
     // blockManagerMaster.removeExecutor("exec-hostA")
     // have DAGScheduler try again
@@ -1500,15 +1468,15 @@ class DAGSchedulerSuite
     cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
     // complete stage 0
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", 2)),
-        (Success, makeMapStatus("hostB", 2))))
+      (Success, makeMapStatus("hostA", 2)),
+      (Success, makeMapStatus("hostB", 2))))
     // complete stage 1
     complete(taskSets(1), Seq(
-        (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostB", 1))))
+      (Success, makeMapStatus("hostA", 1)),
+      (Success, makeMapStatus("hostB", 1))))
     // pretend stage 2 failed because hostA went down
     complete(taskSets(2), Seq(
-        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
+      (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
     // TODO assert this:
     // blockManagerMaster.removeExecutor("exec-hostA")
     // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
@@ -1606,6 +1574,28 @@ class DAGSchedulerSuite
     assertDataStructuresEmpty()
   }
 
+  test("accumulators are updated on exception failures") {
+    val acc1 = sc.accumulator(0L, "ingenieur")
+    val acc2 = sc.accumulator(0L, "boulanger")
+    val acc3 = sc.accumulator(0L, "agriculteur")
+    assert(Accumulators.get(acc1.id).isDefined)
+    assert(Accumulators.get(acc2.id).isDefined)
+    assert(Accumulators.get(acc3.id).isDefined)
+    val accInfo1 = new AccumulableInfo(
+      acc1.id, acc1.name, Some(15L), None, internal = false, countFailedValues = false)
+    val accInfo2 = new AccumulableInfo(
+      acc2.id, acc2.name, Some(13L), None, internal = false, countFailedValues = false)
+    val accInfo3 = new AccumulableInfo(
+      acc3.id, acc3.name, Some(18L), None, internal = false, countFailedValues = false)
+    val accumUpdates = Seq(accInfo1, accInfo2, accInfo3)
+    val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates)
+    submit(new MyRDD(sc, 1, Nil), Array(0))
+    runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result"))
+    assert(Accumulators.get(acc1.id).get.value === 15L)
+    assert(Accumulators.get(acc2.id).get.value === 13L)
+    assert(Accumulators.get(acc3.id).get.value === 18L)
+  }
+
   test("reduce tasks should be placed locally with map output") {
     // Create an shuffleMapRdd with 1 partition
     val shuffleMapRdd = new MyRDD(sc, 1, Nil)
@@ -1614,9 +1604,9 @@ class DAGSchedulerSuite
     val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0))
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", 1))))
+      (Success, makeMapStatus("hostA", 1))))
     assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
-           HashSet(makeBlockManagerId("hostA")))
+      HashSet(makeBlockManagerId("hostA")))
 
     // Reducer should run on the same host that map task ran
     val reduceTaskSet = taskSets(1)
@@ -1884,8 +1874,7 @@ class DAGSchedulerSuite
     submitMapStage(shuffleDep)
 
     val oldTaskSet = taskSets(0)
-    runEvent(CompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2),
-      null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2)))
     assert(results.size === 0)    // Map stage job should not be complete yet
 
     // Pretend host A was lost
@@ -1895,23 +1884,19 @@ class DAGSchedulerSuite
     assert(newEpoch > oldEpoch)
 
     // Suppose we also get a completed event from task 1 on the same host; this should be ignored
-    runEvent(CompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2),
-      null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2)))
     assert(results.size === 0)    // Map stage job should not be complete yet
 
     // A completion from another task should work because it's a non-failed host
-    runEvent(CompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2),
-      null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2)))
     assert(results.size === 0)    // Map stage job should not be complete yet
 
     // Now complete tasks in the second task set
     val newTaskSet = taskSets(1)
     assert(newTaskSet.tasks.size === 2)     // Both tasks 0 and 1 were on on hostA
-    runEvent(CompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2),
-      null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2)))
     assert(results.size === 0)    // Map stage job should not be complete yet
-    runEvent(CompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2),
-      null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2)))
     assert(results.size === 1)    // Map stage job should now finally be complete
     assertDataStructuresEmpty()
 
@@ -1962,5 +1947,21 @@ class DAGSchedulerSuite
     info
   }
 
-}
+  private def makeCompletionEvent(
+      task: Task[_],
+      reason: TaskEndReason,
+      result: Any,
+      extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo],
+      taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = {
+    val accumUpdates = reason match {
+      case Success =>
+        task.initialAccumulators.map { a =>
+          new AccumulableInfo(a.id, a.name, Some(a.zero), None, a.isInternal, a.countFailedValues)
+        }
+      case ef: ExceptionFailure => ef.accumUpdates
+      case _ => Seq.empty[AccumulableInfo]
+    }
+    CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo)
+  }
 
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
index 761e82e..35215c1 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfter
 import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.io.CompressionCodec
-import org.apache.spark.util.{JsonProtocol, Utils}
+import org.apache.spark.util.{JsonProtocol, JsonProtocolSuite, Utils}
 
 /**
  * Test whether ReplayListenerBus replays events from logs correctly.
@@ -131,7 +131,11 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter {
     assert(sc.eventLogger.isDefined)
     val originalEvents = sc.eventLogger.get.loggedEvents
     val replayedEvents = eventMonster.loggedEvents
-    originalEvents.zip(replayedEvents).foreach { case (e1, e2) => assert(e1 === e2) }
+    originalEvents.zip(replayedEvents).foreach { case (e1, e2) =>
+      // Don't compare the JSON here because accumulators in StageInfo may be out of order
+      JsonProtocolSuite.assertEquals(
+        JsonProtocol.sparkEventFromJson(e1), JsonProtocol.sparkEventFromJson(e2))
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index e5ec44a..b3bb86d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -22,6 +22,8 @@ import org.mockito.Mockito._
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark._
+import org.apache.spark.executor.TaskMetricsSuite
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.metrics.source.JvmSource
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
@@ -57,8 +59,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
     val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
     val func = (c: TaskContext, i: Iterator[String]) => i.next()
     val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
-    val task = new ResultTask[String, String](
-      0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
+    val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
     intercept[RuntimeException] {
       task.run(0, 0, null)
     }
@@ -97,6 +98,57 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
     }.collect()
     assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
   }
+
+  test("accumulators are updated on exception failures") {
+    // This means use 1 core and 4 max task failures
+    sc = new SparkContext("local[1,4]", "test")
+    val param = AccumulatorParam.LongAccumulatorParam
+    // Create 2 accumulators, one that counts failed values and another that doesn't
+    val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true)
+    val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false)
+    // Fail first 3 attempts of every task. This means each task should be run 4 times.
+    sc.parallelize(1 to 10, 10).map { i =>
+      acc1 += 1
+      acc2 += 1
+      if (TaskContext.get.attemptNumber() <= 2) {
+        throw new Exception("you did something wrong")
+      } else {
+        0
+      }
+    }.count()
+    // The one that counts failed values should be 4x the one that didn't,
+    // since we ran each task 4 times
+    assert(Accumulators.get(acc1.id).get.value === 40L)
+    assert(Accumulators.get(acc2.id).get.value === 10L)
+  }
+
+  test("failed tasks collect only accumulators whose values count during failures") {
+    sc = new SparkContext("local", "test")
+    val param = AccumulatorParam.LongAccumulatorParam
+    val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true)
+    val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false)
+    val initialAccums = InternalAccumulator.create()
+    // Create a dummy task. We won't end up running this; we just want to collect
+    // accumulator updates from it.
+    val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]]) {
+      context = new TaskContextImpl(0, 0, 0L, 0,
+        new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
+        SparkEnv.get.metricsSystem,
+        initialAccums)
+      context.taskMetrics.registerAccumulator(acc1)
+      context.taskMetrics.registerAccumulator(acc2)
+      override def runTask(tc: TaskContext): Int = 0
+    }
+    // First, simulate task success. This should give us all the accumulators.
+    val accumUpdates1 = task.collectAccumulatorUpdates(taskFailed = false)
+    val accumUpdates2 = (initialAccums ++ Seq(acc1, acc2)).map(TaskMetricsSuite.makeInfo)
+    TaskMetricsSuite.assertUpdatesEquals(accumUpdates1, accumUpdates2)
+    // Now, simulate task failures. This should give us only the accums that count failed values.
+    val accumUpdates3 = task.collectAccumulatorUpdates(taskFailed = true)
+    val accumUpdates4 = (initialAccums ++ Seq(acc1)).map(TaskMetricsSuite.makeInfo)
+    TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4)
+  }
+
 }
 
 private object TaskContextSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index cc2557c..b5385c1 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -21,10 +21,15 @@ import java.io.File
 import java.net.URL
 import java.nio.ByteBuffer
 
+import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.duration._
 import scala.language.postfixOps
 import scala.util.control.NonFatal
 
+import com.google.common.util.concurrent.MoreExecutors
+import org.mockito.ArgumentCaptor
+import org.mockito.Matchers.{any, anyLong}
+import org.mockito.Mockito.{spy, times, verify}
 import org.scalatest.BeforeAndAfter
 import org.scalatest.concurrent.Eventually._
 
@@ -33,13 +38,14 @@ import org.apache.spark.storage.TaskResultBlockId
 import org.apache.spark.TestUtils.JavaSourceFromString
 import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils}
 
+
 /**
  * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
  *
  * Used to test the case where a BlockManager evicts the task result (or dies) before the
  * TaskResult is retrieved.
  */
-class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
+private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
   extends TaskResultGetter(sparkEnv, scheduler) {
   var removedResult = false
 
@@ -72,6 +78,31 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule
   }
 }
 
+
+/**
+ * A [[TaskResultGetter]] that stores the [[DirectTaskResult]]s it receives from executors
+ * _before_ modifying the results in any way.
+ */
+private class MyTaskResultGetter(env: SparkEnv, scheduler: TaskSchedulerImpl)
+  extends TaskResultGetter(env, scheduler) {
+
+  // Use the current thread so we can access its results synchronously
+  protected override val getTaskResultExecutor = MoreExecutors.sameThreadExecutor()
+
+  // DirectTaskResults that we receive from the executors
+  private val _taskResults = new ArrayBuffer[DirectTaskResult[_]]
+
+  def taskResults: Seq[DirectTaskResult[_]] = _taskResults
+
+  override def enqueueSuccessfulTask(tsm: TaskSetManager, tid: Long, data: ByteBuffer): Unit = {
+    // work on a copy since the super class still needs to use the buffer
+    val newBuffer = data.duplicate()
+    _taskResults += env.closureSerializer.newInstance().deserialize[DirectTaskResult[_]](newBuffer)
+    super.enqueueSuccessfulTask(tsm, tid, data)
+  }
+}
+
+
 /**
  * Tests related to handling task results (both direct and indirect).
  */
@@ -182,5 +213,39 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
       Thread.currentThread.setContextClassLoader(originalClassLoader)
     }
   }
+
+  test("task result size is set on the driver, not the executors") {
+    import InternalAccumulator._
+
+    // Set up custom TaskResultGetter and TaskSchedulerImpl spy
+    sc = new SparkContext("local", "test", conf)
+    val scheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
+    val spyScheduler = spy(scheduler)
+    val resultGetter = new MyTaskResultGetter(sc.env, spyScheduler)
+    val newDAGScheduler = new DAGScheduler(sc, spyScheduler)
+    scheduler.taskResultGetter = resultGetter
+    sc.dagScheduler = newDAGScheduler
+    sc.taskScheduler = spyScheduler
+    sc.taskScheduler.setDAGScheduler(newDAGScheduler)
+
+    // Just run 1 task and capture the corresponding DirectTaskResult
+    sc.parallelize(1 to 1, 1).count()
+    val captor = ArgumentCaptor.forClass(classOf[DirectTaskResult[_]])
+    verify(spyScheduler, times(1)).handleSuccessfulTask(any(), anyLong(), captor.capture())
+
+    // When a task finishes, the executor sends a serialized DirectTaskResult to the driver
+    // without setting the result size so as to avoid serializing the result again. Instead,
+    // the result size is set later in TaskResultGetter on the driver before passing the
+    // DirectTaskResult on to TaskSchedulerImpl. In this test, we capture the DirectTaskResult
+    // before and after the result size is set.
+    assert(resultGetter.taskResults.size === 1)
+    val resBefore = resultGetter.taskResults.head
+    val resAfter = captor.getValue
+    val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
+    val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
+    assert(resSizeBefore.exists(_ == 0L))
+    assert(resSizeAfter.exists(_.toString.toLong > 0L))
+  }
+
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index ecc18fc..a2e7436 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -24,7 +24,6 @@ import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark._
-import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.util.ManualClock
 
 class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
@@ -38,9 +37,8 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
       task: Task[_],
       reason: TaskEndReason,
       result: Any,
-      accumUpdates: Map[Long, Any],
-      taskInfo: TaskInfo,
-      taskMetrics: TaskMetrics) {
+      accumUpdates: Seq[AccumulableInfo],
+      taskInfo: TaskInfo) {
     taskScheduler.endedTasks(taskInfo.index) = reason
   }
 
@@ -167,14 +165,17 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     val taskSet = FakeTask.createTaskSet(1)
     val clock = new ManualClock
     val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
+    val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a =>
+      new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues)
+    }
 
     // Offer a host with NO_PREF as the constraint,
     // we should get a nopref task immediately since that's what we only have
-    var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
+    val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
     assert(taskOption.isDefined)
 
     // Tell it the task has finished
-    manager.handleSuccessfulTask(0, createTaskResult(0))
+    manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates))
     assert(sched.endedTasks(0) === Success)
     assert(sched.finishedManagers.contains(manager))
   }
@@ -184,10 +185,15 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
     val taskSet = FakeTask.createTaskSet(3)
     val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
+    val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task =>
+      task.initialAccumulators.map { a =>
+        new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues)
+      }
+    }
 
     // First three offers should all find tasks
     for (i <- 0 until 3) {
-      var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
+      val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
       assert(taskOption.isDefined)
       val task = taskOption.get
       assert(task.executorId === "exec1")
@@ -198,14 +204,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     assert(manager.resourceOffer("exec1", "host1", NO_PREF) === None)
 
     // Finish the first two tasks
-    manager.handleSuccessfulTask(0, createTaskResult(0))
-    manager.handleSuccessfulTask(1, createTaskResult(1))
+    manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdatesByTask(0)))
+    manager.handleSuccessfulTask(1, createTaskResult(1, accumUpdatesByTask(1)))
     assert(sched.endedTasks(0) === Success)
     assert(sched.endedTasks(1) === Success)
     assert(!sched.finishedManagers.contains(manager))
 
     // Finish the last task
-    manager.handleSuccessfulTask(2, createTaskResult(2))
+    manager.handleSuccessfulTask(2, createTaskResult(2, accumUpdatesByTask(2)))
     assert(sched.endedTasks(2) === Success)
     assert(sched.finishedManagers.contains(manager))
   }
@@ -620,7 +626,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
 
     // multiple 1k result
     val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect()
-    assert(10 === r.size )
+    assert(10 === r.size)
 
     // single 10M result
     val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()}
@@ -761,7 +767,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     // Regression test for SPARK-2931
     sc = new SparkContext("local", "test")
     val sched = new FakeTaskScheduler(sc,
-        ("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
+      ("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
     val taskSet = FakeTask.createTaskSet(3,
       Seq(TaskLocation("host1")),
       Seq(TaskLocation("host2")),
@@ -786,8 +792,10 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     assert(TaskLocation("executor_host1_3") === ExecutorCacheTaskLocation("host1", "3"))
   }
 
-  def createTaskResult(id: Int): DirectTaskResult[Int] = {
+  private def createTaskResult(
+      id: Int,
+      accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]): DirectTaskResult[Int] = {
     val valueSer = SparkEnv.get.serializer.newInstance()
-    new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)
+    new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
index 86699e7..b83ffa3 100644
--- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -31,6 +31,8 @@ import org.apache.spark.ui.scope.RDDOperationGraphListener
 
 class StagePageSuite extends SparkFunSuite with LocalSparkContext {
 
+  private val peakExecutionMemory = 10
+
   test("peak execution memory only displayed if unsafe is enabled") {
     val unsafeConf = "spark.sql.unsafe.enabled"
     val conf = new SparkConf(false).set(unsafeConf, "true")
@@ -52,7 +54,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext {
     val conf = new SparkConf(false).set(unsafeConf, "true")
     val html = renderStagePage(conf).toString().toLowerCase
     // verify min/25/50/75/max show task value not cumulative values
-    assert(html.contains("<td>10.0 b</td>" * 5))
+    assert(html.contains(s"<td>$peakExecutionMemory.0 b</td>" * 5))
   }
 
   /**
@@ -79,14 +81,13 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext {
     (1 to 2).foreach {
       taskId =>
         val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
-        val peakExecutionMemory = 10
-        taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY,
-          Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true)
         jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
         jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
         taskInfo.markSuccessful()
+        val taskMetrics = TaskMetrics.empty
+        taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
         jobListener.onTaskEnd(
-          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty))
+          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics))
     }
     jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
     page.render(request)

http://git-wip-us.apache.org/repos/asf/spark/blob/87abcf7d/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 607617c..18a16a2 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -240,7 +240,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
     val taskFailedReasons = Seq(
       Resubmitted,
       new FetchFailed(null, 0, 0, 0, "ignored"),
-      ExceptionFailure("Exception", "description", null, null, None, None),
+      ExceptionFailure("Exception", "description", null, null, None),
       TaskResultLost,
       TaskKilled,
       ExecutorLostFailure("0", true, Some("Induced failure")),
@@ -269,20 +269,22 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
     val execId = "exe-1"
 
     def makeTaskMetrics(base: Int): TaskMetrics = {
-      val taskMetrics = new TaskMetrics()
-      taskMetrics.setExecutorRunTime(base + 4)
-      taskMetrics.incDiskBytesSpilled(base + 5)
-      taskMetrics.incMemoryBytesSpilled(base + 6)
+      val accums = InternalAccumulator.create()
+      accums.foreach(Accumulators.register)
+      val taskMetrics = new TaskMetrics(accums)
       val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics()
+      val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics()
+      val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop)
+      val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop)
       shuffleReadMetrics.incRemoteBytesRead(base + 1)
       shuffleReadMetrics.incLocalBytesRead(base + 9)
       shuffleReadMetrics.incRemoteBlocksFetched(base + 2)
       taskMetrics.mergeShuffleReadMetrics()
-      val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics()
       shuffleWriteMetrics.incBytesWritten(base + 3)
-      val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop)
-      inputMetrics.incBytesRead(base + 7)
-      val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop)
+      taskMetrics.setExecutorRunTime(base + 4)
+      taskMetrics.incDiskBytesSpilled(base + 5)
+      taskMetrics.incMemoryBytesSpilled(base + 6)
+      inputMetrics.setBytesRead(base + 7)
       outputMetrics.setBytesWritten(base + 8)
       taskMetrics
     }
@@ -300,9 +302,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
     listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L)))
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array(
-      (1234L, 0, 0, makeTaskMetrics(0)),
-      (1235L, 0, 0, makeTaskMetrics(100)),
-      (1236L, 1, 0, makeTaskMetrics(200)))))
+      (1234L, 0, 0, makeTaskMetrics(0).accumulatorUpdates()),
+      (1235L, 0, 0, makeTaskMetrics(100).accumulatorUpdates()),
+      (1236L, 1, 0, makeTaskMetrics(200).accumulatorUpdates()))))
 
     var stage0Data = listener.stageIdToData.get((0, 0)).get
     var stage1Data = listener.stageIdToData.get((1, 0)).get


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org