Repository: spark
Updated Branches:
  refs/heads/master d33e3d572 -> 1eda2f10d


[SPARK-14646][ML] Modified Kmeans to store cluster centers with one per row

## What changes were proposed in this pull request?

Modified Kmeans to store cluster centers with one per row

## How was this patch tested?

Existing tests

Author: Joseph K. Bradley <[email protected]>

Closes #12792 from jkbradley/kmeans-save-fix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1eda2f10
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1eda2f10
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1eda2f10

Branch: refs/heads/master
Commit: 1eda2f10d9f7add319e5b271488045c44ea30c03
Parents: d33e3d5
Author: Joseph K. Bradley <[email protected]>
Authored: Fri Apr 29 16:46:25 2016 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Fri Apr 29 16:46:25 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/KMeans.scala  | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1eda2f10/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index bf2ab98..7c9ac02 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -169,18 +169,21 @@ object KMeansModel extends MLReadable[KMeansModel] {
   @Since("1.6.0")
   override def load(path: String): KMeansModel = super.load(path)
 
+  /** Helper class for storing model data */
+  private case class Data(clusterIdx: Int, clusterCenter: Vector)
+
   /** [[MLWriter]] instance for [[KMeansModel]] */
   private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends 
MLWriter {
 
-    private case class Data(clusterCenters: Array[Vector])
-
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
       DefaultParamsWriter.saveMetadata(instance, path, sc)
       // Save model data: cluster centers
-      val data = Data(instance.clusterCenters)
+      val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case 
(center, idx) =>
+        Data(idx, center)
+      }
       val dataPath = new Path(path, "data").toString
-      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+      sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath)
     }
   }
 
@@ -190,11 +193,15 @@ object KMeansModel extends MLReadable[KMeansModel] {
     private val className = classOf[KMeansModel].getName
 
     override def load(path: String): KMeansModel = {
+      // Import implicits for Dataset Encoder
+      val sqlContext = super.sqlContext
+      import sqlContext.implicits._
+
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
-      val data = 
sqlContext.read.parquet(dataPath).select("clusterCenters").head()
-      val clusterCenters = data.getAs[Seq[Vector]](0).toArray
+      val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
+      val clusterCenters = 
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter)
       val model = new KMeansModel(metadata.uid, new 
MLlibKMeansModel(clusterCenters))
 
       DefaultParamsReader.getAndSetParams(model, metadata)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to