Github user smurching commented on a diff in the pull request:
https://github.com/apache/spark/pull/19666#discussion_r149241680
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
@@ -976,6 +930,44 @@ private[spark] object RandomForest extends Logging {
categories
}
+ private[tree] def traverseUnorderedSplits[T](
+ arity: Int,
+ zeroStats: T,
+ seqOp: (T, Int) => T,
+ finalizer: (BitSet, T) => Unit): Unit = {
+ assert(arity > 1)
+
+ // numSplits = (1 << arity - 1) - 1
+ val numSplits = DecisionTreeMetadata.numUnorderedSplits(arity)
+ val subSet: BitSet = new BitSet(arity)
+
+ // dfs traverse
+ // binIndex: [0, arity)
+ def dfs(binIndex: Int, combNumber: Int, stats: T): Unit = {
+ if (binIndex == arity) {
+ // recursion exit when binIndex == arity
+ if (combNumber > 0) {
+ // we get an available unordered split, saved in subSet.
+ finalizer(subSet, stats)
+ }
+ } else {
+ subSet.set(binIndex)
+ val leftChildCombNumber = combNumber + (1 << binIndex)
+ // pruning: only need combNumber satisfy: 1 <= combNumber <=
numSplits
--- End diff --
If I understand correctly, the check `if (leftChildCombNumber <=
numSplits)` helps us ensure that we consider each split only once, right?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]