You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/06/21 00:21:50 UTC

[spark] branch master updated: [SPARK-43952][CORE][CONNECT][SQL] Add SparkContext APIs for query cancellation by tag

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 607469b2fd2 [SPARK-43952][CORE][CONNECT][SQL] Add SparkContext APIs for query cancellation by tag
607469b2fd2 is described below

commit 607469b2fd2ee6d70739c5e8b3aca15f67a45cde
Author: Juliusz Sompolski <ju...@databricks.com>
AuthorDate: Wed Jun 21 09:21:30 2023 +0900

    [SPARK-43952][CORE][CONNECT][SQL] Add SparkContext APIs for query cancellation by tag
    
    ### What changes were proposed in this pull request?
    
    Currently, the only way to cancel running Spark Jobs is by using `SparkContext.cancelJobGroup`, using a job group name that was previously set using `SparkContext.setJobGroup`. This is problematic if multiple different parts of the system want to do cancellation, and set their own ids.
    
    For example, BroadcastExchangeExec sets it's own job group, which may override job group set by user. This way, if user cancels the job group they set in the "parent" execution, it will not cancel these broadcast jobs launches from within their jobs. It would also be useful in e.g. Spark Connect to be able to cancel jobs without overriding jobGroupId, which may be used and needed for other purposes.
    
    As a solution, consider add API to set tags on jobs, and to cancel jobs using tags:
    * `SparkContext.addJobTag(tag: String): Unit`
    * `SparkContext.removeJobTag(tag: String): Unit`
    * `SparkContext.getJobTags(): Set[String]`
    * `SparkContext.clearJobTags(): Unit`
    * `SparkContext.cancelJobsWithTag(tag: String): Unit`
    * `DAGScheduler.cancelJobsWithTag(tag: String): Unit`
    
    Also added `SparkContext.setInterruptOnCancel(interruptOnCancel: Boolean): Unit`, which previously could only be set in `setJobGroup`.
    
    The tags are also added to `JobData` and `AppStatusTracker`. A new API is added to `SparkStatusTracker`:
    * `SparkStatusTracker.getJobIdsForTag(jobTag: String): Array[Int]`
    
    Use the new API internally in BroadcastExchangeExec instead of cancellation using job group, to fix the issue with these not being cancelled by user-set jobgroupid. Now, the user set jobgroupid should propagate into broadcast execution.
    
    Also, switch cancellation in Spark Connect to use tag instead of jobgroup.
    
    ### Why are the changes needed?
    
    Currently, there may be multiple places that want to cancel a set of jobs, with different scopes.
    
    ### Does this PR introduce _any_ user-facing change?
    
    The APIs described above are added.
    
    ### How was this patch tested?
    
    Added test to JobCancellationSuite.
    
    Closes #41440 from juliuszsompolski/SPARK-43952.
    
    Authored-by: Juliusz Sompolski <ju...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../sql/connect/service/ExecutePlanHolder.scala    |   7 +-
 .../service/SparkConnectStreamHandler.scala        |   8 +-
 .../apache/spark/status/protobuf/store_types.proto |   1 +
 .../main/scala/org/apache/spark/SparkContext.scala |  77 ++++++++++++
 .../org/apache/spark/SparkStatusTracker.scala      |  11 ++
 .../org/apache/spark/scheduler/DAGScheduler.scala  |  25 ++++
 .../apache/spark/scheduler/DAGSchedulerEvent.scala |   2 +
 .../apache/spark/status/AppStatusListener.scala    |   7 ++
 .../scala/org/apache/spark/status/LiveEntity.scala |   2 +
 .../scala/org/apache/spark/status/api/v1/api.scala |   1 +
 .../status/protobuf/JobDataWrapperSerializer.scala |   2 +
 ...from_multi_attempt_app_json_1__expectation.json |   1 +
 ...from_multi_attempt_app_json_2__expectation.json |   1 +
 .../job_list_json_expectation.json                 |   3 +
 .../one_job_json_expectation.json                  |   1 +
 ...succeeded_failed_job_list_json_expectation.json |   3 +
 .../succeeded_job_list_json_expectation.json       |   2 +
 .../org/apache/spark/JobCancellationSuite.scala    | 129 ++++++++++++++++++++-
 .../org/apache/spark/StatusTrackerSuite.scala      |  41 +++++++
 .../protobuf/KVStoreProtobufSerializerSuite.scala  |   8 +-
 project/MimaExcludes.scala                         |   4 +-
 .../execution/exchange/BroadcastExchangeExec.scala |  11 +-
 .../sql/execution/BroadcastExchangeSuite.scala     |   4 +-
 23 files changed, 335 insertions(+), 16 deletions(-)

diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala
index a3c17b9826e..9bf9df07e01 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala
@@ -27,15 +27,16 @@ case class ExecutePlanHolder(
     sessionHolder: SessionHolder,
     request: proto.ExecutePlanRequest) {
 
-  val jobGroupId =
-    s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Request_${operationId}"
+  val jobTag =
+    "SparkConnect_" +
+      s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Request_${operationId}"
 
   def interrupt(): Unit = {
     // TODO/WIP: This only interrupts active Spark jobs that are actively running.
     // This would then throw the error from ExecutePlan and terminate it.
     // But if the query is not running a Spark job, but executing code on Spark driver, this
     // would be a noop and the execution will keep running.
-    sessionHolder.session.sparkContext.cancelJobGroup(jobGroupId)
+    sessionHolder.session.sparkContext.cancelJobsWithTag(jobTag)
   }
 
 }
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index d11f4dcc600..70204f2913d 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -63,8 +63,12 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
         }
 
       val executeHolder = sessionHolder.createExecutePlanHolder(v)
+      session.sparkContext.addJobTag(executeHolder.jobTag)
+      session.sparkContext.setInterruptOnCancel(true)
+      // Also set the tag as the JobGroup for all the jobs in the query.
+      // TODO: In the long term, it should be encouraged to use job tag only.
       session.sparkContext.setJobGroup(
-        executeHolder.jobGroupId,
+        executeHolder.jobTag,
         s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}",
         interruptOnCancel = true)
 
@@ -89,6 +93,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
             throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.")
         }
       } finally {
+        session.sparkContext.removeJobTag(executeHolder.jobTag)
+        session.sparkContext.clearJobGroup()
         sessionHolder.removeExecutePlanHolder(executeHolder.operationId)
       }
     }
diff --git a/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto b/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto
index 94ce1b8b58a..93365add3a6 100644
--- a/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto
+++ b/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto
@@ -47,6 +47,7 @@ message JobData {
   optional int64 completion_time = 5;
   repeated int64 stage_ids = 6;
   optional string job_group = 7;
+  repeated string job_tags = 21;
   JobExecutionStatus status = 8;
   int32 num_tasks = 9;
   int32 num_active_tasks = 10;
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index cf7a405f1ba..c32c674d64e 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -829,6 +829,55 @@ class SparkContext(config: SparkConf) extends Logging {
     setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null)
   }
 
+  /**
+   * Set the behavior of job cancellation from jobs started in this thread.
+   *
+   * @param interruptOnCancel If true, then job cancellation will result in `Thread.interrupt()`
+   * being called on the job's executor threads. This is useful to help ensure that the tasks
+   * are actually stopped in a timely manner, but is off by default due to HDFS-1208, where HDFS
+   * may respond to Thread.interrupt() by marking nodes as dead.
+   */
+  def setInterruptOnCancel(interruptOnCancel: Boolean): Unit = {
+    setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, interruptOnCancel.toString)
+  }
+
+  /**
+   * Add a tag to be assigned to all the jobs started by this thread.
+   *
+   * @param tag The tag to be added. Cannot contain ',' (comma) character.
+   */
+  def addJobTag(tag: String): Unit = {
+    SparkContext.throwIfInvalidTag(tag)
+    val existingTags = getJobTags()
+    val newTags = (existingTags + tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
+    setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags)
+  }
+
+  /**
+   * Remove a tag previously added to be assigned to all the jobs started by this thread.
+   * Noop if such a tag was not added earlier.
+   *
+   * @param tag The tag to be removed. Cannot contain ',' (comma) character.
+   */
+  def removeJobTag(tag: String): Unit = {
+    SparkContext.throwIfInvalidTag(tag)
+    val existingTags = getJobTags()
+    val newTags = (existingTags - tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
+    setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags)
+  }
+
+  /** Get the tags that are currently set to be assigned to all the jobs started by this thread. */
+  def getJobTags(): Set[String] = {
+    Option(getLocalProperty(SparkContext.SPARK_JOB_TAGS))
+      .map(_.split(SparkContext.SPARK_JOB_TAGS_SEP).toSet)
+      .getOrElse(Set())
+  }
+
+  /** Clear the current thread's job tags. */
+  def clearJobTags(): Unit = {
+    setLocalProperty(SparkContext.SPARK_JOB_TAGS, null)
+  }
+
   /**
    * Execute a block of code in a scope such that all new RDDs created in this body will
    * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}.
@@ -2471,6 +2520,17 @@ class SparkContext(config: SparkConf) extends Logging {
     dagScheduler.cancelJobGroup(groupId)
   }
 
+  /**
+   * Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`.
+   *
+   * @param tag The tag to be added. Cannot contain ',' (comma) character.
+   */
+  def cancelJobsWithTag(tag: String): Unit = {
+    SparkContext.throwIfInvalidTag(tag)
+    assertNotStopped()
+    dagScheduler.cancelJobsWithTag(tag)
+  }
+
   /** Cancel all jobs that have been scheduled or are running.  */
   def cancelAllJobs(): Unit = {
     assertNotStopped()
@@ -2840,6 +2900,7 @@ object SparkContext extends Logging {
   private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
   private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
   private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel"
+  private[spark] val SPARK_JOB_TAGS = "spark.job.tags"
   private[spark] val SPARK_SCHEDULER_POOL = "spark.scheduler.pool"
   private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope"
   private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride"
@@ -2851,6 +2912,22 @@ object SparkContext extends Logging {
    */
   private[spark] val DRIVER_IDENTIFIER = "driver"
 
+  /** Separator of tags in SPARK_JOB_TAGS property */
+  private[spark] val SPARK_JOB_TAGS_SEP = ","
+
+  private[spark] def throwIfInvalidTag(tag: String) = {
+    if (tag == null) {
+      throw new IllegalArgumentException("Spark job tag cannot be null.")
+    }
+    if (tag.contains(SPARK_JOB_TAGS_SEP)) {
+      throw new IllegalArgumentException(
+        s"Spark job tag cannot contain '$SPARK_JOB_TAGS_SEP'.")
+    }
+    if (tag.isEmpty) {
+      throw new IllegalArgumentException(
+        "Spark job tag cannot be an empty string.")
+    }
+  }
 
   private implicit def arrayToArrayWritable[T <: Writable : ClassTag](arr: Iterable[T])
     : ArrayWritable = {
diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
index 22dc1d056ec..a55a6a8b8eb 100644
--- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
+++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
@@ -52,6 +52,17 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore
     store.jobsList(null).filter(_.jobGroup == expected).map(_.jobId).toArray
   }
 
+  /**
+   * Return a list of all known jobs with a particular tag.
+   *
+   * The returned list may contain running, failed, and completed jobs, and may vary across
+   * invocations of this method.  This method does not guarantee the order of the elements in
+   * its result.
+   */
+  def getJobIdsForTag(jobTag: String): Array[Int] = {
+    store.jobsList(null).filter(_.jobTags.contains(jobTag)).map(_.jobId).toArray
+  }
+
   /**
    * Returns an array containing the ids of all active stages.
    *
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c78a26d91eb..64a8192f8e1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1085,6 +1085,15 @@ private[spark] class DAGScheduler(
     eventProcessLoop.post(JobGroupCancelled(groupId))
   }
 
+  /**
+   * Cancel all jobs with a given tag.
+   */
+  def cancelJobsWithTag(tag: String): Unit = {
+    SparkContext.throwIfInvalidTag(tag)
+    logInfo(s"Asked to cancel jobs with tag $tag")
+    eventProcessLoop.post(JobTagCancelled(tag))
+  }
+
   /**
    * Cancel all jobs that are running or waiting in the queue.
    */
@@ -1182,6 +1191,19 @@ private[spark] class DAGScheduler(
         Option("part of cancelled job group %s".format(groupId))))
   }
 
+  private[scheduler] def handleJobTagCancelled(tag: String): Unit = {
+    // Cancel all jobs belonging that have this tag.
+    // First finds all active jobs with this group id, and then kill stages for them.
+    val jobIds = activeJobs.filter { activeJob =>
+      Option(activeJob.properties).exists { properties =>
+        Option(properties.getProperty(SparkContext.SPARK_JOB_TAGS)).getOrElse("")
+          .split(SparkContext.SPARK_JOB_TAGS_SEP).toSet.contains(tag)
+      }
+    }.map(_.jobId)
+    jobIds.foreach(handleJobCancellation(_,
+      Option(s"part of cancelled job tag $tag")))
+  }
+
   private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = {
     listenerBus.post(SparkListenerTaskStart(task.stageId, task.stageAttemptId, taskInfo))
   }
@@ -2972,6 +2994,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
     case JobGroupCancelled(groupId) =>
       dagScheduler.handleJobGroupCancelled(groupId)
 
+    case JobTagCancelled(groupId) =>
+      dagScheduler.handleJobTagCancelled(groupId)
+
     case AllJobsCancelled =>
       dagScheduler.doCancelAllJobs()
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index c16e5ea03d7..6f2b778ca82 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -63,6 +63,8 @@ private[scheduler] case class JobCancelled(
 
 private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
 
+private[scheduler] case class JobTagCancelled(tagName: String) extends DAGSchedulerEvent
+
 private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
 
 private[scheduler]
diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
index 5dee3cb6719..c1f52e86dd0 100644
--- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
@@ -438,6 +438,12 @@ private[spark] class AppStatusListener(
       .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) }
     val jobGroup = Option(event.properties)
       .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) }
+    val jobTags = Option(event.properties)
+      .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_TAGS)) }
+      .map(_.split(SparkContext.SPARK_JOB_TAGS_SEP).toSet)
+      .getOrElse(Set())
+      .toSeq
+      .sorted
     val sqlExecutionId = Option(event.properties)
       .flatMap(p => Option(p.getProperty(SQL_EXECUTION_ID_KEY)).map(_.toLong))
 
@@ -448,6 +454,7 @@ private[spark] class AppStatusListener(
       if (event.time > 0) Some(new Date(event.time)) else None,
       event.stageIds,
       jobGroup,
+      jobTags,
       numTasks,
       sqlExecutionId)
     liveJobs.put(event.jobId, job)
diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala
index 9910a0f07fc..ebea11fdca0 100644
--- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala
+++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala
@@ -66,6 +66,7 @@ private class LiveJob(
     val submissionTime: Option[Date],
     val stageIds: Seq[Int],
     jobGroup: Option[String],
+    jobTags: Seq[String],
     numTasks: Int,
     sqlExecutionId: Option[Long]) extends LiveEntity {
 
@@ -98,6 +99,7 @@ private class LiveJob(
       completionTime,
       stageIds,
       jobGroup,
+      jobTags,
       status,
       numTasks,
       activeTasks,
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
index f436d16ca47..8d648b9df38 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
@@ -199,6 +199,7 @@ class JobData private[spark](
     val completionTime: Option[Date],
     val stageIds: collection.Seq[Int],
     val jobGroup: Option[String],
+    val jobTags: collection.Seq[String],
     val status: JobExecutionStatus,
     val numTasks: Int,
     val numActiveTasks: Int,
diff --git a/core/src/main/scala/org/apache/spark/status/protobuf/JobDataWrapperSerializer.scala b/core/src/main/scala/org/apache/spark/status/protobuf/JobDataWrapperSerializer.scala
index 97189e372f9..11f1b7070cc 100644
--- a/core/src/main/scala/org/apache/spark/status/protobuf/JobDataWrapperSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/status/protobuf/JobDataWrapperSerializer.scala
@@ -71,6 +71,7 @@ private[protobuf] class JobDataWrapperSerializer extends ProtobufSerDe[JobDataWr
     }
     jobData.stageIds.foreach(id => jobDataBuilder.addStageIds(id.toLong))
     jobData.jobGroup.foreach(jobDataBuilder.setJobGroup)
+    jobData.jobTags.foreach(jobDataBuilder.addJobTags)
     jobData.killedTasksSummary.foreach { entry =>
       jobDataBuilder.putKillTasksSummary(entry._1, entry._2)
     }
@@ -93,6 +94,7 @@ private[protobuf] class JobDataWrapperSerializer extends ProtobufSerDe[JobDataWr
       completionTime = completionTime,
       stageIds = info.getStageIdsList.asScala.map(_.toInt),
       jobGroup = jobGroup,
+      jobTags = info.getJobTagsList.asScala,
       status = status,
       numTasks = info.getNumTasks,
       numActiveTasks = info.getNumActiveTasks,
diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json
index 2f275c7bfe2..b7271d89e02 100644
--- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json
@@ -2,6 +2,7 @@
   "jobId" : 0,
   "name" : "foreach at <console>:15",
   "stageIds" : [ 0 ],
+  "jobTags" : [ ],
   "status" : "SUCCEEDED",
   "numTasks" : 8,
   "numActiveTasks" : 0,
diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json
index 2f275c7bfe2..b7271d89e02 100644
--- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json
@@ -2,6 +2,7 @@
   "jobId" : 0,
   "name" : "foreach at <console>:15",
   "stageIds" : [ 0 ],
+  "jobTags" : [ ],
   "status" : "SUCCEEDED",
   "numTasks" : 8,
   "numActiveTasks" : 0,
diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json
index 71bf8706307..bb26bc47eac 100644
--- a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json
@@ -2,6 +2,7 @@
   "jobId" : 2,
   "name" : "count at <console>:17",
   "stageIds" : [ 3 ],
+  "jobTags" : [ ],
   "status" : "SUCCEEDED",
   "numTasks" : 8,
   "numActiveTasks" : 0,
@@ -19,6 +20,7 @@
   "jobId" : 1,
   "name" : "count at <console>:20",
   "stageIds" : [ 1, 2 ],
+  "jobTags" : [ ],
   "status" : "FAILED",
   "numTasks" : 16,
   "numActiveTasks" : 0,
@@ -36,6 +38,7 @@
   "jobId" : 0,
   "name" : "count at <console>:15",
   "stageIds" : [ 0 ],
+  "jobTags" : [ ],
   "status" : "SUCCEEDED",
   "numTasks" : 8,
   "numActiveTasks" : 0,
diff --git a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json
index 1eae5f3d5be..3bf4845ed1e 100644
--- a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json
@@ -2,6 +2,7 @@
   "jobId" : 0,
   "name" : "count at <console>:15",
   "stageIds" : [ 0 ],
+  "jobTags" : [ ],
   "status" : "SUCCEEDED",
   "numTasks" : 8,
   "numActiveTasks" : 0,
diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json
index 71bf8706307..bb26bc47eac 100644
--- a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json
@@ -2,6 +2,7 @@
   "jobId" : 2,
   "name" : "count at <console>:17",
   "stageIds" : [ 3 ],
+  "jobTags" : [ ],
   "status" : "SUCCEEDED",
   "numTasks" : 8,
   "numActiveTasks" : 0,
@@ -19,6 +20,7 @@
   "jobId" : 1,
   "name" : "count at <console>:20",
   "stageIds" : [ 1, 2 ],
+  "jobTags" : [ ],
   "status" : "FAILED",
   "numTasks" : 16,
   "numActiveTasks" : 0,
@@ -36,6 +38,7 @@
   "jobId" : 0,
   "name" : "count at <console>:15",
   "stageIds" : [ 0 ],
+  "jobTags" : [ ],
   "status" : "SUCCEEDED",
   "numTasks" : 8,
   "numActiveTasks" : 0,
diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json
index b1ddd760c97..2b2c2fbe1f2 100644
--- a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json
@@ -2,6 +2,7 @@
   "jobId" : 2,
   "name" : "count at <console>:17",
   "stageIds" : [ 3 ],
+  "jobTags" : [ ],
   "status" : "SUCCEEDED",
   "numTasks" : 8,
   "numActiveTasks" : 0,
@@ -19,6 +20,7 @@
   "jobId" : 0,
   "name" : "count at <console>:15",
   "stageIds" : [ 0 ],
+  "jobTags" : [ ],
   "status" : "SUCCEEDED",
   "numTasks" : 8,
   "numActiveTasks" : 0,
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 77bdb882c50..f2ad33b0be7 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -20,10 +20,10 @@ package org.apache.spark
 import java.util.concurrent.{Semaphore, TimeUnit}
 import java.util.concurrent.atomic.AtomicInteger
 
+import scala.concurrent.{ExecutionContext, Future}
 // scalastyle:off executioncontextglobal
 import scala.concurrent.ExecutionContext.Implicits.global
 // scalastyle:on executioncontextglobal
-import scala.concurrent.Future
 import scala.concurrent.duration._
 
 import org.scalatest.BeforeAndAfter
@@ -31,7 +31,7 @@ import org.scalatest.matchers.must.Matchers
 
 import org.apache.spark.internal.config._
 import org.apache.spark.internal.config.Deploy._
-import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart}
 import org.apache.spark.util.ThreadUtils
 
 /**
@@ -153,6 +153,131 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
     assert(jobB.get() === 100)
   }
 
+  test("job tags") {
+    sc = new SparkContext("local[2]", "test")
+
+    // global ExecutionContext has only 2 threads in Apache Spark CI
+    // create own thread pool for four Futures used in this test
+    val numThreads = 4
+    val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads)
+    val executionContext = ExecutionContext.fromExecutorService(fpool)
+
+    try {
+      // Add a listener to release the semaphore once jobs are launched.
+      val sem = new Semaphore(0)
+      val jobEnded = new AtomicInteger(0)
+
+      sc.addSparkListener(new SparkListener {
+        override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+          sem.release()
+        }
+
+        override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+          sem.release()
+          jobEnded.incrementAndGet()
+        }
+      })
+
+      val eSep = intercept[IllegalArgumentException](sc.addJobTag("foo,bar"))
+      assert(eSep.getMessage.contains(
+        s"Spark job tag cannot contain '${SparkContext.SPARK_JOB_TAGS_SEP}'."))
+      val eEmpty = intercept[IllegalArgumentException](sc.addJobTag(""))
+      assert(eEmpty.getMessage.contains("Spark job tag cannot be an empty string."))
+      val eNull = intercept[IllegalArgumentException](sc.addJobTag(null))
+      assert(eNull.getMessage.contains("Spark job tag cannot be null."))
+
+      // Note: since tags are added in the Future threads, they don't need to be cleared in between.
+      val jobA = Future {
+        assert(sc.getJobTags() == Set())
+        sc.addJobTag("two")
+        assert(sc.getJobTags() == Set("two"))
+        sc.clearJobTags() // check that clearing all tags works
+        assert(sc.getJobTags() == Set())
+        sc.addJobTag("one")
+        assert(sc.getJobTags() == Set("one"))
+        try {
+          sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100); i }.count()
+        } finally {
+          sc.clearJobTags() // clear for the case of thread reuse by another Future
+        }
+      }(executionContext)
+      val jobB = Future {
+        assert(sc.getJobTags() == Set())
+        sc.addJobTag("one")
+        sc.addJobTag("two")
+        sc.addJobTag("one")
+        sc.addJobTag("two") // duplicates shouldn't matter
+        assert(sc.getJobTags() == Set("one", "two"))
+        try {
+          sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100); i }.count()
+        } finally {
+          sc.clearJobTags() // clear for the case of thread reuse by another Future
+        }
+      }(executionContext)
+      val jobC = Future {
+        assert(sc.getJobTags() == Set())
+        sc.addJobTag("two")
+        assert(sc.getJobTags() == Set("two"))
+        try {
+          sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100); i }.count()
+        } finally {
+          sc.clearJobTags() // clear for the case of thread reuse by another Future
+        }
+      }(executionContext)
+      val jobD = Future {
+        assert(sc.getJobTags() == Set())
+        sc.addJobTag("one")
+        sc.addJobTag("two")
+        sc.addJobTag("two")
+        assert(sc.getJobTags() == Set("one", "two"))
+        sc.removeJobTag("two") // check that remove works, despite duplicate add
+        assert(sc.getJobTags() == Set("one"))
+        try {
+          sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100); i }.count()
+        } finally {
+          sc.clearJobTags() // clear for the case of thread reuse by another Future
+        }
+      }(executionContext)
+
+      // Block until four jobs have started.
+      val acquired1 = sem.tryAcquire(4, 1, TimeUnit.MINUTES)
+      assert(acquired1 == true)
+
+      sc.cancelJobsWithTag("two")
+      val eB = intercept[SparkException] {
+        ThreadUtils.awaitResult(jobB, 1.minute)
+      }.getCause
+      assert(eB.getMessage contains "cancel")
+      val eC = intercept[SparkException] {
+        ThreadUtils.awaitResult(jobC, 1.minute)
+      }.getCause
+      assert(eC.getMessage contains "cancel")
+
+      // two jobs cancelled
+      val acquired2 = sem.tryAcquire(2, 1, TimeUnit.MINUTES)
+      assert(acquired2 == true)
+      assert(jobEnded.intValue == 2)
+
+      // this cancels the remaining two jobs
+      sc.cancelJobsWithTag("one")
+      val eA = intercept[SparkException] {
+        ThreadUtils.awaitResult(jobA, 1.minute)
+      }.getCause
+      assert(eA.getMessage contains "cancel")
+      val eD = intercept[SparkException] {
+        ThreadUtils.awaitResult(jobD, 1.minute)
+      }.getCause
+      assert(eD.getMessage contains "cancel")
+
+      // another two jobs cancelled
+      val acquired3 = sem.tryAcquire(2, 1, TimeUnit.MINUTES)
+      assert(acquired3 == true)
+      assert(jobEnded.intValue == 4)
+    } finally {
+      fpool.shutdownNow()
+    }
+  }
+
   test("inherited job group (SPARK-6629)") {
     sc = new SparkContext("local[2]", "test")
 
diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
index e6d3377120e..0817abbc6a3 100644
--- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
@@ -111,4 +111,45 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont
       sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2
     }
   }
+
+  test("getJobIdsForTag()") {
+    sc = new SparkContext("local", "test", new SparkConf(false))
+
+    sc.addJobTag("tag1")
+    sc.statusTracker.getJobIdsForTag("tag1") should be (Seq.empty)
+
+    // countAsync()
+    val firstJobFuture = sc.parallelize(1 to 1000).countAsync()
+    val firstJobId = eventually(timeout(10.seconds)) {
+      firstJobFuture.jobIds.head
+    }
+    eventually(timeout(10.seconds)) {
+      sc.statusTracker.getJobIdsForTag("tag1") should be (Seq(firstJobId))
+    }
+
+    sc.addJobTag("tag2")
+    // takeAsync()
+    val secondJobFuture = sc.parallelize(1 to 1000).takeAsync(1)
+    val secondJobId = eventually(timeout(10.seconds)) {
+      secondJobFuture.jobIds.head
+    }
+    eventually(timeout(10.seconds)) {
+      sc.statusTracker.getJobIdsForTag("tag1").toSet should be (
+        Set(firstJobId, secondJobId))
+      sc.statusTracker.getJobIdsForTag("tag2") should be (Seq(secondJobId))
+    }
+
+    sc.removeJobTag("tag1")
+    // takeAsync() across multiple partitions
+    val thirdJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999)
+    val thirdJobId = eventually(timeout(10.seconds)) {
+      thirdJobFuture.jobIds.head
+    }
+    eventually(timeout(10.seconds)) {
+      sc.statusTracker.getJobIdsForTag("tag1").toSet should be (
+        Set(firstJobId, secondJobId))
+      sc.statusTracker.getJobIdsForTag("tag2").toSet should be (
+        Set(secondJobId, thirdJobId))
+    }
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala b/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala
index 0849d63b03e..ac568fee1ad 100644
--- a/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala
@@ -65,9 +65,9 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite {
 
   test("Job data") {
     Seq(
-      ("test", Some("test description"), Some("group")),
-      (null, None, None)
-    ).foreach { case (name, description, jobGroup) =>
+      ("test", Some("test description"), Some("group"), Seq("tag1", "tag2")),
+      (null, None, None, Seq())
+    ).foreach { case (name, description, jobGroup, jobTags) =>
       val input = new JobDataWrapper(
         new JobData(
           jobId = 1,
@@ -77,6 +77,7 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite {
           completionTime = Some(new Date(654321L)),
           stageIds = Seq(1, 2, 3, 4),
           jobGroup = jobGroup,
+          jobTags = jobTags,
           status = JobExecutionStatus.UNKNOWN,
           numTasks = 2,
           numActiveTasks = 3,
@@ -102,6 +103,7 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite {
       assert(result.info.completionTime == input.info.completionTime)
       assert(result.info.stageIds == input.info.stageIds)
       assert(result.info.jobGroup == input.info.jobGroup)
+      assert(result.info.jobTags == input.info.jobTags)
       assert(result.info.status == input.info.status)
       assert(result.info.numTasks == input.info.numTasks)
       assert(result.info.numActiveTasks == input.info.numActiveTasks)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 7cac416838d..543c6eb0961 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -56,7 +56,9 @@ object MimaExcludes {
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Row.prettyJson"),
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.json"),
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.prettyJson"),
-    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.jsonValue")
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.jsonValue"),
+    // [SPARK-43952][CORE][CONNECT][SQL] Add SparkContext APIs for query cancellation by tag
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.JobData.this")
   )
 
   // Defulat exclude rules
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index 548a8628ba4..15141b09b6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -124,14 +124,17 @@ case class BroadcastExchangeExec(
     case _ => 512000000
   }
 
+  @transient
+  private lazy val jobTag = s"broadcast exchange (runId ${runId.toString})"
+
   @transient
   override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
     SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
       session, BroadcastExchangeExec.executionContext) {
           try {
-            // Setup a job group here so later it may get cancelled by groupId if necessary.
-            sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
-              interruptOnCancel = true)
+            // Setup a job tag here so later it may get cancelled by tag if necessary.
+            sparkContext.addJobTag(jobTag)
+            sparkContext.setInterruptOnCancel(true)
             val beforeCollect = System.nanoTime()
             // Use executeCollect/executeCollectIterator to avoid conversion to Scala types
             val (numRows, input) = child.executeCollectIterator()
@@ -211,7 +214,7 @@ case class BroadcastExchangeExec(
       case ex: TimeoutException =>
         logError(s"Could not execute broadcast in $timeout secs.", ex)
         if (!relationFuture.isDone) {
-          sparkContext.cancelJobGroup(runId.toString)
+          sparkContext.cancelJobsWithTag(jobTag)
           relationFuture.cancel(true)
         }
         throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala
index 31a8507cba0..0efb4180dbd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala
@@ -38,7 +38,7 @@ class BroadcastExchangeSuite extends SparkPlanTest
 
   import testImplicits._
 
-  test("BroadcastExchange should cancel the job group if timeout") {
+  test("BroadcastExchange should cancel the job tag if timeout") {
     val startLatch = new CountDownLatch(1)
     val endLatch = new CountDownLatch(1)
     var jobEvents: Seq[SparkListenerEvent] = Seq.empty[SparkListenerEvent]
@@ -82,7 +82,7 @@ class BroadcastExchangeSuite extends SparkPlanTest
       val events = jobEvents.toArray
       val hasStart = events(0).isInstanceOf[SparkListenerJobStart]
       val hasCancelled = events(1).asInstanceOf[SparkListenerJobEnd].jobResult
-        .asInstanceOf[JobFailed].exception.getMessage.contains("cancelled job group")
+        .asInstanceOf[JobFailed].exception.getMessage.contains("cancelled job tag")
       events.length == 2 && hasStart && hasCancelled
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org