Repository: spark
Updated Branches:
  refs/heads/master 8eb2dc713 -> 1cdc42d2b


[SPARK-12331][ML] R^2 for regression through the origin.

Modified the definition of R^2 for regression through origin. Added modified 
test for regression metrics.

Author: Imran Younus <iyou...@us.ibm.com>
Author: Imran Younus <imranyou...@gmail.com>

Closes #10384 from iyounus/SPARK_12331_R2_for_regression_through_origin.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1cdc42d2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1cdc42d2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1cdc42d2

Branch: refs/heads/master
Commit: 1cdc42d2b99edfec01066699a7620cca02b61f0e
Parents: 8eb2dc7
Author: Imran Younus <iyou...@us.ibm.com>
Authored: Tue Jan 5 11:48:45 2016 +0000
Committer: Sean Owen <so...@cloudera.com>
Committed: Tue Jan 5 11:48:45 2016 +0000

----------------------------------------------------------------------
 .../spark/ml/regression/LinearRegression.scala  |   3 +-
 .../mllib/evaluation/RegressionMetrics.scala    |  24 ++-
 .../evaluation/RegressionMetricsSuite.scala     | 156 +++++++++++--------
 3 files changed, 112 insertions(+), 71 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1cdc42d2/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index dee2633..c54e08b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -534,7 +534,8 @@ class LinearRegressionSummary private[regression] (
   @transient private val metrics = new RegressionMetrics(
     predictions
       .select(predictionCol, labelCol)
-      .map { case Row(pred: Double, label: Double) => (pred, label) } )
+      .map { case Row(pred: Double, label: Double) => (pred, label) },
+    !model.getFitIntercept)
 
   /**
    * Returns the explained variance regression score.

http://git-wip-us.apache.org/repos/asf/spark/blob/1cdc42d2/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
index 34883f2..18c90b2 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -27,11 +27,18 @@ import org.apache.spark.sql.DataFrame
 /**
  * Evaluator for regression.
  *
- * @param predictionAndObservations an RDD of (prediction, observation) pairs.
+ * @param predictionAndObservations an RDD of (prediction, observation) pairs
+ * @param throughOrigin True if the regression is through the origin. For 
example, in linear
+ *                      regression, it will be true without fitting intercept.
  */
 @Since("1.2.0")
-class RegressionMetrics @Since("1.2.0") (
-    predictionAndObservations: RDD[(Double, Double)]) extends Logging {
+class RegressionMetrics @Since("2.0.0") (
+    predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean)
+    extends Logging {
+
+  @Since("1.2.0")
+  def this(predictionAndObservations: RDD[(Double, Double)]) =
+    this(predictionAndObservations, false)
 
   /**
    * An auxiliary constructor taking a DataFrame.
@@ -53,6 +60,8 @@ class RegressionMetrics @Since("1.2.0") (
       )
     summary
   }
+
+  private lazy val SSy = math.pow(summary.normL2(0), 2)
   private lazy val SSerr = math.pow(summary.normL2(1), 2)
   private lazy val SStot = summary.variance(0) * (summary.count - 1)
   private lazy val SSreg = {
@@ -102,9 +111,16 @@ class RegressionMetrics @Since("1.2.0") (
   /**
    * Returns R^2^, the unadjusted coefficient of determination.
    * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
+   * In case of regression through the origin, the definition of R^2^ is to be 
modified.
+   * @see J. G. Eisenhauer, Regression through the Origin. Teaching Statistics 
25, 76-80 (2003)
+   * 
[[https://online.stat.psu.edu/~ajw13/stat501/SpecialTopics/Reg_thru_origin.pdf]]
    */
   @Since("1.2.0")
   def r2: Double = {
-    1 - SSerr / SStot
+    if (throughOrigin) {
+      1 - SSerr / SSy
+    } else {
+      1 - SSerr / SStot
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1cdc42d2/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
index 4b7f1be..f1d5173 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -22,91 +22,115 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 
 class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
+  val obs = List[Double](77, 85, 62, 55, 63, 88, 57, 81, 51)
+  val eps = 1E-5
 
   test("regression metrics for unbiased (includes intercept term) predictor") {
     /* Verify results in R:
-       preds = c(2.25, -0.25, 1.75, 7.75)
-       obs = c(3.0, -0.5, 2.0, 7.0)
-
-       SStot = sum((obs - mean(obs))^2)
-       SSreg = sum((preds - mean(obs))^2)
-       SSerr = sum((obs - preds)^2)
-
-       explainedVariance = SSreg / length(obs)
-       explainedVariance
-       > [1] 8.796875
-       meanAbsoluteError = mean(abs(preds - obs))
-       meanAbsoluteError
-       > [1] 0.5
-       meanSquaredError = mean((preds - obs)^2)
-       meanSquaredError
-       > [1] 0.3125
-       rmse = sqrt(meanSquaredError)
-       rmse
-       > [1] 0.559017
-       r2 = 1 - SSerr / SStot
-       r2
-       > [1] 0.9571734
+        y = c(77, 85, 62, 55, 63, 88, 57, 81, 51)
+        x = c(16, 22, 14, 10, 13, 19, 12, 18, 11)
+        df <- as.data.frame(cbind(x, y))
+        model <- lm(y ~  x, data=df)
+        preds = signif(predict(model), digits = 4)
+        preds
+            1     2     3     4     5     6     7     8     9
+        72.08 91.88 65.48 52.28 62.18 81.98 58.88 78.68 55.58
+        options(digits=8)
+        explainedVariance = mean((preds - mean(y))^2)
+        [1] 157.3
+        meanAbsoluteError = mean(abs(preds - y))
+        meanAbsoluteError
+        [1] 3.7355556
+        meanSquaredError = mean((preds - y)^2)
+        meanSquaredError
+        [1] 17.539511
+        rmse = sqrt(meanSquaredError)
+        rmse
+        [1] 4.18802
+        r2 = summary(model)$r.squared
+        r2
+        [1] 0.89968225
      */
-    val predictionAndObservations = sc.parallelize(
-      Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2)
+    val preds = List(72.08, 91.88, 65.48, 52.28, 62.18, 81.98, 58.88, 78.68, 
55.58)
+    val predictionAndObservations = sc.parallelize(preds.zip(obs), 2)
     val metrics = new RegressionMetrics(predictionAndObservations)
-    assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5,
+    assert(metrics.explainedVariance ~== 157.3 absTol eps,
       "explained variance regression score mismatch")
-    assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error 
mismatch")
-    assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared 
error mismatch")
-    assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5,
+    assert(metrics.meanAbsoluteError ~== 3.7355556 absTol eps, "mean absolute 
error mismatch")
+    assert(metrics.meanSquaredError ~== 17.539511 absTol eps, "mean squared 
error mismatch")
+    assert(metrics.rootMeanSquaredError ~== 4.18802 absTol eps,
       "root mean squared error mismatch")
-    assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch")
+    assert(metrics.r2 ~== 0.89968225 absTol eps, "r2 score mismatch")
   }
 
   test("regression metrics for biased (no intercept term) predictor") {
     /* Verify results in R:
-       preds = c(2.5, 0.0, 2.0, 8.0)
-       obs = c(3.0, -0.5, 2.0, 7.0)
-
-       SStot = sum((obs - mean(obs))^2)
-       SSreg = sum((preds - mean(obs))^2)
-       SSerr = sum((obs - preds)^2)
-
-       explainedVariance = SSreg / length(obs)
-       explainedVariance
-       > [1] 8.859375
-       meanAbsoluteError = mean(abs(preds - obs))
-       meanAbsoluteError
-       > [1] 0.5
-       meanSquaredError = mean((preds - obs)^2)
-       meanSquaredError
-       > [1] 0.375
-       rmse = sqrt(meanSquaredError)
-       rmse
-       > [1] 0.6123724
-       r2 = 1 - SSerr / SStot
-       r2
-       > [1] 0.9486081
+        y = c(77, 85, 62, 55, 63, 88, 57, 81, 51)
+        x = c(16, 22, 14, 10, 13, 19, 12, 18, 11)
+        df <- as.data.frame(cbind(x, y))
+        model <- lm(y ~ 0 + x, data=df)
+        preds = signif(predict(model), digits = 4)
+        preds
+            1     2     3     4     5     6     7     8     9
+        72.12 99.17 63.11 45.08 58.60 85.65 54.09 81.14 49.58
+        options(digits=8)
+        explainedVariance = mean((preds - mean(y))^2)
+        explainedVariance
+        [1] 294.88167
+        meanAbsoluteError = mean(abs(preds - y))
+        meanAbsoluteError
+        [1] 4.5888889
+        meanSquaredError = mean((preds - y)^2)
+        meanSquaredError
+        [1] 39.958711
+        rmse = sqrt(meanSquaredError)
+        rmse
+        [1] 6.3212903
+        r2 = summary(model)$r.squared
+        r2
+        [1] 0.99185395
      */
-    val predictionAndObservations = sc.parallelize(
-      Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2)
-    val metrics = new RegressionMetrics(predictionAndObservations)
-    assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5,
+    val preds = List(72.12, 99.17, 63.11, 45.08, 58.6, 85.65, 54.09, 81.14, 
49.58)
+    val predictionAndObservations = sc.parallelize(preds.zip(obs), 2)
+    val metrics = new RegressionMetrics(predictionAndObservations, true)
+    assert(metrics.explainedVariance ~== 294.88167 absTol eps,
       "explained variance regression score mismatch")
-    assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error 
mismatch")
-    assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error 
mismatch")
-    assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
+    assert(metrics.meanAbsoluteError ~== 4.5888889 absTol eps, "mean absolute 
error mismatch")
+    assert(metrics.meanSquaredError ~== 39.958711 absTol eps, "mean squared 
error mismatch")
+    assert(metrics.rootMeanSquaredError ~== 6.3212903 absTol eps,
       "root mean squared error mismatch")
-    assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch")
+    assert(metrics.r2 ~== 0.99185395 absTol eps, "r2 score mismatch")
   }
 
   test("regression metrics with complete fitting") {
-    val predictionAndObservations = sc.parallelize(
-      Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2)
+    /* Verify results in R:
+        y = c(77, 85, 62, 55, 63, 88, 57, 81, 51)
+        preds = y
+        explainedVariance = mean((preds - mean(y))^2)
+        explainedVariance
+        [1] 174.8395
+        meanAbsoluteError = mean(abs(preds - y))
+        meanAbsoluteError
+        [1] 0
+        meanSquaredError = mean((preds - y)^2)
+        meanSquaredError
+        [1] 0
+        rmse = sqrt(meanSquaredError)
+        rmse
+        [1] 0
+        r2 = 1 - sum((preds - y)^2)/sum((y - mean(y))^2)
+        r2
+        [1] 1
+     */
+    val preds = obs
+    val predictionAndObservations = sc.parallelize(preds.zip(obs), 2)
     val metrics = new RegressionMetrics(predictionAndObservations)
-    assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5,
+    assert(metrics.explainedVariance ~== 174.83951 absTol eps,
       "explained variance regression score mismatch")
-    assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error 
mismatch")
-    assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error 
mismatch")
-    assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5,
+    assert(metrics.meanAbsoluteError ~== 0.0 absTol eps, "mean absolute error 
mismatch")
+    assert(metrics.meanSquaredError ~== 0.0 absTol eps, "mean squared error 
mismatch")
+    assert(metrics.rootMeanSquaredError ~== 0.0 absTol eps,
       "root mean squared error mismatch")
-    assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch")
+    assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch")
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to