Repository: spark Updated Branches: refs/heads/master 8f0e88df0 -> 7f99a05e6
[SPARK-22422][ML] Add Adjusted R2 to RegressionMetrics ## What changes were proposed in this pull request? I added adjusted R2 as a regression metric which was implemented in all major statistical analysis tools. In practice, no one looks at R2 alone. The reason is R2 itself is misleading. If we add more parameters, R2 will not decrease but only increase (or stay the same). This leads to overfitting. Adjusted R2 addressed this issue by using number of parameters as "weight" for the sum of errors. ## How was this patch tested? - Added a new unit test and passed. - ./dev/run-tests all passed. Author: test <[email protected]> Author: tengpeng <[email protected]> Closes #19638 from tengpeng/master. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7f99a05e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7f99a05e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7f99a05e Branch: refs/heads/master Commit: 7f99a05e6ff258fc2192130451aa8aa1304bfe93 Parents: 8f0e88d Author: test <[email protected]> Authored: Wed Nov 15 10:13:01 2017 -0600 Committer: Sean Owen <[email protected]> Committed: Wed Nov 15 10:13:01 2017 -0600 ---------------------------------------------------------------------- .../spark/ml/regression/LinearRegression.scala | 15 +++++++++++++++ .../spark/ml/regression/LinearRegressionSuite.scala | 6 ++++++ 2 files changed, 21 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7f99a05e/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index df1aa60..da6bcf0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -722,6 +722,21 @@ class LinearRegressionSummary private[regression] ( @Since("1.5.0") val r2: Double = metrics.r2 + /** + * Returns Adjusted R^2^, the adjusted coefficient of determination. + * Reference: <a href="https://en.wikipedia.org/wiki/Coefficient_of_determination#Adjusted_R2"> + * Wikipedia coefficient of determination</a> + * + * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. + * This will change in later Spark versions. + */ + @Since("2.3.0") + val r2adj: Double = { + val interceptDOF = if (privateModel.getFitIntercept) 1 else 0 + 1 - (1 - r2) * (numInstances - interceptDOF) / + (numInstances - privateModel.coefficients.size - interceptDOF) + } + /** Residuals (label - predicted value) */ @Since("1.5.0") @transient lazy val residuals: DataFrame = { http://git-wip-us.apache.org/repos/asf/spark/blob/7f99a05e/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index f470dca..0e0be58 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -764,6 +764,11 @@ class LinearRegressionSuite (Intercept) 6.3022157 0.0018600 3388 <2e-16 *** V2 4.6982442 0.0011805 3980 <2e-16 *** V3 7.1994344 0.0009044 7961 <2e-16 *** + + # R code for r2adj + lm_fit <- lm(V1 ~ V2 + V3, data = d1) + summary(lm_fit)$adj.r.squared + [1] 0.9998736 --- .... @@ -771,6 +776,7 @@ class LinearRegressionSuite assert(model.summary.meanSquaredError ~== 0.00985449 relTol 1E-4) assert(model.summary.meanAbsoluteError ~== 0.07961668 relTol 1E-4) assert(model.summary.r2 ~== 0.9998737 relTol 1E-4) + assert(model.summary.r2adj ~== 0.9998736 relTol 1E-4) // Normal solver uses "WeightedLeastSquares". If no regularization is applied or only L2 // regularization is applied, this algorithm uses a direct solver and does not generate an --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
