You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2021/08/05 02:14:27 UTC

[spark] branch master updated: [SPARK-36173][CORE] Support getting CPU number in TaskContext

This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f6e6d11  [SPARK-36173][CORE] Support getting CPU number in TaskContext
f6e6d11 is described below

commit f6e6d1157ac988d7c5809fcb08b577631bdea8eb
Author: Wu, Xiaochang <xi...@intel.com>
AuthorDate: Wed Aug 4 21:14:01 2021 -0500

    [SPARK-36173][CORE] Support getting CPU number in TaskContext
    
    In stage-level resource scheduling, the allocated 3rd party resources can be obtained in TaskContext using resources() interface, however there is no API to get how many cpus are allocated for the task. Will add a cpus() interface to TaskContext to complement resources(). Althrough the task cpu requests can be got from profile, it's more convenient to get it inside the task code without the need to pass profile from driver side to the executor side.
    
    ### What changes were proposed in this pull request?
    Add cpus() interface in TaskContext and modify relevant code.
    
    ### Why are the changes needed?
    TaskContext has resources() to get 3rd party resources allocated. the is no API to get CPU allocated for the task.
    
    ### Does this PR introduce _any_ user-facing change?
    Add cpus() interface for TaskContext
    
    ### How was this patch tested?
    Unit tests
    
    Closes #33385 from xwu99/taskcontext-cpus.
    
    Lead-authored-by: Wu, Xiaochang <xi...@intel.com>
    Co-authored-by: Xiaochang Wu <xi...@intel.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../main/scala/org/apache/spark/BarrierTaskContext.scala   |  2 ++
 core/src/main/scala/org/apache/spark/TaskContext.scala     | 11 +++++++++--
 core/src/main/scala/org/apache/spark/TaskContextImpl.scala |  3 ++-
 .../scala/org/apache/spark/api/python/PythonRunner.scala   |  1 +
 .../main/scala/org/apache/spark/executor/Executor.scala    |  1 +
 core/src/main/scala/org/apache/spark/scheduler/Task.scala  |  5 +++++
 .../scala/org/apache/spark/scheduler/TaskDescription.scala | 11 ++++++++++-
 .../org/apache/spark/scheduler/TaskSchedulerImpl.scala     |  3 ++-
 .../scala/org/apache/spark/scheduler/TaskSetManager.scala  |  6 ++++++
 .../test/scala/org/apache/spark/SparkContextSuite.scala    | 14 +++++++++++---
 .../spark/executor/CoarseGrainedExecutorBackendSuite.scala |  2 +-
 .../scala/org/apache/spark/executor/ExecutorSuite.scala    |  1 +
 .../scheduler/CoarseGrainedSchedulerBackendSuite.scala     |  2 +-
 .../org/apache/spark/scheduler/TaskContextSuite.scala      |  4 ++--
 .../org/apache/spark/scheduler/TaskDescriptionSuite.scala  |  2 ++
 .../org/apache/spark/scheduler/TaskSetManagerSuite.scala   |  4 +++-
 .../org/apache/spark/storage/BlockInfoManagerSuite.scala   |  4 +++-
 project/MimaExcludes.scala                                 |  3 +++
 python/pyspark/taskcontext.py                              |  7 +++++++
 python/pyspark/tests/test_taskcontext.py                   |  9 ++++++++-
 python/pyspark/worker.py                                   |  1 +
 .../mesos/MesosFineGrainedSchedulerBackendSuite.scala      |  2 ++
 22 files changed, 83 insertions(+), 15 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
index 09fa916..aa63e61 100644
--- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
@@ -227,6 +227,8 @@ class BarrierTaskContext private[spark] (
     taskContext.getMetricsSources(sourceName)
   }
 
+  override def cpus(): Int = taskContext.cpus()
+
   override def resources(): Map[String, ResourceInformation] = taskContext.resources()
 
   override def resourcesJMap(): java.util.Map[String, ResourceInformation] = {
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index fd41fac..fd115fd 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -20,7 +20,7 @@ package org.apache.spark
 import java.io.Serializable
 import java.util.Properties
 
-import org.apache.spark.annotation.{DeveloperApi, Evolving}
+import org.apache.spark.annotation.{DeveloperApi, Evolving, Since}
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.metrics.source.Source
@@ -67,7 +67,8 @@ object TaskContext {
    * An empty task context that does not represent an actual task.  This is only used in tests.
    */
   private[spark] def empty(): TaskContextImpl = {
-    new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null)
+    new TaskContextImpl(0, 0, 0, 0, 0,
+      null, new Properties, null, TaskMetrics.empty, 1)
   }
 }
 
@@ -178,6 +179,12 @@ abstract class TaskContext extends Serializable {
   def getLocalProperty(key: String): String
 
   /**
+   * CPUs allocated to the task.
+   */
+  @Since("3.3.0")
+  def cpus(): Int
+
+  /**
    * Resources allocated to the task. The key is the resource name and the value is information
    * about the resource. Please refer to [[org.apache.spark.resource.ResourceInformation]] for
    * specifics.
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index db4b74b..7d909a5 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.internal.Logging
+import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.metrics.MetricsSystem
 import org.apache.spark.metrics.source.Source
@@ -54,6 +54,7 @@ private[spark] class TaskContextImpl(
     @transient private val metricsSystem: MetricsSystem,
     // The default value is only used in tests.
     override val taskMetrics: TaskMetrics = TaskMetrics.empty,
+    override val cpus: Int = SparkEnv.get.conf.get(config.CPUS_PER_TASK),
     override val resources: Map[String, ResourceInformation] = Map.empty)
   extends TaskContext
   with Logging {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index db0e100..c50a8b9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -341,6 +341,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         dataOut.writeInt(context.partitionId())
         dataOut.writeInt(context.attemptNumber())
         dataOut.writeLong(context.taskAttemptId())
+        dataOut.writeInt(context.cpus())
         val resources = context.resources()
         dataOut.writeInt(resources.size)
         resources.foreach { case (k, v) =>
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 5fb1d7f..a505160 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -502,6 +502,7 @@ private[spark] class Executor(
             taskAttemptId = taskId,
             attemptNumber = taskDescription.attemptNumber,
             metricsSystem = env.metricsSystem,
+            cpus = taskDescription.cpus,
             resources = taskDescription.resources,
             plugins = plugins)
           threwException = false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 81f984b..3ef8361 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -83,8 +83,12 @@ private[spark] abstract class Task[T](
       taskAttemptId: Long,
       attemptNumber: Int,
       metricsSystem: MetricsSystem,
+      cpus: Int,
       resources: Map[String, ResourceInformation],
       plugins: Option[PluginContainer]): T = {
+
+    require(cpus > 0, "CPUs per task should be > 0")
+
     SparkEnv.get.blockManager.registerTask(taskAttemptId)
     // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
     // the stage is barrier.
@@ -98,6 +102,7 @@ private[spark] abstract class Task[T](
       localProperties,
       metricsSystem,
       metrics,
+      cpus,
       resources)
 
     context = if (isBarrier) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
index 12b911d..8813851 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
@@ -57,9 +57,12 @@ private[spark] class TaskDescription(
     val addedJars: Map[String, Long],
     val addedArchives: Map[String, Long],
     val properties: Properties,
+    val cpus: Int,
     val resources: immutable.Map[String, ResourceInformation],
     val serializedTask: ByteBuffer) {
 
+  assert(cpus > 0, "CPUs per task should be > 0")
+
   override def toString: String = s"TaskDescription($name)"
 }
 
@@ -113,6 +116,9 @@ private[spark] object TaskDescription {
       dataOut.write(bytes)
     }
 
+    // Write cpus.
+    dataOut.writeInt(taskDescription.cpus)
+
     // Write resources.
     serializeResources(taskDescription.resources, dataOut)
 
@@ -185,6 +191,9 @@ private[spark] object TaskDescription {
       properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8))
     }
 
+    // Read cpus.
+    val cpus = dataIn.readInt()
+
     // Read resources.
     val resources = deserializeResources(dataIn)
 
@@ -192,6 +201,6 @@ private[spark] object TaskDescription {
     val serializedTask = byteBuffer.slice()
 
     new TaskDescription(taskId, attemptNumber, executorId, name, index, partitionId, taskFiles,
-      taskJars, taskArchives, properties, resources, serializedTask)
+      taskJars, taskArchives, properties, cpus, resources, serializedTask)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index b48ba12..7af47da 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -392,7 +392,7 @@ private[spark] class TaskSchedulerImpl(
             val prof = sc.resourceProfileManager.resourceProfileFromId(taskSetRpID)
             val taskCpus = ResourceProfile.getTaskCpusOrDefaultForProfile(prof, conf)
             val (taskDescOption, didReject, index) =
-              taskSet.resourceOffer(execId, host, maxLocality, taskResAssignments)
+              taskSet.resourceOffer(execId, host, maxLocality, taskCpus, taskResAssignments)
             noDelayScheduleRejects &= !didReject
             for (task <- taskDescOption) {
               val (locality, resources) = if (task != null) {
@@ -714,6 +714,7 @@ private[spark] class TaskSchedulerImpl(
                 task.index,
                 task.taskLocality,
                 false,
+                task.assignedCores,
                 task.assignedResources,
                 launchTime)
               addRunningTask(taskDesc.taskId, taskDesc.executorId, taskSet)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index e2a0375..b7bd7bb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -418,6 +418,8 @@ private[spark] class TaskSetManager(
    * @param execId the executor Id of the offered resource
    * @param host  the host Id of the offered resource
    * @param maxLocality the maximum locality we want to schedule the tasks at
+   * @param taskCpus the number of CPUs for the task
+   * @param taskResourceAssignments the resource assignments for the task
    *
    * @return Triple containing:
    *         (TaskDescription of launched task if any,
@@ -429,6 +431,7 @@ private[spark] class TaskSetManager(
       execId: String,
       host: String,
       maxLocality: TaskLocality.TaskLocality,
+      taskCpus: Int = sched.CPUS_PER_TASK,
       taskResourceAssignments: Map[String, ResourceInformation] = Map.empty)
     : (Option[TaskDescription], Boolean, Int) =
   {
@@ -474,6 +477,7 @@ private[spark] class TaskSetManager(
                 index,
                 taskLocality,
                 speculative,
+                taskCpus,
                 taskResourceAssignments,
                 curTime)
             }
@@ -495,6 +499,7 @@ private[spark] class TaskSetManager(
       index: Int,
       taskLocality: TaskLocality.Value,
       speculative: Boolean,
+      taskCpus: Int,
       taskResourceAssignments: Map[String, ResourceInformation],
       launchTime: Long): TaskDescription = {
     // Found a task; do some bookkeeping and return a task description
@@ -548,6 +553,7 @@ private[spark] class TaskSetManager(
       addedJars,
       addedArchives,
       task.localProperties,
+      taskCpus,
       taskResourceAssignments,
       serializedTask)
   }
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index 93677d3..fcac90b 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -990,8 +990,9 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
         """{"name": "gpu","addresses":["0", "1", "2"]}""")
 
       val conf = new SparkConf()
-        .setMaster("local-cluster[3, 1, 1024]")
+        .setMaster("local-cluster[3, 2, 1024]")
         .setAppName("test-cluster")
+        .set(CPUS_PER_TASK, 2)
         .set(WORKER_GPU_ID.amountConf, "3")
         .set(WORKER_GPU_ID.discoveryScriptConf, discoveryScript)
         .set(TASK_GPU_ID.amountConf, "3")
@@ -1002,11 +1003,18 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
       // Ensure all executors has started
       TestUtils.waitUntilExecutorsUp(sc, 3, 60000)
 
-      val rdd = sc.makeRDD(1 to 10, 3).mapPartitions { it =>
+      val rdd1 = sc.makeRDD(1 to 10, 3).mapPartitions { it =>
+        val context = TaskContext.get()
+        Iterator(context.cpus())
+      }
+      val cpus = rdd1.collect()
+      assert(cpus === Array(2, 2, 2))
+
+      val rdd2 = sc.makeRDD(1 to 10, 3).mapPartitions { it =>
         val context = TaskContext.get()
         context.resources().get(GPU).get.addresses.iterator
       }
-      val gpus = rdd.collect()
+      val gpus = rdd2.collect()
       assert(gpus.sorted === Seq("0", "0", "0", "1", "1", "1", "2", "2", "2"))
 
       eventually(timeout(10.seconds)) {
diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
index 24182e4..9bbfdc7 100644
--- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
@@ -301,7 +301,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
       // We don't really verify the data, just pass it around.
       val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
       val taskDescription = new TaskDescription(taskId, 2, "1", "TASK 1000000", 19,
-        1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new Properties,
+        1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new Properties, 1,
         Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
       val serializedTaskDescription = TaskDescription.encode(taskDescription)
       backend.executor = mock[Executor]
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 8ec279a..3938ce3 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -570,6 +570,7 @@ class ExecutorSuite extends SparkFunSuite
       addedJars = Map[String, Long](),
       addedArchives = Map[String, Long](),
       properties = new Properties,
+      cpus = 1,
       resources = immutable.Map[String, ResourceInformation](),
       serializedTask)
   }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
index 3ce4ccf..4663717 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
@@ -246,7 +246,7 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo
     val taskDescs: Seq[Seq[TaskDescription]] = Seq(Seq(new TaskDescription(1, 0, "1",
       "t1", 0, 1, mutable.Map.empty[String, Long],
       mutable.Map.empty[String, Long], mutable.Map.empty[String, Long],
-      new Properties(), taskResources, bytebuffer)))
+      new Properties(), 1, taskResources, bytebuffer)))
     val ts = backend.getTaskSchedulerImpl()
     when(ts.resourceOffers(any[IndexedSeq[WorkerOffer]], any[Boolean])).thenReturn(taskDescs)
 
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 8a7ff9e..2200b5b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
       0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
       closureSerializer.serialize(TaskMetrics.registered).array())
     intercept[RuntimeException] {
-      task.run(0, 0, null, null, Option.empty)
+      task.run(0, 0, null, 1, null, Option.empty)
     }
     assert(TaskContextSuite.completed)
   }
@@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
       0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
       closureSerializer.serialize(TaskMetrics.registered).array())
     intercept[RuntimeException] {
-      task.run(0, 0, null, null, Option.empty)
+      task.run(0, 0, null, 1, null, Option.empty)
     }
     assert(TaskContextSuite.lastError.getMessage == "damn error")
   }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
index 98b5bad..25d7ab8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
@@ -76,6 +76,7 @@ class TaskDescriptionSuite extends SparkFunSuite {
       originalJars,
       originalArchives,
       originalProperties,
+      cpus = 2,
       originalResources,
       taskBuffer
     )
@@ -94,6 +95,7 @@ class TaskDescriptionSuite extends SparkFunSuite {
     assert(decodedTaskDescription.addedJars.equals(originalJars))
     assert(decodedTaskDescription.addedArchives.equals(originalArchives))
     assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties))
+    assert(decodedTaskDescription.cpus.equals(originalTaskDescription.cpus))
     assert(equalResources(decodedTaskDescription.resources, originalTaskDescription.resources))
     assert(decodedTaskDescription.serializedTask.equals(taskBuffer))
 
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index c7fc251..c72c180 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -1777,9 +1777,11 @@ class TaskSetManagerSuite
 
     val taskResourceAssignments = Map(GPU -> new ResourceInformation(GPU, Array("0", "1")))
     val taskOption =
-      manager.resourceOffer("exec1", "host1", NO_PREF, taskResourceAssignments)._1
+      manager.resourceOffer("exec1", "host1", NO_PREF, 2, taskResourceAssignments)._1
     assert(taskOption.isDefined)
+    val allocatedCpus = taskOption.get.cpus
     val allocatedResources = taskOption.get.resources
+    assert(allocatedCpus == 2)
     assert(allocatedResources.size == 1)
     assert(allocatedResources(GPU).addresses sameElements Array("0", "1"))
   }
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
index d2bf385..c415652 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
@@ -27,6 +27,7 @@ import org.scalatest.BeforeAndAfterEach
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.{SparkException, SparkFunSuite, TaskContext, TaskContextImpl}
+import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.util.ThreadUtils
 
 
@@ -62,7 +63,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
   private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
     try {
       TaskContext.setTaskContext(
-        new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null))
+        new TaskContextImpl(0, 0, 0, taskAttemptId, 0,
+          null, new Properties, null, TaskMetrics.empty, 1))
       block
     } finally {
       TaskContext.unset()
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 0469472..3427c4a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -61,6 +61,9 @@ object MimaExcludes {
     // in the REST API call for a specified stage
     ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"),
 
+    // [SPARK-36173][CORE] Support getting CPU number in TaskContext
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.cpus"),
+
     // [SPARK-35896] Include more granular metrics for stateful operators in StreamingQueryProgress
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"),
 
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index ac0eb3b..aa736d1 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -33,6 +33,7 @@ class TaskContext(object):
     _stageId = None
     _taskAttemptId = None
     _localProperties = None
+    _cpus = None
     _resources = None
 
     def __new__(cls):
@@ -97,6 +98,12 @@ class TaskContext(object):
         """
         return self._localProperties.get(key, None)
 
+    def cpus(self):
+        """
+        CPUs allocated to the task.
+        """
+        return self._cpus
+
     def resources(self):
         """
         Resources allocated to the task. The key is the resource name and the value is information
diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py
index 8b39837..a526673 100644
--- a/python/pyspark/tests/test_taskcontext.py
+++ b/python/pyspark/tests/test_taskcontext.py
@@ -309,9 +309,16 @@ class TaskContextTestsWithResources(unittest.TestCase):
         conf = SparkConf().set("spark.test.home", SPARK_HOME)
         conf = conf.set("spark.worker.resource.gpu.discoveryScript", self.tempFile.name)
         conf = conf.set("spark.worker.resource.gpu.amount", 1)
+        conf = conf.set("spark.task.cpus", 2)
         conf = conf.set("spark.task.resource.gpu.amount", "1")
         conf = conf.set("spark.executor.resource.gpu.amount", "1")
-        self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf)
+        self.sc = SparkContext('local-cluster[2,2,1024]', class_name, conf=conf)
+
+    def test_cpus(self):
+        """Test the cpus are available."""
+        rdd = self.sc.parallelize(range(10))
+        cpus = rdd.map(lambda x: TaskContext.get().cpus()).take(1)[0]
+        self.assertEqual(cpus, 2)
 
     def test_resources(self):
         """Test the resources are available."""
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index a13717f..ad6c003 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -532,6 +532,7 @@ def main(infile, outfile):
         taskContext._partitionId = read_int(infile)
         taskContext._attemptNumber = read_int(infile)
         taskContext._taskAttemptId = read_long(infile)
+        taskContext._cpus = read_int(infile)
         taskContext._resources = {}
         for r in range(read_int(infile)):
             key = utf8_deserializer.loads(infile)
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
index 10030a2..fa4e800 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
@@ -266,6 +266,7 @@ class MesosFineGrainedSchedulerBackendSuite
       addedJars = mutable.Map.empty[String, Long],
       addedArchives = mutable.Map.empty[String, Long],
       properties = new Properties(),
+      cpus = 1,
       resources = immutable.Map.empty[String, ResourceInformation],
       ByteBuffer.wrap(new Array[Byte](0)))
     when(taskScheduler.resourceOffers(
@@ -380,6 +381,7 @@ class MesosFineGrainedSchedulerBackendSuite
       addedJars = mutable.Map.empty[String, Long],
       addedArchives = mutable.Map.empty[String, Long],
       properties = new Properties(),
+      cpus = 1,
       resources = immutable.Map.empty[String, ResourceInformation],
       ByteBuffer.wrap(new Array[Byte](0)))
     when(taskScheduler.resourceOffers(

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org