Github user smurching commented on a diff in the pull request: https://github.com/apache/spark/pull/19666#discussion_r149241680 --- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala --- @@ -976,6 +930,44 @@ private[spark] object RandomForest extends Logging { categories } + private[tree] def traverseUnorderedSplits[T]( + arity: Int, + zeroStats: T, + seqOp: (T, Int) => T, + finalizer: (BitSet, T) => Unit): Unit = { + assert(arity > 1) + + // numSplits = (1 << arity - 1) - 1 + val numSplits = DecisionTreeMetadata.numUnorderedSplits(arity) + val subSet: BitSet = new BitSet(arity) + + // dfs traverse + // binIndex: [0, arity) + def dfs(binIndex: Int, combNumber: Int, stats: T): Unit = { + if (binIndex == arity) { + // recursion exit when binIndex == arity + if (combNumber > 0) { + // we get an available unordered split, saved in subSet. + finalizer(subSet, stats) + } + } else { + subSet.set(binIndex) + val leftChildCombNumber = combNumber + (1 << binIndex) + // pruning: only need combNumber satisfy: 1 <= combNumber <= numSplits --- End diff -- If I understand correctly, the check `if (leftChildCombNumber <= numSplits)` helps us ensure that we consider each split only once, right?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org