Github user dbtsai commented on a diff in the pull request:
https://github.com/apache/spark/pull/7080#discussion_r33532011
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
---
@@ -534,27 +554,39 @@ private class LogisticCostFun(
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
})
- // regVal is the sum of weight squares for L2 regularization
- val norm = if (regParamL2 == 0.0) {
- 0.0
- } else if (fitIntercept) {
- brzNorm(Vectors.dense(weights.toArray.slice(0, weights.size
-1)).toBreeze, 2.0)
- } else {
- brzNorm(weights, 2.0)
- }
- val regVal = 0.5 * regParamL2 * norm * norm
+ val totalGradientArray = logisticAggregator.gradient.toArray
- val loss = logisticAggregator.loss + regVal
- val gradient = logisticAggregator.gradient
-
- if (fitIntercept) {
- val wArray = w.toArray.clone()
- wArray(wArray.length - 1) = 0.0
- axpy(regParamL2, Vectors.dense(wArray), gradient)
+ // regVal is the sum of weight squares excluding intercept for L2
regularization.
+ val regVal = if (regParamL2 == 0.0) {
+ 0.0
} else {
- axpy(regParamL2, w, gradient)
--- End diff --
I think it will be even faster than the previous one since I don't have a
copy of wArray anymore, and I compute the gradient and norm in the same pass
while previously, we looped through the array twice. Since it's axpy, the BLAS
will not gain too much performance.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]