Repository: spark
Updated Branches:
  refs/heads/master 7e28fabdf -> 909c6d812


[SPARK-16307][ML] Add test to verify the predicted variances of a DT on toy data

## What changes were proposed in this pull request?

The current tests assumes that `impurity.calculate()` returns the variance 
correctly. It should be better to make the tests independent of this 
assumption. In other words verify that the variance computed equals the 
variance computed manually on a small tree.

## How was this patch tested?

The patch is a test....

Author: MechCoder <[email protected]>

Closes #13981 from MechCoder/dt_variance.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/909c6d81
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/909c6d81
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/909c6d81

Branch: refs/heads/master
Commit: 909c6d812f6ca3a3305e4611a700c8c17905b953
Parents: 7e28fab
Author: MechCoder <[email protected]>
Authored: Wed Jul 6 02:54:44 2016 -0700
Committer: Yanbo Liang <[email protected]>
Committed: Wed Jul 6 02:54:44 2016 -0700

----------------------------------------------------------------------
 .../regression/DecisionTreeRegressorSuite.scala | 20 ++++++++++++++++++++
 .../apache/spark/ml/tree/impl/TreeTests.scala   | 12 ++++++++++++
 2 files changed, 32 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/909c6d81/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 9afb742..15fa26e 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
   DecisionTreeSuite => OldDecisionTreeSuite}
@@ -96,6 +97,25 @@ class DecisionTreeRegressorSuite
       assert(variance === expectedVariance,
         s"Expected variance $expectedVariance but got $variance.")
     }
+
+    val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc)
+    val varianceDF = TreeTests.setMetadata(varianceData, Map.empty[Int, Int], 
0)
+    dt.setMaxDepth(1)
+      .setMaxBins(6)
+      .setSeed(0)
+    val transformVarDF = dt.fit(varianceDF).transform(varianceDF)
+    val calculatedVariances = 
transformVarDF.select(dt.getVarianceCol).collect().map {
+      case Row(variance: Double) => variance
+    }
+
+    // Since max depth is set to 1, the best split point is that which splits 
the data
+    // into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for 
each
+    // data point in the left node is 0.667 and for each data point in the 
right node
+    // is 2.667
+    val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
+    calculatedVariances.zip(expectedVariances).foreach { case (actual, 
expected) =>
+      assert(actual ~== expected absTol 1e-3)
+    }
   }
 
   test("Feature importance with toy data") {

http://git-wip-us.apache.org/repos/asf/spark/blob/909c6d81/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index d2fa8d0..c90cb8c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -183,6 +183,18 @@ private[ml] object TreeTests extends SparkFunSuite {
   ))
 
   /**
+   * Create some toy data for testing correctness of variance.
+   */
+  def varianceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq(
+    new LabeledPoint(1.0, Vectors.dense(Array(0.0))),
+    new LabeledPoint(2.0, Vectors.dense(Array(1.0))),
+    new LabeledPoint(3.0, Vectors.dense(Array(2.0))),
+    new LabeledPoint(10.0, Vectors.dense(Array(3.0))),
+    new LabeledPoint(12.0, Vectors.dense(Array(4.0))),
+    new LabeledPoint(14.0, Vectors.dense(Array(5.0)))
+  ))
+
+  /**
    * Mapping from all Params to valid settings which differ from the defaults.
    * This is useful for tests which need to exercise all Params, such as 
save/load.
    * This excludes input columns to simplify some tests.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to