Repository: spark
Updated Branches:
  refs/heads/master 442287ae2 -> ad09e4ca0


[MINOR][SPARKR][ML] Joint coefficients with intercept for SparkR linear SVM 
summary.

## What changes were proposed in this pull request?
Joint coefficients with intercept for SparkR linear SVM summary.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #18035 from yanboliang/svm-r.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ad09e4ca
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ad09e4ca
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ad09e4ca

Branch: refs/heads/master
Commit: ad09e4ca045715d053a672c2ba23f598f06085d8
Parents: 442287a
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Tue May 23 16:16:14 2017 +0800
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Tue May 23 16:16:14 2017 +0800

----------------------------------------------------------------------
 R/pkg/R/mllib_classification.R                  | 38 ++++++++------------
 .../tests/testthat/test_mllib_classification.R  |  3 +-
 .../apache/spark/ml/r/LinearSVCWrapper.scala    | 12 +++++--
 3 files changed, 26 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ad09e4ca/R/pkg/R/mllib_classification.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index 4db9cc3..306a9b8 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -46,15 +46,16 @@ setClass("MultilayerPerceptronClassificationModel", 
representation(jobj = "jobj"
 #' @note NaiveBayesModel since 2.0.0
 setClass("NaiveBayesModel", representation(jobj = "jobj"))
 
-#' linear SVM Model
+#' Linear SVM Model
 #'
-#' Fits an linear SVM model against a SparkDataFrame. It is a binary 
classifier, similar to svm in glmnet package
+#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 
package.
+#' Currently only supports binary classification model with linear kernel.
 #' Users can print, make predictions on the produced model and save the model 
to the input path.
 #'
 #' @param data SparkDataFrame for training.
 #' @param formula A symbolic description of the model to be fitted. Currently 
only a few formula
 #'                operators are supported, including '~', '.', ':', '+', and 
'-'.
-#' @param regParam The regularization parameter.
+#' @param regParam The regularization parameter. Only supports L2 
regularization currently.
 #' @param maxIter Maximum iteration number.
 #' @param tol Convergence tolerance of iterations.
 #' @param standardization Whether to standardize the training features before 
fitting the model. The coefficients
@@ -111,10 +112,10 @@ setMethod("spark.svmLinear", signature(data = 
"SparkDataFrame", formula = "formu
             new("LinearSVCModel", jobj = jobj)
           })
 
-#  Predicted values based on an LinearSVCModel model
+#  Predicted values based on a LinearSVCModel model
 
 #' @param newData a SparkDataFrame for testing.
-#' @return \code{predict} returns the predicted values based on an 
LinearSVCModel.
+#' @return \code{predict} returns the predicted values based on a 
LinearSVCModel.
 #' @rdname spark.svmLinear
 #' @aliases predict,LinearSVCModel,SparkDataFrame-method
 #' @export
@@ -124,13 +125,12 @@ setMethod("predict", signature(object = "LinearSVCModel"),
             predict_internal(object, newData)
           })
 
-#  Get the summary of an LinearSVCModel
+#  Get the summary of a LinearSVCModel
 
-#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}.
+#' @param object a LinearSVCModel fitted by \code{spark.svmLinear}.
 #' @return \code{summary} returns summary information of the fitted model, 
which is a list.
 #'         The list includes \code{coefficients} (coefficients of the fitted 
model),
-#'         \code{intercept} (intercept of the fitted model), \code{numClasses} 
(number of classes),
-#'         \code{numFeatures} (number of features).
+#'         \code{numClasses} (number of classes), \code{numFeatures} (number 
of features).
 #' @rdname spark.svmLinear
 #' @aliases summary,LinearSVCModel-method
 #' @export
@@ -138,22 +138,14 @@ setMethod("predict", signature(object = "LinearSVCModel"),
 setMethod("summary", signature(object = "LinearSVCModel"),
           function(object) {
             jobj <- object@jobj
-            features <- callJMethod(jobj, "features")
-            labels <- callJMethod(jobj, "labels")
-            coefficients <- callJMethod(jobj, "coefficients")
-            nCol <- length(coefficients) / length(features)
-            coefficients <- matrix(unlist(coefficients), ncol = nCol)
-            intercept <- callJMethod(jobj, "intercept")
+            features <- callJMethod(jobj, "rFeatures")
+            coefficients <- callJMethod(jobj, "rCoefficients")
+            coefficients <- as.matrix(unlist(coefficients))
+            colnames(coefficients) <- c("Estimate")
+            rownames(coefficients) <- unlist(features)
             numClasses <- callJMethod(jobj, "numClasses")
             numFeatures <- callJMethod(jobj, "numFeatures")
-            if (nCol == 1) {
-              colnames(coefficients) <- c("Estimate")
-            } else {
-              colnames(coefficients) <- unlist(labels)
-            }
-            rownames(coefficients) <- unlist(features)
-            list(coefficients = coefficients, intercept = intercept,
-                 numClasses = numClasses, numFeatures = numFeatures)
+            list(coefficients = coefficients, numClasses = numClasses, 
numFeatures = numFeatures)
           })
 
 #  Save fitted LinearSVCModel to the input path

http://git-wip-us.apache.org/repos/asf/spark/blob/ad09e4ca/R/pkg/inst/tests/testthat/test_mllib_classification.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R 
b/R/pkg/inst/tests/testthat/test_mllib_classification.R
index abf8bb2..c1c7468 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_classification.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R
@@ -38,9 +38,8 @@ test_that("spark.svmLinear", {
   expect_true(class(summary$coefficients[, 1]) == "numeric")
 
   coefs <- summary$coefficients[, "Estimate"]
-  expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085)
+  expected_coefs <- c(-0.06004978, -0.1563083, -0.460648, 0.2276626, 1.055085)
   expect_true(all(abs(coefs - expected_coefs) < 0.1))
-  expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2)
 
   # Test prediction with string label
   prediction <- predict(model, training)

http://git-wip-us.apache.org/repos/asf/spark/blob/ad09e4ca/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
index cfd043b..0dd1f11 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
@@ -38,9 +38,17 @@ private[r] class LinearSVCWrapper private (
   private val svcModel: LinearSVCModel =
     pipeline.stages(1).asInstanceOf[LinearSVCModel]
 
-  lazy val coefficients: Array[Double] = svcModel.coefficients.toArray
+  lazy val rFeatures: Array[String] = if (svcModel.getFitIntercept) {
+    Array("(Intercept)") ++ features
+  } else {
+    features
+  }
 
-  lazy val intercept: Double = svcModel.intercept
+  lazy val rCoefficients: Array[Double] = if (svcModel.getFitIntercept) {
+    Array(svcModel.intercept) ++ svcModel.coefficients.toArray
+  } else {
+    svcModel.coefficients.toArray
+  }
 
   lazy val numClasses: Int = svcModel.numClasses
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to