You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2020/02/14 00:15:34 UTC
[spark] branch master updated: [SPARK-30667][CORE] Add allGather
method to BarrierTaskContext
This is an automated email from the ASF dual-hosted git repository.
meng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 57254c9 [SPARK-30667][CORE] Add allGather method to BarrierTaskContext
57254c9 is described below
commit 57254c9719f9af9ad985596ed7fbbaafa4052002
Author: sarthfrey-db <sa...@databricks.com>
AuthorDate: Thu Feb 13 16:15:00 2020 -0800
[SPARK-30667][CORE] Add allGather method to BarrierTaskContext
### What changes were proposed in this pull request?
The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call.
### Why are the changes needed?
There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on.
### Does this PR introduce any user-facing change?
Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs.
### How was this patch tested?
Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID.
An example through the Python API:
```python
>>> from pyspark import BarrierTaskContext
>>>
>>> def f(iterator):
... context = BarrierTaskContext.get()
... return [context.allGather('{}'.format(context.partitionId()))]
...
>>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0]
[u'3', u'1', u'0', u'2']
```
Closes #27395 from sarthfrey/master.
Lead-authored-by: sarthfrey-db <sa...@databricks.com>
Co-authored-by: sarthfrey <sa...@gmail.com>
Signed-off-by: Xiangrui Meng <me...@databricks.com>
---
.../org/apache/spark/BarrierCoordinator.scala | 113 +++++++++++++--
.../org/apache/spark/BarrierTaskContext.scala | 153 ++++++++++++++-------
.../org/apache/spark/api/python/PythonRunner.scala | 51 +++++--
.../spark/scheduler/BarrierTaskContextSuite.scala | 74 ++++++++++
python/pyspark/taskcontext.py | 49 ++++++-
python/pyspark/tests/test_taskcontext.py | 20 +++
6 files changed, 381 insertions(+), 79 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
index 4e41767..042a266 100644
--- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
@@ -17,12 +17,17 @@
package org.apache.spark
+import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Timer, TimerTask}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.Consumer
import scala.collection.mutable.ArrayBuffer
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}
@@ -99,10 +104,15 @@ private[spark] class BarrierCoordinator(
// reset when a barrier() call fails due to timeout.
private var barrierEpoch: Int = 0
- // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier()
- // call.
+ // An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call
private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)
+ // An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call
+ private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer]
+
+ // The blocking requestMethod called by tasks to sync up for this stage attempt
+ private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER
+
// A timer task that ensures we may timeout for a barrier() call.
private var timerTask: TimerTask = null
@@ -130,9 +140,32 @@ private[spark] class BarrierCoordinator(
// Process the global sync request. The barrier() call succeed if collected enough requests
// within a configured time, otherwise fail all the pending requests.
- def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized {
+ def handleRequest(
+ requester: RpcCallContext,
+ request: RequestToSync
+ ): Unit = synchronized {
val taskId = request.taskAttemptId
val epoch = request.barrierEpoch
+ val requestMethod = request.requestMethod
+ val partitionId = request.partitionId
+ val allGatherMessage = request match {
+ case ag: AllGatherRequestToSync => ag.allGatherMessage
+ case _ => ""
+ }
+
+ if (requesters.size == 0) {
+ requestMethodToSync = requestMethod
+ }
+
+ if (requestMethodToSync != requestMethod) {
+ requesters.foreach(
+ _.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " +
+ s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " +
+ s"the current synchronized requestMethod `$requestMethodToSync`"
+ ))
+ )
+ cleanupBarrierStage(barrierId)
+ }
// Require the number of tasks is correctly set from the BarrierTaskContext.
require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " +
@@ -153,6 +186,7 @@ private[spark] class BarrierCoordinator(
}
// Add the requester to array of RPCCallContexts pending for reply.
requesters += requester
+ allGatherMessages(partitionId) = allGatherMessage
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
s"$taskId, current progress: ${requesters.size}/$numTasks.")
if (maybeFinishAllRequesters(requesters, numTasks)) {
@@ -162,6 +196,7 @@ private[spark] class BarrierCoordinator(
s"tasks, finished successfully.")
barrierEpoch += 1
requesters.clear()
+ allGatherMessages.clear()
cancelTimerTask()
}
}
@@ -173,7 +208,13 @@ private[spark] class BarrierCoordinator(
requesters: ArrayBuffer[RpcCallContext],
numTasks: Int): Boolean = {
if (requesters.size == numTasks) {
- requesters.foreach(_.reply(()))
+ requestMethodToSync match {
+ case RequestMethod.BARRIER =>
+ requesters.foreach(_.reply(""))
+ case RequestMethod.ALL_GATHER =>
+ val json: String = compact(render(allGatherMessages))
+ requesters.foreach(_.reply(json))
+ }
true
} else {
false
@@ -186,6 +227,7 @@ private[spark] class BarrierCoordinator(
// messages come from current stage attempt shall fail.
barrierEpoch = -1
requesters.clear()
+ allGatherMessages.clear()
cancelTimerTask()
}
}
@@ -199,11 +241,11 @@ private[spark] class BarrierCoordinator(
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
+ case request: RequestToSync =>
// Get or init the ContextBarrierState correspond to the stage attempt.
- val barrierId = ContextBarrierId(stageId, stageAttemptId)
+ val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId)
states.computeIfAbsent(barrierId,
- (key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
+ (key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks))
val barrierState = states.get(barrierId)
barrierState.handleRequest(context, request)
@@ -216,6 +258,16 @@ private[spark] class BarrierCoordinator(
private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
+private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
+ def numTasks: Int
+ def stageId: Int
+ def stageAttemptId: Int
+ def taskAttemptId: Long
+ def barrierEpoch: Int
+ def partitionId: Int
+ def requestMethod: RequestMethod.Value
+}
+
/**
* A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
@@ -224,11 +276,44 @@ private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
- * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls.
+ * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
+ * @param partitionId ID of the current partition the task is assigned to
+ * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
*/
-private[spark] case class RequestToSync(
- numTasks: Int,
- stageId: Int,
- stageAttemptId: Int,
- taskAttemptId: Long,
- barrierEpoch: Int) extends BarrierCoordinatorMessage
+private[spark] case class BarrierRequestToSync(
+ numTasks: Int,
+ stageId: Int,
+ stageAttemptId: Int,
+ taskAttemptId: Long,
+ barrierEpoch: Int,
+ partitionId: Int,
+ requestMethod: RequestMethod.Value
+) extends RequestToSync
+
+/**
+ * A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is
+ * identified by stageId + stageAttemptId + barrierEpoch.
+ *
+ * @param numTasks The number of global sync requests the BarrierCoordinator shall receive
+ * @param stageId ID of current stage
+ * @param stageAttemptId ID of current stage attempt
+ * @param taskAttemptId Unique ID of current task
+ * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
+ * @param partitionId ID of the current partition the task is assigned to
+ * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
+ * @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER
+ */
+private[spark] case class AllGatherRequestToSync(
+ numTasks: Int,
+ stageId: Int,
+ stageAttemptId: Int,
+ taskAttemptId: Long,
+ barrierEpoch: Int,
+ partitionId: Int,
+ requestMethod: RequestMethod.Value,
+ allGatherMessage: String
+) extends RequestToSync
+
+private[spark] object RequestMethod extends Enumeration {
+ val BARRIER, ALL_GATHER = Value
+}
diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
index 3d36980..2263538 100644
--- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
@@ -17,11 +17,19 @@
package org.apache.spark
+import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Properties, Timer, TimerTask}
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.concurrent.TimeoutException
import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import org.json4s.DefaultFormats
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.parse
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
@@ -59,49 +67,31 @@ class BarrierTaskContext private[spark] (
// from different tasks within the same barrier stage attempt to succeed.
private lazy val numTasks = getTaskInfos().size
- /**
- * :: Experimental ::
- * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
- * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
- * stage have reached this routine.
- *
- * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
- * possible code branches. Otherwise, you may get the job hanging or a SparkException after
- * timeout. Some examples of '''misuses''' are listed below:
- * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
- * shall lead to timeout of the function call.
- * {{{
- * rdd.barrier().mapPartitions { iter =>
- * val context = BarrierTaskContext.get()
- * if (context.partitionId() == 0) {
- * // Do nothing.
- * } else {
- * context.barrier()
- * }
- * iter
- * }
- * }}}
- *
- * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
- * second function call.
- * {{{
- * rdd.barrier().mapPartitions { iter =>
- * val context = BarrierTaskContext.get()
- * try {
- * // Do something that might throw an Exception.
- * doSomething()
- * context.barrier()
- * } catch {
- * case e: Exception => logWarning("...", e)
- * }
- * context.barrier()
- * iter
- * }
- * }}}
- */
- @Experimental
- @Since("2.4.0")
- def barrier(): Unit = {
+ private def getRequestToSync(
+ numTasks: Int,
+ stageId: Int,
+ stageAttemptNumber: Int,
+ taskAttemptId: Long,
+ barrierEpoch: Int,
+ partitionId: Int,
+ requestMethod: RequestMethod.Value,
+ allGatherMessage: String
+ ): RequestToSync = {
+ requestMethod match {
+ case RequestMethod.BARRIER =>
+ BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
+ barrierEpoch, partitionId, requestMethod)
+ case RequestMethod.ALL_GATHER =>
+ AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
+ barrierEpoch, partitionId, requestMethod, allGatherMessage)
+ }
+ }
+
+ private def runBarrier(
+ requestMethod: RequestMethod.Value,
+ allGatherMessage: String = ""
+ ): String = {
+
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
s"the global sync, current barrier epoch is $barrierEpoch.")
logTrace("Current callSite: " + Utils.getCallSite())
@@ -118,10 +108,12 @@ class BarrierTaskContext private[spark] (
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)
+ var json: String = ""
+
try {
- val abortableRpcFuture = barrierCoordinator.askAbortable[Unit](
- message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
- barrierEpoch),
+ val abortableRpcFuture = barrierCoordinator.askAbortable[String](
+ message = getRequestToSync(numTasks, stageId, stageAttemptNumber,
+ taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
timeout = new RpcTimeout(365.days, "barrierTimeout"))
@@ -133,7 +125,7 @@ class BarrierTaskContext private[spark] (
while (!abortableRpcFuture.toFuture.isCompleted) {
// wait RPC future for at most 1 second
try {
- ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
+ json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
} catch {
case _: TimeoutException | _: InterruptedException =>
// If `TimeoutException` thrown, waiting RPC future reach 1 second.
@@ -163,6 +155,73 @@ class BarrierTaskContext private[spark] (
timerTask.cancel()
timer.purge()
}
+ json
+ }
+
+ /**
+ * :: Experimental ::
+ * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
+ * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
+ * stage have reached this routine.
+ *
+ * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
+ * possible code branches. Otherwise, you may get the job hanging or a SparkException after
+ * timeout. Some examples of '''misuses''' are listed below:
+ * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
+ * shall lead to timeout of the function call.
+ * {{{
+ * rdd.barrier().mapPartitions { iter =>
+ * val context = BarrierTaskContext.get()
+ * if (context.partitionId() == 0) {
+ * // Do nothing.
+ * } else {
+ * context.barrier()
+ * }
+ * iter
+ * }
+ * }}}
+ *
+ * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
+ * second function call.
+ * {{{
+ * rdd.barrier().mapPartitions { iter =>
+ * val context = BarrierTaskContext.get()
+ * try {
+ * // Do something that might throw an Exception.
+ * doSomething()
+ * context.barrier()
+ * } catch {
+ * case e: Exception => logWarning("...", e)
+ * }
+ * context.barrier()
+ * iter
+ * }
+ * }}}
+ */
+ @Experimental
+ @Since("2.4.0")
+ def barrier(): Unit = {
+ runBarrier(RequestMethod.BARRIER)
+ ()
+ }
+
+ /**
+ * :: Experimental ::
+ * Blocks until all tasks in the same stage have reached this routine. Each task passes in
+ * a message and returns with a list of all the messages passed in by each of those tasks.
+ *
+ * CAUTION! The allGather method requires the same precautions as the barrier method
+ *
+ * The message is type String rather than Array[Byte] because it is more convenient for
+ * the user at the cost of worse performance.
+ */
+ @Experimental
+ @Since("3.0.0")
+ def allGather(message: String): ArrayBuffer[String] = {
+ val json = runBarrier(RequestMethod.ALL_GATHER, message)
+ val jsonArray = parse(json)
+ implicit val formats = DefaultFormats
+ ArrayBuffer(jsonArray.extract[Array[String]]: _*)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 658e0d5..fa8bf0f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -24,8 +24,13 @@ import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
@@ -238,13 +243,18 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
sock.setSoTimeout(10000)
authHelper.authClient(sock)
val input = new DataInputStream(sock.getInputStream())
- input.readInt() match {
+ val requestMethod = input.readInt()
+ // The BarrierTaskContext function may wait infinitely, socket shall not timeout
+ // before the function finishes.
+ sock.setSoTimeout(0)
+ requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
- // The barrier() function may wait infinitely, socket shall not timeout
- // before the function finishes.
- sock.setSoTimeout(0)
- barrierAndServe(sock)
-
+ barrierAndServe(requestMethod, sock)
+ case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
+ val length = input.readInt()
+ val message = new Array[Byte](length)
+ input.readFully(message)
+ barrierAndServe(requestMethod, sock, new String(message, UTF_8))
case _ =>
val out = new DataOutputStream(new BufferedOutputStream(
sock.getOutputStream))
@@ -395,15 +405,31 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
/**
- * Gateway to call BarrierTaskContext.barrier().
+ * Gateway to call BarrierTaskContext methods.
*/
- def barrierAndServe(sock: Socket): Unit = {
- require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.")
-
+ def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = {
+ require(
+ serverSocket.isDefined,
+ "No available ServerSocket to redirect the BarrierTaskContext method call."
+ )
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try {
- context.asInstanceOf[BarrierTaskContext].barrier()
- writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out)
+ var result: String = ""
+ requestMethod match {
+ case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
+ context.asInstanceOf[BarrierTaskContext].barrier()
+ result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS
+ case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
+ val messages: ArrayBuffer[String] = context.asInstanceOf[BarrierTaskContext].allGather(
+ message
+ )
+ result = compact(render(JArray(
+ messages.map(
+ (message) => JString(message)
+ ).toList
+ )))
+ }
+ writeUTF(result, out)
} catch {
case e: SparkException =>
writeUTF(e.getMessage, out)
@@ -638,6 +664,7 @@ private[spark] object SpecialLengths {
private[spark] object BarrierTaskContextMessageProtocol {
val BARRIER_FUNCTION = 1
+ val ALL_GATHER_FUNCTION = 2
val BARRIER_RESULT_SUCCESS = "success"
val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side."
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
index fc8ac38..ed38b7f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io.File
+import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.apache.spark._
@@ -52,6 +53,79 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
assert(times.max - times.min <= 1000)
}
+ test("share messages with allGather() call") {
+ val conf = new SparkConf()
+ .setMaster("local-cluster[4, 1, 1024]")
+ .setAppName("test-cluster")
+ sc = new SparkContext(conf)
+ val rdd = sc.makeRDD(1 to 10, 4)
+ val rdd2 = rdd.barrier().mapPartitions { it =>
+ val context = BarrierTaskContext.get()
+ // Sleep for a random time before global sync.
+ Thread.sleep(Random.nextInt(1000))
+ // Pass partitionId message in
+ val message = context.partitionId().toString
+ val messages = context.allGather(message)
+ messages.toList.iterator
+ }
+ // Take a sorted list of all the partitionId messages
+ val messages = rdd2.collect().head
+ // All the task partitionIds are shared
+ for((x, i) <- messages.view.zipWithIndex) assert(x == i.toString)
+ }
+
+ test("throw exception if we attempt to synchronize with different blocking calls") {
+ val conf = new SparkConf()
+ .setMaster("local-cluster[4, 1, 1024]")
+ .setAppName("test-cluster")
+ sc = new SparkContext(conf)
+ val rdd = sc.makeRDD(1 to 10, 4)
+ val rdd2 = rdd.barrier().mapPartitions { it =>
+ val context = BarrierTaskContext.get()
+ val partitionId = context.partitionId
+ if (partitionId == 0) {
+ context.barrier()
+ } else {
+ context.allGather(partitionId.toString)
+ }
+ Seq(null).iterator
+ }
+ val error = intercept[SparkException] {
+ rdd2.collect()
+ }.getMessage
+ assert(error.contains("does not match the current synchronized requestMethod"))
+ }
+
+ test("successively sync with allGather and barrier") {
+ val conf = new SparkConf()
+ .setMaster("local-cluster[4, 1, 1024]")
+ .setAppName("test-cluster")
+ sc = new SparkContext(conf)
+ val rdd = sc.makeRDD(1 to 10, 4)
+ val rdd2 = rdd.barrier().mapPartitions { it =>
+ val context = BarrierTaskContext.get()
+ // Sleep for a random time before global sync.
+ Thread.sleep(Random.nextInt(1000))
+ context.barrier()
+ val time1 = System.currentTimeMillis()
+ // Sleep for a random time before global sync.
+ Thread.sleep(Random.nextInt(1000))
+ // Pass partitionId message in
+ val message = context.partitionId().toString
+ val messages = context.allGather(message)
+ val time2 = System.currentTimeMillis()
+ Seq((time1, time2)).iterator
+ }
+ val times = rdd2.collect()
+ // All the tasks shall finish the first round of global sync within a short time slot.
+ val times1 = times.map(_._1)
+ assert(times1.max - times1.min <= 1000)
+
+ // All the tasks shall finish the second round of global sync within a short time slot.
+ val times2 = times.map(_._2)
+ assert(times2.max - times2.min <= 1000)
+ }
+
test("support multiple barrier() call within a single task") {
initLocalClusterSparkContext()
val rdd = sc.makeRDD(1 to 10, 4)
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index d648f63..90bd234 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -16,9 +16,10 @@
#
from __future__ import print_function
+import json
from pyspark.java_gateway import local_connect_and_auth
-from pyspark.serializers import write_int, UTF8Deserializer
+from pyspark.serializers import write_int, write_with_length, UTF8Deserializer
class TaskContext(object):
@@ -107,18 +108,28 @@ class TaskContext(object):
BARRIER_FUNCTION = 1
+ALL_GATHER_FUNCTION = 2
-def _load_from_socket(port, auth_secret):
+def _load_from_socket(port, auth_secret, function, all_gather_message=None):
"""
Load data from a given socket, this is a blocking method thus only return when the socket
connection has been closed.
"""
(sockfile, sock) = local_connect_and_auth(port, auth_secret)
- # The barrier() call may block forever, so no timeout
+
+ # The call may block forever, so no timeout
sock.settimeout(None)
- # Make a barrier() function call.
- write_int(BARRIER_FUNCTION, sockfile)
+
+ if function == BARRIER_FUNCTION:
+ # Make a barrier() function call.
+ write_int(function, sockfile)
+ elif function == ALL_GATHER_FUNCTION:
+ # Make a all_gather() function call.
+ write_int(function, sockfile)
+ write_with_length(all_gather_message.encode("utf-8"), sockfile)
+ else:
+ raise ValueError("Unrecognized function type")
sockfile.flush()
# Collect result.
@@ -199,7 +210,33 @@ class BarrierTaskContext(TaskContext):
raise Exception("Not supported to call barrier() before initialize " +
"BarrierTaskContext.")
else:
- _load_from_socket(self._port, self._secret)
+ _load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
+
+ def allGather(self, message=""):
+ """
+ .. note:: Experimental
+
+ This function blocks until all tasks in the same stage have reached this routine.
+ Each task passes in a message and returns with a list of all the messages passed in
+ by each of those tasks.
+
+ .. warning:: In a barrier stage, each task much have the same number of `allGather()`
+ calls, in all possible code branches.
+ Otherwise, you may get the job hanging or a SparkException after timeout.
+ """
+ if not isinstance(message, str):
+ raise ValueError("Argument `message` must be of type `str`")
+ elif self._port is None or self._secret is None:
+ raise Exception("Not supported to call barrier() before initialize " +
+ "BarrierTaskContext.")
+ else:
+ gathered_items = _load_from_socket(
+ self._port,
+ self._secret,
+ ALL_GATHER_FUNCTION,
+ message,
+ )
+ return [e for e in json.loads(gathered_items)]
def getTaskInfos(self):
"""
diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py
index 68cfe81..0053aad 100644
--- a/python/pyspark/tests/test_taskcontext.py
+++ b/python/pyspark/tests/test_taskcontext.py
@@ -135,6 +135,26 @@ class TaskContextTests(PySparkTestCase):
times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
self.assertTrue(max(times) - min(times) < 1)
+ def test_all_gather(self):
+ """
+ Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks
+ within a stage and passes messages properly.
+ """
+ rdd = self.sc.parallelize(range(10), 4)
+
+ def f(iterator):
+ yield sum(iterator)
+
+ def context_barrier(x):
+ tc = BarrierTaskContext.get()
+ time.sleep(random.randint(1, 10))
+ out = tc.allGather(str(context.partitionId()))
+ pids = [int(e) for e in out]
+ return [pids]
+
+ pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0]
+ self.assertTrue(pids == [0, 1, 2, 3])
+
def test_barrier_infos(self):
"""
Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org