You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2020/10/16 03:13:54 UTC

[spark] branch master updated: [SPARK-33088][CORE] Enhance ExecutorPlugin API to include callbacks on task start and end events

This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8f4fc22  [SPARK-33088][CORE] Enhance ExecutorPlugin API to include callbacks on task start and end events
8f4fc22 is described below

commit 8f4fc22dc460eb05c47e0d61facf116c60b1be37
Author: Samuel Souza <ss...@palantir.com>
AuthorDate: Thu Oct 15 22:12:41 2020 -0500

    [SPARK-33088][CORE] Enhance ExecutorPlugin API to include callbacks on task start and end events
    
    ### What changes were proposed in this pull request?
    Proposing a new set of APIs for ExecutorPlugins, to provide callbacks invoked at the start and end of each task of a job. Not very opinionated on the shape of the API, tried to be as minimal as possible for now.
    
    ### Why are the changes needed?
    Changes described in detail on [SPARK-33088](https://issues.apache.org/jira/browse/SPARK-33088), but mostly this boils down to:
    
    1. This feature was considered when the ExecutorPlugin API was initially introduced in #21923, but never implemented.
    2. The use-case which **requires** this feature is to propagate tracing information from the driver to the executor, such that calls from the same job can all be traced.
      a. Tracing frameworks usually are setup in thread locals, therefore it's important for the setup to happen in the same thread which runs the tasks.
      b. Executors can be for multiple jobs, therefore it's not sufficient to set tracing information at executor startup time -- it needs to happen every time a task starts or ends.
    
    ### Does this PR introduce _any_ user-facing change?
    No. This PR introduces new features for future developers to use.
    
    ### How was this patch tested?
    Unit tests on `PluginContainerSuite`.
    
    Closes #29977 from fsamuel-bs/SPARK-33088.
    
    Authored-by: Samuel Souza <ss...@palantir.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../apache/spark/api/plugin/ExecutorPlugin.java    | 42 +++++++++++++++++++
 .../scala/org/apache/spark/executor/Executor.scala | 32 ++++++++------
 .../spark/internal/plugin/PluginContainer.scala    | 49 +++++++++++++++++++++-
 .../scala/org/apache/spark/scheduler/Task.scala    |  6 ++-
 .../internal/plugin/PluginContainerSuite.scala     | 47 +++++++++++++++++++++
 .../apache/spark/scheduler/TaskContextSuite.scala  |  4 +-
 6 files changed, 163 insertions(+), 17 deletions(-)

diff --git a/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java b/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java
index 4961308..481bf98 100644
--- a/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java
+++ b/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java
@@ -19,6 +19,7 @@ package org.apache.spark.api.plugin;
 
 import java.util.Map;
 
+import org.apache.spark.TaskFailedReason;
 import org.apache.spark.annotation.DeveloperApi;
 
 /**
@@ -54,4 +55,45 @@ public interface ExecutorPlugin {
    */
   default void shutdown() {}
 
+  /**
+   * Perform any action before the task is run.
+   * <p>
+   * This method is invoked from the same thread the task will be executed.
+   * Task-specific information can be accessed via {@link org.apache.spark.TaskContext#get}.
+   * <p>
+   * Plugin authors should avoid expensive operations here, as this method will be called
+   * on every task, and doing something expensive can significantly slow down a job.
+   * It is not recommended for a user to call a remote service, for example.
+   * <p>
+   * Exceptions thrown from this method do not propagate - they're caught,
+   * logged, and suppressed. Therefore exceptions when executing this method won't
+   * make the job fail.
+   *
+   * @since 3.1.0
+   */
+  default void onTaskStart() {}
+
+  /**
+   * Perform an action after tasks completes without exceptions.
+   * <p>
+   * As {@link #onTaskStart() onTaskStart} exceptions are suppressed, this method
+   * will still be invoked even if the corresponding {@link #onTaskStart} call for this
+   * task failed.
+   * <p>
+   * Same warnings of {@link #onTaskStart() onTaskStart} apply here.
+   *
+   * @since 3.1.0
+   */
+  default void onTaskSucceeded() {}
+
+  /**
+   * Perform an action after tasks completes with exceptions.
+   * <p>
+   * Same warnings of {@link #onTaskStart() onTaskStart} apply here.
+   *
+   * @param failureReason the exception thrown from the failed task.
+   *
+   * @since 3.1.0
+   */
+  default void onTaskFailed(TaskFailedReason failureReason) {}
 }
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 27addd8..6653650 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -253,7 +253,7 @@ private[spark] class Executor(
   }
 
   def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
-    val tr = new TaskRunner(context, taskDescription)
+    val tr = new TaskRunner(context, taskDescription, plugins)
     runningTasks.put(taskDescription.taskId, tr)
     threadPool.execute(tr)
     if (decommissioned) {
@@ -332,7 +332,8 @@ private[spark] class Executor(
 
   class TaskRunner(
       execBackend: ExecutorBackend,
-      private val taskDescription: TaskDescription)
+      private val taskDescription: TaskDescription,
+      private val plugins: Option[PluginContainer])
     extends Runnable {
 
     val taskId = taskDescription.taskId
@@ -479,7 +480,8 @@ private[spark] class Executor(
             taskAttemptId = taskId,
             attemptNumber = taskDescription.attemptNumber,
             metricsSystem = env.metricsSystem,
-            resources = taskDescription.resources)
+            resources = taskDescription.resources,
+            plugins = plugins)
           threwException = false
           res
         } {
@@ -614,6 +616,7 @@ private[spark] class Executor(
 
         executorSource.SUCCEEDED_TASKS.inc(1L)
         setTaskFinishedAndClearInterruptStatus()
+        plugins.foreach(_.onTaskSucceeded())
         execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
       } catch {
         case t: TaskKilledException =>
@@ -623,9 +626,9 @@ private[spark] class Executor(
           // Here and below, put task metric peaks in a WrappedArray to expose them as a Seq
           // without requiring a copy.
           val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
-          val serializedTK = ser.serialize(
-            TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq))
-          execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
+          val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq)
+          plugins.foreach(_.onTaskFailed(reason))
+          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
 
         case _: InterruptedException | NonFatal(_) if
             task != null && task.reasonIfKilled.isDefined =>
@@ -634,9 +637,9 @@ private[spark] class Executor(
 
           val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
           val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
-          val serializedTK = ser.serialize(
-            TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq))
-          execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
+          val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq)
+          plugins.foreach(_.onTaskFailed(reason))
+          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
 
         case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
           val reason = task.context.fetchFailed.get.toTaskFailedReason
@@ -650,11 +653,13 @@ private[spark] class Executor(
               s"other exception: $t")
           }
           setTaskFinishedAndClearInterruptStatus()
+          plugins.foreach(_.onTaskFailed(reason))
           execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
 
         case CausedBy(cDE: CommitDeniedException) =>
           val reason = cDE.toTaskCommitDeniedReason
           setTaskFinishedAndClearInterruptStatus()
+          plugins.foreach(_.onTaskFailed(reason))
           execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
 
         case t: Throwable if env.isStopped =>
@@ -677,21 +682,22 @@ private[spark] class Executor(
             val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
             val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
 
-            val serializedTaskEndReason = {
+            val (taskFailureReason, serializedTaskFailureReason) = {
               try {
                 val ef = new ExceptionFailure(t, accUpdates).withAccums(accums)
                   .withMetricPeaks(metricPeaks.toSeq)
-                ser.serialize(ef)
+                (ef, ser.serialize(ef))
               } catch {
                 case _: NotSerializableException =>
                   // t is not serializable so just send the stacktrace
                   val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums)
                     .withMetricPeaks(metricPeaks.toSeq)
-                  ser.serialize(ef)
+                  (ef, ser.serialize(ef))
               }
             }
             setTaskFinishedAndClearInterruptStatus()
-            execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
+            plugins.foreach(_.onTaskFailed(taskFailureReason))
+            execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason)
           } else {
             logInfo("Not reporting error to driver during JVM shutdown.")
           }
diff --git a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala
index 4eda476..f78ec25 100644
--- a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala
+++ b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.internal.plugin
 import scala.collection.JavaConverters._
 import scala.util.{Either, Left, Right}
 
-import org.apache.spark.{SparkContext, SparkEnv}
+import org.apache.spark.{SparkContext, SparkEnv, TaskFailedReason}
 import org.apache.spark.api.plugin._
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
@@ -31,6 +31,9 @@ sealed abstract class PluginContainer {
 
   def shutdown(): Unit
   def registerMetrics(appId: String): Unit
+  def onTaskStart(): Unit
+  def onTaskSucceeded(): Unit
+  def onTaskFailed(failureReason: TaskFailedReason): Unit
 
 }
 
@@ -85,6 +88,17 @@ private class DriverPluginContainer(
     }
   }
 
+  override def onTaskStart(): Unit = {
+    throw new IllegalStateException("Should not be called for the driver container.")
+  }
+
+  override def onTaskSucceeded(): Unit = {
+    throw new IllegalStateException("Should not be called for the driver container.")
+  }
+
+  override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
+    throw new IllegalStateException("Should not be called for the driver container.")
+  }
 }
 
 private class ExecutorPluginContainer(
@@ -134,6 +148,39 @@ private class ExecutorPluginContainer(
       }
     }
   }
+
+  override def onTaskStart(): Unit = {
+    executorPlugins.foreach { case (name, plugin) =>
+      try {
+        plugin.onTaskStart()
+      } catch {
+        case t: Throwable =>
+          logInfo(s"Exception while calling onTaskStart on plugin $name.", t)
+      }
+    }
+  }
+
+  override def onTaskSucceeded(): Unit = {
+    executorPlugins.foreach { case (name, plugin) =>
+      try {
+        plugin.onTaskSucceeded()
+      } catch {
+        case t: Throwable =>
+          logInfo(s"Exception while calling onTaskSucceeded on plugin $name.", t)
+      }
+    }
+  }
+
+  override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
+    executorPlugins.foreach { case (name, plugin) =>
+      try {
+        plugin.onTaskFailed(failureReason)
+      } catch {
+        case t: Throwable =>
+          logInfo(s"Exception while calling onTaskFailed on plugin $name.", t)
+      }
+    }
+  }
 }
 
 object PluginContainer {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index ebc1c05..81f984b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -23,6 +23,7 @@ import java.util.Properties
 import org.apache.spark._
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.internal.config.APP_CALLER_CONTEXT
+import org.apache.spark.internal.plugin.PluginContainer
 import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
 import org.apache.spark.metrics.MetricsSystem
 import org.apache.spark.rdd.InputFileBlockHolder
@@ -82,7 +83,8 @@ private[spark] abstract class Task[T](
       taskAttemptId: Long,
       attemptNumber: Int,
       metricsSystem: MetricsSystem,
-      resources: Map[String, ResourceInformation]): T = {
+      resources: Map[String, ResourceInformation],
+      plugins: Option[PluginContainer]): T = {
     SparkEnv.get.blockManager.registerTask(taskAttemptId)
     // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
     // the stage is barrier.
@@ -123,6 +125,8 @@ private[spark] abstract class Task[T](
       Option(taskAttemptId),
       Option(attemptNumber)).setCurrentContext()
 
+    plugins.foreach(_.onTaskStart())
+
     try {
       runTask(context)
     } catch {
diff --git a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala
index 7888796..e7fbe5b 100644
--- a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala
@@ -129,6 +129,38 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
     assert(TestSparkPlugin.driverPlugin != null)
   }
 
+  test("SPARK-33088: executor tasks trigger plugin calls") {
+    val conf = new SparkConf()
+      .setAppName(getClass().getName())
+      .set(SparkLauncher.SPARK_MASTER, "local[1]")
+      .set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))
+
+    sc = new SparkContext(conf)
+    sc.parallelize(1 to 10, 2).count()
+
+    assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
+    assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 2)
+    assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 0)
+  }
+
+  test("SPARK-33088: executor failed tasks trigger plugin calls") {
+    val conf = new SparkConf()
+      .setAppName(getClass().getName())
+      .set(SparkLauncher.SPARK_MASTER, "local[1]")
+      .set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))
+
+    sc = new SparkContext(conf)
+    try {
+      sc.parallelize(1 to 10, 2).foreach(i => throw new RuntimeException)
+    } catch {
+      case t: Throwable => // ignore exception
+    }
+
+    assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
+    assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 0)
+    assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 2)
+  }
+
   test("plugin initialization in non-local mode") {
     val path = Utils.createTempDir()
 
@@ -309,6 +341,10 @@ private class TestDriverPlugin extends DriverPlugin {
 
 private class TestExecutorPlugin extends ExecutorPlugin {
 
+  var numOnTaskStart: Int = 0
+  var numOnTaskSucceeded: Int = 0
+  var numOnTaskFailed: Int = 0
+
   override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = {
     ctx.metricRegistry().register("executorMetric", new Gauge[Int] {
       override def getValue(): Int = 84
@@ -316,6 +352,17 @@ private class TestExecutorPlugin extends ExecutorPlugin {
     TestSparkPlugin.executorContext = ctx
   }
 
+  override def onTaskStart(): Unit = {
+    numOnTaskStart += 1
+  }
+
+  override def onTaskSucceeded(): Unit = {
+    numOnTaskSucceeded += 1
+  }
+
+  override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
+    numOnTaskFailed += 1
+  }
 }
 
 private object TestSparkPlugin {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 394a2a9..8a7ff9e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
       0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
       closureSerializer.serialize(TaskMetrics.registered).array())
     intercept[RuntimeException] {
-      task.run(0, 0, null, null)
+      task.run(0, 0, null, null, Option.empty)
     }
     assert(TaskContextSuite.completed)
   }
@@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
       0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
       closureSerializer.serialize(TaskMetrics.registered).array())
     intercept[RuntimeException] {
-      task.run(0, 0, null, null)
+      task.run(0, 0, null, null, Option.empty)
     }
     assert(TaskContextSuite.lastError.getMessage == "damn error")
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org