Repository: spark
Updated Branches:
  refs/heads/master f42eaf42b -> cf823bead


http://git-wip-us.apache.org/repos/asf/spark/blob/cf823bea/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 9d92229..361366f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -17,11 +17,17 @@
 
 package org.apache.spark.ml.tree.impl
 
+import scala.collection.mutable
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, 
Node}
+import org.apache.spark.ml.tree._
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
-import org.apache.spark.mllib.tree.impurity.GiniCalculator
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, 
EnsembleTestHelper}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, 
QuantileStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.util.collection.OpenHashMap
@@ -33,6 +39,414 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 
   import RandomForestSuite.mapToVec
 
+  /////////////////////////////////////////////////////////////////////////////
+  // Tests for split calculation
+  /////////////////////////////////////////////////////////////////////////////
+
+  test("Binary classification with continuous features: split calculation") {
+    val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1()
+    assert(arr.length === 1000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+    assert(splits.length === 2)
+    assert(splits(0).length === 99)
+  }
+
+  test("Binary classification with binary (ordered) categorical features: 
split calculation") {
+    val arr = OldDTSuite.generateCategoricalDataPoints()
+    assert(arr.length === 1000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, 
numClasses = 2,
+      maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
+
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
+    assert(splits.length === 2)
+    // no splits pre-computed for ordered categorical features
+    assert(splits(0).length === 0)
+  }
+
+  test("Binary classification with 3-ary (ordered) categorical features," +
+    " with no samples for one category: split calculation") {
+    val arr = OldDTSuite.generateCategoricalDataPoints()
+    assert(arr.length === 1000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, 
numClasses = 2,
+      maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
+    val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+    assert(splits.length === 2)
+    // no splits pre-computed for ordered categorical features
+    assert(splits(0).length === 0)
+  }
+
+  test("find splits for a continuous feature") {
+    // find splits for normal case
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(6), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array.fill(200000)(math.random)
+      val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
+      assert(splits.length === 5)
+      assert(fakeMetadata.numSplits(0) === 5)
+      assert(fakeMetadata.numBins(0) === 6)
+      // check returned splits are distinct
+      assert(splits.distinct.length === splits.length)
+    }
+
+    // find splits should not return identical splits
+    // when there are not enough split candidates, reduce the number of splits 
in metadata
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(5), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 
3).map(_.toDouble)
+      val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
+      assert(splits.length === 3)
+      // check returned splits are distinct
+      assert(splits.distinct.length === splits.length)
+    }
+
+    // find splits when most samples close to the minimum
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(3), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 
5).map(_.toDouble)
+      val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
+      assert(splits.length === 2)
+      assert(splits(0) === 2.0)
+      assert(splits(1) === 3.0)
+    }
+
+    // find splits when most samples close to the maximum
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(3), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2).map(_.toDouble)
+      val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
+      assert(splits.length === 1)
+      assert(splits(0) === 1.0)
+    }
+  }
+
+  test("Multiclass classification with unordered categorical features: split 
calculations") {
+    val arr = OldDTSuite.generateCategoricalDataPoints()
+    assert(arr.length === 1000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new OldStrategy(
+      OldAlgo.Classification,
+      Gini,
+      maxDepth = 2,
+      numClasses = 100,
+      maxBins = 100,
+      categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(metadata.isUnordered(featureIndex = 0))
+    assert(metadata.isUnordered(featureIndex = 1))
+    val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+    assert(splits.length === 2)
+    assert(splits(0).length === 3)
+    assert(metadata.numSplits(0) === 3)
+    assert(metadata.numBins(0) === 3)
+    assert(metadata.numSplits(1) === 3)
+    assert(metadata.numBins(1) === 3)
+
+    // Expecting 2^2 - 1 = 3 splits per feature
+    def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: 
Array[Double]): Unit = {
+      assert(s.featureIndex === featureIndex)
+      assert(s.isInstanceOf[CategoricalSplit])
+      val s0 = s.asInstanceOf[CategoricalSplit]
+      assert(s0.leftCategories === leftCategories)
+      assert(s0.numCategories === 3)  // for this unit test
+    }
+    // Feature 0
+    checkCategoricalSplit(splits(0)(0), 0, Array(0.0))
+    checkCategoricalSplit(splits(0)(1), 0, Array(1.0))
+    checkCategoricalSplit(splits(0)(2), 0, Array(0.0, 1.0))
+    // Feature 1
+    checkCategoricalSplit(splits(1)(0), 1, Array(0.0))
+    checkCategoricalSplit(splits(1)(1), 1, Array(1.0))
+    checkCategoricalSplit(splits(1)(2), 1, Array(0.0, 1.0))
+  }
+
+  test("Multiclass classification with ordered categorical features: split 
calculations") {
+    val arr = 
OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+    assert(arr.length === 3000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, 
numClasses = 100,
+      maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+    // 2^(10-1) - 1 > 100, so categorical features will be ordered
+
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
+    val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+    assert(splits.length === 2)
+    // no splits pre-computed for ordered categorical features
+    assert(splits(0).length === 0)
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+  // Tests of other algorithm internals
+  /////////////////////////////////////////////////////////////////////////////
+
+  test("extract categories from a number for multiclass classification") {
+    val l = RandomForest.extractMultiClassCategories(13, 10)
+    assert(l.length === 3)
+    assert(Seq(3.0, 2.0, 0.0) === l)
+  }
+
+  test("Avoid aggregation on the last level") {
+    val arr = Array(
+      LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+    val input = sc.parallelize(arr)
+
+    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 1,
+      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val splits = RandomForest.findSplits(input, metadata, seed = 42)
+
+    val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
+    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, 
withReplacement = false)
+
+    val topNode = LearningNode.emptyNode(nodeIndex = 1)
+    assert(topNode.isLeaf === false)
+    assert(topNode.stats === null)
+
+    val nodesForGroup = Map((0, Array(topNode)))
+    val treeToNodeToIndexInfo = Map((0, Map(
+      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+    )))
+    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+    RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
+      nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
+
+    // don't enqueue leaf nodes into node queue
+    assert(nodeQueue.isEmpty)
+
+    // set impurity and predict for topNode
+    assert(topNode.stats !== null)
+    assert(topNode.stats.impurity > 0.0)
+
+    // set impurity and predict for child nodes
+    assert(topNode.leftChild.get.toNode.prediction === 0.0)
+    assert(topNode.rightChild.get.toNode.prediction === 1.0)
+    assert(topNode.leftChild.get.stats.impurity === 0.0)
+    assert(topNode.rightChild.get.stats.impurity === 0.0)
+  }
+
+  test("Avoid aggregation if impurity is 0.0") {
+    val arr = Array(
+      LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+    val input = sc.parallelize(arr)
+
+    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 5,
+      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val splits = RandomForest.findSplits(input, metadata, seed = 42)
+
+    val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
+    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, 
withReplacement = false)
+
+    val topNode = LearningNode.emptyNode(nodeIndex = 1)
+    assert(topNode.isLeaf === false)
+    assert(topNode.stats === null)
+
+    val nodesForGroup = Map((0, Array(topNode)))
+    val treeToNodeToIndexInfo = Map((0, Map(
+      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+    )))
+    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+    RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
+      nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
+
+    // don't enqueue a node into node queue if its impurity is 0.0
+    assert(nodeQueue.isEmpty)
+
+    // set impurity and predict for topNode
+    assert(topNode.stats !== null)
+    assert(topNode.stats.impurity > 0.0)
+
+    // set impurity and predict for child nodes
+    assert(topNode.leftChild.get.toNode.prediction === 0.0)
+    assert(topNode.rightChild.get.toNode.prediction === 1.0)
+    assert(topNode.leftChild.get.stats.impurity === 0.0)
+    assert(topNode.rightChild.get.stats.impurity === 0.0)
+  }
+
+  test("Use soft prediction for binary classification with ordered categorical 
features") {
+    // The following dataset is set up such that the best split is {1} vs. {0, 
2}.
+    // If the hard prediction is used to order the categories, then {0} vs. 
{1, 2} is chosen.
+    val arr = Array(
+      LabeledPoint(0.0, Vectors.dense(0.0)),
+      LabeledPoint(0.0, Vectors.dense(0.0)),
+      LabeledPoint(0.0, Vectors.dense(0.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0)),
+      LabeledPoint(1.0, Vectors.dense(2.0)))
+    val input = sc.parallelize(arr)
+
+    // Must set maxBins s.t. the feature will be treated as an ordered 
categorical feature.
+    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 1,
+      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
+
+    val model = RandomForest.run(input, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
+      seed = 42).head
+    model.rootNode match {
+      case n: InternalNode => n.split match {
+        case s: CategoricalSplit =>
+          assert(s.leftCategories === Array(1.0))
+      }
+    }
+  }
+
+  test("Second level node building with vs. without groups") {
+    val arr = OldDTSuite.generateOrderedLabeledPoints()
+    assert(arr.length === 1000)
+    val rdd = sc.parallelize(arr)
+    // For tree with 1 group
+    val strategy1 =
+      new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, 
maxMemoryInMB = 1000)
+    // For tree with multiple groups
+    val strategy2 =
+      new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, 
maxMemoryInMB = 0)
+
+    val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1, 
featureSubsetStrategy = "all",
+      seed = 42).head
+    val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1, 
featureSubsetStrategy = "all",
+      seed = 42).head
+
+    def getChildren(rootNode: Node): Array[InternalNode] = rootNode match {
+      case n: InternalNode =>
+        assert(n.leftChild.isInstanceOf[InternalNode])
+        assert(n.rightChild.isInstanceOf[InternalNode])
+        Array(n.leftChild.asInstanceOf[InternalNode], 
n.rightChild.asInstanceOf[InternalNode])
+    }
+
+    // Single group second level tree construction.
+    val children1 = getChildren(tree1.rootNode)
+    val children2 = getChildren(tree2.rootNode)
+
+    // Verify whether the splits obtained using single group and multiple 
group level
+    // construction strategies are the same.
+    for (i <- 0 until 2) {
+      assert(children1(i).gain > 0)
+      assert(children2(i).gain > 0)
+      assert(children1(i).split === children2(i).split)
+      assert(children1(i).impurity === children2(i).impurity)
+      assert(children1(i).impurityStats.stats === 
children2(i).impurityStats.stats)
+      assert(children1(i).leftChild.impurity === 
children2(i).leftChild.impurity)
+      assert(children1(i).rightChild.impurity === 
children2(i).rightChild.impurity)
+      assert(children1(i).prediction === children2(i).prediction)
+    }
+  }
+
+  def 
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: 
OldStrategy) {
+    val numFeatures = 50
+    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 
1000)
+    val rdd = sc.parallelize(arr)
+
+    // Select feature subset for top nodes.  Return true if OK.
+    def checkFeatureSubsetStrategy(
+        numTrees: Int,
+        featureSubsetStrategy: String,
+        numFeaturesPerNode: Int): Unit = {
+      val seeds = Array(123, 5354, 230, 349867, 23987)
+      val maxMemoryUsage: Long = 128 * 1024L * 1024L
+      val metadata =
+        DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, 
featureSubsetStrategy)
+      seeds.foreach { seed =>
+        val failString = s"Failed on test with:" +
+          s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," 
+
+          s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
+        val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+        val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees)
+        Range(0, numTrees).foreach { treeIndex =>
+          topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1)
+          nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
+        }
+        val rng = new scala.util.Random(seed = seed)
+        val (nodesForGroup: Map[Int, Array[LearningNode]],
+        treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) 
=
+          RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, 
rng)
+
+        assert(nodesForGroup.size === numTrees, failString)
+        assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 
node per tree
+
+        if (numFeaturesPerNode == numFeatures) {
+          // featureSubset values should all be None
+          
assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
+            failString)
+        } else {
+          // Check number of features.
+          assert(treeToNodeToIndexInfo.values.forall(_.values.forall(
+            _.featureSubset.get.length === numFeaturesPerNode)), failString)
+        }
+      }
+    }
+
+    checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures)
+    checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures)
+    checkFeatureSubsetStrategy(numTrees = 1, "sqrt", 
math.sqrt(numFeatures).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 1, "log2",
+      (math.log(numFeatures) / math.log(2)).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 
3.0).ceil.toInt)
+
+    checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
+    checkFeatureSubsetStrategy(numTrees = 2, "auto", 
math.sqrt(numFeatures).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 2, "sqrt", 
math.sqrt(numFeatures).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 2, "log2",
+      (math.log(numFeatures) / math.log(2)).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 
3.0).ceil.toInt)
+  }
+
+  test("Binary classification with continuous features: subsampling features") 
{
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 2,
+      numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+    
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+  }
+
+  test("Binary classification with continuous features and node Id cache: 
subsampling features") {
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 2,
+      numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+      useNodeIdCache = true)
+    
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+  }
+
   test("computeFeatureImportance, featureImportances") {
     /* Build tree for testing, with this structure:
           grandParent

http://git-wip-us.apache.org/repos/asf/spark/blob/cf823bea/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 89b64fc..bb1041b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -18,430 +18,23 @@
 package org.apache.spark.mllib.tree
 
 import scala.collection.JavaConverters._
-import scala.collection.mutable
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, 
TreePoint}
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.util.Utils
 
 
 class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
 
   /////////////////////////////////////////////////////////////////////////////
-  // Tests examining individual elements of training
-  /////////////////////////////////////////////////////////////////////////////
-
-  test("Binary classification with continuous features: split and bin 
calculation") {
-    val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
-    assert(arr.length === 1000)
-    val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
-    assert(!metadata.isUnordered(featureIndex = 0))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    assert(splits.length === 2)
-    assert(bins.length === 2)
-    assert(splits(0).length === 99)
-    assert(bins(0).length === 100)
-  }
-
-  test("Binary classification with binary (ordered) categorical features:" +
-    " split and bin calculation") {
-    val arr = DecisionTreeSuite.generateCategoricalDataPoints()
-    assert(arr.length === 1000)
-    val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(
-      Classification,
-      Gini,
-      maxDepth = 2,
-      numClasses = 2,
-      maxBins = 100,
-      categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
-
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    assert(!metadata.isUnordered(featureIndex = 0))
-    assert(!metadata.isUnordered(featureIndex = 1))
-    assert(splits.length === 2)
-    assert(bins.length === 2)
-    // no bins or splits pre-computed for ordered categorical features
-    assert(splits(0).length === 0)
-    assert(bins(0).length === 0)
-  }
-
-  test("Binary classification with 3-ary (ordered) categorical features," +
-    " with no samples for one category") {
-    val arr = DecisionTreeSuite.generateCategoricalDataPoints()
-    assert(arr.length === 1000)
-    val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(
-      Classification,
-      Gini,
-      maxDepth = 2,
-      numClasses = 2,
-      maxBins = 100,
-      categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
-
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
-    assert(!metadata.isUnordered(featureIndex = 0))
-    assert(!metadata.isUnordered(featureIndex = 1))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    assert(splits.length === 2)
-    assert(bins.length === 2)
-    // no bins or splits pre-computed for ordered categorical features
-    assert(splits(0).length === 0)
-    assert(bins(0).length === 0)
-  }
-
-  test("extract categories from a number for multiclass classification") {
-    val l = DecisionTree.extractMultiClassCategories(13, 10)
-    assert(l.length === 3)
-    assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
-  }
-
-  test("find splits for a continuous feature") {
-    // find splits for normal case
-    {
-      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
-        Map(), Set(),
-        Array(6), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
-      )
-      val featureSamples = Array.fill(200000)(math.random)
-      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
-      assert(splits.length === 5)
-      assert(fakeMetadata.numSplits(0) === 5)
-      assert(fakeMetadata.numBins(0) === 6)
-      // check returned splits are distinct
-      assert(splits.distinct.length === splits.length)
-    }
-
-    // find splits should not return identical splits
-    // when there are not enough split candidates, reduce the number of splits 
in metadata
-    {
-      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
-        Map(), Set(),
-        Array(5), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
-      )
-      val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 
3).map(_.toDouble)
-      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
-      assert(splits.length === 3)
-      // check returned splits are distinct
-      assert(splits.distinct.length === splits.length)
-    }
-
-    // find splits when most samples close to the minimum
-    {
-      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
-        Map(), Set(),
-        Array(3), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
-      )
-      val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 
5).map(_.toDouble)
-      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
-      assert(splits.length === 2)
-      assert(splits(0) === 2.0)
-      assert(splits(1) === 3.0)
-    }
-
-    // find splits when most samples close to the maximum
-    {
-      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
-        Map(), Set(),
-        Array(3), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
-      )
-      val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2).map(_.toDouble)
-      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
-      assert(splits.length === 1)
-      assert(splits(0) === 1.0)
-    }
-  }
-
-  test("Multiclass classification with unordered categorical features:" +
-      " split and bin calculations") {
-    val arr = DecisionTreeSuite.generateCategoricalDataPoints()
-    assert(arr.length === 1000)
-    val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(
-      Classification,
-      Gini,
-      maxDepth = 2,
-      numClasses = 100,
-      maxBins = 100,
-      categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
-
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
-    assert(metadata.isUnordered(featureIndex = 0))
-    assert(metadata.isUnordered(featureIndex = 1))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    assert(splits.length === 2)
-    assert(bins.length === 2)
-    assert(splits(0).length === 3)
-    assert(bins(0).length === 0)
-    assert(metadata.numSplits(0) === 3)
-    assert(metadata.numBins(0) === 3)
-    assert(metadata.numSplits(1) === 3)
-    assert(metadata.numBins(1) === 3)
-
-    // Expecting 2^2 - 1 = 3 bins/splits
-    assert(splits(0)(0).feature === 0)
-    assert(splits(0)(0).threshold === Double.MinValue)
-    assert(splits(0)(0).featureType === Categorical)
-    assert(splits(0)(0).categories.length === 1)
-    assert(splits(0)(0).categories.contains(0.0))
-    assert(splits(1)(0).feature === 1)
-    assert(splits(1)(0).threshold === Double.MinValue)
-    assert(splits(1)(0).featureType === Categorical)
-    assert(splits(1)(0).categories.length === 1)
-    assert(splits(1)(0).categories.contains(0.0))
-
-    assert(splits(0)(1).feature === 0)
-    assert(splits(0)(1).threshold === Double.MinValue)
-    assert(splits(0)(1).featureType === Categorical)
-    assert(splits(0)(1).categories.length === 1)
-    assert(splits(0)(1).categories.contains(1.0))
-    assert(splits(1)(1).feature === 1)
-    assert(splits(1)(1).threshold === Double.MinValue)
-    assert(splits(1)(1).featureType === Categorical)
-    assert(splits(1)(1).categories.length === 1)
-    assert(splits(1)(1).categories.contains(1.0))
-
-    assert(splits(0)(2).feature === 0)
-    assert(splits(0)(2).threshold === Double.MinValue)
-    assert(splits(0)(2).featureType === Categorical)
-    assert(splits(0)(2).categories.length === 2)
-    assert(splits(0)(2).categories.contains(0.0))
-    assert(splits(0)(2).categories.contains(1.0))
-    assert(splits(1)(2).feature === 1)
-    assert(splits(1)(2).threshold === Double.MinValue)
-    assert(splits(1)(2).featureType === Categorical)
-    assert(splits(1)(2).categories.length === 2)
-    assert(splits(1)(2).categories.contains(0.0))
-    assert(splits(1)(2).categories.contains(1.0))
-
-  }
-
-  test("Multiclass classification with ordered categorical features: split and 
bin calculations") {
-    val arr = 
DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
-    assert(arr.length === 3000)
-    val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(
-      Classification,
-      Gini,
-      maxDepth = 2,
-      numClasses = 100,
-      maxBins = 100,
-      categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
-    // 2^(10-1) - 1 > 100, so categorical features will be ordered
-
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
-    assert(!metadata.isUnordered(featureIndex = 0))
-    assert(!metadata.isUnordered(featureIndex = 1))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    assert(splits.length === 2)
-    assert(bins.length === 2)
-    // no bins or splits pre-computed for ordered categorical features
-    assert(splits(0).length === 0)
-    assert(bins(0).length === 0)
-  }
-
-  test("Avoid aggregation on the last level") {
-    val arr = Array(
-      LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
-      LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
-      LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
-      LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
-    val input = sc.parallelize(arr)
-
-    val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 1,
-      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
-    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
-    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
-    val topNode = Node.emptyNode(nodeIndex = 1)
-    assert(topNode.predict.predict === Double.MinValue)
-    assert(topNode.impurity === -1.0)
-    assert(topNode.isLeaf === false)
-
-    val nodesForGroup = Map((0, Array(topNode)))
-    val treeToNodeToIndexInfo = Map((0, Map(
-      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
-    )))
-    val nodeQueue = new mutable.Queue[(Int, Node)]()
-    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
-      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
-    // don't enqueue leaf nodes into node queue
-    assert(nodeQueue.isEmpty)
-
-    // set impurity and predict for topNode
-    assert(topNode.predict.predict !== Double.MinValue)
-    assert(topNode.impurity !== -1.0)
-
-    // set impurity and predict for child nodes
-    assert(topNode.leftNode.get.predict.predict === 0.0)
-    assert(topNode.rightNode.get.predict.predict === 1.0)
-    assert(topNode.leftNode.get.impurity === 0.0)
-    assert(topNode.rightNode.get.impurity === 0.0)
-  }
-
-  test("Avoid aggregation if impurity is 0.0") {
-    val arr = Array(
-      LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
-      LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
-      LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
-      LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
-    val input = sc.parallelize(arr)
-
-    val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 5,
-      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
-    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
-    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
-    val topNode = Node.emptyNode(nodeIndex = 1)
-    assert(topNode.predict.predict === Double.MinValue)
-    assert(topNode.impurity === -1.0)
-    assert(topNode.isLeaf === false)
-
-    val nodesForGroup = Map((0, Array(topNode)))
-    val treeToNodeToIndexInfo = Map((0, Map(
-      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
-    )))
-    val nodeQueue = new mutable.Queue[(Int, Node)]()
-    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
-      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
-    // don't enqueue a node into node queue if its impurity is 0.0
-    assert(nodeQueue.isEmpty)
-
-    // set impurity and predict for topNode
-    assert(topNode.predict.predict !== Double.MinValue)
-    assert(topNode.impurity !== -1.0)
-
-    // set impurity and predict for child nodes
-    assert(topNode.leftNode.get.predict.predict === 0.0)
-    assert(topNode.rightNode.get.predict.predict === 1.0)
-    assert(topNode.leftNode.get.impurity === 0.0)
-    assert(topNode.rightNode.get.impurity === 0.0)
-  }
-
-  test("Use soft prediction for binary classification with ordered categorical 
features") {
-    // The following dataset is set up such that the best split is {1} vs. {0, 
2}.
-    // If the hard prediction is used to order the categories, then {0} vs. 
{1, 2} is chosen.
-    val arr = Array(
-      LabeledPoint(0.0, Vectors.dense(0.0)),
-      LabeledPoint(0.0, Vectors.dense(0.0)),
-      LabeledPoint(0.0, Vectors.dense(0.0)),
-      LabeledPoint(1.0, Vectors.dense(0.0)),
-      LabeledPoint(0.0, Vectors.dense(1.0)),
-      LabeledPoint(0.0, Vectors.dense(1.0)),
-      LabeledPoint(0.0, Vectors.dense(1.0)),
-      LabeledPoint(0.0, Vectors.dense(1.0)),
-      LabeledPoint(0.0, Vectors.dense(2.0)),
-      LabeledPoint(0.0, Vectors.dense(2.0)),
-      LabeledPoint(0.0, Vectors.dense(2.0)),
-      LabeledPoint(1.0, Vectors.dense(2.0)))
-    val input = sc.parallelize(arr)
-
-    // Must set maxBins s.t. the feature will be treated as an ordered 
categorical feature.
-    val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 1,
-      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
-
-    val model = new DecisionTree(strategy).run(input)
-    model.topNode.split.get match {
-      case Split(_, _, _, categories: List[Double]) =>
-        assert(categories === List(1.0))
-    }
-  }
-
-  test("Second level node building with vs. without groups") {
-    val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
-    assert(arr.length === 1000)
-    val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    assert(splits.length === 2)
-    assert(splits(0).length === 99)
-    assert(bins.length === 2)
-    assert(bins(0).length === 100)
-
-    // Train a 1-node model
-    val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
-      numClasses = 2, maxBins = 100)
-    val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
-    val rootNode1 = modelOneNode.topNode.deepCopy()
-    val rootNode2 = modelOneNode.topNode.deepCopy()
-    assert(rootNode1.leftNode.nonEmpty)
-    assert(rootNode1.rightNode.nonEmpty)
-
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
-    // Single group second level tree construction.
-    val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, 
rootNode1.rightNode.get)))
-    val treeToNodeToIndexInfo = Map((0, Map(
-      (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
-      (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
-    val nodeQueue = new mutable.Queue[(Int, Node)]()
-    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
-      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-    val children1 = new Array[Node](2)
-    children1(0) = rootNode1.leftNode.get
-    children1(1) = rootNode1.rightNode.get
-
-    // Train one second-level node at a time.
-    val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
-    val treeToNodeToIndexInfoA = Map((0, Map(
-      (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
-    nodeQueue.clear()
-    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
-      nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
-    val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
-    val treeToNodeToIndexInfoB = Map((0, Map(
-      (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
-    nodeQueue.clear()
-    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
-      nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
-    val children2 = new Array[Node](2)
-    children2(0) = rootNode2.leftNode.get
-    children2(1) = rootNode2.rightNode.get
-
-    // Verify whether the splits obtained using single group and multiple 
group level
-    // construction strategies are the same.
-    for (i <- 0 until 2) {
-      assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
-      assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
-      assert(children1(i).split === children2(i).split)
-      assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
-      val stats1 = children1(i).stats.get
-      val stats2 = children2(i).stats.get
-      assert(stats1.gain === stats2.gain)
-      assert(stats1.impurity === stats2.impurity)
-      assert(stats1.leftImpurity === stats2.leftImpurity)
-      assert(stats1.rightImpurity === stats2.rightImpurity)
-      assert(children1(i).predict.predict === children2(i).predict.predict)
-    }
-  }
-
-  /////////////////////////////////////////////////////////////////////////////
   // Tests calling train()
   /////////////////////////////////////////////////////////////////////////////
 
@@ -457,22 +50,11 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
 
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
-    assert(!metadata.isUnordered(featureIndex = 0))
-    assert(!metadata.isUnordered(featureIndex = 1))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    assert(splits.length === 2)
-    assert(bins.length === 2)
-    // no bins or splits pre-computed for ordered categorical features
-    assert(splits(0).length === 0)
-    assert(bins(0).length === 0)
-
     val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
     assert(split.categories === List(1.0))
     assert(split.featureType === Categorical)
-    assert(split.threshold === Double.MinValue)
 
     val stats = rootNode.stats.get
     assert(stats.gain > 0)
@@ -501,7 +83,6 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(split.categories.length === 1)
     assert(split.categories.contains(1.0))
     assert(split.featureType === Categorical)
-    assert(split.threshold === Double.MinValue)
 
     val stats = rootNode.stats.get
     assert(stats.gain > 0)
@@ -539,18 +120,11 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    assert(splits.length === 2)
-    assert(splits(0).length === 99)
-    assert(bins.length === 2)
-    assert(bins(0).length === 100)
-
     val rootNode = DecisionTree.train(rdd, strategy).topNode
 
-    val stats = rootNode.stats.get
-    assert(stats.gain === 0)
-    assert(stats.leftImpurity === 0)
-    assert(stats.rightImpurity === 0)
+    assert(rootNode.impurity === 0)
+    assert(rootNode.stats.isEmpty)
+    assert(rootNode.predict.predict === 0)
   }
 
   test("Binary classification stump with fixed label 1 for Gini") {
@@ -563,18 +137,10 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    assert(splits.length === 2)
-    assert(splits(0).length === 99)
-    assert(bins.length === 2)
-    assert(bins(0).length === 100)
-
     val rootNode = DecisionTree.train(rdd, strategy).topNode
 
-    val stats = rootNode.stats.get
-    assert(stats.gain === 0)
-    assert(stats.leftImpurity === 0)
-    assert(stats.rightImpurity === 0)
+    assert(rootNode.impurity === 0)
+    assert(rootNode.stats.isEmpty)
     assert(rootNode.predict.predict === 1)
   }
 
@@ -588,18 +154,10 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    assert(splits.length === 2)
-    assert(splits(0).length === 99)
-    assert(bins.length === 2)
-    assert(bins(0).length === 100)
-
     val rootNode = DecisionTree.train(rdd, strategy).topNode
 
-    val stats = rootNode.stats.get
-    assert(stats.gain === 0)
-    assert(stats.leftImpurity === 0)
-    assert(stats.rightImpurity === 0)
+    assert(rootNode.impurity === 0)
+    assert(rootNode.stats.isEmpty)
     assert(rootNode.predict.predict === 0)
   }
 
@@ -613,18 +171,10 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    assert(splits.length === 2)
-    assert(splits(0).length === 99)
-    assert(bins.length === 2)
-    assert(bins(0).length === 100)
-
     val rootNode = DecisionTree.train(rdd, strategy).topNode
 
-    val stats = rootNode.stats.get
-    assert(stats.gain === 0)
-    assert(stats.leftImpurity === 0)
-    assert(stats.rightImpurity === 0)
+    assert(rootNode.impurity === 0)
+    assert(rootNode.stats.isEmpty)
     assert(rootNode.predict.predict === 1)
   }
 
@@ -718,7 +268,6 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClasses = 3, maxBins = 100)
     assert(strategy.isMulticlassClassification)
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
 
     val model = DecisionTree.train(rdd, strategy)
     DecisionTreeSuite.validateClassifier(model, arr, 0.9)
@@ -807,8 +356,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     // test when no valid split can be found
     val rootNode = model.topNode
 
-    val gain = rootNode.stats.get
-    assert(gain == InformationGainStats.invalidInformationGainStats)
+    assert(rootNode.stats.isEmpty)
   }
 
   test("do not choose split that does not satisfy min instance per node 
requirements") {
@@ -828,9 +376,10 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
-    val gain = rootNode.stats.get
+    val gainStats = rootNode.stats.get
     assert(split.feature == 1)
-    assert(gain != InformationGainStats.invalidInformationGainStats)
+    assert(gainStats.gain >= 0)
+    assert(gainStats.impurity >= 0)
   }
 
   test("split must satisfy min info gain requirements") {
@@ -852,10 +401,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     }
 
     // test when no valid split can be found
-    val rootNode = model.topNode
-
-    val gain = rootNode.stats.get
-    assert(gain == InformationGainStats.invalidInformationGainStats)
+    assert(model.topNode.stats.isEmpty)
   }
 
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/cf823bea/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index c72fc9b..bec61ba 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -17,16 +17,13 @@
 
 package org.apache.spark.mllib.tree
 
-import scala.collection.mutable
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
 import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
-import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
+import org.apache.spark.mllib.tree.model.RandomForestModel
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.util.Utils
 
@@ -42,7 +39,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 
     val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
       featureSubsetStrategy = "auto", seed = 123)
-    assert(rf.trees.size === 1)
+    assert(rf.trees.length === 1)
     val rfTree = rf.trees(0)
 
     val dt = DecisionTree.train(rdd, strategy)
@@ -78,7 +75,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 
     val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
       featureSubsetStrategy = "auto", seed = 123)
-    assert(rf.trees.size === 1)
+    assert(rf.trees.length === 1)
     val rfTree = rf.trees(0)
 
     val dt = DecisionTree.train(rdd, strategy)
@@ -108,80 +105,6 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     regressionTestWithContinuousFeatures(strategy)
   }
 
-  def 
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: 
Strategy) {
-    val numFeatures = 50
-    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 
1000)
-    val rdd = sc.parallelize(arr)
-
-    // Select feature subset for top nodes.  Return true if OK.
-    def checkFeatureSubsetStrategy(
-        numTrees: Int,
-        featureSubsetStrategy: String,
-        numFeaturesPerNode: Int): Unit = {
-      val seeds = Array(123, 5354, 230, 349867, 23987)
-      val maxMemoryUsage: Long = 128 * 1024L * 1024L
-      val metadata =
-        DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, 
featureSubsetStrategy)
-      seeds.foreach { seed =>
-        val failString = s"Failed on test with:" +
-          s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," 
+
-          s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
-        val nodeQueue = new mutable.Queue[(Int, Node)]()
-        val topNodes: Array[Node] = new Array[Node](numTrees)
-        Range(0, numTrees).foreach { treeIndex =>
-          topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1)
-          nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
-        }
-        val rng = new scala.util.Random(seed = seed)
-        val (nodesForGroup: Map[Int, Array[Node]],
-            treeToNodeToIndexInfo: Map[Int, Map[Int, 
RandomForest.NodeIndexInfo]]) =
-          RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, 
rng)
-
-        assert(nodesForGroup.size === numTrees, failString)
-        assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node 
per tree
-
-        if (numFeaturesPerNode == numFeatures) {
-          // featureSubset values should all be None
-          
assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
-            failString)
-        } else {
-          // Check number of features.
-          assert(treeToNodeToIndexInfo.values.forall(_.values.forall(
-            _.featureSubset.get.size === numFeaturesPerNode)), failString)
-        }
-      }
-    }
-
-    checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures)
-    checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures)
-    checkFeatureSubsetStrategy(numTrees = 1, "sqrt", 
math.sqrt(numFeatures).ceil.toInt)
-    checkFeatureSubsetStrategy(numTrees = 1, "log2",
-      (math.log(numFeatures) / math.log(2)).ceil.toInt)
-    checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 
3.0).ceil.toInt)
-
-    checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
-    checkFeatureSubsetStrategy(numTrees = 2, "auto", 
math.sqrt(numFeatures).ceil.toInt)
-    checkFeatureSubsetStrategy(numTrees = 2, "sqrt", 
math.sqrt(numFeatures).ceil.toInt)
-    checkFeatureSubsetStrategy(numTrees = 2, "log2",
-      (math.log(numFeatures) / math.log(2)).ceil.toInt)
-    checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 
3.0).ceil.toInt)
-  }
-
-  test("Binary classification with continuous features: subsampling features") 
{
-    val categoricalFeaturesInfo = Map.empty[Int, Int]
-    val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 2,
-      numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
-    
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
-  }
-
-  test("Binary classification with continuous features and node Id cache: 
subsampling features") {
-    val categoricalFeaturesInfo = Map.empty[Int, Int]
-    val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 2,
-      numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
-      useNodeIdCache = true)
-    
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
-  }
-
   test("alternating categorical and continuous features with multiclass labels 
to test indexing") {
     val arr = new Array[LabeledPoint](4)
     arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))

http://git-wip-us.apache.org/repos/asf/spark/blob/cf823bea/python/pyspark/ml/param/_shared_params_code_gen.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py 
b/python/pyspark/ml/param/_shared_params_code_gen.py
index 7dd2937..715fa9e 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -164,7 +164,8 @@ if __name__ == "__main__":
          "split will be discarded as invalid. Should be >= 1.", 
"TypeConverters.toInt"),
         ("minInfoGain", "Minimum information gain for a split to be considered 
at a tree node.",
          "TypeConverters.toFloat"),
-        ("maxMemoryInMB", "Maximum memory in MB allocated to histogram 
aggregation.",
+        ("maxMemoryInMB", "Maximum memory in MB allocated to histogram 
aggregation. If too small," +
+         " then 1 node will be split per iteration, and its aggregates may 
exceed this size.",
          "TypeConverters.toInt"),
         ("cacheNodeIds", "If false, the algorithm will pass trees to executors 
to match " +
          "instances with nodes. If true, the algorithm will cache node IDs for 
each instance. " +

http://git-wip-us.apache.org/repos/asf/spark/blob/cf823bea/python/pyspark/ml/param/shared.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/shared.py 
b/python/pyspark/ml/param/shared.py
index 83fbd59..d79d55e 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -568,7 +568,7 @@ class DecisionTreeParams(Params):
     maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for 
discretizing continuous features.  Must be >=2 and >= number of categories for 
any categorical feature.", typeConverter=TypeConverters.toInt)
     minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", 
"Minimum number of instances each child must have after split. If a split 
causes the left or right child to have fewer than minInstancesPerNode, the 
split will be discarded as invalid. Should be >= 1.", 
typeConverter=TypeConverters.toInt)
     minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information 
gain for a split to be considered at a tree node.", 
typeConverter=TypeConverters.toFloat)
-    maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in 
MB allocated to histogram aggregation.", typeConverter=TypeConverters.toInt)
+    maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in 
MB allocated to histogram aggregation. If too small, then 1 node will be split 
per iteration, and its aggregates may exceed this size.", 
typeConverter=TypeConverters.toInt)
     cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the 
algorithm will pass trees to executors to match instances with nodes. If true, 
the algorithm will cache node IDs for each instance. Caching can speed up 
training of deeper trees. Users can set how often should the cache be 
checkpointed or disable it by setting checkpointInterval.", 
typeConverter=TypeConverters.toBoolean)
     
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to