Github user thvasilo commented on a diff in the pull request:

    https://github.com/apache/flink/pull/740#discussion_r31336757
  
    --- Diff: 
flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
 ---
    @@ -88,217 +55,296 @@ class GradientDescent() extends IterativeSolver {
       override def optimize(
         data: DataSet[LabeledVector],
         initialWeights: Option[DataSet[WeightVector]]): DataSet[WeightVector] 
= {
    +
         val numberOfIterations: Int = parameters(Iterations)
         val convergenceThresholdOption: Option[Double] = 
parameters.get(ConvergenceThreshold)
    +    val lossFunction = parameters(LossFunction)
    +    val learningRate = parameters(LearningRate)
    +    val regularizationConstant = parameters(RegularizationConstant)
     
         // Initialize weights
         val initialWeightsDS: DataSet[WeightVector] = 
createInitialWeightsDS(initialWeights, data)
     
         // Perform the iterations
    -    val optimizedWeights = convergenceThresholdOption match {
    +    convergenceThresholdOption match {
           // No convergence criterion
           case None =>
    -        initialWeightsDS.iterate(numberOfIterations) {
    -          weightVectorDS => {
    -            SGDStep(data, weightVectorDS)
    -          }
    -        }
    +        optimizeWithoutConvergenceCriterion(
    +          data,
    +          initialWeightsDS,
    +          numberOfIterations,
    +          regularizationConstant,
    +          learningRate,
    +          lossFunction)
           case Some(convergence) =>
    -        // Calculates the regularized loss, from the data and given weights
    -        def lossCalculation(data: DataSet[LabeledVector], weightDS: 
DataSet[WeightVector]):
    -        DataSet[Double] = {
    -          data
    -            .map {new LossCalculation}.withBroadcastSet(weightDS, 
WEIGHTVECTOR_BROADCAST)
    -            .reduce {
    -              (left, right) =>
    -                val (leftLoss, leftCount) = left
    -                val (rightLoss, rightCount) = right
    -                (leftLoss + rightLoss, rightCount + leftCount)
    +        optimizeWithConvergenceCriterion(
    +          data,
    +          initialWeightsDS,
    +          numberOfIterations,
    +          regularizationConstant,
    +          learningRate,
    +          convergence,
    +          lossFunction
    +        )
    +    }
    +  }
    +
    +  def optimizeWithConvergenceCriterion(
    +      dataPoints: DataSet[LabeledVector],
    +      initialWeightsDS: DataSet[WeightVector],
    +      numberOfIterations: Int,
    +      regularizationConstant: Double,
    +      learningRate: Double,
    +      convergenceThreshold: Double,
    +      lossFunction: LossFunction)
    +    : DataSet[WeightVector] = {
    +    // We have to calculate for each weight vector the sum of squared 
residuals,
    +    // and then sum them and apply regularization
    +    val initialLossSumDS = calculateLoss(dataPoints, initialWeightsDS, 
lossFunction)
    +
    +    // Combine weight vector with the current loss
    +    val initialWeightsWithLossSum = 
initialWeightsDS.mapWithBcVariable(initialLossSumDS){
    +      (weights, loss) => (weights, loss)
    +    }
    +
    +    val resultWithLoss = 
initialWeightsWithLossSum.iterateWithTermination(numberOfIterations) {
    +      weightsWithPreviousLossSum =>
    +
    +        // Extract weight vector and loss
    +        val previousWeightsDS = weightsWithPreviousLossSum.map{_._1}
    +        val previousLossSumDS = weightsWithPreviousLossSum.map{_._2}
    +
    +        val currentWeightsDS = SGDStep(
    +          dataPoints,
    +          previousWeightsDS,
    +          lossFunction,
    +          regularizationConstant,
    +          learningRate)
    +
    +        val currentLossSumDS = calculateLoss(dataPoints, currentWeightsDS, 
lossFunction)
    +
    +        // Check if the relative change in the loss is smaller than the
    +        // convergence threshold. If yes, then terminate i.e. return empty 
termination data set
    +        val termination = 
previousLossSumDS.filterWithBcVariable(currentLossSumDS){
    +          (previousLoss, currentLoss) => {
    +            if (previousLoss <= 0) {
    +              false
    +            } else {
    +              scala.math.abs((previousLoss - currentLoss)/previousLoss) >= 
convergenceThreshold
                 }
    -            .map{new 
RegularizedLossCalculation}.withBroadcastSet(weightDS, WEIGHTVECTOR_BROADCAST)
    +          }
             }
    -        // We have to calculate for each weight vector the sum of squared 
residuals,
    -        // and then sum them and apply regularization
    -        val initialLossSumDS = lossCalculation(data, initialWeightsDS)
    -
    -        // Combine weight vector with the current loss
    -        val initialWeightsWithLossSum = initialWeightsDS.
    -          crossWithTiny(initialLossSumDS).setParallelism(1)
    -
    -        val resultWithLoss = initialWeightsWithLossSum.
    -          iterateWithTermination(numberOfIterations) {
    -          weightsWithLossSum =>
    -
    -            // Extract weight vector and loss
    -            val previousWeightsDS = weightsWithLossSum.map{_._1}
    -            val previousLossSumDS = weightsWithLossSum.map{_._2}
    -
    -            val currentWeightsDS = SGDStep(data, previousWeightsDS)
    -
    -            val currentLossSumDS = lossCalculation(data, currentWeightsDS)
    -
    -            // Check if the relative change in the loss is smaller than the
    -            // convergence threshold. If yes, then terminate i.e. return 
empty termination data set
    -            val termination = 
previousLossSumDS.crossWithTiny(currentLossSumDS).setParallelism(1).
    -              filter{
    -              pair => {
    -                val (previousLoss, currentLoss) = pair
    -
    -                if (previousLoss <= 0) {
    -                  false
    -                } else {
    -                  math.abs((previousLoss - currentLoss)/previousLoss) >= 
convergence
    -                }
    -              }
    -            }
     
    -            // Result for new iteration
    -            (currentWeightsDS cross currentLossSumDS, termination)
    -        }
    -        // Return just the weights
    -        resultWithLoss.map{_._1}
    +        // Result for new iteration
    +        (currentWeightsDS.mapWithBcVariable(currentLossSumDS)((w, l) => 
(w, l)), termination)
         }
    -    optimizedWeights
    +    // Return just the weights
    +    resultWithLoss.map{_._1}
       }
     
    -  /** Calculates the loss value, given a labeled vector and the current 
weight vector
    +  def optimizeWithoutConvergenceCriterion(
    +      data: DataSet[LabeledVector],
    +      initialWeightsDS: DataSet[WeightVector],
    +      numberOfIterations: Int,
    +      regularizationConstant: Double,
    +      learningRate: Double,
    +      lossFunction: LossFunction)
    +    : DataSet[WeightVector] = {
    +    initialWeightsDS.iterate(numberOfIterations) {
    +      weightVectorDS => {
    +        SGDStep(data, weightVectorDS, lossFunction, 
regularizationConstant, learningRate)
    +      }
    +    }
    +  }
    +
    +  /** Performs one iteration of Stochastic Gradient Descent using mini 
batches
         *
    -    * The weight vector is received as a broadcast variable.
    +    * @param data A Dataset of LabeledVector (label, features) pairs
    +    * @param currentWeights A Dataset with the current weights to be 
optimized as its only element
    +    * @return A Dataset containing the weights after one stochastic 
gradient descent step
         */
    -  private class LossCalculation extends RichMapFunction[LabeledVector, 
(Double, Int)] {
    +  private def SGDStep(
    +    data: DataSet[(LabeledVector)],
    +    currentWeights: DataSet[WeightVector],
    +    lossFunction: LossFunction,
    +    regularizationConstant: Double,
    +    learningRate: Double)
    +  : DataSet[WeightVector] = {
    +
    +    data.mapWithBcVariable(currentWeights){
    +      (data, weightVector) => (lossFunction.gradient(data, weightVector), 
1)
    +    }.reduce{
    +      (left, right) =>
    +        val (leftGradVector, leftCount) = left
    +        val (rightGradVector, rightCount) = right
    +        // Add the left gradient to the right one
    +        BLAS.axpy(1.0, leftGradVector.weights, rightGradVector.weights)
    +        val gradients = WeightVector(
    +          rightGradVector.weights, leftGradVector.intercept + 
rightGradVector.intercept)
     
    -    var weightVector: WeightVector = null
    +        (gradients , leftCount + rightCount)
    +    }.mapWithBcVariableIteration(currentWeights){
    +      (gradientCount, weightVector, iteration) => {
    +        val (WeightVector(weights, intercept), count) = gradientCount
     
    -    @throws(classOf[Exception])
    -    override def open(configuration: Configuration): Unit = {
    -      val list = this.getRuntimeContext.
    -        getBroadcastVariable[WeightVector](WEIGHTVECTOR_BROADCAST)
    +        BLAS.scal(1.0/count, weights)
     
    -      weightVector = list.get(0)
    -    }
    +        val gradient = WeightVector(weights, intercept/count)
     
    -    override def map(example: LabeledVector): (Double, Int) = {
    -      val lossFunction = parameters(LossFunction)
    -      val predictionFunction = parameters(PredictionFunction)
    +        val effectiveLearningRate = learningRate/Math.sqrt(iteration)
     
    -      val loss = lossFunction.lossValue(
    -        example,
    -        weightVector,
    -        predictionFunction)
    +        val newWeights = takeStep(
    +          weightVector.weights,
    +          gradient.weights,
    +          regularizationConstant,
    +          effectiveLearningRate)
     
    -      (loss, 1)
    +        WeightVector(
    +          newWeights,
    +          weightVector.intercept - effectiveLearningRate * 
gradient.intercept)
    +      }
         }
       }
     
    -/** Calculates the regularized loss value, given the loss and the current 
weight vector
    -  *
    -  * The weight vector is received as a broadcast variable.
    -  */
    -private class RegularizedLossCalculation extends RichMapFunction[(Double, 
Int), Double] {
    -
    -  var weightVector: WeightVector = null
    -
    -  @throws(classOf[Exception])
    -  override def open(configuration: Configuration): Unit = {
    -    val list = this.getRuntimeContext.
    -      getBroadcastVariable[WeightVector](WEIGHTVECTOR_BROADCAST)
    -
    -    weightVector = list.get(0)
    -  }
    -
    -  override def map(lossAndCount: (Double, Int)): Double = {
    -    val (lossSum, count) = lossAndCount
    -    val regType = parameters(RegularizationType)
    -    val regParameter = parameters(RegularizationParameter)
    -
    -    val regularizedLoss = {
    -      regType.regLoss(
    -        lossSum/count,
    -        weightVector.weights,
    -        regParameter)
    +  /** Calculates the new weights based on the gradient
    +    *
    +    * @param weightVector
    +    * @param gradient
    +    * @param regularizationConstant
    +    * @param learningRate
    +    * @return
    +    */
    +  def takeStep(
    --- End diff --
    
    It would make sense to define this in IterativeSolver instead


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

Reply via email to