Github user imatiach-msft commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21632#discussion_r209840163
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
    @@ -700,6 +722,82 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
         assert(unprunedTree.numNodes === 5)
         assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) 
=== arr.size)
       }
    +
    +  test("weights at arbitrary scale") {
    +    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10)
    +    val rddWithUnitWeights = 
sc.parallelize(arr.map(_.asML.toInstance(1.0)))
    +    val rddWithSmallWeights = rddWithUnitWeights.map { inst =>
    +      Instance(inst.label, 0.001, inst.features)
    +    }
    +    val rddWithBigWeights = rddWithUnitWeights.map { inst =>
    +      Instance(inst.label, 1000, inst.features)
    +    }
    +    val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2)
    +    val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 
3, "all", 42L, None)
    +
    +    val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 
3, "all", 42L, None)
    +    unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, 
smallWeightTree) =>
    +      TreeTests.checkEqual(unitTree, smallWeightTree)
    +    }
    +
    +    val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, 
"all", 42L, None)
    +    unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, 
bigWeightTree) =>
    +      TreeTests.checkEqual(unitTree, bigWeightTree)
    +    }
    +  }
    +
    +  test("minWeightFraction and minInstancesPerNode") {
    +    val data = Array(
    +      Instance(0.0, 1.0, Vectors.dense(0.0)),
    +      Instance(0.0, 1.0, Vectors.dense(0.0)),
    +      Instance(0.0, 1.0, Vectors.dense(0.0)),
    +      Instance(0.0, 1.0, Vectors.dense(0.0)),
    +      Instance(1.0, 0.1, Vectors.dense(1.0))
    +    )
    +    val rdd = sc.parallelize(data)
    +    val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2,
    +      minWeightFractionPerNode = 0.5)
    +    val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
    +    assert(tree1.depth == 0)
    +
    +    strategy.minWeightFractionPerNode = 0.0
    +    val Array(tree2) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
    +    assert(tree2.depth == 1)
    +
    +    strategy.minInstancesPerNode = 2
    +    val Array(tree3) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
    +    assert(tree3.depth == 0)
    +
    +    strategy.minInstancesPerNode = 1
    +    val Array(tree4) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
    +    assert(tree4.depth == 1)
    +  }
    +
    +  test("extremely unbalanced weighting with bagging") {
    +    /*
    +    This test verifies that sample weights are taken into account during 
the
    +    bagging process, instead of applied afterwards. If sample weights were 
applied
    +    after the sampling is done, then some of the trees would not contain 
the heavily
    +    weighted example. Here, we verify that all trees predict the correct 
value.
    +     */
    +    val data = Array(
    +      Instance(0.0, 1.0, Vectors.dense(0.0)),
    +      Instance(0.0, 1.0, Vectors.dense(0.0)),
    +      Instance(0.0, 1.0, Vectors.dense(0.0)),
    +      Instance(0.0, 1.0, Vectors.dense(0.0)),
    +      Instance(1.0, 1e6, Vectors.dense(1.0))
    +    )
    +    val rdd = sc.parallelize(data)
    +    val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2)
    +    val trees = RandomForest.run(rdd, strategy, 10, "all", 42L, None)
    +    val features = Vectors.dense(1.0)
    +    trees.foreach { tree =>
    +      val predict = tree.rootNode.predictImpl(features).prediction
    +      // TODO: need to investigate why this went to 0 for some trees
    --- End diff --
    
    need to investigate this test


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to