Repository: spark Updated Branches: refs/heads/master 8ac71d62d -> e1e77b22b
[SPARK-11029] [ML] Add computeCost to KMeansModel in spark.ml jira: https://issues.apache.org/jira/browse/SPARK-11029 We should add a method analogous to spark.mllib.clustering.KMeansModel.computeCost to spark.ml.clustering.KMeansModel. This will be a temp fix until we have proper evaluators defined for clustering. Author: Yuhao Yang <hhb...@gmail.com> Author: yuhaoyang <yuhao@zhanglipings-iMac.local> Closes #9073 from hhbyyh/computeCost. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e1e77b22 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e1e77b22 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e1e77b22 Branch: refs/heads/master Commit: e1e77b22b3b577909a12c3aa898eb53be02267fd Parents: 8ac71d6 Author: Yuhao Yang <hhb...@gmail.com> Authored: Sat Oct 17 10:04:19 2015 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Sat Oct 17 10:04:19 2015 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 12 ++++++++++++ .../org/apache/spark/ml/clustering/KMeansSuite.scala | 1 + 2 files changed, 13 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e1e77b22/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index f40ab71..509be63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -117,6 +117,18 @@ class KMeansModel private[ml] ( @Since("1.5.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters + + /** + * Return the K-means cost (sum of squared distances of points to their nearest center) for this + * model on the given data. + */ + // TODO: Replace the temp fix when we have proper evaluators defined for clustering. + @Since("1.6.0") + def computeCost(dataset: DataFrame): Double = { + SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) + val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + parentModel.computeCost(data) + } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/e1e77b22/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 688b0e3..c05f905 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -104,5 +104,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org