Github user manishamde commented on a diff in the pull request:

    https://github.com/apache/spark/pull/3374#discussion_r20621257
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala ---
    @@ -45,146 +43,92 @@ import org.apache.spark.storage.StorageLevel
      *    but weak hypothesis weights are not computed correctly for LogLoss 
or AbsoluteError.
      *    Running with those losses will likely behave reasonably, but lacks 
the same guarantees.
      *
    - * @param boostingStrategy Parameters for the gradient boosting algorithm
    + * @param boostingStrategy Parameters for the gradient boosting algorithm.
      */
     @Experimental
    -class GradientBoosting (
    -    private val boostingStrategy: BoostingStrategy) extends Serializable 
with Logging {
    -
    -  boostingStrategy.weakLearnerParams.algo = Regression
    -  boostingStrategy.weakLearnerParams.impurity = impurity.Variance
    -
    -  // Ensure values for weak learner are the same as what is provided to 
the boosting algorithm.
    -  boostingStrategy.weakLearnerParams.numClassesForClassification =
    -    boostingStrategy.numClassesForClassification
    -
    -  boostingStrategy.assertValid()
    +class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
    +  extends Serializable with Logging {
     
       /**
        * Method to train a gradient boosting model
        * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
    -   * @return WeightedEnsembleModel that can be used for prediction
    +   * @return a gradient boosted trees model that can be used for prediction
        */
    -  def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
    -    val algo = boostingStrategy.algo
    +  def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
    +    val algo = boostingStrategy.treeStrategy.algo
         algo match {
    -      case Regression => GradientBoosting.boost(input, boostingStrategy)
    +      case Regression => GradientBoostedTrees.boost(input, 
boostingStrategy)
           case Classification =>
             // Map labels to -1, +1 so binary classification can be treated as 
regression.
             val remappedInput = input.map(x => new LabeledPoint((x.label * 2) 
- 1, x.features))
    -        GradientBoosting.boost(remappedInput, boostingStrategy)
    +        GradientBoostedTrees.boost(remappedInput, boostingStrategy)
           case _ =>
             throw new IllegalArgumentException(s"$algo is not supported by the 
gradient boosting.")
         }
       }
     
    +  /**
    +   * Java-friendly API for 
[[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
    +   */
    +  def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
    +    run(input.rdd)
    +  }
     }
     
     
    -object GradientBoosting extends Logging {
    +object GradientBoostedTrees extends Logging {
     
       /**
        * Method to train a gradient boosting model.
        *
    -   * Note: Using 
[[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
    -   *       is recommended to clearly specify regression.
    -   *       Using 
[[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
    -   *       is recommended to clearly specify regression.
    -   *
        * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
        *              For classification, labels should take values {0, 1, 
..., numClasses-1}.
        *              For regression, labels are real numbers.
        * @param boostingStrategy Configuration options for the boosting 
algorithm.
    -   * @return WeightedEnsembleModel that can be used for prediction
    +   * @return a gradient boosted trees model that can be used for prediction
    --- End diff --
    
    Very minor nit: A gradient boosted trees model that can be used for 
prediction.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to