Repository: spark
Updated Branches:
  refs/heads/master 0d1cc4ae4 -> 711356b42


http://git-wip-us.apache.org/repos/asf/spark/blob/711356b4/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
new file mode 100644
index 0000000..866d85a
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import org.apache.spark.mllib.tree.impurity._
+
+/**
+ * DecisionTree statistics aggregator.
+ * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * and helps with indexing.
+ */
+private[tree] class DTStatsAggregator(
+    val metadata: DecisionTreeMetadata,
+    val numNodes: Int) extends Serializable {
+
+  /**
+   * [[ImpurityAggregator]] instance specifying the impurity type.
+   */
+  val impurityAggregator: ImpurityAggregator = metadata.impurity match {
+    case Gini => new GiniAggregator(metadata.numClasses)
+    case Entropy => new EntropyAggregator(metadata.numClasses)
+    case Variance => new VarianceAggregator()
+    case _ => throw new IllegalArgumentException(s"Bad impurity parameter: 
${metadata.impurity}")
+  }
+
+  /**
+   * Number of elements (Double values) used for the sufficient statistics of 
each bin.
+   */
+  val statsSize: Int = impurityAggregator.statsSize
+
+  val numFeatures: Int = metadata.numFeatures
+
+  /**
+   * Number of bins for each feature.  This is indexed by the feature index.
+   */
+  val numBins: Array[Int] = metadata.numBins
+
+  /**
+   * Number of splits for the given feature.
+   */
+  def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex)
+
+  /**
+   * Indicator for each feature of whether that feature is an unordered 
feature.
+   * TODO: Is Array[Boolean] any faster?
+   */
+  def isUnordered(featureIndex: Int): Boolean = 
metadata.isUnordered(featureIndex)
+
+  /**
+   * Offset for each feature for calculating indices into the [[allStats]] 
array.
+   */
+  private val featureOffsets: Array[Int] = {
+    def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
+      if (isUnordered(featureIndex)) {
+        total + 2 * numBins(featureIndex)
+      } else {
+        total + numBins(featureIndex)
+      }
+    }
+    Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * 
_).toArray
+  }
+
+  /**
+   * Number of elements for each node, corresponding to stride between nodes 
in [[allStats]].
+   */
+  private val nodeStride: Int = featureOffsets.last
+
+  /**
+   * Total number of elements stored in this aggregator.
+   */
+  val allStatsSize: Int = numNodes * nodeStride
+
+  /**
+   * Flat array of elements.
+   * Index for start of stats for a (node, feature, bin) is:
+   *   index = nodeIndex * nodeStride + featureOffsets(featureIndex) + 
binIndex * statsSize
+   * Note: For unordered features, the left child stats have binIndex in [0, 
numBins(featureIndex))
+   *       and the right child stats in [numBins(featureIndex), 2 * 
numBins(featureIndex))
+   */
+  val allStats: Array[Double] = new Array[Double](allStatsSize)
+
+  /**
+   * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
+   * @param nodeFeatureOffset  For ordered features, this is a pre-computed 
(node, feature) offset
+   *                           from [[getNodeFeatureOffset]].
+   *                           For unordered features, this is a pre-computed
+   *                           (node, feature, left/right child) offset from
+   *                           [[getLeftRightNodeFeatureOffsets]].
+   */
+  def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): 
ImpurityCalculator = {
+    impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * 
statsSize)
+  }
+
+  /**
+   * Update the stats for a given (node, feature, bin) for ordered features, 
using the given label.
+   */
+  def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): 
Unit = {
+    val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * 
statsSize
+    impurityAggregator.update(allStats, i, label)
+  }
+
+  /**
+   * Pre-compute node offset for use with [[nodeUpdate]].
+   */
+  def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
+
+  /**
+   * Faster version of [[update]].
+   * Update the stats for a given (node, feature, bin) for ordered features, 
using the given label.
+   * @param nodeOffset  Pre-computed node offset from [[getNodeOffset]].
+   */
+  def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: 
Double): Unit = {
+    val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
+    impurityAggregator.update(allStats, i, label)
+  }
+
+  /**
+   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+   * For ordered features only.
+   */
+  def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
+    require(!isUnordered(featureIndex),
+      s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, 
but was called" +
+      s" for unordered feature $featureIndex.")
+    nodeIndex * nodeStride + featureOffsets(featureIndex)
+  }
+
+  /**
+   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+   * For unordered features only.
+   */
+  def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, 
Int) = {
+    require(isUnordered(featureIndex),
+      s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered 
features only," +
+      s" but was called for ordered feature $featureIndex.")
+    val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
+    (baseOffset, baseOffset + numBins(featureIndex) * statsSize)
+  }
+
+  /**
+   * Faster version of [[update]].
+   * Update the stats for a given (node, feature, bin), using the given label.
+   * @param nodeFeatureOffset  For ordered features, this is a pre-computed 
(node, feature) offset
+   *                           from [[getNodeFeatureOffset]].
+   *                           For unordered features, this is a pre-computed
+   *                           (node, feature, left/right child) offset from
+   *                           [[getLeftRightNodeFeatureOffsets]].
+   */
+  def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): 
Unit = {
+    impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * 
statsSize, label)
+  }
+
+  /**
+   * For a given (node, feature), merge the stats for two bins.
+   * @param nodeFeatureOffset  For ordered features, this is a pre-computed 
(node, feature) offset
+   *                           from [[getNodeFeatureOffset]].
+   *                           For unordered features, this is a pre-computed
+   *                           (node, feature, left/right child) offset from
+   *                           [[getLeftRightNodeFeatureOffsets]].
+   * @param binIndex  The other bin is merged into this bin.
+   * @param otherBinIndex  This bin is not modified.
+   */
+  def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, 
otherBinIndex: Int): Unit = {
+    impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * 
statsSize,
+      nodeFeatureOffset + otherBinIndex * statsSize)
+  }
+
+  /**
+   * Merge this aggregator with another, and returns this aggregator.
+   * This method modifies this aggregator in-place.
+   */
+  def merge(other: DTStatsAggregator): DTStatsAggregator = {
+    require(allStatsSize == other.allStatsSize,
+      s"DTStatsAggregator.merge requires that both aggregators have the same 
length stats vectors."
+      + s" This aggregator is of length $allStatsSize, but the other is 
${other.allStatsSize}.")
+    var i = 0
+    // TODO: Test BLAS.axpy
+    while (i < allStatsSize) {
+      allStats(i) += other.allStats(i)
+      i += 1
+    }
+    this
+  }
+
+}
+
+private[tree] object DTStatsAggregator extends Serializable {
+
+  /**
+   * Combines two aggregates (modifying the first) and returns the combination.
+   */
+  def binCombOp(
+      agg1: DTStatsAggregator,
+      agg2: DTStatsAggregator): DTStatsAggregator = {
+    agg1.merge(agg2)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/711356b4/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index d9eda35..e95add7 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -26,14 +26,15 @@ import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.impurity.Impurity
 import org.apache.spark.rdd.RDD
 
-
 /**
  * Learning and dataset metadata for DecisionTree.
  *
  * @param numClasses    For classification: labels can take values {0, ..., 
numClasses - 1}.
  *                      For regression: fixed at 0 (no meaning).
+ * @param maxBins  Maximum number of bins, for all features.
  * @param featureArity  Map: categorical feature index --> arity.
  *                      I.e., the feature takes values in {0, ..., arity - 1}.
+ * @param numBins  Number of bins for each feature.
  */
 private[tree] class DecisionTreeMetadata(
     val numFeatures: Int,
@@ -42,6 +43,7 @@ private[tree] class DecisionTreeMetadata(
     val maxBins: Int,
     val featureArity: Map[Int, Int],
     val unorderedFeatures: Set[Int],
+    val numBins: Array[Int],
     val impurity: Impurity,
     val quantileStrategy: QuantileStrategy) extends Serializable {
 
@@ -57,10 +59,26 @@ private[tree] class DecisionTreeMetadata(
 
   def isContinuous(featureIndex: Int): Boolean = 
!featureArity.contains(featureIndex)
 
+  /**
+   * Number of splits for the given feature.
+   * For unordered features, there are 2 bins per split.
+   * For ordered features, there is 1 more bin than split.
+   */
+  def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
+    numBins(featureIndex) >> 1
+  } else {
+    numBins(featureIndex) - 1
+  }
+
 }
 
 private[tree] object DecisionTreeMetadata {
 
+  /**
+   * Construct a [[DecisionTreeMetadata]] instance for this dataset and 
parameters.
+   * This computes which categorical features will be ordered vs. unordered,
+   * as well as the number of splits and bins for each feature.
+   */
   def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): 
DecisionTreeMetadata = {
 
     val numFeatures = input.take(1)(0).features.size
@@ -70,32 +88,55 @@ private[tree] object DecisionTreeMetadata {
       case Regression => 0
     }
 
-    val maxBins = math.min(strategy.maxBins, numExamples).toInt
-    val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
+    val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
+
+    // We check the number of bins here against maxPossibleBins.
+    // This needs to be checked here instead of in Strategy since 
maxPossibleBins can be modified
+    // based on the number of training examples.
+    if (strategy.categoricalFeaturesInfo.nonEmpty) {
+      val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
+      require(maxCategoriesPerFeature <= maxPossibleBins,
+        s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories 
" +
+          s"in categorical features (= $maxCategoriesPerFeature)")
+    }
 
     val unorderedFeatures = new mutable.HashSet[Int]()
+    val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
     if (numClasses > 2) {
-      strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
-        if (k - 1 < log2MaxBinsp1) {
-          // Note: The above check is equivalent to checking:
-          //       numUnorderedBins = (1 << k - 1) - 1 < maxBins
-          unorderedFeatures.add(f)
+      // Multiclass classification
+      val maxCategoriesForUnorderedFeature =
+        ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
+      strategy.categoricalFeaturesInfo.foreach { case (featureIndex, 
numCategories) =>
+        // Decide if some categorical features should be treated as unordered 
features,
+        //  which require 2 * ((1 << numCategories - 1) - 1) bins.
+        // We do this check with log values to prevent overflows in case 
numCategories is large.
+        // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) 
<= maxBins
+        if (numCategories <= maxCategoriesForUnorderedFeature) {
+          unorderedFeatures.add(featureIndex)
+          numBins(featureIndex) = numUnorderedBins(numCategories)
         } else {
-          // TODO: Allow this case, where we simply will know nothing about 
some categories?
-          require(k < maxBins, s"maxBins (= $maxBins) should be greater than 
max categories " +
-            s"in categorical features (>= $k)")
+          numBins(featureIndex) = numCategories
         }
       }
     } else {
-      strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
-        require(k < maxBins, s"maxBins (= $maxBins) should be greater than max 
categories " +
-          s"in categorical features (>= $k)")
+      // Binary classification or regression
+      strategy.categoricalFeaturesInfo.foreach { case (featureIndex, 
numCategories) =>
+        numBins(featureIndex) = numCategories
       }
     }
 
-    new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
-      strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
+    new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
+      strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
       strategy.impurity, strategy.quantileCalculationStrategy)
   }
 
+  /**
+   * Given the arity of a categorical feature (arity = number of categories),
+   * return the number of bins for the feature if it is to be treated as an 
unordered feature.
+   * There is 1 split for every partitioning of categories into 2 disjoint, 
non-empty sets;
+   * there are math.pow(2, arity - 1) - 1 such splits.
+   * Each split has 2 corresponding bins.
+   */
+  def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/711356b4/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
index 170e43e..35e361a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -48,54 +48,63 @@ private[tree] object TreePoint {
    * binning feature values in preparation for DecisionTree training.
    * @param input     Input dataset.
    * @param bins      Bins for features, of size (numFeatures, numBins).
-   * @param metadata Learning and dataset metadata
+   * @param metadata  Learning and dataset metadata
    * @return  TreePoint dataset representation
    */
   def convertToTreeRDD(
       input: RDD[LabeledPoint],
       bins: Array[Array[Bin]],
       metadata: DecisionTreeMetadata): RDD[TreePoint] = {
+    // Construct arrays for featureArity and isUnordered for efficiency in the 
inner loop.
+    val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
+    val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures)
+    var featureIndex = 0
+    while (featureIndex < metadata.numFeatures) {
+      featureArity(featureIndex) = 
metadata.featureArity.getOrElse(featureIndex, 0)
+      isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
+      featureIndex += 1
+    }
     input.map { x =>
-      TreePoint.labeledPointToTreePoint(x, bins, metadata)
+      TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered)
     }
   }
 
   /**
    * Convert one LabeledPoint into its TreePoint representation.
    * @param bins      Bins for features, of size (numFeatures, numBins).
+   * @param featureArity  Array indexed by feature, with value 0 for 
continuous and numCategories
+   *                      for categorical features.
+   * @param isUnordered  Array index by feature, with value true for unordered 
categorical features.
    */
   private def labeledPointToTreePoint(
       labeledPoint: LabeledPoint,
       bins: Array[Array[Bin]],
-      metadata: DecisionTreeMetadata): TreePoint = {
-
+      featureArity: Array[Int],
+      isUnordered: Array[Boolean]): TreePoint = {
     val numFeatures = labeledPoint.features.size
-    val numBins = bins(0).size
     val arr = new Array[Int](numFeatures)
     var featureIndex = 0
     while (featureIndex < numFeatures) {
-      arr(featureIndex) = findBin(featureIndex, labeledPoint, 
metadata.isContinuous(featureIndex),
-        metadata.isUnordered(featureIndex), bins, metadata.featureArity)
+      arr(featureIndex) = findBin(featureIndex, labeledPoint, 
featureArity(featureIndex),
+        isUnordered(featureIndex), bins)
       featureIndex += 1
     }
-
     new TreePoint(labeledPoint.label, arr)
   }
 
   /**
    * Find bin for one (labeledPoint, feature).
    *
+   * @param featureArity  0 for continuous features; number of categories for 
categorical features.
    * @param isUnorderedFeature  (only applies if feature is categorical)
    * @param bins   Bins for features, of size (numFeatures, numBins).
-   * @param categoricalFeaturesInfo  Map over categorical features: feature 
index --> feature arity
    */
   private def findBin(
       featureIndex: Int,
       labeledPoint: LabeledPoint,
-      isFeatureContinuous: Boolean,
+      featureArity: Int,
       isUnorderedFeature: Boolean,
-      bins: Array[Array[Bin]],
-      categoricalFeaturesInfo: Map[Int, Int]): Int = {
+      bins: Array[Array[Bin]]): Int = {
 
     /**
      * Binary search helper method for continuous feature.
@@ -121,44 +130,7 @@ private[tree] object TreePoint {
       -1
     }
 
-    /**
-     * Sequential search helper method to find bin for categorical feature in 
multiclass
-     * classification. The category is returned since each category can belong 
to multiple
-     * splits. The actual left/right child allocation per split is performed 
in the
-     * sequential phase of the bin aggregate operation.
-     */
-    def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): 
Int = {
-      labeledPoint.features(featureIndex).toInt
-    }
-
-    /**
-     * Sequential search helper method to find bin for categorical feature
-     * (for classification and regression).
-     */
-    def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
-      val featureCategories = categoricalFeaturesInfo(featureIndex)
-      val featureValue = labeledPoint.features(featureIndex)
-      var binIndex = 0
-      while (binIndex < featureCategories) {
-        val bin = bins(featureIndex)(binIndex)
-        val categories = bin.highSplit.categories
-        if (categories.contains(featureValue)) {
-          return binIndex
-        }
-        binIndex += 1
-      }
-      if (featureValue < 0 || featureValue >= featureCategories) {
-        throw new IllegalArgumentException(
-          s"DecisionTree given invalid data:" +
-            s" Feature $featureIndex is categorical with values in" +
-            s" {0,...,${featureCategories - 1}," +
-            s" but a data point gives it value $featureValue.\n" +
-            "  Bad data point: " + labeledPoint.toString)
-      }
-      -1
-    }
-
-    if (isFeatureContinuous) {
+    if (featureArity == 0) {
       // Perform binary search for finding bin for continuous features.
       val binIndex = binarySearchForBins()
       if (binIndex == -1) {
@@ -168,18 +140,17 @@ private[tree] object TreePoint {
       }
       binIndex
     } else {
-      // Perform sequential search to find bin for categorical features.
-      val binIndex = if (isUnorderedFeature) {
-          sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
-        } else {
-          sequentialBinSearchForOrderedCategoricalFeature()
-        }
-      if (binIndex == -1) {
-        throw new RuntimeException("No bin was found for categorical feature." 
+
-          " This error can occur when given invalid data values (such as 
NaN)." +
-          s" Feature index: $featureIndex.  Feature value: 
${labeledPoint.features(featureIndex)}")
+      // Categorical feature bins are indexed by feature values.
+      val featureValue = labeledPoint.features(featureIndex)
+      if (featureValue < 0 || featureValue >= featureArity) {
+        throw new IllegalArgumentException(
+          s"DecisionTree given invalid data:" +
+            s" Feature $featureIndex is categorical with values in" +
+            s" {0,...,${featureArity - 1}," +
+            s" but a data point gives it value $featureValue.\n" +
+            "  Bad data point: " + labeledPoint.toString)
       }
-      binIndex
+      featureValue.toInt
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/711356b4/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 96d2471..1c8afc2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -74,3 +74,87 @@ object Entropy extends Impurity {
   def instance = this
 
 }
+
+/**
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views 
of the data.
+ * @param numClasses  Number of classes for label.
+ */
+private[tree] class EntropyAggregator(numClasses: Int)
+  extends ImpurityAggregator(numClasses) with Serializable {
+
+  /**
+   * Update stats for one (node, feature, bin) with the given label.
+   * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+    if (label >= statsSize) {
+      throw new IllegalArgumentException(s"EntropyAggregator given label 
$label" +
+        s" but requires label < numClasses (= $statsSize).")
+    }
+    allStats(offset + label.toInt) += 1
+  }
+
+  /**
+   * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+   * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = 
{
+    new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
+  }
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[EntropyAggregator]], this class stores its own data and is for a 
specific
+ * (node, feature, bin).
+ * @param stats  Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class EntropyCalculator(stats: Array[Double]) extends 
ImpurityCalculator(stats) {
+
+  /**
+   * Make a deep copy of this [[ImpurityCalculator]].
+   */
+  def copy: EntropyCalculator = new EntropyCalculator(stats.clone())
+
+  /**
+   * Calculate the impurity from the stored sufficient statistics.
+   */
+  def calculate(): Double = Entropy.calculate(stats, stats.sum)
+
+  /**
+   * Number of data points accounted for in the sufficient statistics.
+   */
+  def count: Long = stats.sum.toLong
+
+  /**
+   * Prediction which should be made based on the sufficient statistics.
+   */
+  def predict: Double = if (count == 0) {
+    0
+  } else {
+    indexOfLargestArrayElement(stats)
+  }
+
+  /**
+   * Probability of the label given by [[predict]].
+   */
+  override def prob(label: Double): Double = {
+    val lbl = label.toInt
+    require(lbl < stats.length,
+      s"EntropyCalculator.prob given invalid label: $lbl (should be < 
${stats.length}")
+    val cnt = count
+    if (cnt == 0) {
+      0
+    } else {
+      stats(lbl) / cnt
+    }
+  }
+
+  override def toString: String = s"EntropyCalculator(stats = 
[${stats.mkString(", ")}])"
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/711356b4/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index d586f44..5cfdf34 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -70,3 +70,87 @@ object Gini extends Impurity {
   def instance = this
 
 }
+
+/**
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views 
of the data.
+ * @param numClasses  Number of classes for label.
+ */
+private[tree] class GiniAggregator(numClasses: Int)
+  extends ImpurityAggregator(numClasses) with Serializable {
+
+  /**
+   * Update stats for one (node, feature, bin) with the given label.
+   * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+    if (label >= statsSize) {
+      throw new IllegalArgumentException(s"GiniAggregator given label $label" +
+        s" but requires label < numClasses (= $statsSize).")
+    }
+    allStats(offset + label.toInt) += 1
+  }
+
+  /**
+   * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+   * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
+    new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
+  }
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[GiniAggregator]], this class stores its own data and is for a 
specific
+ * (node, feature, bin).
+ * @param stats  Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class GiniCalculator(stats: Array[Double]) extends 
ImpurityCalculator(stats) {
+
+  /**
+   * Make a deep copy of this [[ImpurityCalculator]].
+   */
+  def copy: GiniCalculator = new GiniCalculator(stats.clone())
+
+  /**
+   * Calculate the impurity from the stored sufficient statistics.
+   */
+  def calculate(): Double = Gini.calculate(stats, stats.sum)
+
+  /**
+   * Number of data points accounted for in the sufficient statistics.
+   */
+  def count: Long = stats.sum.toLong
+
+  /**
+   * Prediction which should be made based on the sufficient statistics.
+   */
+  def predict: Double = if (count == 0) {
+    0
+  } else {
+    indexOfLargestArrayElement(stats)
+  }
+
+  /**
+   * Probability of the label given by [[predict]].
+   */
+  override def prob(label: Double): Double = {
+    val lbl = label.toInt
+    require(lbl < stats.length,
+      s"GiniCalculator.prob given invalid label: $lbl (should be < 
${stats.length}")
+    val cnt = count
+    if (cnt == 0) {
+      0
+    } else {
+      stats(lbl) / cnt
+    }
+  }
+
+  override def toString: String = s"GiniCalculator(stats = 
[${stats.mkString(", ")}])"
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/711356b4/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 92b0c7b..5a047d6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -22,6 +22,9 @@ import org.apache.spark.annotation.{DeveloperApi, 
Experimental}
 /**
  * :: Experimental ::
  * Trait for calculating information gain.
+ * This trait is used for
+ *  (a) setting the impurity parameter in 
[[org.apache.spark.mllib.tree.configuration.Strategy]]
+ *  (b) calculating impurity values from sufficient statistics.
  */
 @Experimental
 trait Impurity extends Serializable {
@@ -47,3 +50,127 @@ trait Impurity extends Serializable {
   @DeveloperApi
   def calculate(count: Double, sum: Double, sumSquares: Double): Double
 }
+
+/**
+ * Interface for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views 
of the data.
+ * @param statsSize  Length of the vector of sufficient statistics for one bin.
+ */
+private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends 
Serializable {
+
+  /**
+   * Merge the stats from one bin into another.
+   * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
+   * @param offset    Start index of stats for (node, feature, bin) which is 
modified by the merge.
+   * @param otherOffset  Start index of stats for (node, feature, other bin) 
which is not modified.
+   */
+  def merge(allStats: Array[Double], offset: Int, otherOffset: Int): Unit = {
+    var i = 0
+    while (i < statsSize) {
+      allStats(offset + i) += allStats(otherOffset + i)
+      i += 1
+    }
+  }
+
+  /**
+   * Update stats for one (node, feature, bin) with the given label.
+   * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def update(allStats: Array[Double], offset: Int, label: Double): Unit
+
+  /**
+   * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+   * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[ImpurityAggregator]], this class stores its own data and is for a 
specific
+ * (node, feature, bin).
+ * @param stats  Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) {
+
+  /**
+   * Make a deep copy of this [[ImpurityCalculator]].
+   */
+  def copy: ImpurityCalculator
+
+  /**
+   * Calculate the impurity from the stored sufficient statistics.
+   */
+  def calculate(): Double
+
+  /**
+   * Add the stats from another calculator into this one, modifying and 
returning this calculator.
+   */
+  def add(other: ImpurityCalculator): ImpurityCalculator = {
+    require(stats.size == other.stats.size,
+      s"Two ImpurityCalculator instances cannot be added with different counts 
sizes." +
+        s"  Sizes are ${stats.size} and ${other.stats.size}.")
+    var i = 0
+    while (i < other.stats.size) {
+      stats(i) += other.stats(i)
+      i += 1
+    }
+    this
+  }
+
+  /**
+   * Subtract the stats from another calculator from this one, modifying and 
returning this
+   * calculator.
+   */
+  def subtract(other: ImpurityCalculator): ImpurityCalculator = {
+    require(stats.size == other.stats.size,
+      s"Two ImpurityCalculator instances cannot be subtracted with different 
counts sizes." +
+      s"  Sizes are ${stats.size} and ${other.stats.size}.")
+    var i = 0
+    while (i < other.stats.size) {
+      stats(i) -= other.stats(i)
+      i += 1
+    }
+    this
+  }
+
+  /**
+   * Number of data points accounted for in the sufficient statistics.
+   */
+  def count: Long
+
+  /**
+   * Prediction which should be made based on the sufficient statistics.
+   */
+  def predict: Double
+
+  /**
+   * Probability of the label given by [[predict]], or -1 if no probability is 
available.
+   */
+  def prob(label: Double): Double = -1
+
+  /**
+   * Return the index of the largest array element.
+   * Fails if the array is empty.
+   */
+  protected def indexOfLargestArrayElement(array: Array[Double]): Int = {
+    val result = array.foldLeft(-1, Double.MinValue, 0) {
+      case ((maxIndex, maxValue, currentIndex), currentValue) =>
+        if (currentValue > maxValue) {
+          (currentIndex, currentValue, currentIndex + 1)
+        } else {
+          (maxIndex, maxValue, currentIndex + 1)
+        }
+    }
+    if (result._1 < 0) {
+      throw new RuntimeException("ImpurityCalculator internal error:" +
+        " indexOfLargestArrayElement failed")
+    }
+    result._1
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/711356b4/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index f7d99a4..e9ccecb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -61,3 +61,75 @@ object Variance extends Impurity {
   def instance = this
 
 }
+
+/**
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views 
of the data.
+ */
+private[tree] class VarianceAggregator()
+  extends ImpurityAggregator(statsSize = 3) with Serializable {
+
+  /**
+   * Update stats for one (node, feature, bin) with the given label.
+   * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+    allStats(offset) += 1
+    allStats(offset + 1) += label
+    allStats(offset + 2) += label * label
+  }
+
+  /**
+   * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+   * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator 
= {
+    new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
+  }
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[GiniAggregator]], this class stores its own data and is for a 
specific
+ * (node, feature, bin).
+ * @param stats  Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class VarianceCalculator(stats: Array[Double]) extends 
ImpurityCalculator(stats) {
+
+  require(stats.size == 3,
+    s"VarianceCalculator requires sufficient statistics array stats to be of 
length 3," +
+    s" but was given array of length ${stats.size}.")
+
+  /**
+   * Make a deep copy of this [[ImpurityCalculator]].
+   */
+  def copy: VarianceCalculator = new VarianceCalculator(stats.clone())
+
+  /**
+   * Calculate the impurity from the stored sufficient statistics.
+   */
+  def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
+
+  /**
+   * Number of data points accounted for in the sufficient statistics.
+   */
+  def count: Long = stats(0).toLong
+
+  /**
+   * Prediction which should be made based on the sufficient statistics.
+   */
+  def predict: Double = if (count == 0) {
+    0
+  } else {
+    stats(1) / count
+  }
+
+  override def toString: String = {
+    s"VarianceAggregator(cnt = ${stats(0)}, sum = ${stats(1)}, sum2 = 
${stats(2)})"
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/711356b4/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
index af35d88..0cad473 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.model
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 
 /**
- * Used for "binning" the features bins for faster best split calculation.
+ * Used for "binning" the feature values for faster best split calculation.
  *
  * For a continuous feature, the bin is determined by a low and a high split,
  *  where an example with featureValue falls into the bin s.t.
@@ -30,13 +30,16 @@ import 
org.apache.spark.mllib.tree.configuration.FeatureType._
  *  bins, splits, and feature values.  The bin is determined by 
category/feature value.
  *  However, the bins are not necessarily ordered by feature value;
  *  they are ordered using impurity.
+ *
  * For unordered categorical features, there is a 1-1 correspondence between 
bins, splits,
  *  where bins and splits correspond to subsets of feature values (in 
highSplit.categories).
+ *  An unordered feature with k categories uses (1 << k - 1) - 1 bins, 
corresponding to all
+ *  partitionings of categories into 2 disjoint, non-empty sets.
  *
  * @param lowSplit signifying the lower threshold for the continuous feature 
to be
  *                 accepted in the bin
  * @param highSplit signifying the upper threshold for the continuous feature 
to be
- *                 accepted in the bin
+ *                  accepted in the bin
  * @param featureType type of feature -- categorical or continuous
  * @param category categorical label value accepted in the bin for ordered 
features
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/711356b4/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 0eee626..5b8a4cb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -24,8 +24,13 @@ import org.apache.spark.mllib.linalg.Vector
 
 /**
  * :: DeveloperApi ::
- * Node in a decision tree
- * @param id integer node id
+ * Node in a decision tree.
+ *
+ * About node indexing:
+ *   Nodes are indexed from 1.  Node 1 is the root; nodes 2, 3 are the left, 
right children.
+ *   Node index 0 is not used.
+ *
+ * @param id integer node id, from 1
  * @param predict predicted value at the node
  * @param isLeaf whether the leaf is a node
  * @param split split to calculate left and right nodes
@@ -51,17 +56,13 @@ class Node (
    * @param nodes array of nodes
    */
   def build(nodes: Array[Node]): Unit = {
-
-    logDebug("building node " + id + " at level " +
-      (scala.math.log(id + 1)/scala.math.log(2)).toInt )
+    logDebug("building node " + id + " at level " + Node.indexToLevel(id))
     logDebug("id = " + id + ", split = " + split)
     logDebug("stats = " + stats)
     logDebug("predict = " + predict)
     if (!isLeaf) {
-      val leftNodeIndex = id * 2 + 1
-      val rightNodeIndex = id * 2 + 2
-      leftNode = Some(nodes(leftNodeIndex))
-      rightNode = Some(nodes(rightNodeIndex))
+      leftNode = Some(nodes(Node.leftChildIndex(id)))
+      rightNode = Some(nodes(Node.rightChildIndex(id)))
       leftNode.get.build(nodes)
       rightNode.get.build(nodes)
     }
@@ -96,24 +97,20 @@ class Node (
    * Get the number of nodes in tree below this node, including leaf nodes.
    * E.g., if this is a leaf, returns 0.  If both children are leaves, returns 
2.
    */
-  private[tree] def numDescendants: Int = {
-    if (isLeaf) {
-      0
-    } else {
-      2 + leftNode.get.numDescendants + rightNode.get.numDescendants
-    }
+  private[tree] def numDescendants: Int = if (isLeaf) {
+    0
+  } else {
+    2 + leftNode.get.numDescendants + rightNode.get.numDescendants
   }
 
   /**
    * Get depth of tree from this node.
    * E.g.: Depth 0 means this is a leaf node.
    */
-  private[tree] def subtreeDepth: Int = {
-    if (isLeaf) {
-      0
-    } else {
-      1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth)
-    }
+  private[tree] def subtreeDepth: Int = if (isLeaf) {
+    0
+  } else {
+    1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth)
   }
 
   /**
@@ -148,3 +145,49 @@ class Node (
   }
 
 }
+
+private[tree] object Node {
+
+  /**
+   * Return the index of the left child of this node.
+   */
+  def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
+
+  /**
+   * Return the index of the right child of this node.
+   */
+  def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
+
+  /**
+   * Get the parent index of the given node, or 0 if it is the root.
+   */
+  def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1
+
+  /**
+   * Return the level of a tree which the given node is in.
+   */
+  def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) {
+    throw new IllegalArgumentException(s"0 is not a valid node index.")
+  } else {
+    
java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex))
+  }
+
+  /**
+   * Returns true if this is a left child.
+   * Note: Returns false for the root.
+   */
+  def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0
+
+  /**
+   * Return the maximum number of nodes which can be in the given level of the 
tree.
+   * @param level  Level of tree (0 = root).
+   */
+  def maxNodesInLevel(level: Int): Int = 1 << level
+
+  /**
+   * Return the index of the first node in the given level.
+   * @param level  Level of tree (0 = root).
+   */
+  def startIndexInLevel(level: Int): Int = 1 << level
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/711356b4/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 2f36fd9..8e556c9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -21,15 +21,16 @@ import scala.collection.JavaConverters._
 
 import org.scalatest.FunSuite
 
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
-import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.LocalSparkContext
-import org.apache.spark.mllib.regression.LabeledPoint
+
 
 class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
@@ -59,12 +60,13 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but 
required $requiredMSE.")
   }
 
-  test("split and bin calculation") {
+  test("Binary classification with continuous features: split and bin 
calculation") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Gini, 3, 2, 100)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(bins.length === 2)
@@ -72,7 +74,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bins(0).length === 100)
   }
 
-  test("split and bin calculation for categorical variables") {
+  test("Binary classification with binary (ordered) categorical features:" +
+    " split and bin calculation") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -83,77 +86,20 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       numClassesForClassification = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     assert(splits.length === 2)
     assert(bins.length === 2)
-    assert(splits(0).length === 99)
-    assert(bins(0).length === 100)
-
-    // Check splits.
-
-    assert(splits(0)(0).feature === 0)
-    assert(splits(0)(0).threshold === Double.MinValue)
-    assert(splits(0)(0).featureType === Categorical)
-    assert(splits(0)(0).categories.length === 1)
-    assert(splits(0)(0).categories.contains(1.0))
-
-    assert(splits(0)(1).feature === 0)
-    assert(splits(0)(1).threshold === Double.MinValue)
-    assert(splits(0)(1).featureType === Categorical)
-    assert(splits(0)(1).categories.length === 2)
-    assert(splits(0)(1).categories.contains(1.0))
-    assert(splits(0)(1).categories.contains(0.0))
-
-    assert(splits(0)(2) === null)
-
-    assert(splits(1)(0).feature === 1)
-    assert(splits(1)(0).threshold === Double.MinValue)
-    assert(splits(1)(0).featureType === Categorical)
-    assert(splits(1)(0).categories.length === 1)
-    assert(splits(1)(0).categories.contains(0.0))
-
-    assert(splits(1)(1).feature === 1)
-    assert(splits(1)(1).threshold === Double.MinValue)
-    assert(splits(1)(1).featureType === Categorical)
-    assert(splits(1)(1).categories.length === 2)
-    assert(splits(1)(1).categories.contains(1.0))
-    assert(splits(1)(1).categories.contains(0.0))
-
-    assert(splits(1)(2) === null)
-
-    // Check bins.
-
-    assert(bins(0)(0).category === 1.0)
-    assert(bins(0)(0).lowSplit.categories.length === 0)
-    assert(bins(0)(0).highSplit.categories.length === 1)
-    assert(bins(0)(0).highSplit.categories.contains(1.0))
-
-    assert(bins(0)(1).category === 0.0)
-    assert(bins(0)(1).lowSplit.categories.length === 1)
-    assert(bins(0)(1).lowSplit.categories.contains(1.0))
-    assert(bins(0)(1).highSplit.categories.length === 2)
-    assert(bins(0)(1).highSplit.categories.contains(1.0))
-    assert(bins(0)(1).highSplit.categories.contains(0.0))
-
-    assert(bins(0)(2) === null)
-
-    assert(bins(1)(0).category === 0.0)
-    assert(bins(1)(0).lowSplit.categories.length === 0)
-    assert(bins(1)(0).highSplit.categories.length === 1)
-    assert(bins(1)(0).highSplit.categories.contains(0.0))
-
-    assert(bins(1)(1).category === 1.0)
-    assert(bins(1)(1).lowSplit.categories.length === 1)
-    assert(bins(1)(1).lowSplit.categories.contains(0.0))
-    assert(bins(1)(1).highSplit.categories.length === 2)
-    assert(bins(1)(1).highSplit.categories.contains(0.0))
-    assert(bins(1)(1).highSplit.categories.contains(1.0))
-
-    assert(bins(1)(2) === null)
+    // no bins or splits pre-computed for ordered categorical features
+    assert(splits(0).length === 0)
+    assert(bins(0).length === 0)
   }
 
-  test("split and bin calculations for categorical variables with no sample 
for one category") {
+  test("Binary classification with 3-ary (ordered) categorical features," +
+    " with no samples for one category") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -164,104 +110,16 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       numClassesForClassification = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-
-    // Check splits.
-
-    assert(splits(0)(0).feature === 0)
-    assert(splits(0)(0).threshold === Double.MinValue)
-    assert(splits(0)(0).featureType === Categorical)
-    assert(splits(0)(0).categories.length === 1)
-    assert(splits(0)(0).categories.contains(1.0))
-
-    assert(splits(0)(1).feature === 0)
-    assert(splits(0)(1).threshold === Double.MinValue)
-    assert(splits(0)(1).featureType === Categorical)
-    assert(splits(0)(1).categories.length === 2)
-    assert(splits(0)(1).categories.contains(1.0))
-    assert(splits(0)(1).categories.contains(0.0))
-
-    assert(splits(0)(2).feature === 0)
-    assert(splits(0)(2).threshold === Double.MinValue)
-    assert(splits(0)(2).featureType === Categorical)
-    assert(splits(0)(2).categories.length === 3)
-    assert(splits(0)(2).categories.contains(1.0))
-    assert(splits(0)(2).categories.contains(0.0))
-    assert(splits(0)(2).categories.contains(2.0))
-
-    assert(splits(0)(3) === null)
-
-    assert(splits(1)(0).feature === 1)
-    assert(splits(1)(0).threshold === Double.MinValue)
-    assert(splits(1)(0).featureType === Categorical)
-    assert(splits(1)(0).categories.length === 1)
-    assert(splits(1)(0).categories.contains(0.0))
-
-    assert(splits(1)(1).feature === 1)
-    assert(splits(1)(1).threshold === Double.MinValue)
-    assert(splits(1)(1).featureType === Categorical)
-    assert(splits(1)(1).categories.length === 2)
-    assert(splits(1)(1).categories.contains(1.0))
-    assert(splits(1)(1).categories.contains(0.0))
-
-    assert(splits(1)(2).feature === 1)
-    assert(splits(1)(2).threshold === Double.MinValue)
-    assert(splits(1)(2).featureType === Categorical)
-    assert(splits(1)(2).categories.length === 3)
-    assert(splits(1)(2).categories.contains(1.0))
-    assert(splits(1)(2).categories.contains(0.0))
-    assert(splits(1)(2).categories.contains(2.0))
-
-    assert(splits(1)(3) === null)
-
-    // Check bins.
-
-    assert(bins(0)(0).category === 1.0)
-    assert(bins(0)(0).lowSplit.categories.length === 0)
-    assert(bins(0)(0).highSplit.categories.length === 1)
-    assert(bins(0)(0).highSplit.categories.contains(1.0))
-
-    assert(bins(0)(1).category === 0.0)
-    assert(bins(0)(1).lowSplit.categories.length === 1)
-    assert(bins(0)(1).lowSplit.categories.contains(1.0))
-    assert(bins(0)(1).highSplit.categories.length === 2)
-    assert(bins(0)(1).highSplit.categories.contains(1.0))
-    assert(bins(0)(1).highSplit.categories.contains(0.0))
-
-    assert(bins(0)(2).category === 2.0)
-    assert(bins(0)(2).lowSplit.categories.length === 2)
-    assert(bins(0)(2).lowSplit.categories.contains(1.0))
-    assert(bins(0)(2).lowSplit.categories.contains(0.0))
-    assert(bins(0)(2).highSplit.categories.length === 3)
-    assert(bins(0)(2).highSplit.categories.contains(1.0))
-    assert(bins(0)(2).highSplit.categories.contains(0.0))
-    assert(bins(0)(2).highSplit.categories.contains(2.0))
-
-    assert(bins(0)(3) === null)
-
-    assert(bins(1)(0).category === 0.0)
-    assert(bins(1)(0).lowSplit.categories.length === 0)
-    assert(bins(1)(0).highSplit.categories.length === 1)
-    assert(bins(1)(0).highSplit.categories.contains(0.0))
-
-    assert(bins(1)(1).category === 1.0)
-    assert(bins(1)(1).lowSplit.categories.length === 1)
-    assert(bins(1)(1).lowSplit.categories.contains(0.0))
-    assert(bins(1)(1).highSplit.categories.length === 2)
-    assert(bins(1)(1).highSplit.categories.contains(0.0))
-    assert(bins(1)(1).highSplit.categories.contains(1.0))
-
-    assert(bins(1)(2).category === 2.0)
-    assert(bins(1)(2).lowSplit.categories.length === 2)
-    assert(bins(1)(2).lowSplit.categories.contains(0.0))
-    assert(bins(1)(2).lowSplit.categories.contains(1.0))
-    assert(bins(1)(2).highSplit.categories.length === 3)
-    assert(bins(1)(2).highSplit.categories.contains(0.0))
-    assert(bins(1)(2).highSplit.categories.contains(1.0))
-    assert(bins(1)(2).highSplit.categories.contains(2.0))
-
-    assert(bins(1)(3) === null)
+    assert(splits.length === 2)
+    assert(bins.length === 2)
+    // no bins or splits pre-computed for ordered categorical features
+    assert(splits(0).length === 0)
+    assert(bins(0).length === 0)
   }
 
   test("extract categories from a number for multiclass classification") {
@@ -270,8 +128,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
   }
 
-  test("split and bin calculations for unordered categorical variables with 
multiclass " +
-    "classification") {
+  test("Multiclass classification with unordered categorical features:" +
+      " split and bin calculations") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -282,8 +140,15 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       numClassesForClassification = 100,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(metadata.isUnordered(featureIndex = 0))
+    assert(metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    assert(splits.length === 2)
+    assert(bins.length === 2)
+    assert(splits(0).length === 3)
+    assert(bins(0).length === 6)
 
     // Expecting 2^2 - 1 = 3 bins/splits
     assert(splits(0)(0).feature === 0)
@@ -321,10 +186,6 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(splits(1)(2).categories.contains(0.0))
     assert(splits(1)(2).categories.contains(1.0))
 
-    assert(splits(0)(3) === null)
-    assert(splits(1)(3) === null)
-
-
     // Check bins.
 
     assert(bins(0)(0).category === Double.MinValue)
@@ -360,13 +221,9 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bins(1)(2).highSplit.categories.contains(1.0))
     assert(bins(1)(2).highSplit.categories.contains(0.0))
 
-    assert(bins(0)(3) === null)
-    assert(bins(1)(3) === null)
-
   }
 
-  test("split and bin calculations for ordered categorical variables with 
multiclass " +
-    "classification") {
+  test("Multiclass classification with ordered categorical features: split and 
bin calculations") {
     val arr = 
DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
     assert(arr.length === 3000)
     val rdd = sc.parallelize(arr)
@@ -377,52 +234,21 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       numClassesForClassification = 100,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
+    // 2^10 - 1 > 100, so categorical features will be ordered
+
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-
-    // 2^10 - 1 > 100, so categorical variables will be ordered
-
-    assert(splits(0)(0).feature === 0)
-    assert(splits(0)(0).threshold === Double.MinValue)
-    assert(splits(0)(0).featureType === Categorical)
-    assert(splits(0)(0).categories.length === 1)
-    assert(splits(0)(0).categories.contains(1.0))
-
-    assert(splits(0)(1).feature === 0)
-    assert(splits(0)(1).threshold === Double.MinValue)
-    assert(splits(0)(1).featureType === Categorical)
-    assert(splits(0)(1).categories.length === 2)
-    assert(splits(0)(1).categories.contains(2.0))
-
-    assert(splits(0)(2).feature === 0)
-    assert(splits(0)(2).threshold === Double.MinValue)
-    assert(splits(0)(2).featureType === Categorical)
-    assert(splits(0)(2).categories.length === 3)
-    assert(splits(0)(2).categories.contains(2.0))
-    assert(splits(0)(2).categories.contains(1.0))
-
-    assert(splits(0)(10) === null)
-    assert(splits(1)(10) === null)
-
-
-    // Check bins.
-
-    assert(bins(0)(0).category === 1.0)
-    assert(bins(0)(0).lowSplit.categories.length === 0)
-    assert(bins(0)(0).highSplit.categories.length === 1)
-    assert(bins(0)(0).highSplit.categories.contains(1.0))
-    assert(bins(0)(1).category === 2.0)
-    assert(bins(0)(1).lowSplit.categories.length === 1)
-    assert(bins(0)(1).highSplit.categories.length === 2)
-    assert(bins(0)(1).highSplit.categories.contains(1.0))
-    assert(bins(0)(1).highSplit.categories.contains(2.0))
-
-    assert(bins(0)(10) === null)
-
+    assert(splits.length === 2)
+    assert(bins.length === 2)
+    // no bins or splits pre-computed for ordered categorical features
+    assert(splits(0).length === 0)
+    assert(bins(0).length === 0)
   }
 
 
-  test("classification stump with all categorical variables") {
+  test("Binary classification stump with ordered categorical features") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -433,15 +259,23 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       maxDepth = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    assert(splits.length === 2)
+    assert(bins.length === 2)
+    // no bins or splits pre-computed for ordered categorical features
+    assert(splits(0).length === 0)
+    assert(bins(0).length === 0)
+
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), 
metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), 
metadata, 0,
       new Array[Node](0), splits, bins, 10)
 
     val split = bestSplits(0)._1
-    assert(split.categories.length === 1)
-    assert(split.categories.contains(1.0))
+    assert(split.categories === List(1.0))
     assert(split.featureType === Categorical)
     assert(split.threshold === Double.MinValue)
 
@@ -452,7 +286,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(stats.impurity > 0.2)
   }
 
-  test("regression stump with all categorical variables") {
+  test("Regression stump with 3-ary (ordered) categorical features") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -462,10 +296,14 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       maxDepth = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
+
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), 
metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), 
metadata, 0,
       new Array[Node](0), splits, bins, 10)
 
     val split = bestSplits(0)._1
@@ -480,7 +318,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(stats.impurity > 0.2)
   }
 
-  test("regression stump with categorical variables of arity 2") {
+  test("Regression stump with binary (ordered) categorical features") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -490,6 +328,9 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       maxDepth = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
 
     val model = DecisionTree.train(rdd, strategy)
     validateRegressor(model, arr, 0.0)
@@ -497,12 +338,16 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(model.depth === 1)
   }
 
-  test("stump with fixed label 0 for Gini") {
+  test("Binary classification stump with fixed label 0 for Gini") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
+    val strategy = new Strategy(Classification, Gini, maxDepth = 3,
+      numClassesForClassification = 2, maxBins = 100)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
+
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -512,7 +357,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), 
metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), 
metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -521,12 +366,16 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bestSplits(0)._2.rightImpurity === 0)
   }
 
-  test("stump with fixed label 1 for Gini") {
+  test("Binary classification stump with fixed label 1 for Gini") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
+    val strategy = new Strategy(Classification, Gini, maxDepth = 3,
+      numClassesForClassification = 2, maxBins = 100)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
+
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -536,7 +385,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), 
metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), 
metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -546,12 +395,16 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bestSplits(0)._2.predict === 1)
   }
 
-  test("stump with fixed label 0 for Entropy") {
+  test("Binary classification stump with fixed label 0 for Entropy") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+    val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
+      numClassesForClassification = 2, maxBins = 100)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
+
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -561,7 +414,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), 
metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), 
metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -571,12 +424,16 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bestSplits(0)._2.predict === 0)
   }
 
-  test("stump with fixed label 1 for Entropy") {
+  test("Binary classification stump with fixed label 1 for Entropy") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+    val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
+      numClassesForClassification = 2, maxBins = 100)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
+
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -586,7 +443,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), 
metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), 
metadata, 0,
       new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
@@ -596,7 +453,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bestSplits(0)._2.predict === 1)
   }
 
-  test("second level node building with/without groups") {
+  test("Second level node building with vs. without groups") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
@@ -613,12 +470,12 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     // Train a 1-node model
     val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
     val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
-    val nodes: Array[Node] = new Array[Node](7)
-    nodes(0) = modelOneNode.topNode
-    nodes(0).leftNode = None
-    nodes(0).rightNode = None
+    val nodes: Array[Node] = new Array[Node](8)
+    nodes(1) = modelOneNode.topNode
+    nodes(1).leftNode = None
+    nodes(1).rightNode = None
 
-    val parentImpurities = Array(0.5, 0.5, 0.5)
+    val parentImpurities = Array(0, 0.5, 0.5, 0.5)
 
     // Single group second level tree construction.
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
@@ -648,16 +505,19 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     }
   }
 
-  test("stump with categorical variables for multiclass classification") {
+  test("Multiclass classification stump with 3-ary (unordered) categorical 
features") {
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 
-> 3))
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
     assert(strategy.isMulticlassClassification)
+    assert(metadata.isUnordered(featureIndex = 0))
+    assert(metadata.isUnordered(featureIndex = 1))
+
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), 
metadata, 0,
       new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
@@ -668,7 +528,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bestSplit.featureType === Categorical)
   }
 
-  test("stump with 1 continuous variable for binary classification, to check 
off-by-1 error") {
+  test("Binary classification stump with 1 continuous feature, to check 
off-by-1 error") {
     val arr = new Array[LabeledPoint](4)
     arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0))
     arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
@@ -684,26 +544,27 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(model.depth === 1)
   }
 
-  test("stump with 2 continuous variables for binary classification") {
+  test("Binary classification stump with 2 continuous features") {
     val arr = new Array[LabeledPoint](4)
     arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
     arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
     arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
     arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))
 
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClassesForClassification = 2)
 
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
     assert(model.topNode.split.get.feature === 1)
   }
 
-  test("stump with categorical variables for multiclass classification, with 
just enough bins") {
-    val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow 
unordered features
+  test("Multiclass classification stump with unordered categorical features," +
+    " with just enough bins") {
+    val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to 
allow unordered features
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
@@ -711,6 +572,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
     assert(strategy.isMulticlassClassification)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(metadata.isUnordered(featureIndex = 0))
+    assert(metadata.isUnordered(featureIndex = 1))
 
     val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 1.0)
@@ -719,7 +582,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), 
metadata, 0,
       new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
@@ -733,7 +596,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(gain.rightImpurity === 0)
   }
 
-  test("stump with continuous variables for multiclass classification") {
+  test("Multiclass classification stump with continuous features") {
     val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
@@ -746,7 +609,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), 
metadata, 0,
       new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
@@ -759,20 +622,21 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
 
   }
 
-  test("stump with continuous + categorical variables for multiclass 
classification") {
+  test("Multiclass classification stump with continuous + unordered 
categorical features") {
     val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
     assert(strategy.isMulticlassClassification)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(metadata.isUnordered(featureIndex = 0))
 
     val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 0.9)
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), 
metadata, 0,
       new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
@@ -784,17 +648,19 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bestSplit.threshold < 2020)
   }
 
-  test("stump with categorical variables for ordered multiclass 
classification") {
+  test("Multiclass classification stump with 10-ary (ordered) categorical 
features") {
     val arr = 
DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 
1 -> 10))
     assert(strategy.isMulticlassClassification)
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    assert(!metadata.isUnordered(featureIndex = 0))
+    assert(!metadata.isUnordered(featureIndex = 1))
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
metadata, 0,
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), 
metadata, 0,
       new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
@@ -805,6 +671,18 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bestSplit.featureType === Categorical)
   }
 
+  test("Multiclass classification tree with 10-ary (ordered) categorical 
features," +
+      " with just enough bins") {
+    val arr = 
DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+    val rdd = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
+      numClassesForClassification = 3, maxBins = 10,
+      categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+    assert(strategy.isMulticlassClassification)
+
+    val model = DecisionTree.train(rdd, strategy)
+    validateClassifier(model, arr, 0.6)
+  }
 
 }
 
@@ -899,5 +777,4 @@ object DecisionTreeSuite {
     arr
   }
 
-
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to