Repository: spark Updated Branches: refs/heads/master 8eb2dc713 -> 1cdc42d2b
[SPARK-12331][ML] R^2 for regression through the origin. Modified the definition of R^2 for regression through origin. Added modified test for regression metrics. Author: Imran Younus <iyou...@us.ibm.com> Author: Imran Younus <imranyou...@gmail.com> Closes #10384 from iyounus/SPARK_12331_R2_for_regression_through_origin. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1cdc42d2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1cdc42d2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1cdc42d2 Branch: refs/heads/master Commit: 1cdc42d2b99edfec01066699a7620cca02b61f0e Parents: 8eb2dc7 Author: Imran Younus <iyou...@us.ibm.com> Authored: Tue Jan 5 11:48:45 2016 +0000 Committer: Sean Owen <so...@cloudera.com> Committed: Tue Jan 5 11:48:45 2016 +0000 ---------------------------------------------------------------------- .../spark/ml/regression/LinearRegression.scala | 3 +- .../mllib/evaluation/RegressionMetrics.scala | 24 ++- .../evaluation/RegressionMetricsSuite.scala | 156 +++++++++++-------- 3 files changed, 112 insertions(+), 71 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1cdc42d2/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index dee2633..c54e08b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -534,7 +534,8 @@ class LinearRegressionSummary private[regression] ( @transient private val metrics = new RegressionMetrics( predictions .select(predictionCol, labelCol) - .map { case Row(pred: Double, label: Double) => (pred, label) } ) + .map { case Row(pred: Double, label: Double) => (pred, label) }, + !model.getFitIntercept) /** * Returns the explained variance regression score. http://git-wip-us.apache.org/repos/asf/spark/blob/1cdc42d2/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 34883f2..18c90b2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -27,11 +27,18 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for regression. * - * @param predictionAndObservations an RDD of (prediction, observation) pairs. + * @param predictionAndObservations an RDD of (prediction, observation) pairs + * @param throughOrigin True if the regression is through the origin. For example, in linear + * regression, it will be true without fitting intercept. */ @Since("1.2.0") -class RegressionMetrics @Since("1.2.0") ( - predictionAndObservations: RDD[(Double, Double)]) extends Logging { +class RegressionMetrics @Since("2.0.0") ( + predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) + extends Logging { + + @Since("1.2.0") + def this(predictionAndObservations: RDD[(Double, Double)]) = + this(predictionAndObservations, false) /** * An auxiliary constructor taking a DataFrame. @@ -53,6 +60,8 @@ class RegressionMetrics @Since("1.2.0") ( ) summary } + + private lazy val SSy = math.pow(summary.normL2(0), 2) private lazy val SSerr = math.pow(summary.normL2(1), 2) private lazy val SStot = summary.variance(0) * (summary.count - 1) private lazy val SSreg = { @@ -102,9 +111,16 @@ class RegressionMetrics @Since("1.2.0") ( /** * Returns R^2^, the unadjusted coefficient of determination. * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * In case of regression through the origin, the definition of R^2^ is to be modified. + * @see J. G. Eisenhauer, Regression through the Origin. Teaching Statistics 25, 76-80 (2003) + * [[https://online.stat.psu.edu/~ajw13/stat501/SpecialTopics/Reg_thru_origin.pdf]] */ @Since("1.2.0") def r2: Double = { - 1 - SSerr / SStot + if (throughOrigin) { + 1 - SSerr / SSy + } else { + 1 - SSerr / SStot + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/1cdc42d2/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 4b7f1be..f1d5173 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -22,91 +22,115 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { + val obs = List[Double](77, 85, 62, 55, 63, 88, 57, 81, 51) + val eps = 1E-5 test("regression metrics for unbiased (includes intercept term) predictor") { /* Verify results in R: - preds = c(2.25, -0.25, 1.75, 7.75) - obs = c(3.0, -0.5, 2.0, 7.0) - - SStot = sum((obs - mean(obs))^2) - SSreg = sum((preds - mean(obs))^2) - SSerr = sum((obs - preds)^2) - - explainedVariance = SSreg / length(obs) - explainedVariance - > [1] 8.796875 - meanAbsoluteError = mean(abs(preds - obs)) - meanAbsoluteError - > [1] 0.5 - meanSquaredError = mean((preds - obs)^2) - meanSquaredError - > [1] 0.3125 - rmse = sqrt(meanSquaredError) - rmse - > [1] 0.559017 - r2 = 1 - SSerr / SStot - r2 - > [1] 0.9571734 + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + x = c(16, 22, 14, 10, 13, 19, 12, 18, 11) + df <- as.data.frame(cbind(x, y)) + model <- lm(y ~ x, data=df) + preds = signif(predict(model), digits = 4) + preds + 1 2 3 4 5 6 7 8 9 + 72.08 91.88 65.48 52.28 62.18 81.98 58.88 78.68 55.58 + options(digits=8) + explainedVariance = mean((preds - mean(y))^2) + [1] 157.3 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 3.7355556 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 17.539511 + rmse = sqrt(meanSquaredError) + rmse + [1] 4.18802 + r2 = summary(model)$r.squared + r2 + [1] 0.89968225 */ - val predictionAndObservations = sc.parallelize( - Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2) + val preds = List(72.08, 91.88, 65.48, 52.28, 62.18, 81.98, 58.88, 78.68, 55.58) + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5, + assert(metrics.explainedVariance ~== 157.3 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 3.7355556 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 17.539511 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 4.18802 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.89968225 absTol eps, "r2 score mismatch") } test("regression metrics for biased (no intercept term) predictor") { /* Verify results in R: - preds = c(2.5, 0.0, 2.0, 8.0) - obs = c(3.0, -0.5, 2.0, 7.0) - - SStot = sum((obs - mean(obs))^2) - SSreg = sum((preds - mean(obs))^2) - SSerr = sum((obs - preds)^2) - - explainedVariance = SSreg / length(obs) - explainedVariance - > [1] 8.859375 - meanAbsoluteError = mean(abs(preds - obs)) - meanAbsoluteError - > [1] 0.5 - meanSquaredError = mean((preds - obs)^2) - meanSquaredError - > [1] 0.375 - rmse = sqrt(meanSquaredError) - rmse - > [1] 0.6123724 - r2 = 1 - SSerr / SStot - r2 - > [1] 0.9486081 + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + x = c(16, 22, 14, 10, 13, 19, 12, 18, 11) + df <- as.data.frame(cbind(x, y)) + model <- lm(y ~ 0 + x, data=df) + preds = signif(predict(model), digits = 4) + preds + 1 2 3 4 5 6 7 8 9 + 72.12 99.17 63.11 45.08 58.60 85.65 54.09 81.14 49.58 + options(digits=8) + explainedVariance = mean((preds - mean(y))^2) + explainedVariance + [1] 294.88167 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 4.5888889 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 39.958711 + rmse = sqrt(meanSquaredError) + rmse + [1] 6.3212903 + r2 = summary(model)$r.squared + r2 + [1] 0.99185395 */ - val predictionAndObservations = sc.parallelize( - Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) - val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5, + val preds = List(72.12, 99.17, 63.11, 45.08, 58.6, 85.65, 54.09, 81.14, 49.58) + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) + val metrics = new RegressionMetrics(predictionAndObservations, true) + assert(metrics.explainedVariance ~== 294.88167 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 4.5888889 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 39.958711 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 6.3212903 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.99185395 absTol eps, "r2 score mismatch") } test("regression metrics with complete fitting") { - val predictionAndObservations = sc.parallelize( - Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) + /* Verify results in R: + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + preds = y + explainedVariance = mean((preds - mean(y))^2) + explainedVariance + [1] 174.8395 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 0 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 0 + rmse = sqrt(meanSquaredError) + rmse + [1] 0 + r2 = 1 - sum((preds - y)^2)/sum((y - mean(y))^2) + r2 + [1] 1 + */ + val preds = obs + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5, + assert(metrics.explainedVariance ~== 174.83951 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 0.0 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.0 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.0 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch") } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org