You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by tg...@apache.org on 2020/01/31 14:25:59 UTC

[spark] branch master updated: [SPARK-30638][CORE] Add resources allocated to PluginContext

This is an automated email from the ASF dual-hosted git repository.

tgraves pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d2b8d8  [SPARK-30638][CORE] Add resources allocated to PluginContext
3d2b8d8 is described below

commit 3d2b8d8b13eff0faa02316542a343e7a64873b8a
Author: Thomas Graves <tg...@nvidia.com>
AuthorDate: Fri Jan 31 08:25:32 2020 -0600

    [SPARK-30638][CORE] Add resources allocated to PluginContext
    
    ### What changes were proposed in this pull request?
    
    Add the allocated resources to parameters to the PluginContext so that any plugins in driver or executor could use this information to initialize devices or use this information in a useful manner.
    
    ### Why are the changes needed?
    
    To allow users to initialize/track devices once at the executor level before each task runs to use them.
    
    ### Does this PR introduce any user-facing change?
    
    Yes to the people using the Executor/Driver plugin interface.
    
    ### How was this patch tested?
    
    Unit tests and manually by writing a plugin that initialized GPU's using this interface.
    
    Closes #27367 from tgravescs/pluginWithResources.
    
    Lead-authored-by: Thomas Graves <tg...@nvidia.com>
    Co-authored-by: Thomas Graves <tg...@apache.org>
    Signed-off-by: Thomas Graves <tg...@apache.org>
---
 .../org/apache/spark/api/plugin/PluginContext.java |   5 +
 .../main/scala/org/apache/spark/SparkContext.scala |   2 +-
 .../executor/CoarseGrainedExecutorBackend.scala    |  10 +-
 .../scala/org/apache/spark/executor/Executor.scala |   7 +-
 .../spark/internal/plugin/PluginContainer.scala    |  36 +++++--
 .../spark/internal/plugin/PluginContextImpl.scala  |   6 +-
 .../scheduler/local/LocalSchedulerBackend.scala    |   5 +-
 .../org/apache/spark/executor/ExecutorSuite.scala  |  12 ++-
 .../internal/plugin/PluginContainerSuite.scala     | 109 +++++++++++++++++++--
 .../spark/executor/MesosExecutorBackend.scala      |   4 +-
 10 files changed, 167 insertions(+), 29 deletions(-)

diff --git a/core/src/main/java/org/apache/spark/api/plugin/PluginContext.java b/core/src/main/java/org/apache/spark/api/plugin/PluginContext.java
index b9413cf..36d8275 100644
--- a/core/src/main/java/org/apache/spark/api/plugin/PluginContext.java
+++ b/core/src/main/java/org/apache/spark/api/plugin/PluginContext.java
@@ -18,11 +18,13 @@
 package org.apache.spark.api.plugin;
 
 import java.io.IOException;
+import java.util.Map;
 
 import com.codahale.metrics.MetricRegistry;
 
 import org.apache.spark.SparkConf;
 import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.resource.ResourceInformation;
 
 /**
  * :: DeveloperApi ::
@@ -54,6 +56,9 @@ public interface PluginContext {
   /** The host name which is being used by the Spark process for communication. */
   String hostname();
 
+  /** The custom resources (GPUs, FPGAs, etc) allocated to driver or executor. */
+  Map<String, ResourceInformation> resources();
+
   /**
    * Send a message to the plugin's driver-side component.
    * <p>
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 3262631..6e0c7ac 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -542,7 +542,7 @@ class SparkContext(config: SparkConf) extends Logging {
       HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this))
 
     // Initialize any plugins before the task scheduler is initialized.
-    _plugins = PluginContainer(this)
+    _plugins = PluginContainer(this, _resources.asJava)
 
     // Create and start the scheduler
     val (sched, ts) = SparkContext.createTaskScheduler(this, master, deployMode)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 511c63a..ce211ce 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -69,6 +69,8 @@ private[spark] class CoarseGrainedExecutorBackend(
   // to be changed so that we don't share the serializer instance across threads
   private[this] val ser: SerializerInstance = env.closureSerializer.newInstance()
 
+  private var _resources = Map.empty[String, ResourceInformation]
+
   /**
    * Map each taskId to the information about the resource allocated to it, Please refer to
    * [[ResourceInformation]] for specifics.
@@ -78,9 +80,8 @@ private[spark] class CoarseGrainedExecutorBackend(
 
   override def onStart(): Unit = {
     logInfo("Connecting to driver: " + driverUrl)
-    var resources = Map.empty[String, ResourceInformation]
     try {
-      resources = parseOrFindResources(resourcesFileOpt)
+      _resources = parseOrFindResources(resourcesFileOpt)
     } catch {
       case NonFatal(e) =>
         exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
@@ -89,7 +90,7 @@ private[spark] class CoarseGrainedExecutorBackend(
       // This is a very fast action so we can use "ThreadUtils.sameThread"
       driver = Some(ref)
       ref.ask[Boolean](RegisterExecutor(executorId, self, hostname, cores, extractLogUrls,
-        extractAttributes, resources, resourceProfile.id))
+        extractAttributes, _resources, resourceProfile.id))
     }(ThreadUtils.sameThread).onComplete {
       case Success(_) =>
         self.send(RegisteredExecutor)
@@ -125,7 +126,8 @@ private[spark] class CoarseGrainedExecutorBackend(
     case RegisteredExecutor =>
       logInfo("Successfully registered with driver")
       try {
-        executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
+        executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false,
+          resources = _resources)
         driver.get.send(LaunchedExecutor(executorId))
       } catch {
         case NonFatal(e) =>
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 0ea16d0..8aeb16f 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -28,6 +28,7 @@ import java.util.concurrent.atomic.AtomicBoolean
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.JavaConverters._
+import scala.collection.immutable
 import scala.collection.mutable.{ArrayBuffer, HashMap, Map, WrappedArray}
 import scala.concurrent.duration._
 import scala.util.control.NonFatal
@@ -41,6 +42,7 @@ import org.apache.spark.internal.config._
 import org.apache.spark.internal.plugin.PluginContainer
 import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
 import org.apache.spark.metrics.source.JVMCPUSource
+import org.apache.spark.resource.ResourceInformation
 import org.apache.spark.rpc.RpcTimeout
 import org.apache.spark.scheduler._
 import org.apache.spark.shuffle.FetchFailedException
@@ -61,7 +63,8 @@ private[spark] class Executor(
     env: SparkEnv,
     userClassPath: Seq[URL] = Nil,
     isLocal: Boolean = false,
-    uncaughtExceptionHandler: UncaughtExceptionHandler = new SparkUncaughtExceptionHandler)
+    uncaughtExceptionHandler: UncaughtExceptionHandler = new SparkUncaughtExceptionHandler,
+    resources: immutable.Map[String, ResourceInformation])
   extends Logging {
 
   logInfo(s"Starting executor ID $executorId on host $executorHostname")
@@ -152,7 +155,7 @@ private[spark] class Executor(
 
   // Plugins need to load using a class loader that includes the executor's user classpath
   private val plugins: Option[PluginContainer] = Utils.withContextClassLoader(replClassLoader) {
-    PluginContainer(env)
+    PluginContainer(env, resources.asJava)
   }
 
   // Max size of direct result. If task result is bigger than this, we use the block manager
diff --git a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala
index fc7a9d8..4eda476 100644
--- a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala
+++ b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala
@@ -24,6 +24,7 @@ import org.apache.spark.{SparkContext, SparkEnv}
 import org.apache.spark.api.plugin._
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
+import org.apache.spark.resource.ResourceInformation
 import org.apache.spark.util.Utils
 
 sealed abstract class PluginContainer {
@@ -33,7 +34,10 @@ sealed abstract class PluginContainer {
 
 }
 
-private class DriverPluginContainer(sc: SparkContext, plugins: Seq[SparkPlugin])
+private class DriverPluginContainer(
+    sc: SparkContext,
+    resources: java.util.Map[String, ResourceInformation],
+    plugins: Seq[SparkPlugin])
   extends PluginContainer with Logging {
 
   private val driverPlugins: Seq[(String, DriverPlugin, PluginContextImpl)] = plugins.flatMap { p =>
@@ -41,7 +45,7 @@ private class DriverPluginContainer(sc: SparkContext, plugins: Seq[SparkPlugin])
     if (driverPlugin != null) {
       val name = p.getClass().getName()
       val ctx = new PluginContextImpl(name, sc.env.rpcEnv, sc.env.metricsSystem, sc.conf,
-        sc.env.executorId)
+        sc.env.executorId, resources)
 
       val extraConf = driverPlugin.init(sc, ctx)
       if (extraConf != null) {
@@ -83,7 +87,10 @@ private class DriverPluginContainer(sc: SparkContext, plugins: Seq[SparkPlugin])
 
 }
 
-private class ExecutorPluginContainer(env: SparkEnv, plugins: Seq[SparkPlugin])
+private class ExecutorPluginContainer(
+    env: SparkEnv,
+    resources: java.util.Map[String, ResourceInformation],
+    plugins: Seq[SparkPlugin])
   extends PluginContainer with Logging {
 
   private val executorPlugins: Seq[(String, ExecutorPlugin)] = {
@@ -100,7 +107,7 @@ private class ExecutorPluginContainer(env: SparkEnv, plugins: Seq[SparkPlugin])
           .toMap
           .asJava
         val ctx = new PluginContextImpl(name, env.rpcEnv, env.metricsSystem, env.conf,
-          env.executorId)
+          env.executorId, resources)
         executorPlugin.init(ctx, extraConf)
         ctx.registerMetrics()
 
@@ -133,17 +140,28 @@ object PluginContainer {
 
   val EXTRA_CONF_PREFIX = "spark.plugins.internal.conf."
 
-  def apply(sc: SparkContext): Option[PluginContainer] = PluginContainer(Left(sc))
+  def apply(
+      sc: SparkContext,
+      resources: java.util.Map[String, ResourceInformation]): Option[PluginContainer] = {
+    PluginContainer(Left(sc), resources)
+  }
+
+  def apply(
+      env: SparkEnv,
+      resources: java.util.Map[String, ResourceInformation]): Option[PluginContainer] = {
+    PluginContainer(Right(env), resources)
+  }
 
-  def apply(env: SparkEnv): Option[PluginContainer] = PluginContainer(Right(env))
 
-  private def apply(ctx: Either[SparkContext, SparkEnv]): Option[PluginContainer] = {
+  private def apply(
+      ctx: Either[SparkContext, SparkEnv],
+      resources: java.util.Map[String, ResourceInformation]): Option[PluginContainer] = {
     val conf = ctx.fold(_.conf, _.conf)
     val plugins = Utils.loadExtensions(classOf[SparkPlugin], conf.get(PLUGINS).distinct, conf)
     if (plugins.nonEmpty) {
       ctx match {
-        case Left(sc) => Some(new DriverPluginContainer(sc, plugins))
-        case Right(env) => Some(new ExecutorPluginContainer(env, plugins))
+        case Left(sc) => Some(new DriverPluginContainer(sc, resources, plugins))
+        case Right(env) => Some(new ExecutorPluginContainer(env, resources, plugins))
       }
     } else {
       None
diff --git a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContextImpl.scala b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContextImpl.scala
index 279f3d3..ca91194 100644
--- a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContextImpl.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.internal.plugin
 
+import java.util
+
 import com.codahale.metrics.MetricRegistry
 
 import org.apache.spark.{SparkConf, SparkException}
@@ -24,6 +26,7 @@ import org.apache.spark.api.plugin.PluginContext
 import org.apache.spark.internal.Logging
 import org.apache.spark.metrics.MetricsSystem
 import org.apache.spark.metrics.source.Source
+import org.apache.spark.resource.ResourceInformation
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.util.RpcUtils
 
@@ -32,7 +35,8 @@ private class PluginContextImpl(
     rpcEnv: RpcEnv,
     metricsSystem: MetricsSystem,
     override val conf: SparkConf,
-    override val executorID: String)
+    override val executorID: String,
+    override val resources: util.Map[String, ResourceInformation])
   extends PluginContext with Logging {
 
   override def hostname(): String = rpcEnv.address.hostPort.split(":")(0)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
index d2c0dc8..42a5afe 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
@@ -26,6 +26,7 @@ import org.apache.spark.TaskState.TaskState
 import org.apache.spark.executor.{Executor, ExecutorBackend}
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
+import org.apache.spark.resource.ResourceInformation
 import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
 import org.apache.spark.scheduler._
 import org.apache.spark.scheduler.cluster.ExecutorInfo
@@ -57,8 +58,10 @@ private[spark] class LocalEndpoint(
   val localExecutorId = SparkContext.DRIVER_IDENTIFIER
   val localExecutorHostname = Utils.localCanonicalHostName()
 
+  // local mode doesn't support extra resources like GPUs right now
   private val executor = new Executor(
-    localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true)
+    localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true,
+    resources = Map.empty[String, ResourceInformation])
 
   override def receive: PartialFunction[Any, Unit] = {
     case ReviveOffers =>
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 7272a98..31049d1 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -117,7 +117,8 @@ class ExecutorSuite extends SparkFunSuite
 
     var executor: Executor = null
     try {
-      executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true)
+      executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true,
+        resources = immutable.Map.empty[String, ResourceInformation])
       // the task will be launched in a dedicated worker thread
       executor.launchTask(mockExecutorBackend, taskDescription)
 
@@ -254,7 +255,8 @@ class ExecutorSuite extends SparkFunSuite
     val serializer = new JavaSerializer(conf)
     val env = createMockEnv(conf, serializer)
     val executor =
-      new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
+      new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
+        resources = immutable.Map.empty[String, ResourceInformation])
     val executorClass = classOf[Executor]
 
     // Save all heartbeats sent into an ArrayBuffer for verification
@@ -353,7 +355,8 @@ class ExecutorSuite extends SparkFunSuite
     val mockBackend = mock[ExecutorBackend]
     var executor: Executor = null
     try {
-      executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
+      executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
+        resources = immutable.Map.empty[String, ResourceInformation])
       executor.launchTask(mockBackend, taskDescription)
 
       // Ensure that the executor's metricsPoller is polled so that values are recorded for
@@ -466,7 +469,8 @@ class ExecutorSuite extends SparkFunSuite
     val timedOut = new AtomicBoolean(false)
     try {
       executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
-        uncaughtExceptionHandler = mockUncaughtExceptionHandler)
+        uncaughtExceptionHandler = mockUncaughtExceptionHandler,
+        resources = immutable.Map.empty[String, ResourceInformation])
       // the task will be launched in a dedicated worker thread
       executor.launchTask(mockBackend, taskDescription)
       if (killTask) {
diff --git a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala
index b432253..ac57e29 100644
--- a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala
@@ -32,15 +32,20 @@ import org.scalatest.BeforeAndAfterEach
 import org.scalatest.concurrent.Eventually.{eventually, interval, timeout}
 
 import org.apache.spark._
+import org.apache.spark.TestUtils._
 import org.apache.spark.api.plugin._
 import org.apache.spark.internal.config._
 import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.resource.ResourceInformation
+import org.apache.spark.resource.ResourceUtils.GPU
+import org.apache.spark.resource.TestResourceIDs.{DRIVER_GPU_ID, EXECUTOR_GPU_ID, WORKER_GPU_ID}
 import org.apache.spark.util.Utils
 
 class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with LocalSparkContext {
 
   override def afterEach(): Unit = {
     TestSparkPlugin.reset()
+    NonLocalModeSparkPlugin.reset()
     super.afterEach()
   }
 
@@ -61,6 +66,7 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
     verify(TestSparkPlugin.executorPlugin).init(any(), meq(TestSparkPlugin.extraConf))
 
     assert(TestSparkPlugin.executorContext != null)
+    assert(TestSparkPlugin.executorContext.resources.isEmpty)
 
     // One way messages don't block, so need to loop checking whether it arrives.
     TestSparkPlugin.executorContext.send("oneway")
@@ -105,7 +111,8 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
     val conf = new SparkConf()
     val env = mock(classOf[SparkEnv])
     when(env.conf).thenReturn(conf)
-    assert(PluginContainer(env) === None)
+    val container = PluginContainer(env, Map.empty[String, ResourceInformation].asJava)
+    assert(container === None)
   }
 
   test("merging of config options") {
@@ -140,6 +147,53 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
       assert(children.length >= 3)
     }
   }
+
+  test("plugin initialization in non-local mode with resources") {
+    withTempDir { dir =>
+      val scriptPath = createTempScriptWithExpectedOutput(dir, "gpuDiscoveryScript",
+        """{"name": "gpu","addresses":["5", "6"]}""")
+
+      val workerScript = createTempScriptWithExpectedOutput(dir, "resourceDiscoveryScript",
+        """{"name": "gpu","addresses":["3", "4"]}""")
+
+      val conf = new SparkConf()
+        .setAppName(getClass().getName())
+        .set(SparkLauncher.SPARK_MASTER, "local-cluster[1,1,1024]")
+        .set(PLUGINS, Seq(classOf[NonLocalModeSparkPlugin].getName()))
+        .set(NonLocalModeSparkPlugin.TEST_PATH_CONF, dir.getAbsolutePath())
+        .set(DRIVER_GPU_ID.amountConf, "2")
+        .set(DRIVER_GPU_ID.discoveryScriptConf, scriptPath)
+        .set(WORKER_GPU_ID.amountConf, "2")
+        .set(WORKER_GPU_ID.discoveryScriptConf, workerScript)
+        .set(EXECUTOR_GPU_ID.amountConf, "2")
+      sc = new SparkContext(conf)
+
+      // Ensure all executors has started
+      TestUtils.waitUntilExecutorsUp(sc, 1, 10000)
+
+      var children = Array.empty[File]
+      eventually(timeout(10.seconds), interval(100.millis)) {
+        children = dir.listFiles()
+        assert(children != null)
+        // we have 2 discovery scripts and then expect 1 driver and 1 executor file
+        assert(children.length >= 4)
+      }
+      val execFiles =
+        children.filter(_.getName.startsWith(NonLocalModeSparkPlugin.executorFileStr))
+      assert(execFiles.size === 1)
+      val allLines = Files.readLines(execFiles(0), StandardCharsets.US_ASCII)
+      assert(allLines.size === 1)
+      val addrs = NonLocalModeSparkPlugin.extractGpuAddrs(allLines.get(0))
+      assert(addrs.size === 2)
+      assert(addrs.sorted === Array("3", "4"))
+
+      assert(NonLocalModeSparkPlugin.driverContext != null)
+      val driverResources = NonLocalModeSparkPlugin.driverContext.resources()
+      assert(driverResources.size === 1)
+      assert(driverResources.get(GPU).addresses === Array("5", "6"))
+      assert(driverResources.get(GPU).name === GPU)
+    }
+  }
 }
 
 class NonLocalModeSparkPlugin extends SparkPlugin {
@@ -147,8 +201,10 @@ class NonLocalModeSparkPlugin extends SparkPlugin {
   override def driverPlugin(): DriverPlugin = {
     new DriverPlugin() {
       override def init(sc: SparkContext, ctx: PluginContext): JMap[String, String] = {
-        NonLocalModeSparkPlugin.writeFile(ctx.conf(), ctx.executorID())
-        Map.empty.asJava
+        NonLocalModeSparkPlugin.writeDriverFile(NonLocalModeSparkPlugin.driverFileStr, ctx.conf(),
+          ctx.executorID())
+        NonLocalModeSparkPlugin.driverContext = ctx
+        Map.empty[String, String].asJava
       }
     }
   }
@@ -156,7 +212,8 @@ class NonLocalModeSparkPlugin extends SparkPlugin {
   override def executorPlugin(): ExecutorPlugin = {
     new ExecutorPlugin() {
       override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = {
-        NonLocalModeSparkPlugin.writeFile(ctx.conf(), ctx.executorID())
+        NonLocalModeSparkPlugin.writeFile(NonLocalModeSparkPlugin.executorFileStr, ctx.conf(),
+        ctx.executorID(), ctx.resources().asScala.toMap)
       }
     }
   }
@@ -164,10 +221,50 @@ class NonLocalModeSparkPlugin extends SparkPlugin {
 
 object NonLocalModeSparkPlugin {
   val TEST_PATH_CONF = "spark.nonLocalPlugin.path"
+  var driverContext: PluginContext = _
+  val executorFileStr = "EXECUTOR_FILE_"
+  val driverFileStr = "DRIVER_FILE_"
+
+  private def createFileStringWithGpuAddrs(
+      id: String,
+      resources: Map[String, ResourceInformation]): String = {
+    // try to keep this simple and only write the gpus addresses, if we add more resources need to
+    // make more complex
+    val resourcesString = resources.filterKeys(_.equals(GPU)).map {
+      case (_, ri) =>
+        s"${ri.addresses.mkString(",")}"
+    }.mkString(",")
+    s"$id&$resourcesString"
+  }
 
-  def writeFile(conf: SparkConf, id: String): Unit = {
+  def extractGpuAddrs(str: String): Array[String] = {
+    val idAndAddrs = str.split("&")
+    if (idAndAddrs.size > 1) {
+      idAndAddrs(1).split(",")
+    } else {
+      Array.empty[String]
+    }
+  }
+
+  def writeDriverFile(
+      filePrefix: String,
+      conf: SparkConf,
+      id: String): Unit = {
+    writeFile(filePrefix, conf, id, Map.empty)
+  }
+
+  def writeFile(
+      filePrefix: String,
+      conf: SparkConf,
+      id: String,
+      resources: Map[String, ResourceInformation]): Unit = {
     val path = conf.get(TEST_PATH_CONF)
-    Files.write(id, new File(path, id), StandardCharsets.UTF_8)
+    val strToWrite = createFileStringWithGpuAddrs(id, resources)
+    Files.write(strToWrite, new File(path, s"$filePrefix$id"), StandardCharsets.UTF_8)
+  }
+
+  def reset(): Unit = {
+    driverContext = null
   }
 }
 
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index 213d33c..47243e8 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -30,6 +30,7 @@ import org.apache.spark.TaskState
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.EXECUTOR_ID
+import org.apache.spark.resource.ResourceInformation
 import org.apache.spark.scheduler.TaskDescription
 import org.apache.spark.scheduler.cluster.mesos.MesosSchedulerUtils
 import org.apache.spark.util.Utils
@@ -82,7 +83,8 @@ private[spark] class MesosExecutorBackend
     executor = new Executor(
       executorId,
       slaveInfo.getHostname,
-      env)
+      env,
+      resources = Map.empty[String, ResourceInformation])
   }
 
   override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org