Repository: spark
Updated Branches:
  refs/heads/branch-1.3 4ff577160 -> bc92a2e40


[SPARK-5955][MLLIB] add checkpointInterval to ALS

Add checkpiontInterval to ALS to prevent:

1. StackOverflow exceptions caused by long lineage,
2. large shuffle files generated during iterations,
3. slow recovery when some node fail.

srowen coderxiang

Author: Xiangrui Meng <[email protected]>

Closes #5076 from mengxr/SPARK-5955 and squashes the following commits:

df56791 [Xiangrui Meng] update impl to reuse code
29affcb [Xiangrui Meng] do not materialize factors in implicit
20d3f7f [Xiangrui Meng] add checkpointInterval to ALS

(cherry picked from commit 6b36470c66bd6140c45e45d3f1d51b0082c3fd97)
Signed-off-by: Xiangrui Meng <[email protected]>

Conflicts:
        mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc92a2e4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc92a2e4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc92a2e4

Branch: refs/heads/branch-1.3
Commit: bc92a2e405241542770a64adfef39dcb02e96461
Parents: 4ff5771
Author: Xiangrui Meng <[email protected]>
Authored: Fri Mar 20 15:02:57 2015 -0400
Committer: Xiangrui Meng <[email protected]>
Committed: Tue Mar 24 11:32:18 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/param/sharedParams.scala    | 11 +++++
 .../apache/spark/ml/recommendation/ALS.scala    | 42 +++++++++++++++++---
 .../apache/spark/mllib/recommendation/ALS.scala | 17 ++++++++
 .../spark/ml/recommendation/ALSSuite.scala      | 17 ++++++++
 4 files changed, 82 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bc92a2e4/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index 1a70322..5d660d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -138,3 +138,14 @@ private[ml] trait HasOutputCol extends Params {
   /** @group getParam */
   def getOutputCol: String = get(outputCol)
 }
+
+private[ml] trait HasCheckpointInterval extends Params {
+  /**
+   * param for checkpoint interval
+   * @group param
+   */
+  val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", 
"checkpoint interval")
+
+  /** @group getParam */
+  def getCheckpointInterval: Int = get(checkpointInterval)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bc92a2e4/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 7bb69df..058076d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.recommendation
 
 import java.{util => ju}
+import java.io.IOException
 
 import scala.collection.mutable
 import scala.reflect.ClassTag
@@ -26,6 +27,7 @@ import scala.util.hashing.byteswap64
 
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
 import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
+import org.apache.hadoop.fs.{FileSystem, Path}
 import org.jblas.DoubleMatrix
 import org.netlib.util.intW
 
@@ -47,7 +49,7 @@ import org.apache.spark.util.random.XORShiftRandom
  * Common params for ALS.
  */
 private[recommendation] trait ALSParams extends Params with HasMaxIter with 
HasRegParam
-  with HasPredictionCol {
+  with HasPredictionCol with HasCheckpointInterval {
 
   /**
    * Param for rank of the matrix factorization.
@@ -165,6 +167,7 @@ class ALSModel private[ml] (
     itemFactors: RDD[(Int, Array[Float])])
   extends Model[ALSModel] with ALSParams {
 
+  /** @group setParam */
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
   override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
@@ -263,6 +266,9 @@ class ALS extends Estimator[ALSModel] with ALSParams {
   /** @group setParam */
   def setNonnegative(value: Boolean): this.type = set(nonnegative, value)
 
+  /** @group setParam */
+  def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, 
value)
+
   /**
    * Sets both numUserBlocks and numItemBlocks to the specific value.
    * @group setParam
@@ -275,6 +281,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
 
   setMaxIter(20)
   setRegParam(1.0)
+  setCheckpointInterval(10)
 
   override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
     val map = this.paramMap ++ paramMap
@@ -286,7 +293,8 @@ class ALS extends Estimator[ALSModel] with ALSParams {
     val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
       numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
       maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = 
map(implicitPrefs),
-      alpha = map(alpha), nonnegative = map(nonnegative))
+      alpha = map(alpha), nonnegative = map(nonnegative),
+      checkpointInterval = map(checkpointInterval))
     val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
     Params.inheritValues(map, this, model)
     model
@@ -496,6 +504,7 @@ object ALS extends Logging {
       nonnegative: Boolean = false,
       intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
       finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
+      checkpointInterval: Int = 10,
       seed: Long = 0L)(
       implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, 
Array[Float])]) = {
     require(intermediateRDDStorageLevel != StorageLevel.NONE,
@@ -523,6 +532,18 @@ object ALS extends Logging {
     val seedGen = new XORShiftRandom(seed)
     var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
     var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
+    var previousCheckpointFile: Option[String] = None
+    val shouldCheckpoint: Int => Boolean = (iter) =>
+      sc.checkpointDir.isDefined && (iter % checkpointInterval == 0)
+    val deletePreviousCheckpointFile: () => Unit = () =>
+      previousCheckpointFile.foreach { file =>
+        try {
+          FileSystem.get(sc.hadoopConfiguration).delete(new Path(file), true)
+        } catch {
+          case e: IOException =>
+            logWarning(s"Cannot delete checkpoint file $file:", e)
+        }
+      }
     if (implicitPrefs) {
       for (iter <- 1 to maxIter) {
         
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
@@ -530,19 +551,30 @@ object ALS extends Logging {
         itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, 
rank, regParam,
           userLocalIndexEncoder, implicitPrefs, alpha, solver)
         previousItemFactors.unpersist()
-        if (sc.checkpointDir.isDefined && (iter % 3 == 0)) {
-          itemFactors.checkpoint()
-        }
         
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
+        // TODO: Generalize PeriodicGraphCheckpointer and use it here.
+        if (shouldCheckpoint(iter)) {
+          itemFactors.checkpoint() // itemFactors gets materialized in 
computeFactors.
+        }
         val previousUserFactors = userFactors
         userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, 
rank, regParam,
           itemLocalIndexEncoder, implicitPrefs, alpha, solver)
+        if (shouldCheckpoint(iter)) {
+          deletePreviousCheckpointFile()
+          previousCheckpointFile = itemFactors.getCheckpointFile
+        }
         previousUserFactors.unpersist()
       }
     } else {
       for (iter <- 0 until maxIter) {
         itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, 
rank, regParam,
           userLocalIndexEncoder, solver = solver)
+        if (shouldCheckpoint(iter)) {
+          itemFactors.checkpoint()
+          itemFactors.count() // checkpoint item factors and cut lineage
+          deletePreviousCheckpointFile()
+          previousCheckpointFile = itemFactors.getCheckpointFile
+        }
         userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, 
rank, regParam,
           itemLocalIndexEncoder, solver = solver)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/bc92a2e4/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index caacab9..dddefe1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -82,6 +82,9 @@ class ALS private (
   private var intermediateRDDStorageLevel: StorageLevel = 
StorageLevel.MEMORY_AND_DISK
   private var finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
 
+  /** checkpoint interval */
+  private var checkpointInterval: Int = 10
+
   /**
    * Set the number of blocks for both user blocks and product blocks to 
parallelize the computation
    * into; pass -1 for an auto-configured number of blocks. Default: -1.
@@ -183,6 +186,19 @@ class ALS private (
   }
 
   /**
+   * Set period (in iterations) between checkpoints (default = 10). 
Checkpointing helps with
+   * recovery (when nodes fail) and StackOverflow exceptions caused by long 
lineage. It also helps
+   * with eliminating temporary shuffle files on disk, which can be important 
when there are many
+   * ALS iterations. If the checkpoint directory is not set in 
[[org.apache.spark.SparkContext]],
+   * this setting is ignored.
+   */
+  @DeveloperApi
+  def setCheckpointInterval(checkpointInterval: Int): this.type = {
+    this.checkpointInterval = checkpointInterval
+    this
+  }
+
+  /**
    * Run ALS with the configured parameters on an input RDD of (user, product, 
rating) triples.
    * Returns a MatrixFactorizationModel with feature vectors for each user and 
product.
    */
@@ -212,6 +228,7 @@ class ALS private (
       nonnegative = nonnegative,
       intermediateRDDStorageLevel = intermediateRDDStorageLevel,
       finalRDDStorageLevel = StorageLevel.NONE,
+      checkpointInterval = checkpointInterval,
       seed = seed)
 
     val userFactors = floatUserFactors

http://git-wip-us.apache.org/repos/asf/spark/blob/bc92a2e4/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index bb86baf..0bb06e9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.ml.recommendation
 
+import java.io.File
 import java.util.Random
 
 import scala.collection.mutable
@@ -32,16 +33,25 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.util.Utils
 
 class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
 
   private var sqlContext: SQLContext = _
+  private var tempDir: File = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()
+    tempDir = Utils.createTempDir()
+    sc.setCheckpointDir(tempDir.getAbsolutePath)
     sqlContext = new SQLContext(sc)
   }
 
+  override def afterAll(): Unit = {
+    Utils.deleteRecursively(tempDir)
+    super.afterAll()
+  }
+
   test("LocalIndexEncoder") {
     val random = new Random
     for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) {
@@ -485,4 +495,11 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext 
with Logging {
       }.count()
     }
   }
+
+  test("als with large number of iterations") {
+    val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 
1)
+    ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, 
numItemBlocks = 2)
+    ALS.train(
+      ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, 
implicitPrefs = true)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to