[ https://issues.apache.org/jira/browse/FLINK-1979?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15286591#comment-15286591 ]
ASF GitHub Bot commented on FLINK-1979: --------------------------------------- Github user thvasilo commented on a diff in the pull request: https://github.com/apache/flink/pull/1985#discussion_r63519390 --- Diff: flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/RegularizationPenalty.scala --- @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.optimization + +import org.apache.flink.ml.math.{Vector, BLAS} +import org.apache.flink.ml.math.Breeze._ +import breeze.linalg.{norm => BreezeNorm} + +/** Represents a type of regularization penalty + * + * Regularization penalties are used to restrict the optimization problem to solutions with + * certain desirable characteristics, such as sparsity for the L1 penalty, or penalizing large + * weights for the L2 penalty. + * + * The regularization term, `R(w)` is added to the objective function, `f(w) = L(w) + lambda*R(w)` + * where lambda is the regularization parameter used to tune the amount of regularization applied. + */ +trait RegularizationPenalty extends Serializable { + + /** Calculates the new weights based on the gradient and regularization penalty + * + * @param weightVector The weights to be updated + * @param gradient The gradient used to update the weights + * @param regularizationConstant The regularization parameter to be applied + * @param learningRate The effective step size for this iteration + * @return Updated weights + */ + def takeStep( + weightVector: Vector, + gradient: Vector, + regularizationConstant: Double, + learningRate: Double) + : Vector + + /** Adds regularization to the loss value + * + * @param oldLoss The loss to be updated + * @param weightVector The gradient used to update the loss + * @param regularizationConstant The regularization parameter to be applied + * @return Updated loss + */ + def regLoss(oldLoss: Double, weightVector: Vector, regularizationConstant: Double): Double + +} + + +/** `L_2` regularization penalty. + * + * The regularization function is the square of the L2 norm `1/2*||w||_2^2` + * with `w` being the weight vector. The function penalizes large weights, + * favoring solutions with more small weights rather than few large ones. + */ +object L2Regularization extends RegularizationPenalty { + + /** Calculates the new weights based on the gradient and L2 regularization penalty + * + * The updated weight is `w - learningRate *(gradient + lambda * w)` where + * `w` is the weight vector, and `lambda` is the regularization parameter. + * + * @param weightVector The weights to be updated + * @param gradient The gradient according to which we will update the weights + * @param regularizationConstant The regularization parameter to be applied + * @param learningRate The effective step size for this iteration + * @return Updated weights + */ + override def takeStep( + weightVector: Vector, + gradient: Vector, + regularizationConstant: Double, + learningRate: Double) + : Vector = { + // add the gradient of the L2 regularization + BLAS.axpy(regularizationConstant, weightVector, gradient) + + // update the weights according to the learning rate + BLAS.axpy(-learningRate, gradient, weightVector) + + weightVector + } + + /** Adds regularization to the loss value + * + * The updated loss is `l + lambda * 1/2*||w||_2^2` where `l` is the old loss, --- End diff -- I would recommend spelling out `loss` here, as in some fonts it's hard tell the letter "l" apart from the number 1, > Implement Loss Functions > ------------------------ > > Key: FLINK-1979 > URL: https://issues.apache.org/jira/browse/FLINK-1979 > Project: Flink > Issue Type: Improvement > Components: Machine Learning Library > Reporter: Johannes Günther > Assignee: Johannes Günther > Priority: Minor > Labels: ML > > For convex optimization problems, optimizer methods like SGD rely on a > pluggable implementation of a loss function and its first derivative. -- This message was sent by Atlassian JIRA (v6.3.4#6332)