Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/702#discussion_r12501401
--- Diff: docs/mllib-optimization.md ---
@@ -163,3 +171,108 @@ each iteration, to compute the gradient direction.
Available algorithms for gradient descent:
*
[GradientDescent.runMiniBatchSGD](api/mllib/index.html#org.apache.spark.mllib.optimization.GradientDescent)
+
+### Limited-memory BFGS
+L-BFGS is currently only a low-level optimization primitive in `MLlib`. If
you want to use L-BFGS in various
+ML algorithms such as Linear Regression, and Logistic Regression, you have
to pass the gradient of objective
+function, and updater into optimizer yourself instead of using the
training APIs like
+[LogisticRegression.LogisticRegressionWithSGD](api/mllib/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD).
+See the example below. It will be addressed in the next release.
+
+The L1 regularization by using
+[L1Updater](api/mllib/index.html#org.apache.spark.mllib.optimization.L1Updater)
will not work since the
+soft-thresholding logic in L1Updater is designed for gradient descent. See
the developer's note.
+
+The L-BFGS method
+[LBFGS.runLBFGS](api/scala/index.html#org.apache.spark.mllib.optimization.LBFGS)
+has the following parameters:
+
+* `gradient` is a class that computes the gradient of the objective
function
+being optimized, i.e., with respect to a single training example, at the
+current parameter value. MLlib includes gradient classes for common loss
+functions, e.g., hinge, logistic, least-squares. The gradient class takes
as
+input a training example, its label, and the current parameter value.
+* `updater` is a class that computes the gradient and loss of objective
function
+of the regularization part for L-BFGS. MLlib includes updaters for cases
without
+regularization, as well as L2 regularizer.
+* `numCorrections` is the number of corrections used in the L-BFGS update.
10 is
+recommended.
+* `maxNumIterations` is the maximal number of iterations that L-BFGS can
be run.
+* `regParam` is the regularization parameter when using regularization.
+
+
+The `return` is a tuple containing two elements. The first element is a
column matrix
+containing weights for every feature, and the second element is an array
containing
+the loss computed for every iteration.
+
+Here is an example to train binary logistic regression with L2
regularization using
+L-BFGS optimizer.
+{% highlight scala %}
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.classification.LogisticRegressionModel
+
+val data = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt")
+val numFeatures = data.take(1)(0).features.size
+
+// Split data into training (60%) and test (40%).
+val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
+
+// Prepend 1 into the training data as intercept.
+val training = splits(0).map(x => (x.label,
MLUtils.appendBias(x.features))).cache()
+
+val test = splits(1)
+
+// Run training algorithm to build the model
+val numCorrections = 10
+val convergenceTol = 1e-4
+val maxNumIterations = 20
+val regParam = 0.1
+val initialWeightsWithIntercept = Vectors.dense(new
Array[Double](numFeatures + 1))
+
+val (weightsWithIntercept, loss) = LBFGS.runLBFGS(
+ training,
+ new LogisticGradient(),
+ new SquaredL2Updater(),
+ numCorrections,
+ convergenceTol,
+ maxNumIterations,
+ regParam,
+ initialWeightsWithIntercept)
+
+val model = new LogisticRegressionModel(
+ Vectors.dense(weightsWithIntercept.toArray.slice(1,
weightsWithIntercept.size)),
--- End diff --
`appendBias` puts `1.0` at the end of the vector. So the slicing here needs
update.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---