Github user asolimando commented on a diff in the pull request:
https://github.com/apache/spark/pull/20632#discussion_r171045897
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -631,10 +634,99 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
+
+
///////////////////////////////////////////////////////////////////////////////
+ // Tests for pruning of redundant subtrees (generated by a split
improving the
+ // impurity measure, but always leading to the same prediction).
+
///////////////////////////////////////////////////////////////////////////////
+
+ test("[SPARK-3159] tree model redundancy - binary classification") {
+ // The following dataset is set up such that splitting over feature 1
for points having
+ // feature 0 = 0 improves the impurity measure, despite the prediction
will always be 0
+ // in both branches.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
+ )
+ val rdd = sc.parallelize(arr)
+
+ val numClasses = 2
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity
= Gini, maxDepth = 4,
+ numClasses = numClasses, maxBins = 32)
+
+ val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
+ seed = 42, instr = None).head
+
+ val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
+ seed = 42, instr = None, prune = false).head
+
+ assert(prunedTree.numNodes === 5)
+ assert(unprunedTree.numNodes === 7)
+ }
+
+ test("[SPARK-3159] tree model redundancy - multiclass classification") {
--- End diff --
Honestly, for the original (and more involved!) version of the "fix" it was
less evident that the two cases would behave identically (due to explicit
computation of the predictions for a whole subtree), but I understand that with
the current version a single test can be sufficient.
In case, which one of the two (binary/multiclass) classification would you
keep?
Binary is simpler and more immediate to figure out (also by looking at the
data points).
But I am afraid that dropping the test for mutliclass might cause doubts on
the tests coverage.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]