Repository: spark
Updated Branches:
  refs/heads/branch-2.1 7d4596734 -> e8d8e3509


[SPARK-18476][SPARKR][ML] SparkR Logistic Regression should should support 
output original label.

## What changes were proposed in this pull request?

Similar to SPARK-18401, as a classification algorithm, logistic regression 
should support output original label instead of supporting index label.

In this PR, original label output is supported and test cases are modified and 
added. Document is also modified.

## How was this patch tested?

Unit tests.

Author: wm...@hotmail.com <wm...@hotmail.com>

Closes #15910 from wangmiao1981/audit.

(cherry picked from commit 2eb6764fbb23553fc17772d8a4a1cad55ff7ba6e)
Signed-off-by: Yanbo Liang <yblia...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e8d8e350
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e8d8e350
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e8d8e350

Branch: refs/heads/branch-2.1
Commit: e8d8e350998e6e44a6dee7f78dbe2d1aa997c1d6
Parents: 7d45967
Author: wm...@hotmail.com <wm...@hotmail.com>
Authored: Wed Nov 30 20:32:17 2016 -0800
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Wed Nov 30 20:33:07 2016 -0800

----------------------------------------------------------------------
 R/pkg/R/mllib.R                                 | 19 +++++-----
 R/pkg/inst/tests/testthat/test_mllib.R          | 26 +++++++++-----
 .../scala/org/apache/spark/SparkContext.scala   |  2 +-
 .../spark/ml/r/LogisticRegressionWrapper.scala  | 37 ++++++++++++++------
 4 files changed, 54 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e8d8e350/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 02bc645..eed8293 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -712,7 +712,6 @@ setMethod("predict", signature(object = "KMeansModel"),
 #'                        of L1 and L2. Default is 0.0 which is an L2 penalty.
 #' @param maxIter maximum iteration number.
 #' @param tol convergence tolerance of iterations.
-#' @param fitIntercept whether to fit an intercept term.
 #' @param family the name of family which is a description of the label 
distribution to be used in the model.
 #'               Supported options:
 #'                 \itemize{
@@ -747,11 +746,11 @@ setMethod("predict", signature(object = "KMeansModel"),
 #' \dontrun{
 #' sparkR.session()
 #' # binary logistic regression
-#' label <- c(1.0, 1.0, 1.0, 0.0, 0.0)
-#' feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776)
-#' binary_data <- as.data.frame(cbind(label, feature))
+#' label <- c(0.0, 0.0, 0.0, 1.0, 1.0)
+#' features <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776)
+#' binary_data <- as.data.frame(cbind(label, features))
 #' binary_df <- createDataFrame(binary_data)
-#' blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0)
+#' blr_model <- spark.logit(binary_df, label ~ features, thresholds = 1.0)
 #' blr_predict <- collect(select(predict(blr_model, binary_df), "prediction"))
 #'
 #' # summary of binary logistic regression
@@ -783,7 +782,7 @@ setMethod("predict", signature(object = "KMeansModel"),
 #' @note spark.logit since 2.1.0
 setMethod("spark.logit", signature(data = "SparkDataFrame", formula = 
"formula"),
           function(data, formula, regParam = 0.0, elasticNetParam = 0.0, 
maxIter = 100,
-                   tol = 1E-6, fitIntercept = TRUE, family = "auto", 
standardization = TRUE,
+                   tol = 1E-6, family = "auto", standardization = TRUE,
                    thresholds = 0.5, weightCol = NULL, aggregationDepth = 2,
                    probabilityCol = "probability") {
             formula <- paste(deparse(formula), collapse = "")
@@ -795,10 +794,10 @@ setMethod("spark.logit", signature(data = 
"SparkDataFrame", formula = "formula")
             jobj <- 
callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
                                 data@sdf, formula, as.numeric(regParam),
                                 as.numeric(elasticNetParam), 
as.integer(maxIter),
-                                as.numeric(tol), as.logical(fitIntercept),
-                                as.character(family), 
as.logical(standardization),
-                                as.array(thresholds), as.character(weightCol),
-                                as.integer(aggregationDepth), 
as.character(probabilityCol))
+                                as.numeric(tol), as.character(family),
+                                as.logical(standardization), 
as.array(thresholds),
+                                as.character(weightCol), 
as.integer(aggregationDepth),
+                                as.character(probabilityCol))
             new("LogisticRegressionModel", jobj = jobj)
           })
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e8d8e350/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R 
b/R/pkg/inst/tests/testthat/test_mllib.R
index b05be47..c8f062d 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -646,30 +646,30 @@ test_that("spark.isotonicRegression", {
 
 test_that("spark.logit", {
   # test binary logistic regression
-  label <- c(1.0, 1.0, 1.0, 0.0, 0.0)
+  label <- c(0.0, 0.0, 0.0, 1.0, 1.0)
   feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776)
   binary_data <- as.data.frame(cbind(label, feature))
   binary_df <- createDataFrame(binary_data)
 
   blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0)
   blr_predict <- collect(select(predict(blr_model, binary_df), "prediction"))
-  expect_equal(blr_predict$prediction, c(0, 0, 0, 0, 0))
+  expect_equal(blr_predict$prediction, c("0.0", "0.0", "0.0", "0.0", "0.0"))
   blr_model1 <- spark.logit(binary_df, label ~ feature, thresholds = 0.0)
   blr_predict1 <- collect(select(predict(blr_model1, binary_df), "prediction"))
-  expect_equal(blr_predict1$prediction, c(1, 1, 1, 1, 1))
+  expect_equal(blr_predict1$prediction, c("1.0", "1.0", "1.0", "1.0", "1.0"))
 
   # test summary of binary logistic regression
   blr_summary <- summary(blr_model)
   blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", 
"F-Measure"))
-  expect_equal(blr_fmeasure$threshold, c(0.8221347, 0.7884005, 0.6674709, 
0.3785437, 0.3434487),
+  expect_equal(blr_fmeasure$threshold, c(0.6565513, 0.6214563, 0.3325291, 
0.2115995, 0.1778653),
                tolerance = 1e-4)
-  expect_equal(blr_fmeasure$"F-Measure", c(0.5000000, 0.8000000, 0.6666667, 
0.8571429, 0.7500000),
+  expect_equal(blr_fmeasure$"F-Measure", c(0.6666667, 0.5000000, 0.8000000, 
0.6666667, 0.5714286),
                tolerance = 1e-4)
   blr_precision <- collect(select(blr_summary$precisionByThreshold, 
"threshold", "precision"))
-  expect_equal(blr_precision$precision, c(1.0000000, 1.0000000, 0.6666667, 
0.7500000, 0.6000000),
+  expect_equal(blr_precision$precision, c(1.0000000, 0.5000000, 0.6666667, 
0.5000000, 0.4000000),
                tolerance = 1e-4)
   blr_recall <- collect(select(blr_summary$recallByThreshold, "threshold", 
"recall"))
-  expect_equal(blr_recall$recall, c(0.3333333, 0.6666667, 0.6666667, 
1.0000000, 1.0000000),
+  expect_equal(blr_recall$recall, c(0.5000000, 0.5000000, 1.0000000, 
1.0000000, 1.0000000),
                tolerance = 1e-4)
 
   # test model save and read
@@ -683,6 +683,16 @@ test_that("spark.logit", {
   expect_error(summary(blr_model2))
   unlink(modelPath)
 
+  # test prediction label as text
+  training <- suppressWarnings(createDataFrame(iris))
+  binomial_training <- training[training$Species %in% c("versicolor", 
"virginica"), ]
+  binomial_model <- spark.logit(binomial_training, Species ~ Sepal_Length + 
Sepal_Width)
+  prediction <- predict(binomial_model, binomial_training)
+  expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), 
"character")
+  expected <- c("virginica", "virginica", "virginica", "versicolor", 
"virginica",
+                "versicolor", "virginica", "versicolor", "virginica", 
"versicolor")
+  expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], 
expected)
+
   # test multinomial logistic regression
   label <- c(0.0, 1.0, 2.0, 0.0, 0.0)
   feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667)
@@ -694,7 +704,7 @@ test_that("spark.logit", {
 
   model <- spark.logit(df, label ~., family = "multinomial", thresholds = c(0, 
1, 1))
   predict1 <- collect(select(predict(model, df), "prediction"))
-  expect_equal(predict1$prediction, c(0, 0, 0, 0, 0))
+  expect_equal(predict1$prediction, c("0.0", "0.0", "0.0", "0.0", "0.0"))
   # Summary of multinomial logistic regression is not implemented yet
   expect_error(summary(model))
 })

http://git-wip-us.apache.org/repos/asf/spark/blob/e8d8e350/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala 
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 1cb39a4..b8414b5 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -19,7 +19,7 @@ package org.apache.spark
 
 import java.io._
 import java.lang.reflect.Constructor
-import java.net.{MalformedURLException, URI}
+import java.net.{URI}
 import java.util.{Arrays, Locale, Properties, ServiceLoader, UUID}
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, 
AtomicReference}

http://git-wip-us.apache.org/repos/asf/spark/blob/e8d8e350/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
index 9b352c9..9fe6202 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
@@ -23,9 +23,9 @@ import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.ml.{Pipeline, PipelineModel}
-import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, 
LogisticRegression, LogisticRegressionModel}
-import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.ml.r.RWrapperUtils._
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
 
@@ -34,6 +34,8 @@ private[r] class LogisticRegressionWrapper private (
     val features: Array[String],
     val isLoaded: Boolean = false) extends MLWritable {
 
+  import LogisticRegressionWrapper._
+
   private val logisticRegressionModel: LogisticRegressionModel =
     pipeline.stages(1).asInstanceOf[LogisticRegressionModel]
 
@@ -57,7 +59,11 @@ private[r] class LogisticRegressionWrapper private (
   lazy val recallByThreshold: DataFrame = blrSummary.recallByThreshold
 
   def transform(dataset: Dataset[_]): DataFrame = {
-    pipeline.transform(dataset).drop(logisticRegressionModel.getFeaturesCol)
+    pipeline.transform(dataset)
+      .drop(PREDICTED_LABEL_INDEX_COL)
+      .drop(logisticRegressionModel.getFeaturesCol)
+      .drop(logisticRegressionModel.getLabelCol)
+
   }
 
   override def write: MLWriter = new 
LogisticRegressionWrapper.LogisticRegressionWrapperWriter(this)
@@ -66,6 +72,9 @@ private[r] class LogisticRegressionWrapper private (
 private[r] object LogisticRegressionWrapper
     extends MLReadable[LogisticRegressionWrapper] {
 
+  val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
+  val PREDICTED_LABEL_COL = "prediction"
+
   def fit( // scalastyle:ignore
       data: DataFrame,
       formula: String,
@@ -73,7 +82,6 @@ private[r] object LogisticRegressionWrapper
       elasticNetParam: Double,
       maxIter: Int,
       tol: Double,
-      fitIntercept: Boolean,
       family: String,
       standardization: Boolean,
       thresholds: Array[Double],
@@ -84,14 +92,14 @@ private[r] object LogisticRegressionWrapper
 
     val rFormula = new RFormula()
       .setFormula(formula)
-    RWrapperUtils.checkDataColumns(rFormula, data)
+      .setForceIndexLabel(true)
+    checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
 
-    // get feature names from output schema
-    val schema = rFormulaModel.transform(data).schema
-    val featureAttrs = 
AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
-      .attributes.get
-    val features = featureAttrs.map(_.name.get)
+    val fitIntercept = rFormula.hasIntercept
+
+    // get labels and feature names from output schema
+    val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
 
     // assemble and fit the pipeline
     val logisticRegression = new LogisticRegression()
@@ -105,7 +113,9 @@ private[r] object LogisticRegressionWrapper
       .setWeightCol(weightCol)
       .setAggregationDepth(aggregationDepth)
       .setFeaturesCol(rFormula.getFeaturesCol)
+      .setLabelCol(rFormula.getLabelCol)
       .setProbabilityCol(probability)
+      .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
 
     if (thresholds.length > 1) {
       logisticRegression.setThresholds(thresholds)
@@ -113,8 +123,13 @@ private[r] object LogisticRegressionWrapper
       logisticRegression.setThreshold(thresholds(0))
     }
 
+    val idxToStr = new IndexToString()
+      .setInputCol(PREDICTED_LABEL_INDEX_COL)
+      .setOutputCol(PREDICTED_LABEL_COL)
+      .setLabels(labels)
+
     val pipeline = new Pipeline()
-      .setStages(Array(rFormulaModel, logisticRegression))
+      .setStages(Array(rFormulaModel, logisticRegression, idxToStr))
       .fit(data)
 
     new LogisticRegressionWrapper(pipeline, features)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to