Repository: spark Updated Branches: refs/heads/master 28ab5bf59 -> 9b9827759
[SPARK-20199][ML] : Provided featureSubsetStrategy to GBTClassifier and GBTRegressor ## What changes were proposed in this pull request? (Provided featureSubset Strategy to GBTClassifier a) Moved featureSubsetStrategy to TreeEnsembleParams b) Changed GBTClassifier to pass featureSubsetStrategy val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy)) ## How was this patch tested? a) Tested GradientBoostedTreeClassifierExample by adding .setFeatureSubsetStrategy with GBTClassifier b)Added test cases in GBTClassifierSuite and GBTRegressorSuite Author: Pralabh Kumar <pralabhku...@gmail.com> Closes #18118 from pralabhkumar/develop. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9b982775 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9b982775 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9b982775 Branch: refs/heads/master Commit: 9b9827759af2ca3eea146a6032f9165f640ce152 Parents: 28ab5bf Author: Pralabh Kumar <pralabhku...@gmail.com> Authored: Fri Nov 10 13:17:25 2017 +0200 Committer: Nick Pentreath <ni...@za.ibm.com> Committed: Fri Nov 10 13:17:25 2017 +0200 ---------------------------------------------------------------------- .../GradientBoostedTreeClassifierExample.scala | 1 + .../spark/ml/classification/GBTClassifier.scala | 9 ++- .../classification/RandomForestClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 8 +- .../spark/ml/regression/GBTRegressor.scala | 9 ++- .../ml/regression/RandomForestRegressor.scala | 2 +- .../ml/tree/impl/DecisionTreeMetadata.scala | 4 +- .../ml/tree/impl/GradientBoostedTrees.scala | 25 +++--- .../org/apache/spark/ml/tree/treeParams.scala | 82 ++++++++++---------- .../spark/mllib/tree/GradientBoostedTrees.scala | 4 +- .../apache/spark/mllib/tree/RandomForest.scala | 2 +- .../ml/classification/GBTClassifierSuite.scala | 29 +++++++ .../spark/ml/regression/GBTRegressorSuite.scala | 29 +++++++ .../tree/impl/GradientBoostedTreesSuite.scala | 4 +- 14 files changed, 146 insertions(+), 64 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 9a39acf..3656773 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -59,6 +59,7 @@ object GradientBoostedTreeClassifierExample { .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures") .setMaxIter(10) + .setFeatureSubsetStrategy("auto") // Convert indexed labels back to original labels. val labelConverter = new IndexToString() http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3da809c..f11bc1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -135,6 +135,11 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override def setStepSize(value: Double): this.type = set(stepSize, value) + /** @group setParam */ + @Since("2.3.0") + override def setFeatureSubsetStrategy(value: String): this.type = + set(featureSubsetStrategy, value) + // Parameters from GBTClassifierParams: /** @group setParam */ @@ -167,12 +172,12 @@ class GBTClassifier @Since("1.4.0") ( val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed)) + $(seed), $(featureSubsetStrategy)) val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index ab4c235..78a4972 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -158,7 +158,7 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = - RandomForestParams.supportedFeatureSubsetStrategies + TreeEnsembleParams.supportedFeatureSubsetStrategies @Since("2.0.0") override def load(path: String): RandomForestClassifier = super.load(path) http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 01c5cc1..0291a57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -117,12 +117,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S } /** (private[ml]) Train a decision tree on an RDD */ - private[ml] def train(data: RDD[LabeledPoint], - oldStrategy: OldStrategy): DecisionTreeRegressionModel = { + private[ml] def train( + data: RDD[LabeledPoint], + oldStrategy: OldStrategy, + featureSubsetStrategy: String): DecisionTreeRegressionModel = { val instr = Instrumentation.create(this, data) instr.logParams(params: _*) - val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", + val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid)) val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 08d175c..f41d15b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -140,6 +140,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) + /** @group setParam */ + @Since("2.3.0") + override def setFeatureSubsetStrategy(value: String): this.type = + set(featureSubsetStrategy, value) + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -150,11 +155,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed)) + $(seed), $(featureSubsetStrategy)) val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index a58da50..200b234 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -149,7 +149,7 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = - RandomForestParams.supportedFeatureSubsetStrategies + TreeEnsembleParams.supportedFeatureSubsetStrategies @Since("2.0.0") override def load(path: String): RandomForestRegressor = super.load(path) http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 8a9dcb4..53189e0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -22,7 +22,7 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.tree.RandomForestParams +import org.apache.spark.ml.tree.TreeEnsembleParams import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.Strategy @@ -200,7 +200,7 @@ private[spark] object DecisionTreeMetadata extends Logging { Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match { case Some(value) => math.ceil(value * numFeatures).toInt case _ => throw new IllegalArgumentException(s"Supported values:" + - s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}," + s" (0.0-1.0], [1-n].") } } http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index e32447a..bd8c9af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -42,16 +42,18 @@ private[spark] object GradientBoostedTrees extends Logging { def run( input: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, - seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + seed: Long, + featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => - GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed) + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, + seed, featureSubsetStrategy) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, - seed) + seed, featureSubsetStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.") } @@ -73,11 +75,13 @@ private[spark] object GradientBoostedTrees extends Logging { input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, - seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + seed: Long, + featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => - GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed) + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, + validate = true, seed, featureSubsetStrategy) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -85,7 +89,7 @@ private[spark] object GradientBoostedTrees extends Logging { val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate = true, seed) + validate = true, seed, featureSubsetStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -245,7 +249,8 @@ private[spark] object GradientBoostedTrees extends Logging { validationInput: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, validate: Boolean, - seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + seed: Long, + featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -258,6 +263,7 @@ private[spark] object GradientBoostedTrees extends Logging { val baseLearnerWeights = new Array[Double](numIterations) val loss = boostingStrategy.loss val learningRate = boostingStrategy.learningRate + // Prepare strategy for individual trees, which use regression with variance impurity. val treeStrategy = boostingStrategy.treeStrategy.copy val validationTol = boostingStrategy.validationTol @@ -288,7 +294,7 @@ private[spark] object GradientBoostedTrees extends Logging { // Initialize tree timer.start("building tree 0") val firstTree = new DecisionTreeRegressor().setSeed(seed) - val firstTreeModel = firstTree.train(input, treeStrategy) + val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight @@ -319,8 +325,9 @@ private[spark] object GradientBoostedTrees extends Logging { logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") + val dt = new DecisionTreeRegressor().setSeed(seed + m) - val model = dt.train(data, treeStrategy) + val model = dt.train(data, treeStrategy, featureSubsetStrategy) timer.stop(s"building tree $m") // Update partial model baseLearners(m) = model http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 47079d9..81b6222 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -320,6 +320,12 @@ private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams } } +private[spark] object TreeEnsembleParams { + // These options should be lowercase. + final val supportedFeatureSubsetStrategies: Array[String] = + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT)) +} + /** * Parameters for Decision Tree-based ensemble algorithms. * @@ -359,38 +365,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { oldImpurity: OldImpurity): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) } -} - -/** - * Parameters for Random Forest algorithms. - */ -private[ml] trait RandomForestParams extends TreeEnsembleParams { - - /** - * Number of trees to train (>= 1). - * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. - * TODO: Change to always do bootstrapping (simpler). SPARK-7130 - * (default = 20) - * - * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) - * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms - * are a bit different. - * @group param - */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", - ParamValidators.gtEq(1)) - - setDefault(numTrees -> 20) - - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setNumTrees(value: Int): this.type = set(numTrees, value) - - /** @group getParam */ - final def getNumTrees: Int = $(numTrees) /** * The number of features to consider for splits at each tree node. @@ -420,10 +394,10 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { */ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", "The number of features to consider for splits at each tree node." + - s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" + + s" Supported options: ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}" + s", (0.0-1.0], [1-n].", (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains( + TreeEnsembleParams.supportedFeatureSubsetStrategies.contains( value.toLowerCase(Locale.ROOT)) || Try(value.toInt).filter(_ > 0).isSuccess || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess) @@ -431,7 +405,7 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(featureSubsetStrategy -> "auto") /** - * @deprecated This method is deprecated and will be removed in 3.0.0. + * @deprecated This method is deprecated and will be removed in 3.0.0 * @group setParam */ @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") @@ -441,10 +415,38 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } -private[spark] object RandomForestParams { - // These options should be lowercase. - final val supportedFeatureSubsetStrategies: Array[String] = - Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT)) + + +/** + * Parameters for Random Forest algorithms. + */ +private[ml] trait RandomForestParams extends TreeEnsembleParams { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * + * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) + * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms + * are a bit different. + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) + + setDefault(numTrees -> 20) + + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setNumTrees(value: Int): this.type = set(numTrees, value) + + /** @group getParam */ + final def getNumTrees: Int = $(numTrees) } private[ml] trait RandomForestClassifierParams @@ -497,6 +499,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS setDefault(maxIter -> 20, stepSize -> 0.1) + setDefault(featureSubsetStrategy -> "all") + /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( categoricalFeatures: Map[Int, Int], http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index df2c1b0..d24d8da 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -69,7 +69,7 @@ class GradientBoostedTrees private[spark] ( val algo = boostingStrategy.treeStrategy.algo val (trees, treeWeights) = NewGBT.run(input.map { point => NewLabeledPoint(point.label, point.features.asML) - }, boostingStrategy, seed.toLong) + }, boostingStrategy, seed.toLong, "all") new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } @@ -101,7 +101,7 @@ class GradientBoostedTrees private[spark] ( NewLabeledPoint(point.label, point.features.asML) }, validationInput.map { point => NewLabeledPoint(point.label, point.features.asML) - }, boostingStrategy, seed.toLong) + }, boostingStrategy, seed.toLong, "all") new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index d1331a5..a8c5286 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -23,7 +23,7 @@ import scala.util.Try import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams} +import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams} import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 8000143..978f89c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -83,6 +83,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext assert(gbt.getPredictionCol === "prediction") assert(gbt.getRawPredictionCol === "rawPrediction") assert(gbt.getProbabilityCol === "probability") + assert(gbt.getFeatureSubsetStrategy === "all") val df = trainData.toDF() val model = gbt.fit(df) model.transform(df) @@ -95,6 +96,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.getPredictionCol === "prediction") assert(model.getRawPredictionCol === "rawPrediction") assert(model.getProbabilityCol === "probability") + assert(model.getFeatureSubsetStrategy === "all") assert(model.hasParent) MLTestingUtils.checkCopyAndUids(gbt, model) @@ -357,6 +359,33 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext } ///////////////////////////////////////////////////////////////////////////// + // Tests of feature subset strategy + ///////////////////////////////////////////////////////////////////////////// + test("Tests of feature subset strategy") { + val numClasses = 2 + val gbt = new GBTClassifier() + .setSeed(123) + .setMaxDepth(3) + .setMaxIter(5) + .setFeatureSubsetStrategy("all") + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = gbt.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + + // GBT with different featureSubsetStrategy + val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") + val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances + val mostIF = importanceFeatures.argmax + assert(mostImportantFeature !== mostIF) + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 2da25f7..ecbb571 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -166,6 +166,35 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext } ///////////////////////////////////////////////////////////////////////////// + // Tests of feature subset strategy + ///////////////////////////////////////////////////////////////////////////// + test("Tests of feature subset strategy") { + val numClasses = 2 + val gbt = new GBTRegressor() + .setMaxDepth(3) + .setMaxIter(5) + .setSeed(123) + .setFeatureSubsetStrategy("all") + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = gbt.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + + // GBT with different featureSubsetStrategy + val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") + val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances + val mostIF = importanceFeatures.argmax + assert(mostImportantFeature !== mostIF) + } + + + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/spark/blob/9b982775/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala index 4109a29..366d5ec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -50,12 +50,12 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val boostingStrategy = new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) val (validateTrees, validateTreeWeights) = GradientBoostedTrees - .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L) + .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L, "all") val numTrees = validateTrees.length assert(numTrees !== numIterations) // Test that it performs better on the validation dataset. - val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L) + val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L, "all") val (errorWithoutValidation, errorWithValidation) = { if (algo == Classification) { val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org