Github user yanboliang commented on a diff in the pull request:
https://github.com/apache/spark/pull/11179#discussion_r53131871
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
---
@@ -437,23 +437,25 @@ object AFTSurvivalRegressionModel extends
MLReadable[AFTSurvivalRegressionModel]
private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
extends Serializable {
- // beta is the intercept and regression coefficients to the covariates
- private val beta = parameters.slice(1, parameters.length)
+ // the regression coefficients to the covariates
+ private val coefficients = parameters.slice(2, parameters.length)
+ private val intercept = parameters.valueAt(1)
// sigma is the scale parameter of the AFT model
private val sigma = math.exp(parameters(0))
private var totalCnt: Long = 0L
private var lossSum = 0.0
- private var gradientBetaSum = BDV.zeros[Double](beta.length)
+ private var gradientCoefficientSum =
BDV.zeros[Double](coefficients.length)
+ private var gradientInterceptSum = 0.0
private var gradientLogSigmaSum = 0.0
def count: Long = totalCnt
def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
- // Here we optimize loss function over beta and log(sigma)
+ // Here we optimize loss function over coefficients and log(sigma)
--- End diff --
```beta``` means ```coefficients and intercept```, so here should be
```coefficients, intercept and log(sigma)``` in the annotation.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]