Github user yanboliang commented on a diff in the pull request: https://github.com/apache/spark/pull/15394#discussion_r83006722 --- Diff: mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala --- @@ -85,73 +101,193 @@ private[ml] class WeightedLeastSquares( val triK = summary.triK val wSum = summary.wSum val bBar = summary.bBar - val bStd = summary.bStd + val bbBar = summary.bbBar val aBar = summary.aBar - val aVar = summary.aVar + val aStd = summary.aStd val abBar = summary.abBar val aaBar = summary.aaBar - val aaValues = aaBar.values - - if (bStd == 0) { - if (fitIntercept) { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + - s"training is not needed.") - val coefficients = new DenseVector(Array.ofDim(k-1)) + val aaBarValues = aaBar.values + val numFeatures = abBar.size + val rawBStd = summary.bStd + // if b is constant (rawBStd is zero), then b cannot be scaled. In this case + // setting bStd=abs(bBar) ensures that b is not scaled anymore in l-bfgs algorithm. + val bStd = if (rawBStd == 0.0) math.abs(bBar) else rawBStd + + if (rawBStd == 0) { + if (fitIntercept || bBar == 0.0) { + if (bBar == 0.0) { + logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + + s"and the intercept will all be zero; as a result, training is not needed.") + } else { + logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + + s"zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") + } + val coefficients = new DenseVector(Array.ofDim(numFeatures)) val intercept = bBar val diagInvAtWA = new DenseVector(Array(0D)) - return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) + return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA, Array(0D)) + } else { + require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " + + "zero. Model cannot be regularized with standardization=true") + logWarning(s"The standard deviation of the label is zero. Consider setting " + + s"fitIntercept=true.") + } + } + + val aBarStd = new Array[Double](numFeatures) + var j = 0 + while (j < numFeatures) { + if (aStd(j) == 0.0) { + aBarStd(j) = 0.0 } else { - require(!(regParam > 0.0 && standardizeLabel), - "The standard deviation of the label is zero. " + - "Model cannot be regularized with standardization=true") - logWarning(s"The standard deviation of the label is zero. " + - "Consider setting fitIntercept=true.") + aBarStd(j) = aBar(j) / aStd(j) + } + j += 1 + } + + val abBarStd = new Array[Double](numFeatures) + j = 0 + while (j < numFeatures) { + if (aStd(j) == 0.0) { + abBarStd(j) = 0.0 + } else { + abBarStd(j) = abBar(j) / (aStd(j) * bStd) + } + j += 1 + } + + val aaBarStd = new Array[Double](triK) + j = 0 + var kk = 0 --- End diff -- I'd like to use ```p``` or ```q``` rather than ```kk```.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org