Github user imatiach-msft commented on a diff in the pull request: https://github.com/apache/spark/pull/21632#discussion_r209840163 --- Diff: mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala --- @@ -700,6 +722,82 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(unprunedTree.numNodes === 5) assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) } + + test("weights at arbitrary scale") { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10) + val rddWithUnitWeights = sc.parallelize(arr.map(_.asML.toInstance(1.0))) + val rddWithSmallWeights = rddWithUnitWeights.map { inst => + Instance(inst.label, 0.001, inst.features) + } + val rddWithBigWeights = rddWithUnitWeights.map { inst => + Instance(inst.label, 1000, inst.features) + } + val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2) + val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3, "all", 42L, None) + + val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3, "all", 42L, None) + unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, smallWeightTree) => + TreeTests.checkEqual(unitTree, smallWeightTree) + } + + val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, "all", 42L, None) + unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, bigWeightTree) => + TreeTests.checkEqual(unitTree, bigWeightTree) + } + } + + test("minWeightFraction and minInstancesPerNode") { + val data = Array( + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(1.0, 0.1, Vectors.dense(1.0)) + ) + val rdd = sc.parallelize(data) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, + minWeightFractionPerNode = 0.5) + val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None) + assert(tree1.depth == 0) + + strategy.minWeightFractionPerNode = 0.0 + val Array(tree2) = RandomForest.run(rdd, strategy, 1, "all", 42L, None) + assert(tree2.depth == 1) + + strategy.minInstancesPerNode = 2 + val Array(tree3) = RandomForest.run(rdd, strategy, 1, "all", 42L, None) + assert(tree3.depth == 0) + + strategy.minInstancesPerNode = 1 + val Array(tree4) = RandomForest.run(rdd, strategy, 1, "all", 42L, None) + assert(tree4.depth == 1) + } + + test("extremely unbalanced weighting with bagging") { + /* + This test verifies that sample weights are taken into account during the + bagging process, instead of applied afterwards. If sample weights were applied + after the sampling is done, then some of the trees would not contain the heavily + weighted example. Here, we verify that all trees predict the correct value. + */ + val data = Array( + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(1.0, 1e6, Vectors.dense(1.0)) + ) + val rdd = sc.parallelize(data) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2) + val trees = RandomForest.run(rdd, strategy, 10, "all", 42L, None) + val features = Vectors.dense(1.0) + trees.foreach { tree => + val predict = tree.rootNode.predictImpl(features).prediction + // TODO: need to investigate why this went to 0 for some trees --- End diff -- need to investigate this test
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org