Github user avulanov commented on a diff in the pull request:
https://github.com/apache/spark/pull/10806#discussion_r54820188
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala ---
@@ -204,17 +204,54 @@ class KMeans private (
+ " parent RDDs are also uncached.")
}
- // Compute squared norms and cache them.
- val norms = data.map(Vectors.norm(_, 2.0))
- norms.persist()
- val zippedData = data.zip(norms).map { case (v, norm) =>
- new VectorWithNorm(v, norm)
+ val zippedData = data.map { x =>
+ val norm = Vectors.norm(x, 2.0)
+ new VectorWithNorm(x, norm)
}
- val model = runAlgorithm(zippedData)
- norms.unpersist()
+
+ val centers = initialModel match {
+ case Some(kMeansCenters) => {
+ kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))
+ }
+ case None => {
+ if (initializationMode == KMeans.RANDOM) {
+ initRandom(zippedData)
+ } else {
+ initKMeansParallel(zippedData)
+ }
+ }
+ }
+
+ val samplePoint = data.first()
+ val dims = samplePoint.size
+ // TODO: make stack size can be configured.
+ val stackSize = 128
+
+ val blockData = zippedData.mapPartitions { iter =>
+ iter.grouped(stackSize).map { points =>
+ val realSize = points.size
+ val pointsArray = new Array[Double](realSize * dims)
+ val pointsNormArray = new Array[Double](realSize)
+ var numRows = 0
+
+ points.foreach { point =>
+ System.arraycopy(point.vector.toArray, 0, pointsArray, numRows *
dims, dims)
+ pointsNormArray(numRows) = math.pow(point.norm, 2.0)
+ numRows += 1
+ }
+ val pointMatrix = new DenseMatrix(numRows, dims, pointsArray, true)
+ val pointsNormMatrix = new DenseMatrix(numRows, k,
Array.fill(k)(pointsNormArray).flatten)
+
+ (pointMatrix, pointsNormMatrix)
+ }
+ }
+
+ blockData.persist()
+ val model = runAlgorithm(blockData, centers)
+ blockData.unpersist()
// Warn at the end of the run as well, for increased visibility.
- if (data.getStorageLevel == StorageLevel.NONE) {
+ if (blockData.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt
performance if its"
+ " parent RDDs are also uncached.")
--- End diff --
It is hard to understand the difference between the 'input data`,
`blockData` and original RDD that is provided by the user. Message should be
more clear.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]