Github user asolimando commented on a diff in the pull request:
https://github.com/apache/spark/pull/20632#discussion_r170734558
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -631,6 +651,160 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
+
+ test("[SPARK-3159] tree model redundancy - binary classification") {
+ val numClasses = 2
+
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity
= Gini, maxDepth = 4,
+ numClasses = numClasses, maxBins = 32)
+
+ val dt = buildRedundantDecisionTree(numClasses, 20, strategy =
strategy)
+
+ /* Expected tree structure tested below:
+ root
+ left1 right1
+ left2 right2
+
+ pred(left1) = 0
+ pred(left2) = 1
+ pred(right2) = 0
+ */
+ assert(dt.rootNode.numDescendants === 4)
+ assert(dt.rootNode.subtreeDepth === 2)
+
+ assert(dt.rootNode.isInstanceOf[InternalNode])
+
+ // left 1 prediction test
+ assert(dt.rootNode.asInstanceOf[InternalNode].leftChild.prediction ===
0)
+
+ val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
+ assert(right1.isInstanceOf[InternalNode])
+
+ // left 2 prediction test
+ assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 1)
+ // right 2 prediction test
+ assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 0)
+ }
+
+ test("[SPARK-3159] tree model redundancy - multiclass classification") {
+ val numClasses = 4
+
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity
= Gini, maxDepth = 4,
+ numClasses = numClasses, maxBins = 32)
+
+ val dt = buildRedundantDecisionTree(numClasses, 20, strategy =
strategy)
+
+ /* Expected tree structure tested below:
+ root
+ left1 right1
+ left2 right2 left3 right3
+
+ pred(left2) = 0
+ pred(right2) = 1
+ pred(left3) = 2
+ pred(right3) = 1
+ */
+ assert(dt.rootNode.numDescendants === 6)
+ assert(dt.rootNode.subtreeDepth === 2)
+
+ assert(dt.rootNode.isInstanceOf[InternalNode])
+
+ val left1 = dt.rootNode.asInstanceOf[InternalNode].leftChild
+ val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
+
+ assert(left1.isInstanceOf[InternalNode])
+
+ // left 2 prediction test
+ assert(left1.asInstanceOf[InternalNode].leftChild.prediction === 0)
+ // right 2 prediction test
+ assert(left1.asInstanceOf[InternalNode].rightChild.prediction === 1)
+
+ assert(right1.isInstanceOf[InternalNode])
+
+ // left 3 prediction test
+ assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 2)
+ // right 3 prediction test
+ assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 1)
+ }
+
+ test("[SPARK-3159] tree model redundancy - regression") {
+ val numClasses = 2
+
+ val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity =
Variance,
+ maxDepth = 3, maxBins = 10, numClasses = numClasses)
+
+ val dt = buildRedundantDecisionTree(numClasses, 20, strategy =
strategy)
+
+ /* Expected tree structure tested below:
+ root
+ 1 2
+ 1_1 1_2 2_1 2_2
+ 1_1_1 1_1_2 1_2_1 1_2_2 2_1_1 2_1_2
+
+ pred(1_1_1) = 0.5
+ pred(1_1_2) = 0.0
+ pred(1_2_1) = 0.0
+ pred(1_2_2) = 0.25
+ pred(2_1_1) = 1.0
+ pred(2_1_2) = 0.6666666666666666
+ pred(2_2) = 0.5
+ */
+
+ assert(dt.rootNode.numDescendants === 12)
--- End diff --
I agree on the whole line, and here are the modifications I have made to
align to the comments:
- I have simplified the tests as much as possible (minimal input data,
resulting into minimal trees)
- I have removed the data generator methods (now the input data are
explicitly given)
- Conditions to be tested are simple (just the number of nodes as you
suggested), and they test the pruned vs unpruned version
- I have added the extra debugging parameter _prune_ (to
_RandomForest.run(...)_ and _Node.toNode()_)
- Now that we have _prune_, I have reverted back the adaptation of "Use
soft prediction for binary classification with ordered categorical features",
because now it is sufficient to add `prune = false` to the invocation of
_RandomForest.run(...)_)
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]