Repository: spark
Updated Branches:
  refs/heads/master c09e51398 -> e328b69c3


[SPARK-9492][ML][R] LogisticRegression in R should provide model statistics

Like ml ```LinearRegression```, ```LogisticRegression``` should provide a 
training summary including feature names and their coefficients.

Author: Yanbo Liang <[email protected]>

Closes #9303 from yanboliang/spark-9492.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e328b69c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e328b69c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e328b69c

Branch: refs/heads/master
Commit: e328b69c31821e4b27673d7ef6182ab3b7a05ca8
Parents: c09e513
Author: Yanbo Liang <[email protected]>
Authored: Wed Nov 4 08:28:33 2015 -0800
Committer: Xiangrui Meng <[email protected]>
Committed: Wed Nov 4 08:28:33 2015 -0800

----------------------------------------------------------------------
 R/pkg/inst/tests/test_mllib.R                      | 17 +++++++++++++++++
 .../ml/classification/LogisticRegression.scala     | 17 +++++++++++++----
 .../org/apache/spark/ml/r/SparkRWrappers.scala     |  7 ++++---
 project/MimaExcludes.scala                         |  4 +++-
 4 files changed, 37 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e328b69c/R/pkg/inst/tests/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 3331ce7..032cfef 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -67,3 +67,20 @@ test_that("summary coefficients match with native glm", {
     as.character(stats$features) ==
     c("(Intercept)", "Sepal_Length", "Species_versicolor", 
"Species_virginica")))
 })
+
+test_that("summary coefficients match with native glm of family 'binomial'", {
+  df <- createDataFrame(sqlContext, iris)
+  training <- filter(df, df$Species != "setosa")
+  stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
+    family = "binomial"))
+  coefs <- as.vector(stats$coefficients)
+
+  rTraining <- iris[iris$Species %in% c("versicolor","virginica"),]
+  rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = 
rTraining,
+    family = binomial(link = "logit"))))
+
+  expect_true(all(abs(rCoefs - coefs) < 1e-4))
+  expect_true(all(
+    as.character(stats$features) ==
+    c("(Intercept)", "Sepal_Length", "Sepal_Width")))
+})

http://git-wip-us.apache.org/repos/asf/spark/blob/e328b69c/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index a1335e7..f5fca68 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -378,6 +378,7 @@ class LogisticRegression(override val uid: String)
       model.transform(dataset),
       $(probabilityCol),
       $(labelCol),
+      $(featuresCol),
       objectiveHistory)
     model.setSummary(logRegSummary)
   }
@@ -452,7 +453,8 @@ class LogisticRegressionModel private[ml] (
    */
   // TODO: decide on a good name before exposing to public API
   private[classification] def evaluate(dataset: DataFrame): 
LogisticRegressionSummary = {
-    new BinaryLogisticRegressionSummary(this.transform(dataset), 
$(probabilityCol), $(labelCol))
+    new BinaryLogisticRegressionSummary(
+      this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol))
   }
 
   /**
@@ -614,9 +616,12 @@ sealed trait LogisticRegressionSummary extends 
Serializable {
   /** Field in "predictions" which gives the calibrated probability of each 
instance as a vector. */
   def probabilityCol: String
 
-  /** Field in "predictions" which gives the the true label of each instance. 
*/
+  /** Field in "predictions" which gives the true label of each instance. */
   def labelCol: String
 
+  /** Field in "predictions" which gives the features of each instance as a 
vector. */
+  def featuresCol: String
+
 }
 
 /**
@@ -626,6 +631,7 @@ sealed trait LogisticRegressionSummary extends Serializable 
{
  * @param probabilityCol field in "predictions" which gives the calibrated 
probability of
  *                       each instance as a vector.
  * @param labelCol field in "predictions" which gives the true label of each 
instance.
+ * @param featuresCol field in "predictions" which gives the features of each 
instance as a vector.
  * @param objectiveHistory objective function (scaled loss + regularization) 
at each iteration.
  */
 @Experimental
@@ -633,8 +639,9 @@ class BinaryLogisticRegressionTrainingSummary 
private[classification] (
     predictions: DataFrame,
     probabilityCol: String,
     labelCol: String,
+    featuresCol: String,
     val objectiveHistory: Array[Double])
-  extends BinaryLogisticRegressionSummary(predictions, probabilityCol, 
labelCol)
+  extends BinaryLogisticRegressionSummary(predictions, probabilityCol, 
labelCol, featuresCol)
   with LogisticRegressionTrainingSummary {
 
 }
@@ -646,12 +653,14 @@ class BinaryLogisticRegressionTrainingSummary 
private[classification] (
  * @param probabilityCol field in "predictions" which gives the calibrated 
probability of
  *                       each instance.
  * @param labelCol field in "predictions" which gives the true label of each 
instance.
+ * @param featuresCol field in "predictions" which gives the features of each 
instance as a vector.
  */
 @Experimental
 class BinaryLogisticRegressionSummary private[classification] (
     @transient override val predictions: DataFrame,
     override val probabilityCol: String,
-    override val labelCol: String) extends LogisticRegressionSummary {
+    override val labelCol: String,
+    override val featuresCol: String) extends LogisticRegressionSummary {
 
   private val sqlContext = predictions.sqlContext
   import sqlContext.implicits._

http://git-wip-us.apache.org/repos/asf/spark/blob/e328b69c/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 24f76de..5be2f86 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -66,9 +66,10 @@ private[r] object SparkRWrappers {
         val attrs = AttributeGroup.fromStructField(
           m.summary.predictions.schema(m.summary.featuresCol))
         Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
-      case _: LogisticRegressionModel =>
-        throw new UnsupportedOperationException(
-          "No features names available for LogisticRegressionModel")  // 
SPARK-9492
+      case m: LogisticRegressionModel =>
+        val attrs = AttributeGroup.fromStructField(
+          m.summary.predictions.schema(m.summary.featuresCol))
+        Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e328b69c/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index ec0e44b..eeef96c 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -59,7 +59,9 @@ object MimaExcludes {
         ProblemFilters.exclude[MissingMethodProblem](
           "org.apache.spark.ml.classification.LogisticAggregator.add"),
         ProblemFilters.exclude[MissingMethodProblem](
-          "org.apache.spark.ml.classification.LogisticAggregator.count")
+          "org.apache.spark.ml.classification.LogisticAggregator.count"),
+        ProblemFilters.exclude[MissingMethodProblem](
+          
"org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol")
       ) ++ Seq(
         // SPARK-10381 Fix types / units in private 
AskPermissionToCommitOutput RPC message.
         // This class is marked as `private` but MiMa still seems to be 
confused by the change.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to