Github user manishamde commented on a diff in the pull request:
https://github.com/apache/spark/pull/2125#discussion_r16866883
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
@@ -1395,96 +1027,24 @@ object DecisionTree extends Serializable with
Logging {
Double.MinValue)
} else {
new Bin(
- splits(featureIndex)(index - 1),
- splits(featureIndex)(index),
+ splits(featureIndex)(splitIndex - 1),
+ splits(featureIndex)(splitIndex),
Categorical,
Double.MinValue)
}
}
- index += 1
- }
- } else { // ordered feature
- /* For a given categorical feature, use a subsample of the
data
- * to choose how to arrange possible splits.
- * This examines each category and computes a centroid.
- * These centroids are later used to sort the possible
splits.
- * centroidForCategories is a mapping: category (for the
given feature) --> centroid
- */
- val centroidForCategories = {
- if (isMulticlass) {
- // For categorical variables in multiclass
classification,
- // each bin is a category. The bins are sorted and they
- // are ordered by calculating the impurity of their
corresponding labels.
- sampledInput.map(lp => (lp.features(featureIndex),
lp.label))
- .groupBy(_._1)
- .mapValues(x => x.groupBy(_._2).mapValues(x =>
x.size.toDouble))
- .map(x => (x._1, x._2.values.toArray))
- .map(x => (x._1, metadata.impurity.calculate(x._2,
x._2.sum)))
- } else { // regression or binary classification
- // For categorical variables in regression and binary
classification,
- // each bin is a category. The bins are sorted and they
- // are ordered by calculating the centroid of their
corresponding labels.
- sampledInput.map(lp => (lp.features(featureIndex),
lp.label))
- .groupBy(_._1)
- .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
- }
- }
-
- logDebug("centroid for categories = " +
centroidForCategories.mkString(","))
-
- // Check for missing categorical variables and putting them
last in the sorted list.
- val fullCentroidForCategories =
scala.collection.mutable.Map[Double,Double]()
- for (i <- 0 until featureCategories) {
- if (centroidForCategories.contains(i)) {
- fullCentroidForCategories(i) = centroidForCategories(i)
- } else {
- fullCentroidForCategories(i) = Double.MaxValue
- }
- }
-
- // bins sorted by centroids
- val categoriesSortedByCentroid =
fullCentroidForCategories.toList.sortBy(_._2)
-
- logDebug("centroid for categorical variable = " +
categoriesSortedByCentroid)
-
- var categoriesForSplit = List[Double]()
- categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
- case ((key, value), index) =>
- categoriesForSplit = key :: categoriesForSplit
- splits(featureIndex)(index) = new Split(featureIndex,
Double.MinValue,
- Categorical, categoriesForSplit)
- bins(featureIndex)(index) = {
- if (index == 0) {
- new Bin(new DummyCategoricalSplit(featureIndex,
Categorical),
- splits(featureIndex)(0), Categorical, key)
- } else {
- new Bin(splits(featureIndex)(index-1),
splits(featureIndex)(index),
- Categorical, key)
- }
- }
+ splitIndex += 1
}
+ } else {
+ // Ordered features: high-arity features, or not multiclass
classification
--- End diff --
Unclear what "or not multiclass classification" means here.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]