This is an automated email from the ASF dual-hosted git repository. huaxingao pushed a commit to branch branch-3.2 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push: new d5e90cf [SPARK-36553][ML] KMeans avoid compute auxiliary statistics for large K d5e90cf is described below commit d5e90cf5ecf287eb53234e25e3a4cc37794360f2 Author: Ruifeng Zheng <ruife...@foxmail.com> AuthorDate: Wed Mar 2 11:51:06 2022 -0800 [SPARK-36553][ML] KMeans avoid compute auxiliary statistics for large K ### What changes were proposed in this pull request? SPARK-31007 introduce an auxiliary statistics to speed up computation in KMeasn. However, it needs a array of size `k * (k + 1) / 2`, which may cause overflow or OOM when k is too large. So we should skip this optimization in this case. ### Why are the changes needed? avoid overflow or OOM when k is too large (like 50,000) ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuites Closes #35457 from zhengruifeng/kmean_k_limit. Authored-by: Ruifeng Zheng <ruife...@foxmail.com> Signed-off-by: huaxingao <huaxin_...@apple.com> (cherry picked from commit ad5427ebe644fc01a9b4c19a48f902f584245edf) Signed-off-by: huaxingao <huaxin_...@apple.com> --- .../spark/mllib/clustering/DistanceMeasure.scala | 23 ++++++++++++++++++++++ .../org/apache/spark/mllib/clustering/KMeans.scala | 15 ++++++++++---- .../spark/mllib/clustering/KMeansModel.scala | 11 +++++++++-- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala index 9ac473a..e4c29a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala @@ -118,6 +118,24 @@ private[spark] abstract class DistanceMeasure extends Serializable { } /** + * @param centers the clustering centers + * @param statistics optional statistics to accelerate the computation, which should not + * change the result. + * @param point given point + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: Array[VectorWithNorm], + statistics: Option[Array[Double]], + point: VectorWithNorm): (Int, Double) = { + if (statistics.nonEmpty) { + findClosest(centers, statistics.get, point) + } else { + findClosest(centers, point) + } + } + + /** * @return the index of the closest center to the given point, as well as the cost. */ def findClosest( @@ -253,6 +271,11 @@ object DistanceMeasure { case _ => false } } + + private[clustering] def shouldComputeStatistics(k: Int): Boolean = k < 1000 + + private[clustering] def shouldComputeStatisticsLocally(k: Int, numFeatures: Int): Boolean = + k.toLong * k * numFeatures < 1000000 } private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 76e2928..c140b1b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -269,15 +269,22 @@ class KMeans private ( instr.foreach(_.logNumFeatures(numFeatures)) - val shouldDistributed = centers.length * centers.length * numFeatures.toLong > 1000000L + val shouldComputeStats = + DistanceMeasure.shouldComputeStatistics(centers.length) + val shouldComputeStatsLocally = + DistanceMeasure.shouldComputeStatisticsLocally(centers.length, numFeatures) // Execute iterations of Lloyd's algorithm until converged while (iteration < maxIterations && !converged) { val bcCenters = sc.broadcast(centers) - val stats = if (shouldDistributed) { - distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters) + val stats = if (shouldComputeStats) { + if (shouldComputeStatsLocally) { + Some(distanceMeasureInstance.computeStatistics(centers)) + } else { + Some(distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters)) + } } else { - distanceMeasureInstance.computeStatistics(centers) + None } val bcStats = sc.broadcast(stats) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a24493b..64b3521 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -50,9 +50,16 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], // TODO: computation of statistics may take seconds, so save it to KMeansModel in training @transient private lazy val statistics = if (clusterCenters == null) { - null + None } else { - distanceMeasureInstance.computeStatistics(clusterCentersWithNorm) + val k = clusterCenters.length + val numFeatures = clusterCenters.head.size + if (DistanceMeasure.shouldComputeStatistics(k) && + DistanceMeasure.shouldComputeStatisticsLocally(k, numFeatures)) { + Some(distanceMeasureInstance.computeStatistics(clusterCentersWithNorm)) + } else { + None + } } @Since("2.4.0") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org