huaxingao commented on a change in pull request #28553:
URL: https://github.com/apache/spark/pull/28553#discussion_r429522403
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
##########
@@ -360,43 +379,47 @@ private[evaluation] object SquaredEuclideanSilhouette
extends Silhouette {
* @param predictionCol The name of the column which contains the predicted
cluster id
* for the point.
* @param featuresCol The name of the column which contains the feature
vector of the point.
+ * @param weightCol The name of the column which contains the instance
weight.
* @return A [[scala.collection.immutable.Map]] which associates each
cluster id
* to a [[ClusterStats]] object (which contains the precomputed
values `N`,
* `$\Psi_{\Gamma}$` and `$Y_{\Gamma}$` for a cluster).
*/
def computeClusterStats(
df: DataFrame,
predictionCol: String,
- featuresCol: String): Map[Double, ClusterStats] = {
+ featuresCol: String,
+ weightCol: String): Map[Double, ClusterStats] = {
val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
val clustersStatsRDD = df.select(
- col(predictionCol).cast(DoubleType), col(featuresCol),
col("squaredNorm"))
+ col(predictionCol).cast(DoubleType), col(featuresCol),
col("squaredNorm"), col(weightCol))
.rdd
- .map { row => (row.getDouble(0), (row.getAs[Vector](1),
row.getDouble(2))) }
- .aggregateByKey[(DenseVector, Double,
Long)]((Vectors.zeros(numFeatures).toDense, 0.0, 0L))(
+ .map { row => (row.getDouble(0), (row.getAs[Vector](1),
row.getDouble(2), row.getDouble(3))) }
+ .aggregateByKey
+ [(DenseVector, Double, Double)]((Vectors.zeros(numFeatures).toDense,
0.0, 0.0))(
seqOp = {
case (
- (featureSum: DenseVector, squaredNormSum: Double, numOfPoints:
Long),
- (features, squaredNorm)
+ (featureSum: DenseVector, squaredNormSum: Double, weightSum:
Double),
+ (features, squaredNorm, weight)
) =>
- BLAS.axpy(1.0, features, featureSum)
- (featureSum, squaredNormSum + squaredNorm, numOfPoints + 1)
+ require (weight >= 0.0, "illegal weight value: " + weight + "
weight must be >= 0.0")
Review comment:
I think it's better to do the check here so it doesn't require an extra
pass to get all the weights.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]