Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/20632#discussion_r170412046 --- Diff: mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala --- @@ -631,6 +651,160 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } + + test("[SPARK-3159] tree model redundancy - binary classification") { + val numClasses = 2 + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val dt = buildRedundantDecisionTree(numClasses, 20, strategy = strategy) + + /* Expected tree structure tested below: + root + left1 right1 + left2 right2 + + pred(left1) = 0 + pred(left2) = 1 + pred(right2) = 0 + */ + assert(dt.rootNode.numDescendants === 4) + assert(dt.rootNode.subtreeDepth === 2) + + assert(dt.rootNode.isInstanceOf[InternalNode]) + + // left 1 prediction test + assert(dt.rootNode.asInstanceOf[InternalNode].leftChild.prediction === 0) + + val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild + assert(right1.isInstanceOf[InternalNode]) + + // left 2 prediction test + assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 1) + // right 2 prediction test + assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 0) + } + + test("[SPARK-3159] tree model redundancy - multiclass classification") { + val numClasses = 4 + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val dt = buildRedundantDecisionTree(numClasses, 20, strategy = strategy) + + /* Expected tree structure tested below: + root + left1 right1 + left2 right2 left3 right3 + + pred(left2) = 0 + pred(right2) = 1 + pred(left3) = 2 + pred(right3) = 1 + */ + assert(dt.rootNode.numDescendants === 6) + assert(dt.rootNode.subtreeDepth === 2) + + assert(dt.rootNode.isInstanceOf[InternalNode]) + + val left1 = dt.rootNode.asInstanceOf[InternalNode].leftChild + val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild + + assert(left1.isInstanceOf[InternalNode]) + + // left 2 prediction test + assert(left1.asInstanceOf[InternalNode].leftChild.prediction === 0) + // right 2 prediction test + assert(left1.asInstanceOf[InternalNode].rightChild.prediction === 1) + + assert(right1.isInstanceOf[InternalNode]) + + // left 3 prediction test + assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 2) + // right 3 prediction test + assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 1) + } + + test("[SPARK-3159] tree model redundancy - regression") { + val numClasses = 2 + + val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, + maxDepth = 3, maxBins = 10, numClasses = numClasses) + + val dt = buildRedundantDecisionTree(numClasses, 20, strategy = strategy) + + /* Expected tree structure tested below: + root + 1 2 + 1_1 1_2 2_1 2_2 + 1_1_1 1_1_2 1_2_1 1_2_2 2_1_1 2_1_2 + + pred(1_1_1) = 0.5 + pred(1_1_2) = 0.0 + pred(1_2_1) = 0.0 + pred(1_2_2) = 0.25 + pred(2_1_1) = 1.0 + pred(2_1_2) = 0.6666666666666666 + pred(2_2) = 0.5 + */ + + assert(dt.rootNode.numDescendants === 12) --- End diff -- Ok, trying to understand these tests. From what I can tell, you've written a data generator that generates random points, and, somewhat by chance, generates redundant tree nodes if the tree is not pruned. Your relying on the random seed to give you a tree which should have exactly 12 descendants after pruning. I think these may be overly complicated. IMO, we just need to test the situation that causes this by creating simple dataset that can improve the impurity by splitting, but which does not change the prediction. For example, for the Gini impurity you might have the following: ```scala val data = Array( LabeledPoint(0.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(0.0)), LabeledPoint(0.0, Vectors.dense(0.0)), LabeledPoint(1.0, Vectors.dense(0.0)) ) val rdd = sc.parallelize(data) val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, numClasses = 0, maxBins = 32) val tree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 42, instr = None).head assert(tree.rootNode.numDescendants === 0) ``` Before this patch, you'd get a tree with two descendants. Actually, I think it might be nice to introduce a `prune` parameter just for testing, then you can do the following: ```scala val data = Array( LabeledPoint(0.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(0.0)), LabeledPoint(0.0, Vectors.dense(0.0)), LabeledPoint(1.0, Vectors.dense(0.0)) ) val rdd = sc.parallelize(data) val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, numClasses = 0, maxBins = 32) val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 42, prune = true, instr = None).head val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 42, prune = false, instr = None).head assert(prunedTree.rootNode.numDescendants === 0) assert(unprunedTree.rootNode.numDescendants === 2) ``` That's a sanity check that the patch has actually made the difference. WDYT?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org