Github user srowen commented on a diff in the pull request: https://github.com/apache/spark/pull/19340#discussion_r161379528 --- Diff: mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala --- @@ -546,10 +577,109 @@ object KMeans { .run(data) } + private[spark] def validateInitMode(initMode: String): Boolean = { + initMode match { + case KMeans.RANDOM => true + case KMeans.K_MEANS_PARALLEL => true + case _ => false + } + } + private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { + distanceMeasure match { + case DistanceMeasure.EUCLIDEAN => true + case DistanceMeasure.COSINE => true + case _ => false + } + } +} + +/** + * A vector with its norm for fast distance computation. + * + * @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]] + */ +private[clustering] +class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable { + + def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0)) + + def this(array: Array[Double]) = this(Vectors.dense(array)) + + /** Converts the vector to a dense vector. */ + def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) +} + + +private[spark] abstract class DistanceMeasure extends Serializable { + /** * Returns the index of the closest center to the given point, as well as the squared distance. */ - private[mllib] def findClosest( + def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + val currentDistance = distance(center, point) + if (currentDistance < bestDistance) { + bestDistance = currentDistance + bestIndex = i + } + i += 1 + } + (bestIndex, bestDistance) + } + + /** + * Returns the K-means cost of a given point against the given cluster centers. + */ + def pointCost( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): Double = + findClosest(centers, point)._2 + + /** + * Returns whether a center converged or not, given the epsilon parameter. + */ + def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = + distance(oldCenter, newCenter) <= epsilon + + /** + * Computes the cosine distance between two points. + */ + def distance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double + +} + +@Since("2.3.0") --- End diff -- All the "2.3.0" would likely have to change. I don't know if this would get in for 2.3.0.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org