This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0d37f79  [SPARK-30820][SPARKR][ML] Add FMClassifier to SparkR
0d37f79 is described below

commit 0d37f794ef9451199e1b757c1015fc7a8b3931a5
Author: zero323 <mszymkiew...@gmail.com>
AuthorDate: Tue Apr 7 09:01:45 2020 -0500

    [SPARK-30820][SPARKR][ML] Add FMClassifier to SparkR
    
    ### What changes were proposed in this pull request?
    
    This pull request adds SparkR wrapper for `FMClassifier`:
    
    - Supporting ` org.apache.spark.ml.r.FMClassifierWrapper`.
    - `FMClassificationModel` S4 class.
    - Corresponding `spark.fmClassifier`, `predict`, `summary` and `write.ml` 
generics.
    - Corresponding docs and tests.
    
    ### Why are the changes needed?
    
    Feature parity.
    
    ### Does this PR introduce any user-facing change?
    
    No (new API).
    
    ### How was this patch tested?
    
    New unit tests.
    
    Closes #27570 from zero323/SPARK-30820.
    
    Authored-by: zero323 <mszymkiew...@gmail.com>
    Signed-off-by: Sean Owen <sro...@gmail.com>
---
 R/pkg/NAMESPACE                                    |   3 +-
 R/pkg/R/generics.R                                 |   4 +
 R/pkg/R/mllib_classification.R                     | 157 ++++++++++++++++++
 R/pkg/R/mllib_utils.R                              |   2 +
 R/pkg/tests/fulltests/test_mllib_classification.R  |  34 ++++
 R/pkg/vignettes/sparkr-vignettes.Rmd               |  20 +++
 docs/ml-classification-regression.md               |  72 +++++----
 docs/sparkr.md                                     |   6 +-
 examples/src/main/r/ml/fmClassifier.R              |  45 ++++++
 .../apache/spark/ml/r/FMClassifierWrapper.scala    | 175 +++++++++++++++++++++
 .../scala/org/apache/spark/ml/r/RWrappers.scala    |   2 +
 11 files changed, 484 insertions(+), 36 deletions(-)

diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index fb879e4..18e4570 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -72,7 +72,8 @@ exportMethods("glm",
               "spark.freqItemsets",
               "spark.associationRules",
               "spark.findFrequentSequentialPatterns",
-              "spark.assignClusters")
+              "spark.assignClusters",
+              "spark.fmClassifier")
 
 # Job group lifecycle management methods
 export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index d924b2a..d36496f 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1479,6 +1479,10 @@ setGeneric("spark.als", function(data, ...) { 
standardGeneric("spark.als") })
 setGeneric("spark.bisectingKmeans",
            function(data, formula, ...) { 
standardGeneric("spark.bisectingKmeans") })
 
+#' @rdname spark.fmClassifier
+setGeneric("spark.fmClassifier",
+           function(data, formula, ...) { 
standardGeneric("spark.fmClassifier") })
+
 #' @rdname spark.gaussianMixture
 setGeneric("spark.gaussianMixture",
            function(data, formula, ...) { 
standardGeneric("spark.gaussianMixture") })
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index 3ad824e..fc5ac9f 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -42,6 +42,12 @@ setClass("MultilayerPerceptronClassificationModel", 
representation(jobj = "jobj"
 #' @note NaiveBayesModel since 2.0.0
 setClass("NaiveBayesModel", representation(jobj = "jobj"))
 
+#' S4 class that represents a FMClassificationModel
+#'
+#' @param jobj a Java object reference to the backing Scala FMClassifierWrapper
+#' @note FMClassificationModel since 3.1.0
+setClass("FMClassificationModel", representation(jobj = "jobj"))
+
 #' Linear SVM Model
 #'
 #' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 
package.
@@ -649,3 +655,154 @@ setMethod("write.ml", signature(object = 
"NaiveBayesModel", path = "character"),
           function(object, path, overwrite = FALSE) {
             write_internal(object, path, overwrite)
           })
+
+#' Factorization Machines Classification Model
+#'
+#' \code{spark.fmClassifier} fits a factorization classification model against 
a SparkDataFrame.
+#' Users can call \code{summary} to print a summary of the fitted model, 
\code{predict} to make
+#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load 
fitted models.
+#' Only categorical data is supported.
+#'
+#' @param data a \code{SparkDataFrame} of observations and labels for model 
fitting.
+#' @param formula a symbolic description of the model to be fitted. Currently 
only a few formula
+#'                operators are supported, including '~', '.', ':', '+', and 
'-'.
+#' @param factorSize dimensionality of the factors.
+#' @param fitLinear whether to fit linear term.  # TODO Can we express this 
with formula?
+#' @param regParam the regularization parameter.
+#' @param miniBatchFraction the mini-batch fraction parameter.
+#' @param initStd the standard deviation of initial coefficients.
+#' @param maxIter maximum iteration number.
+#' @param stepSize stepSize parameter.
+#' @param tol convergence tolerance of iterations.
+#' @param solver solver parameter, supported options: "gd" (minibatch gradient 
descent) or "adamW".
+#' @param thresholds in binary classification, in range [0, 1]. If the 
estimated probability of
+#'                   class label 1 is > threshold, then predict 1, else 0. A 
high threshold
+#'                   encourages the model to predict 0 more often; a low 
threshold encourages the
+#'                   model to predict 1 more often. Note: Setting this with 
threshold p is
+#'                   equivalent to setting thresholds c(1-p, p).
+#' @param seed seed parameter for weights initialization.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL 
values) in features and
+#'                      label column of string type.
+#'                      Supported options: "skip" (filter out rows with 
invalid data),
+#'                                         "error" (throw an error), "keep" 
(put invalid data in
+#'                                         a special additional bucket, at 
index numLabels). Default
+#'                                         is "error".
+#' @param ... additional arguments passed to the method.
+#' @return \code{spark.fmClassifier} returns a fitted Factorization Machines 
Classification Model.
+#' @rdname spark.fmClassifier
+#' @aliases spark.fmClassifier,SparkDataFrame,formula-method
+#' @name spark.fmClassifier
+#' @seealso \link{read.ml}
+#' @examples
+#' \dontrun{
+#' df <- read.df("data/mllib/sample_binary_classification_data.txt", source = 
"libsvm")
+#'
+#' # fit Factorization Machines Classification Model
+#' model <- spark.fmClassifier(
+#'            df, label ~ features,
+#'            regParam = 0.01, maxIter = 10, fitLinear = TRUE
+#'          )
+#'
+#' # get the summary of the model
+#' summary(model)
+#'
+#' # make predictions
+#' predictions <- predict(model, df)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#' }
+#' @note spark.fmClassifier since 3.1.0
+setMethod("spark.fmClassifier", signature(data = "SparkDataFrame", formula = 
"formula"),
+          function(data, formula, factorSize = 8, fitLinear = TRUE, regParam = 
0.0,
+                   miniBatchFraction = 1.0, initStd = 0.01, maxIter = 100, 
stepSize=1.0,
+                   tol = 1e-6, solver = c("adamW", "gd"), thresholds = NULL, 
seed = NULL,
+                   handleInvalid = c("error", "keep", "skip")) {
+
+            formula <- paste(deparse(formula), collapse = "")
+
+            if (!is.null(seed)) {
+              seed <- as.character(as.integer(seed))
+            }
+
+            if (!is.null(thresholds)) {
+              thresholds <- as.list(thresholds)
+            }
+
+            solver <- match.arg(solver)
+            handleInvalid <- match.arg(handleInvalid)
+
+            jobj <- callJStatic("org.apache.spark.ml.r.FMClassifierWrapper",
+                                "fit",
+                                data@sdf,
+                                formula,
+                                as.integer(factorSize),
+                                as.logical(fitLinear),
+                                as.numeric(regParam),
+                                as.numeric(miniBatchFraction),
+                                as.numeric(initStd),
+                                as.integer(maxIter),
+                                as.numeric(stepSize),
+                                as.numeric(tol),
+                                solver,
+                                seed,
+                                thresholds,
+                                handleInvalid)
+            new("FMClassificationModel", jobj = jobj)
+          })
+
+#  Returns the summary of a FM Classification model produced by 
\code{spark.fmClassifier}
+
+#' @param object a FM Classification model fitted by \code{spark.fmClassifier}.
+#' @return \code{summary} returns summary information of the fitted model, 
which is a list.
+#' @rdname spark.fmClassifier
+#' @note summary(FMClassificationModel) since 3.1.0
+setMethod("summary", signature(object = "FMClassificationModel"),
+          function(object) {
+            jobj <- object@jobj
+            features <- callJMethod(jobj, "rFeatures")
+            coefficients <- callJMethod(jobj, "rCoefficients")
+            coefficients <- as.matrix(unlist(coefficients))
+            colnames(coefficients) <- c("Estimate")
+            rownames(coefficients) <- unlist(features)
+            numClasses <- callJMethod(jobj, "numClasses")
+            numFeatures <- callJMethod(jobj, "numFeatures")
+            raw_factors <- unlist(callJMethod(jobj, "rFactors"))
+            factor_size <- callJMethod(jobj, "factorSize")
+
+            list(
+              coefficients = coefficients,
+              factors = matrix(raw_factors, ncol = factor_size),
+              numClasses = numClasses, numFeatures = numFeatures,
+              factorSize = factor_size
+            )
+          })
+
+#  Predicted values based on an FMClassificationModel model
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns the predicted values based on a FM 
Classification model.
+#' @rdname spark.fmClassifier
+#' @aliases predict,FMClassificationModel,SparkDataFrame-method
+#' @note predict(FMClassificationModel) since 3.1.0
+setMethod("predict", signature(object = "FMClassificationModel"),
+          function(object, newData) {
+            predict_internal(object, newData)
+          })
+
+#  Save fitted FMClassificationModel to the input path
+
+#' @param path The directory where the model is saved.
+#' @param overwrite Overwrites or not if the output path already exists. 
Default is FALSE
+#'                  which means throw exception if the output path exists.
+#'
+#' @rdname spark.fmClassifier
+#' @aliases write.ml,FMClassificationModel,character-method
+#' @note write.ml(FMClassificationModel, character) since 3.1.0
+setMethod("write.ml", signature(object = "FMClassificationModel", path = 
"character"),
+          function(object, path, overwrite = FALSE) {
+            write_internal(object, path, overwrite)
+          })
diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R
index 7d04bff..f5643aa 100644
--- a/R/pkg/R/mllib_utils.R
+++ b/R/pkg/R/mllib_utils.R
@@ -123,6 +123,8 @@ read.ml <- function(path) {
     new("LinearSVCModel", jobj = jobj)
   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) {
     new("FPGrowthModel", jobj = jobj)
+  } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FMClassifierWrapper")) {
+    new("FMClassificationModel", jobj = jobj)
   } else {
     stop("Unsupported model: ", jobj)
   }
diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R 
b/R/pkg/tests/fulltests/test_mllib_classification.R
index 2da3a02..9dd275a 100644
--- a/R/pkg/tests/fulltests/test_mllib_classification.R
+++ b/R/pkg/tests/fulltests/test_mllib_classification.R
@@ -488,4 +488,38 @@ test_that("spark.naiveBayes", {
   expect_equal(class(collect(predictions)$clicked[1]), "character")
 })
 
+test_that("spark.fmClassifier", {
+  df <- withColumn(
+    suppressWarnings(createDataFrame(iris)),
+    "Species", otherwise(when(column("Species") == "Setosa", "Setosa"), 
"Not-Setosa")
+  )
+
+  model1 <- spark.fmClassifier(
+    df,  Species ~ .,
+    regParam = 0.01, maxIter = 10, fitLinear = TRUE, factorSize = 3
+  )
+
+  prediction1 <- predict(model1, df)
+  expect_is(prediction1, "SparkDataFrame")
+  expect_equal(summary(model1)$factorSize, 3)
+
+  # Test model save/load
+  if (windows_with_hadoop()) {
+    modelPath <- tempfile(pattern = "spark-fmclassifier", fileext = ".tmp")
+    write.ml(model1, modelPath)
+    model2 <- read.ml(modelPath)
+
+    expect_is(model2, "FMClassificationModel")
+
+    expect_equal(summary(model1), summary(model2))
+
+    prediction2 <- predict(model2, df)
+    expect_equal(
+      collect(drop(prediction1, c("rawPrediction", "probability"))),
+      collect(drop(prediction2, c("rawPrediction", "probability")))
+    )
+    unlink(modelPath)
+  }
+})
+
 sparkR.session.stop()
diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd 
b/R/pkg/vignettes/sparkr-vignettes.Rmd
index 9e48ae3..6b62a66 100644
--- a/R/pkg/vignettes/sparkr-vignettes.Rmd
+++ b/R/pkg/vignettes/sparkr-vignettes.Rmd
@@ -523,6 +523,8 @@ SparkR supports the following machine learning models and 
algorithms.
 
 * Naive Bayes
 
+* Factorization Machines (FM) Classifier
+
 #### Regression
 
 * Accelerated Failure Time (AFT) Survival Model
@@ -705,6 +707,24 @@ naiveBayesPrediction <- predict(naiveBayesModel, titanicDF)
 head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", 
"prediction"))
 ```
 
+#### Factorization Machines Classifier
+
+Factorization Machines for classification problems.
+
+For background and details about the implementation of factorization machines,
+refer to the [Factorization Machines 
section](https://spark.apache.org/docs/latest/ml-classification-regression.html#factorization-machines).
+
+```{r}
+t <- as.data.frame(Titanic)
+training <- createDataFrame(t)
+
+model <- spark.fmClassifier(training, Survived ~ Age + Sex)
+summary(model)
+
+predictions <- predict(model, training)
+head(select(predictions, predictions$prediction))
+```
+
 #### Accelerated Failure Time Survival Model
 
 Survival analysis studies the expected duration of time until an event 
happens, and often the relationship with risk factors or treatment taken on the 
subject. In contrast to standard regression analysis, survival modeling has to 
deal with special characteristics in the data including non-negative survival 
time and censoring.
diff --git a/docs/ml-classification-regression.md 
b/docs/ml-classification-regression.md
index 9d53880..0359456 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -9,9 +9,9 @@ license: |
   The ASF licenses this file to You under the Apache License, Version 2.0
   (the "License"); you may not use this file except in compliance with
   the License.  You may obtain a copy of the License at
- 
+
      http://www.apache.org/licenses/LICENSE-2.0
- 
+
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -59,11 +59,11 @@ parameter to select between these two algorithms, or leave 
it unset and Spark wi
 
 ### Binomial logistic regression
 
-For more background and more details about the implementation of binomial 
logistic regression, refer to the documentation of [logistic regression in 
`spark.mllib`](mllib-linear-methods.html#logistic-regression). 
+For more background and more details about the implementation of binomial 
logistic regression, refer to the documentation of [logistic regression in 
`spark.mllib`](mllib-linear-methods.html#logistic-regression).
 
 **Examples**
 
-The following example shows how to train binomial and multinomial logistic 
regression 
+The following example shows how to train binomial and multinomial logistic 
regression
 models for binary classification with elastic net regularization. 
`elasticNetParam` corresponds to
 $\alpha$ and `regParam` corresponds to $\lambda$.
 
@@ -156,7 +156,7 @@ classes and $J$ is the number of features. If the algorithm 
is fit with an inter
 intercepts is available.
 
   > Multinomial coefficients are available as `coefficientMatrix` and 
intercepts are available as `interceptVector`.
- 
+
   > `coefficients` and `intercept` methods on a logistic regression model 
trained with multinomial family are not supported. Use `coefficientMatrix` and 
`interceptVector` instead.
 
 The conditional probabilities of the outcome classes $k \in \{1, 2, ..., K\}$ 
are modeled using the softmax function.
@@ -175,7 +175,7 @@ For a detailed derivation please see 
[here](https://en.wikipedia.org/wiki/Multin
 
 **Examples**
 
-The following example shows how to train a multiclass logistic regression 
+The following example shows how to train a multiclass logistic regression
 model with elastic net regularization, as well as extract the multiclass
 training summary for evaluating the model.
 
@@ -291,7 +291,7 @@ Refer to the [R API docs](api/R/spark.randomForest.html) 
for more details.
 
 ## Gradient-boosted tree classifier
 
-Gradient-boosted trees (GBTs) are a popular classification and regression 
method using ensembles of decision trees. 
+Gradient-boosted trees (GBTs) are a popular classification and regression 
method using ensembles of decision trees.
 More information about the `spark.ml` implementation can be found further in 
the [section on GBTs](#gradient-boosted-trees-gbts).
 
 **Examples**
@@ -332,10 +332,10 @@ Refer to the [R API docs](api/R/spark.gbt.html) for more 
details.
 
 ## Multilayer perceptron classifier
 
-Multilayer perceptron classifier (MLPC) is a classifier based on the 
[feedforward artificial neural 
network](https://en.wikipedia.org/wiki/Feedforward_neural_network). 
-MLPC consists of multiple layers of nodes. 
-Each layer is fully connected to the next layer in the network. Nodes in the 
input layer represent the input data. All other nodes map inputs to outputs 
-by a linear combination of the inputs with the node's weights `$\wv$` and bias 
`$\bv$` and applying an activation function. 
+Multilayer perceptron classifier (MLPC) is a classifier based on the 
[feedforward artificial neural 
network](https://en.wikipedia.org/wiki/Feedforward_neural_network).
+MLPC consists of multiple layers of nodes.
+Each layer is fully connected to the next layer in the network. Nodes in the 
input layer represent the input data. All other nodes map inputs to outputs
+by a linear combination of the inputs with the node's weights `$\wv$` and bias 
`$\bv$` and applying an activation function.
 This can be written in matrix form for MLPC with `$K+1$` layers as follows:
 `\[
 \mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T 
\x+b_1)+b_2)...+b_K)
@@ -348,7 +348,7 @@ Nodes in the output layer use softmax function:
 `\[
 \mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}}
 \]`
-The number of nodes `$N$` in the output layer corresponds to the number of 
classes. 
+The number of nodes `$N$` in the output layer corresponds to the number of 
classes.
 
 MLPC employs backpropagation for learning the model. We use the logistic loss 
function for optimization and L-BFGS as an optimization routine.
 
@@ -393,7 +393,7 @@ or set of hyperplanes in a high- or infinite-dimensional 
space, which can be use
 regression, or other tasks. Intuitively, a good separation is achieved by the 
hyperplane that has
 the largest distance to the nearest training-data points of any class 
(so-called functional margin),
 since in general the larger the margin the lower the generalization error of 
the classifier. LinearSVC
-in Spark ML supports binary classification with linear SVM. Internally, it 
optimizes the 
+in Spark ML supports binary classification with linear SVM. Internally, it 
optimizes the
 [Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss) using OWLQN optimizer.
 
 
@@ -469,8 +469,8 @@ Refer to the [Python API 
docs](api/python/pyspark.ml.html#pyspark.ml.classificat
 
 ## Naive Bayes
 
-[Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) 
are a family of simple 
-probabilistic, multiclass classifiers based on applying Bayes' theorem with 
strong (naive) independence 
+[Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) 
are a family of simple
+probabilistic, multiclass classifiers based on applying Bayes' theorem with 
strong (naive) independence
 assumptions between every pair of features.
 
 Naive Bayes can be trained very efficiently. With a single pass over the 
training data,
@@ -494,7 +494,7 @@ For document classification, the input feature vectors 
should usually be sparse
 Since the training data is only used once, it is not necessary to cache it.
 
 [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be 
used by
-setting the parameter $\lambda$ (default to $1.0$). 
+setting the parameter $\lambda$ (default to $1.0$).
 
 **Examples**
 
@@ -563,6 +563,15 @@ Refer to the [Python API 
docs](api/python/pyspark.ml.html#pyspark.ml.classificat
 {% include_example python/ml/fm_classifier_example.py %}
 </div>
 
+<div data-lang="r" markdown="1">
+
+Refer to the [R API docs](api/R/spark.fmClassifier.html) for more details.
+
+Note: At the moment SparkR doesn't suport feature scaling.
+
+{% include_example r/ml/fmClassifier.R %}
+</div>
+
 </div>
 
 
@@ -620,7 +629,7 @@ Currently in `spark.ml`, only a subset of the exponential 
family distributions a
 
 **NOTE**: Spark currently only supports up to 4096 features through its 
`GeneralizedLinearRegression`
 interface, and will throw an exception if this constraint is exceeded. See the 
[advanced section](ml-advanced) for more details.
- Still, for linear and logistic regression, models with an increased number of 
features can be trained 
+ Still, for linear and logistic regression, models with an increased number of 
features can be trained
  using the `LinearRegression` and `LogisticRegression` estimators.
 
 GLMs require exponential family distributions that can be written in their 
"canonical" or "natural" form, aka
@@ -840,7 +849,7 @@ Refer to the [R API docs](api/R/spark.randomForest.html) 
for more details.
 
 ## Gradient-boosted tree regression
 
-Gradient-boosted trees (GBTs) are a popular regression method using ensembles 
of decision trees. 
+Gradient-boosted trees (GBTs) are a popular regression method using ensembles 
of decision trees.
 More information about the `spark.ml` implementation can be found further in 
the [section on GBTs](#gradient-boosted-trees-gbts).
 
 **Examples**
@@ -883,16 +892,16 @@ Refer to the [R API docs](api/R/spark.gbt.html) for more 
details.
 ## Survival regression
 
 
-In `spark.ml`, we implement the [Accelerated failure time 
(AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) 
-model which is a parametric survival regression model for censored data. 
-It describes a model for the log of survival time, so it's often called a 
+In `spark.ml`, we implement the [Accelerated failure time 
(AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model)
+model which is a parametric survival regression model for censored data.
+It describes a model for the log of survival time, so it's often called a
 log-linear model for survival analysis. Different from a
 [Proportional 
hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model
-designed for the same purpose, the AFT model is easier to parallelize 
+designed for the same purpose, the AFT model is easier to parallelize
 because each instance contributes to the objective function independently.
 
-Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of 
-subjects i = 1, ..., n, with possible right-censoring, 
+Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of
+subjects i = 1, ..., n, with possible right-censoring,
 the likelihood function under the AFT model is given as:
 `\[
 
L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}}
@@ -906,8 +915,8 @@ assumes the form:
 Where $S_{0}(\epsilon_{i})$ is the baseline survivor function,
 and $f_{0}(\epsilon_{i})$ is the corresponding density function.
 
-The most commonly used AFT model is based on the Weibull distribution of the 
survival time. 
-The Weibull distribution for lifetime corresponds to the extreme value 
distribution for the 
+The most commonly used AFT model is based on the Weibull distribution of the 
survival time.
+The Weibull distribution for lifetime corresponds to the extreme value 
distribution for the
 log of the lifetime, and the $S_{0}(\epsilon)$ function is:
 `\[   
 S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}})
@@ -926,15 +935,15 @@ The gradient functions for $\beta$ and $\log\sigma$ 
respectively are:
 `\[   
 \frac{\partial (-\iota)}{\partial 
\beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma}
 \]`
-`\[ 
+`\[
 \frac{\partial (-\iota)}{\partial 
(\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}]
 \]`
 
-The AFT model can be formulated as a convex optimization problem, 
-i.e. the task of finding a minimizer of a convex function 
$-\iota(\beta,\sigma)$ 
+The AFT model can be formulated as a convex optimization problem,
+i.e. the task of finding a minimizer of a convex function 
$-\iota(\beta,\sigma)$
 that depends on the coefficients vector $\beta$ and the log of scale parameter 
$\log\sigma$.
 The optimization algorithm underlying the implementation is L-BFGS.
-The implementation matches the result from R's survival function 
+The implementation matches the result from R's survival function
 
[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html)
 
   > When fitting AFTSurvivalRegressionModel without intercept on dataset with 
constant nonzero column, Spark MLlib outputs zero coefficients for constant 
nonzero columns. This behavior is different from R survival::survreg.
@@ -1174,7 +1183,7 @@ The main differences between this API and the [original 
MLlib Decision Tree API]
 
 
 The Pipelines API for Decision Trees offers a bit more functionality than the 
original API.  
-In particular, for classification, users can get the predicted probability of 
each class (a.k.a. class conditional probabilities); 
+In particular, for classification, users can get the predicted probability of 
each class (a.k.a. class conditional probabilities);
 for regression, users can get the biased sample variance of prediction.
 
 Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described 
below in the [Tree ensembles section](#tree-ensembles).
@@ -1420,4 +1429,3 @@ Note that `GBTClassifier` currently only supports binary 
labels.
 </table>
 
 In the future, `GBTClassifier` will also output columns for `rawPrediction` 
and `probability`, just as `RandomForestClassifier` does.
-
diff --git a/docs/sparkr.md b/docs/sparkr.md
index 24fa3b4..d816549 100644
--- a/docs/sparkr.md
+++ b/docs/sparkr.md
@@ -9,9 +9,9 @@ license: |
   The ASF licenses this file to You under the Apache License, Version 2.0
   (the "License"); you may not use this file except in compliance with
   the License.  You may obtain a copy of the License at
- 
+
      http://www.apache.org/licenses/LICENSE-2.0
- 
+
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -540,6 +540,7 @@ SparkR supports the following machine learning algorithms 
currently:
 * [`spark.mlp`](api/R/spark.mlp.html): [`Multilayer Perceptron 
(MLP)`](ml-classification-regression.html#multilayer-perceptron-classifier)
 * [`spark.naiveBayes`](api/R/spark.naiveBayes.html): [`Naive 
Bayes`](ml-classification-regression.html#naive-bayes)
 * [`spark.svmLinear`](api/R/spark.svmLinear.html): [`Linear Support Vector 
Machine`](ml-classification-regression.html#linear-support-vector-machine)
+* [`spark.fmClassifier`](api/R/fmClassifier.html): [`Factorization Machines 
classifier`](ml-classification-regression.html#factorization-machines-classifier)
 
 #### Regression
 
@@ -756,4 +757,3 @@ You can inspect the search path in R with 
[`search()`](https://stat.ethz.ch/R-ma
 # Migration Guide
 
 The migration guide is now archived [on this 
page](sparkr-migration-guide.html).
-
diff --git a/examples/src/main/r/ml/fmClassifier.R 
b/examples/src/main/r/ml/fmClassifier.R
new file mode 100644
index 0000000..3f9df91
--- /dev/null
+++ b/examples/src/main/r/ml/fmClassifier.R
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# To run this example use
+# ./bin/spark-submit examples/src/main/r/ml/fmClassifier.R
+
+# Load SparkR library into your R session
+library(SparkR)
+
+# Initialize SparkSession
+sparkR.session(appName = "SparkR-ML-fmclasfier-example")
+
+# $example on$
+# Load training data
+df <- read.df("data/mllib/sample_libsvm_data.txt", source = "libsvm")
+training <- df
+test <- df
+
+# Fit a FM classification model
+model <- spark.fmClassifier(training, label ~ features)
+
+# Model summary
+summary(model)
+
+# Prediction
+predictions <- predict(model, test)
+head(predictions)
+
+# $example off$
+
+sparkR.session.stop()
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/FMClassifierWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/FMClassifierWrapper.scala
new file mode 100644
index 0000000..a6c6ad6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/FMClassifierWrapper.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.classification.{FMClassificationModel, FMClassifier}
+import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.ml.r.RWrapperUtils._
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class FMClassifierWrapper private (
+    val pipeline: PipelineModel,
+    val features: Array[String],
+    val labels: Array[String]) extends MLWritable {
+  import FMClassifierWrapper._
+
+  private val fmClassificationModel: FMClassificationModel =
+    pipeline.stages(1).asInstanceOf[FMClassificationModel]
+
+  lazy val rFeatures: Array[String] = if 
(fmClassificationModel.getFitIntercept) {
+    Array("(Intercept)") ++ features
+  } else {
+    features
+  }
+
+  lazy val rCoefficients: Array[Double] = if 
(fmClassificationModel.getFitIntercept) {
+    Array(fmClassificationModel.intercept) ++ 
fmClassificationModel.linear.toArray
+  } else {
+    fmClassificationModel.linear.toArray
+  }
+
+  lazy val rFactors = fmClassificationModel.factors.toArray
+
+  lazy val numClasses: Int = fmClassificationModel.numClasses
+
+  lazy val numFeatures: Int = fmClassificationModel.numFeatures
+
+  lazy val factorSize: Int = fmClassificationModel.getFactorSize
+
+  def transform(dataset: Dataset[_]): DataFrame = {
+    pipeline.transform(dataset)
+      .drop(PREDICTED_LABEL_INDEX_COL)
+      .drop(fmClassificationModel.getFeaturesCol)
+      .drop(fmClassificationModel.getLabelCol)
+  }
+
+  override def write: MLWriter = new 
FMClassifierWrapper.FMClassifierWrapperWriter(this)
+}
+
+private[r] object FMClassifierWrapper
+  extends MLReadable[FMClassifierWrapper] {
+
+  val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
+  val PREDICTED_LABEL_COL = "prediction"
+
+  def fit(  // scalastyle:ignore
+      data: DataFrame,
+      formula: String,
+      factorSize: Int,
+      fitLinear: Boolean,
+      regParam: Double,
+      miniBatchFraction: Double,
+      initStd: Double,
+      maxIter: Int,
+      stepSize: Double,
+      tol: Double,
+      solver: String,
+      seed: String,
+      thresholds: Array[Double],
+      handleInvalid: String): FMClassifierWrapper = {
+
+    val rFormula = new RFormula()
+      .setFormula(formula)
+      .setForceIndexLabel(true)
+      .setHandleInvalid(handleInvalid)
+    checkDataColumns(rFormula, data)
+    val rFormulaModel = rFormula.fit(data)
+
+    val fitIntercept = rFormula.hasIntercept
+
+    // get labels and feature names from output schema
+    val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
+
+    // assemble and fit the pipeline
+    val fmc = new FMClassifier()
+      .setFactorSize(factorSize)
+      .setFitIntercept(fitIntercept)
+      .setFitLinear(fitLinear)
+      .setRegParam(regParam)
+      .setMiniBatchFraction(miniBatchFraction)
+      .setInitStd(initStd)
+      .setMaxIter(maxIter)
+      .setStepSize(stepSize)
+      .setTol(tol)
+      .setSolver(solver)
+      .setFeaturesCol(rFormula.getFeaturesCol)
+      .setLabelCol(rFormula.getLabelCol)
+      .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
+
+    if (seed != null && seed.length > 0) {
+      fmc.setSeed(seed.toLong)
+    }
+
+    if (thresholds != null) {
+      fmc.setThresholds(thresholds)
+    }
+
+    val idxToStr = new IndexToString()
+      .setInputCol(PREDICTED_LABEL_INDEX_COL)
+      .setOutputCol(PREDICTED_LABEL_COL)
+      .setLabels(labels)
+
+    val pipeline = new Pipeline()
+      .setStages(Array(rFormulaModel, fmc, idxToStr))
+      .fit(data)
+
+    new FMClassifierWrapper(pipeline, features, labels)
+  }
+
+  override def read: MLReader[FMClassifierWrapper] = new 
FMClassifierWrapperReader
+
+  class FMClassifierWrapperWriter(instance: FMClassifierWrapper) extends 
MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("features" -> instance.features.toSeq) ~
+        ("labels" -> instance.labels.toSeq)
+      val rMetadataJson: String = compact(render(rMetadata))
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class FMClassifierWrapperReader extends MLReader[FMClassifierWrapper] {
+
+    override def load(path: String): FMClassifierWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val features = (rMetadata \ "features").extract[Array[String]]
+      val labels = (rMetadata \ "labels").extract[Array[String]]
+
+      val pipeline = PipelineModel.load(pipelinePath)
+      new FMClassifierWrapper(pipeline, features, labels)
+    }
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index ba6445a..68f7c8a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -74,6 +74,8 @@ private[r] object RWrappers extends MLReader[Object] {
         LinearSVCWrapper.load(path)
       case "org.apache.spark.ml.r.FPGrowthWrapper" =>
         FPGrowthWrapper.load(path)
+      case "org.apache.spark.ml.r.FMClassifierWrapper" =>
+        FMClassifierWrapper.load(path)
       case _ =>
         throw new SparkException(s"SparkR read.ml does not support load 
$className")
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to