Github user asolimando commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20632#discussion_r170734558
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
    @@ -631,6 +651,160 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
         val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
         assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
       }
    +
    +  test("[SPARK-3159] tree model redundancy - binary classification") {
    +    val numClasses = 2
    +
    +    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
    +      numClasses = numClasses, maxBins = 32)
    +
    +    val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
    +
    +    /* Expected tree structure tested below:
    +             root
    +      left1       right1
    +                left2  right2
    +
    +      pred(left1)  = 0
    +      pred(left2)  = 1
    +      pred(right2) = 0
    +     */
    +    assert(dt.rootNode.numDescendants === 4)
    +    assert(dt.rootNode.subtreeDepth === 2)
    +
    +    assert(dt.rootNode.isInstanceOf[InternalNode])
    +
    +    // left 1 prediction test
    +    assert(dt.rootNode.asInstanceOf[InternalNode].leftChild.prediction === 
0)
    +
    +    val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
    +    assert(right1.isInstanceOf[InternalNode])
    +
    +    // left 2 prediction test
    +    assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 1)
    +    // right 2 prediction test
    +    assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 0)
    +  }
    +
    +  test("[SPARK-3159] tree model redundancy - multiclass classification") {
    +    val numClasses = 4
    +
    +    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
    +      numClasses = numClasses, maxBins = 32)
    +
    +    val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
    +
    +    /* Expected tree structure tested below:
    +                    root
    +            left1         right1
    +        left2  right2  left3  right3
    +
    +      pred(left2)  = 0
    +      pred(right2)  = 1
    +      pred(left3) = 2
    +      pred(right3) = 1
    +     */
    +    assert(dt.rootNode.numDescendants === 6)
    +    assert(dt.rootNode.subtreeDepth === 2)
    +
    +    assert(dt.rootNode.isInstanceOf[InternalNode])
    +
    +    val left1 = dt.rootNode.asInstanceOf[InternalNode].leftChild
    +    val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
    +
    +    assert(left1.isInstanceOf[InternalNode])
    +
    +    // left 2 prediction test
    +    assert(left1.asInstanceOf[InternalNode].leftChild.prediction === 0)
    +    // right 2 prediction test
    +    assert(left1.asInstanceOf[InternalNode].rightChild.prediction === 1)
    +
    +    assert(right1.isInstanceOf[InternalNode])
    +
    +    // left 3 prediction test
    +    assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 2)
    +    // right 3 prediction test
    +    assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 1)
    +  }
    +
    +  test("[SPARK-3159] tree model redundancy - regression") {
    +    val numClasses = 2
    +
    +    val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance,
    +      maxDepth = 3, maxBins = 10, numClasses = numClasses)
    +
    +    val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
    +
    +    /* Expected tree structure tested below:
    +                                root
    +                    1                             2
    +           1_1             1_2             2_1         2_2
    +      1_1_1   1_1_2   1_2_1   1_2_2   2_1_1   2_1_2
    +
    +      pred(1_1_1)  = 0.5
    +      pred(1_1_2)  = 0.0
    +      pred(1_2_1)  = 0.0
    +      pred(1_2_2)  = 0.25
    +      pred(2_1_1)  = 1.0
    +      pred(2_1_2)  = 0.6666666666666666
    +      pred(2_2)    = 0.5
    +     */
    +
    +    assert(dt.rootNode.numDescendants === 12)
    --- End diff --
    
    I agree on the whole line, and here are the modifications I have made to 
align to the comments: 
    - I have simplified the tests as much as possible (minimal input data, 
resulting into minimal trees)
    - I have removed the data generator methods (now the input data are 
explicitly given)
    - Conditions to be tested are simple (just the number of nodes as you 
suggested), and they test the pruned vs unpruned version
    - I have added the extra debugging parameter _prune_ (to 
_RandomForest.run(...)_ and _Node.toNode()_)
    - Now that we have _prune_, I have reverted back the adaptation of "Use 
soft prediction for binary classification with ordered categorical features", 
because now it is sufficient to add `prune = false` to the invocation of 
_RandomForest.run(...)_)



---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to