Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/20632#discussion_r170410747
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -402,20 +405,40 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(2.0)))
val input = sc.parallelize(arr)
+ val seed = 42
+ val numTrees = 1
+
// Must set maxBins s.t. the feature will be treated as an ordered
categorical feature.
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity
= Gini, maxDepth = 1,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
- val model = RandomForest.run(input, strategy, numTrees = 1,
featureSubsetStrategy = "all",
- seed = 42, instr = None).head
- model.rootNode match {
- case n: InternalNode => n.split match {
- case s: CategoricalSplit =>
- assert(s.leftCategories === Array(1.0))
- case _ => throw new AssertionError("model.rootNode.split was not a
CategoricalSplit")
- }
- case _ => throw new AssertionError("model.rootNode was not an
InternalNode")
- }
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy,
numTrees = numTrees,
+ featureSubsetStrategy = "all")
+ val splits = RandomForest.findSplits(input, metadata, seed = seed)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput,
+ strategy.subsamplingRate, numTrees, false, seed = seed)
+
+ val topNode = LearningNode.emptyNode(nodeIndex = 1)
+ assert(topNode.isLeaf === false)
+ assert(topNode.stats === null)
+
+ val nodesForGroup = Map(0 -> Array(topNode))
+ val treeToNodeToIndexInfo = Map(0 -> Map(
+ topNode.id -> new RandomForest.NodeIndexInfo(0, None)
+ ))
+ val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
+ val bestSplit = RandomForest.findBestSplits(baggedInput, metadata,
Map(0 -> topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
+
+ assert(topNode.split.isDefined, "rootNode does not have a split")
--- End diff --
I'm a fan of just calling `foreach` like:
```scala
topNode.split.foreach { split =>
assert(split.isInstanceOf[CategoricalSplit])
assert(split.toOld.categories === Array(1.0))
}
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]