Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/3833#discussion_r23821635
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala ---
    @@ -81,14 +138,101 @@ class LogisticGradient extends Gradient {
           label: Double,
           weights: Vector,
           cumGradient: Vector): Double = {
    -    val margin = -1.0 * dot(data, weights)
    -    val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
    -    axpy(gradientMultiplier, data, cumGradient)
    -    if (label > 0) {
    -      // The following is equivalent to log(1 + exp(margin)) but more 
numerically stable.
    -      MLUtils.log1pExp(margin)
    -    } else {
    -      MLUtils.log1pExp(margin) - margin
    +    assert((weights.size % data.size) == 0)
    +    val dataSize = data.size
    +
    +    // (n + 1) is number of classes
    +    val n = weights.size / dataSize
    +    n match {
    +      case 1 =>
    +        /**
    +         * For Binary Logistic Regression.
    +         *
    +         * Although the loss and gradient calculation for multinomial one 
is more generalized,
    +         * and multinomial one can also be used in binary case, we still 
implement a specialized
    +         * binary version for performance reason.
    +         */
    +        val margin = -1.0 * dot(data, weights)
    +        val multiplier = (1.0 / (1.0 + math.exp(margin))) - label
    +        axpy(multiplier, data, cumGradient)
    +        if (label > 0) {
    +          // The following is equivalent to log(1 + exp(margin)) but more 
numerically stable.
    +          MLUtils.log1pExp(margin)
    +        } else {
    +          MLUtils.log1pExp(margin) - margin
    +        }
    +      case _ =>
    +        /**
    +         * For Multinomial Logistic Regression.
    +         */
    +        val weightsArray = weights match {
    +          case dv: DenseVector => dv.values
    +          case _ =>
    +            throw new IllegalArgumentException(
    +              s"weights only supports dense vector but got type 
${weights.getClass}.")
    +        }
    +        val cumGradientArray = cumGradient match {
    +          case dv: DenseVector => dv.values
    +          case _ =>
    +            throw new IllegalArgumentException(
    +              s"cumGradient only supports dense vector but got type 
${cumGradient.getClass}.")
    +        }
    +
    +        // marginY is margins(label - 1) in the formula.
    +        var marginY = 0.0
    +        var maxMargin = Double.NegativeInfinity
    +        var maxMarginIndex = 0
    +
    +        val margins = (0 until n).map { i =>
    --- End diff --
    
    `Array.tabulate`


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to