Github user sethah commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20632#discussion_r170412046
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
    @@ -631,6 +651,160 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
         val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
         assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
       }
    +
    +  test("[SPARK-3159] tree model redundancy - binary classification") {
    +    val numClasses = 2
    +
    +    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
    +      numClasses = numClasses, maxBins = 32)
    +
    +    val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
    +
    +    /* Expected tree structure tested below:
    +             root
    +      left1       right1
    +                left2  right2
    +
    +      pred(left1)  = 0
    +      pred(left2)  = 1
    +      pred(right2) = 0
    +     */
    +    assert(dt.rootNode.numDescendants === 4)
    +    assert(dt.rootNode.subtreeDepth === 2)
    +
    +    assert(dt.rootNode.isInstanceOf[InternalNode])
    +
    +    // left 1 prediction test
    +    assert(dt.rootNode.asInstanceOf[InternalNode].leftChild.prediction === 
0)
    +
    +    val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
    +    assert(right1.isInstanceOf[InternalNode])
    +
    +    // left 2 prediction test
    +    assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 1)
    +    // right 2 prediction test
    +    assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 0)
    +  }
    +
    +  test("[SPARK-3159] tree model redundancy - multiclass classification") {
    +    val numClasses = 4
    +
    +    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
    +      numClasses = numClasses, maxBins = 32)
    +
    +    val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
    +
    +    /* Expected tree structure tested below:
    +                    root
    +            left1         right1
    +        left2  right2  left3  right3
    +
    +      pred(left2)  = 0
    +      pred(right2)  = 1
    +      pred(left3) = 2
    +      pred(right3) = 1
    +     */
    +    assert(dt.rootNode.numDescendants === 6)
    +    assert(dt.rootNode.subtreeDepth === 2)
    +
    +    assert(dt.rootNode.isInstanceOf[InternalNode])
    +
    +    val left1 = dt.rootNode.asInstanceOf[InternalNode].leftChild
    +    val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
    +
    +    assert(left1.isInstanceOf[InternalNode])
    +
    +    // left 2 prediction test
    +    assert(left1.asInstanceOf[InternalNode].leftChild.prediction === 0)
    +    // right 2 prediction test
    +    assert(left1.asInstanceOf[InternalNode].rightChild.prediction === 1)
    +
    +    assert(right1.isInstanceOf[InternalNode])
    +
    +    // left 3 prediction test
    +    assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 2)
    +    // right 3 prediction test
    +    assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 1)
    +  }
    +
    +  test("[SPARK-3159] tree model redundancy - regression") {
    +    val numClasses = 2
    +
    +    val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance,
    +      maxDepth = 3, maxBins = 10, numClasses = numClasses)
    +
    +    val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
    +
    +    /* Expected tree structure tested below:
    +                                root
    +                    1                             2
    +           1_1             1_2             2_1         2_2
    +      1_1_1   1_1_2   1_2_1   1_2_2   2_1_1   2_1_2
    +
    +      pred(1_1_1)  = 0.5
    +      pred(1_1_2)  = 0.0
    +      pred(1_2_1)  = 0.0
    +      pred(1_2_2)  = 0.25
    +      pred(2_1_1)  = 1.0
    +      pred(2_1_2)  = 0.6666666666666666
    +      pred(2_2)    = 0.5
    +     */
    +
    +    assert(dt.rootNode.numDescendants === 12)
    --- End diff --
    
    Ok, trying to understand these tests. From what I can tell, you've written 
a data generator that generates random points, and, somewhat by chance, 
generates redundant tree nodes if the tree is not pruned. Your relying on the 
random seed to give you a tree which should have exactly 12 descendants after 
pruning. 
    
    I think these may be overly complicated. IMO, we just need to test the 
situation that causes this by creating simple dataset that can improve the 
impurity by splitting, but which does not change the prediction. For example, 
for the Gini impurity you might have the following:
    
    ```scala
        val data = Array(
          LabeledPoint(0.0, Vectors.dense(1.0)),
          LabeledPoint(0.0, Vectors.dense(1.0)),
          LabeledPoint(0.0, Vectors.dense(1.0)),
          LabeledPoint(0.0, Vectors.dense(0.0)),
          LabeledPoint(0.0, Vectors.dense(0.0)),
          LabeledPoint(1.0, Vectors.dense(0.0))
        )
        val rdd = sc.parallelize(data)
        val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance, maxDepth = 4,
          numClasses = 0, maxBins = 32)
        val tree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
          seed = 42, instr = None).head
        assert(tree.rootNode.numDescendants === 0)
    ```
    Before this patch, you'd get a tree with two descendants. Actually, I think 
it might be nice to introduce a `prune` parameter just for testing, then you 
can do the following:
    
    ```scala
        val data = Array(
          LabeledPoint(0.0, Vectors.dense(1.0)),
          LabeledPoint(0.0, Vectors.dense(1.0)),
          LabeledPoint(0.0, Vectors.dense(1.0)),
          LabeledPoint(0.0, Vectors.dense(0.0)),
          LabeledPoint(0.0, Vectors.dense(0.0)),
          LabeledPoint(1.0, Vectors.dense(0.0))
        )
        val rdd = sc.parallelize(data)
        val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance, maxDepth = 4,
          numClasses = 0, maxBins = 32)
        val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
          seed = 42, prune = true, instr = None).head
        val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
          seed = 42, prune = false, instr = None).head
        assert(prunedTree.rootNode.numDescendants === 0)
        assert(unprunedTree.rootNode.numDescendants === 2)
    ```
    That's a sanity check that the patch has actually made the difference. WDYT?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to