Repository: spark
Updated Branches:
  refs/heads/master 267237411 -> 2d4e00efe


[SPARK-5986][MLLib] Add save/load for k-means

This PR adds save/load for K-means as described in SPARK-5986. Python version 
will be added in another PR.

Author: Xusen Yin <[email protected]>

Closes #4951 from yinxusen/SPARK-5986 and squashes the following commits:

6dd74a0 [Xusen Yin] rewrite some functions and classes
cd390fd [Xusen Yin] add indexed point
b144216 [Xusen Yin] remove invalid comments
dce7055 [Xusen Yin] add save/load for k-means for SPARK-5986


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2d4e00ef
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2d4e00ef
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2d4e00ef

Branch: refs/heads/master
Commit: 2d4e00efe2cf179935ae108a68f28edf6e5a1628
Parents: 2672374
Author: Xusen Yin <[email protected]>
Authored: Wed Mar 11 00:24:55 2015 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Wed Mar 11 00:24:55 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/clustering/KMeansModel.scala    | 68 +++++++++++++++++++-
 .../spark/mllib/clustering/KMeansSuite.scala    | 44 ++++++++++++-
 2 files changed, 108 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2d4e00ef/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index 3b95a9e..707da53 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -17,15 +17,22 @@
 
 package org.apache.spark.mllib.clustering
 
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext._
 import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.{Loader, Saveable}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.Row
 
 /**
  * A clustering model for K-means. Each point belongs to the cluster with the 
closest center.
  */
-class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
+class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with 
Serializable {
 
   /** Total number of clusters. */
   def k: Int = clusterCenters.length
@@ -58,4 +65,59 @@ class KMeansModel (val clusterCenters: Array[Vector]) 
extends Serializable {
 
   private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
     clusterCenters.map(new VectorWithNorm(_))
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    KMeansModel.SaveLoadV1_0.save(sc, this, path)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+object KMeansModel extends Loader[KMeansModel] {
+  override def load(sc: SparkContext, path: String): KMeansModel = {
+    KMeansModel.SaveLoadV1_0.load(sc, path)
+  }
+
+  private case class Cluster(id: Int, point: Vector)
+
+  private object Cluster {
+    def apply(r: Row): Cluster = {
+      Cluster(r.getInt(0), r.getAs[Vector](1))
+    }
+  }
+
+  private[clustering]
+  object SaveLoadV1_0 {
+
+    private val thisFormatVersion = "1.0"
+
+    private[clustering]
+    val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
+
+    def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
+      val sqlContext = new SQLContext(sc)
+      import sqlContext.implicits._
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" 
-> model.k)))
+      sc.parallelize(Seq(metadata), 
1).saveAsTextFile(Loader.metadataPath(path))
+      val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { 
case (point, id) =>
+        Cluster(id, point)
+      }.toDF()
+      dataRDD.saveAsParquetFile(Loader.dataPath(path))
+    }
+
+    def load(sc: SparkContext, path: String): KMeansModel = {
+      implicit val formats = DefaultFormats
+      val sqlContext = new SQLContext(sc)
+      val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+      assert(className == thisClassName)
+      assert(formatVersion == thisFormatVersion)
+      val k = (metadata \ "k").extract[Int]
+      val centriods = sqlContext.parquetFile(Loader.dataPath(path))
+      Loader.checkSchema[Cluster](centriods.schema)
+      val localCentriods = centriods.map(Cluster.apply).collect()
+      assert(k == localCentriods.size)
+      new KMeansModel(localCentriods.sortBy(_.id).map(_.point))
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2d4e00ef/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index caee591..7bf250e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -21,9 +21,10 @@ import scala.util.Random
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, 
Vectors}
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, 
MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
 
 class KMeansSuite extends FunSuite with MLlibTestSparkContext {
 
@@ -257,6 +258,47 @@ class KMeansSuite extends FunSuite with 
MLlibTestSparkContext {
       assert(predicts(0) != predicts(3))
     }
   }
+
+  test("model save/load") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    Array(true, false).foreach { case selector =>
+      val model = KMeansSuite.createModel(10, 3, selector)
+      // Save model, load it back, and compare.
+      try {
+        model.save(sc, path)
+        val sameModel = KMeansModel.load(sc, path)
+        KMeansSuite.checkEqual(model, sameModel)
+      } finally {
+        Utils.deleteRecursively(tempDir)
+      }
+    }
+  }
+}
+
+object KMeansSuite extends FunSuite {
+  def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
+    val singlePoint = isSparse match {
+      case true =>
+        Vectors.sparse(dim, Array.empty[Int], Array.empty[Double])
+      case _ =>
+        Vectors.dense(Array.fill[Double](dim)(0.0))
+    }
+    new KMeansModel(Array.fill[Vector](k)(singlePoint))
+  }
+
+  def checkEqual(a: KMeansModel, b: KMeansModel): Unit = {
+    assert(a.k === b.k)
+    a.clusterCenters.zip(b.clusterCenters).foreach {
+      case (ca: SparseVector, cb: SparseVector) =>
+        assert(ca === cb)
+      case (ca: DenseVector, cb: DenseVector) =>
+        assert(ca === cb)
+      case _ =>
+        throw new AssertionError("checkEqual failed since the two clusters 
were not identical.\n")
+    }
+  }
 }
 
 class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to