Github user manishamde commented on a diff in the pull request:
https://github.com/apache/spark/pull/3374#discussion_r20621257
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala ---
@@ -45,146 +43,92 @@ import org.apache.spark.storage.StorageLevel
* but weak hypothesis weights are not computed correctly for LogLoss
or AbsoluteError.
* Running with those losses will likely behave reasonably, but lacks
the same guarantees.
*
- * @param boostingStrategy Parameters for the gradient boosting algorithm
+ * @param boostingStrategy Parameters for the gradient boosting algorithm.
*/
@Experimental
-class GradientBoosting (
- private val boostingStrategy: BoostingStrategy) extends Serializable
with Logging {
-
- boostingStrategy.weakLearnerParams.algo = Regression
- boostingStrategy.weakLearnerParams.impurity = impurity.Variance
-
- // Ensure values for weak learner are the same as what is provided to
the boosting algorithm.
- boostingStrategy.weakLearnerParams.numClassesForClassification =
- boostingStrategy.numClassesForClassification
-
- boostingStrategy.assertValid()
+class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
+ extends Serializable with Logging {
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of
[[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a gradient boosted trees model that can be used for prediction
*/
- def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
- val algo = boostingStrategy.algo
+ def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
+ val algo = boostingStrategy.treeStrategy.algo
algo match {
- case Regression => GradientBoosting.boost(input, boostingStrategy)
+ case Regression => GradientBoostedTrees.boost(input,
boostingStrategy)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as
regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2)
- 1, x.features))
- GradientBoosting.boost(remappedInput, boostingStrategy)
+ GradientBoostedTrees.boost(remappedInput, boostingStrategy)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the
gradient boosting.")
}
}
+ /**
+ * Java-friendly API for
[[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
+ */
+ def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
+ run(input.rdd)
+ }
}
-object GradientBoosting extends Logging {
+object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model.
*
- * Note: Using
[[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
- * is recommended to clearly specify regression.
- * Using
[[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
- * is recommended to clearly specify regression.
- *
* @param input Training dataset: RDD of
[[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1,
..., numClasses-1}.
* For regression, labels are real numbers.
* @param boostingStrategy Configuration options for the boosting
algorithm.
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a gradient boosted trees model that can be used for prediction
--- End diff --
Very minor nit: A gradient boosted trees model that can be used for
prediction.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]