Repository: spark
Updated Branches:
  refs/heads/master 24db35826 -> 5854f77ce


[SPARK-20244][CORE] Handle incorrect bytesRead metrics when using PySpark

## What changes were proposed in this pull request?

Hadoop FileSystem's statistics in based on thread local variables, this is ok 
if the RDD computation chain is running in the same thread. But if child RDD 
creates another thread to consume the iterator got from Hadoop RDDs, the 
bytesRead computation will be error, because now the iterator's `next()` and 
`close()` may run in different threads. This could be happened when using 
PySpark with PythonRDD.

So here building a map to track the `bytesRead` for different thread and add 
them together. This method will be used in three RDDs, `HadoopRDD`, 
`NewHadoopRDD` and `FileScanRDD`. I assume `FileScanRDD` cannot be called 
directly, so I only fixed `HadoopRDD` and `NewHadoopRDD`.

## How was this patch tested?

Unit test and local cluster verification.

Author: jerryshao <[email protected]>

Closes #17617 from jerryshao/SPARK-20244.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5854f77c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5854f77c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5854f77c

Branch: refs/heads/master
Commit: 5854f77ce1d3b9491e2a6bd1f352459da294e369
Parents: 24db358
Author: jerryshao <[email protected]>
Authored: Wed May 31 22:34:53 2017 -0700
Committer: Wenchen Fan <[email protected]>
Committed: Wed May 31 22:34:53 2017 -0700

----------------------------------------------------------------------
 .../apache/spark/deploy/SparkHadoopUtil.scala   | 28 ++++++++++++++----
 .../scala/org/apache/spark/rdd/HadoopRDD.scala  |  8 ++++-
 .../org/apache/spark/rdd/NewHadoopRDD.scala     |  8 ++++-
 .../spark/metrics/InputOutputMetricsSuite.scala | 31 +++++++++++++++++++-
 4 files changed, 66 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5854f77c/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala 
b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 9cc321a..6afe58b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -23,6 +23,7 @@ import java.text.DateFormat
 import java.util.{Arrays, Comparator, Date, Locale}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 import scala.util.control.NonFatal
 
 import com.google.common.primitives.Longs
@@ -143,14 +144,29 @@ class SparkHadoopUtil extends Logging {
    * Returns a function that can be called to find Hadoop FileSystem bytes 
read. If
    * getFSBytesReadOnThreadCallback is called from thread r at time t, the 
returned callback will
    * return the bytes read on r since t.
-   *
-   * @return None if the required method can't be found.
    */
   private[spark] def getFSBytesReadOnThreadCallback(): () => Long = {
-    val threadStats = 
FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics)
-    val f = () => threadStats.map(_.getBytesRead).sum
-    val baselineBytesRead = f()
-    () => f() - baselineBytesRead
+    val f = () => 
FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum
+    val baseline = (Thread.currentThread().getId, f())
+
+    /**
+     * This function may be called in both spawned child threads and parent 
task thread (in
+     * PythonRDD), and Hadoop FileSystem uses thread local variables to track 
the statistics.
+     * So we need a map to track the bytes read from the child threads and 
parent thread,
+     * summing them together to get the bytes read of this task.
+     */
+    new Function0[Long] {
+      private val bytesReadMap = new mutable.HashMap[Long, Long]()
+
+      override def apply(): Long = {
+        bytesReadMap.synchronized {
+          bytesReadMap.put(Thread.currentThread().getId, f())
+          bytesReadMap.map { case (k, v) =>
+            v - (if (k == baseline._1) baseline._2 else 0)
+          }.sum
+        }
+      }
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/5854f77c/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 4bf8ecc..76ea8b8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -251,7 +251,13 @@ class HadoopRDD[K, V](
             null
         }
       // Register an on-task-completion callback to close the input stream.
-      context.addTaskCompletionListener{ context => closeIfNeeded() }
+      context.addTaskCompletionListener { context =>
+        // Update the bytes read before closing is to make sure lingering 
bytesRead statistics in
+        // this thread get correctly added.
+        updateBytesRead()
+        closeIfNeeded()
+      }
+
       private val key: K = if (reader == null) null.asInstanceOf[K] else 
reader.createKey()
       private val value: V = if (reader == null) null.asInstanceOf[V] else 
reader.createValue()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5854f77c/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index ce3a9a2..482875e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -191,7 +191,13 @@ class NewHadoopRDD[K, V](
         }
 
       // Register an on-task-completion callback to close the input stream.
-      context.addTaskCompletionListener(context => close())
+      context.addTaskCompletionListener { context =>
+        // Update the bytesRead before closing is to make sure lingering 
bytesRead statistics in
+        // this thread get correctly added.
+        updateBytesRead()
+        close()
+      }
+
       private var havePair = false
       private var recordsSinceMetricsUpdate = 0
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5854f77c/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala 
b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index 5d52218..6f4203d 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -34,7 +34,7 @@ import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.{SharedSparkContext, SparkFunSuite}
 import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
   with BeforeAndAfter {
@@ -319,6 +319,35 @@ class InputOutputMetricsSuite extends SparkFunSuite with 
SharedSparkContext
     }
     assert(bytesRead >= tmpFile.length())
   }
+
+  test("input metrics with old Hadoop API in different thread") {
+    val bytesRead = runAndReturnBytesRead {
+      sc.textFile(tmpFilePath, 4).mapPartitions { iter =>
+        val buf = new ArrayBuffer[String]()
+        ThreadUtils.runInNewThread("testThread", false) {
+          iter.flatMap(_.split(" ")).foreach(buf.append(_))
+        }
+
+        buf.iterator
+      }.count()
+    }
+    assert(bytesRead >= tmpFile.length())
+  }
+
+  test("input metrics with new Hadoop API in different thread") {
+    val bytesRead = runAndReturnBytesRead {
+      sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], 
classOf[LongWritable],
+        classOf[Text]).mapPartitions { iter =>
+        val buf = new ArrayBuffer[String]()
+        ThreadUtils.runInNewThread("testThread", false) {
+          iter.map(_._2.toString).flatMap(_.split(" ")).foreach(buf.append(_))
+        }
+
+        buf.iterator
+      }.count()
+    }
+    assert(bytesRead >= tmpFile.length())
+  }
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to