Repository: spark
Updated Branches:
  refs/heads/branch-1.1 162fc9512 -> eaa93555a


[SPARK-2197] [mllib] Java DecisionTree bug fix and easy-of-use

Bug fix: Before, when an RDD was created in Java and passed to 
DecisionTree.train(), the fake class tag caused problems.
* Fix: DecisionTree: Used new RDD.retag() method to allow passing RDDs from 
Java.

Other improvements to Decision Trees for easy-of-use with Java:
* impurity classes: Added instance() methods to help with Java interface.
* Strategy: Added Java-friendly constructor
--> Note: I removed quantileCalculationStrategy from the Java-friendly 
constructor since (a) it is a special class and (b) there is only 1 option 
currently.  I suspect we will redo the API before the other options are 
included.

CC: mengxr

Author: Joseph K. Bradley <[email protected]>

Closes #1740 from jkbradley/dt-java-new and squashes the following commits:

0805dc6 [Joseph K. Bradley] Changed Strategy to use JavaConverters instead of 
JavaConversions
519b1b7 [Joseph K. Bradley] * Organized imports in JavaDecisionTreeSuite.java * 
Using JavaConverters instead of JavaConversions in DecisionTreeSuite.scala
f7b5ca1 [Joseph K. Bradley] Improvements to make it easier to run DecisionTree 
from Java. * DecisionTree: Used new RDD.retag() method to allow passing RDDs 
from Java. * impurity classes: Added instance() methods to help with Java 
interface. * Strategy: Added Java-friendly constructor ** Note: I removed 
quantileCalculationStrategy from the Java-friendly constructor since (a) it is 
a special class and (b) there is only 1 option currently.  I suspect we will 
redo the API before the other options are included.
d78ada6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into 
dt-java
320853f [Joseph K. Bradley] Added JavaDecisionTreeSuite, partly written
13a585e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into 
dt-java
f1a8283 [Joseph K. Bradley] Added old JavaDecisionTreeSuite, to be updated later
225822f [Joseph K. Bradley] Bug: In DecisionTree, the method 
sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins 
from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is 
the bound for unordered categorical features, not ordered ones. The upper bound 
should be the arity (i.e., max value) of the feature.

(cherry picked from commit 2998e38a942351974da36cb619e863c6f0316e7a)
Signed-off-by: Xiangrui Meng <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eaa93555
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eaa93555
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eaa93555

Branch: refs/heads/branch-1.1
Commit: eaa93555a7f935b00a2f94a7fa50a12e11578bd7
Parents: 162fc95
Author: Joseph K. Bradley <[email protected]>
Authored: Sun Aug 3 10:36:52 2014 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Sun Aug 3 10:37:05 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  |   8 +-
 .../mllib/tree/configuration/Strategy.scala     |  29 ++++++
 .../spark/mllib/tree/impurity/Entropy.scala     |   7 ++
 .../apache/spark/mllib/tree/impurity/Gini.scala |   7 ++
 .../spark/mllib/tree/impurity/Variance.scala    |   7 ++
 .../spark/mllib/tree/JavaDecisionTreeSuite.java | 102 +++++++++++++++++++
 .../spark/mllib/tree/DecisionTreeSuite.scala    |   6 ++
 7 files changed, 162 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/eaa93555/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 382e76a..1d03e6e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -48,12 +48,12 @@ class DecisionTree (private val strategy: Strategy) extends 
Serializable with Lo
   def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
 
     // Cache input RDD for speedup during multiple passes.
-    input.cache()
+    val retaggedInput = input.retag(classOf[LabeledPoint]).cache()
     logDebug("algo = " + strategy.algo)
 
     // Find the splits and the corresponding bins (interval between the 
splits) using a sample
     // of the input data.
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy)
     val numBins = bins(0).length
     logDebug("numBins = " + numBins)
 
@@ -70,7 +70,7 @@ class DecisionTree (private val strategy: Strategy) extends 
Serializable with Lo
     // dummy value for top node (updated during first split calculation)
     val nodes = new Array[Node](maxNumNodes)
     // num features
-    val numFeatures = input.take(1)(0).features.size
+    val numFeatures = retaggedInput.take(1)(0).features.size
 
     // Calculate level for single group construction
 
@@ -107,7 +107,7 @@ class DecisionTree (private val strategy: Strategy) extends 
Serializable with Lo
       logDebug("#####################################")
 
       // Find best split for all nodes at a level.
-      val splitsStatsForLevel = DecisionTree.findBestSplits(input, 
parentImpurities,
+      val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, 
parentImpurities,
         strategy, level, filters, splits, bins, maxLevelForSingleGroup)
 
       for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {

http://git-wip-us.apache.org/repos/asf/spark/blob/eaa93555/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index fdad4f0..4ee4bcd 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.mllib.tree.configuration
 
+import scala.collection.JavaConverters._
+
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.tree.impurity.Impurity
 import org.apache.spark.mllib.tree.configuration.Algo._
@@ -61,4 +63,31 @@ class Strategy (
   val isMulticlassWithCategoricalFeatures
     = isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
 
+  /**
+   * Java-friendly constructor.
+   *
+   * @param algo classification or regression
+   * @param impurity criterion used for information gain calculation
+   * @param maxDepth Maximum depth of the tree.
+   *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal 
node + 2 leaf nodes.
+   * @param numClassesForClassification number of classes for classification. 
Default value is 2
+   *                                    leads to binary classification
+   * @param maxBins maximum number of bins used for splitting features
+   * @param categoricalFeaturesInfo A map storing information about the 
categorical variables and
+   *                                the number of discrete values they take. 
For example, an entry
+   *                                (n -> k) implies the feature n is 
categorical with k categories
+   *                                0, 1, 2, ... , k-1. It's important to note 
that features are
+   *                                zero-indexed.
+   */
+  def this(
+      algo: Algo,
+      impurity: Impurity,
+      maxDepth: Int,
+      numClassesForClassification: Int,
+      maxBins: Int,
+      categoricalFeaturesInfo: java.util.Map[java.lang.Integer, 
java.lang.Integer]) {
+    this(algo, impurity, maxDepth, numClassesForClassification, maxBins, Sort,
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, 
Int]].asScala.toMap)
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/eaa93555/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 9297c20..96d2471 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -66,4 +66,11 @@ object Entropy extends Impurity {
   @DeveloperApi
   override def calculate(count: Double, sum: Double, sumSquares: Double): 
Double =
     throw new UnsupportedOperationException("Entropy.calculate")
+
+  /**
+   * Get this impurity instance.
+   * This is useful for passing impurity parameters to a Strategy in Java.
+   */
+  def instance = this
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/eaa93555/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 2874bcf..d586f44 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -62,4 +62,11 @@ object Gini extends Impurity {
   @DeveloperApi
   override def calculate(count: Double, sum: Double, sumSquares: Double): 
Double =
     throw new UnsupportedOperationException("Gini.calculate")
+
+  /**
+   * Get this impurity instance.
+   * This is useful for passing impurity parameters to a Strategy in Java.
+   */
+  def instance = this
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/eaa93555/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 698a1a2..f7d99a4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -53,4 +53,11 @@ object Variance extends Impurity {
     val squaredLoss = sumSquares - (sum * sum) / count
     squaredLoss / count
   }
+
+  /**
+   * Get this impurity instance.
+   * This is useful for passing impurity parameters to a Strategy in Java.
+   */
+  def instance = this
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/eaa93555/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
new file mode 100644
index 0000000..2c281a1
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.configuration.Algo;
+import org.apache.spark.mllib.tree.configuration.Strategy;
+import org.apache.spark.mllib.tree.impurity.Gini;
+import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+
+
+public class JavaDecisionTreeSuite implements Serializable {
+  private transient JavaSparkContext sc;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaDecisionTreeSuite");
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel 
model) {
+    int numCorrect = 0;
+    for (LabeledPoint point: validationData) {
+      Double prediction = model.predict(point.features());
+      if (prediction == point.label()) {
+        numCorrect++;
+      }
+    }
+    return numCorrect;
+  }
+
+  @Test
+  public void runDTUsingConstructor() {
+    List<LabeledPoint> arr = 
DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
+    JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
+    HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, 
Integer>();
+    categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
+
+    int maxDepth = 4;
+    int numClasses = 2;
+    int maxBins = 100;
+    Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), 
maxDepth, numClasses,
+        maxBins, categoricalFeaturesInfo);
+
+    DecisionTree learner = new DecisionTree(strategy);
+    DecisionTreeModel model = learner.train(rdd.rdd());
+
+    int numCorrect = validatePrediction(arr, model);
+    Assert.assertTrue(numCorrect == rdd.count());
+  }
+
+  @Test
+  public void runDTUsingStaticMethods() {
+    List<LabeledPoint> arr = 
DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
+    JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
+    HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, 
Integer>();
+    categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
+
+    int maxDepth = 4;
+    int numClasses = 2;
+    int maxBins = 100;
+    Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), 
maxDepth, numClasses,
+        maxBins, categoricalFeaturesInfo);
+
+    DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
+
+    int numCorrect = validatePrediction(arr, model);
+    Assert.assertTrue(numCorrect == rdd.count());
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/eaa93555/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 8665a00..70ca7c8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.mllib.tree
 
+import scala.collection.JavaConverters._
+
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
@@ -815,6 +817,10 @@ object DecisionTreeSuite {
     arr
   }
 
+  def generateCategoricalDataPointsAsJavaList(): java.util.List[LabeledPoint] 
= {
+    generateCategoricalDataPoints().toList.asJava
+  }
+
   def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = {
     val arr = new Array[LabeledPoint](3000)
     for (i <- 0 until 3000) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to