Github user imatiach-msft commented on a diff in the pull request:
https://github.com/apache/spark/pull/21632#discussion_r220413827
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -700,6 +722,82 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
assert(unprunedTree.numNodes === 5)
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode))
=== arr.size)
}
+
+ test("weights at arbitrary scale") {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10)
+ val rddWithUnitWeights =
sc.parallelize(arr.map(_.asML.toInstance(1.0)))
+ val rddWithSmallWeights = rddWithUnitWeights.map { inst =>
+ Instance(inst.label, 0.001, inst.features)
+ }
+ val rddWithBigWeights = rddWithUnitWeights.map { inst =>
+ Instance(inst.label, 1000, inst.features)
+ }
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2)
+ val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy,
3, "all", 42L, None)
+
+ val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy,
3, "all", 42L, None)
+ unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree,
smallWeightTree) =>
+ TreeTests.checkEqual(unitTree, smallWeightTree)
+ }
+
+ val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3,
"all", 42L, None)
+ unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree,
bigWeightTree) =>
+ TreeTests.checkEqual(unitTree, bigWeightTree)
+ }
+ }
+
+ test("minWeightFraction and minInstancesPerNode") {
+ val data = Array(
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(1.0, 0.1, Vectors.dense(1.0))
+ )
+ val rdd = sc.parallelize(data)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2,
+ minWeightFractionPerNode = 0.5)
+ val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree1.depth == 0)
+
+ strategy.minWeightFractionPerNode = 0.0
+ val Array(tree2) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree2.depth == 1)
+
+ strategy.minInstancesPerNode = 2
+ val Array(tree3) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree3.depth == 0)
+
+ strategy.minInstancesPerNode = 1
+ val Array(tree4) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree4.depth == 1)
+ }
+
+ test("extremely unbalanced weighting with bagging") {
+ /*
+ This test verifies that sample weights are taken into account during
the
+ bagging process, instead of applied afterwards. If sample weights were
applied
+ after the sampling is done, then some of the trees would not contain
the heavily
+ weighted example. Here, we verify that all trees predict the correct
value.
+ */
+ val data = Array(
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(1.0, 1e6, Vectors.dense(1.0))
+ )
+ val rdd = sc.parallelize(data)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2)
+ val trees = RandomForest.run(rdd, strategy, 10, "all", 42L, None)
+ val features = Vectors.dense(1.0)
+ trees.foreach { tree =>
+ val predict = tree.rootNode.predictImpl(features).prediction
+ // TODO: need to investigate why this went to 0 for some trees
--- End diff --
investigated - this was added by mistake, in original PR it was added in
and then removed, it got re-added during the big merge to latest code
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]