Repository: spark Updated Branches: refs/heads/master f42eaf42b -> cf823bead
http://git-wip-us.apache.org/repos/asf/spark/blob/cf823bea/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 9d92229..361366f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -17,11 +17,17 @@ package org.apache.spark.ml.tree.impl +import scala.collection.mutable + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node} +import org.apache.spark.ml.tree._ import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.tree.impurity.GiniCalculator +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.collection.OpenHashMap @@ -33,6 +39,414 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestSuite.mapToVec + ///////////////////////////////////////////////////////////////////////////// + // Tests for split calculation + ///////////////////////////////////////////////////////////////////////////// + + test("Binary classification with continuous features: split calculation") { + val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + val splits = RandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + assert(splits(0).length === 99) + } + + test("Binary classification with binary (ordered) categorical features: split calculation") { + val arr = OldDTSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, + maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val splits = RandomForest.findSplits(rdd, metadata, seed = 42) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + assert(splits.length === 2) + // no splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + } + + test("Binary classification with 3-ary (ordered) categorical features," + + " with no samples for one category: split calculation") { + val arr = OldDTSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, + maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val splits = RandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + // no splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + } + + test("find splits for a continuous feature") { + // find splits for normal case + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(6), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array.fill(200000)(math.random) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 5) + assert(fakeMetadata.numSplits(0) === 5) + assert(fakeMetadata.numBins(0) === 6) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // find splits should not return identical splits + // when there are not enough split candidates, reduce the number of splits in metadata + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(5), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 3) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // find splits when most samples close to the minimum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 2) + assert(splits(0) === 2.0) + assert(splits(1) === 3.0) + } + + // find splits when most samples close to the maximum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 1) + assert(splits(0) === 1.0) + } + } + + test("Multiclass classification with unordered categorical features: split calculations") { + val arr = OldDTSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy( + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 100, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) + val splits = RandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + assert(splits(0).length === 3) + assert(metadata.numSplits(0) === 3) + assert(metadata.numBins(0) === 3) + assert(metadata.numSplits(1) === 3) + assert(metadata.numBins(1) === 3) + + // Expecting 2^2 - 1 = 3 splits per feature + def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = { + assert(s.featureIndex === featureIndex) + assert(s.isInstanceOf[CategoricalSplit]) + val s0 = s.asInstanceOf[CategoricalSplit] + assert(s0.leftCategories === leftCategories) + assert(s0.numCategories === 3) // for this unit test + } + // Feature 0 + checkCategoricalSplit(splits(0)(0), 0, Array(0.0)) + checkCategoricalSplit(splits(0)(1), 0, Array(1.0)) + checkCategoricalSplit(splits(0)(2), 0, Array(0.0, 1.0)) + // Feature 1 + checkCategoricalSplit(splits(1)(0), 1, Array(0.0)) + checkCategoricalSplit(splits(1)(1), 1, Array(1.0)) + checkCategoricalSplit(splits(1)(2), 1, Array(0.0, 1.0)) + } + + test("Multiclass classification with ordered categorical features: split calculations") { + val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + assert(arr.length === 3000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100, + maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + // 2^(10-1) - 1 > 100, so categorical features will be ordered + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val splits = RandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + // no splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of other algorithm internals + ///////////////////////////////////////////////////////////////////////////// + + test("extract categories from a number for multiclass classification") { + val l = RandomForest.extractMultiClassCategories(13, 10) + assert(l.length === 3) + assert(Seq(3.0, 2.0, 0.0) === l) + } + + test("Avoid aggregation on the last level") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val splits = RandomForest.findSplits(input, metadata, seed = 42) + + val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, withReplacement = false) + + val topNode = LearningNode.emptyNode(nodeIndex = 1) + assert(topNode.isLeaf === false) + assert(topNode.stats === null) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + RandomForest.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.stats !== null) + assert(topNode.stats.impurity > 0.0) + + // set impurity and predict for child nodes + assert(topNode.leftChild.get.toNode.prediction === 0.0) + assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.stats.impurity === 0.0) + assert(topNode.rightChild.get.stats.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val splits = RandomForest.findSplits(input, metadata, seed = 42) + + val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, withReplacement = false) + + val topNode = LearningNode.emptyNode(nodeIndex = 1) + assert(topNode.isLeaf === false) + assert(topNode.stats === null) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + RandomForest.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.stats !== null) + assert(topNode.stats.impurity > 0.0) + + // set impurity and predict for child nodes + assert(topNode.leftChild.get.toNode.prediction === 0.0) + assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.stats.impurity === 0.0) + assert(topNode.rightChild.get.stats.impurity === 0.0) + } + + test("Use soft prediction for binary classification with ordered categorical features") { + // The following dataset is set up such that the best split is {1} vs. {0, 2}. + // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(2.0))) + val input = sc.parallelize(arr) + + // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) + + val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 42).head + model.rootNode match { + case n: InternalNode => n.split match { + case s: CategoricalSplit => + assert(s.leftCategories === Array(1.0)) + } + } + } + + test("Second level node building with vs. without groups") { + val arr = OldDTSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + // For tree with 1 group + val strategy1 = + new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 1000) + // For tree with multiple groups + val strategy2 = + new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 0) + + val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all", + seed = 42).head + val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all", + seed = 42).head + + def getChildren(rootNode: Node): Array[InternalNode] = rootNode match { + case n: InternalNode => + assert(n.leftChild.isInstanceOf[InternalNode]) + assert(n.rightChild.isInstanceOf[InternalNode]) + Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) + } + + // Single group second level tree construction. + val children1 = getChildren(tree1.rootNode) + val children2 = getChildren(tree2.rootNode) + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until 2) { + assert(children1(i).gain > 0) + assert(children2(i).gain > 0) + assert(children1(i).split === children2(i).split) + assert(children1(i).impurity === children2(i).impurity) + assert(children1(i).impurityStats.stats === children2(i).impurityStats.stats) + assert(children1(i).leftChild.impurity === children2(i).leftChild.impurity) + assert(children1(i).rightChild.impurity === children2(i).rightChild.impurity) + assert(children1(i).prediction === children2(i).prediction) + } + } + + def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) { + val numFeatures = 50 + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) + val rdd = sc.parallelize(arr) + + // Select feature subset for top nodes. Return true if OK. + def checkFeatureSubsetStrategy( + numTrees: Int, + featureSubsetStrategy: String, + numFeaturesPerNode: Int): Unit = { + val seeds = Array(123, 5354, 230, 349867, 23987) + val maxMemoryUsage: Long = 128 * 1024L * 1024L + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) + seeds.foreach { seed => + val failString = s"Failed on test with:" + + s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + + s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed" + val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees) + Range(0, numTrees).foreach { treeIndex => + topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1) + nodeQueue.enqueue((treeIndex, topNodes(treeIndex))) + } + val rng = new scala.util.Random(seed = seed) + val (nodesForGroup: Map[Int, Array[LearningNode]], + treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = + RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + + assert(nodesForGroup.size === numTrees, failString) + assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree + + if (numFeaturesPerNode == numFeatures) { + // featureSubset values should all be None + assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), + failString) + } else { + // Check number of features. + assert(treeToNodeToIndexInfo.values.forall(_.values.forall( + _.featureSubset.get.length === numFeaturesPerNode)), failString) + } + } + } + + checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) + + checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) + } + + test("Binary classification with continuous features: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + + test("Binary classification with continuous features and node Id cache: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + useNodeIdCache = true) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + test("computeFeatureImportance, featureImportances") { /* Build tree for testing, with this structure: grandParent http://git-wip-us.apache.org/repos/asf/spark/blob/cf823bea/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 89b64fc..bb1041b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -18,430 +18,23 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ -import scala.collection.mutable import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { ///////////////////////////////////////////////////////////////////////////// - // Tests examining individual elements of training - ///////////////////////////////////////////////////////////////////////////// - - test("Binary classification with continuous features: split and bin calculation") { - val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - assert(!metadata.isUnordered(featureIndex = 0)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(bins.length === 2) - assert(splits(0).length === 99) - assert(bins(0).length === 100) - } - - test("Binary classification with binary (ordered) categorical features:" + - " split and bin calculation") { - val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 2, - numClasses = 2, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) - - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(!metadata.isUnordered(featureIndex = 0)) - assert(!metadata.isUnordered(featureIndex = 1)) - assert(splits.length === 2) - assert(bins.length === 2) - // no bins or splits pre-computed for ordered categorical features - assert(splits(0).length === 0) - assert(bins(0).length === 0) - } - - test("Binary classification with 3-ary (ordered) categorical features," + - " with no samples for one category") { - val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 2, - numClasses = 2, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - assert(!metadata.isUnordered(featureIndex = 0)) - assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(bins.length === 2) - // no bins or splits pre-computed for ordered categorical features - assert(splits(0).length === 0) - assert(bins(0).length === 0) - } - - test("extract categories from a number for multiclass classification") { - val l = DecisionTree.extractMultiClassCategories(13, 10) - assert(l.length === 3) - assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) - } - - test("find splits for a continuous feature") { - // find splits for normal case - { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, - Map(), Set(), - Array(6), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 - ) - val featureSamples = Array.fill(200000)(math.random) - val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 5) - assert(fakeMetadata.numSplits(0) === 5) - assert(fakeMetadata.numBins(0) === 6) - // check returned splits are distinct - assert(splits.distinct.length === splits.length) - } - - // find splits should not return identical splits - // when there are not enough split candidates, reduce the number of splits in metadata - { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, - Map(), Set(), - Array(5), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 - ) - val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) - val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 3) - // check returned splits are distinct - assert(splits.distinct.length === splits.length) - } - - // find splits when most samples close to the minimum - { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, - Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 - ) - val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) - val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 2) - assert(splits(0) === 2.0) - assert(splits(1) === 3.0) - } - - // find splits when most samples close to the maximum - { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, - Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 - ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) - val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 1) - assert(splits(0) === 1.0) - } - } - - test("Multiclass classification with unordered categorical features:" + - " split and bin calculations") { - val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 2, - numClasses = 100, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - assert(metadata.isUnordered(featureIndex = 0)) - assert(metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(bins.length === 2) - assert(splits(0).length === 3) - assert(bins(0).length === 0) - assert(metadata.numSplits(0) === 3) - assert(metadata.numBins(0) === 3) - assert(metadata.numSplits(1) === 3) - assert(metadata.numBins(1) === 3) - - // Expecting 2^2 - 1 = 3 bins/splits - assert(splits(0)(0).feature === 0) - assert(splits(0)(0).threshold === Double.MinValue) - assert(splits(0)(0).featureType === Categorical) - assert(splits(0)(0).categories.length === 1) - assert(splits(0)(0).categories.contains(0.0)) - assert(splits(1)(0).feature === 1) - assert(splits(1)(0).threshold === Double.MinValue) - assert(splits(1)(0).featureType === Categorical) - assert(splits(1)(0).categories.length === 1) - assert(splits(1)(0).categories.contains(0.0)) - - assert(splits(0)(1).feature === 0) - assert(splits(0)(1).threshold === Double.MinValue) - assert(splits(0)(1).featureType === Categorical) - assert(splits(0)(1).categories.length === 1) - assert(splits(0)(1).categories.contains(1.0)) - assert(splits(1)(1).feature === 1) - assert(splits(1)(1).threshold === Double.MinValue) - assert(splits(1)(1).featureType === Categorical) - assert(splits(1)(1).categories.length === 1) - assert(splits(1)(1).categories.contains(1.0)) - - assert(splits(0)(2).feature === 0) - assert(splits(0)(2).threshold === Double.MinValue) - assert(splits(0)(2).featureType === Categorical) - assert(splits(0)(2).categories.length === 2) - assert(splits(0)(2).categories.contains(0.0)) - assert(splits(0)(2).categories.contains(1.0)) - assert(splits(1)(2).feature === 1) - assert(splits(1)(2).threshold === Double.MinValue) - assert(splits(1)(2).featureType === Categorical) - assert(splits(1)(2).categories.length === 2) - assert(splits(1)(2).categories.contains(0.0)) - assert(splits(1)(2).categories.contains(1.0)) - - } - - test("Multiclass classification with ordered categorical features: split and bin calculations") { - val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - assert(arr.length === 3000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 2, - numClasses = 100, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) - // 2^(10-1) - 1 > 100, so categorical features will be ordered - - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - assert(!metadata.isUnordered(featureIndex = 0)) - assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(bins.length === 2) - // no bins or splits pre-computed for ordered categorical features - assert(splits(0).length === 0) - assert(bins(0).length === 0) - } - - test("Avoid aggregation on the last level") { - val arr = Array( - LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue leaf nodes into node queue - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) - } - - test("Avoid aggregation if impurity is 0.0") { - val arr = Array( - LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue a node into node queue if its impurity is 0.0 - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) - } - - test("Use soft prediction for binary classification with ordered categorical features") { - // The following dataset is set up such that the best split is {1} vs. {0, 2}. - // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. - val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(1.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(1.0, Vectors.dense(2.0))) - val input = sc.parallelize(arr) - - // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) - - val model = new DecisionTree(strategy).run(input) - model.topNode.split.get match { - case Split(_, _, _, categories: List[Double]) => - assert(categories === List(1.0)) - } - } - - test("Second level node building with vs. without groups") { - val arr = DecisionTreeSuite.generateOrderedLabeledPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - - // Train a 1-node model - val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, - numClasses = 2, maxBins = 100) - val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val rootNode1 = modelOneNode.topNode.deepCopy() - val rootNode2 = modelOneNode.topNode.deepCopy() - assert(rootNode1.leftNode.nonEmpty) - assert(rootNode1.rightNode.nonEmpty) - - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - // Single group second level tree construction. - val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) - val treeToNodeToIndexInfo = Map((0, Map( - (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), - (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - val children1 = new Array[Node](2) - children1(0) = rootNode1.leftNode.get - children1(1) = rootNode1.rightNode.get - - // Train one second-level node at a time. - val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) - val treeToNodeToIndexInfoA = Map((0, Map( - (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) - val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) - val treeToNodeToIndexInfoB = Map((0, Map( - (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) - val children2 = new Array[Node](2) - children2(0) = rootNode2.leftNode.get - children2(1) = rootNode2.rightNode.get - - // Verify whether the splits obtained using single group and multiple group level - // construction strategies are the same. - for (i <- 0 until 2) { - assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) - assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) - assert(children1(i).split === children2(i).split) - assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) - val stats1 = children1(i).stats.get - val stats2 = children2(i).stats.get - assert(stats1.gain === stats2.gain) - assert(stats1.impurity === stats2.impurity) - assert(stats1.leftImpurity === stats2.leftImpurity) - assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict.predict === children2(i).predict.predict) - } - } - - ///////////////////////////////////////////////////////////////////////////// // Tests calling train() ///////////////////////////////////////////////////////////////////////////// @@ -457,22 +50,11 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - assert(!metadata.isUnordered(featureIndex = 0)) - assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(bins.length === 2) - // no bins or splits pre-computed for ordered categorical features - assert(splits(0).length === 0) - assert(bins(0).length === 0) - val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.categories === List(1.0)) assert(split.featureType === Categorical) - assert(split.threshold === Double.MinValue) val stats = rootNode.stats.get assert(stats.gain > 0) @@ -501,7 +83,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(split.categories.length === 1) assert(split.categories.contains(1.0)) assert(split.featureType === Categorical) - assert(split.threshold === Double.MinValue) val stats = rootNode.stats.get assert(stats.gain > 0) @@ -539,18 +120,11 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - val rootNode = DecisionTree.train(rdd, strategy).topNode - val stats = rootNode.stats.get - assert(stats.gain === 0) - assert(stats.leftImpurity === 0) - assert(stats.rightImpurity === 0) + assert(rootNode.impurity === 0) + assert(rootNode.stats.isEmpty) + assert(rootNode.predict.predict === 0) } test("Binary classification stump with fixed label 1 for Gini") { @@ -563,18 +137,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - val rootNode = DecisionTree.train(rdd, strategy).topNode - val stats = rootNode.stats.get - assert(stats.gain === 0) - assert(stats.leftImpurity === 0) - assert(stats.rightImpurity === 0) + assert(rootNode.impurity === 0) + assert(rootNode.stats.isEmpty) assert(rootNode.predict.predict === 1) } @@ -588,18 +154,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - val rootNode = DecisionTree.train(rdd, strategy).topNode - val stats = rootNode.stats.get - assert(stats.gain === 0) - assert(stats.leftImpurity === 0) - assert(stats.rightImpurity === 0) + assert(rootNode.impurity === 0) + assert(rootNode.stats.isEmpty) assert(rootNode.predict.predict === 0) } @@ -613,18 +171,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - val rootNode = DecisionTree.train(rdd, strategy).topNode - val stats = rootNode.stats.get - assert(stats.gain === 0) - assert(stats.leftImpurity === 0) - assert(stats.rightImpurity === 0) + assert(rootNode.impurity === 0) + assert(rootNode.stats.isEmpty) assert(rootNode.predict.predict === 1) } @@ -718,7 +268,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 100) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val model = DecisionTree.train(rdd, strategy) DecisionTreeSuite.validateClassifier(model, arr, 0.9) @@ -807,8 +356,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // test when no valid split can be found val rootNode = model.topNode - val gain = rootNode.stats.get - assert(gain == InformationGainStats.invalidInformationGainStats) + assert(rootNode.stats.isEmpty) } test("do not choose split that does not satisfy min instance per node requirements") { @@ -828,9 +376,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get - val gain = rootNode.stats.get + val gainStats = rootNode.stats.get assert(split.feature == 1) - assert(gain != InformationGainStats.invalidInformationGainStats) + assert(gainStats.gain >= 0) + assert(gainStats.impurity >= 0) } test("split must satisfy min info gain requirements") { @@ -852,10 +401,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { } // test when no valid split can be found - val rootNode = model.topNode - - val gain = rootNode.stats.get - assert(gain == InformationGainStats.invalidInformationGainStats) + assert(model.topNode.stats.isEmpty) } ///////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/spark/blob/cf823bea/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index c72fc9b..bec61ba 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.mllib.tree -import scala.collection.mutable - import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Gini, Variance} -import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} +import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils @@ -42,7 +39,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.trees.size === 1) + assert(rf.trees.length === 1) val rfTree = rf.trees(0) val dt = DecisionTree.train(rdd, strategy) @@ -78,7 +75,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.trees.size === 1) + assert(rf.trees.length === 1) val rfTree = rf.trees(0) val dt = DecisionTree.train(rdd, strategy) @@ -108,80 +105,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { regressionTestWithContinuousFeatures(strategy) } - def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) { - val numFeatures = 50 - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) - val rdd = sc.parallelize(arr) - - // Select feature subset for top nodes. Return true if OK. - def checkFeatureSubsetStrategy( - numTrees: Int, - featureSubsetStrategy: String, - numFeaturesPerNode: Int): Unit = { - val seeds = Array(123, 5354, 230, 349867, 23987) - val maxMemoryUsage: Long = 128 * 1024L * 1024L - val metadata = - DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) - seeds.foreach { seed => - val failString = s"Failed on test with:" + - s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + - s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed" - val nodeQueue = new mutable.Queue[(Int, Node)]() - val topNodes: Array[Node] = new Array[Node](numTrees) - Range(0, numTrees).foreach { treeIndex => - topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1) - nodeQueue.enqueue((treeIndex, topNodes(treeIndex))) - } - val rng = new scala.util.Random(seed = seed) - val (nodesForGroup: Map[Int, Array[Node]], - treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = - RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) - - assert(nodesForGroup.size === numTrees, failString) - assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree - - if (numFeaturesPerNode == numFeatures) { - // featureSubset values should all be None - assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), - failString) - } else { - // Check number of features. - assert(treeToNodeToIndexInfo.values.forall(_.values.forall( - _.featureSubset.get.size === numFeaturesPerNode)), failString) - } - } - } - - checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures) - checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures) - checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 1, "log2", - (math.log(numFeatures) / math.log(2)).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) - - checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) - checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 2, "log2", - (math.log(numFeatures) / math.log(2)).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) - } - - test("Binary classification with continuous features: subsampling features") { - val categoricalFeaturesInfo = Map.empty[Int, Int] - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) - binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) - } - - test("Binary classification with continuous features and node Id cache: subsampling features") { - val categoricalFeaturesInfo = Map.empty[Int, Int] - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, - useNodeIdCache = true) - binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) - } - test("alternating categorical and continuous features with multiclass labels to test indexing") { val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)) http://git-wip-us.apache.org/repos/asf/spark/blob/cf823bea/python/pyspark/ml/param/_shared_params_code_gen.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 7dd2937..715fa9e 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -164,7 +164,8 @@ if __name__ == "__main__": "split will be discarded as invalid. Should be >= 1.", "TypeConverters.toInt"), ("minInfoGain", "Minimum information gain for a split to be considered at a tree node.", "TypeConverters.toFloat"), - ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.", + ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation. If too small," + + " then 1 node will be split per iteration, and its aggregates may exceed this size.", "TypeConverters.toInt"), ("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " + "instances with nodes. If true, the algorithm will cache node IDs for each instance. " + http://git-wip-us.apache.org/repos/asf/spark/blob/cf823bea/python/pyspark/ml/param/shared.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 83fbd59..d79d55e 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -568,7 +568,7 @@ class DecisionTreeParams(Params): maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.", typeConverter=TypeConverters.toInt) minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.", typeConverter=TypeConverters.toInt) minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.", typeConverter=TypeConverters.toFloat) - maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.", typeConverter=TypeConverters.toInt) + maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.", typeConverter=TypeConverters.toInt) cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.", typeConverter=TypeConverters.toBoolean) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
