Github user thunterdb commented on a diff in the pull request:

    https://github.com/apache/spark/pull/9907#discussion_r49002633
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
 ---
    @@ -109,4 +109,55 @@ class RegressionMetricsSuite extends SparkFunSuite 
with MLlibTestSparkContext {
           "root mean squared error mismatch")
         assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch")
       }
    +
    +  test("regression metrics with same(1.0) weight samples") {
    +    val predictionAndObservationWithWeight = sc.parallelize(
    +      Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 
7.0, 1.0)), 2)
    +    val metrics = new RegressionMetrics(predictionAndObservationWithWeight)
    +    assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5,
    +      "explained variance regression score mismatch")
    +    assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute 
error mismatch")
    +    assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared 
error mismatch")
    +    assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5,
    +      "root mean squared error mismatch")
    +    assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch")
    +  }
    +
    +
    +  /**
    +    * The following values are hand calculated using the formula:
    +    * 
[[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
    +    * preds = c(2.25, -0.25, 1.75, 7.75)
    +    * obs = c(3.0, -0.5, 2.0, 7.0)
    +    * weights = c(0.1, 0.2, 0.15, 0.05)
    +    * count = 4
    +    *
    +    * Weighted metrics can be calculated with 
MultivariateStatisticalSummary.
    +    *             (observations, observations - predictions)
    +    * mean        (1.7, 0.05)
    +    * variance    (7.3, 0.3)
    +    * numNonZeros (0.5, 0.5)
    +    * max         (7.0, 0.75)
    +    * min         (-0.5, -0.75)
    +    * normL2      (2.0, 0.32596)
    +    * normL1      (1.05, 0.2)
    +    *
    +    * explainedVariance: sum((preds - 1.7)^2) / count = 10.1775
    +    * meanAbsoluteError: normL1(1) / count = 0.05
    +    * meanSquaredError: normL2(1)^2 / count = 0.02656
    +    * rootMeanSquaredError: sqrt(meanSquaredError) = 0.16298
    +    * r2: 1 - normL2(1)^2 / (variance(0) * (count - 1)) = 0.9951484
    --- End diff --
    
    Same thing for this formula.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to