You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by an...@apache.org on 2016/01/19 19:58:56 UTC

spark git commit: [SPARK-12887] Do not expose var's in TaskMetrics

Repository: spark
Updated Branches:
  refs/heads/master e14817b52 -> b122c861c


[SPARK-12887] Do not expose var's in TaskMetrics

This is a step in implementing SPARK-10620, which migrates TaskMetrics to accumulators.

TaskMetrics has a bunch of var's, some are fully public, some are `private[spark]`. This is bad coding style that makes it easy to accidentally overwrite previously set metrics. This has happened a few times in the past and caused bugs that were difficult to debug.

Instead, we should have get-or-create semantics, which are more readily understandable. This makes sense in the case of TaskMetrics because these are just aggregated metrics that we want to collect throughout the task, so it doesn't matter who's incrementing them.

Parent PR: #10717

Author: Andrew Or <an...@databricks.com>
Author: Josh Rosen <jo...@databricks.com>
Author: andrewor14 <an...@databricks.com>

Closes #10815 from andrewor14/get-or-create-metrics.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b122c861
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b122c861
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b122c861

Branch: refs/heads/master
Commit: b122c861cd72b580334a7532f0a52c0439552bdf
Parents: e14817b
Author: Andrew Or <an...@databricks.com>
Authored: Tue Jan 19 10:58:51 2016 -0800
Committer: Andrew Or <an...@databricks.com>
Committed: Tue Jan 19 10:58:51 2016 -0800

----------------------------------------------------------------------
 .../sort/BypassMergeSortShuffleWriter.java      |   3 +-
 .../spark/shuffle/sort/UnsafeShuffleWriter.java |   3 +-
 .../unsafe/sort/UnsafeExternalSorter.java       |   4 +-
 .../scala/org/apache/spark/CacheManager.scala   |   6 +-
 .../org/apache/spark/executor/Executor.scala    |   2 +-
 .../org/apache/spark/executor/TaskMetrics.scala | 180 ++++++++++++-------
 .../scala/org/apache/spark/rdd/HadoopRDD.scala  |   2 +-
 .../org/apache/spark/rdd/NewHadoopRDD.scala     |   3 +-
 .../org/apache/spark/rdd/PairRDDFunctions.scala |  46 +++--
 .../spark/shuffle/BlockStoreShuffleReader.scala |   4 +-
 .../spark/shuffle/hash/HashShuffleWriter.scala  |   3 +-
 .../spark/shuffle/sort/SortShuffleWriter.scala  |   6 +-
 .../org/apache/spark/storage/BlockManager.scala |  12 +-
 .../storage/ShuffleBlockFetcherIterator.scala   |   2 +-
 .../spark/storage/StorageStatusListener.scala   |   2 +-
 .../apache/spark/ui/storage/StorageTab.scala    |   4 +-
 .../org/apache/spark/util/JsonProtocol.scala    | 165 ++++++++---------
 .../spark/util/collection/ExternalSorter.scala  |  10 +-
 .../org/apache/spark/CacheManagerSuite.scala    |   2 +-
 .../spark/executor/TaskMetricsSuite.scala       |   2 +-
 .../spark/storage/BlockManagerSuite.scala       |   2 +-
 .../storage/StorageStatusListenerSuite.scala    |  12 +-
 .../ui/jobs/JobProgressListenerSuite.scala      |  23 ++-
 .../spark/ui/storage/StorageTabSuite.scala      |   8 +-
 .../apache/spark/util/JsonProtocolSuite.scala   |  17 +-
 .../execution/datasources/SqlNewHadoopRDD.scala |   3 +-
 .../execution/UnsafeRowSerializerSuite.scala    |   1 -
 27 files changed, 281 insertions(+), 246 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index a06dc1c..dc4f289 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -114,8 +114,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     this.shuffleId = dep.shuffleId();
     this.partitioner = dep.partitioner();
     this.numPartitions = partitioner.numPartitions();
-    this.writeMetrics = new ShuffleWriteMetrics();
-    taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+    this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics();
     this.serializer = Serializer.getSerializer(dep.serializer());
     this.shuffleBlockResolver = shuffleBlockResolver;
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index c8cc705..d3d79a2 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -119,8 +119,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     this.shuffleId = dep.shuffleId();
     this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
     this.partitioner = dep.partitioner();
-    this.writeMetrics = new ShuffleWriteMetrics();
-    taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+    this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics();
     this.taskContext = taskContext;
     this.sparkConf = sparkConf;
     this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 68dc0c6..a6edc1a 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -122,9 +122,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
     // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
     this.fileBufferSizeBytes = 32 * 1024;
-    // TODO: metrics tracking + integration with shuffle write metrics
-    // need to connect the write metrics to task metrics so we count the spill IO somewhere.
-    this.writeMetrics = new ShuffleWriteMetrics();
+    this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics();
 
     if (existingInMemorySorter == null) {
       this.inMemSorter = new UnsafeInMemorySorter(

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/CacheManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index d92d8b0..fa8e2b9 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -43,8 +43,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
     blockManager.get(key) match {
       case Some(blockResult) =>
         // Partition is already materialized, so just return its values
-        val existingMetrics = context.taskMetrics
-          .getInputMetricsForReadMethod(blockResult.readMethod)
+        val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod)
         existingMetrics.incBytesRead(blockResult.bytes)
 
         val iter = blockResult.data.asInstanceOf[Iterator[T]]
@@ -66,11 +65,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
         try {
           logInfo(s"Partition $key not found, computing it")
           val computedValues = rdd.computeOrReadCheckpoint(partition, context)
-
-          // Otherwise, cache the values
           val cachedValues = putInBlockManager(key, computedValues, storageLevel)
           new InterruptibleIterator(context, cachedValues)
-
         } finally {
           loading.synchronized {
             loading.remove(key)

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 9b14184..75d7e34 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -425,7 +425,7 @@ private[spark] class Executor(
     for (taskRunner <- runningTasks.values().asScala) {
       if (taskRunner.task != null) {
         taskRunner.task.metrics.foreach { metrics =>
-          metrics.updateShuffleReadMetrics()
+          metrics.mergeShuffleReadMetrics()
           metrics.updateInputMetrics()
           metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
           metrics.updateAccumulators()

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index ce1fcbf..32ef5a9 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -102,15 +102,38 @@ class TaskMetrics extends Serializable {
   private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value
   private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value
 
-  /**
-   * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read
-   * are stored here.
-   */
   private var _inputMetrics: Option[InputMetrics] = None
 
+  /**
+   * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted
+   * data, defined only in tasks with input.
+   */
   def inputMetrics: Option[InputMetrics] = _inputMetrics
 
   /**
+   * Get or create a new [[InputMetrics]] associated with this task.
+   */
+  private[spark] def registerInputMetrics(readMethod: DataReadMethod.Value): InputMetrics = {
+    synchronized {
+      val metrics = _inputMetrics.getOrElse {
+        val metrics = new InputMetrics(readMethod)
+        _inputMetrics = Some(metrics)
+        metrics
+      }
+      // If there already exists an InputMetric with the same read method, we can just return
+      // that one. Otherwise, if the read method is different from the one previously seen by
+      // this task, we return a new dummy one to avoid clobbering the values of the old metrics.
+      // In the future we should try to store input metrics from all different read methods at
+      // the same time (SPARK-5225).
+      if (metrics.readMethod == readMethod) {
+        metrics
+      } else {
+        new InputMetrics(readMethod)
+      }
+    }
+  }
+
+  /**
    * This should only be used when recreating TaskMetrics, not when updating input metrics in
    * executors
    */
@@ -118,18 +141,37 @@ class TaskMetrics extends Serializable {
     _inputMetrics = inputMetrics
   }
 
+  private var _outputMetrics: Option[OutputMetrics] = None
+
   /**
-   * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much
-   * data was written are stored here.
+   * Metrics related to writing data externally (e.g. to a distributed filesystem),
+   * defined only in tasks with output.
    */
-  var outputMetrics: Option[OutputMetrics] = None
+  def outputMetrics: Option[OutputMetrics] = _outputMetrics
+
+  @deprecated("setting OutputMetrics is for internal use only", "2.0.0")
+  def outputMetrics_=(om: Option[OutputMetrics]): Unit = {
+    _outputMetrics = om
+  }
 
   /**
-   * If this task reads from shuffle output, metrics on getting shuffle data will be collected here.
-   * This includes read metrics aggregated over all the task's shuffle dependencies.
+   * Get or create a new [[OutputMetrics]] associated with this task.
    */
+  private[spark] def registerOutputMetrics(
+      writeMethod: DataWriteMethod.Value): OutputMetrics = synchronized {
+    _outputMetrics.getOrElse {
+      val metrics = new OutputMetrics(writeMethod)
+      _outputMetrics = Some(metrics)
+      metrics
+    }
+  }
+
   private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None
 
+  /**
+   * Metrics related to shuffle read aggregated across all shuffle dependencies.
+   * This is defined only if there are shuffle dependencies in this task.
+   */
   def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics
 
   /**
@@ -141,66 +183,35 @@ class TaskMetrics extends Serializable {
   }
 
   /**
-   * ShuffleReadMetrics per dependency for collecting independently while task is in progress.
-   */
-  @transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] =
-    new ArrayBuffer[ShuffleReadMetrics]()
-
-  /**
-   * If this task writes to shuffle output, metrics on the written shuffle data will be collected
-   * here
-   */
-  var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
-
-  /**
-   * Storage statuses of any blocks that have been updated as a result of this task.
-   */
-  var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None
-
-  /**
+   * Temporary list of [[ShuffleReadMetrics]], one per shuffle dependency.
+   *
    * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization
-   * issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each
-   * dependency, and merge these metrics before reporting them to the driver. This method returns
-   * a ShuffleReadMetrics for a dependency and registers it for merging later.
-   */
-  private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized {
-    val readMetrics = new ShuffleReadMetrics()
-    depsShuffleReadMetrics += readMetrics
-    readMetrics
-  }
+   * issues from readers in different threads, in-progress tasks use a [[ShuffleReadMetrics]] for
+   * each dependency and merge these metrics before reporting them to the driver.
+  */
+  @transient private lazy val tempShuffleReadMetrics = new ArrayBuffer[ShuffleReadMetrics]
 
   /**
-   * Returns the input metrics object that the task should use. Currently, if
-   * there exists an input metric with the same readMethod, we return that one
-   * so the caller can accumulate bytes read. If the readMethod is different
-   * than previously seen by this task, we return a new InputMetric but don't
-   * record it.
+   * Create a temporary [[ShuffleReadMetrics]] for a particular shuffle dependency.
    *
-   * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed,
-   * we can store all the different inputMetrics (one per readMethod).
+   * All usages are expected to be followed by a call to [[mergeShuffleReadMetrics]], which
+   * merges the temporary values synchronously. Otherwise, all temporary data collected will
+   * be lost.
    */
-  private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): InputMetrics = {
-    synchronized {
-      _inputMetrics match {
-        case None =>
-          val metrics = new InputMetrics(readMethod)
-          _inputMetrics = Some(metrics)
-          metrics
-        case Some(metrics @ InputMetrics(method)) if method == readMethod =>
-          metrics
-        case Some(InputMetrics(method)) =>
-          new InputMetrics(readMethod)
-      }
-    }
+  private[spark] def registerTempShuffleReadMetrics(): ShuffleReadMetrics = synchronized {
+    val readMetrics = new ShuffleReadMetrics
+    tempShuffleReadMetrics += readMetrics
+    readMetrics
   }
 
   /**
-   * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
+   * Merge values across all temporary [[ShuffleReadMetrics]] into `_shuffleReadMetrics`.
+   * This is expected to be called on executor heartbeat and at the end of a task.
    */
-  private[spark] def updateShuffleReadMetrics(): Unit = synchronized {
-    if (!depsShuffleReadMetrics.isEmpty) {
-      val merged = new ShuffleReadMetrics()
-      for (depMetrics <- depsShuffleReadMetrics) {
+  private[spark] def mergeShuffleReadMetrics(): Unit = synchronized {
+    if (tempShuffleReadMetrics.nonEmpty) {
+      val merged = new ShuffleReadMetrics
+      for (depMetrics <- tempShuffleReadMetrics) {
         merged.incFetchWaitTime(depMetrics.fetchWaitTime)
         merged.incLocalBlocksFetched(depMetrics.localBlocksFetched)
         merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched)
@@ -212,6 +223,55 @@ class TaskMetrics extends Serializable {
     }
   }
 
+  private var _shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
+
+  /**
+   * Metrics related to shuffle write, defined only in shuffle map stages.
+   */
+  def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics
+
+  @deprecated("setting ShuffleWriteMetrics is for internal use only", "2.0.0")
+  def shuffleWriteMetrics_=(swm: Option[ShuffleWriteMetrics]): Unit = {
+    _shuffleWriteMetrics = swm
+  }
+
+  /**
+   * Get or create a new [[ShuffleWriteMetrics]] associated with this task.
+   */
+  private[spark] def registerShuffleWriteMetrics(): ShuffleWriteMetrics = synchronized {
+    _shuffleWriteMetrics.getOrElse {
+      val metrics = new ShuffleWriteMetrics
+      _shuffleWriteMetrics = Some(metrics)
+      metrics
+    }
+  }
+
+  private var _updatedBlockStatuses: Seq[(BlockId, BlockStatus)] =
+    Seq.empty[(BlockId, BlockStatus)]
+
+  /**
+   * Storage statuses of any blocks that have been updated as a result of this task.
+   */
+  def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses
+
+  @deprecated("setting updated blocks is for internal use only", "2.0.0")
+  def updatedBlocks_=(ub: Option[Seq[(BlockId, BlockStatus)]]): Unit = {
+    _updatedBlockStatuses = ub.getOrElse(Seq.empty[(BlockId, BlockStatus)])
+  }
+
+  private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = {
+    _updatedBlockStatuses ++= v
+  }
+
+  private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = {
+    _updatedBlockStatuses = v
+  }
+
+  @deprecated("use updatedBlockStatuses instead", "2.0.0")
+  def updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = {
+    if (_updatedBlockStatuses.nonEmpty) Some(_updatedBlockStatuses) else None
+  }
+
   private[spark] def updateInputMetrics(): Unit = synchronized {
     inputMetrics.foreach(_.updateBytesRead())
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index a7a6e0b..a79ab86 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -212,7 +212,7 @@ class HadoopRDD[K, V](
       logInfo("Input split: " + split.inputSplit)
       val jobConf = getJobConf()
 
-      val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+      val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
 
       // Sets the thread local variable for the file's name
       split.inputSplit.value match {

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 7a11978..5cc9c81 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -129,8 +129,7 @@ class NewHadoopRDD[K, V](
       logInfo("Input split: " + split.serializableHadoopSplit)
       val conf = getConf
 
-      val inputMetrics = context.taskMetrics
-        .getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+      val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
 
       // Find a function that will return the FileSystem bytes read by this thread. Do this before
       // creating RecordReader, because RecordReader's constructor might read some bytes

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 16a856f..33f2f0b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -1092,7 +1092,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
       val committer = format.getOutputCommitter(hadoopContext)
       committer.setupTask(hadoopContext)
 
-      val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
+      val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] =
+        initHadoopOutputMetrics(context)
 
       val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]]
       require(writer != null, "Unable to obtain RecordWriter")
@@ -1103,15 +1104,17 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
           writer.write(pair._1, pair._2)
 
           // Update bytes written metric every few records
-          maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
+          maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten)
           recordsWritten += 1
         }
       } {
         writer.close(hadoopContext)
       }
       committer.commitTask(hadoopContext)
-      bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
-      outputMetrics.setRecordsWritten(recordsWritten)
+      outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) =>
+        om.setBytesWritten(callback())
+        om.setRecordsWritten(recordsWritten)
+      }
       1
     } : Int
 
@@ -1177,7 +1180,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
       // around by taking a mod. We expect that no task will be attempted 2 billion times.
       val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt
 
-      val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
+      val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] =
+        initHadoopOutputMetrics(context)
 
       writer.setup(context.stageId, context.partitionId, taskAttemptId)
       writer.open()
@@ -1189,35 +1193,43 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
           writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
 
           // Update bytes written metric every few records
-          maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
+          maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten)
           recordsWritten += 1
         }
       } {
         writer.close()
       }
       writer.commit()
-      bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
-      outputMetrics.setRecordsWritten(recordsWritten)
+      outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) =>
+        om.setBytesWritten(callback())
+        om.setRecordsWritten(recordsWritten)
+      }
     }
 
     self.context.runJob(self, writeToFile)
     writer.commitJob()
   }
 
-  private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = {
+  // TODO: these don't seem like the right abstractions.
+  // We should abstract the duplicate code in a less awkward way.
+
+  // return type: (output metrics, bytes written callback), defined only if the latter is defined
+  private def initHadoopOutputMetrics(
+      context: TaskContext): Option[(OutputMetrics, () => Long)] = {
     val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback()
-    val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
-    if (bytesWrittenCallback.isDefined) {
-      context.taskMetrics.outputMetrics = Some(outputMetrics)
+    bytesWrittenCallback.map { b =>
+      (context.taskMetrics().registerOutputMetrics(DataWriteMethod.Hadoop), b)
     }
-    (outputMetrics, bytesWrittenCallback)
   }
 
-  private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long],
-      outputMetrics: OutputMetrics, recordsWritten: Long): Unit = {
+  private def maybeUpdateOutputMetrics(
+      outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)],
+      recordsWritten: Long): Unit = {
     if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) {
-      bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
-      outputMetrics.setRecordsWritten(recordsWritten)
+      outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) =>
+        om.setBytesWritten(callback())
+        om.setRecordsWritten(recordsWritten)
+      }
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index b0abda4..a57e5b0 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -65,13 +65,13 @@ private[spark] class BlockStoreShuffleReader[K, C](
     }
 
     // Update the context task metrics for each record read.
-    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+    val readMetrics = context.taskMetrics.registerTempShuffleReadMetrics()
     val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
       recordIter.map(record => {
         readMetrics.incRecordsRead(1)
         record
       }),
-      context.taskMetrics().updateShuffleReadMetrics())
+      context.taskMetrics().mergeShuffleReadMetrics())
 
     // An interruptible iterator must be used here in order to support task cancellation
     val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 412bf70..28bcced 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -42,8 +42,7 @@ private[spark] class HashShuffleWriter[K, V](
   // we don't try deleting files, etc twice.
   private var stopping = false
 
-  private val writeMetrics = new ShuffleWriteMetrics()
-  metrics.shuffleWriteMetrics = Some(writeMetrics)
+  private val writeMetrics = metrics.registerShuffleWriteMetrics()
 
   private val blockManager = SparkEnv.get.blockManager
   private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 5c5a5f5..7eb3d96 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -45,8 +45,7 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   private var mapStatus: MapStatus = null
 
-  private val writeMetrics = new ShuffleWriteMetrics()
-  context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics)
+  private val writeMetrics = context.taskMetrics().registerShuffleWriteMetrics()
 
   /** Write a bunch of records to this task's output */
   override def write(records: Iterator[Product2[K, V]]): Unit = {
@@ -93,8 +92,7 @@ private[spark] class SortShuffleWriter[K, V, C](
       if (sorter != null) {
         val startTime = System.nanoTime()
         sorter.stop()
-        context.taskMetrics.shuffleWriteMetrics.foreach(
-          _.incWriteTime(System.nanoTime - startTime))
+        writeMetrics.incWriteTime(System.nanoTime - startTime)
         sorter = null
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index e0a8e88..77fd03a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -800,10 +800,8 @@ private[spark] class BlockManager(
           if (tellMaster) {
             reportBlockStatus(blockId, putBlockInfo, putBlockStatus)
           }
-          Option(TaskContext.get()).foreach { taskContext =>
-            val metrics = taskContext.taskMetrics()
-            val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
-            metrics.updatedBlocks = Some(lastUpdatedBlocks ++ Seq((blockId, putBlockStatus)))
+          Option(TaskContext.get()).foreach { c =>
+            c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus)))
           }
         }
       } finally {
@@ -1046,10 +1044,8 @@ private[spark] class BlockManager(
           blockInfo.remove(blockId)
         }
         if (blockIsUpdated) {
-          Option(TaskContext.get()).foreach { taskContext =>
-            val metrics = taskContext.taskMetrics()
-            val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
-            metrics.updatedBlocks = Some(lastUpdatedBlocks ++ Seq((blockId, status)))
+          Option(TaskContext.get()).foreach { c =>
+            c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, status)))
           }
         }
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 037bec1..c6065df 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -101,7 +101,7 @@ final class ShuffleBlockFetcherIterator(
   /** Current bytes in flight from our requests */
   private[this] var bytesInFlight = 0L
 
-  private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()
+  private[this] val shuffleMetrics = context.taskMetrics().registerTempShuffleReadMetrics()
 
   /**
    * Whether the iterator is still active. If isZombie is true, the callback interface will no

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
index ec71148..d98aae8 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
@@ -63,7 +63,7 @@ class StorageStatusListener extends SparkListener {
     val info = taskEnd.taskInfo
     val metrics = taskEnd.taskMetrics
     if (info != null && metrics != null) {
-      val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
+      val updatedBlocks = metrics.updatedBlockStatuses
       if (updatedBlocks.length > 0) {
         updateStorageStatus(info.executorId, updatedBlocks)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index 2d9b885..f1e28b4 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -63,8 +63,8 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc
    */
   override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
     val metrics = taskEnd.taskMetrics
-    if (metrics != null && metrics.updatedBlocks.isDefined) {
-      updateRDDInfo(metrics.updatedBlocks.get)
+    if (metrics != null && metrics.updatedBlockStatuses.nonEmpty) {
+      updateRDDInfo(metrics.updatedBlockStatuses)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index b88221a..efa22b9 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -292,21 +292,38 @@ private[spark] object JsonProtocol {
   }
 
   def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = {
-    val shuffleReadMetrics =
-      taskMetrics.shuffleReadMetrics.map(shuffleReadMetricsToJson).getOrElse(JNothing)
-    val shuffleWriteMetrics =
-      taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing)
-    val inputMetrics =
-      taskMetrics.inputMetrics.map(inputMetricsToJson).getOrElse(JNothing)
-    val outputMetrics =
-      taskMetrics.outputMetrics.map(outputMetricsToJson).getOrElse(JNothing)
-    val updatedBlocks =
-      taskMetrics.updatedBlocks.map { blocks =>
-        JArray(blocks.toList.map { case (id, status) =>
-          ("Block ID" -> id.toString) ~
-          ("Status" -> blockStatusToJson(status))
-        })
+    val shuffleReadMetrics: JValue =
+      taskMetrics.shuffleReadMetrics.map { rm =>
+        ("Remote Blocks Fetched" -> rm.remoteBlocksFetched) ~
+        ("Local Blocks Fetched" -> rm.localBlocksFetched) ~
+        ("Fetch Wait Time" -> rm.fetchWaitTime) ~
+        ("Remote Bytes Read" -> rm.remoteBytesRead) ~
+        ("Local Bytes Read" -> rm.localBytesRead) ~
+        ("Total Records Read" -> rm.recordsRead)
+      }.getOrElse(JNothing)
+    val shuffleWriteMetrics: JValue =
+      taskMetrics.shuffleWriteMetrics.map { wm =>
+        ("Shuffle Bytes Written" -> wm.shuffleBytesWritten) ~
+        ("Shuffle Write Time" -> wm.shuffleWriteTime) ~
+        ("Shuffle Records Written" -> wm.shuffleRecordsWritten)
+      }.getOrElse(JNothing)
+    val inputMetrics: JValue =
+      taskMetrics.inputMetrics.map { im =>
+        ("Data Read Method" -> im.readMethod.toString) ~
+        ("Bytes Read" -> im.bytesRead) ~
+        ("Records Read" -> im.recordsRead)
+      }.getOrElse(JNothing)
+    val outputMetrics: JValue =
+      taskMetrics.outputMetrics.map { om =>
+        ("Data Write Method" -> om.writeMethod.toString) ~
+        ("Bytes Written" -> om.bytesWritten) ~
+        ("Records Written" -> om.recordsWritten)
       }.getOrElse(JNothing)
+    val updatedBlocks =
+      JArray(taskMetrics.updatedBlockStatuses.toList.map { case (id, status) =>
+        ("Block ID" -> id.toString) ~
+        ("Status" -> blockStatusToJson(status))
+      })
     ("Host Name" -> taskMetrics.hostname) ~
     ("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~
     ("Executor Run Time" -> taskMetrics.executorRunTime) ~
@@ -322,34 +339,6 @@ private[spark] object JsonProtocol {
     ("Updated Blocks" -> updatedBlocks)
   }
 
-  def shuffleReadMetricsToJson(shuffleReadMetrics: ShuffleReadMetrics): JValue = {
-    ("Remote Blocks Fetched" -> shuffleReadMetrics.remoteBlocksFetched) ~
-    ("Local Blocks Fetched" -> shuffleReadMetrics.localBlocksFetched) ~
-    ("Fetch Wait Time" -> shuffleReadMetrics.fetchWaitTime) ~
-    ("Remote Bytes Read" -> shuffleReadMetrics.remoteBytesRead) ~
-    ("Local Bytes Read" -> shuffleReadMetrics.localBytesRead) ~
-    ("Total Records Read" -> shuffleReadMetrics.recordsRead)
-  }
-
-  // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes.
-  def shuffleWriteMetricsToJson(shuffleWriteMetrics: ShuffleWriteMetrics): JValue = {
-    ("Shuffle Bytes Written" -> shuffleWriteMetrics.bytesWritten) ~
-    ("Shuffle Write Time" -> shuffleWriteMetrics.writeTime) ~
-    ("Shuffle Records Written" -> shuffleWriteMetrics.recordsWritten)
-  }
-
-  def inputMetricsToJson(inputMetrics: InputMetrics): JValue = {
-    ("Data Read Method" -> inputMetrics.readMethod.toString) ~
-    ("Bytes Read" -> inputMetrics.bytesRead) ~
-    ("Records Read" -> inputMetrics.recordsRead)
-  }
-
-  def outputMetricsToJson(outputMetrics: OutputMetrics): JValue = {
-    ("Data Write Method" -> outputMetrics.writeMethod.toString) ~
-    ("Bytes Written" -> outputMetrics.bytesWritten) ~
-    ("Records Written" -> outputMetrics.recordsWritten)
-  }
-
   def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = {
     val reason = Utils.getFormattedClassName(taskEndReason)
     val json: JObject = taskEndReason match {
@@ -721,58 +710,54 @@ private[spark] object JsonProtocol {
     metrics.setResultSerializationTime((json \ "Result Serialization Time").extract[Long])
     metrics.incMemoryBytesSpilled((json \ "Memory Bytes Spilled").extract[Long])
     metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long])
-    metrics.setShuffleReadMetrics(
-      Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson))
-    metrics.shuffleWriteMetrics =
-      Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
-    metrics.setInputMetrics(
-      Utils.jsonOption(json \ "Input Metrics").map(inputMetricsFromJson))
-    metrics.outputMetrics =
-      Utils.jsonOption(json \ "Output Metrics").map(outputMetricsFromJson)
-    metrics.updatedBlocks =
-      Utils.jsonOption(json \ "Updated Blocks").map { value =>
-        value.extract[List[JValue]].map { block =>
-          val id = BlockId((block \ "Block ID").extract[String])
-          val status = blockStatusFromJson(block \ "Status")
-          (id, status)
-        }
-      }
-    metrics
-  }
 
-  def shuffleReadMetricsFromJson(json: JValue): ShuffleReadMetrics = {
-    val metrics = new ShuffleReadMetrics
-    metrics.incRemoteBlocksFetched((json \ "Remote Blocks Fetched").extract[Int])
-    metrics.incLocalBlocksFetched((json \ "Local Blocks Fetched").extract[Int])
-    metrics.incFetchWaitTime((json \ "Fetch Wait Time").extract[Long])
-    metrics.incRemoteBytesRead((json \ "Remote Bytes Read").extract[Long])
-    metrics.incLocalBytesRead((json \ "Local Bytes Read").extractOpt[Long].getOrElse(0))
-    metrics.incRecordsRead((json \ "Total Records Read").extractOpt[Long].getOrElse(0))
-    metrics
-  }
+    // Shuffle read metrics
+    Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson =>
+      val readMetrics = metrics.registerTempShuffleReadMetrics()
+      readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int])
+      readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int])
+      readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long])
+      readMetrics.incLocalBytesRead((readJson \ "Local Bytes Read").extractOpt[Long].getOrElse(0L))
+      readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long])
+      readMetrics.incRecordsRead((readJson \ "Total Records Read").extractOpt[Long].getOrElse(0L))
+      metrics.mergeShuffleReadMetrics()
+    }
 
-  def shuffleWriteMetricsFromJson(json: JValue): ShuffleWriteMetrics = {
-    val metrics = new ShuffleWriteMetrics
-    metrics.incBytesWritten((json \ "Shuffle Bytes Written").extract[Long])
-    metrics.incWriteTime((json \ "Shuffle Write Time").extract[Long])
-    metrics.setRecordsWritten((json \ "Shuffle Records Written")
-      .extractOpt[Long].getOrElse(0))
-    metrics
-  }
+    // Shuffle write metrics
+    // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes.
+    Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson =>
+      val writeMetrics = metrics.registerShuffleWriteMetrics()
+      writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long])
+      writeMetrics.incRecordsWritten((writeJson \ "Shuffle Records Written")
+        .extractOpt[Long].getOrElse(0L))
+      writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long])
+    }
 
-  def inputMetricsFromJson(json: JValue): InputMetrics = {
-    val metrics = new InputMetrics(
-      DataReadMethod.withName((json \ "Data Read Method").extract[String]))
-    metrics.incBytesRead((json \ "Bytes Read").extract[Long])
-    metrics.incRecordsRead((json \ "Records Read").extractOpt[Long].getOrElse(0))
-    metrics
-  }
+    // Output metrics
+    Utils.jsonOption(json \ "Output Metrics").foreach { outJson =>
+      val writeMethod = DataWriteMethod.withName((outJson \ "Data Write Method").extract[String])
+      val outputMetrics = metrics.registerOutputMetrics(writeMethod)
+      outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long])
+      outputMetrics.setRecordsWritten((outJson \ "Records Written").extractOpt[Long].getOrElse(0L))
+    }
+
+    // Input metrics
+    Utils.jsonOption(json \ "Input Metrics").foreach { inJson =>
+      val readMethod = DataReadMethod.withName((inJson \ "Data Read Method").extract[String])
+      val inputMetrics = metrics.registerInputMetrics(readMethod)
+      inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long])
+      inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L))
+    }
+
+    // Updated blocks
+    Utils.jsonOption(json \ "Updated Blocks").foreach { blocksJson =>
+      metrics.setUpdatedBlockStatuses(blocksJson.extract[List[JValue]].map { blockJson =>
+        val id = BlockId((blockJson \ "Block ID").extract[String])
+        val status = blockStatusFromJson(blockJson \ "Status")
+        (id, status)
+      })
+    }
 
-  def outputMetricsFromJson(json: JValue): OutputMetrics = {
-    val metrics = new OutputMetrics(
-      DataWriteMethod.withName((json \ "Data Write Method").extract[String]))
-    metrics.setBytesWritten((json \ "Bytes Written").extract[Long])
-    metrics.setRecordsWritten((json \ "Records Written").extractOpt[Long].getOrElse(0))
     metrics
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 4c7416e..df9e050 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -644,6 +644,8 @@ private[spark] class ExternalSorter[K, V, C](
       blockId: BlockId,
       outputFile: File): Array[Long] = {
 
+    val writeMetrics = context.taskMetrics().registerShuffleWriteMetrics()
+
     // Track location of each range in the output file
     val lengths = new Array[Long](numPartitions)
 
@@ -652,8 +654,8 @@ private[spark] class ExternalSorter[K, V, C](
       val collection = if (aggregator.isDefined) map else buffer
       val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
       while (it.hasNext) {
-        val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
-          context.taskMetrics.shuffleWriteMetrics.get)
+        val writer = blockManager.getDiskWriter(
+          blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
         val partitionId = it.nextPartition()
         while (it.hasNext && it.nextPartition() == partitionId) {
           it.writeNext(writer)
@@ -666,8 +668,8 @@ private[spark] class ExternalSorter[K, V, C](
       // We must perform merge-sort; get an iterator by partition and write everything directly.
       for ((id, elements) <- this.partitionedIterator) {
         if (elements.hasNext) {
-          val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
-            context.taskMetrics.shuffleWriteMetrics.get)
+          val writer = blockManager.getDiskWriter(
+            blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
           for (elem <- elements) {
             writer.write(elem._1, elem._2)
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 3865c20..48a0282 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -88,7 +88,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
     try {
       TaskContext.setTaskContext(context)
       cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
-      assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
+      assert(context.taskMetrics.updatedBlockStatuses.size === 2)
     } finally {
       TaskContext.unset()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
index 8275fd8..e5ec2aa 100644
--- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.SparkFunSuite
 class TaskMetricsSuite extends SparkFunSuite {
   test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") {
     val taskMetrics = new TaskMetrics()
-    taskMetrics.updateShuffleReadMetrics()
+    taskMetrics.mergeShuffleReadMetrics()
     assert(taskMetrics.shuffleReadMetrics.isEmpty)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 6e6cf63..e1b2c96 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -855,7 +855,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
       } finally {
         TaskContext.unset()
       }
-      context.taskMetrics.updatedBlocks.getOrElse(Seq.empty)
+      context.taskMetrics.updatedBlockStatuses
     }
 
     // 1 updated block (i.e. list1)

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
index 355d80d..9de4341 100644
--- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
@@ -85,8 +85,8 @@ class StorageStatusListenerSuite extends SparkFunSuite {
     val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L))
     val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L))
     val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L))
-    taskMetrics1.updatedBlocks = Some(Seq(block1, block2))
-    taskMetrics2.updatedBlocks = Some(Seq(block3))
+    taskMetrics1.setUpdatedBlockStatuses(Seq(block1, block2))
+    taskMetrics2.setUpdatedBlockStatuses(Seq(block3))
 
     // Task end with new blocks
     assert(listener.executorIdToStorageStatus("big").numBlocks === 0)
@@ -108,8 +108,8 @@ class StorageStatusListenerSuite extends SparkFunSuite {
     val droppedBlock1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.NONE, 0L, 0L))
     val droppedBlock2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.NONE, 0L, 0L))
     val droppedBlock3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.NONE, 0L, 0L))
-    taskMetrics1.updatedBlocks = Some(Seq(droppedBlock1, droppedBlock3))
-    taskMetrics2.updatedBlocks = Some(Seq(droppedBlock2, droppedBlock3))
+    taskMetrics1.setUpdatedBlockStatuses(Seq(droppedBlock1, droppedBlock3))
+    taskMetrics2.setUpdatedBlockStatuses(Seq(droppedBlock2, droppedBlock3))
 
     listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1))
     assert(listener.executorIdToStorageStatus("big").numBlocks === 1)
@@ -133,8 +133,8 @@ class StorageStatusListenerSuite extends SparkFunSuite {
     val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L))
     val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L))
     val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L))
-    taskMetrics1.updatedBlocks = Some(Seq(block1, block2))
-    taskMetrics2.updatedBlocks = Some(Seq(block3))
+    taskMetrics1.setUpdatedBlockStatuses(Seq(block1, block2))
+    taskMetrics2.setUpdatedBlockStatuses(Seq(block3))
     listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1))
     listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics2))
     assert(listener.executorIdToStorageStatus("big").numBlocks === 3)

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index ee2d56a..607617c 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -184,12 +184,12 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
     val conf = new SparkConf()
     val listener = new JobProgressListener(conf)
     val taskMetrics = new TaskMetrics()
-    val shuffleReadMetrics = new ShuffleReadMetrics()
+    val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics()
     assert(listener.stageIdToData.size === 0)
 
     // finish this task, should get updated shuffleRead
     shuffleReadMetrics.incRemoteBytesRead(1000)
-    taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
+    taskMetrics.mergeShuffleReadMetrics()
     var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
     taskInfo.finishTime = 1
     var task = new ShuffleMapTask(0)
@@ -270,22 +270,19 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
 
     def makeTaskMetrics(base: Int): TaskMetrics = {
       val taskMetrics = new TaskMetrics()
-      val shuffleReadMetrics = new ShuffleReadMetrics()
-      val shuffleWriteMetrics = new ShuffleWriteMetrics()
-      taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
-      taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
+      taskMetrics.setExecutorRunTime(base + 4)
+      taskMetrics.incDiskBytesSpilled(base + 5)
+      taskMetrics.incMemoryBytesSpilled(base + 6)
+      val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics()
       shuffleReadMetrics.incRemoteBytesRead(base + 1)
       shuffleReadMetrics.incLocalBytesRead(base + 9)
       shuffleReadMetrics.incRemoteBlocksFetched(base + 2)
+      taskMetrics.mergeShuffleReadMetrics()
+      val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics()
       shuffleWriteMetrics.incBytesWritten(base + 3)
-      taskMetrics.setExecutorRunTime(base + 4)
-      taskMetrics.incDiskBytesSpilled(base + 5)
-      taskMetrics.incMemoryBytesSpilled(base + 6)
-      val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
-      taskMetrics.setInputMetrics(Some(inputMetrics))
+      val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop)
       inputMetrics.incBytesRead(base + 7)
-      val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
-      taskMetrics.outputMetrics = Some(outputMetrics)
+      val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop)
       outputMetrics.setBytesWritten(base + 8)
       taskMetrics
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
index 5ac922c..d1dbf7c 100644
--- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
@@ -127,7 +127,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter {
 
     // Task end with a few new persisted blocks, some from the same RDD
     val metrics1 = new TaskMetrics
-    metrics1.updatedBlocks = Some(Seq(
+    metrics1.setUpdatedBlockStatuses(Seq(
       (RDDBlockId(0, 100), BlockStatus(memAndDisk, 400L, 0L)),
       (RDDBlockId(0, 101), BlockStatus(memAndDisk, 0L, 400L)),
       (RDDBlockId(1, 20), BlockStatus(memAndDisk, 0L, 240L))
@@ -146,7 +146,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter {
 
     // Task end with a few dropped blocks
     val metrics2 = new TaskMetrics
-    metrics2.updatedBlocks = Some(Seq(
+    metrics2.setUpdatedBlockStatuses(Seq(
       (RDDBlockId(0, 100), BlockStatus(none, 0L, 0L)),
       (RDDBlockId(1, 20), BlockStatus(none, 0L, 0L)),
       (RDDBlockId(2, 40), BlockStatus(none, 0L, 0L)), // doesn't actually exist
@@ -173,8 +173,8 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter {
     val taskMetrics1 = new TaskMetrics
     val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L))
     val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L))
-    taskMetrics0.updatedBlocks = Some(Seq(block0))
-    taskMetrics1.updatedBlocks = Some(Seq(block1))
+    taskMetrics0.setUpdatedBlockStatuses(Seq(block0))
+    taskMetrics1.setUpdatedBlockStatuses(Seq(block1))
     bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
     bus.postToAll(SparkListenerStageSubmitted(stageInfo0))
     assert(storageListener.rddInfoList.size === 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 9dd400f..e5ca2de 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -557,7 +557,7 @@ class JsonProtocolSuite extends SparkFunSuite {
       metrics1.shuffleWriteMetrics, metrics2.shuffleWriteMetrics, assertShuffleWriteEquals)
     assertOptionEquals(
       metrics1.inputMetrics, metrics2.inputMetrics, assertInputMetricsEquals)
-    assertOptionEquals(metrics1.updatedBlocks, metrics2.updatedBlocks, assertBlocksEquals)
+    assertBlocksEquals(metrics1.updatedBlockStatuses, metrics2.updatedBlockStatuses)
   }
 
   private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) {
@@ -773,34 +773,31 @@ class JsonProtocolSuite extends SparkFunSuite {
     t.incMemoryBytesSpilled(a + c)
 
     if (hasHadoopInput) {
-      val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
+      val inputMetrics = t.registerInputMetrics(DataReadMethod.Hadoop)
       inputMetrics.incBytesRead(d + e + f)
       inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1)
-      t.setInputMetrics(Some(inputMetrics))
     } else {
-      val sr = new ShuffleReadMetrics
+      val sr = t.registerTempShuffleReadMetrics()
       sr.incRemoteBytesRead(b + d)
       sr.incLocalBlocksFetched(e)
       sr.incFetchWaitTime(a + d)
       sr.incRemoteBlocksFetched(f)
       sr.incRecordsRead(if (hasRecords) (b + d) / 100 else -1)
       sr.incLocalBytesRead(a + f)
-      t.setShuffleReadMetrics(Some(sr))
+      t.mergeShuffleReadMetrics()
     }
     if (hasOutput) {
-      val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
+      val outputMetrics = t.registerOutputMetrics(DataWriteMethod.Hadoop)
       outputMetrics.setBytesWritten(a + b + c)
       outputMetrics.setRecordsWritten(if (hasRecords) (a + b + c)/100 else -1)
-      t.outputMetrics = Some(outputMetrics)
     } else {
-      val sw = new ShuffleWriteMetrics
+      val sw = t.registerShuffleWriteMetrics()
       sw.incBytesWritten(a + b + c)
       sw.incWriteTime(b + c + d)
       sw.setRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1)
-      t.shuffleWriteMetrics = Some(sw)
     }
     // Make at most 6 blocks
-    t.updatedBlocks = Some((1 to (e % 5 + 1)).map { i =>
+    t.setUpdatedBlockStatuses((1 to (e % 5 + 1)).map { i =>
       (RDDBlockId(e % i, f % i), BlockStatus(StorageLevel.MEMORY_AND_DISK_SER_2, a % i, b % i))
     }.toSeq)
     t

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
index d45d2db..8222b84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
@@ -126,8 +126,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
       logInfo("Input split: " + split.serializableHadoopSplit)
       val conf = getConf(isDriverSide = false)
 
-      val inputMetrics = context.taskMetrics
-        .getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+      val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
 
       // Sets the thread local variable for the file's name
       split.serializableHadoopSplit.value match {

http://git-wip-us.apache.org/repos/asf/spark/blob/b122c861/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 9f09eb4..7438e11 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -127,7 +127,6 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
       assert(sorter.numSpills > 0)
 
       // Merging spilled files should not throw assertion error
-      taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics)
       sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile)
     } {
       // Clean up


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org