Repository: spark
Updated Branches:
  refs/heads/master 469276965 -> 3f6296fed


[SPARK-8018] [MLLIB] KMeans should accept initial cluster centers as param

 This allows Kmeans to be initialized using an existing set of cluster centers 
provided as  a KMeansModel object. This mode of initialization performs a 
single run.

Author: FlytxtRnD <[email protected]>

Closes #6737 from FlytxtRnD/Kmeans-8018 and squashes the following commits:

94b56df [FlytxtRnD] style correction
ef95ee2 [FlytxtRnD] style correction
c446c58 [FlytxtRnD] documentation and numRuns warning change
06d13ef [FlytxtRnD] numRuns corrected
d12336e [FlytxtRnD] numRuns variable modifications
07f8554 [FlytxtRnD] remove setRuns from setIntialModel
e721dfe [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into 
Kmeans-8018
242ead1 [FlytxtRnD] corrected == to === in assert
714acb5 [FlytxtRnD] added numRuns
60c8ce2 [FlytxtRnD] ignore runs parameter and initialModel test suite changed
582e6d9 [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into 
Kmeans-8018
3f5fc8e [FlytxtRnD] test case modified and one runs condition added
cd5dc5c [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into 
Kmeans-8018
16f1b53 [FlytxtRnD] Merge branch 'Kmeans-8018', remote-tracking branch 
'upstream/master' into Kmeans-8018
e9c35d7 [FlytxtRnD] Remove getInitialModel and match cluster count criteria
6959861 [FlytxtRnD] Accept initial cluster centers in KMeans


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3f6296fe
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3f6296fe
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3f6296fe

Branch: refs/heads/master
Commit: 3f6296fed4ee10f53e728eb1e02f13338839b94d
Parents: 4692769
Author: FlytxtRnD <[email protected]>
Authored: Tue Jul 14 23:29:02 2015 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Tue Jul 14 23:29:02 2015 -0700

----------------------------------------------------------------------
 docs/mllib-clustering.md                        |  1 +
 .../apache/spark/mllib/clustering/KMeans.scala  | 41 +++++++++++++++++---
 .../spark/mllib/clustering/KMeansSuite.scala    | 22 +++++++++++
 3 files changed, 58 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3f6296fe/docs/mllib-clustering.md
----------------------------------------------------------------------
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index d72dc20..0fc7036 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run 
multiple times on
 a given dataset, the algorithm returns the best clustering result).
 * *initializationSteps* determines the number of steps in the k-means\|\| 
algorithm.
 * *epsilon* determines the distance threshold within which we consider k-means 
to have converged.
+* *initialModel* is an optional set of cluster centers used for 
initialization. If this parameter is supplied, only one run is performed.
 
 **Examples**
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3f6296fe/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 0f8d6a3..6829713 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -156,6 +156,21 @@ class KMeans private (
     this
   }
 
+  // Initial cluster centers can be provided as a KMeansModel object rather 
than using the
+  // random or k-means|| initializationMode
+  private var initialModel: Option[KMeansModel] = None
+
+  /**
+   * Set the initial starting point, bypassing the random initialization or 
k-means||
+   * The condition model.k == this.k must be met, failure results
+   * in an IllegalArgumentException.
+   */
+  def setInitialModel(model: KMeansModel): this.type = {
+    require(model.k == k, "mismatched cluster count")
+    initialModel = Some(model)
+    this
+  }
+
   /**
    * Train a K-means model on the given set of points; `data` should be cached 
for high
    * performance, because this is an iterative algorithm.
@@ -193,20 +208,34 @@ class KMeans private (
 
     val initStartTime = System.nanoTime()
 
-    val centers = if (initializationMode == KMeans.RANDOM) {
-      initRandom(data)
+    // Only one run is allowed when initialModel is given
+    val numRuns = if (initialModel.nonEmpty) {
+      if (runs > 1) logWarning("Ignoring runs; one run is allowed when 
initialModel is given.")
+      1
     } else {
-      initKMeansParallel(data)
+      runs
     }
 
+    val centers = initialModel match {
+      case Some(kMeansCenters) => {
+        Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
+      }
+      case None => {
+        if (initializationMode == KMeans.RANDOM) {
+          initRandom(data)
+        } else {
+          initKMeansParallel(data)
+        }
+      }
+    }
     val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
     logInfo(s"Initialization with $initializationMode took " + 
"%.3f".format(initTimeInSeconds) +
       " seconds.")
 
-    val active = Array.fill(runs)(true)
-    val costs = Array.fill(runs)(0.0)
+    val active = Array.fill(numRuns)(true)
+    val costs = Array.fill(numRuns)(0.0)
 
-    var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
+    var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
     var iteration = 0
 
     val iterationStartTime = System.nanoTime()

http://git-wip-us.apache.org/repos/asf/spark/blob/3f6296fe/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 0dbbd71..3003c62 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       }
     }
   }
+
+  test("Initialize using given cluster centers") {
+    val points = Seq(
+      Vectors.dense(0.0, 0.0),
+      Vectors.dense(1.0, 0.0),
+      Vectors.dense(0.0, 1.0),
+      Vectors.dense(1.0, 1.0)
+    )
+    val rdd = sc.parallelize(points, 3)
+    // creating an initial model
+    val initialModel = new KMeansModel(Array(points(0), points(2)))
+
+    val returnModel = new KMeans()
+      .setK(2)
+      .setMaxIterations(0)
+      .setInitialModel(initialModel)
+      .run(rdd)
+   // comparing the returned model and the initial model
+    assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0))
+    assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
+  }
+
 }
 
 object KMeansSuite extends SparkFunSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to