This is an automated email from the ASF dual-hosted git repository.

meng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new af63971  [SPARK-30667][CORE] Add allGather method to BarrierTaskContext
af63971 is described below

commit af63971cb7a5e7c7cb23ff1f87e5838d54c59a7d
Author: sarthfrey-db <sarth.f...@databricks.com>
AuthorDate: Thu Feb 13 16:15:00 2020 -0800

    [SPARK-30667][CORE] Add allGather method to BarrierTaskContext
    
    ### What changes were proposed in this pull request?
    
    The `allGather` method is added to the `BarrierTaskContext`. This method 
contains the same functionality as the `BarrierTaskContext.barrier` method; it 
blocks the task until all tasks make the call, at which time they may continue 
execution. In addition, the `allGather` method takes an input message. Upon 
returning from the `allGather` the task receives a list of all the messages 
sent by all the tasks that made the `allGather` call.
    
    ### Why are the changes needed?
    
    There are many situations where having the tasks communicate in a 
synchronized way is useful. One simple example is if each task needs to start a 
server to serve requests from one another; first the tasks must find a free 
port (the result of which is undetermined beforehand) and then start making 
requests, but to do so they each must know the port chosen by the other task. 
An `allGather` method would allow them to inform each other of the port they 
will run on.
    
    ### Does this PR introduce any user-facing change?
    
    Yes, an `BarrierTaskContext.allGather` method will be available through the 
Scala, Java, and Python APIs.
    
    ### How was this patch tested?
    
    Most of the code path is already covered by tests to the `barrier` method, 
since this PR includes a refactor so that much code is shared by the `barrier` 
and `allGather` methods. However, a test is added to assert that an all gather 
on each tasks partition ID will return a list of every partition ID.
    
    An example through the Python API:
    ```python
    >>> from pyspark import BarrierTaskContext
    >>>
    >>> def f(iterator):
    ...     context = BarrierTaskContext.get()
    ...     return [context.allGather('{}'.format(context.partitionId()))]
    ...
    >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0]
    [u'3', u'1', u'0', u'2']
    ```
    
    Closes #27395 from sarthfrey/master.
    
    Lead-authored-by: sarthfrey-db <sarth.f...@databricks.com>
    Co-authored-by: sarthfrey <sarth.f...@gmail.com>
    Signed-off-by: Xiangrui Meng <m...@databricks.com>
    (cherry picked from commit 57254c9719f9af9ad985596ed7fbbaafa4052002)
    Signed-off-by: Xiangrui Meng <m...@databricks.com>
---
 .../org/apache/spark/BarrierCoordinator.scala      | 113 +++++++++++++--
 .../org/apache/spark/BarrierTaskContext.scala      | 153 ++++++++++++++-------
 .../org/apache/spark/api/python/PythonRunner.scala |  51 +++++--
 .../spark/scheduler/BarrierTaskContextSuite.scala  |  74 ++++++++++
 python/pyspark/taskcontext.py                      |  49 ++++++-
 python/pyspark/tests/test_taskcontext.py           |  20 +++
 6 files changed, 381 insertions(+), 79 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala 
b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
index 4e41767..042a266 100644
--- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
@@ -17,12 +17,17 @@
 
 package org.apache.spark
 
+import java.nio.charset.StandardCharsets.UTF_8
 import java.util.{Timer, TimerTask}
 import java.util.concurrent.ConcurrentHashMap
 import java.util.function.Consumer
 
 import scala.collection.mutable.ArrayBuffer
 
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
 import org.apache.spark.internal.Logging
 import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
 import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, 
SparkListenerStageCompleted}
@@ -99,10 +104,15 @@ private[spark] class BarrierCoordinator(
     // reset when a barrier() call fails due to timeout.
     private var barrierEpoch: Int = 0
 
-    // An array of RPCCallContexts for barrier tasks that are waiting for 
reply of a barrier()
-    // call.
+    // An Array of RPCCallContexts for barrier tasks that have made a blocking 
runBarrier() call
     private val requesters: ArrayBuffer[RpcCallContext] = new 
ArrayBuffer[RpcCallContext](numTasks)
 
+    // An Array of allGather messages for barrier tasks that have made a 
blocking runBarrier() call
+    private val allGatherMessages: ArrayBuffer[String] = new 
Array[String](numTasks).to[ArrayBuffer]
+
+    // The blocking requestMethod called by tasks to sync up for this stage 
attempt
+    private var requestMethodToSync: RequestMethod.Value = 
RequestMethod.BARRIER
+
     // A timer task that ensures we may timeout for a barrier() call.
     private var timerTask: TimerTask = null
 
@@ -130,9 +140,32 @@ private[spark] class BarrierCoordinator(
 
     // Process the global sync request. The barrier() call succeed if 
collected enough requests
     // within a configured time, otherwise fail all the pending requests.
-    def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit 
= synchronized {
+    def handleRequest(
+      requester: RpcCallContext,
+      request: RequestToSync
+    ): Unit = synchronized {
       val taskId = request.taskAttemptId
       val epoch = request.barrierEpoch
+      val requestMethod = request.requestMethod
+      val partitionId = request.partitionId
+      val allGatherMessage = request match {
+        case ag: AllGatherRequestToSync => ag.allGatherMessage
+        case _ => ""
+      }
+
+      if (requesters.size == 0) {
+        requestMethodToSync = requestMethod
+      }
+
+      if (requestMethodToSync != requestMethod) {
+        requesters.foreach(
+          _.sendFailure(new SparkException(s"$barrierId tried to use 
requestMethod " +
+            s"`$requestMethod` during barrier epoch $barrierEpoch, which does 
not match " +
+            s"the current synchronized requestMethod `$requestMethodToSync`"
+          ))
+        )
+        cleanupBarrierStage(barrierId)
+      }
 
       // Require the number of tasks is correctly set from the 
BarrierTaskContext.
       require(request.numTasks == numTasks, s"Number of tasks of $barrierId is 
" +
@@ -153,6 +186,7 @@ private[spark] class BarrierCoordinator(
         }
         // Add the requester to array of RPCCallContexts pending for reply.
         requesters += requester
+        allGatherMessages(partitionId) = allGatherMessage
         logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received 
update from Task " +
           s"$taskId, current progress: ${requesters.size}/$numTasks.")
         if (maybeFinishAllRequesters(requesters, numTasks)) {
@@ -162,6 +196,7 @@ private[spark] class BarrierCoordinator(
             s"tasks, finished successfully.")
           barrierEpoch += 1
           requesters.clear()
+          allGatherMessages.clear()
           cancelTimerTask()
         }
       }
@@ -173,7 +208,13 @@ private[spark] class BarrierCoordinator(
         requesters: ArrayBuffer[RpcCallContext],
         numTasks: Int): Boolean = {
       if (requesters.size == numTasks) {
-        requesters.foreach(_.reply(()))
+        requestMethodToSync match {
+          case RequestMethod.BARRIER =>
+            requesters.foreach(_.reply(""))
+          case RequestMethod.ALL_GATHER =>
+            val json: String = compact(render(allGatherMessages))
+            requesters.foreach(_.reply(json))
+        }
         true
       } else {
         false
@@ -186,6 +227,7 @@ private[spark] class BarrierCoordinator(
       // messages come from current stage attempt shall fail.
       barrierEpoch = -1
       requesters.clear()
+      allGatherMessages.clear()
       cancelTimerTask()
     }
   }
@@ -199,11 +241,11 @@ private[spark] class BarrierCoordinator(
   }
 
   override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
-    case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
+    case request: RequestToSync =>
       // Get or init the ContextBarrierState correspond to the stage attempt.
-      val barrierId = ContextBarrierId(stageId, stageAttemptId)
+      val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId)
       states.computeIfAbsent(barrierId,
-        (key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
+        (key: ContextBarrierId) => new ContextBarrierState(key, 
request.numTasks))
       val barrierState = states.get(barrierId)
 
       barrierState.handleRequest(context, request)
@@ -216,6 +258,16 @@ private[spark] class BarrierCoordinator(
 
 private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
 
+private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
+  def numTasks: Int
+  def stageId: Int
+  def stageAttemptId: Int
+  def taskAttemptId: Long
+  def barrierEpoch: Int
+  def partitionId: Int
+  def requestMethod: RequestMethod.Value
+}
+
 /**
  * A global sync request message from BarrierTaskContext, by `barrier()` call. 
Each request is
  * identified by stageId + stageAttemptId + barrierEpoch.
@@ -224,11 +276,44 @@ private[spark] sealed trait BarrierCoordinatorMessage 
extends Serializable
  * @param stageId ID of current stage
  * @param stageAttemptId ID of current stage attempt
  * @param taskAttemptId Unique ID of current task
- * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple 
`barrier()` calls.
+ * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple 
`barrier()` calls
+ * @param partitionId ID of the current partition the task is assigned to
+ * @param requestMethod The BarrierTaskContext method that was called to 
trigger BarrierCoordinator
  */
-private[spark] case class RequestToSync(
-    numTasks: Int,
-    stageId: Int,
-    stageAttemptId: Int,
-    taskAttemptId: Long,
-    barrierEpoch: Int) extends BarrierCoordinatorMessage
+private[spark] case class BarrierRequestToSync(
+  numTasks: Int,
+  stageId: Int,
+  stageAttemptId: Int,
+  taskAttemptId: Long,
+  barrierEpoch: Int,
+  partitionId: Int,
+  requestMethod: RequestMethod.Value
+) extends RequestToSync
+
+/**
+ * A global sync request message from BarrierTaskContext, by `allGather()` 
call. Each request is
+ * identified by stageId + stageAttemptId + barrierEpoch.
+ *
+ * @param numTasks The number of global sync requests the BarrierCoordinator 
shall receive
+ * @param stageId ID of current stage
+ * @param stageAttemptId ID of current stage attempt
+ * @param taskAttemptId Unique ID of current task
+ * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple 
`barrier()` calls
+ * @param partitionId ID of the current partition the task is assigned to
+ * @param requestMethod The BarrierTaskContext method that was called to 
trigger BarrierCoordinator
+ * @param allGatherMessage Message sent from the BarrierTaskContext if 
requestMethod is ALL_GATHER
+ */
+private[spark] case class AllGatherRequestToSync(
+  numTasks: Int,
+  stageId: Int,
+  stageAttemptId: Int,
+  taskAttemptId: Long,
+  barrierEpoch: Int,
+  partitionId: Int,
+  requestMethod: RequestMethod.Value,
+  allGatherMessage: String
+) extends RequestToSync
+
+private[spark] object RequestMethod extends Enumeration {
+  val BARRIER, ALL_GATHER = Value
+}
diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala 
b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
index 3d36980..2263538 100644
--- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
@@ -17,11 +17,19 @@
 
 package org.apache.spark
 
+import java.nio.charset.StandardCharsets.UTF_8
 import java.util.{Properties, Timer, TimerTask}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.TimeoutException
 import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import org.json4s.DefaultFormats
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.parse
 
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.executor.TaskMetrics
@@ -59,49 +67,31 @@ class BarrierTaskContext private[spark] (
   // from different tasks within the same barrier stage attempt to succeed.
   private lazy val numTasks = getTaskInfos().size
 
-  /**
-   * :: Experimental ::
-   * Sets a global barrier and waits until all tasks in this stage hit this 
barrier. Similar to
-   * MPI_Barrier function in MPI, the barrier() function call blocks until all 
tasks in the same
-   * stage have reached this routine.
-   *
-   * CAUTION! In a barrier stage, each task must have the same number of 
barrier() calls, in all
-   * possible code branches. Otherwise, you may get the job hanging or a 
SparkException after
-   * timeout. Some examples of '''misuses''' are listed below:
-   * 1. Only call barrier() function on a subset of all the tasks in the same 
barrier stage, it
-   * shall lead to timeout of the function call.
-   * {{{
-   *   rdd.barrier().mapPartitions { iter =>
-   *       val context = BarrierTaskContext.get()
-   *       if (context.partitionId() == 0) {
-   *           // Do nothing.
-   *       } else {
-   *           context.barrier()
-   *       }
-   *       iter
-   *   }
-   * }}}
-   *
-   * 2. Include barrier() function in a try-catch code block, this may lead to 
timeout of the
-   * second function call.
-   * {{{
-   *   rdd.barrier().mapPartitions { iter =>
-   *       val context = BarrierTaskContext.get()
-   *       try {
-   *           // Do something that might throw an Exception.
-   *           doSomething()
-   *           context.barrier()
-   *       } catch {
-   *           case e: Exception => logWarning("...", e)
-   *       }
-   *       context.barrier()
-   *       iter
-   *   }
-   * }}}
-   */
-  @Experimental
-  @Since("2.4.0")
-  def barrier(): Unit = {
+  private def getRequestToSync(
+    numTasks: Int,
+    stageId: Int,
+    stageAttemptNumber: Int,
+    taskAttemptId: Long,
+    barrierEpoch: Int,
+    partitionId: Int,
+    requestMethod: RequestMethod.Value,
+    allGatherMessage: String
+  ): RequestToSync = {
+    requestMethod match {
+      case RequestMethod.BARRIER =>
+        BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, 
taskAttemptId,
+          barrierEpoch, partitionId, requestMethod)
+      case RequestMethod.ALL_GATHER =>
+        AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, 
taskAttemptId,
+          barrierEpoch, partitionId, requestMethod, allGatherMessage)
+    }
+  }
+
+  private def runBarrier(
+    requestMethod: RequestMethod.Value,
+    allGatherMessage: String = ""
+  ): String = {
+
     logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt 
$stageAttemptNumber) has entered " +
       s"the global sync, current barrier epoch is $barrierEpoch.")
     logTrace("Current callSite: " + Utils.getCallSite())
@@ -118,10 +108,12 @@ class BarrierTaskContext private[spark] (
     // Log the update of global sync every 60 seconds.
     timer.schedule(timerTask, 60000, 60000)
 
+    var json: String = ""
+
     try {
-      val abortableRpcFuture = barrierCoordinator.askAbortable[Unit](
-        message = RequestToSync(numTasks, stageId, stageAttemptNumber, 
taskAttemptId,
-          barrierEpoch),
+      val abortableRpcFuture = barrierCoordinator.askAbortable[String](
+        message = getRequestToSync(numTasks, stageId, stageAttemptNumber,
+          taskAttemptId, barrierEpoch, partitionId, requestMethod, 
allGatherMessage),
         // Set a fixed timeout for RPC here, so users shall get a 
SparkException thrown by
         // BarrierCoordinator on timeout, instead of RPCTimeoutException from 
the RPC framework.
         timeout = new RpcTimeout(365.days, "barrierTimeout"))
@@ -133,7 +125,7 @@ class BarrierTaskContext private[spark] (
         while (!abortableRpcFuture.toFuture.isCompleted) {
           // wait RPC future for at most 1 second
           try {
-            ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
+            json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 
1.second)
           } catch {
             case _: TimeoutException | _: InterruptedException =>
               // If `TimeoutException` thrown, waiting RPC future reach 1 
second.
@@ -163,6 +155,73 @@ class BarrierTaskContext private[spark] (
       timerTask.cancel()
       timer.purge()
     }
+    json
+  }
+
+  /**
+   * :: Experimental ::
+   * Sets a global barrier and waits until all tasks in this stage hit this 
barrier. Similar to
+   * MPI_Barrier function in MPI, the barrier() function call blocks until all 
tasks in the same
+   * stage have reached this routine.
+   *
+   * CAUTION! In a barrier stage, each task must have the same number of 
barrier() calls, in all
+   * possible code branches. Otherwise, you may get the job hanging or a 
SparkException after
+   * timeout. Some examples of '''misuses''' are listed below:
+   * 1. Only call barrier() function on a subset of all the tasks in the same 
barrier stage, it
+   * shall lead to timeout of the function call.
+   * {{{
+   *   rdd.barrier().mapPartitions { iter =>
+   *       val context = BarrierTaskContext.get()
+   *       if (context.partitionId() == 0) {
+   *           // Do nothing.
+   *       } else {
+   *           context.barrier()
+   *       }
+   *       iter
+   *   }
+   * }}}
+   *
+   * 2. Include barrier() function in a try-catch code block, this may lead to 
timeout of the
+   * second function call.
+   * {{{
+   *   rdd.barrier().mapPartitions { iter =>
+   *       val context = BarrierTaskContext.get()
+   *       try {
+   *           // Do something that might throw an Exception.
+   *           doSomething()
+   *           context.barrier()
+   *       } catch {
+   *           case e: Exception => logWarning("...", e)
+   *       }
+   *       context.barrier()
+   *       iter
+   *   }
+   * }}}
+   */
+  @Experimental
+  @Since("2.4.0")
+  def barrier(): Unit = {
+    runBarrier(RequestMethod.BARRIER)
+    ()
+  }
+
+  /**
+   * :: Experimental ::
+   * Blocks until all tasks in the same stage have reached this routine. Each 
task passes in
+   * a message and returns with a list of all the messages passed in by each 
of those tasks.
+   *
+   * CAUTION! The allGather method requires the same precautions as the 
barrier method
+   *
+   * The message is type String rather than Array[Byte] because it is more 
convenient for
+   * the user at the cost of worse performance.
+   */
+  @Experimental
+  @Since("3.0.0")
+  def allGather(message: String): ArrayBuffer[String] = {
+    val json = runBarrier(RequestMethod.ALL_GATHER, message)
+    val jsonArray = parse(json)
+    implicit val formats = DefaultFormats
+    ArrayBuffer(jsonArray.extract[Array[String]]: _*)
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 658e0d5..fa8bf0f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -24,8 +24,13 @@ import java.nio.charset.StandardCharsets.UTF_8
 import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
 
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
 import org.apache.spark._
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
@@ -238,13 +243,18 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
                   sock.setSoTimeout(10000)
                   authHelper.authClient(sock)
                   val input = new DataInputStream(sock.getInputStream())
-                  input.readInt() match {
+                  val requestMethod = input.readInt()
+                  // The BarrierTaskContext function may wait infinitely, 
socket shall not timeout
+                  // before the function finishes.
+                  sock.setSoTimeout(0)
+                  requestMethod match {
                     case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
-                      // The barrier() function may wait infinitely, socket 
shall not timeout
-                      // before the function finishes.
-                      sock.setSoTimeout(0)
-                      barrierAndServe(sock)
-
+                      barrierAndServe(requestMethod, sock)
+                    case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION 
=>
+                      val length = input.readInt()
+                      val message = new Array[Byte](length)
+                      input.readFully(message)
+                      barrierAndServe(requestMethod, sock, new String(message, 
UTF_8))
                     case _ =>
                       val out = new DataOutputStream(new BufferedOutputStream(
                         sock.getOutputStream))
@@ -395,15 +405,31 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     }
 
     /**
-     * Gateway to call BarrierTaskContext.barrier().
+     * Gateway to call BarrierTaskContext methods.
      */
-    def barrierAndServe(sock: Socket): Unit = {
-      require(serverSocket.isDefined, "No available ServerSocket to redirect 
the barrier() call.")
-
+    def barrierAndServe(requestMethod: Int, sock: Socket, message: String = 
""): Unit = {
+      require(
+        serverSocket.isDefined,
+        "No available ServerSocket to redirect the BarrierTaskContext method 
call."
+      )
       val out = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream))
       try {
-        context.asInstanceOf[BarrierTaskContext].barrier()
-        writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out)
+        var result: String = ""
+        requestMethod match {
+          case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
+            context.asInstanceOf[BarrierTaskContext].barrier()
+            result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS
+          case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
+            val messages: ArrayBuffer[String] = 
context.asInstanceOf[BarrierTaskContext].allGather(
+              message
+            )
+            result = compact(render(JArray(
+              messages.map(
+                (message) => JString(message)
+              ).toList
+            )))
+        }
+        writeUTF(result, out)
       } catch {
         case e: SparkException =>
           writeUTF(e.getMessage, out)
@@ -638,6 +664,7 @@ private[spark] object SpecialLengths {
 
 private[spark] object BarrierTaskContextMessageProtocol {
   val BARRIER_FUNCTION = 1
+  val ALL_GATHER_FUNCTION = 2
   val BARRIER_RESULT_SUCCESS = "success"
   val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python 
side."
 }
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
index fc8ac38..ed38b7f 100644
--- 
a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
 
 import java.io.File
 
+import scala.collection.mutable.ArrayBuffer
 import scala.util.Random
 
 import org.apache.spark._
@@ -52,6 +53,79 @@ class BarrierTaskContextSuite extends SparkFunSuite with 
LocalSparkContext {
     assert(times.max - times.min <= 1000)
   }
 
+  test("share messages with allGather() call") {
+    val conf = new SparkConf()
+      .setMaster("local-cluster[4, 1, 1024]")
+      .setAppName("test-cluster")
+    sc = new SparkContext(conf)
+    val rdd = sc.makeRDD(1 to 10, 4)
+    val rdd2 = rdd.barrier().mapPartitions { it =>
+      val context = BarrierTaskContext.get()
+      // Sleep for a random time before global sync.
+      Thread.sleep(Random.nextInt(1000))
+      // Pass partitionId message in
+      val message = context.partitionId().toString
+      val messages = context.allGather(message)
+      messages.toList.iterator
+    }
+    // Take a sorted list of all the partitionId messages
+    val messages = rdd2.collect().head
+    // All the task partitionIds are shared
+    for((x, i) <- messages.view.zipWithIndex) assert(x == i.toString)
+  }
+
+  test("throw exception if we attempt to synchronize with different blocking 
calls") {
+    val conf = new SparkConf()
+      .setMaster("local-cluster[4, 1, 1024]")
+      .setAppName("test-cluster")
+    sc = new SparkContext(conf)
+    val rdd = sc.makeRDD(1 to 10, 4)
+    val rdd2 = rdd.barrier().mapPartitions { it =>
+      val context = BarrierTaskContext.get()
+      val partitionId = context.partitionId
+      if (partitionId == 0) {
+        context.barrier()
+      } else {
+        context.allGather(partitionId.toString)
+      }
+      Seq(null).iterator
+    }
+    val error = intercept[SparkException] {
+      rdd2.collect()
+    }.getMessage
+    assert(error.contains("does not match the current synchronized 
requestMethod"))
+  }
+
+  test("successively sync with allGather and barrier") {
+    val conf = new SparkConf()
+      .setMaster("local-cluster[4, 1, 1024]")
+      .setAppName("test-cluster")
+    sc = new SparkContext(conf)
+    val rdd = sc.makeRDD(1 to 10, 4)
+    val rdd2 = rdd.barrier().mapPartitions { it =>
+      val context = BarrierTaskContext.get()
+      // Sleep for a random time before global sync.
+      Thread.sleep(Random.nextInt(1000))
+      context.barrier()
+      val time1 = System.currentTimeMillis()
+      // Sleep for a random time before global sync.
+      Thread.sleep(Random.nextInt(1000))
+      // Pass partitionId message in
+      val message = context.partitionId().toString
+      val messages = context.allGather(message)
+      val time2 = System.currentTimeMillis()
+      Seq((time1, time2)).iterator
+    }
+    val times = rdd2.collect()
+    // All the tasks shall finish the first round of global sync within a 
short time slot.
+    val times1 = times.map(_._1)
+    assert(times1.max - times1.min <= 1000)
+
+    // All the tasks shall finish the second round of global sync within a 
short time slot.
+    val times2 = times.map(_._2)
+    assert(times2.max - times2.min <= 1000)
+  }
+
   test("support multiple barrier() call within a single task") {
     initLocalClusterSparkContext()
     val rdd = sc.makeRDD(1 to 10, 4)
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index d648f63..90bd234 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -16,9 +16,10 @@
 #
 
 from __future__ import print_function
+import json
 
 from pyspark.java_gateway import local_connect_and_auth
-from pyspark.serializers import write_int, UTF8Deserializer
+from pyspark.serializers import write_int, write_with_length, UTF8Deserializer
 
 
 class TaskContext(object):
@@ -107,18 +108,28 @@ class TaskContext(object):
 
 
 BARRIER_FUNCTION = 1
+ALL_GATHER_FUNCTION = 2
 
 
-def _load_from_socket(port, auth_secret):
+def _load_from_socket(port, auth_secret, function, all_gather_message=None):
     """
     Load data from a given socket, this is a blocking method thus only return 
when the socket
     connection has been closed.
     """
     (sockfile, sock) = local_connect_and_auth(port, auth_secret)
-    # The barrier() call may block forever, so no timeout
+
+    # The call may block forever, so no timeout
     sock.settimeout(None)
-    # Make a barrier() function call.
-    write_int(BARRIER_FUNCTION, sockfile)
+
+    if function == BARRIER_FUNCTION:
+        # Make a barrier() function call.
+        write_int(function, sockfile)
+    elif function == ALL_GATHER_FUNCTION:
+        # Make a all_gather() function call.
+        write_int(function, sockfile)
+        write_with_length(all_gather_message.encode("utf-8"), sockfile)
+    else:
+        raise ValueError("Unrecognized function type")
     sockfile.flush()
 
     # Collect result.
@@ -199,7 +210,33 @@ class BarrierTaskContext(TaskContext):
             raise Exception("Not supported to call barrier() before initialize 
" +
                             "BarrierTaskContext.")
         else:
-            _load_from_socket(self._port, self._secret)
+            _load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
+
+    def allGather(self, message=""):
+        """
+        .. note:: Experimental
+
+        This function blocks until all tasks in the same stage have reached 
this routine.
+        Each task passes in a message and returns with a list of all the 
messages passed in
+        by each of those tasks.
+
+        .. warning:: In a barrier stage, each task much have the same number 
of `allGather()`
+            calls, in all possible code branches.
+            Otherwise, you may get the job hanging or a SparkException after 
timeout.
+        """
+        if not isinstance(message, str):
+            raise ValueError("Argument `message` must be of type `str`")
+        elif self._port is None or self._secret is None:
+            raise Exception("Not supported to call barrier() before initialize 
" +
+                            "BarrierTaskContext.")
+        else:
+            gathered_items = _load_from_socket(
+                self._port,
+                self._secret,
+                ALL_GATHER_FUNCTION,
+                message,
+            )
+            return [e for e in json.loads(gathered_items)]
 
     def getTaskInfos(self):
         """
diff --git a/python/pyspark/tests/test_taskcontext.py 
b/python/pyspark/tests/test_taskcontext.py
index 68cfe81..0053aad 100644
--- a/python/pyspark/tests/test_taskcontext.py
+++ b/python/pyspark/tests/test_taskcontext.py
@@ -135,6 +135,26 @@ class TaskContextTests(PySparkTestCase):
         times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
         self.assertTrue(max(times) - min(times) < 1)
 
+    def test_all_gather(self):
+        """
+        Verify that BarrierTaskContext.allGather() performs global sync among 
all barrier tasks
+        within a stage and passes messages properly.
+        """
+        rdd = self.sc.parallelize(range(10), 4)
+
+        def f(iterator):
+            yield sum(iterator)
+
+        def context_barrier(x):
+            tc = BarrierTaskContext.get()
+            time.sleep(random.randint(1, 10))
+            out = tc.allGather(str(context.partitionId()))
+            pids = [int(e) for e in out]
+            return [pids]
+
+        pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0]
+        self.assertTrue(pids == [0, 1, 2, 3])
+
     def test_barrier_infos(self):
         """
         Verify that BarrierTaskContext.getTaskInfos() returns a list of all 
task infos in the


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to