srowen closed pull request #17085: [SPARK-24102][ML][MLLIB] ML Evaluators 
should use weight column - added weight column for regression evaluator
URL: https://github.com/apache/spark/pull/17085
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 031cd0d635bf4..616569bb55e4c 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
 
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
-import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
+import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, 
HasWeightCol}
 import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable, SchemaUtils}
 import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.sql.{Dataset, Row}
@@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
 @Since("1.4.0")
 @Experimental
 final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val 
uid: String)
-  extends Evaluator with HasPredictionCol with HasLabelCol with 
DefaultParamsWritable {
+  extends Evaluator with HasPredictionCol with HasLabelCol
+    with HasWeightCol with DefaultParamsWritable {
 
   @Since("1.4.0")
   def this() = this(Identifiable.randomUID("regEval"))
@@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") 
(@Since("1.4.0") override val ui
   @Since("1.4.0")
   def setLabelCol(value: String): this.type = set(labelCol, value)
 
+  /** @group setParam */
+  @Since("3.0.0")
+  def setWeightCol(value: String): this.type = set(weightCol, value)
+
   setDefault(metricName -> "rmse")
 
   @Since("2.0.0")
@@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") 
(@Since("1.4.0") override val ui
     SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, 
FloatType))
     SchemaUtils.checkNumericType(schema, $(labelCol))
 
-    val predictionAndLabels = dataset
-      .select(col($(predictionCol)).cast(DoubleType), 
col($(labelCol)).cast(DoubleType))
+    val predictionAndLabelsWithWeights = dataset
+      .select(col($(predictionCol)).cast(DoubleType), 
col($(labelCol)).cast(DoubleType),
+        if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else 
col($(weightCol)))
       .rdd
-      .map { case Row(prediction: Double, label: Double) => (prediction, 
label) }
-    val metrics = new RegressionMetrics(predictionAndLabels)
+      .map { case Row(prediction: Double, label: Double, weight: Double) =>
+        (prediction, label, weight) }
+    val metrics = new RegressionMetrics(predictionAndLabelsWithWeights)
     val metric = $(metricName) match {
       case "rmse" => metrics.rootMeanSquaredError
       case "mse" => metrics.meanSquaredError
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
index 020676cac5a64..525047973ad5c 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -27,17 +27,18 @@ import org.apache.spark.sql.DataFrame
 /**
  * Evaluator for regression.
  *
- * @param predictionAndObservations an RDD of (prediction, observation) pairs
+ * @param predAndObsWithOptWeight an RDD of either (prediction, observation, 
weight)
+ *                                                    or (prediction, 
observation) pairs
  * @param throughOrigin True if the regression is through the origin. For 
example, in linear
  *                      regression, it will be true without fitting intercept.
  */
 @Since("1.2.0")
 class RegressionMetrics @Since("2.0.0") (
-    predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean)
+    predAndObsWithOptWeight: RDD[_ <: Product], throughOrigin: Boolean)
     extends Logging {
 
   @Since("1.2.0")
-  def this(predictionAndObservations: RDD[(Double, Double)]) =
+  def this(predictionAndObservations: RDD[_ <: Product]) =
     this(predictionAndObservations, false)
 
   /**
@@ -52,10 +53,13 @@ class RegressionMetrics @Since("2.0.0") (
    * Use MultivariateOnlineSummarizer to calculate summary statistics of 
observations and errors.
    */
   private lazy val summary: MultivariateStatisticalSummary = {
-    val summary: MultivariateStatisticalSummary = 
predictionAndObservations.map {
-      case (prediction, observation) => Vectors.dense(observation, observation 
- prediction)
+    val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map {
+      case (prediction: Double, observation: Double, weight: Double) =>
+        (Vectors.dense(observation, observation - prediction), weight)
+      case (prediction: Double, observation: Double) =>
+        (Vectors.dense(observation, observation - prediction), 1.0)
     }.treeAggregate(new MultivariateOnlineSummarizer())(
-        (summary, v) => summary.add(v),
+        (summary, sample) => summary.add(sample._1, sample._2),
         (sum1, sum2) => sum1.merge(sum2)
       )
     summary
@@ -63,11 +67,13 @@ class RegressionMetrics @Since("2.0.0") (
 
   private lazy val SSy = math.pow(summary.normL2(0), 2)
   private lazy val SSerr = math.pow(summary.normL2(1), 2)
-  private lazy val SStot = summary.variance(0) * (summary.count - 1)
+  private lazy val SStot = summary.variance(0) * (summary.weightSum - 1)
   private lazy val SSreg = {
     val yMean = summary.mean(0)
-    predictionAndObservations.map {
-      case (prediction, _) => math.pow(prediction - yMean, 2)
+    predAndObsWithOptWeight.map {
+      case (prediction: Double, _: Double, weight: Double) =>
+        math.pow(prediction - yMean, 2) * weight
+      case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2)
     }.sum()
   }
 
@@ -79,7 +85,7 @@ class RegressionMetrics @Since("2.0.0") (
    */
   @Since("1.2.0")
   def explainedVariance: Double = {
-    SSreg / summary.count
+    SSreg / summary.weightSum
   }
 
   /**
@@ -88,7 +94,7 @@ class RegressionMetrics @Since("2.0.0") (
    */
   @Since("1.2.0")
   def meanAbsoluteError: Double = {
-    summary.normL1(1) / summary.count
+    summary.normL1(1) / summary.weightSum
   }
 
   /**
@@ -97,7 +103,7 @@ class RegressionMetrics @Since("2.0.0") (
    */
   @Since("1.2.0")
   def meanSquaredError: Double = {
-    SSerr / summary.count
+    SSerr / summary.weightSum
   }
 
   /**
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 0554b6d8ff5b5..6d510e1633d67 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
   private var totalCnt: Long = 0
   private var totalWeightSum: Double = 0.0
   private var weightSquareSum: Double = 0.0
-  private var weightSum: Array[Double] = _
+  private var currWeightSum: Array[Double] = _
   private var nnz: Array[Long] = _
   private var currMax: Array[Double] = _
   private var currMin: Array[Double] = _
@@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
       currM2n = Array.ofDim[Double](n)
       currM2 = Array.ofDim[Double](n)
       currL1 = Array.ofDim[Double](n)
-      weightSum = Array.ofDim[Double](n)
+      currWeightSum = Array.ofDim[Double](n)
       nnz = Array.ofDim[Long](n)
       currMax = Array.fill[Double](n)(Double.MinValue)
       currMin = Array.fill[Double](n)(Double.MaxValue)
@@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
     val localCurrM2n = currM2n
     val localCurrM2 = currM2
     val localCurrL1 = currL1
-    val localWeightSum = weightSum
+    val localWeightSum = currWeightSum
     val localNumNonzeros = nnz
     val localCurrMax = currMax
     val localCurrMin = currMin
@@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
       weightSquareSum += other.weightSquareSum
       var i = 0
       while (i < n) {
-        val thisNnz = weightSum(i)
-        val otherNnz = other.weightSum(i)
+        val thisNnz = currWeightSum(i)
+        val otherNnz = other.currWeightSum(i)
         val totalNnz = thisNnz + otherNnz
         val totalCnnz = nnz(i) + other.nnz(i)
         if (totalNnz != 0.0) {
@@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
           currMax(i) = math.max(currMax(i), other.currMax(i))
           currMin(i) = math.min(currMin(i), other.currMin(i))
         }
-        weightSum(i) = totalNnz
+        currWeightSum(i) = totalNnz
         nnz(i) = totalCnnz
         i += 1
       }
@@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
       this.totalCnt = other.totalCnt
       this.totalWeightSum = other.totalWeightSum
       this.weightSquareSum = other.weightSquareSum
-      this.weightSum = other.weightSum.clone()
+      this.currWeightSum = other.currWeightSum.clone()
       this.nnz = other.nnz.clone()
       this.currMax = other.currMax.clone()
       this.currMin = other.currMin.clone()
@@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
     val realMean = Array.ofDim[Double](n)
     var i = 0
     while (i < n) {
-      realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
+      realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum)
       i += 1
     }
     Vectors.dense(realMean)
@@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
       val len = currM2n.length
       while (i < len) {
         // We prevent variance from negative value caused by numerical error.
-        realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * 
weightSum(i) *
-          (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
+        realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * 
currWeightSum(i) *
+          (totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 
0.0)
         i += 1
       }
     }
@@ -229,6 +229,11 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
   @Since("1.1.0")
   override def count: Long = totalCnt
 
+  /**
+   * Sum of weights.
+   */
+  override def weightSum: Double = totalWeightSum
+
   /**
    * Number of nonzero elements in each dimension.
    *
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
index 39a16fb743d64..a4381032f8c0d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
@@ -44,6 +44,12 @@ trait MultivariateStatisticalSummary {
   @Since("1.0.0")
   def count: Long
 
+  /**
+   * Sum of weights.
+   */
+  @Since("3.0.0")
+  def weightSum: Double
+
   /**
    * Number of nonzero elements (including explicitly presented zero values) 
in each column.
    */
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
index f1d517383643d..23809777f7d3a 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -133,4 +133,54 @@ class RegressionMetricsSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       "root mean squared error mismatch")
     assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch")
   }
+
+  test("regression metrics with same (1.0) weight samples") {
+    val predictionAndObservationWithWeight = sc.parallelize(
+      Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 
1.0)), 2)
+    val metrics = new RegressionMetrics(predictionAndObservationWithWeight, 
false)
+    assert(metrics.explainedVariance ~== 8.79687 absTol eps,
+      "explained variance regression score mismatch")
+    assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error 
mismatch")
+    assert(metrics.meanSquaredError ~== 0.3125 absTol eps, "mean squared error 
mismatch")
+    assert(metrics.rootMeanSquaredError ~== 0.55901 absTol eps,
+      "root mean squared error mismatch")
+    assert(metrics.r2 ~== 0.95717 absTol eps, "r2 score mismatch")
+  }
+
+  /**
+   * The following values are hand calculated using the formula:
+   * 
[[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
+   * preds = c(2.25, -0.25, 1.75, 7.75)
+   * obs = c(3.0, -0.5, 2.0, 7.0)
+   * weights = c(0.1, 0.2, 0.15, 0.05)
+   * count = 4
+   *
+   * Weighted metrics can be calculated with MultivariateStatisticalSummary.
+   *             (observations, observations - predictions)
+   * mean        (1.7, 0.05)
+   * variance    (7.3, 0.3)
+   * numNonZeros (0.5, 0.5)
+   * max         (7.0, 0.75)
+   * min         (-0.5, -0.75)
+   * normL2      (2.0, 0.32596)
+   * normL1      (1.05, 0.2)
+   *
+   * explainedVariance: sum(pow((preds - 1.7),2)*weight) / weightedCount = 
5.2425
+   * meanAbsoluteError: normL1(1) / weightedCount = 0.4
+   * meanSquaredError: pow(normL2(1),2) / weightedCount = 0.2125
+   * rootMeanSquaredError: sqrt(meanSquaredError) = 0.46098
+   * r2: 1 - pow(normL2(1),2) / (variance(0) * (weightedCount - 1)) = 1.02910
+   */
+  test("regression metrics with weighted samples") {
+    val predictionAndObservationWithWeight = sc.parallelize(
+      Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 
0.05)), 2)
+    val metrics = new RegressionMetrics(predictionAndObservationWithWeight, 
false)
+    assert(metrics.explainedVariance ~== 5.2425 absTol eps,
+      "explained variance regression score mismatch")
+    assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error 
mismatch")
+    assert(metrics.meanSquaredError ~== 0.2125 absTol eps, "mean squared error 
mismatch")
+    assert(metrics.rootMeanSquaredError ~== 0.46098 absTol eps,
+      "root mean squared error mismatch")
+    assert(metrics.r2 ~== 1.02910 absTol eps, "r2 score mismatch")
+  }
 }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index b3252d70a80c8..883913332ca1e 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -531,7 +531,10 @@ object MimaExcludes {
     
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"),
     
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"),
     
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"),
-    
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes")
+    
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"),
+
+    // [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary
+    
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum")
   ) ++ Seq(
       // [SPARK-17019] Expose on-heap and off-heap memory usage in various 
places
       
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"),


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to