You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2022/08/11 00:13:37 UTC

[spark] branch master updated: [SPARK-39983][CORE][SQL] Do not cache unserialized broadcast relations on the driver

This is an automated email from the ASF dual-hosted git repository.

joshrosen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e17d8ecabca [SPARK-39983][CORE][SQL] Do not cache unserialized broadcast relations on the driver
e17d8ecabca is described below

commit e17d8ecabcad6e84428752b977120ff355a4007a
Author: Alex Balikov <al...@databricks.com>
AuthorDate: Wed Aug 10 17:13:03 2022 -0700

    [SPARK-39983][CORE][SQL] Do not cache unserialized broadcast relations on the driver
    
    ### What changes were proposed in this pull request?
    
    This PR addresses the issue raised in https://issues.apache.org/jira/browse/SPARK-39983 - broadcast relations should not be cached on the driver as they are not needed and can cause significant memory pressure (in one case the relation was 60MB )
    
    The PR adds a new SparkContext.broadcastInternal method with parameter serializedOnly allowing the caller to specify that the broadcasted object should be stored only in serialized form. The current behavior is to also cache an unserialized form of the object.
    
    The PR changes the broadcast implementation in TorrentBroadcast to honor the serializedOnly flag and not store the unserialized value, unless the execution is in a local mode (single process). In that case the broadcast cache is effectively shared between driver and executors and thus the unserialized value needs to be cached to satisfy the executor-side of the functionality.
    
    ### Why are the changes needed?
    
    The broadcast relations can be fairly large (observed 60MB one) and are not needed in unserialized form on the driver.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added a new unit test to BroadcastSuite verifying the low-level broadcast functionality in respect to the serializedOnly flag.
    Added a new unit test to BroadcastExchangeSuite verifying that broadcasted relations are not cached on the driver.
    
    Closes #37413 from alex-balikov/SPARK-39983-broadcast-no-cache.
    
    Lead-authored-by: Alex Balikov <al...@databricks.com>
    Co-authored-by: Josh Rosen <jo...@databricks.com>
    Signed-off-by: Josh Rosen <jo...@databricks.com>
---
 .../main/scala/org/apache/spark/SparkContext.scala | 19 +++++-
 .../apache/spark/broadcast/BroadcastFactory.scala  |  8 ++-
 .../apache/spark/broadcast/BroadcastManager.scala  |  7 ++-
 .../apache/spark/broadcast/TorrentBroadcast.scala  | 67 +++++++++++++++++-----
 .../spark/broadcast/TorrentBroadcastFactory.scala  |  8 ++-
 .../apache/spark/broadcast/BroadcastSuite.scala    | 19 ++++++
 .../execution/exchange/BroadcastExchangeExec.scala |  4 +-
 .../sql/execution/BroadcastExchangeSuite.scala     | 29 +++++++++-
 8 files changed, 136 insertions(+), 25 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 6cb4f04ac7f..f101dc8e083 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1511,16 +1511,31 @@ class SparkContext(config: SparkConf) extends Logging {
   /**
    * Broadcast a read-only variable to the cluster, returning a
    * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
-   * The variable will be sent to each cluster only once.
+   * The variable will be sent to each executor only once.
    *
    * @param value value to broadcast to the Spark nodes
    * @return `Broadcast` object, a read-only variable cached on each machine
    */
   def broadcast[T: ClassTag](value: T): Broadcast[T] = {
+    broadcastInternal(value, serializedOnly = false)
+  }
+
+  /**
+   * Internal version of broadcast - broadcast a read-only variable to the cluster, returning a
+   * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
+   * The variable will be sent to each executor only once.
+   *
+   * @param value value to broadcast to the Spark nodes
+   * @param serializedOnly if true, do not cache the unserialized value on the driver
+   * @return `Broadcast` object, a read-only variable cached on each machine
+   */
+  private[spark] def broadcastInternal[T: ClassTag](
+      value: T,
+      serializedOnly: Boolean): Broadcast[T] = {
     assertNotStopped()
     require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
       "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
-    val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
+    val bc = env.broadcastManager.newBroadcast[T](value, isLocal, serializedOnly)
     val callSite = getCallSite
     logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
     cleaner.foreach(_.registerBroadcastForCleanup(bc))
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index 9891582501b..38d642753ad 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -36,8 +36,14 @@ private[spark] trait BroadcastFactory {
    * @param value value to broadcast
    * @param isLocal whether we are in local mode (single JVM process)
    * @param id unique id representing this broadcast variable
+   * @param serializedOnly if true, do not cache the unserialized value on the driver
+   * @return `Broadcast` object, a read-only variable cached on each machine
    */
-  def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
+  def newBroadcast[T: ClassTag](
+      value: T,
+      isLocal: Boolean,
+      id: Long,
+      serializedOnly: Boolean = false): Broadcast[T]
 
   def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
 
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
index b6f59c36081..cd152709a1f 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
@@ -60,7 +60,10 @@ private[spark] class BroadcastManager(
         .asInstanceOf[java.util.Map[Any, Any]]
     )
 
-  def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
+  def newBroadcast[T: ClassTag](
+      value_ : T,
+      isLocal: Boolean,
+      serializedOnly: Boolean = false): Broadcast[T] = {
     val bid = nextBroadcastId.getAndIncrement()
     value_ match {
       case pb: PythonBroadcast =>
@@ -72,7 +75,7 @@ private[spark] class BroadcastManager(
 
       case _ => // do nothing
     }
-    broadcastFactory.newBroadcast[T](value_, isLocal, bid)
+    broadcastFactory.newBroadcast[T](value_, isLocal, bid, serializedOnly)
   }
 
   def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index e35a079746a..8f91f673aa9 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.broadcast
 
 import java.io._
-import java.lang.ref.SoftReference
+import java.lang.ref.{Reference, SoftReference, WeakReference}
 import java.nio.ByteBuffer
 import java.util.zip.Adler32
 
@@ -54,8 +54,9 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
  *
  * @param obj object to broadcast
  * @param id A unique identifier for the broadcast variable.
+ * @param serializedOnly if true, do not cache the unserialized value on the driver
  */
-private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
+private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long, serializedOnly: Boolean)
   extends Broadcast[T](id) with Logging with Serializable {
 
   /**
@@ -64,15 +65,17 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
    *
    * On the driver, if the value is required, it is read lazily from the block manager. We hold
    * a soft reference so that it can be garbage collected if required, as we can always reconstruct
-   * in the future.
+   * in the future. For internal broadcast variables where `serializedOnly = true`, we hold a
+   * WeakReference to allow the value to be reclaimed more aggressively.
    */
-  @transient private var _value: SoftReference[T] = _
+  @transient private var _value: Reference[T] = _
 
   /** The compression codec to use, or None if compression is disabled */
   @transient private var compressionCodec: Option[CompressionCodec] = _
   /** Size of each block. Default value is 4MB.  This value is only read by the broadcaster. */
   @transient private var blockSize: Int = _
-
+  /** Is the execution in local mode. */
+  @transient private var isLocalMaster: Boolean = _
 
   /** Whether to generate checksum for blocks or not. */
   private var checksumEnabled: Boolean = false
@@ -86,6 +89,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
     // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided
     blockSize = conf.get(config.BROADCAST_BLOCKSIZE).toInt * 1024
     checksumEnabled = conf.get(config.BROADCAST_CHECKSUM)
+    isLocalMaster = Utils.isLocalMaster(conf)
   }
   setConf(SparkEnv.get.conf)
 
@@ -103,7 +107,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
       memoized
     } else {
       val newlyRead = readBroadcastBlock()
-      _value = new SoftReference[T](newlyRead)
+      _value = if (serializedOnly) {
+        new WeakReference[T](newlyRead)
+      } else {
+        new SoftReference[T](newlyRead)
+      }
       newlyRead
     }
   }
@@ -129,11 +137,23 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
    */
   private def writeBlocks(value: T): Int = {
     import StorageLevel._
-    // Store a copy of the broadcast variable in the driver so that tasks run on the driver
-    // do not create a duplicate copy of the broadcast variable's value.
     val blockManager = SparkEnv.get.blockManager
-    if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
-      throw new SparkException(s"Failed to store $broadcastId in BlockManager")
+    if (serializedOnly && !isLocalMaster) {
+      // SPARK-39983: When creating a broadcast variable internal to Spark (such as a broadcasted
+      // hashed relation), don't store the broadcasted value in the driver's block manager:
+      // we do not expect internal broadcast variables' values to be read on the driver, so
+      // skipping the store reduces driver memory pressure because we don't add a long-lived
+      // reference to the broadcasted object. However, this optimization cannot be applied for
+      // local mode (since tasks might run on the driver). To guard against performance
+      // regressions if an internal broadcast is accessed on the driver, we store a weak
+      // reference to the broadcasted value:
+      _value = new WeakReference[T](value)
+    } else {
+      // Store a copy of the broadcast variable in the driver so that tasks run on the driver
+      // do not create a duplicate copy of the broadcast variable's value.
+      if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
+        throw new SparkException(s"Failed to store $broadcastId in BlockManager")
+      }
     }
     try {
       val blocks =
@@ -258,11 +278,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
             try {
               val obj = TorrentBroadcast.unBlockifyObject[T](
                 blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
-              // Store the merged copy in BlockManager so other tasks on this executor don't
-              // need to re-fetch it.
-              val storageLevel = StorageLevel.MEMORY_AND_DISK
-              if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
-                throw new SparkException(s"Failed to store $broadcastId in BlockManager")
+
+              if (!serializedOnly || isLocalMaster || Utils.isInRunningSparkTask) {
+                // Store the merged copy in BlockManager so other tasks on this executor don't
+                // need to re-fetch it.
+                val storageLevel = StorageLevel.MEMORY_AND_DISK
+                if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
+                  throw new SparkException(s"Failed to store $broadcastId in BlockManager")
+                }
               }
 
               if (obj != null) {
@@ -297,6 +320,20 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
     }
   }
 
+  // Is the unserialized value cached. Exposed for testing.
+  private[spark] def hasCachedValue: Boolean = {
+    TorrentBroadcast.torrentBroadcastLock.withLock(broadcastId) {
+      setConf(SparkEnv.get.conf)
+      val blockManager = SparkEnv.get.blockManager
+      blockManager.getLocalValues(broadcastId) match {
+        case Some(blockResult) if (blockResult.data.hasNext) =>
+          val x = blockResult.data.next().asInstanceOf[T]
+          releaseBlockManagerLock(broadcastId)
+          x != null
+        case _ => false
+      }
+    }
+  }
 }
 
 
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
index 6846e1967c4..4ff39ba4074 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
@@ -30,8 +30,12 @@ private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
 
   override def initialize(isDriver: Boolean, conf: SparkConf): Unit = { }
 
-  override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
-    new TorrentBroadcast[T](value_, id)
+  override def newBroadcast[T: ClassTag](
+      value_ : T,
+      isLocal: Boolean,
+      id: Long,
+      serializedOnly: Boolean = false): Broadcast[T] = {
+    new TorrentBroadcast[T](value_, id, serializedOnly)
   }
 
   override def stop(): Unit = { }
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 5e8b25f4251..41452076f88 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -187,6 +187,25 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio
     assert(instances.size === 1)
   }
 
+  test("SPARK-39983 - Broadcasted value not cached on driver") {
+    // Use distributed cluster as in local mode the broabcast value is actually cached.
+    val conf = new SparkConf()
+      .setMaster("local-cluster[2,1,1024]")
+      .setAppName("test")
+    sc = new SparkContext(conf)
+
+    sc.broadcastInternal(value = 1234, serializedOnly = false) match {
+      case tb: TorrentBroadcast[Int] =>
+        assert(tb.hasCachedValue)
+        assert(1234 === tb.value)
+    }
+    sc.broadcastInternal(value = 1234, serializedOnly = true) match {
+      case tb: TorrentBroadcast[Int] =>
+        assert(!tb.hasCachedValue)
+        assert(1234 === tb.value)
+    }
+  }
+
   /**
    * Verify the persistence of state associated with a TorrentBroadcast in a local-cluster.
    *
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index accd0a064ea..548a8628ba4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -166,8 +166,8 @@ case class BroadcastExchangeExec(
             val beforeBroadcast = System.nanoTime()
             longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)
 
-            // Broadcast the relation
-            val broadcasted = sparkContext.broadcast(relation)
+            // SPARK-39983 - Broadcast the relation without caching the unserialized object.
+            val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true)
             longMetric("broadcastTime") += NANOSECONDS.toMillis(
               System.nanoTime() - beforeBroadcast)
             val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala
index 7d6306b65ff..129f76d7be3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala
@@ -19,8 +19,10 @@ package org.apache.spark.sql.execution
 
 import java.util.concurrent.{CountDownLatch, TimeUnit}
 
-import org.apache.spark.SparkException
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark.broadcast.TorrentBroadcast
 import org.apache.spark.scheduler._
+import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
 import org.apache.spark.sql.execution.joins.HashedRelation
@@ -94,3 +96,28 @@ class BroadcastExchangeSuite extends SparkPlanTest
     }
   }
 }
+
+// Additional tests run in 'local-cluster' mode.
+class BroadcastExchangeExecSparkSuite
+  extends SparkFunSuite with LocalSparkContext with AdaptiveSparkPlanHelper {
+
+  test("SPARK-39983 - Broadcasted relation is not cached on the driver") {
+    // Use distributed cluster as in local mode the broabcast value is actually cached.
+    val conf = new SparkConf()
+      .setMaster("local-cluster[2,1,1024]")
+      .setAppName("test")
+    sc = new SparkContext(conf)
+    val spark = new SparkSession(sc)
+
+    val df = spark.range(1).toDF()
+    val joinDF = df.join(broadcast(df), "id")
+    val broadcastExchangeExec = collect(
+      joinDF.queryExecution.executedPlan) { case p: BroadcastExchangeExec => p }
+    assert(broadcastExchangeExec.size == 1, "one and only BroadcastExchangeExec")
+
+    // The broadcasted relation should not be cached on the driver.
+    val broadcasted =
+      broadcastExchangeExec(0).relationFuture.get().asInstanceOf[TorrentBroadcast[Any]]
+    assert(!broadcasted.hasCachedValue)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org