You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/04/21 08:32:06 UTC

[spark] branch branch-3.2 updated: [SPARK-38916][CORE] Tasks not killed caused by race conditions between killTask() and launchTask()

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 9dd64d40c91 [SPARK-38916][CORE] Tasks not killed caused by race conditions between killTask() and launchTask()
9dd64d40c91 is described below

commit 9dd64d40c91253c275fef2313c6a326ef72112cb
Author: Maryann Xue <ma...@gmail.com>
AuthorDate: Thu Apr 21 16:30:54 2022 +0800

    [SPARK-38916][CORE] Tasks not killed caused by race conditions between killTask() and launchTask()
    
    ### What changes were proposed in this pull request?
    
    This PR fixes the race conditions between the killTask() call and the launchTask() call that sometimes causes tasks not to be killed properly. If killTask() probes the map of pendingTasksLaunches before launchTask() has had a chance to put the corresponding task into that map, the kill flag will be lost and the subsequent launchTask() call will just proceed and run that task without knowing this task should be killed instead. The fix adds a kill mark during the killTask() call so that [...]
    
    ### Why are the changes needed?
    
    Bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added UTs.
    
    Closes #36238 from maryannxue/spark-38916.
    
    Authored-by: Maryann Xue <ma...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit bb5092b9af60afdceeccb239d14be660f77ae0ea)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../scala/org/apache/spark/executor/Executor.scala |  51 +++++-
 .../CoarseGrainedExecutorBackendSuite.scala        | 185 ++++++++++++++++++++-
 .../org/apache/spark/executor/ExecutorSuite.scala  |  10 +-
 3 files changed, 230 insertions(+), 16 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 3f1023e3491..4c84224dd05 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -83,7 +83,7 @@ private[spark] class Executor(
 
   private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
 
-  private val conf = env.conf
+  private[executor] val conf = env.conf
 
   // No ip or host:port - just hostname
   Utils.checkHost(executorHostname)
@@ -104,7 +104,7 @@ private[spark] class Executor(
   // Use UninterruptibleThread to run tasks so that we can allow running codes without being
   // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622,
   // will hang forever if some methods are interrupted.
-  private val threadPool = {
+  private[executor] val threadPool = {
     val threadFactory = new ThreadFactoryBuilder()
       .setDaemon(true)
       .setNameFormat("Executor task launch worker-%d")
@@ -174,7 +174,33 @@ private[spark] class Executor(
   private val maxResultSize = conf.get(MAX_RESULT_SIZE)
 
   // Maintains the list of running tasks.
-  private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
+  private[executor] val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
+
+  // Kill mark TTL in milliseconds - 10 seconds.
+  private val KILL_MARK_TTL_MS = 10000L
+
+  // Kill marks with interruptThread flag, kill reason and timestamp.
+  // This is to avoid dropping the kill event when killTask() is called before launchTask().
+  private[executor] val killMarks = new ConcurrentHashMap[Long, (Boolean, String, Long)]
+
+  private val killMarkCleanupTask = new Runnable {
+    override def run(): Unit = {
+      val oldest = System.currentTimeMillis() - KILL_MARK_TTL_MS
+      val iter = killMarks.entrySet().iterator()
+      while (iter.hasNext) {
+        if (iter.next().getValue._3 < oldest) {
+          iter.remove()
+        }
+      }
+    }
+  }
+
+  // Kill mark cleanup thread executor.
+  private val killMarkCleanupService =
+    ThreadUtils.newDaemonSingleThreadScheduledExecutor("executor-kill-mark-cleanup")
+
+  killMarkCleanupService.scheduleAtFixedRate(
+    killMarkCleanupTask, KILL_MARK_TTL_MS, KILL_MARK_TTL_MS, TimeUnit.MILLISECONDS)
 
   /**
    * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES`
@@ -264,9 +290,18 @@ private[spark] class Executor(
     decommissioned = true
   }
 
+  private[executor] def createTaskRunner(context: ExecutorBackend,
+    taskDescription: TaskDescription) = new TaskRunner(context, taskDescription, plugins)
+
   def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
-    val tr = new TaskRunner(context, taskDescription, plugins)
-    runningTasks.put(taskDescription.taskId, tr)
+    val taskId = taskDescription.taskId
+    val tr = createTaskRunner(context, taskDescription)
+    runningTasks.put(taskId, tr)
+    val killMark = killMarks.get(taskId)
+    if (killMark != null) {
+      tr.kill(killMark._1, killMark._2)
+      killMarks.remove(taskId)
+    }
     threadPool.execute(tr)
     if (decommissioned) {
       log.error(s"Launching a task while in decommissioned state.")
@@ -274,6 +309,7 @@ private[spark] class Executor(
   }
 
   def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = {
+    killMarks.put(taskId, (interruptThread, reason, System.currentTimeMillis()))
     val taskRunner = runningTasks.get(taskId)
     if (taskRunner != null) {
       if (taskReaperEnabled) {
@@ -296,6 +332,8 @@ private[spark] class Executor(
       } else {
         taskRunner.kill(interruptThread = interruptThread, reason = reason)
       }
+      // Safe to remove kill mark as we got a chance with the TaskRunner.
+      killMarks.remove(taskId)
     }
   }
 
@@ -334,6 +372,9 @@ private[spark] class Executor(
       if (threadPool != null) {
         threadPool.shutdown()
       }
+      if (killMarkCleanupService != null) {
+        killMarkCleanupService.shutdown()
+      }
       if (replClassLoader != null && plugins != null) {
         // Notify plugins that executor is shutting down so they can terminate cleanly
         Utils.withContextClassLoader(replClassLoader) {
diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
index 4909a586d31..5210990f3b9 100644
--- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
@@ -21,14 +21,17 @@ import java.io.File
 import java.net.URL
 import java.nio.ByteBuffer
 import java.util.Properties
+import java.util.concurrent.ConcurrentHashMap
 
+import scala.collection.concurrent.TrieMap
 import scala.collection.mutable
 import scala.concurrent.duration._
 
 import org.json4s.{DefaultFormats, Extraction}
 import org.json4s.JsonAST.{JArray, JObject}
 import org.json4s.JsonDSL._
-import org.mockito.Mockito.when
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
 import org.scalatest.concurrent.Eventually.{eventually, timeout}
 import org.scalatestplus.mockito.MockitoSugar
 
@@ -39,9 +42,9 @@ import org.apache.spark.resource.ResourceUtils._
 import org.apache.spark.resource.TestResourceIDs._
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.scheduler.TaskDescription
-import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.LaunchTask
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{KillTask, LaunchTask}
 import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.util.{SerializableBuffer, Utils}
+import org.apache.spark.util.{SerializableBuffer, ThreadUtils, Utils}
 
 class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
     with LocalSparkContext with MockitoSugar {
@@ -357,6 +360,182 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
     assert(arg.bindAddress == "bindaddress1")
   }
 
+  /**
+   * This testcase is to verify that [[Executor.killTask()]] will always cancel a task that is
+   * being executed in [[Executor.TaskRunner]].
+   */
+  test(s"Tasks launched should always be cancelled.")  {
+    val conf = new SparkConf
+    val securityMgr = new SecurityManager(conf)
+    val serializer = new JavaSerializer(conf)
+    val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor")
+    var backend: CoarseGrainedExecutorBackend = null
+
+    try {
+      val rpcEnv = RpcEnv.create("1", "localhost", 0, conf, securityMgr)
+      val env = createMockEnv(conf, serializer, Some(rpcEnv))
+      backend = new CoarseGrainedExecutorBackend(env.rpcEnv, rpcEnv.address.hostPort, "1",
+        "host1", "host1", 4, env, None,
+        resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf))
+
+      backend.rpcEnv.setupEndpoint("Executor 1", backend)
+      backend.executor = mock[Executor](CALLS_REAL_METHODS)
+      val executor = backend.executor
+      // Mock the executor.
+      when(executor.threadPool).thenReturn(threadPool)
+      val runningTasks = spy(new ConcurrentHashMap[Long, Executor#TaskRunner])
+      when(executor.runningTasks).thenAnswer(_ => runningTasks)
+      when(executor.conf).thenReturn(conf)
+
+      // We don't really verify the data, just pass it around.
+      val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
+
+      val numTasks = 1000
+      val tasksKilled = new TrieMap[Long, Boolean]()
+      val tasksExecuted = new TrieMap[Long, Boolean]()
+
+      // Fake tasks with different taskIds.
+      val taskDescriptions = (1 to numTasks).map {
+        taskId => new TaskDescription(taskId, 2, "1", "TASK ${taskId}", 19,
+          1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new Properties, 1,
+          Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
+      }
+      assert(taskDescriptions.length == numTasks)
+
+      def getFakeTaskRunner(taskDescription: TaskDescription): Executor#TaskRunner = {
+        new executor.TaskRunner(backend, taskDescription, None) {
+          override def run(): Unit = {
+            tasksExecuted.put(taskDescription.taskId, true)
+            logInfo(s"task ${taskDescription.taskId} runs.")
+          }
+
+          override def kill(interruptThread: Boolean, reason: String): Unit = {
+            logInfo(s"task ${taskDescription.taskId} killed.")
+            tasksKilled.put(taskDescription.taskId, true)
+          }
+        }
+      }
+
+      // Feed the fake task-runners to be executed by the executor.
+      val firstLaunchTask = getFakeTaskRunner(taskDescriptions(1))
+      val otherTasks = taskDescriptions.slice(1, numTasks).map(getFakeTaskRunner(_)).toArray
+      assert (otherTasks.length == numTasks - 1)
+      // Workaround for compilation issue around Mockito.doReturn
+      doReturn(firstLaunchTask, otherTasks: _*).when(executor).
+        createTaskRunner(any(), any())
+
+      // Launch tasks and quickly kill them so that TaskRunner.killTask will be triggered.
+      taskDescriptions.foreach { taskDescription =>
+        val buffer = new SerializableBuffer(TaskDescription.encode(taskDescription))
+        backend.self.send(LaunchTask(buffer))
+        Thread.sleep(1)
+        backend.self.send(KillTask(taskDescription.taskId, "exec1", false, "test"))
+      }
+
+      eventually(timeout(10.seconds)) {
+        verify(runningTasks, times(numTasks)).put(any(), any())
+      }
+
+      assert(tasksExecuted.size == tasksKilled.size,
+        s"Tasks killed ${tasksKilled.size} != tasks executed ${tasksExecuted.size}")
+      assert(tasksExecuted.keySet == tasksKilled.keySet)
+      logInfo(s"Task executed ${tasksExecuted.size}, task killed ${tasksKilled.size}")
+    } finally {
+      if (backend != null) {
+        backend.rpcEnv.shutdown()
+      }
+      threadPool.shutdownNow()
+    }
+  }
+
+  /**
+   * This testcase is to verify that [[Executor.killTask()]] will always cancel a task even if
+   * it has not been launched yet.
+   */
+  test(s"Tasks not launched should always be cancelled.")  {
+    val conf = new SparkConf
+    val securityMgr = new SecurityManager(conf)
+    val serializer = new JavaSerializer(conf)
+    val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor")
+    var backend: CoarseGrainedExecutorBackend = null
+
+    try {
+      val rpcEnv = RpcEnv.create("1", "localhost", 0, conf, securityMgr)
+      val env = createMockEnv(conf, serializer, Some(rpcEnv))
+      backend = new CoarseGrainedExecutorBackend(env.rpcEnv, rpcEnv.address.hostPort, "1",
+        "host1", "host1", 4, env, None,
+        resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf))
+
+      backend.rpcEnv.setupEndpoint("Executor 1", backend)
+      backend.executor = mock[Executor](CALLS_REAL_METHODS)
+      val executor = backend.executor
+      // Mock the executor.
+      when(executor.threadPool).thenReturn(threadPool)
+      val runningTasks = spy(new ConcurrentHashMap[Long, Executor#TaskRunner])
+      when(executor.runningTasks).thenAnswer(_ => runningTasks)
+      when(executor.conf).thenReturn(conf)
+
+      // We don't really verify the data, just pass it around.
+      val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
+
+      val numTasks = 1000
+      val tasksKilled = new TrieMap[Long, Boolean]()
+      val tasksExecuted = new TrieMap[Long, Boolean]()
+
+      // Fake tasks with different taskIds.
+      val taskDescriptions = (1 to numTasks).map {
+        taskId => new TaskDescription(taskId, 2, "1", "TASK ${taskId}", 19,
+          1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new Properties, 1,
+          Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
+      }
+      assert(taskDescriptions.length == numTasks)
+
+      def getFakeTaskRunner(taskDescription: TaskDescription): Executor#TaskRunner = {
+        new executor.TaskRunner(backend, taskDescription, None) {
+          override def run(): Unit = {
+            tasksExecuted.put(taskDescription.taskId, true)
+            logInfo(s"task ${taskDescription.taskId} runs.")
+          }
+
+          override def kill(interruptThread: Boolean, reason: String): Unit = {
+            logInfo(s"task ${taskDescription.taskId} killed.")
+            tasksKilled.put(taskDescription.taskId, true)
+          }
+        }
+      }
+
+      // Feed the fake task-runners to be executed by the executor.
+      val firstLaunchTask = getFakeTaskRunner(taskDescriptions(1))
+      val otherTasks = taskDescriptions.slice(1, numTasks).map(getFakeTaskRunner(_)).toArray
+      assert (otherTasks.length == numTasks - 1)
+      // Workaround for compilation issue around Mockito.doReturn
+      doReturn(firstLaunchTask, otherTasks: _*).when(executor).
+        createTaskRunner(any(), any())
+
+      // The reverse order of events can happen when the scheduler tries to cancel a task right
+      // after launching it.
+      taskDescriptions.foreach { taskDescription =>
+        val buffer = new SerializableBuffer(TaskDescription.encode(taskDescription))
+        backend.self.send(KillTask(taskDescription.taskId, "exec1", false, "test"))
+        backend.self.send(LaunchTask(buffer))
+      }
+
+      eventually(timeout(10.seconds)) {
+        verify(runningTasks, times(numTasks)).put(any(), any())
+      }
+
+      assert(tasksExecuted.size == tasksKilled.size,
+        s"Tasks killed ${tasksKilled.size} != tasks executed ${tasksExecuted.size}")
+      assert(tasksExecuted.keySet == tasksKilled.keySet)
+      logInfo(s"Task executed ${tasksExecuted.size}, task killed ${tasksKilled.size}")
+    } finally {
+      if (backend != null) {
+        backend.rpcEnv.shutdown()
+      }
+      threadPool.shutdownNow()
+    }
+  }
+
   private def createMockEnv(conf: SparkConf, serializer: JavaSerializer,
       rpcEnv: Option[RpcEnv] = None): SparkEnv = {
     val mockEnv = mock[SparkEnv]
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index a237447b0fa..7f7b10c8c33 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -22,7 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
 import java.net.URL
 import java.nio.ByteBuffer
 import java.util.Properties
-import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit}
+import java.util.concurrent.{CountDownLatch, TimeUnit}
 import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.collection.immutable
@@ -321,13 +321,7 @@ class ExecutorSuite extends SparkFunSuite
       nonZeroAccumulator.add(1)
       metrics.registerAccumulator(nonZeroAccumulator)
 
-      val executorClass = classOf[Executor]
-      val tasksMap = {
-        val field =
-          executorClass.getDeclaredField("org$apache$spark$executor$Executor$$runningTasks")
-        field.setAccessible(true)
-        field.get(executor).asInstanceOf[ConcurrentHashMap[Long, executor.TaskRunner]]
-      }
+      val tasksMap = executor.runningTasks
       val mockTaskRunner = mock[executor.TaskRunner]
       val mockTask = mock[Task[Any]]
       when(mockTask.metrics).thenReturn(metrics)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org