Repository: spark
Updated Branches:
  refs/heads/master 045a4f045 -> 2acdf10b1


[SPARK-6789][ML] Add Readable, Writable support for spark.ml ALS, ALSModel

Also modifies DefaultParamsWriter.saveMetadata to take optional extra metadata.

CC: mengxr yanboliang

Author: Joseph K. Bradley <[email protected]>

Closes #9786 from jkbradley/als-io.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2acdf10b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2acdf10b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2acdf10b

Branch: refs/heads/master
Commit: 2acdf10b1f3bb1242dba64efa798c672fde9f0d2
Parents: 045a4f0
Author: Joseph K. Bradley <[email protected]>
Authored: Wed Nov 18 13:16:31 2015 -0800
Committer: Xiangrui Meng <[email protected]>
Committed: Wed Nov 18 13:16:31 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/recommendation/ALS.scala    | 75 +++++++++++++++++--
 .../org/apache/spark/ml/util/ReadWrite.scala    | 14 +++-
 .../spark/ml/recommendation/ALSSuite.scala      | 78 +++++++++++++++++---
 3 files changed, 150 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2acdf10b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 535f266..d92514d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -27,13 +27,16 @@ import scala.util.hashing.byteswap64
 
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
 import org.apache.hadoop.fs.{FileSystem, Path}
+import org.json4s.{DefaultFormats, JValue}
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.{Logging, Partitioner}
-import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.CholeskyDecomposition
 import org.apache.spark.mllib.optimization.NNLS
 import org.apache.spark.rdd.RDD
@@ -182,7 +185,7 @@ class ALSModel private[ml] (
     val rank: Int,
     @transient val userFactors: DataFrame,
     @transient val itemFactors: DataFrame)
-  extends Model[ALSModel] with ALSModelParams {
+  extends Model[ALSModel] with ALSModelParams with Writable {
 
   /** @group setParam */
   def setUserCol(value: String): this.type = set(userCol, value)
@@ -220,8 +223,60 @@ class ALSModel private[ml] (
     val copied = new ALSModel(uid, rank, userFactors, itemFactors)
     copyValues(copied, extra).setParent(parent)
   }
+
+  @Since("1.6.0")
+  override def write: Writer = new ALSModel.ALSModelWriter(this)
 }
 
+@Since("1.6.0")
+object ALSModel extends Readable[ALSModel] {
+
+  @Since("1.6.0")
+  override def read: Reader[ALSModel] = new ALSModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): ALSModel = read.load(path)
+
+  private[recommendation] class ALSModelWriter(instance: ALSModel) extends 
Writer {
+
+    override protected def saveImpl(path: String): Unit = {
+      val extraMetadata = render("rank" -> instance.rank)
+      DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+      val userPath = new Path(path, "userFactors").toString
+      instance.userFactors.write.format("parquet").save(userPath)
+      val itemPath = new Path(path, "itemFactors").toString
+      instance.itemFactors.write.format("parquet").save(itemPath)
+    }
+  }
+
+  private[recommendation] class ALSModelReader extends Reader[ALSModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = "org.apache.spark.ml.recommendation.ALSModel"
+
+    override def load(path: String): ALSModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      implicit val format = DefaultFormats
+      val rank: Int = metadata.extraMetadata match {
+        case Some(m: JValue) =>
+          (m \ "rank").extract[Int]
+        case None =>
+          throw new RuntimeException(s"ALSModel loader could not read rank 
from JSON metadata:" +
+            s" ${metadata.metadataStr}")
+      }
+
+      val userPath = new Path(path, "userFactors").toString
+      val userFactors = sqlContext.read.format("parquet").load(userPath)
+      val itemPath = new Path(path, "itemFactors").toString
+      val itemFactors = sqlContext.read.format("parquet").load(itemPath)
+
+      val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors)
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+}
 
 /**
  * :: Experimental ::
@@ -254,7 +309,7 @@ class ALSModel private[ml] (
  * preferences rather than explicit ratings given to items.
  */
 @Experimental
-class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams 
{
+class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams 
with Writable {
 
   import org.apache.spark.ml.recommendation.ALS.Rating
 
@@ -336,8 +391,12 @@ class ALS(override val uid: String) extends 
Estimator[ALSModel] with ALSParams {
   }
 
   override def copy(extra: ParamMap): ALS = defaultCopy(extra)
+
+  @Since("1.6.0")
+  override def write: Writer = new DefaultParamsWriter(this)
 }
 
+
 /**
  * :: DeveloperApi ::
  * An implementation of ALS that supports generic ID types, specialized for 
Int and Long. This is
@@ -347,7 +406,7 @@ class ALS(override val uid: String) extends 
Estimator[ALSModel] with ALSParams {
  * than 2 billion.
  */
 @DeveloperApi
-object ALS extends Logging {
+object ALS extends Readable[ALS] with Logging {
 
   /**
    * :: DeveloperApi ::
@@ -356,6 +415,12 @@ object ALS extends Logging {
   @DeveloperApi
   case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: 
Float)
 
+  @Since("1.6.0")
+  override def read: Reader[ALS] = new DefaultParamsReader[ALS]
+
+  @Since("1.6.0")
+  override def load(path: String): ALS = read.load(path)
+
   /** Trait for least squares solvers applied to the normal equation. */
   private[recommendation] trait LeastSquaresNESolver extends Serializable {
     /** Solves a least squares problem with regularization (possibly with 
other constraints). */

http://git-wip-us.apache.org/repos/asf/spark/blob/2acdf10b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index dddb72a..d8ce907 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -194,7 +194,11 @@ private[ml] object DefaultParamsWriter {
    *  - uid
    *  - paramMap: These must be encodable using 
[[org.apache.spark.ml.param.Param.jsonEncode()]].
    */
-  def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
+  def saveMetadata(
+      instance: Params,
+      path: String,
+      sc: SparkContext,
+      extraMetadata: Option[JValue] = None): Unit = {
     val uid = instance.uid
     val cls = instance.getClass.getName
     val params = 
instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@@ -205,7 +209,8 @@ private[ml] object DefaultParamsWriter {
       ("timestamp" -> System.currentTimeMillis()) ~
       ("sparkVersion" -> sc.version) ~
       ("uid" -> uid) ~
-      ("paramMap" -> jsonParams)
+      ("paramMap" -> jsonParams) ~
+      ("extraMetadata" -> extraMetadata)
     val metadataPath = new Path(path, "metadata").toString
     val metadataJson = compact(render(metadata))
     sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
@@ -236,6 +241,7 @@ private[ml] object DefaultParamsReader {
   /**
    * All info from metadata file.
    * @param params  paramMap, as a [[JValue]]
+   * @param extraMetadata  Extra metadata saved by 
[[DefaultParamsWriter.saveMetadata()]]
    * @param metadataStr  Full metadata file String (for debugging)
    */
   case class Metadata(
@@ -244,6 +250,7 @@ private[ml] object DefaultParamsReader {
       timestamp: Long,
       sparkVersion: String,
       params: JValue,
+      extraMetadata: Option[JValue],
       metadataStr: String)
 
   /**
@@ -262,12 +269,13 @@ private[ml] object DefaultParamsReader {
     val timestamp = (metadata \ "timestamp").extract[Long]
     val sparkVersion = (metadata \ "sparkVersion").extract[String]
     val params = metadata \ "paramMap"
+    val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]]
     if (expectedClassName.nonEmpty) {
       require(className == expectedClassName, s"Error loading metadata: 
Expected class name" +
         s" $expectedClassName but found class name $className")
     }
 
-    Metadata(className, uid, timestamp, sparkVersion, params, metadataStr)
+    Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, 
metadataStr)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/2acdf10b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index eadc80e..2c3fb84 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.ml.recommendation
 
-import java.io.File
 import java.util.Random
 
 import scala.collection.mutable
@@ -26,28 +25,26 @@ import scala.language.existentials
 
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
 
+import org.apache.spark.util.Utils
 import org.apache.spark.{Logging, SparkException, SparkFunSuite}
 import org.apache.spark.ml.recommendation.ALS._
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.{DataFrame, Row}
 
-class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
 
-  private var tempDir: File = _
+class ALSSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest 
with Logging {
 
   override def beforeAll(): Unit = {
     super.beforeAll()
-    tempDir = Utils.createTempDir()
     sc.setCheckpointDir(tempDir.getAbsolutePath)
   }
 
   override def afterAll(): Unit = {
-    Utils.deleteRecursively(tempDir)
     super.afterAll()
   }
 
@@ -186,7 +183,7 @@ class ALSSuite extends SparkFunSuite with 
MLlibTestSparkContext with Logging {
     assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5))
     var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)]
     var i = 0
-    while (i < compressed.srcIds.size) {
+    while (i < compressed.srcIds.length) {
       var j = compressed.dstPtrs(i)
       while (j < compressed.dstPtrs(i + 1)) {
         val dstEncodedIndex = compressed.dstEncodedIndices(j)
@@ -483,4 +480,67 @@ class ALSSuite extends SparkFunSuite with 
MLlibTestSparkContext with Logging {
     ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, 
numItemBlocks = 2,
       implicitPrefs = true, seed = 0)
   }
+
+  test("read/write") {
+    import ALSSuite._
+    val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 
1)
+    val als = new ALS()
+    allEstimatorParamSettings.foreach { case (p, v) =>
+      als.set(als.getParam(p), v)
+    }
+    val sqlContext = this.sqlContext
+    import sqlContext.implicits._
+    val model = als.fit(ratings.toDF())
+
+    // Test Estimator save/load
+    val als2 = testDefaultReadWrite(als)
+    allEstimatorParamSettings.foreach { case (p, v) =>
+      val param = als.getParam(p)
+      assert(als.get(param).get === als2.get(param).get)
+    }
+
+    // Test Model save/load
+    val model2 = testDefaultReadWrite(model)
+    allModelParamSettings.foreach { case (p, v) =>
+      val param = model.getParam(p)
+      assert(model.get(param).get === model2.get(param).get)
+    }
+    assert(model.rank === model2.rank)
+    def getFactors(df: DataFrame): Set[(Int, Array[Float])] = {
+      df.select("id", "features").collect().map { case r =>
+        (r.getInt(0), r.getAs[Array[Float]](1))
+      }.toSet
+    }
+    assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
+    assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
+  }
+}
+
+object ALSSuite {
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as 
save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allModelParamSettings: Map[String, Any] = Map(
+    "predictionCol" -> "myPredictionCol"
+  )
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as 
save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allEstimatorParamSettings: Map[String, Any] = allModelParamSettings ++ 
Map(
+    "maxIter" -> 1,
+    "rank" -> 1,
+    "regParam" -> 0.01,
+    "numUserBlocks" -> 2,
+    "numItemBlocks" -> 2,
+    "implicitPrefs" -> true,
+    "alpha" -> 0.9,
+    "nonnegative" -> true,
+    "checkpointInterval" -> 20
+  )
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to