huaxingao commented on a change in pull request #28553:
URL: https://github.com/apache/spark/pull/28553#discussion_r429522403



##########
File path: 
mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
##########
@@ -360,43 +379,47 @@ private[evaluation] object SquaredEuclideanSilhouette 
extends Silhouette {
    * @param predictionCol The name of the column which contains the predicted 
cluster id
    *                      for the point.
    * @param featuresCol The name of the column which contains the feature 
vector of the point.
+   * @param weightCol The name of the column which contains the instance 
weight.
    * @return A [[scala.collection.immutable.Map]] which associates each 
cluster id
    *         to a [[ClusterStats]] object (which contains the precomputed 
values `N`,
    *         `$\Psi_{\Gamma}$` and `$Y_{\Gamma}$` for a cluster).
    */
   def computeClusterStats(
     df: DataFrame,
     predictionCol: String,
-    featuresCol: String): Map[Double, ClusterStats] = {
+    featuresCol: String,
+    weightCol: String): Map[Double, ClusterStats] = {
     val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
     val clustersStatsRDD = df.select(
-        col(predictionCol).cast(DoubleType), col(featuresCol), 
col("squaredNorm"))
+        col(predictionCol).cast(DoubleType), col(featuresCol), 
col("squaredNorm"), col(weightCol))
       .rdd
-      .map { row => (row.getDouble(0), (row.getAs[Vector](1), 
row.getDouble(2))) }
-      .aggregateByKey[(DenseVector, Double, 
Long)]((Vectors.zeros(numFeatures).toDense, 0.0, 0L))(
+      .map { row => (row.getDouble(0), (row.getAs[Vector](1), 
row.getDouble(2), row.getDouble(3))) }
+      .aggregateByKey
+        [(DenseVector, Double, Double)]((Vectors.zeros(numFeatures).toDense, 
0.0, 0.0))(
         seqOp = {
           case (
-              (featureSum: DenseVector, squaredNormSum: Double, numOfPoints: 
Long),
-              (features, squaredNorm)
+              (featureSum: DenseVector, squaredNormSum: Double, weightSum: 
Double),
+              (features, squaredNorm, weight)
             ) =>
-            BLAS.axpy(1.0, features, featureSum)
-            (featureSum, squaredNormSum + squaredNorm, numOfPoints + 1)
+            require (weight >= 0.0, "illegal weight value: " + weight + " 
weight must be >= 0.0")

Review comment:
       I think it's better to do the check here so it doesn't require an extra 
pass to get all the weights.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to