HyukjinKwon commented on a change in pull request #23977: [SPARK-26923][SQL][R] 
Refactor ArrowRRunner and RRunner to share one BaseRRunner
URL: https://github.com/apache/spark/pull/23977#discussion_r263332058
 
 

 ##########
 File path: core/src/main/scala/org/apache/spark/api/r/RRunner.scala
 ##########
 @@ -44,380 +35,149 @@ private[spark] class RRunner[U](
     isDataFrame: Boolean = false,
     colNames: Array[String] = null,
     mode: Int = RRunnerModes.RDD)
-  extends Logging {
-  protected var bootTime: Double = _
-  private var dataStream: DataInputStream = _
-  val readData = numPartitions match {
-    case -1 =>
-      serializer match {
-        case SerializationFormats.STRING => readStringData _
-        case _ => readByteArrayData _
-      }
-    case _ => readShuffledData _
-  }
-
-  def compute(
-      inputIterator: Iterator[_],
-      partitionIndex: Int): Iterator[U] = {
-    // Timing start
-    bootTime = System.currentTimeMillis / 1000.0
-
-    // we expect two connections
-    val serverSocket = new ServerSocket(0, 2, 
InetAddress.getByName("localhost"))
-    val listenPort = serverSocket.getLocalPort()
-
-    // The stdout/stderr is shared by multiple tasks, because we use one daemon
-    // to launch child process as worker.
-    val errThread = RRunner.createRWorker(listenPort)
-
-    // We use two sockets to separate input and output, then it's easy to 
manage
-    // the lifecycle of them to avoid deadlock.
-    // TODO: optimize it to use one socket
-
-    // the socket used to send out the input of task
-    serverSocket.setSoTimeout(10000)
-    dataStream = try {
-      val inSocket = serverSocket.accept()
-      RRunner.authHelper.authClient(inSocket)
-      startStdinThread(inSocket.getOutputStream(), inputIterator, 
partitionIndex)
-
-      // the socket used to receive the output of task
-      val outSocket = serverSocket.accept()
-      RRunner.authHelper.authClient(outSocket)
-      val inputStream = new BufferedInputStream(outSocket.getInputStream)
-      new DataInputStream(inputStream)
-    } finally {
-      serverSocket.close()
-    }
-
-    try {
-      newReaderIterator(dataStream, errThread)
-    } catch {
-      case e: Exception =>
-        throw new SparkException("R computation failed with\n " + 
errThread.getLines(), e)
-    }
-  }
+  extends BaseRRunner[IN, OUT](
+    func,
+    deserializer,
+    serializer,
+    packageNames,
+    broadcastVars,
+    numPartitions,
+    isDataFrame,
+    colNames,
+    mode) {
 
   protected def newReaderIterator(
-      dataStream: DataInputStream, errThread: BufferedStreamThread): 
Iterator[U] = {
-    new Iterator[U] {
-      def next(): U = {
-        val obj = _nextObj
-        if (hasNext()) {
-          _nextObj = read()
-        }
-        obj
+      dataStream: DataInputStream, errThread: BufferedStreamThread): 
ReaderIterator = {
+    new ReaderIterator(dataStream, errThread) {
+      private val readData = numPartitions match {
+        case -1 =>
+          serializer match {
+            case SerializationFormats.STRING => readStringData _
+            case _ => readByteArrayData _
+          }
+        case _ => readShuffledData _
       }
 
-      private var _nextObj = read()
-
-      def hasNext(): Boolean = {
-        val hasMore = _nextObj != null
-        if (!hasMore) {
-          dataStream.close()
+      private def readShuffledData(length: Int): (Int, Array[Byte]) = {
+        length match {
+          case length if length == 2 =>
+            val hashedKey = dataStream.readInt()
+            val contentPairsLength = dataStream.readInt()
+            val contentPairs = new Array[Byte](contentPairsLength)
+            dataStream.readFully(contentPairs)
+            (hashedKey, contentPairs)
+          case _ => null
         }
-        hasMore
       }
-    }
-  }
 
-  protected def writeData(
-      dataOut: DataOutputStream,
-      printOut: PrintStream,
-      iter: Iterator[_]): Unit = {
-    def writeElem(elem: Any): Unit = {
-      if (deserializer == SerializationFormats.BYTE) {
-        val elemArr = elem.asInstanceOf[Array[Byte]]
-        dataOut.writeInt(elemArr.length)
-        dataOut.write(elemArr)
-      } else if (deserializer == SerializationFormats.ROW) {
-        dataOut.write(elem.asInstanceOf[Array[Byte]])
-      } else if (deserializer == SerializationFormats.STRING) {
-        // write string(for StringRRDD)
-        // scalastyle:off println
-        printOut.println(elem)
-        // scalastyle:on println
+      private def readByteArrayData(length: Int): Array[Byte] = {
+        length match {
+          case length if length > 0 =>
+            val obj = new Array[Byte](length)
+            dataStream.readFully(obj)
+            obj
+          case _ => null
+        }
       }
-    }
 
-    for (elem <- iter) {
-      elem match {
-        case (key, innerIter: Iterator[_]) =>
-          for (innerElem <- innerIter) {
-            writeElem(innerElem)
-          }
-          // Writes key which can be used as a boundary in group-aggregate
-          dataOut.writeByte('r')
-          writeElem(key)
-        case (key, value) =>
-          writeElem(key)
-          writeElem(value)
-        case _ =>
-          writeElem(elem)
+      private def readStringData(length: Int): String = {
+        length match {
+          case length if length > 0 =>
+            SerDe.readStringBytes(dataStream, length)
+          case _ => null
+        }
       }
-    }
-  }
-
-  /**
-   * Start a thread to write RDD data to the R process.
-   */
-  private def startStdinThread(
-      output: OutputStream,
-      iter: Iterator[_],
-      partitionIndex: Int): Unit = {
-    val env = SparkEnv.get
-    val taskContext = TaskContext.get()
-    val bufferSize = System.getProperty(BUFFER_SIZE.key,
-      BUFFER_SIZE.defaultValueString).toInt
-    val stream = new BufferedOutputStream(output, bufferSize)
 
-    new Thread("writer for R") {
-      override def run(): Unit = {
+      /**
+       * Reads next object from the stream.
+       * When the stream reaches end of data, needs to process the following 
sections,
+       * and then returns null.
+       */
+      override protected def read(): OUT = {
         try {
-          SparkEnv.set(env)
-          TaskContext.setTaskContext(taskContext)
-          val dataOut = new DataOutputStream(stream)
-          dataOut.writeInt(partitionIndex)
-
-          SerDe.writeString(dataOut, deserializer)
-          SerDe.writeString(dataOut, serializer)
-
-          dataOut.writeInt(packageNames.length)
-          dataOut.write(packageNames)
-
-          dataOut.writeInt(func.length)
-          dataOut.write(func)
-
-          dataOut.writeInt(broadcastVars.length)
-          broadcastVars.foreach { broadcast =>
-            // TODO(shivaram): Read a Long in R to avoid this cast
-            dataOut.writeInt(broadcast.id.toInt)
-            // TODO: Pass a byte array from R to avoid this cast ?
-            val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
-            dataOut.writeInt(broadcastByteArr.length)
-            dataOut.write(broadcastByteArr)
-          }
-
-          dataOut.writeInt(numPartitions)
-          dataOut.writeInt(mode)
-
-          if (isDataFrame) {
-            SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null)
+          val length = dataStream.readInt()
+
+          length match {
+            case SpecialLengths.TIMING_DATA =>
+              // Timing data from R worker
+              val boot = dataStream.readDouble - bootTime
+              val init = dataStream.readDouble
+              val broadcast = dataStream.readDouble
+              val input = dataStream.readDouble
+              val compute = dataStream.readDouble
+              val output = dataStream.readDouble
+              logInfo(
+                ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
+                  "read-input = %.3f s, compute = %.3f s, write-output = %.3f 
s, " +
+                  "total = %.3f s").format(
+                  boot,
+                  init,
+                  broadcast,
+                  input,
+                  compute,
+                  output,
+                  boot + init + broadcast + input + compute + output))
+              read()
+            case length if length > 0 =>
+              readData(length).asInstanceOf[OUT]
+            case length if length == 0 =>
 
 Review comment:
   In case or ArrowRRunner, it behaves same as previous.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to