Repository: spark
Updated Branches:
  refs/heads/master 586e716e4 -> d88f6be44


http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index bcb1187..5961a61 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model.Filter
 import org.apache.spark.mllib.tree.model.Split
@@ -28,6 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
 
 class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
@@ -35,7 +35,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 100)
+    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(bins.length === 2)
@@ -51,6 +51,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       Classification,
       Gini,
       maxDepth = 3,
+      numClassesForClassification = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
@@ -130,8 +131,9 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       Classification,
       Gini,
       maxDepth = 3,
+      numClassesForClassification = 2,
       maxBins = 100,
-      categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+      categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
 
     // Check splits.
@@ -231,6 +233,162 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bins(1)(3) === null)
   }
 
+  test("extract categories from a number for multiclass classification") {
+    val l = DecisionTree.extractMultiClassCategories(13, 10)
+    assert(l.length === 3)
+    assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq)
+  }
+
+  test("split and bin calculations for unordered categorical variables with 
multiclass " +
+    "classification") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+    assert(arr.length === 1000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new Strategy(
+      Classification,
+      Gini,
+      maxDepth = 3,
+      numClassesForClassification = 100,
+      maxBins = 100,
+      categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+    // Expecting 2^2 - 1 = 3 bins/splits
+    assert(splits(0)(0).feature === 0)
+    assert(splits(0)(0).threshold === Double.MinValue)
+    assert(splits(0)(0).featureType === Categorical)
+    assert(splits(0)(0).categories.length === 1)
+    assert(splits(0)(0).categories.contains(0.0))
+    assert(splits(1)(0).feature === 1)
+    assert(splits(1)(0).threshold === Double.MinValue)
+    assert(splits(1)(0).featureType === Categorical)
+    assert(splits(1)(0).categories.length === 1)
+    assert(splits(1)(0).categories.contains(0.0))
+
+    assert(splits(0)(1).feature === 0)
+    assert(splits(0)(1).threshold === Double.MinValue)
+    assert(splits(0)(1).featureType === Categorical)
+    assert(splits(0)(1).categories.length === 1)
+    assert(splits(0)(1).categories.contains(1.0))
+    assert(splits(1)(1).feature === 1)
+    assert(splits(1)(1).threshold === Double.MinValue)
+    assert(splits(1)(1).featureType === Categorical)
+    assert(splits(1)(1).categories.length === 1)
+    assert(splits(1)(1).categories.contains(1.0))
+
+    assert(splits(0)(2).feature === 0)
+    assert(splits(0)(2).threshold === Double.MinValue)
+    assert(splits(0)(2).featureType === Categorical)
+    assert(splits(0)(2).categories.length === 2)
+    assert(splits(0)(2).categories.contains(0.0))
+    assert(splits(0)(2).categories.contains(1.0))
+    assert(splits(1)(2).feature === 1)
+    assert(splits(1)(2).threshold === Double.MinValue)
+    assert(splits(1)(2).featureType === Categorical)
+    assert(splits(1)(2).categories.length === 2)
+    assert(splits(1)(2).categories.contains(0.0))
+    assert(splits(1)(2).categories.contains(1.0))
+
+    assert(splits(0)(3) === null)
+    assert(splits(1)(3) === null)
+
+
+    // Check bins.
+
+    assert(bins(0)(0).category === Double.MinValue)
+    assert(bins(0)(0).lowSplit.categories.length === 0)
+    assert(bins(0)(0).highSplit.categories.length === 1)
+    assert(bins(0)(0).highSplit.categories.contains(0.0))
+    assert(bins(1)(0).category === Double.MinValue)
+    assert(bins(1)(0).lowSplit.categories.length === 0)
+    assert(bins(1)(0).highSplit.categories.length === 1)
+    assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+    assert(bins(0)(1).category === Double.MinValue)
+    assert(bins(0)(1).lowSplit.categories.length === 1)
+    assert(bins(0)(1).lowSplit.categories.contains(0.0))
+    assert(bins(0)(1).highSplit.categories.length === 1)
+    assert(bins(0)(1).highSplit.categories.contains(1.0))
+    assert(bins(1)(1).category === Double.MinValue)
+    assert(bins(1)(1).lowSplit.categories.length === 1)
+    assert(bins(1)(1).lowSplit.categories.contains(0.0))
+    assert(bins(1)(1).highSplit.categories.length === 1)
+    assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+    assert(bins(0)(2).category === Double.MinValue)
+    assert(bins(0)(2).lowSplit.categories.length === 1)
+    assert(bins(0)(2).lowSplit.categories.contains(1.0))
+    assert(bins(0)(2).highSplit.categories.length === 2)
+    assert(bins(0)(2).highSplit.categories.contains(1.0))
+    assert(bins(0)(2).highSplit.categories.contains(0.0))
+    assert(bins(1)(2).category === Double.MinValue)
+    assert(bins(1)(2).lowSplit.categories.length === 1)
+    assert(bins(1)(2).lowSplit.categories.contains(1.0))
+    assert(bins(1)(2).highSplit.categories.length === 2)
+    assert(bins(1)(2).highSplit.categories.contains(1.0))
+    assert(bins(1)(2).highSplit.categories.contains(0.0))
+
+    assert(bins(0)(3) === null)
+    assert(bins(1)(3) === null)
+
+  }
+
+  test("split and bin calculations for ordered categorical variables with 
multiclass " +
+    "classification") {
+    val arr = 
DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+    assert(arr.length === 3000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new Strategy(
+      Classification,
+      Gini,
+      maxDepth = 3,
+      numClassesForClassification = 100,
+      maxBins = 100,
+      categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+    // 2^10 - 1 > 100, so categorical variables will be ordered
+
+    assert(splits(0)(0).feature === 0)
+    assert(splits(0)(0).threshold === Double.MinValue)
+    assert(splits(0)(0).featureType === Categorical)
+    assert(splits(0)(0).categories.length === 1)
+    assert(splits(0)(0).categories.contains(1.0))
+
+    assert(splits(0)(1).feature === 0)
+    assert(splits(0)(1).threshold === Double.MinValue)
+    assert(splits(0)(1).featureType === Categorical)
+    assert(splits(0)(1).categories.length === 2)
+    assert(splits(0)(1).categories.contains(2.0))
+
+    assert(splits(0)(2).feature === 0)
+    assert(splits(0)(2).threshold === Double.MinValue)
+    assert(splits(0)(2).featureType === Categorical)
+    assert(splits(0)(2).categories.length === 3)
+    assert(splits(0)(2).categories.contains(2.0))
+    assert(splits(0)(2).categories.contains(1.0))
+
+    assert(splits(0)(10) === null)
+    assert(splits(1)(10) === null)
+
+
+    // Check bins.
+
+    assert(bins(0)(0).category === 1.0)
+    assert(bins(0)(0).lowSplit.categories.length === 0)
+    assert(bins(0)(0).highSplit.categories.length === 1)
+    assert(bins(0)(0).highSplit.categories.contains(1.0))
+    assert(bins(0)(1).category === 2.0)
+    assert(bins(0)(1).lowSplit.categories.length === 1)
+    assert(bins(0)(1).highSplit.categories.length === 2)
+    assert(bins(0)(1).highSplit.categories.contains(1.0))
+    assert(bins(0)(1).highSplit.categories.contains(2.0))
+
+    assert(bins(0)(10) === null)
+
+  }
+
+
   test("classification stump with all categorical variables") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
@@ -238,6 +396,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     val strategy = new Strategy(
       Classification,
       Gini,
+      numClassesForClassification = 2,
       maxDepth = 3,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
@@ -253,8 +412,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
 
     val stats = bestSplits(0)._2
     assert(stats.gain > 0)
-    assert(stats.predict > 0.5)
-    assert(stats.predict < 0.7)
+    assert(stats.predict === 1)
+    assert(stats.prob == 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -280,8 +439,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
 
     val stats = bestSplits(0)._2
     assert(stats.gain > 0)
-    assert(stats.predict > 0.5)
-    assert(stats.predict < 0.7)
+    assert(stats.predict == 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -289,7 +447,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 100)
+    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -312,7 +470,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 100)
+    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -336,7 +494,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 100)
+    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -360,7 +518,7 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 100)
+    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -380,11 +538,11 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(bestSplits(0)._2.predict === 1)
   }
 
-  test("test second level node building with/without groups") {
+  test("second level node building with/without groups") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 100)
+    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -426,6 +584,82 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
 
   }
 
+  test("stump with categorical variables for multiclass classification") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 5,
+      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 
-> 3))
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), 
strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+    assert(bestSplit.feature === 0)
+    assert(bestSplit.categories.length === 1)
+    assert(bestSplit.categories.contains(1))
+    assert(bestSplit.featureType === Categorical)
+  }
+
+  test("stump with continuous variables for multiclass classification") {
+    val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 5,
+      numClassesForClassification = 3)
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), 
strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+
+    assert(bestSplit.feature === 1)
+    assert(bestSplit.featureType === Continuous)
+    assert(bestSplit.threshold > 1980)
+    assert(bestSplit.threshold < 2020)
+
+  }
+
+  test("stump with continuous + categorical variables for multiclass 
classification") {
+    val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 5,
+      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), 
strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+
+    assert(bestSplit.feature === 1)
+    assert(bestSplit.featureType === Continuous)
+    assert(bestSplit.threshold > 1980)
+    assert(bestSplit.threshold < 2020)
+  }
+
+  test("stump with categorical variables for ordered multiclass 
classification") {
+    val arr = 
DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 5,
+      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 
1 -> 10))
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), 
strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+    assert(bestSplit.feature === 0)
+    assert(bestSplit.categories.length === 1)
+    assert(bestSplit.categories.contains(1.0))
+    assert(bestSplit.featureType === Categorical)
+  }
+
+
 }
 
 object DecisionTreeSuite {
@@ -473,4 +707,47 @@ object DecisionTreeSuite {
     }
     arr
   }
+
+  def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = {
+    val arr = new Array[LabeledPoint](3000)
+    for (i <- 0 until 3000) {
+      if (i < 1000) {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+      } else if (i < 2000) {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
+      } else {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+      }
+    }
+    arr
+  }
+
+  def generateContinuousDataPointsForMulticlass(): Array[LabeledPoint] = {
+    val arr = new Array[LabeledPoint](3000)
+    for (i <- 0 until 3000) {
+      if (i < 2000) {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, i))
+      } else {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, i))
+      }
+    }
+    arr
+  }
+
+  def generateCategoricalDataPointsForMulticlassForOrderedFeatures():
+    Array[LabeledPoint] = {
+    val arr = new Array[LabeledPoint](3000)
+    for (i <- 0 until 3000) {
+      if (i < 1000) {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+      } else if (i < 2000) {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
+      } else {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0))
+      }
+    }
+    arr
+  }
+
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d88f6be4/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 3487f7c..e0f433b 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -82,7 +82,15 @@ object MimaExcludes {
       MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++
       MimaBuild.excludeSparkClass("storage.Values") ++
       MimaBuild.excludeSparkClass("storage.Entry") ++
-      MimaBuild.excludeSparkClass("storage.MemoryStore$Entry")
+      MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++
+      Seq(
+        ProblemFilters.exclude[IncompatibleMethTypeProblem](
+          "org.apache.spark.mllib.tree.impurity.Gini.calculate"),
+        ProblemFilters.exclude[IncompatibleMethTypeProblem](
+          "org.apache.spark.mllib.tree.impurity.Entropy.calculate"),
+        ProblemFilters.exclude[IncompatibleMethTypeProblem](
+          "org.apache.spark.mllib.tree.impurity.Variance.calculate")
+      )
     case v if v.startsWith("1.0") =>
       Seq(
         MimaBuild.excludeSparkPackage("api.java"),

Reply via email to