Github user manishamde commented on a diff in the pull request:

    https://github.com/apache/spark/pull/2125#discussion_r16866883
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
    @@ -1395,96 +1027,24 @@ object DecisionTree extends Serializable with 
Logging {
                           Double.MinValue)
                       } else {
                         new Bin(
    -                      splits(featureIndex)(index - 1),
    -                      splits(featureIndex)(index),
    +                      splits(featureIndex)(splitIndex - 1),
    +                      splits(featureIndex)(splitIndex),
                           Categorical,
                           Double.MinValue)
                       }
                     }
    -                index += 1
    -              }
    -            } else { // ordered feature
    -              /* For a given categorical feature, use a subsample of the 
data
    -               * to choose how to arrange possible splits.
    -               * This examines each category and computes a centroid.
    -               * These centroids are later used to sort the possible 
splits.
    -               * centroidForCategories is a mapping: category (for the 
given feature) --> centroid
    -               */
    -              val centroidForCategories = {
    -                if (isMulticlass) {
    -                  // For categorical variables in multiclass 
classification,
    -                  // each bin is a category. The bins are sorted and they
    -                  // are ordered by calculating the impurity of their 
corresponding labels.
    -                  sampledInput.map(lp => (lp.features(featureIndex), 
lp.label))
    -                   .groupBy(_._1)
    -                   .mapValues(x => x.groupBy(_._2).mapValues(x => 
x.size.toDouble))
    -                   .map(x => (x._1, x._2.values.toArray))
    -                   .map(x => (x._1, metadata.impurity.calculate(x._2, 
x._2.sum)))
    -                } else { // regression or binary classification
    -                  // For categorical variables in regression and binary 
classification,
    -                  // each bin is a category. The bins are sorted and they
    -                  // are ordered by calculating the centroid of their 
corresponding labels.
    -                  sampledInput.map(lp => (lp.features(featureIndex), 
lp.label))
    -                    .groupBy(_._1)
    -                    .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
    -                }
    -              }
    -
    -              logDebug("centroid for categories = " + 
centroidForCategories.mkString(","))
    -
    -              // Check for missing categorical variables and putting them 
last in the sorted list.
    -              val fullCentroidForCategories = 
scala.collection.mutable.Map[Double,Double]()
    -              for (i <- 0 until featureCategories) {
    -                if (centroidForCategories.contains(i)) {
    -                  fullCentroidForCategories(i) = centroidForCategories(i)
    -                } else {
    -                  fullCentroidForCategories(i) = Double.MaxValue
    -                }
    -              }
    -
    -              // bins sorted by centroids
    -              val categoriesSortedByCentroid = 
fullCentroidForCategories.toList.sortBy(_._2)
    -
    -              logDebug("centroid for categorical variable = " + 
categoriesSortedByCentroid)
    -
    -              var categoriesForSplit = List[Double]()
    -              categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
    -                case ((key, value), index) =>
    -                  categoriesForSplit = key :: categoriesForSplit
    -                  splits(featureIndex)(index) = new Split(featureIndex, 
Double.MinValue,
    -                    Categorical, categoriesForSplit)
    -                  bins(featureIndex)(index) = {
    -                    if (index == 0) {
    -                      new Bin(new DummyCategoricalSplit(featureIndex, 
Categorical),
    -                        splits(featureIndex)(0), Categorical, key)
    -                    } else {
    -                      new Bin(splits(featureIndex)(index-1), 
splits(featureIndex)(index),
    -                        Categorical, key)
    -                    }
    -                  }
    +                splitIndex += 1
                   }
    +            } else {
    +              // Ordered features: high-arity features, or not multiclass 
classification
    --- End diff --
    
    Unclear what "or not multiclass classification" means here.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to