Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/1582#discussion_r15539977
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
@@ -19,48 +19,60 @@ package org.apache.spark.mllib.tree
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
+import org.apache.spark.mllib.rdd.DatasetInfo
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.DTParams
import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.mllib.tree.configuration.QuantileStrategies
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
+
/**
* :: Experimental ::
- * A class that implements a decision tree algorithm for classification
and regression. It
- * supports both continuous and categorical features.
- * @param strategy The configuration parameters for the tree algorithm
which specify the type
- * of algorithm (classification, regression, etc.),
feature type (continuous,
- * categorical), depth of the tree, quantile calculation
strategy, etc.
+ * An abstract class for decision tree algorithms for classification and
regression.
+ * It supports both continuous and categorical features.
+ * @param params The configuration parameters for the tree algorithm.
*/
@Experimental
-class DecisionTree (private val strategy: Strategy) extends Serializable
with Logging {
+private[mllib] abstract class DecisionTree[M <: DecisionTreeModel]
(params: DTParams)
+ extends Serializable with Logging {
- /**
+ protected final val InvalidBinIndex = -1
+
+ // depth of the decision tree
+ protected val maxDepth: Int = params.maxDepth
+
+ protected val maxBins: Int = params.maxBins
+
+ protected val quantileStrategy: QuantileStrategy.QuantileStrategy =
+ QuantileStrategies.strategy(params.quantileStrategy)
+
+ protected val maxMemoryInMB: Int = params.maxMemoryInMB
+
+ /**
* Method to train a decision tree model over an RDD
* @param input RDD of
[[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
- * @return a DecisionTreeModel that can be used for prediction
+ * @param datasetInfo Dataset metadata.
+ * @return top node of a DecisionTreeModel
*/
- def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
+ protected def trainSub(
--- End diff --
Not obvious, IMO. @mengxr and I talked and decided to call the learning
functions "run()" in the class and "train()" in the object. When I tried
naming it "train()" in both the class and object, it caused name conflicts when
calling code from Java. I added the suffix "Sub" here because the code returns
a node instead of a model, so it does not fit the "run" or "train" signature.
I am renaming it to runSub() since it is called from run().
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---