Github user sethah commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20632#discussion_r171899289
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
    @@ -631,10 +634,70 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
         val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
         assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
       }
    +
    +  
///////////////////////////////////////////////////////////////////////////////
    +  // Tests for pruning of redundant subtrees (generated by a split 
improving the
    +  // impurity measure, but always leading to the same prediction).
    +  
///////////////////////////////////////////////////////////////////////////////
    +
    +  test("SPARK-3159 tree model redundancy - binary classification") {
    +    // The following dataset is set up such that splitting over feature 1 
for points having
    +    // feature 0 = 0 improves the impurity measure, despite the prediction 
will always be 0
    +    // in both branches.
    +    val arr = Array(
    +      LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
    +      LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
    +      LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
    +      LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
    +      LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
    +      LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
    +    )
    +    val rdd = sc.parallelize(arr)
    +
    +    val numClasses = 2
    +    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
    +      numClasses = numClasses, maxBins = 32)
    +
    +    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
    +      seed = 42, instr = None).head
    +
    +    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
    +      seed = 42, instr = None, prune = false).head
    +
    +    assert(prunedTree.numNodes === 5)
    +    assert(unprunedTree.numNodes === 7)
    +  }
    +
    +  test("SPARK-3159 tree model redundancy - regression") {
    +    // The following dataset is set up such that splitting over feature 0 
for points having
    +    // feature 1 = 1 improves the impurity measure, despite the prediction 
will always be 0.5
    +    // in both branches.
    +    val arr = Array(
    +      LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
    +      LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
    +      LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
    +      LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
    +      LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
    +      LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
    +      LabeledPoint(0.5, Vectors.dense(1.0, 1.0))
    +    )
    +    val rdd = sc.parallelize(arr)
    +
    +    val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance, maxDepth = 4,
    +      numClasses = 0, maxBins = 32)
    +
    +    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
    +      seed = 42, instr = None).head
    +
    +    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
    +      seed = 42, instr = None, prune = false).head
    +
    --- End diff --
    
    Would you mind adding a check in both tests to make sure that the count of 
all the leaf nodes sums to the total count (i.e. 6)? That way we make sure we 
don't lose information when merging the leaves? You can do it via 
`leafNode.impurityStats.count`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to