Repository: spark Updated Branches: refs/heads/master d33e3d572 -> 1eda2f10d
[SPARK-14646][ML] Modified Kmeans to store cluster centers with one per row ## What changes were proposed in this pull request? Modified Kmeans to store cluster centers with one per row ## How was this patch tested? Existing tests Author: Joseph K. Bradley <[email protected]> Closes #12792 from jkbradley/kmeans-save-fix. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1eda2f10 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1eda2f10 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1eda2f10 Branch: refs/heads/master Commit: 1eda2f10d9f7add319e5b271488045c44ea30c03 Parents: d33e3d5 Author: Joseph K. Bradley <[email protected]> Authored: Fri Apr 29 16:46:25 2016 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Fri Apr 29 16:46:25 2016 -0700 ---------------------------------------------------------------------- .../org/apache/spark/ml/clustering/KMeans.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1eda2f10/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index bf2ab98..7c9ac02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -169,18 +169,21 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.6.0") override def load(path: String): KMeansModel = super.load(path) + /** Helper class for storing model data */ + private case class Data(clusterIdx: Int, clusterCenter: Vector) + /** [[MLWriter]] instance for [[KMeansModel]] */ private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - private case class Data(clusterCenters: Array[Vector]) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) // Save model data: cluster centers - val data = Data(instance.clusterCenters) + val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => + Data(idx, center) + } val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath) } } @@ -190,11 +193,15 @@ object KMeansModel extends MLReadable[KMeansModel] { private val className = classOf[KMeansModel].getName override def load(path: String): KMeansModel = { + // Import implicits for Dataset Encoder + val sqlContext = super.sqlContext + import sqlContext.implicits._ + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head() - val clusterCenters = data.getAs[Seq[Vector]](0).toArray + val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data] + val clusterCenters = data.collect().sortBy(_.clusterIdx).map(_.clusterCenter) val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
