zhengruifeng commented on a change in pull request #29501:
URL: https://github.com/apache/spark/pull/29501#discussion_r474389914
##########
File path: mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
##########
@@ -210,27 +210,26 @@ class KMeans private (
@Since("0.8.0")
def run(data: RDD[Vector]): KMeansModel = {
val instances = data.map(point => (point, 1.0))
- runWithWeight(instances, None)
+ val handlePersistence = data.getStorageLevel == StorageLevel.NONE
+ runWithWeight(instances, handlePersistence, None)
}
private[spark] def runWithWeight(
- data: RDD[(Vector, Double)],
+ instances: RDD[(Vector, Double)],
+ handlePersistence: Boolean,
instr: Option[Instrumentation]): KMeansModel = {
+ val norms = instances.map { case (v, _) => Vectors.norm(v, 2.0) }
+ val vectors = instances.zip(norms)
+ .map { case ((v, w), norm) => new VectorWithNorm(v, norm, w) }
- // Compute squared norms and cache them.
- val norms = data.map { case (v, _) =>
- Vectors.norm(v, 2.0)
- }
-
- val zippedData = data.zip(norms).map { case ((v, w), norm) =>
- new VectorWithNorm(v, norm, w)
- }
-
- if (data.getStorageLevel == StorageLevel.NONE) {
- zippedData.persist(StorageLevel.MEMORY_AND_DISK)
+ if (handlePersistence) {
Review comment:
if we need to persist the training dataset, we directly persit
`vectors`; otherwise, we persist the `norms` according to the comment: "Compute
squared norms and cache them."
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]