Repository: spark
Updated Branches:
  refs/heads/master 487377e69 -> 9e26473c0


[SPARK-3159][ML] Add decision tree pruning

## What changes were proposed in this pull request?

Added subtree pruning in the translation from LearningNode to Node: a learning 
node having a single prediction value for all the leaves in the subtree rooted 
at it is translated into a LeafNode, instead of a (redundant) InternalNode

## How was this patch tested?

Added two unit tests under 
"mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala":
- test("SPARK-3159 tree model redundancy - classification")
- test("SPARK-3159 tree model redundancy - regression")

4 existing unit tests relying on the tree structure (existence of a specific 
redundant subtree) had to be adapted as the tested components in the output 
tree are now pruned (fixed by adding an extra _prune_ parameter which can be 
used to disable pruning for testing)

Author: Alessandro Solimando <18898964+asolima...@users.noreply.github.com>

Closes #20632 from asolimando/master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9e26473c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9e26473c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9e26473c

Branch: refs/heads/master
Commit: 9e26473c0f29ee4281519104ac5e182a3bd4bf23
Parents: 487377e
Author: Alessandro Solimando <18898964+asolima...@users.noreply.github.com>
Authored: Fri Mar 2 16:24:29 2018 -0800
Committer: sethah <shendrick...@cloudera.com>
Committed: Fri Mar 2 16:24:29 2018 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/tree/Node.scala   |  22 ++--
 .../spark/ml/tree/impl/RandomForest.scala       |  10 +-
 .../DecisionTreeClassifierSuite.scala           |  38 -------
 .../spark/ml/tree/impl/RandomForestSuite.scala  | 100 +++++++++++++++++--
 .../spark/mllib/tree/DecisionTreeSuite.scala    |  10 +-
 5 files changed, 115 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9e26473c/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index 07e98a1..d30be45 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -19,8 +19,7 @@ package org.apache.spark.ml.tree
 
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
-import org.apache.spark.mllib.tree.model.{ImpurityStats,
-  InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => 
OldPredict}
+import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats 
=> OldInformationGainStats, Node => OldNode, Predict => OldPredict}
 
 /**
  * Decision tree node interface.
@@ -266,15 +265,23 @@ private[tree] class LearningNode(
     var isLeaf: Boolean,
     var stats: ImpurityStats) extends Serializable {
 
+  def toNode: Node = toNode(prune = true)
+
   /**
    * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any 
children.
    */
-  def toNode: Node = {
-    if (leftChild.nonEmpty) {
-      assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
+  def toNode(prune: Boolean = true): Node = {
+
+    if (!leftChild.isEmpty || !rightChild.isEmpty) {
+      assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && 
stats != null,
         "Unknown error during Decision Tree learning.  Could not convert 
LearningNode to Node.")
-      new InternalNode(stats.impurityCalculator.predict, stats.impurity, 
stats.gain,
-        leftChild.get.toNode, rightChild.get.toNode, split.get, 
stats.impurityCalculator)
+      (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
+        case (l: LeafNode, r: LeafNode) if prune && l.prediction == 
r.prediction =>
+          new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
+        case (l, r) =>
+          new InternalNode(stats.impurityCalculator.predict, stats.impurity, 
stats.gain,
+            l, r, split.get, stats.impurityCalculator)
+      }
     } else {
       if (stats.valid) {
         new LeafNode(stats.impurityCalculator.predict, stats.impurity,
@@ -283,7 +290,6 @@ private[tree] class LearningNode(
         // Here we want to keep same behavior with the old 
mllib.DecisionTreeModel
         new LeafNode(stats.impurityCalculator.predict, -1.0, 
stats.impurityCalculator)
       }
-
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9e26473c/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index acfc639..8e514f1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging {
       featureSubsetStrategy: String,
       seed: Long,
       instr: Option[Instrumentation[_]],
+      prune: Boolean = true, // exposed for testing only, real trees are 
always pruned
       parentUID: Option[String] = None): Array[DecisionTreeModel] = {
 
     val timer = new TimeTracker()
@@ -223,22 +224,23 @@ private[spark] object RandomForest extends Logging {
       case Some(uid) =>
         if (strategy.algo == OldAlgo.Classification) {
           topNodes.map { rootNode =>
-            new DecisionTreeClassificationModel(uid, rootNode.toNode, 
numFeatures,
+            new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), 
numFeatures,
               strategy.getNumClasses)
           }
         } else {
           topNodes.map { rootNode =>
-            new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures)
+            new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), 
numFeatures)
           }
         }
       case None =>
         if (strategy.algo == OldAlgo.Classification) {
           topNodes.map { rootNode =>
-            new DecisionTreeClassificationModel(rootNode.toNode, numFeatures,
+            new DecisionTreeClassificationModel(rootNode.toNode(prune), 
numFeatures,
               strategy.getNumClasses)
           }
         } else {
-          topNodes.map(rootNode => new 
DecisionTreeRegressionModel(rootNode.toNode, numFeatures))
+          topNodes.map(rootNode =>
+            new DecisionTreeRegressionModel(rootNode.toNode(prune), 
numFeatures))
         }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/9e26473c/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 98c879e..38b265d 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -280,44 +280,6 @@ class DecisionTreeClassifierSuite
     dt.fit(df)
   }
 
-  test("Use soft prediction for binary classification with ordered categorical 
features") {
-    // The following dataset is set up such that the best split is {1} vs. {0, 
2}.
-    // If the hard prediction is used to order the categories, then {0} vs. 
{1, 2} is chosen.
-    val arr = Array(
-      LabeledPoint(0.0, Vectors.dense(0.0)),
-      LabeledPoint(0.0, Vectors.dense(0.0)),
-      LabeledPoint(0.0, Vectors.dense(0.0)),
-      LabeledPoint(1.0, Vectors.dense(0.0)),
-      LabeledPoint(0.0, Vectors.dense(1.0)),
-      LabeledPoint(0.0, Vectors.dense(1.0)),
-      LabeledPoint(0.0, Vectors.dense(1.0)),
-      LabeledPoint(0.0, Vectors.dense(1.0)),
-      LabeledPoint(0.0, Vectors.dense(2.0)),
-      LabeledPoint(0.0, Vectors.dense(2.0)),
-      LabeledPoint(0.0, Vectors.dense(2.0)),
-      LabeledPoint(1.0, Vectors.dense(2.0)))
-    val data = sc.parallelize(arr)
-    val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
-
-    // Must set maxBins s.t. the feature will be treated as an ordered 
categorical feature.
-    val dt = new DecisionTreeClassifier()
-      .setImpurity("gini")
-      .setMaxDepth(1)
-      .setMaxBins(3)
-    val model = dt.fit(df)
-    model.rootNode match {
-      case n: InternalNode =>
-        n.split match {
-          case s: CategoricalSplit =>
-            assert(s.leftCategories === Array(1.0))
-          case other =>
-            fail(s"All splits should be categorical, but got 
${other.getClass.getName}: $other.")
-        }
-      case other =>
-        fail(s"Root node should be an internal node, but got 
${other.getClass.getName}: $other.")
-    }
-  }
-
   test("Feature importance with toy data") {
     val dt = new DecisionTreeClassifier()
       .setImpurity("gini")

http://git-wip-us.apache.org/repos/asf/spark/blob/9e26473c/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index dbe2ea9..5f0d26e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.ml.tree.impl
 
+import scala.annotation.tailrec
 import scala.collection.mutable
 
 import org.apache.spark.SparkFunSuite
@@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 
   import RandomForestSuite.mapToVec
 
+  private val seed = 42
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests for split calculation
   /////////////////////////////////////////////////////////////////////////////
@@ -320,10 +323,10 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(topNode.isLeaf === false)
     assert(topNode.stats === null)
 
-    val nodesForGroup = Map((0, Array(topNode)))
-    val treeToNodeToIndexInfo = Map((0, Map(
-      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
-    )))
+    val nodesForGroup = Map(0 -> Array(topNode))
+    val treeToNodeToIndexInfo = Map(0 -> Map(
+      topNode.id -> new RandomForest.NodeIndexInfo(0, None)
+    ))
     val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
     RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
       nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
@@ -362,10 +365,10 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(topNode.isLeaf === false)
     assert(topNode.stats === null)
 
-    val nodesForGroup = Map((0, Array(topNode)))
-    val treeToNodeToIndexInfo = Map((0, Map(
-      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
-    )))
+    val nodesForGroup = Map(0 -> Array(topNode))
+    val treeToNodeToIndexInfo = Map(0 -> Map(
+      topNode.id -> new RandomForest.NodeIndexInfo(0, None)
+    ))
     val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
     RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
       nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
@@ -407,7 +410,8 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
 
     val model = RandomForest.run(input, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
-      seed = 42, instr = None).head
+      seed = 42, instr = None, prune = false).head
+
     model.rootNode match {
       case n: InternalNode => n.split match {
         case s: CategoricalSplit =>
@@ -631,13 +635,89 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
     assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
   }
+
+  
///////////////////////////////////////////////////////////////////////////////
+  // Tests for pruning of redundant subtrees (generated by a split improving 
the
+  // impurity measure, but always leading to the same prediction).
+  
///////////////////////////////////////////////////////////////////////////////
+
+  test("SPARK-3159 tree model redundancy - classification") {
+    // The following dataset is set up such that splitting over feature_1 for 
points having
+    // feature_0 = 0 improves the impurity measure, despite the prediction 
will always be 0
+    // in both branches.
+    val arr = Array(
+      LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+      LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+      LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
+      LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
+    )
+    val rdd = sc.parallelize(arr)
+
+    val numClasses = 2
+    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 4,
+      numClasses = numClasses, maxBins = 32)
+
+    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+      seed = 42, instr = None).head
+
+    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+      seed = 42, instr = None, prune = false).head
+
+    assert(prunedTree.numNodes === 5)
+    assert(unprunedTree.numNodes === 7)
+
+    assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === 
arr.size)
+  }
+
+  test("SPARK-3159 tree model redundancy - regression") {
+    // The following dataset is set up such that splitting over feature_0 for 
points having
+    // feature_1 = 1 improves the impurity measure, despite the prediction 
will always be 0.5
+    // in both branches.
+    val arr = Array(
+      LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+      LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
+      LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
+      LabeledPoint(0.5, Vectors.dense(1.0, 1.0))
+    )
+    val rdd = sc.parallelize(arr)
+
+    val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance, maxDepth = 4,
+      numClasses = 0, maxBins = 32)
+
+    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+      seed = 42, instr = None).head
+
+    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
+      seed = 42, instr = None, prune = false).head
+
+    assert(prunedTree.numNodes === 3)
+    assert(unprunedTree.numNodes === 5)
+    assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === 
arr.size)
+  }
 }
 
 private object RandomForestSuite {
-
   def mapToVec(map: Map[Int, Double]): Vector = {
     val size = (map.keys.toSeq :+ 0).max + 1
     val (indices, values) = map.toSeq.sortBy(_._1).unzip
     Vectors.sparse(size, indices.toArray, values.toArray)
   }
+
+  @tailrec
+  private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = {
+    if (nodes.isEmpty) {
+      acc
+    }
+    else {
+      nodes.head match {
+        case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild 
:: nodes.tail, acc)
+        case l: LeafNode => getSumLeafCounters(nodes.tail, acc + 
l.impurityStats.count)
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9e26473c/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 441d0f7..bc59f3f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     // if a split does not satisfy min instances per node requirements,
     // this split is invalid, even though the information gain of split is 
large.
     val arr = Array(
-      LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
-      LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
-      LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
-      LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
+      LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0, 0.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0, 0.0)))
 
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini,
@@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite {
     Array[LabeledPoint] = {
     val arr = new Array[LabeledPoint](3000)
     for (i <- 0 until 3000) {
-      if (i < 1000) {
+      if (i < 1001) {
         arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
       } else if (i < 2000) {
         arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to