Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/1975#discussion_r16322133
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
    @@ -556,135 +568,98 @@ object DecisionTree extends Serializable with 
Logging {
         // shift when more than one group is used at deep tree level
         val groupShift = numNodes * groupIndex
     
    -    /** Find the filters used before reaching the current code. */
    -    def findParentFilters(nodeIndex: Int): List[Filter] = {
    -      if (level == 0) {
    -        List[Filter]()
    -      } else {
    -        val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + 
groupShift
    -        filters(nodeFilterIndex)
    -      }
    -    }
    -
         /**
    -     * Find whether the sample is valid input for the current node, i.e., 
whether it passes through
    -     * all the filters for the current node.
    +     * Get the node index corresponding to this data point.
    +     * This is used during training, mimicking prediction.
    +     * @return  Leaf index if the data point reaches a leaf.
    +     *          Otherwise, last node reachable in tree matching this 
example.
          */
    -    def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): 
Boolean = {
    -      // leaf
    -      if ((level > 0) && (parentFilters.length == 0)) {
    -        return false
    -      }
    -
    -      // Apply each filter and check sample validity. Return false when 
invalid condition found.
    -      parentFilters.foreach { filter =>
    -        val featureIndex = filter.split.feature
    -        val comparison = filter.comparison
    -        val isFeatureContinuous = filter.split.featureType == Continuous
    -        if (isFeatureContinuous) {
    -          val binId = treePoint.binnedFeatures(featureIndex)
    -          val bin = bins(featureIndex)(binId)
    -          val featureValue = bin.highSplit.threshold
    -          val threshold = filter.split.threshold
    -          comparison match {
    -            case -1 => if (featureValue > threshold) return false
    -            case 1 => if (featureValue <= threshold) return false
    +    def predictNodeIndex(node: Node, features: Array[Int]): Int = {
    +      if (node.isLeaf) {
    +        node.id
    +      } else {
    +        val featureIndex = node.split.get.feature
    +        val splitLeft = node.split.get.featureType match {
    +          case Continuous => {
    +            val binIndex = features(featureIndex)
    +            val featureValueUpperBound = 
bins(featureIndex)(binIndex).highSplit.threshold
    +            // bin binIndex has range (bin.lowSplit.threshold, 
bin.highSplit.threshold]
    +            // We do not need to check lowSplit since bins are separated 
by splits.
    +            featureValueUpperBound <= node.split.get.threshold
               }
    -        } else {
    -          val numFeatureCategories = 
strategy.categoricalFeaturesInfo(featureIndex)
    -          val isSpaceSufficientForAllCategoricalSplits =
    -            numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1
    -          val isUnorderedFeature =
    -            isMulticlassClassification && 
isSpaceSufficientForAllCategoricalSplits
    -          val featureValue = if (isUnorderedFeature) {
    -            treePoint.binnedFeatures(featureIndex)
    +          case Categorical => {
    +            val featureValue = if 
(unorderedFeatures.contains(featureIndex)) {
    +                features(featureIndex)
    +              } else {
    +                val binIndex = features(featureIndex)
    +                bins(featureIndex)(binIndex).category
    +              }
    +            node.split.get.categories.contains(featureValue)
    +          }
    +          case _ => throw new RuntimeException(s"predictNodeIndex failed 
for unknown reason.")
    +        }
    +        if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
    +          // Return index from next layer of nodes to train
    +          if (splitLeft) {
    +            node.id * 2 + 1 // left
               } else {
    -            val binId = treePoint.binnedFeatures(featureIndex)
    -            bins(featureIndex)(binId).category
    +            node.id * 2 + 2 // right
               }
    -          val containsFeature = 
filter.split.categories.contains(featureValue)
    -          comparison match {
    -            case -1 => if (!containsFeature) return false
    -            case 1 => if (containsFeature) return false
    +        } else {
    +          if (splitLeft) {
    +            predictNodeIndex(node.leftNode.get, features)
    +          } else {
    +            predictNodeIndex(node.rightNode.get, features)
               }
             }
           }
    +    }
     
    -      // Return true when the sample is valid for all filters.
    -      true
    +    def nodeIndexToLevel(idx: Int): Int = {
    +      if (idx == 0) {
    +        0
    +      } else {
    +        math.floor(math.log(idx) / math.log(2)).toInt
    --- End diff --
    
    If we use 1-indexing, this is `java.lang.Integer.highestOneBit(idx) - 1`, 
which is much faster than `math.log`.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to