Github user asolimando commented on a diff in the pull request: https://github.com/apache/spark/pull/20632#discussion_r170734558 --- Diff: mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala --- @@ -631,6 +651,160 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } + + test("[SPARK-3159] tree model redundancy - binary classification") { + val numClasses = 2 + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val dt = buildRedundantDecisionTree(numClasses, 20, strategy = strategy) + + /* Expected tree structure tested below: + root + left1 right1 + left2 right2 + + pred(left1) = 0 + pred(left2) = 1 + pred(right2) = 0 + */ + assert(dt.rootNode.numDescendants === 4) + assert(dt.rootNode.subtreeDepth === 2) + + assert(dt.rootNode.isInstanceOf[InternalNode]) + + // left 1 prediction test + assert(dt.rootNode.asInstanceOf[InternalNode].leftChild.prediction === 0) + + val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild + assert(right1.isInstanceOf[InternalNode]) + + // left 2 prediction test + assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 1) + // right 2 prediction test + assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 0) + } + + test("[SPARK-3159] tree model redundancy - multiclass classification") { + val numClasses = 4 + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val dt = buildRedundantDecisionTree(numClasses, 20, strategy = strategy) + + /* Expected tree structure tested below: + root + left1 right1 + left2 right2 left3 right3 + + pred(left2) = 0 + pred(right2) = 1 + pred(left3) = 2 + pred(right3) = 1 + */ + assert(dt.rootNode.numDescendants === 6) + assert(dt.rootNode.subtreeDepth === 2) + + assert(dt.rootNode.isInstanceOf[InternalNode]) + + val left1 = dt.rootNode.asInstanceOf[InternalNode].leftChild + val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild + + assert(left1.isInstanceOf[InternalNode]) + + // left 2 prediction test + assert(left1.asInstanceOf[InternalNode].leftChild.prediction === 0) + // right 2 prediction test + assert(left1.asInstanceOf[InternalNode].rightChild.prediction === 1) + + assert(right1.isInstanceOf[InternalNode]) + + // left 3 prediction test + assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 2) + // right 3 prediction test + assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 1) + } + + test("[SPARK-3159] tree model redundancy - regression") { + val numClasses = 2 + + val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, + maxDepth = 3, maxBins = 10, numClasses = numClasses) + + val dt = buildRedundantDecisionTree(numClasses, 20, strategy = strategy) + + /* Expected tree structure tested below: + root + 1 2 + 1_1 1_2 2_1 2_2 + 1_1_1 1_1_2 1_2_1 1_2_2 2_1_1 2_1_2 + + pred(1_1_1) = 0.5 + pred(1_1_2) = 0.0 + pred(1_2_1) = 0.0 + pred(1_2_2) = 0.25 + pred(2_1_1) = 1.0 + pred(2_1_2) = 0.6666666666666666 + pred(2_2) = 0.5 + */ + + assert(dt.rootNode.numDescendants === 12) --- End diff -- I agree on the whole line, and here are the modifications I have made to align to the comments: - I have simplified the tests as much as possible (minimal input data, resulting into minimal trees) - I have removed the data generator methods (now the input data are explicitly given) - Conditions to be tested are simple (just the number of nodes as you suggested), and they test the pruned vs unpruned version - I have added the extra debugging parameter _prune_ (to _RandomForest.run(...)_ and _Node.toNode()_) - Now that we have _prune_, I have reverted back the adaptation of "Use soft prediction for binary classification with ordered categorical features", because now it is sufficient to add `prune = false` to the invocation of _RandomForest.run(...)_)
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org