Repository: spark Updated Branches: refs/heads/master c09e51398 -> e328b69c3
[SPARK-9492][ML][R] LogisticRegression in R should provide model statistics Like ml ```LinearRegression```, ```LogisticRegression``` should provide a training summary including feature names and their coefficients. Author: Yanbo Liang <[email protected]> Closes #9303 from yanboliang/spark-9492. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e328b69c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e328b69c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e328b69c Branch: refs/heads/master Commit: e328b69c31821e4b27673d7ef6182ab3b7a05ca8 Parents: c09e513 Author: Yanbo Liang <[email protected]> Authored: Wed Nov 4 08:28:33 2015 -0800 Committer: Xiangrui Meng <[email protected]> Committed: Wed Nov 4 08:28:33 2015 -0800 ---------------------------------------------------------------------- R/pkg/inst/tests/test_mllib.R | 17 +++++++++++++++++ .../ml/classification/LogisticRegression.scala | 17 +++++++++++++---- .../org/apache/spark/ml/r/SparkRWrappers.scala | 7 ++++--- project/MimaExcludes.scala | 4 +++- 4 files changed, 37 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e328b69c/R/pkg/inst/tests/test_mllib.R ---------------------------------------------------------------------- diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 3331ce7..032cfef 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -67,3 +67,20 @@ test_that("summary coefficients match with native glm", { as.character(stats$features) == c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) + +test_that("summary coefficients match with native glm of family 'binomial'", { + df <- createDataFrame(sqlContext, iris) + training <- filter(df, df$Species != "setosa") + stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, + family = "binomial")) + coefs <- as.vector(stats$coefficients) + + rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] + rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, + family = binomial(link = "logit")))) + + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Sepal_Width"))) +}) http://git-wip-us.apache.org/repos/asf/spark/blob/e328b69c/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a1335e7..f5fca68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -378,6 +378,7 @@ class LogisticRegression(override val uid: String) model.transform(dataset), $(probabilityCol), $(labelCol), + $(featuresCol), objectiveHistory) model.setSummary(logRegSummary) } @@ -452,7 +453,8 @@ class LogisticRegressionModel private[ml] ( */ // TODO: decide on a good name before exposing to public API private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { - new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol)) + new BinaryLogisticRegressionSummary( + this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol)) } /** @@ -614,9 +616,12 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ def probabilityCol: String - /** Field in "predictions" which gives the the true label of each instance. */ + /** Field in "predictions" which gives the true label of each instance. */ def labelCol: String + /** Field in "predictions" which gives the features of each instance as a vector. */ + def featuresCol: String + } /** @@ -626,6 +631,7 @@ sealed trait LogisticRegressionSummary extends Serializable { * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental @@ -633,8 +639,9 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( predictions: DataFrame, probabilityCol: String, labelCol: String, + featuresCol: String, val objectiveHistory: Array[Double]) - extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol) + extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) with LogisticRegressionTrainingSummary { } @@ -646,12 +653,14 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance. * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. */ @Experimental class BinaryLogisticRegressionSummary private[classification] ( @transient override val predictions: DataFrame, override val probabilityCol: String, - override val labelCol: String) extends LogisticRegressionSummary { + override val labelCol: String, + override val featuresCol: String) extends LogisticRegressionSummary { private val sqlContext = predictions.sqlContext import sqlContext.implicits._ http://git-wip-us.apache.org/repos/asf/spark/blob/e328b69c/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 24f76de..5be2f86 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -66,9 +66,10 @@ private[r] object SparkRWrappers { val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - case _: LogisticRegressionModel => - throw new UnsupportedOperationException( - "No features names available for LogisticRegressionModel") // SPARK-9492 + case m: LogisticRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/e328b69c/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index ec0e44b..eeef96c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -59,7 +59,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticAggregator.add"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticAggregator.count") + "org.apache.spark.ml.classification.LogisticAggregator.count"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol") ) ++ Seq( // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. // This class is marked as `private` but MiMa still seems to be confused by the change. --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
