Github user pralabhkumar commented on a diff in the pull request: https://github.com/apache/spark/pull/18118#discussion_r149886357 --- Diff: mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala --- @@ -166,6 +166,40 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext } ///////////////////////////////////////////////////////////////////////////// + // Tests of feature subset strategy + ///////////////////////////////////////////////////////////////////////////// + test("Tests of feature subset strategy") { + val numClasses = 2 + val gbt = new GBTRegressor() + .setMaxDepth(3) + .setMaxIter(5) + .setSubsamplingRate(1.0) + .setStepSize(0.5) + .setSeed(123) + .setFeatureSubsetStrategy("all") + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = gbt.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) + + // GBT with different featureSubsetStrategy + val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") + val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances + val mostIF = importanceFeatures.argmax + assert(!(mostImportantFeature === mostIF)) + assert(importanceFeatures.toArray.sum === 1.0) + assert(importanceFeatures.toArray.forall(_ >= 0.0)) + assert(!(importanceFeatures.toDense.values.deep === importances.toDense.values.deep)) --- End diff -- done
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org