This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0d37f79 [SPARK-30820][SPARKR][ML] Add FMClassifier to SparkR
0d37f79 is described below
commit 0d37f794ef9451199e1b757c1015fc7a8b3931a5
Author: zero323 <[email protected]>
AuthorDate: Tue Apr 7 09:01:45 2020 -0500
[SPARK-30820][SPARKR][ML] Add FMClassifier to SparkR
### What changes were proposed in this pull request?
This pull request adds SparkR wrapper for `FMClassifier`:
- Supporting ` org.apache.spark.ml.r.FMClassifierWrapper`.
- `FMClassificationModel` S4 class.
- Corresponding `spark.fmClassifier`, `predict`, `summary` and `write.ml`
generics.
- Corresponding docs and tests.
### Why are the changes needed?
Feature parity.
### Does this PR introduce any user-facing change?
No (new API).
### How was this patch tested?
New unit tests.
Closes #27570 from zero323/SPARK-30820.
Authored-by: zero323 <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
---
R/pkg/NAMESPACE | 3 +-
R/pkg/R/generics.R | 4 +
R/pkg/R/mllib_classification.R | 157 ++++++++++++++++++
R/pkg/R/mllib_utils.R | 2 +
R/pkg/tests/fulltests/test_mllib_classification.R | 34 ++++
R/pkg/vignettes/sparkr-vignettes.Rmd | 20 +++
docs/ml-classification-regression.md | 72 +++++----
docs/sparkr.md | 6 +-
examples/src/main/r/ml/fmClassifier.R | 45 ++++++
.../apache/spark/ml/r/FMClassifierWrapper.scala | 175 +++++++++++++++++++++
.../scala/org/apache/spark/ml/r/RWrappers.scala | 2 +
11 files changed, 484 insertions(+), 36 deletions(-)
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index fb879e4..18e4570 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -72,7 +72,8 @@ exportMethods("glm",
"spark.freqItemsets",
"spark.associationRules",
"spark.findFrequentSequentialPatterns",
- "spark.assignClusters")
+ "spark.assignClusters",
+ "spark.fmClassifier")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index d924b2a..d36496f 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1479,6 +1479,10 @@ setGeneric("spark.als", function(data, ...) {
standardGeneric("spark.als") })
setGeneric("spark.bisectingKmeans",
function(data, formula, ...) {
standardGeneric("spark.bisectingKmeans") })
+#' @rdname spark.fmClassifier
+setGeneric("spark.fmClassifier",
+ function(data, formula, ...) {
standardGeneric("spark.fmClassifier") })
+
#' @rdname spark.gaussianMixture
setGeneric("spark.gaussianMixture",
function(data, formula, ...) {
standardGeneric("spark.gaussianMixture") })
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index 3ad824e..fc5ac9f 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -42,6 +42,12 @@ setClass("MultilayerPerceptronClassificationModel",
representation(jobj = "jobj"
#' @note NaiveBayesModel since 2.0.0
setClass("NaiveBayesModel", representation(jobj = "jobj"))
+#' S4 class that represents a FMClassificationModel
+#'
+#' @param jobj a Java object reference to the backing Scala FMClassifierWrapper
+#' @note FMClassificationModel since 3.1.0
+setClass("FMClassificationModel", representation(jobj = "jobj"))
+
#' Linear SVM Model
#'
#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071
package.
@@ -649,3 +655,154 @@ setMethod("write.ml", signature(object =
"NaiveBayesModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})
+
+#' Factorization Machines Classification Model
+#'
+#' \code{spark.fmClassifier} fits a factorization classification model against
a SparkDataFrame.
+#' Users can call \code{summary} to print a summary of the fitted model,
\code{predict} to make
+#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load
fitted models.
+#' Only categorical data is supported.
+#'
+#' @param data a \code{SparkDataFrame} of observations and labels for model
fitting.
+#' @param formula a symbolic description of the model to be fitted. Currently
only a few formula
+#' operators are supported, including '~', '.', ':', '+', and
'-'.
+#' @param factorSize dimensionality of the factors.
+#' @param fitLinear whether to fit linear term. # TODO Can we express this
with formula?
+#' @param regParam the regularization parameter.
+#' @param miniBatchFraction the mini-batch fraction parameter.
+#' @param initStd the standard deviation of initial coefficients.
+#' @param maxIter maximum iteration number.
+#' @param stepSize stepSize parameter.
+#' @param tol convergence tolerance of iterations.
+#' @param solver solver parameter, supported options: "gd" (minibatch gradient
descent) or "adamW".
+#' @param thresholds in binary classification, in range [0, 1]. If the
estimated probability of
+#' class label 1 is > threshold, then predict 1, else 0. A
high threshold
+#' encourages the model to predict 0 more often; a low
threshold encourages the
+#' model to predict 1 more often. Note: Setting this with
threshold p is
+#' equivalent to setting thresholds c(1-p, p).
+#' @param seed seed parameter for weights initialization.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL
values) in features and
+#' label column of string type.
+#' Supported options: "skip" (filter out rows with
invalid data),
+#' "error" (throw an error), "keep"
(put invalid data in
+#' a special additional bucket, at
index numLabels). Default
+#' is "error".
+#' @param ... additional arguments passed to the method.
+#' @return \code{spark.fmClassifier} returns a fitted Factorization Machines
Classification Model.
+#' @rdname spark.fmClassifier
+#' @aliases spark.fmClassifier,SparkDataFrame,formula-method
+#' @name spark.fmClassifier
+#' @seealso \link{read.ml}
+#' @examples
+#' \dontrun{
+#' df <- read.df("data/mllib/sample_binary_classification_data.txt", source =
"libsvm")
+#'
+#' # fit Factorization Machines Classification Model
+#' model <- spark.fmClassifier(
+#' df, label ~ features,
+#' regParam = 0.01, maxIter = 10, fitLinear = TRUE
+#' )
+#'
+#' # get the summary of the model
+#' summary(model)
+#'
+#' # make predictions
+#' predictions <- predict(model, df)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#' }
+#' @note spark.fmClassifier since 3.1.0
+setMethod("spark.fmClassifier", signature(data = "SparkDataFrame", formula =
"formula"),
+ function(data, formula, factorSize = 8, fitLinear = TRUE, regParam =
0.0,
+ miniBatchFraction = 1.0, initStd = 0.01, maxIter = 100,
stepSize=1.0,
+ tol = 1e-6, solver = c("adamW", "gd"), thresholds = NULL,
seed = NULL,
+ handleInvalid = c("error", "keep", "skip")) {
+
+ formula <- paste(deparse(formula), collapse = "")
+
+ if (!is.null(seed)) {
+ seed <- as.character(as.integer(seed))
+ }
+
+ if (!is.null(thresholds)) {
+ thresholds <- as.list(thresholds)
+ }
+
+ solver <- match.arg(solver)
+ handleInvalid <- match.arg(handleInvalid)
+
+ jobj <- callJStatic("org.apache.spark.ml.r.FMClassifierWrapper",
+ "fit",
+ data@sdf,
+ formula,
+ as.integer(factorSize),
+ as.logical(fitLinear),
+ as.numeric(regParam),
+ as.numeric(miniBatchFraction),
+ as.numeric(initStd),
+ as.integer(maxIter),
+ as.numeric(stepSize),
+ as.numeric(tol),
+ solver,
+ seed,
+ thresholds,
+ handleInvalid)
+ new("FMClassificationModel", jobj = jobj)
+ })
+
+# Returns the summary of a FM Classification model produced by
\code{spark.fmClassifier}
+
+#' @param object a FM Classification model fitted by \code{spark.fmClassifier}.
+#' @return \code{summary} returns summary information of the fitted model,
which is a list.
+#' @rdname spark.fmClassifier
+#' @note summary(FMClassificationModel) since 3.1.0
+setMethod("summary", signature(object = "FMClassificationModel"),
+ function(object) {
+ jobj <- object@jobj
+ features <- callJMethod(jobj, "rFeatures")
+ coefficients <- callJMethod(jobj, "rCoefficients")
+ coefficients <- as.matrix(unlist(coefficients))
+ colnames(coefficients) <- c("Estimate")
+ rownames(coefficients) <- unlist(features)
+ numClasses <- callJMethod(jobj, "numClasses")
+ numFeatures <- callJMethod(jobj, "numFeatures")
+ raw_factors <- unlist(callJMethod(jobj, "rFactors"))
+ factor_size <- callJMethod(jobj, "factorSize")
+
+ list(
+ coefficients = coefficients,
+ factors = matrix(raw_factors, ncol = factor_size),
+ numClasses = numClasses, numFeatures = numFeatures,
+ factorSize = factor_size
+ )
+ })
+
+# Predicted values based on an FMClassificationModel model
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns the predicted values based on a FM
Classification model.
+#' @rdname spark.fmClassifier
+#' @aliases predict,FMClassificationModel,SparkDataFrame-method
+#' @note predict(FMClassificationModel) since 3.1.0
+setMethod("predict", signature(object = "FMClassificationModel"),
+ function(object, newData) {
+ predict_internal(object, newData)
+ })
+
+# Save fitted FMClassificationModel to the input path
+
+#' @param path The directory where the model is saved.
+#' @param overwrite Overwrites or not if the output path already exists.
Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @rdname spark.fmClassifier
+#' @aliases write.ml,FMClassificationModel,character-method
+#' @note write.ml(FMClassificationModel, character) since 3.1.0
+setMethod("write.ml", signature(object = "FMClassificationModel", path =
"character"),
+ function(object, path, overwrite = FALSE) {
+ write_internal(object, path, overwrite)
+ })
diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R
index 7d04bff..f5643aa 100644
--- a/R/pkg/R/mllib_utils.R
+++ b/R/pkg/R/mllib_utils.R
@@ -123,6 +123,8 @@ read.ml <- function(path) {
new("LinearSVCModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) {
new("FPGrowthModel", jobj = jobj)
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FMClassifierWrapper")) {
+ new("FMClassificationModel", jobj = jobj)
} else {
stop("Unsupported model: ", jobj)
}
diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R
b/R/pkg/tests/fulltests/test_mllib_classification.R
index 2da3a02..9dd275a 100644
--- a/R/pkg/tests/fulltests/test_mllib_classification.R
+++ b/R/pkg/tests/fulltests/test_mllib_classification.R
@@ -488,4 +488,38 @@ test_that("spark.naiveBayes", {
expect_equal(class(collect(predictions)$clicked[1]), "character")
})
+test_that("spark.fmClassifier", {
+ df <- withColumn(
+ suppressWarnings(createDataFrame(iris)),
+ "Species", otherwise(when(column("Species") == "Setosa", "Setosa"),
"Not-Setosa")
+ )
+
+ model1 <- spark.fmClassifier(
+ df, Species ~ .,
+ regParam = 0.01, maxIter = 10, fitLinear = TRUE, factorSize = 3
+ )
+
+ prediction1 <- predict(model1, df)
+ expect_is(prediction1, "SparkDataFrame")
+ expect_equal(summary(model1)$factorSize, 3)
+
+ # Test model save/load
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-fmclassifier", fileext = ".tmp")
+ write.ml(model1, modelPath)
+ model2 <- read.ml(modelPath)
+
+ expect_is(model2, "FMClassificationModel")
+
+ expect_equal(summary(model1), summary(model2))
+
+ prediction2 <- predict(model2, df)
+ expect_equal(
+ collect(drop(prediction1, c("rawPrediction", "probability"))),
+ collect(drop(prediction2, c("rawPrediction", "probability")))
+ )
+ unlink(modelPath)
+ }
+})
+
sparkR.session.stop()
diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd
b/R/pkg/vignettes/sparkr-vignettes.Rmd
index 9e48ae3..6b62a66 100644
--- a/R/pkg/vignettes/sparkr-vignettes.Rmd
+++ b/R/pkg/vignettes/sparkr-vignettes.Rmd
@@ -523,6 +523,8 @@ SparkR supports the following machine learning models and
algorithms.
* Naive Bayes
+* Factorization Machines (FM) Classifier
+
#### Regression
* Accelerated Failure Time (AFT) Survival Model
@@ -705,6 +707,24 @@ naiveBayesPrediction <- predict(naiveBayesModel, titanicDF)
head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived",
"prediction"))
```
+#### Factorization Machines Classifier
+
+Factorization Machines for classification problems.
+
+For background and details about the implementation of factorization machines,
+refer to the [Factorization Machines
section](https://spark.apache.org/docs/latest/ml-classification-regression.html#factorization-machines).
+
+```{r}
+t <- as.data.frame(Titanic)
+training <- createDataFrame(t)
+
+model <- spark.fmClassifier(training, Survived ~ Age + Sex)
+summary(model)
+
+predictions <- predict(model, training)
+head(select(predictions, predictions$prediction))
+```
+
#### Accelerated Failure Time Survival Model
Survival analysis studies the expected duration of time until an event
happens, and often the relationship with risk factors or treatment taken on the
subject. In contrast to standard regression analysis, survival modeling has to
deal with special characteristics in the data including non-negative survival
time and censoring.
diff --git a/docs/ml-classification-regression.md
b/docs/ml-classification-regression.md
index 9d53880..0359456 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -9,9 +9,9 @@ license: |
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
-
+
http://www.apache.org/licenses/LICENSE-2.0
-
+
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -59,11 +59,11 @@ parameter to select between these two algorithms, or leave
it unset and Spark wi
### Binomial logistic regression
-For more background and more details about the implementation of binomial
logistic regression, refer to the documentation of [logistic regression in
`spark.mllib`](mllib-linear-methods.html#logistic-regression).
+For more background and more details about the implementation of binomial
logistic regression, refer to the documentation of [logistic regression in
`spark.mllib`](mllib-linear-methods.html#logistic-regression).
**Examples**
-The following example shows how to train binomial and multinomial logistic
regression
+The following example shows how to train binomial and multinomial logistic
regression
models for binary classification with elastic net regularization.
`elasticNetParam` corresponds to
$\alpha$ and `regParam` corresponds to $\lambda$.
@@ -156,7 +156,7 @@ classes and $J$ is the number of features. If the algorithm
is fit with an inter
intercepts is available.
> Multinomial coefficients are available as `coefficientMatrix` and
intercepts are available as `interceptVector`.
-
+
> `coefficients` and `intercept` methods on a logistic regression model
trained with multinomial family are not supported. Use `coefficientMatrix` and
`interceptVector` instead.
The conditional probabilities of the outcome classes $k \in \{1, 2, ..., K\}$
are modeled using the softmax function.
@@ -175,7 +175,7 @@ For a detailed derivation please see
[here](https://en.wikipedia.org/wiki/Multin
**Examples**
-The following example shows how to train a multiclass logistic regression
+The following example shows how to train a multiclass logistic regression
model with elastic net regularization, as well as extract the multiclass
training summary for evaluating the model.
@@ -291,7 +291,7 @@ Refer to the [R API docs](api/R/spark.randomForest.html)
for more details.
## Gradient-boosted tree classifier
-Gradient-boosted trees (GBTs) are a popular classification and regression
method using ensembles of decision trees.
+Gradient-boosted trees (GBTs) are a popular classification and regression
method using ensembles of decision trees.
More information about the `spark.ml` implementation can be found further in
the [section on GBTs](#gradient-boosted-trees-gbts).
**Examples**
@@ -332,10 +332,10 @@ Refer to the [R API docs](api/R/spark.gbt.html) for more
details.
## Multilayer perceptron classifier
-Multilayer perceptron classifier (MLPC) is a classifier based on the
[feedforward artificial neural
network](https://en.wikipedia.org/wiki/Feedforward_neural_network).
-MLPC consists of multiple layers of nodes.
-Each layer is fully connected to the next layer in the network. Nodes in the
input layer represent the input data. All other nodes map inputs to outputs
-by a linear combination of the inputs with the node's weights `$\wv$` and bias
`$\bv$` and applying an activation function.
+Multilayer perceptron classifier (MLPC) is a classifier based on the
[feedforward artificial neural
network](https://en.wikipedia.org/wiki/Feedforward_neural_network).
+MLPC consists of multiple layers of nodes.
+Each layer is fully connected to the next layer in the network. Nodes in the
input layer represent the input data. All other nodes map inputs to outputs
+by a linear combination of the inputs with the node's weights `$\wv$` and bias
`$\bv$` and applying an activation function.
This can be written in matrix form for MLPC with `$K+1$` layers as follows:
`\[
\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T
\x+b_1)+b_2)...+b_K)
@@ -348,7 +348,7 @@ Nodes in the output layer use softmax function:
`\[
\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}}
\]`
-The number of nodes `$N$` in the output layer corresponds to the number of
classes.
+The number of nodes `$N$` in the output layer corresponds to the number of
classes.
MLPC employs backpropagation for learning the model. We use the logistic loss
function for optimization and L-BFGS as an optimization routine.
@@ -393,7 +393,7 @@ or set of hyperplanes in a high- or infinite-dimensional
space, which can be use
regression, or other tasks. Intuitively, a good separation is achieved by the
hyperplane that has
the largest distance to the nearest training-data points of any class
(so-called functional margin),
since in general the larger the margin the lower the generalization error of
the classifier. LinearSVC
-in Spark ML supports binary classification with linear SVM. Internally, it
optimizes the
+in Spark ML supports binary classification with linear SVM. Internally, it
optimizes the
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss) using OWLQN optimizer.
@@ -469,8 +469,8 @@ Refer to the [Python API
docs](api/python/pyspark.ml.html#pyspark.ml.classificat
## Naive Bayes
-[Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier)
are a family of simple
-probabilistic, multiclass classifiers based on applying Bayes' theorem with
strong (naive) independence
+[Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier)
are a family of simple
+probabilistic, multiclass classifiers based on applying Bayes' theorem with
strong (naive) independence
assumptions between every pair of features.
Naive Bayes can be trained very efficiently. With a single pass over the
training data,
@@ -494,7 +494,7 @@ For document classification, the input feature vectors
should usually be sparse
Since the training data is only used once, it is not necessary to cache it.
[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be
used by
-setting the parameter $\lambda$ (default to $1.0$).
+setting the parameter $\lambda$ (default to $1.0$).
**Examples**
@@ -563,6 +563,15 @@ Refer to the [Python API
docs](api/python/pyspark.ml.html#pyspark.ml.classificat
{% include_example python/ml/fm_classifier_example.py %}
</div>
+<div data-lang="r" markdown="1">
+
+Refer to the [R API docs](api/R/spark.fmClassifier.html) for more details.
+
+Note: At the moment SparkR doesn't suport feature scaling.
+
+{% include_example r/ml/fmClassifier.R %}
+</div>
+
</div>
@@ -620,7 +629,7 @@ Currently in `spark.ml`, only a subset of the exponential
family distributions a
**NOTE**: Spark currently only supports up to 4096 features through its
`GeneralizedLinearRegression`
interface, and will throw an exception if this constraint is exceeded. See the
[advanced section](ml-advanced) for more details.
- Still, for linear and logistic regression, models with an increased number of
features can be trained
+ Still, for linear and logistic regression, models with an increased number of
features can be trained
using the `LinearRegression` and `LogisticRegression` estimators.
GLMs require exponential family distributions that can be written in their
"canonical" or "natural" form, aka
@@ -840,7 +849,7 @@ Refer to the [R API docs](api/R/spark.randomForest.html)
for more details.
## Gradient-boosted tree regression
-Gradient-boosted trees (GBTs) are a popular regression method using ensembles
of decision trees.
+Gradient-boosted trees (GBTs) are a popular regression method using ensembles
of decision trees.
More information about the `spark.ml` implementation can be found further in
the [section on GBTs](#gradient-boosted-trees-gbts).
**Examples**
@@ -883,16 +892,16 @@ Refer to the [R API docs](api/R/spark.gbt.html) for more
details.
## Survival regression
-In `spark.ml`, we implement the [Accelerated failure time
(AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model)
-model which is a parametric survival regression model for censored data.
-It describes a model for the log of survival time, so it's often called a
+In `spark.ml`, we implement the [Accelerated failure time
(AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model)
+model which is a parametric survival regression model for censored data.
+It describes a model for the log of survival time, so it's often called a
log-linear model for survival analysis. Different from a
[Proportional
hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model
-designed for the same purpose, the AFT model is easier to parallelize
+designed for the same purpose, the AFT model is easier to parallelize
because each instance contributes to the objective function independently.
-Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of
-subjects i = 1, ..., n, with possible right-censoring,
+Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of
+subjects i = 1, ..., n, with possible right-censoring,
the likelihood function under the AFT model is given as:
`\[
L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}}
@@ -906,8 +915,8 @@ assumes the form:
Where $S_{0}(\epsilon_{i})$ is the baseline survivor function,
and $f_{0}(\epsilon_{i})$ is the corresponding density function.
-The most commonly used AFT model is based on the Weibull distribution of the
survival time.
-The Weibull distribution for lifetime corresponds to the extreme value
distribution for the
+The most commonly used AFT model is based on the Weibull distribution of the
survival time.
+The Weibull distribution for lifetime corresponds to the extreme value
distribution for the
log of the lifetime, and the $S_{0}(\epsilon)$ function is:
`\[
S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}})
@@ -926,15 +935,15 @@ The gradient functions for $\beta$ and $\log\sigma$
respectively are:
`\[
\frac{\partial (-\iota)}{\partial
\beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma}
\]`
-`\[
+`\[
\frac{\partial (-\iota)}{\partial
(\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}]
\]`
-The AFT model can be formulated as a convex optimization problem,
-i.e. the task of finding a minimizer of a convex function
$-\iota(\beta,\sigma)$
+The AFT model can be formulated as a convex optimization problem,
+i.e. the task of finding a minimizer of a convex function
$-\iota(\beta,\sigma)$
that depends on the coefficients vector $\beta$ and the log of scale parameter
$\log\sigma$.
The optimization algorithm underlying the implementation is L-BFGS.
-The implementation matches the result from R's survival function
+The implementation matches the result from R's survival function
[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html)
> When fitting AFTSurvivalRegressionModel without intercept on dataset with
constant nonzero column, Spark MLlib outputs zero coefficients for constant
nonzero columns. This behavior is different from R survival::survreg.
@@ -1174,7 +1183,7 @@ The main differences between this API and the [original
MLlib Decision Tree API]
The Pipelines API for Decision Trees offers a bit more functionality than the
original API.
-In particular, for classification, users can get the predicted probability of
each class (a.k.a. class conditional probabilities);
+In particular, for classification, users can get the predicted probability of
each class (a.k.a. class conditional probabilities);
for regression, users can get the biased sample variance of prediction.
Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described
below in the [Tree ensembles section](#tree-ensembles).
@@ -1420,4 +1429,3 @@ Note that `GBTClassifier` currently only supports binary
labels.
</table>
In the future, `GBTClassifier` will also output columns for `rawPrediction`
and `probability`, just as `RandomForestClassifier` does.
-
diff --git a/docs/sparkr.md b/docs/sparkr.md
index 24fa3b4..d816549 100644
--- a/docs/sparkr.md
+++ b/docs/sparkr.md
@@ -9,9 +9,9 @@ license: |
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
-
+
http://www.apache.org/licenses/LICENSE-2.0
-
+
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -540,6 +540,7 @@ SparkR supports the following machine learning algorithms
currently:
* [`spark.mlp`](api/R/spark.mlp.html): [`Multilayer Perceptron
(MLP)`](ml-classification-regression.html#multilayer-perceptron-classifier)
* [`spark.naiveBayes`](api/R/spark.naiveBayes.html): [`Naive
Bayes`](ml-classification-regression.html#naive-bayes)
* [`spark.svmLinear`](api/R/spark.svmLinear.html): [`Linear Support Vector
Machine`](ml-classification-regression.html#linear-support-vector-machine)
+* [`spark.fmClassifier`](api/R/fmClassifier.html): [`Factorization Machines
classifier`](ml-classification-regression.html#factorization-machines-classifier)
#### Regression
@@ -756,4 +757,3 @@ You can inspect the search path in R with
[`search()`](https://stat.ethz.ch/R-ma
# Migration Guide
The migration guide is now archived [on this
page](sparkr-migration-guide.html).
-
diff --git a/examples/src/main/r/ml/fmClassifier.R
b/examples/src/main/r/ml/fmClassifier.R
new file mode 100644
index 0000000..3f9df91
--- /dev/null
+++ b/examples/src/main/r/ml/fmClassifier.R
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# To run this example use
+# ./bin/spark-submit examples/src/main/r/ml/fmClassifier.R
+
+# Load SparkR library into your R session
+library(SparkR)
+
+# Initialize SparkSession
+sparkR.session(appName = "SparkR-ML-fmclasfier-example")
+
+# $example on$
+# Load training data
+df <- read.df("data/mllib/sample_libsvm_data.txt", source = "libsvm")
+training <- df
+test <- df
+
+# Fit a FM classification model
+model <- spark.fmClassifier(training, label ~ features)
+
+# Model summary
+summary(model)
+
+# Prediction
+predictions <- predict(model, test)
+head(predictions)
+
+# $example off$
+
+sparkR.session.stop()
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/r/FMClassifierWrapper.scala
b/mllib/src/main/scala/org/apache/spark/ml/r/FMClassifierWrapper.scala
new file mode 100644
index 0000000..a6c6ad6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/FMClassifierWrapper.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.classification.{FMClassificationModel, FMClassifier}
+import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.ml.r.RWrapperUtils._
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class FMClassifierWrapper private (
+ val pipeline: PipelineModel,
+ val features: Array[String],
+ val labels: Array[String]) extends MLWritable {
+ import FMClassifierWrapper._
+
+ private val fmClassificationModel: FMClassificationModel =
+ pipeline.stages(1).asInstanceOf[FMClassificationModel]
+
+ lazy val rFeatures: Array[String] = if
(fmClassificationModel.getFitIntercept) {
+ Array("(Intercept)") ++ features
+ } else {
+ features
+ }
+
+ lazy val rCoefficients: Array[Double] = if
(fmClassificationModel.getFitIntercept) {
+ Array(fmClassificationModel.intercept) ++
fmClassificationModel.linear.toArray
+ } else {
+ fmClassificationModel.linear.toArray
+ }
+
+ lazy val rFactors = fmClassificationModel.factors.toArray
+
+ lazy val numClasses: Int = fmClassificationModel.numClasses
+
+ lazy val numFeatures: Int = fmClassificationModel.numFeatures
+
+ lazy val factorSize: Int = fmClassificationModel.getFactorSize
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset)
+ .drop(PREDICTED_LABEL_INDEX_COL)
+ .drop(fmClassificationModel.getFeaturesCol)
+ .drop(fmClassificationModel.getLabelCol)
+ }
+
+ override def write: MLWriter = new
FMClassifierWrapper.FMClassifierWrapperWriter(this)
+}
+
+private[r] object FMClassifierWrapper
+ extends MLReadable[FMClassifierWrapper] {
+
+ val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
+ val PREDICTED_LABEL_COL = "prediction"
+
+ def fit( // scalastyle:ignore
+ data: DataFrame,
+ formula: String,
+ factorSize: Int,
+ fitLinear: Boolean,
+ regParam: Double,
+ miniBatchFraction: Double,
+ initStd: Double,
+ maxIter: Int,
+ stepSize: Double,
+ tol: Double,
+ solver: String,
+ seed: String,
+ thresholds: Array[Double],
+ handleInvalid: String): FMClassifierWrapper = {
+
+ val rFormula = new RFormula()
+ .setFormula(formula)
+ .setForceIndexLabel(true)
+ .setHandleInvalid(handleInvalid)
+ checkDataColumns(rFormula, data)
+ val rFormulaModel = rFormula.fit(data)
+
+ val fitIntercept = rFormula.hasIntercept
+
+ // get labels and feature names from output schema
+ val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
+
+ // assemble and fit the pipeline
+ val fmc = new FMClassifier()
+ .setFactorSize(factorSize)
+ .setFitIntercept(fitIntercept)
+ .setFitLinear(fitLinear)
+ .setRegParam(regParam)
+ .setMiniBatchFraction(miniBatchFraction)
+ .setInitStd(initStd)
+ .setMaxIter(maxIter)
+ .setStepSize(stepSize)
+ .setTol(tol)
+ .setSolver(solver)
+ .setFeaturesCol(rFormula.getFeaturesCol)
+ .setLabelCol(rFormula.getLabelCol)
+ .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
+
+ if (seed != null && seed.length > 0) {
+ fmc.setSeed(seed.toLong)
+ }
+
+ if (thresholds != null) {
+ fmc.setThresholds(thresholds)
+ }
+
+ val idxToStr = new IndexToString()
+ .setInputCol(PREDICTED_LABEL_INDEX_COL)
+ .setOutputCol(PREDICTED_LABEL_COL)
+ .setLabels(labels)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, fmc, idxToStr))
+ .fit(data)
+
+ new FMClassifierWrapper(pipeline, features, labels)
+ }
+
+ override def read: MLReader[FMClassifierWrapper] = new
FMClassifierWrapperReader
+
+ class FMClassifierWrapperWriter(instance: FMClassifierWrapper) extends
MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("features" -> instance.features.toSeq) ~
+ ("labels" -> instance.labels.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class FMClassifierWrapperReader extends MLReader[FMClassifierWrapper] {
+
+ override def load(path: String): FMClassifierWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val features = (rMetadata \ "features").extract[Array[String]]
+ val labels = (rMetadata \ "labels").extract[Array[String]]
+
+ val pipeline = PipelineModel.load(pipelinePath)
+ new FMClassifierWrapper(pipeline, features, labels)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index ba6445a..68f7c8a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -74,6 +74,8 @@ private[r] object RWrappers extends MLReader[Object] {
LinearSVCWrapper.load(path)
case "org.apache.spark.ml.r.FPGrowthWrapper" =>
FPGrowthWrapper.load(path)
+ case "org.apache.spark.ml.r.FMClassifierWrapper" =>
+ FMClassifierWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR read.ml does not support load
$className")
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]