Repository: spark Updated Branches: refs/heads/master b0adb9f54 -> 0d17593b3
[SPARK-14461][ML] GLM training summaries should provide solver ## What changes were proposed in this pull request? GLM training summaries should provide solver. ## How was this patch tested? Unit tests. cc jkbradley Author: Yanbo Liang <yblia...@gmail.com> Closes #12253 from yanboliang/spark-14461. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0d17593b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0d17593b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0d17593b Branch: refs/heads/master Commit: 0d17593b32c12c3e39575430aa85cf20e56fae6a Parents: b0adb9f Author: Yanbo Liang <yblia...@gmail.com> Authored: Wed Apr 13 13:20:29 2016 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Apr 13 13:20:29 2016 -0700 ---------------------------------------------------------------------- .../spark/ml/regression/GeneralizedLinearRegression.scala | 10 +++++++--- .../ml/regression/GeneralizedLinearRegressionSuite.scala | 4 ++++ 2 files changed, 11 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0d17593b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 00cf25d..e92a3e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -237,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val predictionColName, model, wlsModel.diagInvAtWA.toArray, - 1) + 1, + getSolver) return model.setSummary(trainingSummary) } @@ -257,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val predictionColName, model, irlsModel.diagInvAtWA.toArray, - irlsModel.numIterations) + irlsModel.numIterations, + getSolver) model.setSummary(trainingSummary) } @@ -781,6 +783,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr * @param model the model that should be summarized * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration * @param numIterations number of iterations + * @param solver the solver algorithm used for model training */ @Since("2.0.0") @Experimental @@ -789,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( @Since("2.0.0") val predictionCol: String, @Since("2.0.0") val model: GeneralizedLinearRegressionModel, private val diagInvAtWA: Array[Double], - @Since("2.0.0") val numIterations: Int) extends Serializable { + @Since("2.0.0") val numIterations: Int, + @Since("2.0.0") val solver: String) extends Serializable { import GeneralizedLinearRegression._ http://git-wip-us.apache.org/repos/asf/spark/blob/0d17593b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 4905f3e..3ecc210 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -626,6 +626,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("glm summary: binomial family with weight") { @@ -739,6 +740,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("glm summary: poisson family with weight") { @@ -855,6 +857,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("glm summary: gamma family with weight") { @@ -968,6 +971,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("read/write") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org