[ 
https://issues.apache.org/jira/browse/FLINK-1992?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=14553989#comment-14553989
 ] 

ASF GitHub Bot commented on FLINK-1992:
---------------------------------------

Github user tillrohrmann commented on a diff in the pull request:

    https://github.com/apache/flink/pull/692#discussion_r30787129
  
    --- Diff: 
flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
 ---
    @@ -76,86 +77,163 @@ class GradientDescent(runParameters: ParameterMap) 
extends IterativeSolver {
         }.withBroadcastSet(currentWeights, WEIGHTVECTOR_BROADCAST)
       }
     
    +
    +
       /** Provides a solution for the given optimization problem
         *
         * @param data A Dataset of LabeledVector (label, features) pairs
    -    * @param initWeights The initial weights that will be optimized
    +    * @param initialWeights The initial weights that will be optimized
         * @return The weights, optimized for the provided data.
         */
       override def optimize(
         data: DataSet[LabeledVector],
    -    initWeights: Option[DataSet[WeightVector]]): DataSet[WeightVector] = {
    -    // TODO: Faster way to do this?
    -    val dimensionsDS = data.map(_.vector.size).reduce((a, b) => b)
    -
    -    val numberOfIterations: Int = parameterMap(Iterations)
    +    initialWeights: Option[DataSet[WeightVector]]): DataSet[WeightVector] 
= {
    +    val numberOfIterations: Int = parameters(Iterations)
    +    // TODO(tvas): This looks out of place, why don't we get back an 
Option from
    +    // parameters(ConvergenceThreshold)?
    +    val convergenceThresholdOption = parameters.get(ConvergenceThreshold)
     
         // Initialize weights
    -    val initialWeightsDS: DataSet[WeightVector] = initWeights match {
    -      // Ensure provided weight vector is a DenseVector
    -      case Some(wvDS) => {
    -        wvDS.map{wv => {
    -          val denseWeights = wv.weights match {
    -            case dv: DenseVector => dv
    -            case sv: SparseVector => sv.toDenseVector
    +    val initialWeightsDS: DataSet[WeightVector] = 
createInitialWeightsDS(initialWeights, data)
    +
    +    // Perform the iterations
    +    val optimizedWeights = convergenceThresholdOption match {
    +      // No convergence criterion
    +      case None =>
    +        initialWeightsDS.iterate(numberOfIterations) {
    +          weightVectorDS => {
    +            SGDStep(data, weightVectorDS)
               }
    -          WeightVector(denseWeights, wv.intercept)
             }
    -
    +      case Some(convergence) =>
    +        /** Calculates the regularized loss, from the data and given 
weights **/
    +        def lossCalculation(data: DataSet[LabeledVector], weightDS: 
DataSet[WeightVector]):
    +        DataSet[Double] = {
    +          data.map {
    +            new LossCalculation
    +          }.withBroadcastSet(weightDS, WEIGHTVECTOR_BROADCAST)
    +            .reduce {
    +            (left, right) =>
    +              val (leftLoss, leftCount) = left
    +              val (rightLoss, rightCount) = right
    +              (leftLoss + rightLoss, rightCount + leftCount)
    +          }
    +            .map{new RegularizedLossCalculation}
    +            .withBroadcastSet(weightDS, WEIGHTVECTOR_BROADCAST)
             }
    -      }
    -      case None => createInitialWeightVector(dimensionsDS)
    -    }
    -
    -    // Perform the iterations
    -    // TODO: Enable convergence stopping criterion, as in Multiple Linear 
regression
    -    initialWeightsDS.iterate(numberOfIterations) {
    -      weightVector => {
    -        SGDStep(data, weightVector)
    -      }
    +        // We have to calculate for each weight vector the sum of squared 
residuals,
    +        // and then sum them and apply regularization
    +        val initialLossSumDS = lossCalculation(data, initialWeightsDS)
    +
    +        // Combine weight vector with the current loss
    +        val initialWeightsWithLossSum = initialWeightsDS.
    +          crossWithTiny(initialLossSumDS).setParallelism(1)
    +
    +        val resultWithLoss = initialWeightsWithLossSum.
    +          iterateWithTermination(numberOfIterations) {
    +          weightsWithLossSum =>
    +
    +            // Extract weight vector and loss
    +            val previousWeightsDS = weightsWithLossSum.map{_._1}
    +            val previousLossSumDS = weightsWithLossSum.map{_._2}
    +
    +            val currentWeightsDS = SGDStep(data, previousWeightsDS)
    +
    +            val currentLossSumDS = lossCalculation(data, currentWeightsDS)
    +
    +            // Check if the relative change in the loss is smaller than the
    +            // convergence threshold. If yes, then terminate i.e. return 
empty termination data set
    +            val termination = 
previousLossSumDS.crossWithTiny(currentLossSumDS).setParallelism(1).
    +              filter{
    +              pair => {
    +                val (previousLoss, currentLoss) = pair
    +
    +                if (previousLoss <= 0) {
    +                  false
    +                } else {
    +                  math.abs((previousLoss - currentLoss)/previousLoss) >= 
convergence
    --- End diff --
    
    Fair enough.
    
    We need web documentation for this convergence criterion.


> Add convergence criterion to SGD optimizer
> ------------------------------------------
>
>                 Key: FLINK-1992
>                 URL: https://issues.apache.org/jira/browse/FLINK-1992
>             Project: Flink
>          Issue Type: Improvement
>          Components: Machine Learning Library
>            Reporter: Till Rohrmann
>            Assignee: Theodore Vasiloudis
>            Priority: Minor
>              Labels: ML
>             Fix For: 0.9
>
>
> Currently, Flink's SGD optimizer runs for a fixed number of iterations. It 
> would be good to support a dynamic convergence criterion, too.



--
This message was sent by Atlassian JIRA
(v6.3.4#6332)

Reply via email to