Github user asolimando commented on a diff in the pull request:
https://github.com/apache/spark/pull/20632#discussion_r168925240
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -631,6 +654,32 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
+
+ test("[SPARK-3159] tree model redundancy - binary classification") {
+ redundacyTest(2, 20)
+ }
+
+ test("[SPARK-3159] tree model redundancy - multiclass classification") {
+ redundacyTest(4, 20)
+ }
+
+ private def redundacyTest(numClasses: Int, numPoints: Int) = {
+ val strategy = new OldStrategy(algo = Classification, impurity = Gini,
maxDepth = 4,
+ numClasses = numClasses, maxBins = 32)
+
+ val rnd = new Random(100)
+
+ val rdd = RandomForestSuite.generateRedundantPoints(sc,
+ numClasses = numClasses, numPoints = numPoints, rnd)
+
+ val dt = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
+ seed = 42, instr = None).head
+
+ val isRedundant = RandomForestSuite.isRedundant(dt)
--- End diff --
You are right, I have removed isRedundant and rnd.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]