Github user yanboliang commented on a diff in the pull request:

    https://github.com/apache/spark/pull/10306#discussion_r47878122
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala ---
    @@ -250,114 +240,142 @@ class KMeans private (
             }
           }
         }
    +
         val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
         logInfo(s"Initialization with $initializationMode took " + 
"%.3f".format(initTimeInSeconds) +
           " seconds.")
     
    -    val active = Array.fill(numRuns)(true)
    -    val costs = Array.fill(numRuns)(0.0)
    -
    -    var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
    +    var costs = 0.0
         var iteration = 0
    -
         val iterationStartTime = System.nanoTime()
    +    val isSparse = data.take(1)(0).vector.isInstanceOf[SparseVector]
     
    -    // Execute iterations of Lloyd's algorithm until all runs have 
converged
    -    while (iteration < maxIterations && !activeRuns.isEmpty) {
    +    // Execute Lloyd's algorithm until converged or reached the max number 
of iterations
    +    while (iteration < maxIterations) {
           type WeightedPoint = (Vector, Long)
           def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint 
= {
             axpy(1.0, x._1, y._1)
             (y._1, x._2 + y._2)
           }
     
    -      val activeCenters = activeRuns.map(r => centers(r)).toArray
    -      val costAccums = activeRuns.map(_ => sc.accumulator(0.0))
    -
    -      val bcActiveCenters = sc.broadcast(activeCenters)
    +      val costAccums = sc.accumulator(0.0)
    +      val bcCenters = sc.broadcast(centers)
     
           // Find the sum and count of points mapping to each center
           val totalContribs = data.mapPartitions { points =>
    -        val thisActiveCenters = bcActiveCenters.value
    -        val runs = thisActiveCenters.length
    -        val k = thisActiveCenters(0).length
    -        val dims = thisActiveCenters(0)(0).vector.size
    +        val thisCenters = bcCenters.value
    +        val k = thisCenters.length
    +        val dims = thisCenters(0).vector.size
    +
    +        val sums = Array.fill(k)(Vectors.zeros(dims))
    +        val counts = Array.fill(k)(0L)
     
    -        val sums = Array.fill(runs, k)(Vectors.zeros(dims))
    -        val counts = Array.fill(runs, k)(0L)
    +        val vectorOfPoints = new ArrayBuffer[Vector]()
    +        val normOfPoints = new ArrayBuffer[Double]()
    +        var numRows = 0
     
    +        // Construct points matrix
             points.foreach { point =>
    -          (0 until runs).foreach { i =>
    -            val (bestCenter, cost) = 
KMeans.findClosest(thisActiveCenters(i), point)
    -            costAccums(i) += cost
    -            val sum = sums(i)(bestCenter)
    -            axpy(1.0, point.vector, sum)
    -            counts(i)(bestCenter) += 1
    +          vectorOfPoints.append(point.vector)
    +          normOfPoints.append(point.norm)
    +          numRows += 1
    +        }
    +
    +        val pointMatrix = if (isSparse) {
    +          val coo = new ArrayBuffer[(Int, Int, Double)]()
    +          vectorOfPoints.zipWithIndex.foreach { v =>
    +            val sv = v._1.asInstanceOf[SparseVector]
    +            sv.indices.indices.foreach { i =>
    +              coo.append((v._2, sv.indices(i), sv.values(i)))
    +            }
               }
    +          SparseMatrix.fromCOO(numRows, dims, coo.toSeq)
    +        } else {
    +          new DenseMatrix(numRows, dims, 
vectorOfPoints.flatMap(_.toArray).toArray, true)
             }
     
    -        val contribs = for (i <- 0 until runs; j <- 0 until k) yield {
    -          ((i, j), (sums(i)(j), counts(i)(j)))
    +        // Construct centers matrix
    +        val vectorOfCenters = new ArrayBuffer[Double]()
    +        val normOfCenters = new ArrayBuffer[Double]()
    +        thisCenters.foreach { center =>
    +          vectorOfCenters.appendAll(center.vector.toArray)
    +          normOfCenters.append(center.norm)
    +        }
    +        val centerMatrix = new DenseMatrix(dims, k, 
vectorOfCenters.toArray)
    +
    +        val a2b2 = new ArrayBuffer[Double]()
    +        val normOfPointsArray = normOfPoints.toArray
    +        val normOfCentersArray = normOfCenters.toArray
    +        for (i <- 0 until k; j <- 0 until numRows) {
    +          a2b2.append(normOfPointsArray(j) * normOfPointsArray(j) +
    +            normOfCentersArray(i) * normOfCentersArray(i))
    +        }
    +
    +        val distanceMatrix = new DenseMatrix(numRows, k, a2b2.toArray)
    +        gemm(-2.0, pointMatrix, centerMatrix, 1.0, distanceMatrix)
    +
    +        val vectorOfPointsArray = vectorOfPoints.toArray
    +        
distanceMatrix.transpose.toArray.grouped(k).toArray.map(_.zipWithIndex.min).zipWithIndex
    --- End diff --
    
    Good point!


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to