srowen commented on a change in pull request #28553:
URL: https://github.com/apache/spark/pull/28553#discussion_r429642262
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
##########
@@ -360,43 +379,47 @@ private[evaluation] object SquaredEuclideanSilhouette
extends Silhouette {
* @param predictionCol The name of the column which contains the predicted
cluster id
* for the point.
* @param featuresCol The name of the column which contains the feature
vector of the point.
+ * @param weightCol The name of the column which contains the instance
weight.
* @return A [[scala.collection.immutable.Map]] which associates each
cluster id
* to a [[ClusterStats]] object (which contains the precomputed
values `N`,
* `$\Psi_{\Gamma}$` and `$Y_{\Gamma}$` for a cluster).
*/
def computeClusterStats(
df: DataFrame,
predictionCol: String,
- featuresCol: String): Map[Double, ClusterStats] = {
+ featuresCol: String,
+ weightCol: String): Map[Double, ClusterStats] = {
val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
val clustersStatsRDD = df.select(
- col(predictionCol).cast(DoubleType), col(featuresCol),
col("squaredNorm"))
+ col(predictionCol).cast(DoubleType), col(featuresCol),
col("squaredNorm"), col(weightCol))
.rdd
- .map { row => (row.getDouble(0), (row.getAs[Vector](1),
row.getDouble(2))) }
- .aggregateByKey[(DenseVector, Double,
Long)]((Vectors.zeros(numFeatures).toDense, 0.0, 0L))(
+ .map { row => (row.getDouble(0), (row.getAs[Vector](1),
row.getDouble(2), row.getDouble(3))) }
+ .aggregateByKey
+ [(DenseVector, Double, Double)]((Vectors.zeros(numFeatures).toDense,
0.0, 0.0))(
seqOp = {
case (
- (featureSum: DenseVector, squaredNormSum: Double, numOfPoints:
Long),
- (features, squaredNorm)
+ (featureSum: DenseVector, squaredNormSum: Double, weightSum:
Double),
+ (features, squaredNorm, weight)
) =>
- BLAS.axpy(1.0, features, featureSum)
- (featureSum, squaredNormSum + squaredNorm, numOfPoints + 1)
+ require (weight >= 0.0, "illegal weight value: " + weight + "
weight must be >= 0.0")
Review comment:
Looks good; it's consistent with your other change. Really minor: use
string interpolation?
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
##########
@@ -555,30 +583,35 @@ private[evaluation] object CosineSilhouette extends
Silhouette {
* about a cluster which are needed by the algorithm.
*
* @param df The DataFrame which contains the input data
+ * @param featuresCol The name of the column which contains the feature
vector of the point.
* @param predictionCol The name of the column which contains the predicted
cluster id
* for the point.
+ * @param weightCol The name of the column which contains the instance
weight.
* @return A [[scala.collection.immutable.Map]] which associates each
cluster id to a
* its statistics (ie. the precomputed values `N` and
`$\Omega_{\Gamma}$`).
*/
def computeClusterStats(
df: DataFrame,
featuresCol: String,
- predictionCol: String): Map[Double, (Vector, Long)] = {
+ predictionCol: String,
+ weightCol: String): Map[Double, (Vector, Double)] = {
val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
val clustersStatsRDD = df.select(
- col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName))
+ col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName),
col(weightCol))
.rdd
- .map { row => (row.getDouble(0), row.getAs[Vector](1)) }
- .aggregateByKey[(DenseVector,
Long)]((Vectors.zeros(numFeatures).toDense, 0L))(
+ .map { row => (row.getDouble(0), (row.getAs[Vector](1),
row.getDouble(2))) }
+ .aggregateByKey[(DenseVector,
Double)]((Vectors.zeros(numFeatures).toDense, 0.0))(
seqOp = {
- case ((normalizedFeaturesSum: DenseVector, numOfPoints: Long),
(normalizedFeatures)) =>
- BLAS.axpy(1.0, normalizedFeatures, normalizedFeaturesSum)
- (normalizedFeaturesSum, numOfPoints + 1)
+ case ((normalizedFeaturesSum: DenseVector, weightSum: Double),
+ (normalizedFeatures, weight)) =>
+ require (weight >= 0.0, "illegal weight value: " + weight + " weight
must be >= 0.0")
Review comment:
Same here and nit: remove space after require
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]