BryanCutler commented on a change in pull request #24070: [SPARK-23961][PYTHON] 
Fix error when toLocalIterator goes out of scope
URL: https://github.com/apache/spark/pull/24070#discussion_r279589039
 
 

 ##########
 File path: core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
 ##########
 @@ -163,8 +164,63 @@ private[spark] object PythonRDD extends Logging {
     serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
   }
 
+  /**
+   * A helper function to create a local RDD iterator and serve it via socket. 
Partitions are
+   * are collected as separate jobs, by order of index. Partition data is 
first requested by a
+   * non-zero integer to start a collection job. The response is prefaced by 
an integer with 1
+   * meaning partition data will be served, 0 meaning the local iterator has 
been consumed,
+   * and -1 meaining an error occurred during collection. This function is 
used by
+   * pyspark.rdd._local_iterator_from_socket().
+   *
+   * @return 2-tuple (as a Java array) with the port number of a local socket 
which serves the
+   *         data collected from these jobs, and the secret for authentication.
+   */
   def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
-    serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
+    val (port, secret) = SocketAuthServer.setupOneConnectionServer(
+        authHelper, "serve toLocalIterator") { s =>
+      val out = new DataOutputStream(s.getOutputStream)
+      val in = new DataInputStream(s.getInputStream)
+      Utils.tryWithSafeFinally {
+
+        // Collects a partition on each iteration
+        val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
+          rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, 
Seq(i)).head
+        }
+
+        // Read request for data and send next partition if nonzero
+        var complete = false
+        while (!complete && in.readInt() != 0) {
+          if (collectPartitionIter.hasNext) {
+            try {
+              // Attempt to collect the next partition
+              val partitionArray = collectPartitionIter.next()
+
+              // Send response there is a partition to read
+              out.writeInt(1)
+
+              // Write the next object and signal end of data for this 
iteration
+              writeIteratorToStream(partitionArray.toIterator, out)
+              out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+              out.flush()
+            } catch {
+              case e: SparkException =>
 
 Review comment:
   We want to catch any errors during the collection job, so I believe the 
`SparkException` should all that is needed here?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to