Repository: spark
Updated Branches:
  refs/heads/master 1db8feab8 -> a1b136d05


[SPARK-14634][ML] Add BisectingKMeansSummary

## What changes were proposed in this pull request?
Add BisectingKMeansSummary

## How was this patch tested?
unit test

Author: Zheng RuiFeng <ruife...@foxmail.com>

Closes #12394 from zhengruifeng/biKMSummary.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a1b136d0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a1b136d0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a1b136d0

Branch: refs/heads/master
Commit: a1b136d05c6c458ae8211b0844bfc98d7693fa42
Parents: 1db8fea
Author: Zheng RuiFeng <ruife...@foxmail.com>
Authored: Fri Oct 14 04:25:14 2016 -0700
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Fri Oct 14 04:25:14 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/clustering/BisectingKMeans.scala   | 74 +++++++++++++++++++-
 .../ml/clustering/BisectingKMeansSuite.scala    | 18 ++++-
 .../ml/clustering/GaussianMixtureSuite.scala    |  2 +-
 .../spark/ml/clustering/KMeansSuite.scala       |  2 +-
 4 files changed, 91 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a1b136d0/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index a97bd0f..add8ee2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering
 
 import org.apache.hadoop.fs.Path
 
+import org.apache.spark.SparkException
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.linalg.{Vector, VectorUDT}
@@ -127,6 +128,29 @@ class BisectingKMeansModel private[ml] (
 
   @Since("2.0.0")
   override def write: MLWriter = new 
BisectingKMeansModel.BisectingKMeansModelWriter(this)
+
+  private var trainingSummary: Option[BisectingKMeansSummary] = None
+
+  private[clustering] def setSummary(summary: BisectingKMeansSummary): 
this.type = {
+    this.trainingSummary = Some(summary)
+    this
+  }
+
+  /**
+   * Return true if there exists summary of model.
+   */
+  @Since("2.1.0")
+  def hasSummary: Boolean = trainingSummary.nonEmpty
+
+  /**
+   * Gets summary of model on training set. An exception is
+   * thrown if `trainingSummary == None`.
+   */
+  @Since("2.1.0")
+  def summary: BisectingKMeansSummary = trainingSummary.getOrElse {
+    throw new SparkException(
+      s"No training summary available for the ${this.getClass.getSimpleName}")
+  }
 }
 
 object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
@@ -228,14 +252,22 @@ class BisectingKMeans @Since("2.0.0") (
       case Row(point: Vector) => OldVectors.fromML(point)
     }
 
+    val instr = Instrumentation.create(this, rdd)
+    instr.logParams(featuresCol, predictionCol, k, maxIter, seed, 
minDivisibleClusterSize)
+
     val bkm = new MLlibBisectingKMeans()
       .setK($(k))
       .setMaxIterations($(maxIter))
       .setMinDivisibleClusterSize($(minDivisibleClusterSize))
       .setSeed($(seed))
     val parentModel = bkm.run(rdd)
-    val model = new BisectingKMeansModel(uid, parentModel)
-    copyValues(model.setParent(this))
+    val model = copyValues(new BisectingKMeansModel(uid, 
parentModel).setParent(this))
+    val summary = new BisectingKMeansSummary(
+      model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
+    model.setSummary(summary)
+    val m = model.setSummary(summary)
+    instr.logSuccess(m)
+    m
   }
 
   @Since("2.0.0")
@@ -251,3 +283,41 @@ object BisectingKMeans extends 
DefaultParamsReadable[BisectingKMeans] {
   @Since("2.0.0")
   override def load(path: String): BisectingKMeans = super.load(path)
 }
+
+
+/**
+ * :: Experimental ::
+ * Summary of BisectingKMeans.
+ *
+ * @param predictions  [[DataFrame]] produced by 
[[BisectingKMeansModel.transform()]]
+ * @param predictionCol  Name for column of predicted clusters in `predictions`
+ * @param featuresCol  Name for column of features in `predictions`
+ * @param k  Number of clusters
+ */
+@Since("2.1.0")
+@Experimental
+class BisectingKMeansSummary private[clustering] (
+    @Since("2.1.0") @transient val predictions: DataFrame,
+    @Since("2.1.0") val predictionCol: String,
+    @Since("2.1.0") val featuresCol: String,
+    @Since("2.1.0") val k: Int) extends Serializable {
+
+  /**
+   * Cluster centers of the transformed data.
+   */
+  @Since("2.1.0")
+  @transient lazy val cluster: DataFrame = predictions.select(predictionCol)
+
+  /**
+   * Size of (number of data points in) each cluster.
+   */
+  @Since("2.1.0")
+  lazy val clusterSizes: Array[Long] = {
+    val sizes = Array.fill[Long](k)(0)
+    cluster.groupBy(predictionCol).count().select(predictionCol, 
"count").collect().foreach {
+      case Row(cluster: Int, count: Long) => sizes(cluster) = count
+    }
+    sizes
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a1b136d0/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 4f7d441..f2368a9 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -68,7 +68,7 @@ class BisectingKMeansSuite
     }
   }
 
-  test("fit & transform") {
+  test("fit, transform and summary") {
     val predictionColName = "bisecting_kmeans_prediction"
     val bkm = new 
BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
     val model = bkm.fit(dataset)
@@ -85,6 +85,22 @@ class BisectingKMeansSuite
     assert(clusters === Set(0, 1, 2, 3, 4))
     assert(model.computeCost(dataset) < 0.1)
     assert(model.hasParent)
+
+    // Check validity of model summary
+    val numRows = dataset.count()
+    assert(model.hasSummary)
+    val summary: BisectingKMeansSummary = model.summary
+    assert(summary.predictionCol === predictionColName)
+    assert(summary.featuresCol === "features")
+    assert(summary.predictions.count() === numRows)
+    for (c <- Array(predictionColName, "features")) {
+      assert(summary.predictions.columns.contains(c))
+    }
+    assert(summary.cluster.columns === Array(predictionColName))
+    val clusterSizes = summary.clusterSizes
+    assert(clusterSizes.length === k)
+    assert(clusterSizes.sum === numRows)
+    assert(clusterSizes.forall(_ >= 0))
   }
 
   test("read/write") {

http://git-wip-us.apache.org/repos/asf/spark/blob/a1b136d0/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 04366f5..003fa6a 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -70,7 +70,7 @@ class GaussianMixtureSuite extends SparkFunSuite with 
MLlibTestSparkContext
     }
   }
 
-  test("fit, transform, and summary") {
+  test("fit, transform and summary") {
     val predictionColName = "gm_prediction"
     val probabilityColName = "gm_probability"
     val gm = new 
GaussianMixture().setK(k).setMaxIter(2).setPredictionCol(predictionColName)

http://git-wip-us.apache.org/repos/asf/spark/blob/a1b136d0/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c9ba5a2..ca39265 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with 
MLlibTestSparkContext with DefaultR
     }
   }
 
-  test("fit, transform, and summary") {
+  test("fit, transform and summary") {
     val predictionColName = "kmeans_prediction"
     val kmeans = new 
KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
     val model = kmeans.fit(dataset)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to