huaxingao commented on a change in pull request #28960:
URL: https://github.com/apache/spark/pull/28960#discussion_r448024045
##########
File path:
mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
##########
@@ -226,45 +226,48 @@ object GradientDescent extends Logging {
var converged = false // indicates whether converged based on
convergenceTol
var i = 1
- while (!converged && i <= numIterations) {
- val bcWeights = data.context.broadcast(weights)
- // Sample a subset (fraction miniBatchFraction) of the total data
- // compute and sum up the subgradients on this subset (this is one
map-reduce)
- val (gradientSum, lossSum, miniBatchSize) = data.sample(false,
miniBatchFraction, 42 + i)
- .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))(
- seqOp = (c, v) => {
- // c: (grad, loss, count), v: (label, features)
- val l = gradient.compute(v._2, v._1, bcWeights.value,
Vectors.fromBreeze(c._1))
- (c._1, c._2 + l, c._3 + 1)
- },
- combOp = (c1, c2) => {
- // c: (grad, loss, count)
- (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
- })
- bcWeights.destroy()
-
- if (miniBatchSize > 0) {
- /**
- * lossSum is computed using the weights from the previous iteration
- * and regVal is the regularization value computed in the previous
iteration as well.
- */
- stochasticLossHistory += lossSum / miniBatchSize + regVal
- val update = updater.compute(
- weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble),
- stepSize, i, regParam)
- weights = update._1
- regVal = update._2
-
- previousWeights = currentWeights
- currentWeights = Some(weights)
- if (previousWeights != None && currentWeights != None) {
- converged = isConverged(previousWeights.get,
- currentWeights.get, convergenceTol)
+ breakable {
+ while (i <= numIterations + 1) {
+ val bcWeights = data.context.broadcast(weights)
+ // Sample a subset (fraction miniBatchFraction) of the total data
+ // compute and sum up the subgradients on this subset (this is one
map-reduce)
+ val (gradientSum, lossSum, miniBatchSize) = data.sample(false,
miniBatchFraction, 42 + i)
+ .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))(
+ seqOp = (c, v) => {
+ // c: (grad, loss, count), v: (label, features)
+ val l = gradient.compute(v._2, v._1, bcWeights.value,
Vectors.fromBreeze(c._1))
+ (c._1, c._2 + l, c._3 + 1)
+ },
+ combOp = (c1, c2) => {
+ // c: (grad, loss, count)
+ (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
+ })
+ bcWeights.destroy()
+
+ if (miniBatchSize > 0) {
+ /**
+ * lossSum is computed using the weights from the previous iteration
+ * and regVal is the regularization value computed in the previous
iteration as well.
+ */
+ stochasticLossHistory += lossSum / miniBatchSize + regVal
+ if (converged || i == (numIterations + 1)) break
Review comment:
Currently, stochasticLossHistory only contains initial state + state
form 1 to n-1 iteration, so need to add state for the last iteration too. After
adding the last state, exist the loop.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]