Github user jrdi commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5967#discussion_r171690470
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 ---
    @@ -44,45 +53,151 @@ private[classification] trait LogisticRegressionParams 
extends ProbabilisticClas
     @AlphaComponent
     class LogisticRegression
       extends ProbabilisticClassifier[Vector, LogisticRegression, 
LogisticRegressionModel]
    -  with LogisticRegressionParams {
    +  with LogisticRegressionParams with Logging {
     
    -  /** @group setParam */
    +  /**
    +   * Set the regularization parameter.
    +   * Default is 0.0.
    +   * @group setParam
    +   */
       def setRegParam(value: Double): this.type = set(regParam, value)
    +  setDefault(regParam -> 0.0)
     
    -  /** @group setParam */
    +  /**
    +   * Set the ElasticNet mixing parameter.
    +   * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an 
L1 penalty.
    +   * For 0 < alpha < 1, the penalty is a combination of L1 and L2.
    +   * Default is 0.0 which is an L2 penalty.
    +   * @group setParam
    +   */
    +  def setElasticNetParam(value: Double): this.type = set(elasticNetParam, 
value)
    +  setDefault(elasticNetParam -> 0.0)
    +
    +  /**
    +   * Set the maximal number of iterations.
    +   * Default is 100.
    +   * @group setParam
    +   */
       def setMaxIter(value: Int): this.type = set(maxIter, value)
    +  setDefault(maxIter -> 100)
    +
    +  /**
    +   * Set the convergence tolerance of iterations.
    +   * Smaller value will lead to higher accuracy with the cost of more 
iterations.
    +   * Default is 1E-6.
    +   * @group setParam
    +   */
    +  def setTol(value: Double): this.type = set(tol, value)
    +  setDefault(tol -> 1E-6)
     
       /** @group setParam */
       def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
    +  setDefault(fitIntercept -> true)
     
       /** @group setParam */
       def setThreshold(value: Double): this.type = set(threshold, value)
    +  setDefault(threshold -> 0.5)
     
       override protected def train(dataset: DataFrame): 
LogisticRegressionModel = {
         // Extract columns from data.  If dataset is persisted, do not persist 
oldDataset.
    -    val oldDataset = extractLabeledPoints(dataset)
    +    val instances = extractLabeledPoints(dataset).map {
    +      case LabeledPoint(label: Double, features: Vector) => (label, 
features)
    +    }
         val handlePersistence = dataset.rdd.getStorageLevel == 
StorageLevel.NONE
    -    if (handlePersistence) {
    -      oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
    +    if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
    +
    +    val (summarizer, labelSummarizer) = instances.treeAggregate(
    +      (new MultivariateOnlineSummarizer, new MultiClassSummarizer))( {
    +        case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: 
MultiClassSummarizer),
    +        (label: Double, features: Vector)) =>
    +          (summarizer.add(features), labelSummarizer.add(label))
    +      }, {
    +        case ((summarizer1: MultivariateOnlineSummarizer, 
labelSummarizer1: MultiClassSummarizer),
    +        (summarizer2: MultivariateOnlineSummarizer, labelSummarizer2: 
MultiClassSummarizer)) =>
    +          (summarizer1.merge(summarizer2), 
labelSummarizer1.merge(labelSummarizer2))
    +      })
    +
    +    val histogram = labelSummarizer.histogram
    +    val numInvalid = labelSummarizer.countInvalid
    +    val numClasses = histogram.length
    +    val numFeatures = summarizer.mean.size
    +
    +    if (numInvalid != 0) {
    +      logError("Classification labels should be in {0 to " + (numClasses - 
1) + "}. " +
    +        "Found " + numInvalid + " invalid labels.")
    +      throw new SparkException("Input validation failed.")
    +    }
    +
    +    if (numClasses > 2) {
    +      logError("Currently, LogisticRegression with ElasticNet in ML 
package only supports " +
    +        "binary classification. Found " + numClasses + " in the input 
dataset.")
    +      throw new SparkException("Input validation failed.")
         }
     
    -    // Train model
    -    val lr = new LogisticRegressionWithLBFGS()
    -      .setIntercept($(fitIntercept))
    -    lr.optimizer
    -      .setRegParam($(regParam))
    -      .setNumIterations($(maxIter))
    -    val oldModel = lr.run(oldDataset)
    -    val lrm = new LogisticRegressionModel(this, oldModel.weights, 
oldModel.intercept)
    +    val featuresMean = summarizer.mean.toArray
    +    val featuresStd = summarizer.variance.toArray.map(math.sqrt)
    +
    +    val regParamL1 = $(elasticNetParam) * $(regParam)
    +    val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
     
    -    if (handlePersistence) {
    -      oldDataset.unpersist()
    +    val costFun = new LogisticCostFun(instances, numClasses, 
$(fitIntercept),
    +      featuresStd, featuresMean, regParamL2)
    +
    +    val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
    +      new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
    +    } else {
    +      // Remove the L1 penalization on the intercept
    +      def regParamL1Fun = (index: Int) => {
    +        if (index == numFeatures) 0.0 else regParamL1
    +      }
    +      new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, 
$(tol))
         }
    -    copyValues(lrm)
    +
    +    val initialWeightsWithIntercept =
    +      Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
    +
    +    // TODO: Compute the initial intercept based on the histogram.
    +    if ($(fitIntercept)) initialWeightsWithIntercept.toArray(numFeatures) 
= 1.0
    +
    +    val states = optimizer.iterations(new CachedDiffFunction(costFun),
    +      initialWeightsWithIntercept.toBreeze.toDenseVector)
    +
    +    var state = states.next()
    +    val lossHistory = mutable.ArrayBuilder.make[Double]
    +
    +    while (states.hasNext) {
    +      lossHistory += state.value
    +      state = states.next()
    +    }
    +    lossHistory += state.value
    +
    +    // The weights are trained in the scaled space; we're converting them 
back to
    +    // the original space.
    +    val weightsWithIntercept = {
    +      val rawWeights = state.x.toArray.clone()
    +      var i = 0
    +      // Note that the intercept in scaled space and original space is the 
same;
    +      // as a result, no scaling is needed.
    +      while (i < numFeatures) {
    +        rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) 
else 0.0 }
    --- End diff --
    
    @dbtsai why we need to scale using ` / std` and not ` * std`? 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to