Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/2063#discussion_r16503947
--- Diff: docs/mllib-decision-tree.md ---
@@ -114,35 +135,122 @@ perform classification using a decision tree using
Gini impurity as an impurity
maximum tree depth of 5. The training error is calculated to measure the
algorithm accuracy.
<div class="codetabs">
+
<div data-lang="scala">
{% highlight scala %}
-import org.apache.spark.SparkContext
import org.apache.spark.mllib.tree.DecisionTree
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.impurity.Gini
+import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file
-val data = sc.textFile("data/mllib/sample_tree_data.csv")
-val parsedData = data.map { line =>
- val parts = line.split(',').map(_.toDouble)
- LabeledPoint(parts(0), Vectors.dense(parts.tail))
-}
+val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
-// Run training algorithm to build the model
+// Train a DecisionTree model.
+// Empty categoricalFeaturesInfo indicates all features are continuous.
+val numClasses = 2
+val categoricalFeaturesInfo = Map[Int, Int]()
+val impurity = "gini"
val maxDepth = 5
-val model = DecisionTree.train(parsedData, Classification, Gini, maxDepth)
+val maxBins = 100
+
+val model = DecisionTree.trainClassifier(data, numClasses,
categoricalFeaturesInfo, impurity,
+ maxDepth, maxBins)
-// Evaluate model on training examples and compute training error
-val labelAndPreds = parsedData.map { point =>
+// Evaluate model on training instances and compute training error
+val labelAndPreds = data.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
-val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble /
parsedData.count
+val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble /
data.count
println("Training Error = " + trainErr)
+println("Learned classification tree model:\n" + model)
+{% endhighlight %}
+</div>
+
+<div data-lang="java">
+{% highlight java %}
+import java.util.HashMap;
+import scala.Tuple2;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.DecisionTree;
+import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.SparkConf;
+
+SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
+JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+String datapath = "data/mllib/sample_libsvm_data.txt";
+JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(),
datapath).toJavaRDD();
--- End diff --
We cached the binned features in training. But in this example, we visit
the raw features twice. Since it is reading from disk, it should help if we
cache the data.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]