You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2021/10/05 17:06:25 UTC

[spark] branch master updated: [SPARK-36705][FOLLOW-UP] Support the case when user's classes need to register for Kryo serialization

This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e5b01cd  [SPARK-36705][FOLLOW-UP] Support the case when user's classes need to register for Kryo serialization
e5b01cd is described below

commit e5b01cd823990d71c3ff32061c2998a076166ba8
Author: Minchu Yang <mi...@minyang-mn3.linkedin.biz>
AuthorDate: Tue Oct 5 12:05:43 2021 -0500

    [SPARK-36705][FOLLOW-UP] Support the case when user's classes need to register for Kryo serialization
    
    ### What changes were proposed in this pull request?
    
    - Make the val lazy wherever `isPushBasedShuffleEnabled` is invoked when it is a class instance variable, so it can happen after user-defined jars/classes in `spark.kryo.classesToRegister` are downloaded and available on executor-side, as part of the fix for the exception mentioned below.
    
    - Add a flag `checkSerializer` to control whether we need to check a serializer is `supportsRelocationOfSerializedObjects` or not within `isPushBasedShuffleEnabled` as part of the fix for the exception mentioned below. Specifically, we don't check this in `registerWithExternalShuffleServer()` in `BlockManager` and `createLocalDirsForMergedShuffleBlocks()` in `DiskBlockManager.scala` as the same issue would raise otherwise.
    
    - Move `instantiateClassFromConf` and `instantiateClass` from `SparkEnv` into `Utils`, in order to let `isPushBasedShuffleEnabled` to leverage them for instantiating serializer instances.
    
    ### Why are the changes needed?
    
    When user tries to set classes for Kryo Serialization by `spark.kryo.classesToRegister`, below exception(or similar) would be encountered in `isPushBasedShuffleEnabled` as indicated below.
    Reproduced the issue in our internal branch by launching spark-shell as:
    ```
    spark-shell --spark-version 3.1.1 --packages ml.dmlc:xgboost4j_2.12:1.3.1 --conf spark.kryo.classesToRegister=ml.dmlc.xgboost4j.scala.Booster
    ```
    
    ```
    Exception in thread "main" java.lang.reflect.UndeclaredThrowableException
    	at org.apache.hadoop.security.UserGroupInformation.doAs(UserGroupInformation.java:1911)
    	at org.apache.spark.deploy.SparkHadoopUtil.runAsSparkUser(SparkHadoopUtil.scala:61)
    	at org.apache.spark.executor.CoarseGrainedExecutorBackend$.run(CoarseGrainedExecutorBackend.scala:393)
    	at org.apache.spark.executor.YarnCoarseGrainedExecutorBackend$.main(YarnCoarseGrainedExecutorBackend.scala:83)
    	at org.apache.spark.executor.YarnCoarseGrainedExecutorBackend.main(YarnCoarseGrainedExecutorBackend.scala)
    Caused by: org.apache.spark.SparkException: Failed to register classes with Kryo
    	at org.apache.spark.serializer.KryoSerializer.$anonfun$newKryo$5(KryoSerializer.scala:183)
    	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    	at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:230)
    	at org.apache.spark.serializer.KryoSerializer.newKryo(KryoSerializer.scala:171)
    	at org.apache.spark.serializer.KryoSerializer$$anon$1.create(KryoSerializer.scala:102)
    	at com.esotericsoftware.kryo.pool.KryoPoolQueueImpl.borrow(KryoPoolQueueImpl.java:48)
    	at org.apache.spark.serializer.KryoSerializer$PoolWrapper.borrow(KryoSerializer.scala:109)
    	at org.apache.spark.serializer.KryoSerializerInstance.borrowKryo(KryoSerializer.scala:346)
    	at org.apache.spark.serializer.KryoSerializerInstance.getAutoReset(KryoSerializer.scala:446)
    	at org.apache.spark.serializer.KryoSerializer.supportsRelocationOfSerializedObjects$lzycompute(KryoSerializer.scala:253)
    	at org.apache.spark.serializer.KryoSerializer.supportsRelocationOfSerializedObjects(KryoSerializer.scala:249)
    	at org.apache.spark.util.Utils$.isPushBasedShuffleEnabled(Utils.scala:2584)
    	at org.apache.spark.MapOutputTrackerWorker.<init>(MapOutputTracker.scala:1109)
    	at org.apache.spark.SparkEnv$.create(SparkEnv.scala:322)
    	at org.apache.spark.SparkEnv$.createExecutorEnv(SparkEnv.scala:205)
    	at org.apache.spark.executor.CoarseGrainedExecutorBackend$.$anonfun$run$7(CoarseGrainedExecutorBackend.scala:442)
    	at org.apache.spark.deploy.SparkHadoopUtil$$anon$1.run(SparkHadoopUtil.scala:62)
    	at org.apache.spark.deploy.SparkHadoopUtil$$anon$1.run(SparkHadoopUtil.scala:61)
    	at java.security.AccessController.doPrivileged(Native Method)
    	at javax.security.auth.Subject.doAs(Subject.java:422)
    	at org.apache.hadoop.security.UserGroupInformation.doAs(UserGroupInformation.java:1893)
    	... 4 more
    Caused by: java.lang.ClassNotFoundException: ml.dmlc.xgboost4j.scala.Booster
    	at java.net.URLClassLoader.findClass(URLClassLoader.java:381)
    	at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
    	at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:349)
    	at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
    	at java.lang.Class.forName0(Native Method)
    	at java.lang.Class.forName(Class.java:348)
    	at org.apache.spark.util.Utils$.classForName(Utils.scala:217)
    	at org.apache.spark.serializer.KryoSerializer.$anonfun$newKryo$6(KryoSerializer.scala:174)
    	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    	at org.apache.spark.serializer.KryoSerializer.$anonfun$newKryo$5(KryoSerializer.scala:173)
    	... 24 more
    ```
    Registering user class for kryo serialization is happening after serializer creation in SparkEnv. Serializer creation can happen in `isPushBasedShuffleEnabled`, which can be called in some places prior to SparkEnv is created. Also, as per analysis by JoshRosen, this is probably due to Kryo instantiation was failing because added packages hadn't been downloaded to the executor yet (because this code is running during executor startup, not task startup). The proposed change helps fix th [...]
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Passed existing tests.
    Tested this patch in our internal branch where user reported the issue. Issue is now not reproducible with this patch.
    
    Closes #34158 from rmcyang/SPARK-33781-bugFix.
    
    Lead-authored-by: Minchu Yang <mi...@minyang-mn3.linkedin.biz>
    Co-authored-by: Minchu Yang <31...@users.noreply.github.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../main/scala/org/apache/spark/Dependency.scala   |  4 +-
 .../scala/org/apache/spark/MapOutputTracker.scala  |  8 ++-
 .../src/main/scala/org/apache/spark/SparkEnv.scala | 33 ++---------
 .../org/apache/spark/scheduler/DAGScheduler.scala  |  2 +-
 .../apache/spark/shuffle/ShuffleBlockPusher.scala  |  5 +-
 .../org/apache/spark/storage/BlockManager.scala    |  9 +--
 .../spark/storage/BlockManagerMasterEndpoint.scala |  5 +-
 .../apache/spark/storage/DiskBlockManager.scala    |  7 ++-
 .../spark/storage/PushBasedFetchHelper.scala       |  2 +-
 .../main/scala/org/apache/spark/util/Utils.scala   | 65 ++++++++++++++++++----
 .../org/apache/spark/MapOutputTrackerSuite.scala   |  2 +
 .../apache/spark/scheduler/DAGSchedulerSuite.scala |  4 ++
 .../shuffle/HostLocalShuffleReadingSuite.scala     |  1 +
 .../storage/BlockManagerReplicationSuite.scala     |  2 +-
 .../apache/spark/storage/BlockManagerSuite.scala   |  3 +-
 .../spark/storage/DiskBlockManagerSuite.scala      |  8 +--
 .../org/apache/spark/storage/DiskStoreSuite.scala  |  8 +--
 .../spark/storage/FallbackStorageSuite.scala       |  4 +-
 .../scala/org/apache/spark/util/UtilsSuite.scala   | 17 +++---
 .../streaming/ReceivedBlockHandlerSuite.scala      |  2 +-
 20 files changed, 117 insertions(+), 74 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 81e4c8f..1b4e7ba 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -163,7 +163,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
   }
 
   private def canShuffleMergeBeEnabled(): Boolean = {
-    val isPushShuffleEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf)
+    val isPushShuffleEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf,
+      // invoked at driver
+      isDriver = true)
     if (isPushShuffleEnabled && rdd.isBarrier()) {
       logWarning("Push-based shuffle is currently not supported for barrier stages")
     }
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 24954e7..ca1229a 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -617,7 +617,7 @@ private[spark] class MapOutputTrackerMaster(
   private val mapOutputTrackerMasterMessages =
     new LinkedBlockingQueue[MapOutputTrackerMasterMessage]
 
-  private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf)
+  private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf, isDriver = true)
 
   // Thread pool used for handling map output status requests. This is a separate thread pool
   // to ensure we don't block the normal dispatcher threads.
@@ -1126,7 +1126,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
   val mergeStatuses: Map[Int, Array[MergeStatus]] =
     new ConcurrentHashMap[Int, Array[MergeStatus]]().asScala
 
-  private val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf)
+  // This must be lazy to ensure that it is initialized when the first task is run and not at
+  // executor startup time. At startup time, user-added libraries may not have been
+  // downloaded to the executor, causing `isPushBasedShuffleEnabled` to fail when it tries to
+  // instantiate a serializer. See the followup to SPARK-36705 for more details.
+  private lazy val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf, isDriver = false)
 
   /**
    * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index ee50a8f..0388c7b 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -272,33 +272,7 @@ object SparkEnv extends Logging {
       conf.set(DRIVER_PORT, rpcEnv.address.port)
     }
 
-    // Create an instance of the class with the given name, possibly initializing it with our conf
-    def instantiateClass[T](className: String): T = {
-      val cls = Utils.classForName(className)
-      // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
-      // SparkConf, then one taking no arguments
-      try {
-        cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
-          .newInstance(conf, java.lang.Boolean.valueOf(isDriver))
-          .asInstanceOf[T]
-      } catch {
-        case _: NoSuchMethodException =>
-          try {
-            cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
-          } catch {
-            case _: NoSuchMethodException =>
-              cls.getConstructor().newInstance().asInstanceOf[T]
-          }
-      }
-    }
-
-    // Create an instance of the class named by the given SparkConf property
-    // if the property is not set, possibly initializing it with our conf
-    def instantiateClassFromConf[T](propertyName: ConfigEntry[String]): T = {
-      instantiateClass[T](conf.get(propertyName))
-    }
-
-    val serializer = instantiateClassFromConf[Serializer](SERIALIZER)
+    val serializer = Utils.instantiateSerializerFromConf[Serializer](SERIALIZER, conf, isDriver)
     logDebug(s"Using serializer: ${serializer.getClass}")
 
     val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey)
@@ -337,7 +311,8 @@ object SparkEnv extends Logging {
     val shuffleMgrName = conf.get(config.SHUFFLE_MANAGER)
     val shuffleMgrClass =
       shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName)
-    val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
+    val shuffleManager = Utils.instantiateSerializerOrShuffleManager[ShuffleManager](
+      shuffleMgrClass, conf, isDriver)
 
     val memoryManager: MemoryManager = UnifiedMemoryManager(conf, numUsableCores)
 
@@ -370,7 +345,7 @@ object SparkEnv extends Logging {
           } else {
             None
           }, blockManagerInfo,
-          mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])),
+          mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], isDriver)),
       registerOrLookupEndpoint(
         BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME,
         new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)),
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index a3df49a..442edc7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -254,7 +254,7 @@ private[spark] class DAGScheduler(
   private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
   taskScheduler.setDAGScheduler(this)
 
-  private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf)
+  private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf, isDriver = true)
 
   private val blockManagerMasterDriverHeartbeatTimeout =
     sc.getConf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
index bb260f8..50f9c8c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
@@ -24,7 +24,7 @@ import java.util.concurrent.ExecutorService
 
 import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
 
-import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv}
+import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv}
 import org.apache.spark.annotation.Since
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
@@ -463,7 +463,8 @@ private[spark] object ShuffleBlockPusher {
 
   private val BLOCK_PUSHER_POOL: ExecutorService = {
     val conf = SparkEnv.get.conf
-    if (Utils.isPushBasedShuffleEnabled(conf)) {
+    if (Utils.isPushBasedShuffleEnabled(conf,
+        isDriver = SparkContext.DRIVER_IDENTIFIER == SparkEnv.get.executorId)) {
       val numThreads = conf.get(SHUFFLE_NUM_PUSH_THREADS)
         .getOrElse(conf.getInt(SparkLauncher.EXECUTOR_CORES, 1))
       ThreadUtils.newDaemonFixedThreadPool(numThreads, "shuffle-block-push-thread")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index cbb4e9c..9ebf26b 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -185,6 +185,7 @@ private[spark] class BlockManager(
 
   // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)`
   private[spark] val externalShuffleServiceEnabled: Boolean = externalBlockStoreClient.isDefined
+  private val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
 
   private val remoteReadNioBufferConversion =
     conf.get(Network.NETWORK_REMOTE_READ_NIO_BUFFER_CONVERSION)
@@ -194,8 +195,8 @@ private[spark] class BlockManager(
   val diskBlockManager = {
     // Only perform cleanup if an external service is not serving our shuffle files.
     val deleteFilesOnStop =
-      !externalShuffleServiceEnabled || executorId == SparkContext.DRIVER_IDENTIFIER
-    new DiskBlockManager(conf, deleteFilesOnStop)
+      !externalShuffleServiceEnabled || isDriver
+    new DiskBlockManager(conf, deleteFilesOnStop = deleteFilesOnStop, isDriver = isDriver)
   }
 
   // Visible for testing
@@ -535,7 +536,7 @@ private[spark] class BlockManager(
     hostLocalDirManager = {
       if ((conf.get(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED) &&
           !conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) ||
-          Utils.isPushBasedShuffleEnabled(conf)) {
+          Utils.isPushBasedShuffleEnabled(conf, isDriver)) {
         Some(new HostLocalDirManager(
           futureExecutionContext,
           conf.get(config.STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE),
@@ -561,7 +562,7 @@ private[spark] class BlockManager(
   private def registerWithExternalShuffleServer(): Unit = {
     logInfo("Registering executor with local external shuffle service.")
     val shuffleManagerMeta =
-      if (Utils.isPushBasedShuffleEnabled(conf)) {
+      if (Utils.isPushBasedShuffleEnabled(conf, isDriver = isDriver, checkSerializer = false)) {
         s"${shuffleManager.getClass.getName}:" +
           s"${diskBlockManager.getMergeDirectoryAndAttemptIDJsonString()}}}"
       } else {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 6f043da..b96befc 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -51,7 +51,8 @@ class BlockManagerMasterEndpoint(
     listenerBus: LiveListenerBus,
     externalBlockStoreClient: Option[ExternalBlockStoreClient],
     blockManagerInfo: mutable.Map[BlockManagerId, BlockManagerInfo],
-    mapOutputTracker: MapOutputTrackerMaster)
+    mapOutputTracker: MapOutputTrackerMaster,
+    isDriver: Boolean)
   extends IsolatedRpcEndpoint with Logging {
 
   // Mapping from executor id to the block manager's local disk directories.
@@ -100,7 +101,7 @@ class BlockManagerMasterEndpoint(
 
   val defaultRpcTimeout = RpcUtils.askRpcTimeout(conf)
 
-  private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf)
+  private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf, isDriver)
 
   logInfo("BlockManagerMasterEndpoint up")
   // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index ee11e0e..bebe32b 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -45,7 +45,10 @@ import org.apache.spark.util.{ShutdownHookManager, Utils}
  *
  * ShuffleDataIO also can change the behavior of deleteFilesOnStop.
  */
-private[spark] class DiskBlockManager(conf: SparkConf, var deleteFilesOnStop: Boolean)
+private[spark] class DiskBlockManager(
+    conf: SparkConf,
+    var deleteFilesOnStop: Boolean,
+    isDriver: Boolean)
   extends Logging {
 
   private[spark] val subDirsPerLocalDir = conf.get(config.DISKSTORE_SUB_DIRECTORIES)
@@ -208,7 +211,7 @@ private[spark] class DiskBlockManager(conf: SparkConf, var deleteFilesOnStop: Bo
    * permission to create directories under application local directories.
    */
   private def createLocalDirsForMergedShuffleBlocks(): Unit = {
-    if (Utils.isPushBasedShuffleEnabled(conf)) {
+    if (Utils.isPushBasedShuffleEnabled(conf, isDriver = isDriver, checkSerializer = false)) {
       // Will create the merge_manager directory only if it doesn't exist under the local dir.
       Utils.getConfiguredLocalDirs(conf).foreach { rootDir =>
         try {
diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
index 99138b6..d83d901 100644
--- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
+++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
@@ -142,7 +142,7 @@ private class PushBasedFetchHelper(
     val mergedBlocksMetaListener = new MergedBlocksMetaListener {
       override def onSuccess(shuffleId: Int, shuffleMergeId: Int, reduceId: Int,
           meta: MergedBlockMeta): Unit = {
-        logInfo(s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," +
+        logDebug(s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," +
           s" $reduceId) from ${req.address.host}:${req.address.port}")
         try {
           iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, shuffleMergeId,
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index f3fc90d..0029bbd 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -2603,18 +2603,31 @@ private[spark] object Utils extends Logging {
    *   - IO encryption disabled
    *   - serializer(such as KryoSerializer) supports relocation of serialized objects
    */
-  def isPushBasedShuffleEnabled(conf: SparkConf): Boolean = {
+  def isPushBasedShuffleEnabled(conf: SparkConf,
+      isDriver: Boolean,
+      checkSerializer: Boolean = true): Boolean = {
     val pushBasedShuffleEnabled = conf.get(PUSH_BASED_SHUFFLE_ENABLED)
     if (pushBasedShuffleEnabled) {
-      val serializer = Utils.classForName(conf.get(SERIALIZER)).getConstructor(classOf[SparkConf])
-        .newInstance(conf).asInstanceOf[Serializer]
-      val canDoPushBasedShuffle = conf.get(IS_TESTING).getOrElse(false) ||
-        (conf.get(SHUFFLE_SERVICE_ENABLED) &&
-          conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn" &&
-          // TODO: [SPARK-36744] needs to support IO encryption for push-based shuffle
-          !conf.get(IO_ENCRYPTION_ENABLED) &&
-          serializer.supportsRelocationOfSerializedObjects)
-
+      val canDoPushBasedShuffle = {
+        val isTesting = conf.get(IS_TESTING).getOrElse(false)
+        val isShuffleServiceAndYarn = conf.get(SHUFFLE_SERVICE_ENABLED) &&
+            conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn"
+        lazy val serializerIsSupported = {
+          if (checkSerializer) {
+            Option(SparkEnv.get)
+              .map(_.serializer)
+              .filter(_ != null)
+              .getOrElse(instantiateSerializerFromConf[Serializer](SERIALIZER, conf, isDriver))
+              .supportsRelocationOfSerializedObjects
+          } else {
+            // if no need to check Serializer, always set serializerIsSupported as true
+            true
+          }
+        }
+        // TODO: [SPARK-36744] needs to support IO encryption for push-based shuffle
+        val ioEncryptionDisabled = !conf.get(IO_ENCRYPTION_ENABLED)
+        (isShuffleServiceAndYarn || isTesting) && ioEncryptionDisabled && serializerIsSupported
+      }
       if (!canDoPushBasedShuffle) {
         logWarning("Push-based shuffle can only be enabled when the application is submitted " +
           "to run in YARN mode, with external shuffle service enabled, IO encryption disabled, " +
@@ -2627,6 +2640,38 @@ private[spark] object Utils extends Logging {
     }
   }
 
+  // Create an instance of Serializer or ShuffleManager with the given name,
+  // possibly initializing it with our conf
+  def instantiateSerializerOrShuffleManager[T](className: String,
+      conf: SparkConf,
+      isDriver: Boolean): T = {
+    val cls = Utils.classForName(className)
+    // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
+    // SparkConf, then one taking no arguments
+    try {
+      cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
+        .newInstance(conf, java.lang.Boolean.valueOf(isDriver))
+        .asInstanceOf[T]
+    } catch {
+      case _: NoSuchMethodException =>
+        try {
+          cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+        } catch {
+          case _: NoSuchMethodException =>
+            cls.getConstructor().newInstance().asInstanceOf[T]
+        }
+    }
+  }
+
+  // Create an instance of Serializer named by the given SparkConf property
+  // if the property is not set, possibly initializing it with our conf
+  def instantiateSerializerFromConf[T](propertyName: ConfigEntry[String],
+      conf: SparkConf,
+      isDriver: Boolean): T = {
+    instantiateSerializerOrShuffleManager[T](
+      conf.get(propertyName), conf, isDriver)
+  }
+
   /**
    * Return whether dynamic allocation is enabled in the given conf.
    */
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index e81196f..4051118 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -337,6 +337,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
   test("SPARK-32921: master register and unregister merge result") {
     conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
     conf.set(IS_TESTING, true)
+    conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
     val rpcEnv = createRpcEnv("test")
     val tracker = newTrackerMaster()
     tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
@@ -596,6 +597,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
     newConf.set(SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST, 10240L) // 10 KiB << 1MiB framesize
     newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
     newConf.set(IS_TESTING, true)
+    newConf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
 
     // needs TorrentBroadcast so need a SparkContext
     withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc =>
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index deddaea..4cb64ed 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -3431,6 +3431,10 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
     conf.set("spark.master", "pushbasedshuffleclustermanager")
     // Needed to run push-based shuffle tests in ad-hoc manner through IDE
     conf.set(Tests.IS_TESTING, true)
+    // [SPARK-36705] Push-based shuffle does not work with Spark's default
+    // JavaSerializer and will be disabled with it, as it does not support
+    // object relocation
+    conf.set(config.SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
   }
 
   test("SPARK-32920: shuffle merge finalization") {
diff --git a/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala
index 33f544a..4e74036 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala
@@ -139,6 +139,7 @@ class HostLocalShuffleReadingSuite extends SparkFunSuite with Matchers with Loca
       .set(SHUFFLE_SERVICE_ENABLED, true)
       .set("spark.yarn.maxAttempts", "1")
       .set(PUSH_BASED_SHUFFLE_ENABLED, true)
+      .set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
     sc = new SparkContext("local-cluster[2, 1, 1024]", "test-host-local-shuffle-reading", conf)
     sc.env.blockManager.hostLocalDirManager.isDefined should equal(true)
   }
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index 495747b..fc7b7a4 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -102,7 +102,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite
     val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]()
     master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
       new BlockManagerMasterEndpoint(rpcEnv, true, conf,
-        new LiveListenerBus(conf), None, blockManagerInfo, mapOutputTracker)),
+        new LiveListenerBus(conf), None, blockManagerInfo, mapOutputTracker, isDriver = true)),
       rpcEnv.setupEndpoint("blockmanagerHeartbeat",
       new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true)
     allStores.clear()
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 173b839..2cb281d 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -98,6 +98,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
       .set(IS_TESTING, true)
       .set(MEMORY_FRACTION, 1.0)
       .set(MEMORY_STORAGE_FRACTION, 0.999)
+      .set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
       .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m")
       .set(STORAGE_UNROLL_MEMORY_THRESHOLD, 512L)
       .set(Network.RPC_ASK_TIMEOUT, "5s")
@@ -185,7 +186,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     liveListenerBus = spy(new LiveListenerBus(conf))
     master = spy(new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
       new BlockManagerMasterEndpoint(rpcEnv, true, conf,
-        liveListenerBus, None, blockManagerInfo, mapOutputTracker)),
+        liveListenerBus, None, blockManagerInfo, mapOutputTracker, isDriver = true)),
       rpcEnv.setupEndpoint("blockmanagerHeartbeat",
       new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true))
   }
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index 0443c40..b36eeb7 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -60,7 +60,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B
     super.beforeEach()
     val conf = testConf.clone
     conf.set("spark.local.dir", rootDirs)
-    diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true)
+    diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true, isDriver = false)
   }
 
   override def afterEach(): Unit = {
@@ -105,7 +105,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B
     testConf.set("spark.local.dir", rootDirs)
     testConf.set("spark.shuffle.push.enabled", "true")
     testConf.set(config.Tests.IS_TESTING, true)
-    diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true)
+    diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true, isDriver = false)
     assert(Utils.getConfiguredLocalDirs(testConf).map(
       rootDir => new File(rootDir, DiskBlockManager.MERGE_DIRECTORY))
       .filter(mergeDir => mergeDir.exists()).length === 2)
@@ -118,7 +118,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B
   test("Test dir creation with permission 770") {
     val testDir = new File("target/testDir");
     FileUtils.deleteQuietly(testDir)
-    diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true)
+    diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true, isDriver = false)
     diskBlockManager.createDirWithPermission770(testDir)
     assert(testDir.exists && testDir.isDirectory)
     val permission = PosixFilePermissions.toString(
@@ -129,7 +129,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B
 
   test("Encode merged directory name and attemptId in shuffleManager field") {
     testConf.set(config.APP_ATTEMPT_ID, "1");
-    diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true)
+    diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true, isDriver = false)
     val mergedShuffleMeta = diskBlockManager.getMergeDirectoryAndAttemptIDJsonString();
     val mapper: ObjectMapper = new ObjectMapper
     val typeRef: TypeReference[HashMap[String, String]] =
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
index 97b9c97..be1b9be 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
@@ -46,7 +46,7 @@ class DiskStoreSuite extends SparkFunSuite {
     val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes))
 
     val blockId = BlockId("rdd_1_2")
-    val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true)
+    val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true, isDriver = false)
 
     val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager,
       securityManager)
@@ -77,7 +77,7 @@ class DiskStoreSuite extends SparkFunSuite {
 
   test("block size tracking") {
     val conf = new SparkConf()
-    val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true)
+    val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true, isDriver = false)
     val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf))
 
     val blockId = BlockId("rdd_1_2")
@@ -96,7 +96,7 @@ class DiskStoreSuite extends SparkFunSuite {
   test("blocks larger than 2gb") {
     val conf = new SparkConf()
       .set(config.MEMORY_MAP_LIMIT_FOR_TESTS.key, "10k")
-    val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true)
+    val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true, isDriver = false)
     val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf))
 
     val blockId = BlockId("rdd_1_2")
@@ -137,7 +137,7 @@ class DiskStoreSuite extends SparkFunSuite {
 
     val conf = new SparkConf()
     val securityManager = new SecurityManager(conf, Some(CryptoStreamUtils.createKey(conf)))
-    val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true)
+    val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true, isDriver = false)
     val diskStore = new DiskStore(conf, diskBlockManager, securityManager)
 
     val blockId = BlockId("rdd_1_2")
diff --git a/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala
index f58d8ce..88197b6 100644
--- a/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala
@@ -68,7 +68,7 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext {
     val bmm = new BlockManagerMaster(new NoopRpcEndpointRef(conf), null, conf, false)
 
     val bm = mock(classOf[BlockManager])
-    val dbm = new DiskBlockManager(conf, false)
+    val dbm = new DiskBlockManager(conf, deleteFilesOnStop = false, isDriver = false)
     when(bm.diskBlockManager).thenReturn(dbm)
     when(bm.master).thenReturn(bmm)
     val resolver = new IndexShuffleBlockResolver(conf, bm)
@@ -134,7 +134,7 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext {
 
     val ids = Set((1, 1L, 1))
     val bm = mock(classOf[BlockManager])
-    val dbm = new DiskBlockManager(conf, false)
+    val dbm = new DiskBlockManager(conf, deleteFilesOnStop = false, isDriver = false)
     when(bm.diskBlockManager).thenReturn(dbm)
     val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf, bm)
     val indexFile = indexShuffleBlockResolver.getIndexFile(1, 1L)
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index f8607f1..05b24ec 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -1503,23 +1503,26 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
   test("isPushBasedShuffleEnabled when PUSH_BASED_SHUFFLE_ENABLED " +
     "and SHUFFLE_SERVICE_ENABLED are both set to true in YARN mode with maxAttempts set to 1") {
     val conf = new SparkConf()
-    assert(Utils.isPushBasedShuffleEnabled(conf) === false)
+    assert(Utils.isPushBasedShuffleEnabled(conf, isDriver = true) === false)
     conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
     conf.set(IS_TESTING, false)
-    assert(Utils.isPushBasedShuffleEnabled(conf) === false)
+    assert(Utils.isPushBasedShuffleEnabled(
+      conf, isDriver = false, checkSerializer = false) === false)
     conf.set(SHUFFLE_SERVICE_ENABLED, true)
     conf.set(SparkLauncher.SPARK_MASTER, "yarn")
     conf.set("spark.yarn.maxAppAttempts", "1")
     conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
-    assert(Utils.isPushBasedShuffleEnabled(conf) === true)
+    assert(Utils.isPushBasedShuffleEnabled(conf, isDriver = true) === true)
     conf.set("spark.yarn.maxAppAttempts", "2")
-    assert(Utils.isPushBasedShuffleEnabled(conf) === true)
+    assert(Utils.isPushBasedShuffleEnabled(
+      conf, isDriver = false, checkSerializer = false) === true)
     conf.set(IO_ENCRYPTION_ENABLED, true)
-    assert(Utils.isPushBasedShuffleEnabled(conf) === false)
+    assert(Utils.isPushBasedShuffleEnabled(conf, isDriver = true) === false)
     conf.set(IO_ENCRYPTION_ENABLED, false)
-    assert(Utils.isPushBasedShuffleEnabled(conf) === true)
+    assert(Utils.isPushBasedShuffleEnabled(
+      conf, isDriver = false, checkSerializer = false) === true)
     conf.set(SERIALIZER, "org.apache.spark.serializer.JavaSerializer")
-    assert(Utils.isPushBasedShuffleEnabled(conf) === false)
+    assert(Utils.isPushBasedShuffleEnabled(conf, isDriver = true) === false)
   }
 }
 
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 425e39c..3bcea1a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -93,7 +93,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean)
     val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]()
     blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
       new BlockManagerMasterEndpoint(rpcEnv, true, conf,
-        new LiveListenerBus(conf), None, blockManagerInfo, mapOutputTracker)),
+        new LiveListenerBus(conf), None, blockManagerInfo, mapOutputTracker, isDriver = true)),
       rpcEnv.setupEndpoint("blockmanagerHeartbeat",
       new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true)
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org