This is an automated email from the ASF dual-hosted git repository.
WeichenXu123 pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new f67a8551e0b7 [SPARK-34591][ML] Add decision tree pruning as a parameter
f67a8551e0b7 is described below
commit f67a8551e0b707dcf9adeaed9aba2ad14f0a0915
Author: WeichenXu <[email protected]>
AuthorDate: Wed May 13 11:25:52 2026 +0800
[SPARK-34591][ML] Add decision tree pruning as a parameter
This PR adds a parameter to enable/disable a featuer where LearningNodes
are merged after a RF model is trained.
This PR takes over https://github.com/apache/spark/pull/32813
2 Reasons:
1. In addition to basic classification, another use case for decision trees
are the probabilities associated with predictions.
Once pruned, these predictions are lost and it makes the trees/predictions
challenging to work with if not unusable.
2. It is not in line with the default behavior in sklearn. In sklearn, the
trees are left unpruned by default.
Please see Jira ticket for more explanation.
**New params:**
adds a parameter `pruneTree` that is exposed to the Tree based classifiers.
Will add tests here to ensure parameter is exposed correctly.
Unit tests.
Closes #55763 from WeichenXu123/SPARK-34591.
Lead-authored-by: WeichenXu <[email protected]>
Co-authored-by: bribiescas-carlos <[email protected]>
Co-authored-by: Carlos Bribiescas <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
(cherry picked from commit 436291ec35b30b634bcde9ef35ddeb402dc70c06)
Signed-off-by: Weichen Xu <[email protected]>
---
.../ml/classification/DecisionTreeClassifier.scala | 8 +-
.../ml/classification/RandomForestClassifier.scala | 7 +-
.../apache/spark/ml/tree/impl/RandomForest.scala | 657 ++++++++++++---------
.../org/apache/spark/ml/tree/treeParams.scala | 19 +-
.../spark/mllib/tree/configuration/Strategy.scala | 5 +-
.../spark/ml/tree/impl/RandomForestSuite.scala | 411 +++++++++----
project/MimaExcludes.scala | 4 +-
python/pyspark/ml/classification.py | 28 +-
python/pyspark/ml/tree.py | 18 +
9 files changed, 759 insertions(+), 398 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 887d8277d311..d5564f6a3fbd 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -74,6 +74,10 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+ /** @group setParam */
+ @Since("4.3.0")
+ def setPruneTree(value: Boolean): this.type = set(pruneTree, value)
+
/** @group expertSetParam */
@Since("1.4.0")
def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
@@ -134,9 +138,11 @@ class DecisionTreeClassifier @Since("1.4.0") (
val strategy = getOldStrategy(categoricalFeatures, numClasses)
require(!strategy.bootstrap, "DecisionTreeClassifier does not need
bootstrap sampling")
+ strategy.pruneTree = $(pruneTree)
+
instr.logNumClasses(numClasses)
instr.logParams(this, labelCol, featuresCol, predictionCol,
rawPredictionCol,
- probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode,
minInfoGain,
+ probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode,
minInfoGain, pruneTree,
maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed,
thresholds)
val trees = RandomForest.run(instances, strategy, numTrees = 1,
featureSubsetStrategy = "all",
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index fb61358536d0..2c22ca5b4230 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -76,6 +76,10 @@ class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+ /** @group setParam */
+ @Since("4.3.0")
+ def setPruneTree(value: Boolean): this.type = set(pruneTree, value)
+
/** @group expertSetParam */
@Since("1.4.0")
def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
@@ -159,10 +163,11 @@ class RandomForestClassifier @Since("1.4.0") (
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses,
OldAlgo.Classification, getOldImpurity)
strategy.bootstrap = $(bootstrap)
+ strategy.pruneTree = $(pruneTree)
instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol,
probabilityCol,
rawPredictionCol, leafCol, impurity, numTrees, featureSubsetStrategy,
maxDepth, maxBins,
- maxMemoryInMB, minInfoGain, minInstancesPerNode,
minWeightFractionPerNode, seed,
+ maxMemoryInMB, minInfoGain, pruneTree, minInstancesPerNode,
minWeightFractionPerNode, seed,
subsamplingRate, thresholds, cacheNodeIds, checkpointInterval, bootstrap)
val trees = RandomForest
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index cabbc497571b..b3ca7f04c3de 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -41,7 +41,6 @@ import org.apache.spark.util.SizeEstimator
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
-
/**
* ALGORITHM
*
@@ -97,8 +96,9 @@ private[spark] object RandomForest extends Logging with
Serializable {
numTrees: Int,
featureSubsetStrategy: String,
seed: Long): Array[DecisionTreeModel] = {
- val instances = input.map { case LabeledPoint(label, features) =>
- Instance(label, 1.0, features.asML)
+ val instances = input.map {
+ case LabeledPoint(label, features) =>
+ Instance(label, 1.0, features.asML)
}
run(instances, strategy, numTrees, featureSubsetStrategy, seed, None)
}
@@ -124,7 +124,6 @@ private[spark] object RandomForest extends Logging with
Serializable {
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation],
- prune: Boolean = true, // exposed for testing only, real trees are
always pruned
parentUID: Option[String] = None,
earlyStopModelSizeThresholdInBytes: Long = 0): Array[DecisionTreeModel]
= {
lastEarlyStoppedModelSize = 0
@@ -151,7 +150,8 @@ private[spark] object RandomForest extends Logging with
Serializable {
// depth of the decision tree
val maxDepth = strategy.maxDepth
- require(maxDepth <= 30,
+ require(
+ maxDepth <= 30,
s"DecisionTree currently only supports maxDepth <= 30, but was given
maxDepth = $maxDepth.")
// Max memory usage for aggregates
@@ -203,9 +203,10 @@ private[spark] object RandomForest extends Logging with
Serializable {
// Collect some nodes to split, and choose features for each node (if
subsampling).
// Each group of nodes may come from one or multiple trees, and at
multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
- RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
+ RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata,
rng)
// Sanity check (should never occur):
- assert(nodesForGroup.nonEmpty,
+ assert(
+ nodesForGroup.nonEmpty,
s"RandomForest selected empty nodesForGroup. Error for unknown
reason.")
// Only send trees to worker if they contain nodes being split this
iteration.
@@ -214,8 +215,16 @@ private[spark] object RandomForest extends Logging with
Serializable {
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
- val bestSplit = RandomForest.findBestSplits(baggedInput, metadata,
topNodesForGroup,
- nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack, timer,
nodeIds,
+ val bestSplit = RandomForest.findBestSplits(
+ baggedInput,
+ metadata,
+ topNodesForGroup,
+ nodesForGroup,
+ treeToNodeToIndexInfo,
+ bcSplits,
+ nodeStack,
+ timer,
+ nodeIds,
outputBestSplits = strategy.useNodeIdCache)
if (strategy.useNodeIdCache) {
nodeIds = updateNodeIds(baggedInput, nodeIds, bcSplits, bestSplit)
@@ -225,7 +234,7 @@ private[spark] object RandomForest extends Logging with
Serializable {
timer.stop("findBestSplits")
if (earlyStopModelSizeThresholdInBytes > 0) {
- val nodes = topNodes.map(_.toNode(prune))
+ val nodes = topNodes.map(_.toNode(strategy.pruneTree))
val estimatedSize = SizeEstimator.estimate(nodes)
if (estimatedSize > earlyStopModelSizeThresholdInBytes){
earlyStop = true
@@ -258,23 +267,28 @@ private[spark] object RandomForest extends Logging with
Serializable {
case Some(uid) =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
- new DecisionTreeClassificationModel(uid, rootNode.toNode(prune),
numFeatures,
+ new DecisionTreeClassificationModel(
+ uid,
+ rootNode.toNode(strategy.pruneTree),
+ numFeatures,
strategy.getNumClasses())
}
} else {
topNodes.map { rootNode =>
- new DecisionTreeRegressionModel(uid, rootNode.toNode(prune),
numFeatures)
+ new DecisionTreeRegressionModel(uid,
rootNode.toNode(strategy.pruneTree), numFeatures)
}
}
case None =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
- new DecisionTreeClassificationModel(rootNode.toNode(prune),
numFeatures,
+ new DecisionTreeClassificationModel(
+ rootNode.toNode(strategy.pruneTree),
+ numFeatures,
strategy.getNumClasses())
}
} else {
topNodes.map(rootNode =>
- new DecisionTreeRegressionModel(rootNode.toNode(prune),
numFeatures))
+ new
DecisionTreeRegressionModel(rootNode.toNode(strategy.pruneTree), numFeatures))
}
}
}
@@ -293,7 +307,6 @@ private[spark] object RandomForest extends Logging with
Serializable {
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation],
- prune: Boolean = true, // exposed for testing only, real trees are
always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val earlyStopModelSizeThresholdInBytes =
TreeConfig.trainingEarlyStopModelSizeThresholdInBytes
val timer = new TimeTracker()
@@ -311,9 +324,12 @@ private[spark] object RandomForest extends Logging with
Serializable {
val splits = findSplits(retaggedInput, metadata, seed)
timer.stop("findSplits")
logDebug("numBins: feature: number of bins")
- logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
- s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
- }.mkString("\n"))
+ logDebug(
+ Range(0, metadata.numFeatures)
+ .map { featureIndex =>
+ s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+ }
+ .mkString("\n"))
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
@@ -321,14 +337,26 @@ private[spark] object RandomForest extends Logging with
Serializable {
val bcSplits = input.sparkContext.broadcast(splits)
val baggedInput = BaggedPoint
- .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees,
strategy.bootstrap,
- (tp: TreePoint) => tp.weight, seed = seed)
+ .convertToBaggedRDD(
+ treeInput,
+ strategy.subsamplingRate,
+ numTrees,
+ strategy.bootstrap,
+ (tp: TreePoint) => tp.weight,
+ seed = seed)
.persist(StorageLevel.MEMORY_AND_DISK)
.setName("bagged tree points")
- val trees = runBagged(baggedInput = baggedInput, metadata = metadata,
bcSplits = bcSplits,
- strategy = strategy, numTrees = numTrees, featureSubsetStrategy =
featureSubsetStrategy,
- seed = seed, instr = instr, prune = prune, parentUID = parentUID,
+ val trees = runBagged(
+ baggedInput = baggedInput,
+ metadata = metadata,
+ bcSplits = bcSplits,
+ strategy = strategy,
+ numTrees = numTrees,
+ featureSubsetStrategy = featureSubsetStrategy,
+ seed = seed,
+ instr = instr,
+ parentUID = parentUID,
earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes)
baggedInput.unpersist()
@@ -346,26 +374,27 @@ private[spark] object RandomForest extends Logging with
Serializable {
bcSplits: Broadcast[Array[Array[Split]]],
bestSplits: Array[Map[Int, Split]]): RDD[Array[Int]] = {
require(nodeIds != null && bestSplits != null)
- input.zip(nodeIds).map { case (point, ids) =>
- var treeId = 0
- while (treeId < bestSplits.length) {
- val bestSplitsInTree = bestSplits(treeId)
- if (bestSplitsInTree != null) {
- val nodeId = ids(treeId)
- bestSplitsInTree.get(nodeId).foreach { bestSplit =>
- val featureId = bestSplit.featureIndex
- val bin = point.datum.binnedFeatures(featureId)
- val newNodeId = if (bestSplit.shouldGoLeft(bin,
bcSplits.value(featureId))) {
- LearningNode.leftChildIndex(nodeId)
- } else {
- LearningNode.rightChildIndex(nodeId)
+ input.zip(nodeIds).map {
+ case (point, ids) =>
+ var treeId = 0
+ while (treeId < bestSplits.length) {
+ val bestSplitsInTree = bestSplits(treeId)
+ if (bestSplitsInTree != null) {
+ val nodeId = ids(treeId)
+ bestSplitsInTree.get(nodeId).foreach { bestSplit =>
+ val featureId = bestSplit.featureIndex
+ val bin = point.datum.binnedFeatures(featureId)
+ val newNodeId = if (bestSplit.shouldGoLeft(bin,
bcSplits.value(featureId))) {
+ LearningNode.leftChildIndex(nodeId)
+ } else {
+ LearningNode.rightChildIndex(nodeId)
+ }
+ ids(treeId) = newNodeId
}
- ids(treeId) = newNodeId
}
+ treeId += 1
}
- treeId += 1
- }
- ids
+ ids
}
}
@@ -417,7 +446,11 @@ private[spark] object RandomForest extends Logging with
Serializable {
var splitIndex = 0
while (splitIndex < numSplits) {
if (featureSplits(splitIndex).shouldGoLeft(featureValue,
featureSplits)) {
- agg.featureUpdate(leftNodeFeatureOffset, splitIndex,
treePoint.label, numSamples,
+ agg.featureUpdate(
+ leftNodeFeatureOffset,
+ splitIndex,
+ treePoint.label,
+ numSamples,
sampleWeight)
}
splitIndex += 1
@@ -532,8 +565,9 @@ private[spark] object RandomForest extends Logging with
Serializable {
logDebug(s"numFeatures = ${metadata.numFeatures}")
logDebug(s"numClasses = ${metadata.numClasses}")
logDebug(s"isMulticlass = ${metadata.isMulticlass}")
- logDebug(s"isMulticlassWithCategoricalFeatures = " +
- s"${metadata.isMulticlassWithCategoricalFeatures}")
+ logDebug(
+ s"isMulticlassWithCategoricalFeatures = " +
+ s"${metadata.isMulticlassWithCategoricalFeatures}")
logDebug(s"using nodeIdCache = $useNodeIdCache")
/*
@@ -560,11 +594,21 @@ private[spark] object RandomForest extends Logging with
Serializable {
val numSamples = baggedPoint.subsampleCounts(treeIndex)
val sampleWeight = baggedPoint.sampleWeight
if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples,
sampleWeight,
+ orderedBinSeqOp(
+ agg(aggNodeIndex),
+ baggedPoint.datum,
+ numSamples,
+ sampleWeight,
featuresForNode)
} else {
- mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
- metadata.unorderedFeatures, numSamples, sampleWeight,
featuresForNode)
+ mixedBinSeqOp(
+ agg(aggNodeIndex),
+ baggedPoint.datum,
+ splits,
+ metadata.unorderedFeatures,
+ numSamples,
+ sampleWeight,
+ featuresForNode)
}
agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples,
sampleWeight)
}
@@ -585,11 +629,16 @@ private[spark] object RandomForest extends Logging with
Serializable {
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint],
splits: Array[Array[Split]]): Array[DTStatsAggregator] = {
- treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
- val nodeIndex =
-
topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures,
splits)
- nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
- agg, baggedPoint, splits)
+ treeToNodeToIndexInfo.foreach {
+ case (treeIndex, nodeIndexToInfo) =>
+ val nodeIndex =
+
topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures,
splits)
+ nodeBinSeqOp(
+ treeIndex,
+ nodeIndexToInfo.getOrElse(nodeIndex, null),
+ agg,
+ baggedPoint,
+ splits)
}
agg
}
@@ -601,12 +650,17 @@ private[spark] object RandomForest extends Logging with
Serializable {
agg: Array[DTStatsAggregator],
dataPoint: (BaggedPoint[TreePoint], Array[Int]),
splits: Array[Array[Split]]): Array[DTStatsAggregator] = {
- treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
- val baggedPoint = dataPoint._1
- val nodeIdCache = dataPoint._2
- val nodeIndex = nodeIdCache(treeIndex)
- nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
- agg, baggedPoint, splits)
+ treeToNodeToIndexInfo.foreach {
+ case (treeIndex, nodeIndexToInfo) =>
+ val baggedPoint = dataPoint._1
+ val nodeIdCache = dataPoint._2
+ val nodeIndex = nodeIdCache(treeIndex)
+ nodeBinSeqOp(
+ treeIndex,
+ nodeIndexToInfo.getOrElse(nodeIndex, null),
+ agg,
+ baggedPoint,
+ splits)
}
agg
}
@@ -615,8 +669,8 @@ private[spark] object RandomForest extends Logging with
Serializable {
* Get node index in group --> features indices map,
* which is a short cut to find feature indices for a node given node
index in group.
*/
- def getNodeToFeatures(
- treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]):
Option[Map[Int, Array[Int]]] = {
+ def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int,
NodeIndexInfo]])
+ : Option[Map[Int, Array[Int]]] = {
if (!metadata.subsamplingFeatures) {
None
} else {
@@ -624,7 +678,8 @@ private[spark] object RandomForest extends Logging with
Serializable {
treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
assert(nodeIndexInfo.featureSubset.isDefined)
- mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) =
nodeIndexInfo.featureSubset.get
+ mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) =
+ nodeIndexInfo.featureSubset.get
}
}
Some(mutableNodeToFeatures.toMap)
@@ -633,10 +688,11 @@ private[spark] object RandomForest extends Logging with
Serializable {
// array of nodes to train indexed by node index in group
val nodes = new Array[LearningNode](numNodes)
- nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
- nodesForTree.foreach { node =>
- nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) =
node
- }
+ nodesForGroup.foreach {
+ case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) =
node
+ }
}
// Calculate best splits for all nodes in the group
@@ -690,17 +746,20 @@ private[spark] object RandomForest extends Logging with
Serializable {
}
}
- val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) =>
a.merge(b)).map {
- case (nodeIndex, aggStats) =>
- val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures
=>
- Some(nodeToFeatures(nodeIndex))
- }
+ val nodeToBestSplits = partitionAggregates
+ .reduceByKey((a, b) => a.merge(b))
+ .map {
+ case (nodeIndex, aggStats) =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap {
nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
- // find best split for each node
- val (split: Split, stats: ImpurityStats) =
- binsToBestSplit(aggStats, bcSplits.value, featuresForNode,
nodes(nodeIndex))
- (nodeIndex, (split, stats))
- }.collectAsMap()
+ // find best split for each node
+ val (split: Split, stats: ImpurityStats) =
+ binsToBestSplit(aggStats, bcSplits.value, featuresForNode,
nodes(nodeIndex))
+ (nodeIndex, (split, stats))
+ }
+ .collectAsMap()
nodeToFeaturesBc.destroy()
timer.stop("chooseSplits")
@@ -712,55 +771,64 @@ private[spark] object RandomForest extends Logging with
Serializable {
}
// Iterate over all nodes in this group.
- nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
- nodesForTree.foreach { node =>
- val nodeIndex = node.id
- val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
- val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val (split: Split, stats: ImpurityStats) =
- nodeToBestSplits(aggNodeIndex)
- logDebug(s"best split = $split")
-
- // Extract info for this node. Create children if not leaf.
- val isLeaf =
- (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) ==
metadata.maxDepth)
- node.isLeaf = isLeaf
- node.stats = stats
- logDebug(s"Node = $node")
-
- if (!isLeaf) {
- node.split = Some(split)
- val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) ==
metadata.maxDepth
- val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) <
Utils.EPSILON)
- val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity)
< Utils.EPSILON)
- node.leftChild =
Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
- leftChildIsLeaf,
ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
- node.rightChild =
Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
- rightChildIsLeaf,
ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
-
- if (outputBestSplits) {
- val bestSplitsInTree = bestSplits(treeIndex)
- if (bestSplitsInTree == null) {
- bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex ->
split)
- } else {
- bestSplitsInTree.update(nodeIndex, split)
+ nodesForGroup.foreach {
+ case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ val nodeIndex = node.id
+ val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val (split: Split, stats: ImpurityStats) =
+ nodeToBestSplits(aggNodeIndex)
+ logDebug(s"best split = $split")
+
+ // Extract info for this node. Create children if not leaf.
+ val isLeaf =
+ (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) ==
metadata.maxDepth)
+ node.isLeaf = isLeaf
+ node.stats = stats
+ logDebug(s"Node = $node")
+
+ if (!isLeaf) {
+ node.split = Some(split)
+ val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) ==
metadata.maxDepth
+ val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity)
< Utils.EPSILON)
+ val rightChildIsLeaf = childIsLeaf ||
(math.abs(stats.rightImpurity) < Utils.EPSILON)
+ node.leftChild = Some(
+ LearningNode(
+ LearningNode.leftChildIndex(nodeIndex),
+ leftChildIsLeaf,
+
ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
+ node.rightChild = Some(
+ LearningNode(
+ LearningNode.rightChildIndex(nodeIndex),
+ rightChildIsLeaf,
+
ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
+
+ if (outputBestSplits) {
+ val bestSplitsInTree = bestSplits(treeIndex)
+ if (bestSplitsInTree == null) {
+ bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex ->
split)
+ } else {
+ bestSplitsInTree.update(nodeIndex, split)
+ }
}
- }
- // enqueue left child and right child if they are not leaves
- if (!leftChildIsLeaf) {
- nodeStack.prepend((treeIndex, node.leftChild.get))
- }
- if (!rightChildIsLeaf) {
- nodeStack.prepend((treeIndex, node.rightChild.get))
- }
+ // enqueue left child and right child if they are not leaves
+ if (!leftChildIsLeaf) {
+ nodeStack.prepend((treeIndex, node.leftChild.get))
+ }
+ if (!rightChildIsLeaf) {
+ nodeStack.prepend((treeIndex, node.rightChild.get))
+ }
- logDebug(s"leftChildIndex = ${node.leftChild.get.id}" +
- s", impurity = ${stats.leftImpurity}")
- logDebug(s"rightChildIndex = ${node.rightChild.get.id}" +
- s", impurity = ${stats.rightImpurity}")
+ logDebug(
+ s"leftChildIndex = ${node.leftChild.get.id}" +
+ s", impurity = ${stats.leftImpurity}")
+ logDebug(
+ s"rightChildIndex = ${node.rightChild.get.id}" +
+ s", impurity = ${stats.rightImpurity}")
+ }
}
- }
}
if (outputBestSplits) {
@@ -830,8 +898,12 @@ private[spark] object RandomForest extends Logging with
Serializable {
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
- new ImpurityStats(gain, impurity, parentImpurityCalculator,
- leftImpurityCalculator, rightImpurityCalculator)
+ new ImpurityStats(
+ gain,
+ impurity,
+ parentImpurityCalculator,
+ leftImpurityCalculator,
+ rightImpurityCalculator)
}
/**
@@ -855,130 +927,156 @@ private[spark] object RandomForest extends Logging with
Serializable {
}
val validFeatureSplits =
- Iterator.range(0, binAggregates.metadata.numFeaturesPerNode).map {
featureIndexIdx =>
- featuresForNode.map(features => (featureIndexIdx,
features(featureIndexIdx)))
- .getOrElse((featureIndexIdx, featureIndexIdx))
- }.withFilter { case (_, featureIndex) =>
- binAggregates.metadata.numSplits(featureIndex) != 0
- }
+ Iterator
+ .range(0, binAggregates.metadata.numFeaturesPerNode)
+ .map { featureIndexIdx =>
+ featuresForNode
+ .map(features => (featureIndexIdx, features(featureIndexIdx)))
+ .getOrElse((featureIndexIdx, featureIndexIdx))
+ }
+ .withFilter {
+ case (_, featureIndex) =>
+ binAggregates.metadata.numSplits(featureIndex) != 0
+ }
// For each (feature, split), calculate the gain, and select the best
(feature, split).
val splitsAndImpurityInfo =
- validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
- val numSplits = binAggregates.metadata.numSplits(featureIndex)
- if (binAggregates.metadata.isContinuous(featureIndex)) {
- // Cumulative sum (scanLeft) of bin statistics.
- // Afterwards, binAggregates for a bin is the sum of aggregates for
- // that bin + all preceding bins.
- val nodeFeatureOffset =
binAggregates.getFeatureOffset(featureIndexIdx)
- var splitIndex = 0
- while (splitIndex < numSplits) {
- binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1,
splitIndex)
- splitIndex += 1
- }
- // Find best split.
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
- Range(0, numSplits).map { splitIdx =>
- val leftChildStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset,
splitIdx)
- val rightChildStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset,
numSplits)
- rightChildStats.subtract(leftChildStats)
- gainAndImpurityStats =
calculateImpurityStats(gainAndImpurityStats,
- leftChildStats, rightChildStats, binAggregates.metadata)
- (splitIdx, gainAndImpurityStats)
- }.maxBy(_._2.gain)
- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
- } else if (binAggregates.metadata.isUnordered(featureIndex)) {
- // Unordered categorical feature
- val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
- Range(0, numSplits).map { splitIndex =>
- val leftChildStats =
binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
- val rightChildStats = binAggregates.getParentImpurityCalculator()
- .subtract(leftChildStats)
- gainAndImpurityStats =
calculateImpurityStats(gainAndImpurityStats,
- leftChildStats, rightChildStats, binAggregates.metadata)
- (splitIndex, gainAndImpurityStats)
- }.maxBy(_._2.gain)
- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
- } else {
- // Ordered categorical feature
- val nodeFeatureOffset =
binAggregates.getFeatureOffset(featureIndexIdx)
- val numCategories = binAggregates.metadata.numBins(featureIndex)
-
- /* Each bin is one category (feature value).
- * The bins are ordered based on centroidForCategories, and this
ordering determines which
- * splits are considered. (With K categories, we consider K - 1
possible splits.)
- *
+ validFeatureSplits.map {
+ case (featureIndexIdx, featureIndex) =>
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ val nodeFeatureOffset =
binAggregates.getFeatureOffset(featureIndexIdx)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1,
splitIndex)
+ splitIndex += 1
+ }
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits)
+ .map { splitIdx =>
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset,
splitIdx)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset,
numSplits)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(
+ gainAndImpurityStats,
+ leftChildStats,
+ rightChildStats,
+ binAggregates.metadata)
+ (splitIdx, gainAndImpurityStats)
+ }
+ .maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
+ // Unordered categorical feature
+ val leftChildOffset =
binAggregates.getFeatureOffset(featureIndexIdx)
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits)
+ .map { splitIndex =>
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(leftChildOffset,
splitIndex)
+ val rightChildStats = binAggregates
+ .getParentImpurityCalculator()
+ .subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(
+ gainAndImpurityStats,
+ leftChildStats,
+ rightChildStats,
+ binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }
+ .maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else {
+ // Ordered categorical feature
+ val nodeFeatureOffset =
binAggregates.getFeatureOffset(featureIndexIdx)
+ val numCategories = binAggregates.metadata.numBins(featureIndex)
+
+ /* Each bin is one category (feature value).
+ * The bins are ordered based on centroidForCategories, and this
ordering determines
+ * which splits are considered. (With K categories, we
+ * consider K - 1 possible splits.)
+ *
* centroidForCategories is a list: (category, centroid)
- */
- val centroidForCategories = Range(0, numCategories).map {
featureValue =>
- val categoryStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset,
featureValue)
- val centroid = if (categoryStats.count != 0) {
- if (binAggregates.metadata.isMulticlass) {
- // multiclass classification
- // For categorical variables in multiclass classification,
- // the bins are ordered by the impurity of their corresponding
labels.
- categoryStats.calculate()
- } else if (binAggregates.metadata.isClassification) {
- // binary classification
- // For categorical variables in binary classification,
- // the bins are ordered by the count of class 1.
- categoryStats.stats(1)
+ */
+ val centroidForCategories = Range(0, numCategories).map {
featureValue =>
+ val categoryStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset,
featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ if (binAggregates.metadata.isMulticlass) {
+ // multiclass classification
+ // For categorical variables in multiclass classification,
+ // the bins are ordered by the impurity of their
corresponding labels.
+ categoryStats.calculate()
+ } else if (binAggregates.metadata.isClassification) {
+ // binary classification
+ // For categorical variables in binary classification,
+ // the bins are ordered by the count of class 1.
+ categoryStats.stats(1)
+ } else {
+ // regression
+ // For categorical variables in regression and binary
classification,
+ // the bins are ordered by the prediction.
+ categoryStats.predict
+ }
} else {
- // regression
- // For categorical variables in regression and binary
classification,
- // the bins are ordered by the prediction.
- categoryStats.predict
+ Double.MaxValue
}
- } else {
- Double.MaxValue
+ (featureValue, centroid)
}
- (featureValue, centroid)
- }
- logDebug(s"Centroids for categorical variable: " +
- s"${centroidForCategories.mkString(",")}")
-
- // bins sorted by centroids
- val categoriesSortedByCentroid =
centroidForCategories.toList.sortBy(_._2)
-
- logDebug(s"Sorted centroids for categorical variable = " +
- s"${categoriesSortedByCentroid.mkString(",")}")
-
- // Cumulative sum (scanLeft) of bin statistics.
- // Afterwards, binAggregates for a bin is the sum of aggregates for
- // that bin + all preceding bins.
- var splitIndex = 0
- while (splitIndex < numSplits) {
- val currentCategory = categoriesSortedByCentroid(splitIndex)._1
- val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
- binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory,
currentCategory)
- splitIndex += 1
+ logDebug(
+ s"Centroids for categorical variable: " +
+ s"${centroidForCategories.mkString(",")}")
+
+ // bins sorted by centroids
+ val categoriesSortedByCentroid =
centroidForCategories.toList.sortBy(_._2)
+
+ logDebug(
+ s"Sorted centroids for categorical variable = " +
+ s"${categoriesSortedByCentroid.mkString(",")}")
+
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory,
currentCategory)
+ splitIndex += 1
+ }
+ // lastCategory = index of bin with total aggregates for this
(node, feature)
+ val lastCategory = categoriesSortedByCentroid.last._1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits)
+ .map { splitIndex =>
+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset,
featureValue)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset,
lastCategory)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(
+ gainAndImpurityStats,
+ leftChildStats,
+ rightChildStats,
+ binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }
+ .maxBy(_._2.gain)
+ val categoriesForSplit =
+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0,
bestFeatureSplitIndex + 1)
+ val bestFeatureSplit =
+ new CategoricalSplit(featureIndex, categoriesForSplit.toArray,
numCategories)
+ (bestFeatureSplit, bestFeatureGainStats)
}
- // lastCategory = index of bin with total aggregates for this (node,
feature)
- val lastCategory = categoriesSortedByCentroid.last._1
- // Find best split.
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
- Range(0, numSplits).map { splitIndex =>
- val featureValue = categoriesSortedByCentroid(splitIndex)._1
- val leftChildStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset,
featureValue)
- val rightChildStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset,
lastCategory)
- rightChildStats.subtract(leftChildStats)
- gainAndImpurityStats =
calculateImpurityStats(gainAndImpurityStats,
- leftChildStats, rightChildStats, binAggregates.metadata)
- (splitIndex, gainAndImpurityStats)
- }.maxBy(_._2.gain)
- val categoriesForSplit =
- categoriesSortedByCentroid.map(_._1.toDouble).slice(0,
bestFeatureSplitIndex + 1)
- val bestFeatureSplit =
- new CategoricalSplit(featureIndex, categoriesForSplit.toArray,
numCategories)
- (bestFeatureSplit, bestFeatureGainStats)
- }
}
val (bestSplit, bestSplitStats) =
@@ -989,11 +1087,13 @@ private[spark] object RandomForest extends Logging with
Serializable {
val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0)
val parentImpurityCalculator =
binAggregates.getParentImpurityCalculator()
if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) {
- (new ContinuousSplit(dummyFeatureIndex, 0),
+ (
+ new ContinuousSplit(dummyFeatureIndex, 0),
ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
} else {
val numCategories =
binAggregates.metadata.featureArity(dummyFeatureIndex)
- (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories),
+ (
+ new CategoricalSplit(dummyFeatureIndex, Array(), numCategories),
ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
}
} else {
@@ -1066,27 +1166,34 @@ private[spark] object RandomForest extends Logging with
Serializable {
// being spun up that will definitely do no work.
val numPartitions = math.min(continuousFeatures.length,
input.partitions.length)
- input.flatMap { point =>
- continuousFeatures.iterator
- .map(idx => (idx, (point.features(idx), point.weight)))
- .filter(_._2._1 != 0.0)
- }.aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)(
- seqOp = { case ((map, c), (v, w)) =>
- map.changeValue(v, w, _ + w)
- (map, c + 1L)
- },
- combOp = { case ((map1, c1), (map2, c2)) =>
- map2.foreach { case (v, w) =>
- map1.changeValue(v, w, _ + w)
- }
- (map1, c1 + c2)
+ input
+ .flatMap { point =>
+ continuousFeatures.iterator
+ .map(idx => (idx, (point.features(idx), point.weight)))
+ .filter(_._2._1 != 0.0)
}
- ).map { case (idx, (map, c)) =>
- val thresholds = findSplitsForContinuousFeature(map.toMap, c,
metadata, idx)
- val splits: Array[Split] = thresholds.map(thresh => new
ContinuousSplit(idx, thresh))
- logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
- (idx, splits)
- }.collectAsMap()
+ .aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)(
+ seqOp = {
+ case ((map, c), (v, w)) =>
+ map.changeValue(v, w, _ + w)
+ (map, c + 1L)
+ },
+ combOp = {
+ case ((map1, c1), (map2, c2)) =>
+ map2.foreach {
+ case (v, w) =>
+ map1.changeValue(v, w, _ + w)
+ }
+ (map1, c1 + c2)
+ })
+ .map {
+ case (idx, (map, c)) =>
+ val thresholds = findSplitsForContinuousFeature(map.toMap, c,
metadata, idx)
+ val splits: Array[Split] = thresholds.map(thresh => new
ContinuousSplit(idx, thresh))
+ logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
+ (idx, splits)
+ }
+ .collectAsMap()
} else Map.empty[Int, Array[Split]]
val numFeatures = metadata.numFeatures
@@ -1157,9 +1264,10 @@ private[spark] object RandomForest extends Logging with
Serializable {
featureIndex: Int): Array[Double] = {
val valueWeights = new OpenHashMap[Double, Double]
var count = 0L
- featureSamples.foreach { case (weight, value) =>
- valueWeights.changeValue(value, weight, _ + weight)
- count += 1L
+ featureSamples.foreach {
+ case (weight, value) =>
+ valueWeights.changeValue(value, weight, _ + weight)
+ count += 1L
}
findSplitsForContinuousFeature(valueWeights.toMap, count, metadata,
featureIndex)
}
@@ -1182,7 +1290,8 @@ private[spark] object RandomForest extends Logging with
Serializable {
count: Long,
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
- require(metadata.isContinuous(featureIndex),
+ require(
+ metadata.isContinuous(featureIndex),
"findSplitsForContinuousFeature can only be used to find splits for a
continuous feature.")
val splits = if (partValueWeights.isEmpty) {
@@ -1256,7 +1365,8 @@ private[spark] object RandomForest extends Logging with
Serializable {
private[tree] class NodeIndexInfo(
val nodeIndexInGroup: Int,
- val featureSubset: Option[Array[Int]]) extends Serializable
+ val featureSubset: Option[Array[Int]])
+ extends Serializable
/**
* Pull nodes off of the queue, and collect a group of nodes to be split on
this iteration.
@@ -1294,8 +1404,13 @@ private[spark] object RandomForest extends Logging with
Serializable {
val (treeIndex, node) = nodeStack.head
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if
(metadata.subsamplingFeatures) {
- Some(SamplingUtils.reservoirSampleAndCount(Range(0,
- metadata.numFeatures).iterator, metadata.numFeaturesPerNode,
rng.nextLong())._1)
+ Some(
+ SamplingUtils
+ .reservoirSampleAndCount(
+ Range(0, metadata.numFeatures).iterator,
+ metadata.numFeaturesPerNode,
+ rng.nextLong())
+ ._1)
} else {
None
}
@@ -1303,11 +1418,13 @@ private[spark] object RandomForest extends Logging with
Serializable {
val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata,
featureSubset) * 8L
if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
nodeStack.remove(0)
- mutableNodesForGroup.getOrElseUpdate(treeIndex, new
mutable.ArrayBuffer[LearningNode]()) +=
+ mutableNodesForGroup.getOrElseUpdate(
+ treeIndex,
+ new mutable.ArrayBuffer[LearningNode]()) +=
node
mutableTreeToNodeToIndexInfo
- .getOrElseUpdate(treeIndex, new mutable.HashMap[Int,
NodeIndexInfo]())(node.id)
- = new NodeIndexInfo(numNodesInGroup, featureSubset)
+ .getOrElseUpdate(treeIndex, new mutable.HashMap[Int,
NodeIndexInfo]())(node.id) =
+ new NodeIndexInfo(numNodesInGroup, featureSubset)
numNodesInGroup += 1
memUsage += nodeMemUsage
} else {
@@ -1355,8 +1472,7 @@ private[spark] object RandomForest extends Logging with
Serializable {
* @param metadata decision tree metadata
* @return subsample fraction
*/
- private def samplesFractionForFindSplits(
- metadata: DecisionTreeMetadata): Double = {
+ private def samplesFractionForFindSplits(metadata: DecisionTreeMetadata):
Double = {
// Calculate the number of samples for approximate quantile calculation.
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
if (requiredSamples < metadata.numExamples) {
@@ -1365,4 +1481,5 @@ private[spark] object RandomForest extends Logging with
Serializable {
1.0
}
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 768e14f4b74e..2244d49b2a35 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -211,10 +211,27 @@ private[ml] trait TreeClassifierParams extends Params {
(value: String) =>
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
- setDefault(impurity -> "gini")
+ /**
+ * If true, the trained tree will undergo a pruning process after training,
in which sibling
+ * leaf nodes with the same prediction are merged into their parent. The
resulting tree will
+ * be smaller and have faster predictions. Class probabilities remain
available after pruning.
+ * If false, no pruning is applied after training.
+ * (default = true)
+ * @group param
+ */
+ final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
+ "If true, the trained tree will undergo a pruning process after training,
in which sibling" +
+ " leaf nodes with the same prediction are merged into their parent. The
resulting tree will" +
+ " be smaller and have faster predictions. Class probabilities remain
available after pruning." +
+ " If false, no pruning is applied after training."
+ )
+
+ setDefault(impurity -> "gini", pruneTree -> true)
/** @group getParam */
final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT)
+ /** @group getParam */
+ final def getPruneTree: Boolean = $(pruneTree)
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 200d10130eed..c1f82157e82f 100644
---
a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++
b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -65,6 +65,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini,
Impurity, Variance}
* E.g. 10 means that the cache will get
checkpointed every 10 updates. If
* the checkpoint directory is not set in
* [[org.apache.spark.SparkContext]], this setting
is ignored.
+ * @param pruneTree If this is true, the final training tree will undergo a
pruning in which
+ * nodes with the same prediction are merged.
*/
@Since("1.0.0")
class Strategy @Since("1.3.0") (
@@ -82,6 +84,7 @@ class Strategy @Since("1.3.0") (
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
@Since("3.0.0") @BeanProperty var minWeightFractionPerNode: Double = 0.0,
+ @Since("4.3.0") @BeanProperty var pruneTree: Boolean = true,
@BeanProperty private[spark] var bootstrap: Boolean = false) extends
Serializable {
/**
@@ -201,7 +204,7 @@ class Strategy @Since("1.3.0") (
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo,
minInstancesPerNode,
minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache,
- checkpointInterval, minWeightFractionPerNode)
+ checkpointInterval, minWeightFractionPerNode, pruneTree)
}
}
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 62f25474e947..0c6044181315 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -72,8 +72,9 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
assert(splits(0).length === 0)
}
- test("Binary classification with 3-ary (ordered) categorical features," +
- " with no samples for one category: split calculation") {
+ test(
+ "Binary classification with 3-ary (ordered) categorical features," +
+ " with no samples for one category: split calculation") {
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
assert(arr.length === 1000)
val rdd = sc.parallelize(arr.toImmutableArraySeq)
@@ -108,16 +109,29 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
// SPARK-16957: Use midpoints for split values.
{
- val fakeMetadata = new DecisionTreeMetadata(1, 8, 8.0, 0, 0,
- Map(), Set(),
- Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0.0, 0, 0
- )
+ val fakeMetadata = new DecisionTreeMetadata(
+ 1,
+ 8,
+ 8.0,
+ 0,
+ 0,
+ Map(),
+ Set(),
+ Array(3),
+ Gini,
+ QuantileStrategy.Sort,
+ 0,
+ 0,
+ 0.0,
+ 0.0,
+ 0,
+ 0)
// possibleSplits <= numSplits
{
val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1)
- .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0)
+ .map(x => (1.0, x.toDouble))
+ .filter(_._2 != 0.0)
val splits =
RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((0.0 + 1.0) / 2)
assert(splits === expectedSplits)
@@ -126,7 +140,8 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
// possibleSplits > numSplits
{
val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3)
- .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0)
+ .map(x => (1.0, x.toDouble))
+ .filter(_._2 != 0.0)
val splits =
RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2)
assert(splits === expectedSplits)
@@ -136,11 +151,23 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
// find splits should not return identical splits
// when there are not enough split candidates, reduce the number of splits
in metadata
{
- val fakeMetadata = new DecisionTreeMetadata(1, 12, 12.0, 0, 0,
- Map(), Set(),
- Array(5), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0.0, 0, 0
- )
+ val fakeMetadata = new DecisionTreeMetadata(
+ 1,
+ 12,
+ 12.0,
+ 0,
+ 0,
+ Map(),
+ Set(),
+ Array(5),
+ Gini,
+ QuantileStrategy.Sort,
+ 0,
+ 0,
+ 0.0,
+ 0.0,
+ 0,
+ 0)
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(x =>
(1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples,
fakeMetadata, 0)
val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2)
@@ -151,11 +178,23 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
// find splits when most samples close to the minimum
{
- val fakeMetadata = new DecisionTreeMetadata(1, 18, 18.0, 0, 0,
- Map(), Set(),
- Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0.0, 0, 0
- )
+ val fakeMetadata = new DecisionTreeMetadata(
+ 1,
+ 18,
+ 18.0,
+ 0,
+ 0,
+ Map(),
+ Set(),
+ Array(3),
+ Gini,
+ QuantileStrategy.Sort,
+ 0,
+ 0,
+ 0.0,
+ 0.0,
+ 0,
+ 0)
val featureSamples =
Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x =>
(1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples,
fakeMetadata, 0)
@@ -165,11 +204,23 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
// find splits when most samples close to the maximum
{
- val fakeMetadata = new DecisionTreeMetadata(1, 17, 17.0, 0, 0,
- Map(), Set(),
- Array(2), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0.0, 0, 0
- )
+ val fakeMetadata = new DecisionTreeMetadata(
+ 1,
+ 17,
+ 17.0,
+ 0,
+ 0,
+ Map(),
+ Set(),
+ Array(2),
+ Gini,
+ QuantileStrategy.Sort,
+ 0,
+ 0,
+ 0.0,
+ 0.0,
+ 0,
+ 0)
val featureSamples =
Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x =>
(1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples,
fakeMetadata, 0)
@@ -199,11 +250,23 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
// find splits when most weight is close to the minimum
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
- Map(), Set(),
- Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0.0, 0, 0
- )
+ val fakeMetadata = new DecisionTreeMetadata(
+ 1,
+ 0,
+ 0.0,
+ 0,
+ 0,
+ Map(),
+ Set(),
+ Array(3),
+ Gini,
+ QuantileStrategy.Sort,
+ 0,
+ 0,
+ 0.0,
+ 0.0,
+ 0,
+ 0)
val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1,
6)).map {
case (w, x) => (w.toDouble, x.toDouble)
}
@@ -217,10 +280,10 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val data = Array.fill(5)(lp)
val rdd = sc.parallelize(data.toImmutableArraySeq)
- val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2,
- maxBins = 5)
- withClue("DecisionTree requires number of features > 0," +
- " but was given an empty features vector") {
+ val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2,
maxBins = 5)
+ withClue(
+ "DecisionTree requires number of features > 0," +
+ " but was given an empty features vector") {
intercept[IllegalArgumentException] {
RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
}
@@ -232,23 +295,19 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val data = Array.fill(5)(instance)
val rdd = sc.parallelize(data.toImmutableArraySeq)
val strategy = new OldStrategy(
- OldAlgo.Classification,
- Gini,
- maxDepth = 2,
- numClasses = 2,
- maxBins = 5,
- categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5))
+ OldAlgo.Classification,
+ Gini,
+ maxDepth = 2,
+ numClasses = 2,
+ maxBins = 5,
+ categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5))
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr =
None)
assert(tree.rootNode.impurity === -1.0)
assert(tree.depth === 0)
assert(tree.rootNode.prediction === instance.label)
// Test with no categorical features
- val strategy2 = new OldStrategy(
- OldAlgo.Regression,
- Variance,
- maxDepth = 2,
- maxBins = 5)
+ val strategy2 = new OldStrategy(OldAlgo.Regression, Variance, maxDepth =
2, maxBins = 5)
val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr =
None)
assert(tree2.rootNode.impurity === -1.0)
assert(tree2.depth === 0)
@@ -279,12 +338,15 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
assert(metadata.numBins(1) === 3)
// Expecting 2^2 - 1 = 3 splits per feature
- def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories:
Array[Double]): Unit = {
+ def checkCategoricalSplit(
+ s: Split,
+ featureIndex: Int,
+ leftCategories: Array[Double]): Unit = {
assert(s.featureIndex === featureIndex)
assert(s.isInstanceOf[CategoricalSplit])
val s0 = s.asInstanceOf[CategoricalSplit]
assert(s0.leftCategories === leftCategories)
- assert(s0.numCategories === 3) // for this unit test
+ assert(s0.numCategories === 3) // for this unit test
}
// Feature 0
checkCategoricalSplit(splits(0)(0), 0, Array(0.0))
@@ -297,7 +359,8 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
}
test("Multiclass classification with ordered categorical features: split
calculations") {
- val arr =
OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+ val arr = OldDTSuite
+ .generateCategoricalDataPointsForMulticlassForOrderedFeatures()
.map(_.asML.toInstance)
assert(arr.length === 3000)
val rdd = sc.parallelize(arr.toImmutableArraySeq)
@@ -332,8 +395,12 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq)
- val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 1,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Classification,
+ impurity = Gini,
+ maxDepth = 1,
+ numClasses = 2,
+ categoricalFeaturesInfo = Map(0 -> 3))
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val splits = RandomForest.findSplits(input, metadata, seed = 42)
val bcSplits = input.sparkContext.broadcast(splits)
@@ -346,12 +413,17 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
assert(topNode.stats === null)
val nodesForGroup = Map(0 -> Array(topNode))
- val treeToNodeToIndexInfo = Map(0 -> Map(
- topNode.id -> new RandomForest.NodeIndexInfo(0, None)
- ))
+ val treeToNodeToIndexInfo =
+ Map(0 -> Map(topNode.id -> new RandomForest.NodeIndexInfo(0, None)))
val nodeStack = new mutable.ListBuffer[(Int, LearningNode)]
- RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
- nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack)
+ RandomForest.findBestSplits(
+ baggedInput,
+ metadata,
+ Map(0 -> topNode),
+ nodesForGroup,
+ treeToNodeToIndexInfo,
+ bcSplits,
+ nodeStack)
bcSplits.destroy()
// don't enqueue leaf nodes into node queue
@@ -376,8 +448,12 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq)
- val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 5,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Classification,
+ impurity = Gini,
+ maxDepth = 5,
+ numClasses = 2,
+ categoricalFeaturesInfo = Map(0 -> 3))
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val splits = RandomForest.findSplits(input, metadata, seed = 42)
val bcSplits = input.sparkContext.broadcast(splits)
@@ -390,12 +466,17 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
assert(topNode.stats === null)
val nodesForGroup = Map(0 -> Array(topNode))
- val treeToNodeToIndexInfo = Map(0 -> Map(
- topNode.id -> new RandomForest.NodeIndexInfo(0, None)
- ))
+ val treeToNodeToIndexInfo =
+ Map(0 -> Map(topNode.id -> new RandomForest.NodeIndexInfo(0, None)))
val nodeStack = new mutable.ListBuffer[(Int, LearningNode)]
- RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
- nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack)
+ RandomForest.findBestSplits(
+ baggedInput,
+ metadata,
+ Map(0 -> topNode),
+ nodesForGroup,
+ treeToNodeToIndexInfo,
+ bcSplits,
+ nodeStack)
bcSplits.destroy()
// don't enqueue a node into node queue if its impurity is 0.0
@@ -431,18 +512,32 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq)
// Must set maxBins s.t. the feature will be treated as an ordered
categorical feature.
- val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 1,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
-
- val model = RandomForest.run(input, strategy, numTrees = 1,
featureSubsetStrategy = "all",
- seed = 42, instr = None, prune = false).head
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Classification,
+ impurity = Gini,
+ maxDepth = 1,
+ numClasses = 2,
+ categoricalFeaturesInfo = Map(0 -> 3),
+ maxBins = 3)
+
+ strategy.pruneTree = false
+ val model = RandomForest
+ .run(
+ input,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "all",
+ seed = 42,
+ instr = None)
+ .head
model.rootNode match {
- case n: InternalNode => n.split match {
- case s: CategoricalSplit =>
- assert(s.leftCategories === Array(1.0))
- case _ => fail("model.rootNode.split was not a CategoricalSplit")
- }
+ case n: InternalNode =>
+ n.split match {
+ case s: CategoricalSplit =>
+ assert(s.leftCategories === Array(1.0))
+ case _ => fail("model.rootNode.split was not a CategoricalSplit")
+ }
case _ => fail("model.rootNode was not an InternalNode")
}
}
@@ -458,18 +553,21 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val strategy2 =
new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100,
maxMemoryInMB = 0)
- val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1,
featureSubsetStrategy = "all",
- seed = 42, instr = None).head
- val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1,
featureSubsetStrategy = "all",
- seed = 42, instr = None).head
-
- def getChildren(rootNode: Node): Array[InternalNode] = rootNode match {
- case n: InternalNode =>
- assert(n.leftChild.isInstanceOf[InternalNode])
- assert(n.rightChild.isInstanceOf[InternalNode])
- Array(n.leftChild.asInstanceOf[InternalNode],
n.rightChild.asInstanceOf[InternalNode])
- case _ => fail("rootNode was not an InternalNode")
- }
+ val tree1 = RandomForest
+ .run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all", seed =
42, instr = None)
+ .head
+ val tree2 = RandomForest
+ .run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all", seed =
42, instr = None)
+ .head
+
+ def getChildren(rootNode: Node): Array[InternalNode] =
+ rootNode match {
+ case n: InternalNode =>
+ assert(n.leftChild.isInstanceOf[InternalNode])
+ assert(n.rightChild.isInstanceOf[InternalNode])
+ Array(n.leftChild.asInstanceOf[InternalNode],
n.rightChild.asInstanceOf[InternalNode])
+ case _ => fail("rootNode was not an InternalNode")
+ }
// Single group second level tree construction.
val children1 = getChildren(tree1.rootNode)
@@ -515,8 +613,9 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
nodeStack.prepend((treeIndex, topNodes(treeIndex)))
}
val rng = new scala.util.Random(seed = seed)
- val (nodesForGroup: Map[Int, Array[LearningNode]],
- treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]])
=
+ val (
+ nodesForGroup: Map[Int, Array[LearningNode]],
+ treeToNodeToIndexInfo: Map[Int, Map[Int,
RandomForest.NodeIndexInfo]]) =
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata,
rng)
assert(nodesForGroup.size === numTrees, failString)
@@ -524,12 +623,15 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
if (numFeaturesPerNode == numFeatures) {
// featureSubset values should all be None
-
assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
+ assert(
+
treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
failString)
} else {
// Check number of features.
- assert(treeToNodeToIndexInfo.values.forall(_.values.forall(
- _.featureSubset.get.length === numFeaturesPerNode)), failString)
+ assert(
+ treeToNodeToIndexInfo.values.forall(
+ _.values.forall(_.featureSubset.get.length ===
numFeaturesPerNode)),
+ failString)
}
}
}
@@ -537,7 +639,9 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures)
checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures)
checkFeatureSubsetStrategy(numTrees = 1, "sqrt",
math.sqrt(numFeatures).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 1, "log2",
+ checkFeatureSubsetStrategy(
+ numTrees = 1,
+ "log2",
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures /
3.0).ceil.toInt)
@@ -555,7 +659,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1",
"0")
for (invalidStrategy <- invalidStrategies) {
- intercept[IllegalArgumentException]{
+ intercept[IllegalArgumentException] {
val metadata =
DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1,
invalidStrategy)
}
@@ -564,7 +668,9 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
checkFeatureSubsetStrategy(numTrees = 2, "auto",
math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "sqrt",
math.sqrt(numFeatures).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 2, "log2",
+ checkFeatureSubsetStrategy(
+ numTrees = 2,
+ "log2",
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures /
3.0).ceil.toInt)
@@ -578,7 +684,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
}
for (invalidStrategy <- invalidStrategies) {
- intercept[IllegalArgumentException]{
+ intercept[IllegalArgumentException] {
val metadata =
DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2,
invalidStrategy)
}
@@ -587,15 +693,23 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
test("Binary classification with continuous features: subsampling features")
{
val categoricalFeaturesInfo = Map.empty[Int, Int]
- val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 2,
- numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Classification,
+ impurity = Gini,
+ maxDepth = 2,
+ numClasses = 2,
+ categoricalFeaturesInfo = categoricalFeaturesInfo)
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
}
test("Binary classification with continuous features and node Id cache:
subsampling features") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
- val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 2,
- numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Classification,
+ impurity = Gini,
+ maxDepth = 2,
+ numClasses = 2,
+ categoricalFeaturesInfo = categoricalFeaturesInfo,
useNodeIdCache = true)
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
}
@@ -648,7 +762,8 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
}
val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
val tree2norm = feature0importance + feature1importance
- val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
+ val expected = Vectors.dense(
+ (1.0 + feature0importance / tree2norm) / 2.0,
(feature1importance / tree2norm) / 2.0)
assert(importances ~== expected relTol 0.01)
}
@@ -682,18 +797,45 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val rdd = sc.parallelize(arr.toImmutableArraySeq)
val numClasses = 2
- val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 4,
- numClasses = numClasses, maxBins = 32)
-
- val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None).head
-
- val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None, prune = false).head
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Classification,
+ impurity = Gini,
+ maxDepth = 4,
+ numClasses = numClasses,
+ maxBins = 32)
+
+ strategy.pruneTree = true
+ val prunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None)
+ .head
+
+ strategy.pruneTree = false
+ val unprunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None)
+ .head
+
+ strategy.pruneTree = true
+ val defaultBehaviorTree = RandomForest
+ .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed =
42, instr = None)
+ .head
assert(prunedTree.numNodes === 5)
assert(unprunedTree.numNodes === 7)
+ assert(defaultBehaviorTree.numNodes == prunedTree.numNodes)
+
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) ===
arr.length)
}
@@ -712,17 +854,45 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
)
val rdd = sc.parallelize(arr.toImmutableArraySeq)
- val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity =
Variance, maxDepth = 4,
- numClasses = 0, maxBins = 32)
-
- val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None).head
-
- val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None, prune = false).head
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Regression,
+ impurity = Variance,
+ maxDepth = 4,
+ numClasses = 0,
+ maxBins = 32)
+
+ strategy.pruneTree = true
+ val prunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None)
+ .head
+
+ strategy.pruneTree = false
+ val unprunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None)
+ .head
+
+ strategy.pruneTree = true
+ val defaultBehaviorTree = RandomForest
+ .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed =
42, instr = None)
+ .head
assert(prunedTree.numNodes === 3)
assert(unprunedTree.numNodes === 5)
+
+ assert(defaultBehaviorTree.numNodes == prunedTree.numNodes)
+
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) ===
arr.length)
}
@@ -739,13 +909,15 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3,
"all", 42L, None)
val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3,
"all", 42L, None)
- unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree,
smallWeightTree) =>
- TreeTests.checkEqual(unitTree, smallWeightTree)
+ unitWeightTrees.zip(smallWeightTrees).foreach {
+ case (unitTree, smallWeightTree) =>
+ TreeTests.checkEqual(unitTree, smallWeightTree)
}
val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3,
"all", 42L, None)
- unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree,
bigWeightTree) =>
- TreeTests.checkEqual(unitTree, bigWeightTree)
+ unitWeightTrees.zip(bigWeightTrees).foreach {
+ case (unitTree, bigWeightTree) =>
+ TreeTests.checkEqual(unitTree, bigWeightTree)
}
}
@@ -778,6 +950,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
}
private object RandomForestSuite {
+
def mapToVec(map: Map[Int, Double]): Vector = {
val size = (map.keys.toSeq :+ 0).max + 1
val (indices, values) = map.toSeq.sortBy(_._1).unzip
@@ -788,12 +961,12 @@ private object RandomForestSuite {
private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = {
if (nodes.isEmpty) {
acc
- }
- else {
+ } else {
nodes.head match {
case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild
:: nodes.tail, acc)
case l: LeafNode => getSumLeafCounters(nodes.tail, acc +
l.impurityStats.rawCount)
}
}
}
+
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index b5434efee090..bf2984ba8c6d 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -55,7 +55,9 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationAttemptInfo.copy"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.status.api.v1.ApplicationAttemptInfo$"),
// [SPARK-56330][CORE] Add TaskInterruptListener to TaskContext for
interrupt notifications
-
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.addTaskInterruptListener")
+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.addTaskInterruptListener"),
+ // [SPARK-34591][ML] Add pruneTree parameter to Strategy
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.this")
)
// Exclude rules for 4.1.x from 4.0.0
diff --git a/python/pyspark/ml/classification.py
b/python/pyspark/ml/classification.py
index f69ecf115f5a..4b7f2e4da209 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -1678,6 +1678,7 @@ class _DecisionTreeClassifierParams(_DecisionTreeParams,
_TreeClassifierParams):
maxBins=32,
minInstancesPerNode=1,
minInfoGain=0.0,
+ pruneTree=True,
maxMemoryInMB=256,
cacheNodeIds=False,
checkpointInterval=10,
@@ -1789,6 +1790,7 @@ class DecisionTreeClassifier(
maxBins: int = 32,
minInstancesPerNode: int = 1,
minInfoGain: float = 0.0,
+ pruneTree: bool = True,
maxMemoryInMB: int = 256,
cacheNodeIds: bool = False,
checkpointInterval: int = 10,
@@ -1801,7 +1803,7 @@ class DecisionTreeClassifier(
"""
__init__(self, \\*, featuresCol="features", labelCol="label",
predictionCol="prediction", \
probabilityCol="probability",
rawPredictionCol="rawPrediction", \
- maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, \
+ maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, pruneTree=True, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini", \
seed=None, weightCol=None, leafCol="",
minWeightFractionPerNode=0.0)
"""
@@ -1826,6 +1828,7 @@ class DecisionTreeClassifier(
maxBins: int = 32,
minInstancesPerNode: int = 1,
minInfoGain: float = 0.0,
+ pruneTree: bool = True,
maxMemoryInMB: int = 256,
cacheNodeIds: bool = False,
checkpointInterval: int = 10,
@@ -1838,7 +1841,7 @@ class DecisionTreeClassifier(
"""
setParams(self, \\*, featuresCol="features", labelCol="label",
predictionCol="prediction", \
probabilityCol="probability",
rawPredictionCol="rawPrediction", \
- maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, \
+ maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, pruneTree=True, \
maxMemoryInMB=256, cacheNodeIds=False,
checkpointInterval=10, impurity="gini", \
seed=None, weightCol=None, leafCol="",
minWeightFractionPerNode=0.0)
Sets params for the DecisionTreeClassifier.
@@ -1861,6 +1864,13 @@ class DecisionTreeClassifier(
"""
return self._set(maxBins=value)
+ @since("4.3.0")
+ def setPruneTree(self, value: bool) -> "DecisionTreeClassifier":
+ """
+ Sets the value of :py:attr:`pruneTree`.
+ """
+ return self._set(pruneTree=value)
+
def setMinInstancesPerNode(self, value: int) -> "DecisionTreeClassifier":
"""
Sets the value of :py:attr:`minInstancesPerNode`.
@@ -1972,6 +1982,7 @@ class _RandomForestClassifierParams(_RandomForestParams,
_TreeClassifierParams):
maxBins=32,
minInstancesPerNode=1,
minInfoGain=0.0,
+ pruneTree=True,
maxMemoryInMB=256,
cacheNodeIds=False,
checkpointInterval=10,
@@ -2081,6 +2092,7 @@ class RandomForestClassifier(
maxBins: int = 32,
minInstancesPerNode: int = 1,
minInfoGain: float = 0.0,
+ pruneTree: bool = True,
maxMemoryInMB: int = 256,
cacheNodeIds: bool = False,
checkpointInterval: int = 10,
@@ -2097,7 +2109,7 @@ class RandomForestClassifier(
"""
__init__(self, \\*, featuresCol="features", labelCol="label",
predictionCol="prediction", \
probabilityCol="probability",
rawPredictionCol="rawPrediction", \
- maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, \
+ maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, pruneTree=True, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini", \
numTrees=20, featureSubsetStrategy="auto", seed=None,
subsamplingRate=1.0, \
leafCol="", minWeightFractionPerNode=0.0, weightCol=None,
bootstrap=True)
@@ -2123,6 +2135,7 @@ class RandomForestClassifier(
maxBins: int = 32,
minInstancesPerNode: int = 1,
minInfoGain: float = 0.0,
+ pruneTree: bool = True,
maxMemoryInMB: int = 256,
cacheNodeIds: bool = False,
checkpointInterval: int = 10,
@@ -2139,7 +2152,7 @@ class RandomForestClassifier(
"""
setParams(self, featuresCol="features", labelCol="label",
predictionCol="prediction", \
probabilityCol="probability",
rawPredictionCol="rawPrediction", \
- maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, \
+ maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, pruneTree=True, \
maxMemoryInMB=256, cacheNodeIds=False,
checkpointInterval=10, seed=None, \
impurity="gini", numTrees=20, featureSubsetStrategy="auto",
subsamplingRate=1.0, \
leafCol="", minWeightFractionPerNode=0.0, weightCol=None,
bootstrap=True)
@@ -2163,6 +2176,13 @@ class RandomForestClassifier(
"""
return self._set(maxBins=value)
+ @since("4.3.0")
+ def setPruneTree(self, value: bool) -> "RandomForestClassifier":
+ """
+ Sets the value of :py:attr:`pruneTree`.
+ """
+ return self._set(pruneTree=value)
+
def setMinInstancesPerNode(self, value: int) -> "RandomForestClassifier":
"""
Sets the value of :py:attr:`minInstancesPerNode`.
diff --git a/python/pyspark/ml/tree.py b/python/pyspark/ml/tree.py
index 63f58272aeef..92692ec225a7 100644
--- a/python/pyspark/ml/tree.py
+++ b/python/pyspark/ml/tree.py
@@ -415,6 +415,17 @@ class _TreeClassifierParams(Params):
typeConverter=TypeConverters.toString,
)
+ pruneTree: Param[bool] = Param(
+ Params._dummy(),
+ "pruneTree",
+ "If true, the trained tree will undergo a pruning process after
training, in which "
+ + "sibling leaf nodes with the same prediction are merged into their
parent. The "
+ + "resulting tree will be smaller and have faster predictions. Class
probabilities "
+ + "remain available after pruning. "
+ + "If false, no pruning is applied after training.",
+ typeConverter=TypeConverters.toBoolean,
+ )
+
def __init__(self) -> None:
super().__init__()
@@ -425,6 +436,13 @@ class _TreeClassifierParams(Params):
"""
return self.getOrDefault(self.impurity)
+ @since("4.3.0")
+ def getPruneTree(self) -> bool:
+ """
+ Gets the value of pruneTree or its default value.
+ """
+ return self.getOrDefault(self.pruneTree)
+
class _TreeRegressorParams(_HasVarianceImpurity):
"""
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]