Repository: spark
Updated Branches:
  refs/heads/master 42035a4fe -> ad45299d0


[SPARK-25095][PYSPARK] Python support for BarrierTaskContext

## What changes were proposed in this pull request?

Add method `barrier()` and `getTaskInfos()` in python TaskContext, these two 
methods are only allowed for barrier tasks.

## How was this patch tested?

Add new tests in `tests.py`

Closes #22085 from jiangxb1987/python.barrier.

Authored-by: Xingbo Jiang <xingbo.ji...@databricks.com>
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ad45299d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ad45299d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ad45299d

Branch: refs/heads/master
Commit: ad45299d047c10472fd3a86103930fe7c54a4cf1
Parents: 42035a4
Author: Xingbo Jiang <xingbo.ji...@databricks.com>
Authored: Tue Aug 21 15:54:30 2018 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Tue Aug 21 15:54:30 2018 -0700

----------------------------------------------------------------------
 .../apache/spark/api/python/PythonRunner.scala  | 106 ++++++++++++++
 python/pyspark/serializers.py                   |   7 +
 python/pyspark/taskcontext.py                   | 144 +++++++++++++++++++
 python/pyspark/tests.py                         |  36 ++++-
 python/pyspark/worker.py                        |  16 ++-
 5 files changed, 305 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ad45299d/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 7b31857..f824191 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -20,12 +20,14 @@ package org.apache.spark.api.python
 import java.io._
 import java.net._
 import java.nio.charset.StandardCharsets
+import java.nio.charset.StandardCharsets.UTF_8
 import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.collection.JavaConverters._
 
 import org.apache.spark._
 import org.apache.spark.internal.Logging
+import org.apache.spark.security.SocketAuthHelper
 import org.apache.spark.util._
 
 
@@ -76,6 +78,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
   // TODO: support accumulator in multiple UDF
   protected val accumulator = funcs.head.funcs.head.accumulator
 
+  // Expose a ServerSocket to support method calls via socket from Python side.
+  private[spark] var serverSocket: Option[ServerSocket] = None
+
+  // Authentication helper used when serving method calls via socket from 
Python side.
+  private lazy val authHelper = new SocketAuthHelper(SparkEnv.get.conf)
+
   def compute(
       inputIterator: Iterator[IN],
       partitionIndex: Int,
@@ -180,7 +188,73 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         dataOut.writeInt(partitionIndex)
         // Python version of driver
         PythonRDD.writeUTF(pythonVer, dataOut)
+        // Init a ServerSocket to accept method calls from Python side.
+        val isBarrier = context.isInstanceOf[BarrierTaskContext]
+        if (isBarrier) {
+          serverSocket = Some(new ServerSocket(/* port */ 0,
+            /* backlog */ 1,
+            InetAddress.getByName("localhost")))
+          // A call to accept() for ServerSocket shall block infinitely.
+          serverSocket.map(_.setSoTimeout(0))
+          new Thread("accept-connections") {
+            setDaemon(true)
+
+            override def run(): Unit = {
+              while (!serverSocket.get.isClosed()) {
+                var sock: Socket = null
+                try {
+                  sock = serverSocket.get.accept()
+                  // Wait for function call from python side.
+                  sock.setSoTimeout(10000)
+                  val input = new DataInputStream(sock.getInputStream())
+                  input.readInt() match {
+                    case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
+                      // The barrier() function may wait infinitely, socket 
shall not timeout
+                      // before the function finishes.
+                      sock.setSoTimeout(0)
+                      barrierAndServe(sock)
+
+                    case _ =>
+                      val out = new DataOutputStream(new BufferedOutputStream(
+                        sock.getOutputStream))
+                      
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
+                  }
+                } catch {
+                  case e: SocketException if e.getMessage.contains("Socket 
closed") =>
+                    // It is possible that the ServerSocket is not closed, but 
the native socket
+                    // has already been closed, we shall catch and silently 
ignore this case.
+                } finally {
+                  if (sock != null) {
+                    sock.close()
+                  }
+                }
+              }
+            }
+          }.start()
+        }
+        val secret = if (isBarrier) {
+          authHelper.secret
+        } else {
+          ""
+        }
+        // Close ServerSocket on task completion.
+        serverSocket.foreach { server =>
+          context.addTaskCompletionListener(_ => server.close())
+        }
+        val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
+        if (boundPort == -1) {
+          val message = "ServerSocket failed to bind to Java side."
+          logError(message)
+          throw new SparkException(message)
+        } else if (isBarrier) {
+          logDebug(s"Started ServerSocket on port $boundPort.")
+        }
         // Write out the TaskContextInfo
+        dataOut.writeBoolean(isBarrier)
+        dataOut.writeInt(boundPort)
+        val secretBytes = secret.getBytes(UTF_8)
+        dataOut.writeInt(secretBytes.length)
+        dataOut.write(secretBytes, 0, secretBytes.length)
         dataOut.writeInt(context.stageId())
         dataOut.writeInt(context.partitionId())
         dataOut.writeInt(context.attemptNumber())
@@ -243,6 +317,32 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
           }
       }
     }
+
+    /**
+     * Gateway to call BarrierTaskContext.barrier().
+     */
+    def barrierAndServe(sock: Socket): Unit = {
+      require(serverSocket.isDefined, "No available ServerSocket to redirect 
the barrier() call.")
+
+      authHelper.authClient(sock)
+
+      val out = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream))
+      try {
+        context.asInstanceOf[BarrierTaskContext].barrier()
+        writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out)
+      } catch {
+        case e: SparkException =>
+          writeUTF(e.getMessage, out)
+      } finally {
+        out.close()
+      }
+    }
+
+    def writeUTF(str: String, dataOut: DataOutputStream) {
+      val bytes = str.getBytes(UTF_8)
+      dataOut.writeInt(bytes.length)
+      dataOut.write(bytes)
+    }
   }
 
   abstract class ReaderIterator(
@@ -465,3 +565,9 @@ private[spark] object SpecialLengths {
   val NULL = -5
   val START_ARROW_STREAM = -6
 }
+
+private[spark] object BarrierTaskContextMessageProtocol {
+  val BARRIER_FUNCTION = 1
+  val BARRIER_RESULT_SUCCESS = "success"
+  val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python 
side."
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ad45299d/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 47c4c3e..1038558 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -715,6 +715,13 @@ def write_int(value, stream):
     stream.write(struct.pack("!i", value))
 
 
+def read_bool(stream):
+    length = stream.read(1)
+    if not length:
+        raise EOFError
+    return struct.unpack("!?", length)[0]
+
+
 def write_with_length(obj, stream):
     write_int(len(obj), stream)
     stream.write(obj)

http://git-wip-us.apache.org/repos/asf/spark/blob/ad45299d/python/pyspark/taskcontext.py
----------------------------------------------------------------------
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index 63ae1f3..c0312e5 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -16,6 +16,10 @@
 #
 
 from __future__ import print_function
+import socket
+
+from pyspark.java_gateway import do_server_auth
+from pyspark.serializers import write_int, UTF8Deserializer
 
 
 class TaskContext(object):
@@ -95,3 +99,143 @@ class TaskContext(object):
         Get a local property set upstream in the driver, or None if it is 
missing.
         """
         return self._localProperties.get(key, None)
+
+
+BARRIER_FUNCTION = 1
+
+
+def _load_from_socket(port, auth_secret):
+    """
+    Load data from a given socket, this is a blocking method thus only return 
when the socket
+    connection has been closed.
+
+    This is copied from context.py, while modified the message protocol.
+    """
+    sock = None
+    # Support for both IPv4 and IPv6.
+    # On most of IPv6-ready systems, IPv6 will take precedence.
+    for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, 
socket.SOCK_STREAM):
+        af, socktype, proto, canonname, sa = res
+        sock = socket.socket(af, socktype, proto)
+        try:
+            # Do not allow timeout for socket reading operation.
+            sock.settimeout(None)
+            sock.connect(sa)
+        except socket.error:
+            sock.close()
+            sock = None
+            continue
+        break
+    if not sock:
+        raise Exception("could not open socket")
+
+    # We don't really need a socket file here, it's just for convenience that 
we can reuse the
+    # do_server_auth() function and data serialization methods.
+    sockfile = sock.makefile("rwb", 65536)
+
+    # Make a barrier() function call.
+    write_int(BARRIER_FUNCTION, sockfile)
+    sockfile.flush()
+
+    # Do server auth.
+    do_server_auth(sockfile, auth_secret)
+
+    # Collect result.
+    res = UTF8Deserializer().loads(sockfile)
+
+    # Release resources.
+    sockfile.close()
+    sock.close()
+
+    return res
+
+
+class BarrierTaskContext(TaskContext):
+
+    """
+    .. note:: Experimental
+
+    A TaskContext with extra info and tooling for a barrier stage. To access 
the BarrierTaskContext
+    for a running task, use:
+    L{BarrierTaskContext.get()}.
+
+    .. versionadded:: 2.4.0
+    """
+
+    _port = None
+    _secret = None
+
+    def __init__(self):
+        """Construct a BarrierTaskContext, use get instead"""
+        pass
+
+    @classmethod
+    def _getOrCreate(cls):
+        """Internal function to get or create global BarrierTaskContext."""
+        if cls._taskContext is None:
+            cls._taskContext = BarrierTaskContext()
+        return cls._taskContext
+
+    @classmethod
+    def get(cls):
+        """
+        Return the currently active BarrierTaskContext. This can be called 
inside of user functions
+        to access contextual information about running tasks.
+
+        .. note:: Must be called on the worker, not the driver. Returns None 
if not initialized.
+        """
+        return cls._taskContext
+
+    @classmethod
+    def _initialize(cls, port, secret):
+        """
+        Initialize BarrierTaskContext, other methods within BarrierTaskContext 
can only be called
+        after BarrierTaskContext is initialized.
+        """
+        cls._port = port
+        cls._secret = secret
+
+    def barrier(self):
+        """
+        .. note:: Experimental
+
+        Sets a global barrier and waits until all tasks in this stage hit this 
barrier.
+        Note this method is only allowed for a BarrierTaskContext.
+
+        .. versionadded:: 2.4.0
+        """
+        if self._port is None or self._secret is None:
+            raise Exception("Not supported to call barrier() before initialize 
" +
+                            "BarrierTaskContext.")
+        else:
+            _load_from_socket(self._port, self._secret)
+
+    def getTaskInfos(self):
+        """
+        .. note:: Experimental
+
+        Returns the all task infos in this barrier stage, the task infos are 
ordered by
+        partitionId.
+        Note this method is only allowed for a BarrierTaskContext.
+
+        .. versionadded:: 2.4.0
+        """
+        if self._port is None or self._secret is None:
+            raise Exception("Not supported to call getTaskInfos() before 
initialize " +
+                            "BarrierTaskContext.")
+        else:
+            addresses = self._localProperties.get("addresses", "")
+            return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")]
+
+
+class BarrierTaskInfo(object):
+    """
+    .. note:: Experimental
+
+    Carries all task infos of a barrier task.
+
+    .. versionadded:: 2.4.0
+    """
+
+    def __init__(self, address):
+        self.address = address

http://git-wip-us.apache.org/repos/asf/spark/blob/ad45299d/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index a4c5fb1..8ac1df5 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -70,7 +70,7 @@ from pyspark.serializers import read_int, BatchedSerializer, 
MarshalSerializer,
 from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
 from pyspark import shuffle
 from pyspark.profiler import BasicProfiler
-from pyspark.taskcontext import TaskContext
+from pyspark.taskcontext import BarrierTaskContext, TaskContext
 
 _have_scipy = False
 _have_numpy = False
@@ -588,6 +588,40 @@ class TaskContextTests(PySparkTestCase):
         finally:
             self.sc.setLocalProperty(key, None)
 
+    def test_barrier(self):
+        """
+        Verify that BarrierTaskContext.barrier() performs global sync among 
all barrier tasks
+        within a stage.
+        """
+        rdd = self.sc.parallelize(range(10), 4)
+
+        def f(iterator):
+            yield sum(iterator)
+
+        def context_barrier(x):
+            tc = BarrierTaskContext.get()
+            time.sleep(random.randint(1, 10))
+            tc.barrier()
+            return time.time()
+
+        times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
+        self.assertTrue(max(times) - min(times) < 1)
+
+    def test_barrier_infos(self):
+        """
+        Verify that BarrierTaskContext.getTaskInfos() returns a list of all 
task infos in the
+        barrier stage.
+        """
+        rdd = self.sc.parallelize(range(10), 4)
+
+        def f(iterator):
+            yield sum(iterator)
+
+        taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: 
BarrierTaskContext.get()
+                                                       
.getTaskInfos()).collect()
+        self.assertTrue(len(taskInfos) == 4)
+        self.assertTrue(len(taskInfos[0]) == 4)
+
 
 class RDDTests(ReusedPySparkTestCase):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ad45299d/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index eaaae2b..d54a5b8 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -28,10 +28,10 @@ import traceback
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.java_gateway import do_server_auth
-from pyspark.taskcontext import TaskContext
+from pyspark.taskcontext import BarrierTaskContext, TaskContext
 from pyspark.files import SparkFiles
 from pyspark.rdd import PythonEvalType
-from pyspark.serializers import write_with_length, write_int, read_long, \
+from pyspark.serializers import write_with_length, write_int, read_long, 
read_bool, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
     BatchedSerializer, ArrowStreamPandasSerializer
 from pyspark.sql.types import to_arrow_type
@@ -259,8 +259,18 @@ def main(infile, outfile):
                              "PYSPARK_DRIVER_PYTHON are correctly set.") %
                             ("%d.%d" % sys.version_info[:2], version))
 
+        # read inputs only for a barrier task
+        isBarrier = read_bool(infile)
+        boundPort = read_int(infile)
+        secret = UTF8Deserializer().loads(infile)
         # initialize global state
-        taskContext = TaskContext._getOrCreate()
+        taskContext = None
+        if isBarrier:
+            taskContext = BarrierTaskContext._getOrCreate()
+            BarrierTaskContext._initialize(boundPort, secret)
+        else:
+            taskContext = TaskContext._getOrCreate()
+        # read inputs for TaskContext info
         taskContext._stageId = read_int(infile)
         taskContext._partitionId = read_int(infile)
         taskContext._attemptNumber = read_int(infile)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to