Repository: spark Updated Branches: refs/heads/master 1b50e0e0d -> 4d6d8192c
[SPARK-21268][MLLIB] Move center calculations to a distributed map in KMeans ## What changes were proposed in this pull request? The scal() and creation of newCenter vector is done in the driver, after a collectAsMap operation while it could be done in the distributed RDD. This PR moves this code before the collectAsMap for more efficiency ## How was this patch tested? This was tested manually by running the KMeansExample and verifying that the new code ran without error and gave same output as before. Author: dardelet <[email protected]> Author: Guillaume Dardelet <[email protected]> Closes #18491 from dardelet/move-center-calculation-to-distributed-map-kmean. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4d6d8192 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4d6d8192 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4d6d8192 Branch: refs/heads/master Commit: 4d6d8192c807006ff89488a1d38bc6f7d41de5cf Parents: 1b50e0e Author: dardelet <[email protected]> Authored: Tue Jul 4 17:58:44 2017 +0100 Committer: Sean Owen <[email protected]> Committed: Tue Jul 4 17:58:44 2017 +0100 ---------------------------------------------------------------------- .../scala/org/apache/spark/mllib/clustering/KMeans.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/4d6d8192/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fa72b72..98e50c5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -272,8 +272,8 @@ class KMeans private ( val costAccum = sc.doubleAccumulator val bcCenters = sc.broadcast(centers) - // Find the sum and count of points mapping to each center - val totalContribs = data.mapPartitions { points => + // Find the new centers + val newCenters = data.mapPartitions { points => val thisCenters = bcCenters.value val dims = thisCenters.head.vector.size @@ -292,15 +292,16 @@ class KMeans private ( }.reduceByKey { case ((sum1, count1), (sum2, count2)) => axpy(1.0, sum2, sum1) (sum1, count1 + count2) + }.mapValues { case (sum, count) => + scal(1.0 / count, sum) + new VectorWithNorm(sum) }.collectAsMap() bcCenters.destroy(blocking = false) // Update the cluster centers and costs converged = true - totalContribs.foreach { case (j, (sum, count)) => - scal(1.0 / count, sum) - val newCenter = new VectorWithNorm(sum) + newCenters.foreach { case (j, newCenter) => if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) { converged = false } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
