Repository: spark
Updated Branches:
  refs/heads/master 8ae2da0b2 -> 2881a2d1d


[SPARK-17919] Make timeout to RBackend configurable in SparkR

## What changes were proposed in this pull request?

This patch makes RBackend connection timeout configurable by user.

## How was this patch tested?
N/A

Author: Hossein <hoss...@databricks.com>

Closes #15471 from falaki/SPARK-17919.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2881a2d1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2881a2d1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2881a2d1

Branch: refs/heads/master
Commit: 2881a2d1d1a650a91df2c6a01275eba14a43b42a
Parents: 8ae2da0
Author: Hossein <hoss...@databricks.com>
Authored: Sun Oct 30 16:17:23 2016 -0700
Committer: Felix Cheung <felixche...@apache.org>
Committed: Sun Oct 30 16:17:23 2016 -0700

----------------------------------------------------------------------
 R/pkg/R/backend.R                               | 20 ++++++++--
 R/pkg/R/client.R                                |  2 +-
 R/pkg/R/sparkR.R                                |  8 +++-
 R/pkg/inst/worker/daemon.R                      |  4 +-
 R/pkg/inst/worker/worker.R                      |  7 +++-
 .../scala/org/apache/spark/api/r/RBackend.scala | 15 +++++++-
 .../apache/spark/api/r/RBackendHandler.scala    | 39 ++++++++++++++++++--
 .../scala/org/apache/spark/api/r/RRunner.scala  |  3 ++
 .../org/apache/spark/api/r/SparkRDefaults.scala | 30 +++++++++++++++
 .../scala/org/apache/spark/deploy/RRunner.scala |  7 +++-
 docs/configuration.md                           | 15 ++++++++
 11 files changed, 134 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2881a2d1/R/pkg/R/backend.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R
index 03e70bb..0a789e6 100644
--- a/R/pkg/R/backend.R
+++ b/R/pkg/R/backend.R
@@ -108,13 +108,27 @@ invokeJava <- function(isStatic, objId, methodName, ...) {
   conn <- get(".sparkRCon", .sparkREnv)
   writeBin(requestMessage, conn)
 
-  # TODO: check the status code to output error information
   returnStatus <- readInt(conn)
+  handleErrors(returnStatus, conn)
+
+  # Backend will send +1 as keep alive value to prevent various connection 
timeouts
+  # on very long running jobs. See spark.r.heartBeatInterval
+  while (returnStatus == 1) {
+    returnStatus <- readInt(conn)
+    handleErrors(returnStatus, conn)
+  }
+
+  readObject(conn)
+}
+
+# Helper function to check for returned errors and print appropriate error 
message to user
+handleErrors <- function(returnStatus, conn) {
   if (length(returnStatus) == 0) {
     stop("No status is returned. Java SparkR backend might have failed.")
   }
-  if (returnStatus != 0) {
+
+  # 0 is success and +1 is reserved for heartbeats. Other negative values 
indicate errors.
+  if (returnStatus < 0) {
     stop(readString(conn))
   }
-  readObject(conn)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2881a2d1/R/pkg/R/client.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R
index 2d341d8..9d82814 100644
--- a/R/pkg/R/client.R
+++ b/R/pkg/R/client.R
@@ -19,7 +19,7 @@
 
 # Creates a SparkR client connection object
 # if one doesn't already exist
-connectBackend <- function(hostname, port, timeout = 6000) {
+connectBackend <- function(hostname, port, timeout) {
   if (exists(".sparkRcon", envir = .sparkREnv)) {
     if (isOpen(.sparkREnv[[".sparkRCon"]])) {
       cat("SparkRBackend client connection already exists\n")

http://git-wip-us.apache.org/repos/asf/spark/blob/2881a2d1/R/pkg/R/sparkR.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index cc6d591..6b4a2f2 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -154,6 +154,7 @@ sparkR.sparkContext <- function(
   packages <- processSparkPackages(sparkPackages)
 
   existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "")
+  connectionTimeout <- 
as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
   if (existingPort != "") {
     if (length(packages) != 0) {
       warning(paste("sparkPackages has no effect when using spark-submit or 
sparkR shell",
@@ -187,6 +188,7 @@ sparkR.sparkContext <- function(
     backendPort <- readInt(f)
     monitorPort <- readInt(f)
     rLibPath <- readString(f)
+    connectionTimeout <- readInt(f)
     close(f)
     file.remove(path)
     if (length(backendPort) == 0 || backendPort == 0 ||
@@ -194,7 +196,9 @@ sparkR.sparkContext <- function(
         length(rLibPath) != 1) {
       stop("JVM failed to launch")
     }
-    assign(".monitorConn", socketConnection(port = monitorPort), envir = 
.sparkREnv)
+    assign(".monitorConn",
+           socketConnection(port = monitorPort, timeout = connectionTimeout),
+           envir = .sparkREnv)
     assign(".backendLaunched", 1, envir = .sparkREnv)
     if (rLibPath != "") {
       assign(".libPath", rLibPath, envir = .sparkREnv)
@@ -204,7 +208,7 @@ sparkR.sparkContext <- function(
 
   .sparkREnv$backendPort <- backendPort
   tryCatch({
-    connectBackend("localhost", backendPort)
+    connectBackend("localhost", backendPort, timeout = connectionTimeout)
   },
   error = function(err) {
     stop("Failed to connect JVM\n")

http://git-wip-us.apache.org/repos/asf/spark/blob/2881a2d1/R/pkg/inst/worker/daemon.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R
index b92e6be..3a318b7 100644
--- a/R/pkg/inst/worker/daemon.R
+++ b/R/pkg/inst/worker/daemon.R
@@ -18,6 +18,7 @@
 # Worker daemon
 
 rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
+connectionTimeout <- 
as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
 dirs <- strsplit(rLibDir, ",")[[1]]
 script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R")
 
@@ -26,7 +27,8 @@ script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R")
 suppressPackageStartupMessages(library(SparkR))
 
 port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
-inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, 
timeout = 3600)
+inputCon <- socketConnection(
+    port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
 
 while (TRUE) {
   ready <- socketSelect(list(inputCon))

http://git-wip-us.apache.org/repos/asf/spark/blob/2881a2d1/R/pkg/inst/worker/worker.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index cfe41de..03e7450 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -90,6 +90,7 @@ bootTime <- currentTimeSecs()
 bootElap <- elapsedSecs()
 
 rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
+connectionTimeout <- 
as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
 dirs <- strsplit(rLibDir, ",")[[1]]
 # Set libPaths to include SparkR package as loadNamespace needs this
 # TODO: Figure out if we can avoid this by not loading any objects that require
@@ -98,8 +99,10 @@ dirs <- strsplit(rLibDir, ",")[[1]]
 suppressPackageStartupMessages(library(SparkR))
 
 port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
-inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb")
-outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb")
+inputCon <- socketConnection(
+    port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
+outputCon <- socketConnection(
+    port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
 
 # read the index of the current partition inside the RDD
 partition <- SparkR:::readInt(inputCon)

http://git-wip-us.apache.org/repos/asf/spark/blob/2881a2d1/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala 
b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 41d0a85..550746c 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -22,12 +22,13 @@ import java.net.{InetAddress, InetSocketAddress, 
ServerSocket}
 import java.util.concurrent.TimeUnit
 
 import io.netty.bootstrap.ServerBootstrap
-import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
+import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, 
EventLoopGroup}
 import io.netty.channel.nio.NioEventLoopGroup
 import io.netty.channel.socket.SocketChannel
 import io.netty.channel.socket.nio.NioServerSocketChannel
 import io.netty.handler.codec.LengthFieldBasedFrameDecoder
 import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
+import io.netty.handler.timeout.ReadTimeoutHandler
 
 import org.apache.spark.SparkConf
 import org.apache.spark.internal.Logging
@@ -43,7 +44,10 @@ private[spark] class RBackend {
 
   def init(): Int = {
     val conf = new SparkConf()
-    bossGroup = new 
NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2))
+    val backendConnectionTimeout = conf.getInt(
+      "spark.r.backendConnectionTimeout", 
SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
+    bossGroup = new NioEventLoopGroup(
+      conf.getInt("spark.r.numRBackendThreads", 
SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS))
     val workerGroup = bossGroup
     val handler = new RBackendHandler(this)
 
@@ -63,6 +67,7 @@ private[spark] class RBackend {
             // initialBytesToStrip = 4, i.e. strip out the length field itself
             new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
           .addLast("decoder", new ByteArrayDecoder())
+          .addLast("readTimeoutHandler", new 
ReadTimeoutHandler(backendConnectionTimeout))
           .addLast("handler", handler)
       }
     })
@@ -110,6 +115,11 @@ private[spark] object RBackend extends Logging {
       val boundPort = sparkRBackend.init()
       val serverSocket = new ServerSocket(0, 1, 
InetAddress.getByName("localhost"))
       val listenPort = serverSocket.getLocalPort()
+      // Connection timeout is set by socket client. To make it configurable 
we will pass the
+      // timeout value to client inside the temp file
+      val conf = new SparkConf()
+      val backendConnectionTimeout = conf.getInt(
+        "spark.r.backendConnectionTimeout", 
SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
 
       // tell the R process via temporary file
       val path = args(0)
@@ -118,6 +128,7 @@ private[spark] object RBackend extends Logging {
       dos.writeInt(boundPort)
       dos.writeInt(listenPort)
       SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
+      dos.writeInt(backendConnectionTimeout)
       dos.close()
       f.renameTo(new File(path))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2881a2d1/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala 
b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 1422ef8..9f5afa2 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -18,16 +18,19 @@
 package org.apache.spark.api.r
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, 
DataOutputStream}
+import java.util.concurrent.TimeUnit
 
 import scala.collection.mutable.HashMap
 import scala.language.existentials
 
 import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
 import io.netty.channel.ChannelHandler.Sharable
+import io.netty.handler.timeout.ReadTimeoutException
 
 import org.apache.spark.api.r.SerDe._
 import org.apache.spark.internal.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.SparkConf
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 /**
  * Handler for RBackend
@@ -83,7 +86,29 @@ private[r] class RBackendHandler(server: RBackend)
           writeString(dos, s"Error: unknown method $methodName")
       }
     } else {
+      // To avoid timeouts when reading results in SparkR driver, we will be 
regularly sending
+      // heartbeat responses. We use special code +1 to signal the client that 
backend is
+      // alive and it should continue blocking for result.
+      val execService = 
ThreadUtils.newDaemonSingleThreadScheduledExecutor("SparkRKeepAliveThread")
+      val pingRunner = new Runnable {
+        override def run(): Unit = {
+          val pingBaos = new ByteArrayOutputStream()
+          val pingDaos = new DataOutputStream(pingBaos)
+          writeInt(pingDaos, +1)
+          ctx.write(pingBaos.toByteArray)
+        }
+      }
+      val conf = new SparkConf()
+      val heartBeatInterval = conf.getInt(
+        "spark.r.heartBeatInterval", SparkRDefaults.DEFAULT_HEARTBEAT_INTERVAL)
+      val backendConnectionTimeout = conf.getInt(
+        "spark.r.backendConnectionTimeout", 
SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
+      val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1)
+
+      execService.scheduleAtFixedRate(pingRunner, interval, interval, 
TimeUnit.SECONDS)
       handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
+      execService.shutdown()
+      execService.awaitTermination(1, TimeUnit.SECONDS)
     }
 
     val reply = bos.toByteArray
@@ -95,9 +120,15 @@ private[r] class RBackendHandler(server: RBackend)
   }
 
   override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): 
Unit = {
-    // Close the connection when an exception is raised.
-    cause.printStackTrace()
-    ctx.close()
+    cause match {
+      case timeout: ReadTimeoutException =>
+        // Do nothing. We don't want to timeout on read
+        logWarning("Ignoring read timeout in RBackendHandler")
+      case _ =>
+        // Close the connection when an exception is raised.
+        cause.printStackTrace()
+        ctx.close()
+    }
   }
 
   def handleMethodCall(

http://git-wip-us.apache.org/repos/asf/spark/blob/2881a2d1/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala 
b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 496fdf8..7ef6472 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -333,6 +333,8 @@ private[r] object RRunner {
     var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
     rCommand = sparkConf.get("spark.r.command", rCommand)
 
+    val rConnectionTimeout = sparkConf.getInt(
+      "spark.r.backendConnectionTimeout", 
SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
     val rOptions = "--vanilla"
     val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
     val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
@@ -344,6 +346,7 @@ private[r] object RRunner {
     pb.environment().put("R_TESTS", "")
     pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
     pb.environment().put("SPARKR_WORKER_PORT", port.toString)
+    pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", 
rConnectionTimeout.toString)
     pb.redirectErrorStream(true)  // redirect stderr into stdout
     val proc = pb.start()
     val errThread = startStdoutThread(proc)

http://git-wip-us.apache.org/repos/asf/spark/blob/2881a2d1/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala 
b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala
new file mode 100644
index 0000000..af67cbb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+private[spark] object SparkRDefaults {
+
+  // Default value for spark.r.backendConnectionTimeout config
+  val DEFAULT_CONNECTION_TIMEOUT: Int = 6000
+
+  // Default value for spark.r.heartBeatInterval config
+  val DEFAULT_HEARTBEAT_INTERVAL: Int = 100
+
+  // Default value for spark.r.numRBackendThreads config
+  val DEFAULT_NUM_RBACKEND_THREADS = 2
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2881a2d1/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala 
b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index d046683..6eb53a8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
 import org.apache.hadoop.fs.Path
 
 import org.apache.spark.{SparkException, SparkUserAppException}
-import org.apache.spark.api.r.{RBackend, RUtils}
+import org.apache.spark.api.r.{RBackend, RUtils, SparkRDefaults}
 import org.apache.spark.util.RedirectThread
 
 /**
@@ -51,6 +51,10 @@ object RRunner {
       cmd
     }
 
+    //  Connection timeout set by R process on its connection to RBackend in 
seconds.
+    val backendConnectionTimeout = sys.props.getOrElse(
+      "spark.r.backendConnectionTimeout", 
SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT.toString)
+
     // Check if the file path exists.
     // If not, change directory to current working directory for YARN cluster 
mode
     val rF = new File(rFile)
@@ -81,6 +85,7 @@ object RRunner {
         val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ 
otherArgs).asJava)
         val env = builder.environment()
         env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
+        env.put("SPARKR_BACKEND_CONNECTION_TIMEOUT", backendConnectionTimeout)
         val rPackageDir = RUtils.sparkRPackagePath(isDriver = true)
         // Put the R package directories into an env variable of 
comma-separated paths
         env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(","))

http://git-wip-us.apache.org/repos/asf/spark/blob/2881a2d1/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index 6600cb6..780fc94 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1890,6 +1890,21 @@ showDF(properties, numRows = 200, truncate = FALSE)
     <code>spark.r.shell.command</code> is used for sparkR shell while 
<code>spark.r.driver.command</code> is used for running R script.
   </td>
 </tr>
+<tr>
+  <td><code>spark.r.backendConnectionTimeout</code></td>
+  <td>6000</td>
+  <td>
+    Connection timeout set by R process on its connection to RBackend in 
seconds.
+  </td>
+</tr>
+<tr>
+  <td><code>spark.r.heartBeatInterval</code></td>
+  <td>100</td>
+  <td>
+    Interval for heartbeats sents from SparkR backend to R process to prevent 
connection timeout.
+  </td>
+</tr>
+
 </table>
 
 #### Deploy


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to