Updated Branches: refs/heads/branch-0.9 76147a290 -> 03019d106
Merge pull request #459 from srowen/UpdaterL2Regularization Correct L2 regularized weight update with canonical form Per thread on the user@ mailing list, and comments from Ameet, I believe the weight update for L2 regularization needs to be corrected. See http://mail-archives.apache.org/mod_mbox/spark-user/201401.mbox/%3CCAH3_EVMetuQuhj3__NdUniDLc4P-FMmmrmxw9TS14or8nT4BNQ%40mail.gmail.com%3E (cherry picked from commit fe8a3546f40394466a41fc750cb60f6fc73d8bbb) Signed-off-by: Patrick Wendell <pwend...@gmail.com> Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/03019d10 Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/03019d10 Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/03019d10 Branch: refs/heads/branch-0.9 Commit: 03019d106becae3cca95428b462d661c1afac37e Parents: 76147a2 Author: Patrick Wendell <pwend...@gmail.com> Authored: Sat Jan 18 16:29:23 2014 -0800 Committer: Patrick Wendell <pwend...@gmail.com> Committed: Sat Jan 18 16:29:43 2014 -0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/mllib/optimization/Updater.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/03019d10/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 4c51f4f..37124f2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -86,13 +86,17 @@ class L1Updater extends Updater { /** * Updater that adjusts the learning rate and performs L2 regularization + * + * See, for example, explanation of gradient and loss with L2 regularization on slide 21-22 + * of <a href="http://people.cs.umass.edu/~sheldon/teaching/2012fa/ml/files/lec7-annotated.pdf"> + * these slides</a>. */ class SquaredL2Updater extends Updater { override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) val normGradient = gradient.mul(thisIterStepSize) - val newWeights = weightsOld.sub(normGradient).div(2.0 * thisIterStepSize * regParam + 1.0) + val newWeights = weightsOld.mul(1.0 - 2.0 * thisIterStepSize * regParam).sub(normGradient) (newWeights, pow(newWeights.norm2, 2.0) * regParam) } }