Repository: spark
Updated Branches:
  refs/heads/master eb5bdcaf6 -> baf9ce1a4


[SPARK-2490] Change recursive visiting on RDD dependencies to iterative approach

When performing some transformations on RDDs after many iterations, the 
dependencies of RDDs could be very long. It can easily cause StackOverflowError 
when recursively visiting these dependencies in Spark core. For example:

    var rdd = sc.makeRDD(Array(1))
    for (i <- 1 to 1000) {
      rdd = rdd.coalesce(1).cache()
      rdd.collect()
    }

This PR changes recursive visiting on rdd's dependencies to iterative approach 
to avoid StackOverflowError.

In addition to the recursive visiting, since the Java serializer has a known 
[bug](http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4152790) that causes 
StackOverflowError too when serializing/deserializing a large graph of objects. 
So applying this PR only solves part of the problem. Using KryoSerializer to 
replace Java serializer might be helpful. However, since KryoSerializer is not 
supported for `spark.closure.serializer` now, I can not test if KryoSerializer 
can solve Java serializer's problem completely.

Author: Liang-Chi Hsieh <[email protected]>

Closes #1418 from viirya/remove_recursive_visit and squashes the following 
commits:

6b2c615 [Liang-Chi Hsieh] change function name; comply with code style.
5f072a7 [Liang-Chi Hsieh] add comments to explain Stack usage.
8742dbb [Liang-Chi Hsieh] comply with code style.
900538b [Liang-Chi Hsieh] change recursive visiting on rdd's dependencies to 
iterative approach to avoid stackoverflowerror.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/baf9ce1a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/baf9ce1a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/baf9ce1a

Branch: refs/heads/master
Commit: baf9ce1a4ecb7acf5accf7a7029f29604b4360c2
Parents: eb5bdca
Author: Liang-Chi Hsieh <[email protected]>
Authored: Fri Aug 1 12:12:30 2014 -0700
Committer: Matei Zaharia <[email protected]>
Committed: Fri Aug 1 12:12:30 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/scheduler/DAGScheduler.scala   | 83 ++++++++++++++++++--
 1 file changed, 75 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/baf9ce1a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5110785..d87c304 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -21,7 +21,7 @@ import java.io.NotSerializableException
 import java.util.Properties
 import java.util.concurrent.atomic.AtomicInteger
 
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
 import scala.concurrent.Await
 import scala.concurrent.duration._
 import scala.language.postfixOps
@@ -211,11 +211,15 @@ class DAGScheduler(
     shuffleToMapStage.get(shuffleDep.shuffleId) match {
       case Some(stage) => stage
       case None =>
+        // We are going to register ancestor shuffle dependencies
+        registerShuffleDependencies(shuffleDep, jobId)
+        // Then register current shuffleDep
         val stage =
           newOrUsedStage(
             shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId,
             shuffleDep.rdd.creationSite)
         shuffleToMapStage(shuffleDep.shuffleId) = stage
+ 
         stage
     }
   }
@@ -280,6 +284,9 @@ class DAGScheduler(
   private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = {
     val parents = new HashSet[Stage]
     val visited = new HashSet[RDD[_]]
+    // We are manually maintaining a stack here to prevent StackOverflowError
+    // caused by recursively visiting
+    val waitingForVisit = new Stack[RDD[_]]
     def visit(r: RDD[_]) {
       if (!visited(r)) {
         visited += r
@@ -290,18 +297,69 @@ class DAGScheduler(
             case shufDep: ShuffleDependency[_, _, _] =>
               parents += getShuffleMapStage(shufDep, jobId)
             case _ =>
-              visit(dep.rdd)
+              waitingForVisit.push(dep.rdd)
           }
         }
       }
     }
-    visit(rdd)
+    waitingForVisit.push(rdd)
+    while (!waitingForVisit.isEmpty) {
+      visit(waitingForVisit.pop())
+    }
     parents.toList
   }
 
+  // Find ancestor missing shuffle dependencies and register into 
shuffleToMapStage
+  private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, 
_], jobId: Int) = {
+    val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd)
+    while (!parentsWithNoMapStage.isEmpty) {
+      val currentShufDep = parentsWithNoMapStage.pop()
+      val stage =
+        newOrUsedStage(
+          currentShufDep.rdd, currentShufDep.rdd.partitions.size, 
currentShufDep, jobId,
+          currentShufDep.rdd.creationSite)
+      shuffleToMapStage(currentShufDep.shuffleId) = stage
+    }
+  }
+
+  // Find ancestor shuffle dependencies that are not registered in 
shuffleToMapStage yet
+  private def getAncestorShuffleDependencies(rdd: RDD[_]): 
Stack[ShuffleDependency[_, _, _]] = {
+    val parents = new Stack[ShuffleDependency[_, _, _]]
+    val visited = new HashSet[RDD[_]]
+    // We are manually maintaining a stack here to prevent StackOverflowError
+    // caused by recursively visiting
+    val waitingForVisit = new Stack[RDD[_]]
+    def visit(r: RDD[_]) {
+      if (!visited(r)) {
+        visited += r
+        for (dep <- r.dependencies) {
+          dep match {
+            case shufDep: ShuffleDependency[_, _, _] =>
+              if (!shuffleToMapStage.contains(shufDep.shuffleId)) {
+                parents.push(shufDep)
+              }
+
+              waitingForVisit.push(shufDep.rdd)
+            case _ =>
+              waitingForVisit.push(dep.rdd)
+          }
+        }
+      }
+    }
+
+    waitingForVisit.push(rdd)
+    while (!waitingForVisit.isEmpty) {
+      visit(waitingForVisit.pop())
+    }
+    parents
+  }
+
   private def getMissingParentStages(stage: Stage): List[Stage] = {
     val missing = new HashSet[Stage]
     val visited = new HashSet[RDD[_]]
+    // We are manually maintaining a stack here to prevent StackOverflowError
+    // caused by recursively visiting
+    val waitingForVisit = new Stack[RDD[_]]
     def visit(rdd: RDD[_]) {
       if (!visited(rdd)) {
         visited += rdd
@@ -314,13 +372,16 @@ class DAGScheduler(
                   missing += mapStage
                 }
               case narrowDep: NarrowDependency[_] =>
-                visit(narrowDep.rdd)
+                waitingForVisit.push(narrowDep.rdd)
             }
           }
         }
       }
     }
-    visit(stage.rdd)
+    waitingForVisit.push(stage.rdd)
+    while (!waitingForVisit.isEmpty) {
+      visit(waitingForVisit.pop())
+    }
     missing.toList
   }
 
@@ -1119,6 +1180,9 @@ class DAGScheduler(
     }
     val visitedRdds = new HashSet[RDD[_]]
     val visitedStages = new HashSet[Stage]
+    // We are manually maintaining a stack here to prevent StackOverflowError
+    // caused by recursively visiting
+    val waitingForVisit = new Stack[RDD[_]]
     def visit(rdd: RDD[_]) {
       if (!visitedRdds(rdd)) {
         visitedRdds += rdd
@@ -1128,15 +1192,18 @@ class DAGScheduler(
               val mapStage = getShuffleMapStage(shufDep, stage.jobId)
               if (!mapStage.isAvailable) {
                 visitedStages += mapStage
-                visit(mapStage.rdd)
+                waitingForVisit.push(mapStage.rdd)
               }  // Otherwise there's no need to follow the dependency back
             case narrowDep: NarrowDependency[_] =>
-              visit(narrowDep.rdd)
+              waitingForVisit.push(narrowDep.rdd)
           }
         }
       }
     }
-    visit(stage.rdd)
+    waitingForVisit.push(stage.rdd)
+    while (!waitingForVisit.isEmpty) {
+      visit(waitingForVisit.pop())
+    }
     visitedRdds.contains(target.rdd)
   }
 

Reply via email to