srowen commented on a change in pull request #28960: URL: https://github.com/apache/spark/pull/28960#discussion_r454062744
########## File path: mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala ########## @@ -226,45 +239,48 @@ object GradientDescent extends Logging { var converged = false // indicates whether converged based on convergenceTol var i = 1 - while (!converged && i <= numIterations) { - val bcWeights = data.context.broadcast(weights) - // Sample a subset (fraction miniBatchFraction) of the total data - // compute and sum up the subgradients on this subset (this is one map-reduce) - val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i) - .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))( - seqOp = (c, v) => { - // c: (grad, loss, count), v: (label, features) - val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1)) - (c._1, c._2 + l, c._3 + 1) - }, - combOp = (c1, c2) => { - // c: (grad, loss, count) - (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3) - }) - bcWeights.destroy() - - if (miniBatchSize > 0) { - /** - * lossSum is computed using the weights from the previous iteration - * and regVal is the regularization value computed in the previous iteration as well. - */ - stochasticLossHistory += lossSum / miniBatchSize + regVal - val update = updater.compute( - weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), - stepSize, i, regParam) - weights = update._1 - regVal = update._2 - - previousWeights = currentWeights - currentWeights = Some(weights) - if (previousWeights != None && currentWeights != None) { - converged = isConverged(previousWeights.get, - currentWeights.get, convergenceTol) + breakable { + while (i <= numIterations + 1) { + val bcWeights = data.context.broadcast(weights) + // Sample a subset (fraction miniBatchFraction) of the total data + // compute and sum up the subgradients on this subset (this is one map-reduce) + val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i) + .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))( + seqOp = (c, v) => { Review comment: Yeah it's a little unusual unless it significantly simplifies the code. Can `!converged` be added back to the while condition, and then turn the `if (X) break` condition below into `if (!X) { ... code that follows ...}` ? should be the same as i will increment and end the loop right after anyway ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org