Github user yanboliang commented on a diff in the pull request:

    https://github.com/apache/spark/pull/15394#discussion_r83006938
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala ---
    @@ -85,73 +101,193 @@ private[ml] class WeightedLeastSquares(
         val triK = summary.triK
         val wSum = summary.wSum
         val bBar = summary.bBar
    -    val bStd = summary.bStd
    +    val bbBar = summary.bbBar
         val aBar = summary.aBar
    -    val aVar = summary.aVar
    +    val aStd = summary.aStd
         val abBar = summary.abBar
         val aaBar = summary.aaBar
    -    val aaValues = aaBar.values
    -
    -    if (bStd == 0) {
    -      if (fitIntercept) {
    -        logWarning(s"The standard deviation of the label is zero, so the 
coefficients will be " +
    -          s"zeros and the intercept will be the mean of the label; as a 
result, " +
    -          s"training is not needed.")
    -        val coefficients = new DenseVector(Array.ofDim(k-1))
    +    val aaBarValues = aaBar.values
    +    val numFeatures = abBar.size
    +    val rawBStd = summary.bStd
    +    // if b is constant (rawBStd is zero), then b cannot be scaled. In 
this case
    +    // setting bStd=abs(bBar) ensures that b is not scaled anymore in 
l-bfgs algorithm.
    +    val bStd = if (rawBStd == 0.0) math.abs(bBar) else rawBStd
    +
    +    if (rawBStd == 0) {
    +      if (fitIntercept || bBar == 0.0) {
    +        if (bBar == 0.0) {
    +          logWarning(s"Mean and standard deviation of the label are zero, 
so the coefficients " +
    +            s"and the intercept will all be zero; as a result, training is 
not needed.")
    +        } else {
    +          logWarning(s"The standard deviation of the label is zero, so the 
coefficients will be " +
    +            s"zeros and the intercept will be the mean of the label; as a 
result, " +
    +            s"training is not needed.")
    +        }
    +        val coefficients = new DenseVector(Array.ofDim(numFeatures))
             val intercept = bBar
             val diagInvAtWA = new DenseVector(Array(0D))
    -        return new WeightedLeastSquaresModel(coefficients, intercept, 
diagInvAtWA)
    +        return new WeightedLeastSquaresModel(coefficients, intercept, 
diagInvAtWA, Array(0D))
    +      } else {
    +        require(!(regParam > 0.0 && standardizeLabel), "The standard 
deviation of the label is " +
    +          "zero. Model cannot be regularized with standardization=true")
    +        logWarning(s"The standard deviation of the label is zero. Consider 
setting " +
    +          s"fitIntercept=true.")
    +      }
    +    }
    +
    +    val aBarStd = new Array[Double](numFeatures)
    +    var j = 0
    +    while (j < numFeatures) {
    +      if (aStd(j) == 0.0) {
    +        aBarStd(j) = 0.0
           } else {
    -        require(!(regParam > 0.0 && standardizeLabel),
    -          "The standard deviation of the label is zero. " +
    -            "Model cannot be regularized with standardization=true")
    -        logWarning(s"The standard deviation of the label is zero. " +
    -          "Consider setting fitIntercept=true.")
    +        aBarStd(j) = aBar(j) / aStd(j)
    +      }
    +      j += 1
    +    }
    +
    +    val abBarStd = new Array[Double](numFeatures)
    +    j = 0
    +    while (j < numFeatures) {
    +      if (aStd(j) == 0.0) {
    +        abBarStd(j) = 0.0
    +      } else {
    +        abBarStd(j) = abBar(j) / (aStd(j) * bStd)
    +      }
    +      j += 1
    +    }
    +
    +    val aaBarStd = new Array[Double](triK)
    +    j = 0
    +    var kk = 0
    +    while (j < numFeatures) {
    +      val aStdJ = aStd(j)
    +      var i = 0
    +      while (i <= j) {
    +        val aStdI = aStd(i)
    +        if (aStdJ == 0.0 || aStdI == 0.0) {
    +          aaBarStd(kk) = 0.0
    +        } else {
    +          aaBarStd(kk) = aaBarValues(kk) / (aStdI * aStdJ)
    +        }
    +        kk += 1
    +        i += 1
           }
    +      j += 1
         }
     
    +    val bBarStd = bBar / bStd
    +    val bbBarStd = bbBar / (bStd * bStd)
    +
    +    val effectiveRegParam = regParam / bStd
    +    val effectiveL1RegParam = elasticNetParam * effectiveRegParam
    +    val effectiveL2RegParam = (1.0 - elasticNetParam) * effectiveRegParam
    +
         // add regularization to diagonals
         var i = 0
    -    var j = 2
    +    j = 2
         while (i < triK) {
    -      var lambda = regParam
    -      if (standardizeFeatures) {
    -        lambda *= aVar(j - 2)
    +      var lambda = effectiveL2RegParam
    +      if (!standardizeFeatures) {
    +        val std = aStd(j - 2)
    +        if (std != 0.0) {
    +          lambda /= (std * std)
    +        } else {
    +          lambda = 0.0
    +        }
           }
    -      if (standardizeLabel && bStd != 0) {
    -        lambda /= bStd
    +      if (!standardizeLabel) {
    +        lambda *= bStd
           }
    -      aaValues(i) += lambda
    +      aaBarStd(i) += lambda
           i += j
           j += 1
         }
    +    val aa = getAtA(aaBarStd, aBarStd)
    +    val ab = getAtB(abBarStd, bBarStd)
     
    -    val aa = if (fitIntercept) {
    -      Array.concat(aaBar.values, aBar.values, Array(1.0))
    -    } else {
    -      aaBar.values
    -    }
    -    val ab = if (fitIntercept) {
    -      Array.concat(abBar.values, Array(bBar))
    +    val solver = if ((solverType == WeightedLeastSquares.Auto && 
elasticNetParam != 0.0) ||
    +      (solverType == WeightedLeastSquares.QuasiNewton)) {
    +      val effectiveL1RegFun: Option[(Int) => Double] = if 
(effectiveL1RegParam != 0.0) {
    +        Some((index: Int) => {
    +            if (fitIntercept && index == numFeatures) {
    +              0.0
    +            } else {
    +              if (standardizeFeatures) {
    +                effectiveL1RegParam
    +              } else {
    +                if (aStd(index) != 0.0) effectiveL1RegParam / aStd(index) 
else 0.0
    +              }
    +            }
    +          })
    +      } else {
    +        None
    +      }
    +      new QuasiNewtonSolver(fitIntercept, maxIter, tol, effectiveL1RegFun)
         } else {
    -      abBar.values
    +      new CholeskySolver(fitIntercept)
         }
     
    -    val x = CholeskyDecomposition.solve(aa, ab)
    +    val solution = solver match {
    +      case cholesky: CholeskySolver =>
    +        try {
    +          cholesky.solve(bBarStd, bbBarStd, ab, aa, new 
DenseVector(aBarStd))
    +        } catch {
    +          // if Auto solver is used and Cholesky fails due to singular 
AtA, then fall back to
    +          // quasi-newton solver
    +          case _: SingularMatrixException if solverType == 
WeightedLeastSquares.Auto =>
    +            logWarning("Cholesky solver failed due to singular covariance 
matrix. " +
    +              "Retrying with Quasi-Newton solver.")
    +            // ab and aa were modified in place, so reconstruct them
    +            val _aa = getAtA(aaBarStd, aBarStd)
    +            val _ab = getAtB(abBarStd, bBarStd)
    +            val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, 
tol, None)
    +            newSolver.solve(bBarStd, bbBarStd, _ab, _aa, new 
DenseVector(aBarStd))
    +        }
    +      case qn: QuasiNewtonSolver =>
    +        qn.solve(bBarStd, bbBarStd, ab, aa, new DenseVector(aBarStd))
    +    }
    +    val intercept = solution.intercept * bStd
    +    val coefficients = solution.coefficients
     
    -    val aaInv = CholeskyDecomposition.inverse(aa, k)
    +    // convert the coefficients from the scaled space to the original space
    +    var ii = 0
    --- End diff --
    
    Nit: I'd like to use ```p``` or ```q``` rather than ```ii```.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to