You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by an...@apache.org on 2015/10/15 23:51:01 UTC

spark git commit: [SPARK-10412] [SQL] report memory usage for tungsten sql physical operator

Repository: spark
Updated Branches:
  refs/heads/master 3b364ff0a -> 6a2359ff1


[SPARK-10412] [SQL] report memory usage for tungsten sql physical operator

https://issues.apache.org/jira/browse/SPARK-10412

some screenshots:
### aggregate:
![screen shot 2015-10-12 at 2 23 11 pm](https://cloud.githubusercontent.com/assets/3182036/10439534/618320a4-70ef-11e5-94d8-62ea7f2d1531.png)

### join
![screen shot 2015-10-12 at 2 23 29 pm](https://cloud.githubusercontent.com/assets/3182036/10439537/6724797c-70ef-11e5-8f75-0cf5cbd42048.png)

Author: Wenchen Fan <we...@databricks.com>
Author: Wenchen Fan <cl...@163.com>

Closes #8931 from cloud-fan/viz.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6a2359ff
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6a2359ff
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6a2359ff

Branch: refs/heads/master
Commit: 6a2359ff1f7ad2233af2c530313d6ec2ecf70d19
Parents: 3b364ff
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Oct 15 14:50:58 2015 -0700
Committer: Andrew Or <an...@databricks.com>
Committed: Thu Oct 15 14:50:58 2015 -0700

----------------------------------------------------------------------
 .../execution/aggregate/TungstenAggregate.scala | 10 ++-
 .../aggregate/TungstenAggregationIterator.scala | 10 ++-
 .../spark/sql/execution/metric/SQLMetrics.scala | 72 ++++++++++++++------
 .../org/apache/spark/sql/execution/sort.scala   | 16 +++++
 .../spark/sql/execution/ui/ExecutionPage.scala  |  2 +-
 .../spark/sql/execution/ui/SQLListener.scala    |  9 ++-
 .../spark/sql/execution/ui/SparkPlanGraph.scala |  4 +-
 .../TungstenAggregationIteratorSuite.scala      |  3 +-
 .../sql/execution/metric/SQLMetricsSuite.scala  | 13 +++-
 .../sql/execution/ui/SQLListenerSuite.scala     | 20 +++---
 10 files changed, 116 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6a2359ff/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index c342940..0d3a4b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -49,7 +49,9 @@ case class TungstenAggregate(
 
   override private[sql] lazy val metrics = Map(
     "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+    "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
+    "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
 
   override def outputsUnsafeRows: Boolean = true
 
@@ -79,6 +81,8 @@ case class TungstenAggregate(
   protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
     val numInputRows = longMetric("numInputRows")
     val numOutputRows = longMetric("numOutputRows")
+    val dataSize = longMetric("dataSize")
+    val spillSize = longMetric("spillSize")
 
     /**
      * Set up the underlying unsafe data structures used before computing the parent partition.
@@ -97,7 +101,9 @@ case class TungstenAggregate(
         child.output,
         testFallbackStartsAt,
         numInputRows,
-        numOutputRows)
+        numOutputRows,
+        dataSize,
+        spillSize)
     }
 
     /** Compute a partition using the iterator already set up previously. */

http://git-wip-us.apache.org/repos/asf/spark/blob/6a2359ff/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index fe708a5..7cd0f7b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -87,7 +87,9 @@ class TungstenAggregationIterator(
     originalInputAttributes: Seq[Attribute],
     testFallbackStartsAt: Option[Int],
     numInputRows: LongSQLMetric,
-    numOutputRows: LongSQLMetric)
+    numOutputRows: LongSQLMetric,
+    dataSize: LongSQLMetric,
+    spillSize: LongSQLMetric)
   extends Iterator[UnsafeRow] with Logging {
 
   // The parent partition iterator, to be initialized later in `start`
@@ -110,6 +112,10 @@ class TungstenAggregationIterator(
       s"$allAggregateExpressions should have no more than 2 kinds of modes.")
   }
 
+  // Remember spill data size of this task before execute this operator so that we can
+  // figure out how many bytes we spilled for this operator.
+  private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled
+
   //
   // The modes of AggregateExpressions. Right now, we can handle the following mode:
   //  - Partial-only:
@@ -842,6 +848,8 @@ class TungstenAggregationIterator(
         val mapMemory = hashMap.getPeakMemoryUsedBytes
         val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
         val peakMemory = Math.max(mapMemory, sorterMemory)
+        dataSize += peakMemory
+        spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore
         TaskContext.get().internalMetricsToAccumulators(
           InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/6a2359ff/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 7a2a98e..075b7ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.metric
 
+import org.apache.spark.util.Utils
 import org.apache.spark.{Accumulable, AccumulableParam, SparkContext}
 
 /**
@@ -35,6 +36,12 @@ private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
  */
 private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] {
 
+  /**
+   * A function that defines how we aggregate the final accumulator results among all tasks,
+   * and represent it in string for a SQL physical operator.
+   */
+  val stringValue: Seq[T] => String
+
   def zero: R
 }
 
@@ -64,25 +71,11 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr
 }
 
 /**
- * A wrapper of Int to avoid boxing and unboxing when using Accumulator
- */
-private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] {
-
-  def add(term: Int): IntSQLMetricValue = {
-    _value += term
-    this
-  }
-
-  // Although there is a boxing here, it's fine because it's only called in SQLListener
-  override def value: Int = _value
-}
-
-/**
  * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's
  * `+=` and `add`.
  */
-private[sql] class LongSQLMetric private[metric](name: String)
-  extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) {
+private[sql] class LongSQLMetric private[metric](name: String, param: LongSQLMetricParam)
+  extends SQLMetric[LongSQLMetricValue, Long](name, param) {
 
   override def +=(term: Long): Unit = {
     localValue.add(term)
@@ -93,7 +86,8 @@ private[sql] class LongSQLMetric private[metric](name: String)
   }
 }
 
-private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] {
+private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialValue: Long)
+  extends SQLMetricParam[LongSQLMetricValue, Long] {
 
   override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t)
 
@@ -102,20 +96,56 @@ private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Lon
 
   override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero
 
-  override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L)
+  override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue)
 }
 
 private[sql] object SQLMetrics {
 
-  def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
-    val acc = new LongSQLMetric(name)
+  private def createLongMetric(
+      sc: SparkContext,
+      name: String,
+      stringValue: Seq[Long] => String,
+      initialValue: Long): LongSQLMetric = {
+    val param = new LongSQLMetricParam(stringValue, initialValue)
+    val acc = new LongSQLMetric(name, param)
     sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
     acc
   }
 
+  def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
+    createLongMetric(sc, name, _.sum.toString, 0L)
+  }
+
+  /**
+   * Create a metric to report the size information (including total, min, med, max) like data size,
+   * spill size, etc.
+   */
+  def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = {
+    val stringValue = (values: Seq[Long]) => {
+      // This is a workaround for SPARK-11013.
+      // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
+      // it at the end of task and the value will be at least 0.
+      val validValues = values.filter(_ >= 0)
+      val Seq(sum, min, med, max) = {
+        val metric = if (validValues.length == 0) {
+          Seq.fill(4)(0L)
+        } else {
+          val sorted = validValues.sorted
+          Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
+        }
+        metric.map(Utils.bytesToString)
+      }
+      s"\n$sum ($min, $med, $max)"
+    }
+    // The final result of this metric in physical operator UI may looks like:
+    // data size total (min, med, max):
+    // 100GB (100MB, 1GB, 10GB)
+    createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L)
+  }
+
   /**
    * A metric that its value will be ignored. Use this one when we need a metric parameter but don't
    * care about the value.
    */
-  val nullLongMetric = new LongSQLMetric("null")
+  val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L))
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6a2359ff/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 27f2624..9385e57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.ExternalSorter
@@ -93,10 +94,17 @@ case class TungstenSort(
   override def requiredChildDistribution: Seq[Distribution] =
     if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
 
+  override private[sql] lazy val metrics = Map(
+    "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
+    "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
+
   protected override def doExecute(): RDD[InternalRow] = {
     val schema = child.schema
     val childOutput = child.output
 
+    val dataSize = longMetric("dataSize")
+    val spillSize = longMetric("spillSize")
+
     /**
      * Set up the sorter in each partition before computing the parent partition.
      * This makes sure our sorter is not starved by other sorters used in the same task.
@@ -131,7 +139,15 @@ case class TungstenSort(
         partitionIndex: Int,
         sorter: UnsafeExternalRowSorter,
         parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
+      // Remember spill data size of this task before execute this operator so that we can
+      // figure out how many bytes we spilled for this operator.
+      val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled
+
       val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]])
+
+      dataSize += sorter.getPeakMemoryUsage
+      spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore
+
       taskContext.internalMetricsToAccumulators(
         InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
       sortedIterator

http://git-wip-us.apache.org/repos/asf/spark/blob/6a2359ff/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
index a4dbd2e..e74d6fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
@@ -100,7 +100,7 @@ private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution")
     // scalastyle:on
   }
 
-  private def planVisualization(metrics: Map[Long, Any], graph: SparkPlanGraph): Seq[Node] = {
+  private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = {
     val metadata = graph.nodes.flatMap { node =>
       val nodeId = s"plan-meta-data-${node.id}"
       <div id={nodeId}>{node.desc}</div>

http://git-wip-us.apache.org/repos/asf/spark/blob/6a2359ff/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index d647240..b302b51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -252,7 +252,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
   /**
    * Get all accumulator updates from all tasks which belong to this execution and merge them.
    */
-  def getExecutionMetrics(executionId: Long): Map[Long, Any] = synchronized {
+  def getExecutionMetrics(executionId: Long): Map[Long, String] = synchronized {
     _executionIdToData.get(executionId) match {
       case Some(executionUIData) =>
         val accumulatorUpdates = {
@@ -264,8 +264,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
           }
         }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) }
         mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
-          executionUIData.accumulatorMetrics(accumulatorId).metricParam).
-          mapValues(_.asInstanceOf[SQLMetricValue[_]].value)
+          executionUIData.accumulatorMetrics(accumulatorId).metricParam)
       case None =>
         // This execution has been dropped
         Map.empty
@@ -274,11 +273,11 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
 
   private def mergeAccumulatorUpdates(
       accumulatorUpdates: Seq[(Long, Any)],
-      paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = {
+      paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, String] = {
     accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) =>
       val param = paramFunc(accumulatorId)
       (accumulatorId,
-        values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace))
+        param.stringValue(values.map(_._2.asInstanceOf[SQLMetricValue[Any]].value)))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6a2359ff/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index ae3d752..f1fce54 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue}
 private[ui] case class SparkPlanGraph(
     nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge]) {
 
-  def makeDotFile(metrics: Map[Long, Any]): String = {
+  def makeDotFile(metrics: Map[Long, String]): String = {
     val dotFile = new StringBuilder
     dotFile.append("digraph G {\n")
     nodes.foreach(node => dotFile.append(node.makeDotNode(metrics) + "\n"))
@@ -87,7 +87,7 @@ private[sql] object SparkPlanGraph {
 private[ui] case class SparkPlanGraphNode(
     id: Long, name: String, desc: String, metrics: Seq[SQLPlanMetric]) {
 
-  def makeDotNode(metricsValue: Map[Long, Any]): String = {
+  def makeDotNode(metricsValue: Map[Long, String]): String = {
     val values = {
       for (metric <- metrics;
            value <- metricsValue.get(metric.accumulatorId)) yield {

http://git-wip-us.apache.org/repos/asf/spark/blob/6a2359ff/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index 0cc4988..cc0ac1b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -39,7 +39,8 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte
       }
       val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
       iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
-        0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
+        0, Seq.empty, newMutableProjection, Seq.empty, None,
+        dummyAccum, dummyAccum, dummyAccum, dummyAccum)
       val numPages = iter.getHashMap.getNumDataPages
       assert(numPages === 1)
     } finally {

http://git-wip-us.apache.org/repos/asf/spark/blob/6a2359ff/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 6afffae..cdd885b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -93,7 +93,16 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
         }.toMap
         (node.id, node.name -> nodeMetrics)
       }.toMap
-      assert(expectedMetrics === actualMetrics)
+
+      assert(expectedMetrics.keySet === actualMetrics.keySet)
+      for (nodeId <- expectedMetrics.keySet) {
+        val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId)
+        val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId)
+        assert(expectedNodeName === actualNodeName)
+        for (metricName <- expectedMetricsMap.keySet) {
+          assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName))
+        }
+      }
     } else {
       // TODO Remove this "else" once we fix the race condition that missing the JobStarted event.
       // Since we cannot track all jobs, the metric values could be wrong and we should not check
@@ -489,7 +498,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
       val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
       // Because "save" will create a new DataFrame internally, we cannot get the real metric id.
       // However, we still can check the value.
-      assert(metricValues.values.toSeq === Seq(2L))
+      assert(metricValues.values.toSeq === Seq("2"))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6a2359ff/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 727cf36..cc1c1e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -74,6 +74,10 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
   }
 
   test("basic") {
+    def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): Unit = {
+      assert(actual === expected.mapValues(_.toString))
+    }
+
     val listener = new SQLListener(sqlContext.sparkContext.conf)
     val executionId = 0
     val df = createTestDataFrame
@@ -114,7 +118,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
       (1L, 0, 0, createTaskMetrics(accumulatorUpdates))
     )))
 
-    assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2))
+    checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
       // (task id, stage id, stage attempt, metrics)
@@ -122,7 +126,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
       (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)))
     )))
 
-    assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 3))
+    checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3))
 
     // Retrying a stage should reset the metrics
     listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1)))
@@ -133,7 +137,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
       (1L, 0, 1, createTaskMetrics(accumulatorUpdates))
     )))
 
-    assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2))
+    checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
 
     // Ignore the task end for the first attempt
     listener.onTaskEnd(SparkListenerTaskEnd(
@@ -144,7 +148,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
       createTaskInfo(0, 0),
       createTaskMetrics(accumulatorUpdates.mapValues(_ * 100))))
 
-    assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2))
+    checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
 
     // Finish two tasks
     listener.onTaskEnd(SparkListenerTaskEnd(
@@ -162,7 +166,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
       createTaskInfo(1, 0),
       createTaskMetrics(accumulatorUpdates.mapValues(_ * 3))))
 
-    assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 5))
+    checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 5))
 
     // Summit a new stage
     listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0)))
@@ -173,7 +177,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
       (1L, 1, 0, createTaskMetrics(accumulatorUpdates))
     )))
 
-    assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 7))
+    checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7))
 
     // Finish two tasks
     listener.onTaskEnd(SparkListenerTaskEnd(
@@ -191,7 +195,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
       createTaskInfo(1, 0),
       createTaskMetrics(accumulatorUpdates.mapValues(_ * 3))))
 
-    assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11))
+    checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11))
 
     assert(executionUIData.runningJobs === Seq(0))
     assert(executionUIData.succeededJobs.isEmpty)
@@ -208,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
     assert(executionUIData.succeededJobs === Seq(0))
     assert(executionUIData.failedJobs.isEmpty)
 
-    assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11))
+    checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11))
   }
 
   test("onExecutionEnd happens before onJobEnd(JobSucceeded)") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org