This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 320694c87314 [SPARK-45245][PYTHON][CONNECT] PythonWorkerFactory: 
Timeout if worker does not connect back
320694c87314 is described below

commit 320694c87314efad52682f9af64e1a2278186ff9
Author: Raghu Angadi <[email protected]>
AuthorDate: Tue Oct 31 08:49:51 2023 +0900

    [SPARK-45245][PYTHON][CONNECT] PythonWorkerFactory: Timeout if worker does 
not connect back
    
    ### What changes were proposed in this pull request?
    
    `createSimpleWorker()` method in `PythonWorkerFactory` waits forever if the 
worker fails to connect back to the server.
    
    This is because it calls accept() without a timeout. If the worker does not 
connect back, accept() waits forever. There is supposed to be 10 seconds 
timeout, but it was not implemented correctly.
    
    This PR adds a 10 second timeout.
    
    ### Why are the changes needed?
    
    Otherwise create method could be stuck forever.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
     - Unit test
     - Manual
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: ChatGPT 4.0
    Asked ChatGPT to generate sample code to do non-blocking accept() on a 
socket channel in Java.
    
    Closes #43023 from rangadi/fix-py-accept.
    
    Authored-by: Raghu Angadi <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../spark/api/python/PythonWorkerFactory.scala     | 15 ++++--
 .../api/python/PythonWorkerFactorySuite.scala      | 61 ++++++++++++++++++++++
 2 files changed, 73 insertions(+), 3 deletions(-)

diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 14265f03795f..f6dbeadd96f4 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -19,6 +19,7 @@ package org.apache.spark.api.python
 
 import java.io.{DataInputStream, DataOutputStream, EOFException, File, 
InputStream}
 import java.net.{InetAddress, InetSocketAddress, SocketException}
+import java.net.SocketTimeoutException
 import java.nio.channels._
 import java.util.Arrays
 import java.util.concurrent.TimeUnit
@@ -184,10 +185,18 @@ private[spark] class PythonWorkerFactory(
       redirectStreamsToStderr(workerProcess.getInputStream, 
workerProcess.getErrorStream)
 
       // Wait for it to connect to our socket, and validate the auth secret.
-      serverSocketChannel.socket().setSoTimeout(10000)
-
       try {
-        val socketChannel = serverSocketChannel.accept()
+        // Wait up to 10 seconds for client to connect.
+        serverSocketChannel.configureBlocking(false)
+        val serverSelector = Selector.open()
+        serverSocketChannel.register(serverSelector, SelectionKey.OP_ACCEPT)
+        val socketChannel =
+          if (serverSelector.select(10 * 1000) > 0) { // Wait up to 10 seconds.
+            serverSocketChannel.accept()
+          } else {
+            throw new SocketTimeoutException(
+              "Timed out while waiting for the Python worker to connect back")
+          }
         authHelper.authClient(socketChannel.socket())
         val pid = workerProcess.toHandle.pid()
         if (pid < 0) {
diff --git 
a/core/src/test/scala/org/apache/spark/api/python/PythonWorkerFactorySuite.scala
 
b/core/src/test/scala/org/apache/spark/api/python/PythonWorkerFactorySuite.scala
new file mode 100644
index 000000000000..34c10bd95ed7
--- /dev/null
+++ 
b/core/src/test/scala/org/apache/spark/api/python/PythonWorkerFactorySuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.net.SocketTimeoutException
+
+// scalastyle:off executioncontextglobal
+import scala.concurrent.ExecutionContext.Implicits.global
+// scalastyle:on executioncontextglobal
+import scala.concurrent.Future
+import scala.concurrent.duration._
+
+import org.scalatest.matchers.must.Matchers
+
+import org.apache.spark.SharedSparkContext
+import org.apache.spark.SparkException
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.util.ThreadUtils
+
+// Tests for PythonWorkerFactory.
+class PythonWorkerFactorySuite extends SparkFunSuite with Matchers with 
SharedSparkContext {
+
+  test("createSimpleWorker() fails with a timeout error if worker does not 
connect back") {
+    // It verifies that server side times out in accept(), if the worker does 
not connect back.
+    // E.g. the worker might fail at the beginning before it tries to connect 
back.
+
+    val workerFactory = new PythonWorkerFactory(
+      "python3", "pyspark.testing.non_existing_worker_module", Map.empty
+    )
+
+    // Create the worker in a separate thread so that if there is a bug where 
it does not
+    // return (accept() used to be blocking), the test doesn't hang for a long 
time.
+    val createFuture = Future {
+      val ex = intercept[SparkException] {
+        workerFactory.createSimpleWorker(blockingMode = true) // blockingMode 
doesn't matter.
+        // NOTE: This takes 10 seconds (which is the accept timeout in 
PythonWorkerFactory).
+        // That makes this a bit longish test.
+      }
+      assert(ex.getMessage.contains("Python worker failed to connect back"))
+      assert(ex.getCause.isInstanceOf[SocketTimeoutException])
+    }
+
+    // Timeout ensures that the test fails in 5 minutes if 
createSimplerWorker() doesn't return.
+    ThreadUtils.awaitReady(createFuture, 5.minutes)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to