Github user imatiach-msft commented on a diff in the pull request:
https://github.com/apache/spark/pull/17086#discussion_r182300738
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
---
@@ -75,11 +80,16 @@ class MulticlassClassificationEvaluator @Since("1.5.0")
(@Since("1.5.0") overrid
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(labelCol))
- val predictionAndLabels =
- dataset.select(col($(predictionCol)),
col($(labelCol)).cast(DoubleType)).rdd.map {
- case Row(prediction: Double, label: Double) => (prediction, label)
+ val predictionAndLabelsWithWeights =
+ dataset.select(col($(predictionCol)),
col($(labelCol)).cast(DoubleType),
+ if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else
col($(weightCol)))
+ .rdd.map {
+ case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight)
}
- val metrics = new MulticlassMetrics(predictionAndLabels)
+ dataset.select(col($(predictionCol)),
col($(labelCol)).cast(DoubleType)).rdd.map {
+ case Row(prediction: Double, label: Double) => (prediction, label)
+ }.values.countByValue()
--- End diff --
good catch -- hmm that shouldn't be there, not sure why I added it, removed
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]