You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/10/09 22:22:45 UTC

spark git commit: [SPARK-22170][SQL] Reduce memory consumption in broadcast joins.

Repository: spark
Updated Branches:
  refs/heads/master dadd13f36 -> 155ab6347


[SPARK-22170][SQL] Reduce memory consumption in broadcast joins.

## What changes were proposed in this pull request?

This updates the broadcast join code path to lazily decompress pages and
iterate through UnsafeRows to prevent all rows from being held in memory
while the broadcast table is being built.

## How was this patch tested?

Existing tests.

Author: Ryan Blue <bl...@apache.org>

Closes #19394 from rdblue/broadcast-driver-memory.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/155ab634
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/155ab634
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/155ab634

Branch: refs/heads/master
Commit: 155ab6347ec7be06c937372a51e8013fdd371d93
Parents: dadd13f
Author: Ryan Blue <bl...@apache.org>
Authored: Mon Oct 9 15:22:41 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Mon Oct 9 15:22:41 2017 -0700

----------------------------------------------------------------------
 .../catalyst/plans/physical/broadcastMode.scala |  6 ++++
 .../apache/spark/sql/execution/SparkPlan.scala  | 19 +++++++++----
 .../exchange/BroadcastExchangeExec.scala        | 29 ++++++++++++++------
 .../sql/execution/joins/HashedRelation.scala    | 13 ++++++++-
 .../apache/spark/sql/ConfigBehaviorSuite.scala  |  2 +-
 .../sql/execution/metric/SQLMetricsSuite.scala  |  3 +-
 6 files changed, 54 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/155ab634/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
index 2ab46dc..9fac95a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow
 trait BroadcastMode {
   def transform(rows: Array[InternalRow]): Any
 
+  def transform(rows: Iterator[InternalRow], sizeHint: Option[Long]): Any
+
   def canonicalized: BroadcastMode
 }
 
@@ -36,5 +38,9 @@ case object IdentityBroadcastMode extends BroadcastMode {
   // TODO: pack the UnsafeRows into single bytes array.
   override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows
 
+  override def transform(
+      rows: Iterator[InternalRow],
+      sizeHint: Option[Long]): Array[InternalRow] = rows.toArray
+
   override def canonicalized: BroadcastMode = this
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/155ab634/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index b263f10..2ffd948 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -223,7 +223,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
    * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also
    * compressed.
    */
-  private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = {
+  private def getByteArrayRdd(n: Int = -1): RDD[(Long, Array[Byte])] = {
     execute().mapPartitionsInternal { iter =>
       var count = 0
       val buffer = new Array[Byte](4 << 10)  // 4K
@@ -239,7 +239,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
       out.writeInt(-1)
       out.flush()
       out.close()
-      Iterator(bos.toByteArray)
+      Iterator((count, bos.toByteArray))
     }
   }
 
@@ -274,19 +274,26 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
     val byteArrayRdd = getByteArrayRdd()
 
     val results = ArrayBuffer[InternalRow]()
-    byteArrayRdd.collect().foreach { bytes =>
-      decodeUnsafeRows(bytes).foreach(results.+=)
+    byteArrayRdd.collect().foreach { countAndBytes =>
+      decodeUnsafeRows(countAndBytes._2).foreach(results.+=)
     }
     results.toArray
   }
 
+  private[spark] def executeCollectIterator(): (Long, Iterator[InternalRow]) = {
+    val countsAndBytes = getByteArrayRdd().collect()
+    val total = countsAndBytes.map(_._1).sum
+    val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2))
+    (total, rows)
+  }
+
   /**
    * Runs this query returning the result as an iterator of InternalRow.
    *
    * @note Triggers multiple jobs (one for each partition).
    */
   def executeToIterator(): Iterator[InternalRow] = {
-    getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows)
+    getByteArrayRdd().map(_._2).toLocalIterator.flatMap(decodeUnsafeRows)
   }
 
   /**
@@ -307,7 +314,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
       return new Array[InternalRow](0)
     }
 
-    val childRDD = getByteArrayRdd(n)
+    val childRDD = getByteArrayRdd(n).map(_._2)
 
     val buf = new ArrayBuffer[InternalRow]
     val totalParts = childRDD.partitions.length

http://git-wip-us.apache.org/repos/asf/spark/blob/155ab634/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index 9c859e4..880e18c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
 import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.joins.HashedRelation
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.ThreadUtils
 
@@ -72,26 +72,39 @@ case class BroadcastExchangeExec(
       SQLExecution.withExecutionId(sparkContext, executionId) {
         try {
           val beforeCollect = System.nanoTime()
-          // Note that we use .executeCollect() because we don't want to convert data to Scala types
-          val input: Array[InternalRow] = child.executeCollect()
-          if (input.length >= 512000000) {
+          // Use executeCollect/executeCollectIterator to avoid conversion to Scala types
+          val (numRows, input) = child.executeCollectIterator()
+          if (numRows >= 512000000) {
             throw new SparkException(
-              s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows")
+              s"Cannot broadcast the table with more than 512 millions rows: $numRows rows")
           }
+
           val beforeBuild = System.nanoTime()
           longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
-          val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
+
+          // Construct the relation.
+          val relation = mode.transform(input, Some(numRows))
+
+          val dataSize = relation match {
+            case map: HashedRelation =>
+              map.estimatedSize
+            case arr: Array[InternalRow] =>
+              arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
+            case _ =>
+              throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " +
+                  relation.getClass.getName)
+          }
+
           longMetric("dataSize") += dataSize
           if (dataSize >= (8L << 30)) {
             throw new SparkException(
               s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
           }
 
-          // Construct and broadcast the relation.
-          val relation = mode.transform(input)
           val beforeBroadcast = System.nanoTime()
           longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000
 
+          // Broadcast the relation
           val broadcasted = sparkContext.broadcast(relation)
           longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000
 

http://git-wip-us.apache.org/repos/asf/spark/blob/155ab634/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index f8058b2..b2dcbe5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -866,7 +866,18 @@ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression])
   extends BroadcastMode {
 
   override def transform(rows: Array[InternalRow]): HashedRelation = {
-    HashedRelation(rows.iterator, canonicalized.key, rows.length)
+    transform(rows.iterator, Some(rows.length))
+  }
+
+  override def transform(
+      rows: Iterator[InternalRow],
+      sizeHint: Option[Long]): HashedRelation = {
+    sizeHint match {
+      case Some(numRows) =>
+        HashedRelation(rows, canonicalized.key, numRows.toInt)
+      case None =>
+        HashedRelation(rows, canonicalized.key)
+    }
   }
 
   override lazy val canonicalized: HashedRelationBroadcastMode = {

http://git-wip-us.apache.org/repos/asf/spark/blob/155ab634/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala
index 2c1e5db..cee85ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala
@@ -58,7 +58,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext {
       withSQLConf(SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION.key -> "1") {
         // If we only sample one point, the range boundaries will be pretty bad and the
         // chi-sq value would be very high.
-        assert(computeChiSquareTest() > 1000)
+        assert(computeChiSquareTest() > 300)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/155ab634/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 0dc612e..58a194b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -227,8 +227,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
     val df = df1.join(broadcast(df2), "key")
     testSparkPlanMetrics(df, 2, Map(
       1L -> (("BroadcastHashJoin", Map(
-        "number of output rows" -> 2L,
-        "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))))
+        "number of output rows" -> 2L))))
     )
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org