Github user manishamde commented on a diff in the pull request:

    https://github.com/apache/spark/pull/1975#discussion_r16328186
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
    @@ -531,160 +526,121 @@ object DecisionTree extends Serializable with 
Logging {
     
         // numNodes:  Number of nodes in this (level of tree, group),
         //            where nodes at deeper (larger) levels may be divided 
into groups.
    -    val numNodes = math.pow(2, level).toInt / numGroups
    +    val numNodes = (1 << level) / numGroups
         logDebug("numNodes = " + numNodes)
     
         // Find the number of features by looking at the first sample.
    -    val numFeatures = input.first().binnedFeatures.size
    +    val numFeatures = metadata.numFeatures
         logDebug("numFeatures = " + numFeatures)
     
         // numBins:  Number of bins = 1 + number of possible splits
         val numBins = bins(0).length
         logDebug("numBins = " + numBins)
     
    -    val numClasses = strategy.numClassesForClassification
    +    val numClasses = metadata.numClasses
         logDebug("numClasses = " + numClasses)
     
    -    val isMulticlassClassification = strategy.isMulticlassClassification
    -    logDebug("isMulticlassClassification = " + isMulticlassClassification)
    +    val isMulticlass = metadata.isMulticlass
    +    logDebug("isMulticlass = " + isMulticlass)
     
    -    val isMulticlassClassificationWithCategoricalFeatures
    -      = strategy.isMulticlassWithCategoricalFeatures
    -    logDebug("isMultiClassWithCategoricalFeatures = " +
    -      isMulticlassClassificationWithCategoricalFeatures)
    +    val isMulticlassWithCategoricalFeatures = 
metadata.isMulticlassWithCategoricalFeatures
    +    logDebug("isMultiClassWithCategoricalFeatures = " + 
isMulticlassWithCategoricalFeatures)
     
         // shift when more than one group is used at deep tree level
         val groupShift = numNodes * groupIndex
     
    -    /** Find the filters used before reaching the current code. */
    -    def findParentFilters(nodeIndex: Int): List[Filter] = {
    -      if (level == 0) {
    -        List[Filter]()
    -      } else {
    -        val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + 
groupShift
    -        filters(nodeFilterIndex)
    -      }
    -    }
    -
         /**
    -     * Find whether the sample is valid input for the current node, i.e., 
whether it passes through
    -     * all the filters for the current node.
    +     * Get the node index corresponding to this data point.
    +     * This is used during training, mimicking prediction.
    +     * @return  Leaf index if the data point reaches a leaf.
    +     *          Otherwise, last node reachable in tree matching this 
example.
          */
    -    def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): 
Boolean = {
    -      // leaf
    -      if ((level > 0) && (parentFilters.length == 0)) {
    -        return false
    -      }
    -
    -      // Apply each filter and check sample validity. Return false when 
invalid condition found.
    -      parentFilters.foreach { filter =>
    -        val featureIndex = filter.split.feature
    -        val comparison = filter.comparison
    -        val isFeatureContinuous = filter.split.featureType == Continuous
    -        if (isFeatureContinuous) {
    -          val binId = treePoint.binnedFeatures(featureIndex)
    -          val bin = bins(featureIndex)(binId)
    -          val featureValue = bin.highSplit.threshold
    -          val threshold = filter.split.threshold
    -          comparison match {
    -            case -1 => if (featureValue > threshold) return false
    -            case 1 => if (featureValue <= threshold) return false
    +    def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = {
    +      if (node.isLeaf) {
    +        node.id
    +      } else {
    +        val featureIndex = node.split.get.feature
    +        val splitLeft = node.split.get.featureType match {
    +          case Continuous => {
    +            val binIndex = binnedFeatures(featureIndex)
    +            val featureValueUpperBound = 
bins(featureIndex)(binIndex).highSplit.threshold
    +            // bin binIndex has range (bin.lowSplit.threshold, 
bin.highSplit.threshold]
    +            // We do not need to check lowSplit since bins are separated 
by splits.
    +            featureValueUpperBound <= node.split.get.threshold
               }
    -        } else {
    -          val numFeatureCategories = 
strategy.categoricalFeaturesInfo(featureIndex)
    -          val isSpaceSufficientForAllCategoricalSplits =
    -            numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1
    -          val isUnorderedFeature =
    -            isMulticlassClassification && 
isSpaceSufficientForAllCategoricalSplits
    -          val featureValue = if (isUnorderedFeature) {
    -            treePoint.binnedFeatures(featureIndex)
    +          case Categorical => {
    +            val featureValue = if (metadata.isUnordered(featureIndex)) {
    +                binnedFeatures(featureIndex)
    +              } else {
    +                val binIndex = binnedFeatures(featureIndex)
    +                bins(featureIndex)(binIndex).category
    +              }
    +            node.split.get.categories.contains(featureValue)
    +          }
    +          case _ => throw new RuntimeException(s"predictNodeIndex failed 
for unknown reason.")
    +        }
    +        if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
    +          // Return index from next layer of nodes to train
    +          if (splitLeft) {
    +            node.id * 2 + 1 // left
               } else {
    -            val binId = treePoint.binnedFeatures(featureIndex)
    -            bins(featureIndex)(binId).category
    +            node.id * 2 + 2 // right
               }
    -          val containsFeature = 
filter.split.categories.contains(featureValue)
    -          comparison match {
    -            case -1 => if (!containsFeature) return false
    -            case 1 => if (containsFeature) return false
    +        } else {
    +          if (splitLeft) {
    +            predictNodeIndex(node.leftNode.get, binnedFeatures)
    +          } else {
    +            predictNodeIndex(node.rightNode.get, binnedFeatures)
               }
             }
           }
    +    }
     
    -      // Return true when the sample is valid for all filters.
    -      true
    +    def nodeIndexToLevel(idx: Int): Int = {
    +      if (idx == 0) {
    +        0
    +      } else {
    +        math.floor(math.log(idx) / math.log(2)).toInt
    +      }
         }
     
    +    // Used for treePointToNodeIndex
    +    val levelOffset = (1 << level) - 1
    +
         /**
    -     * Finds bins for all nodes (and all features) at a given level.
    -     * For l nodes, k features the storage is as follows:
    -     * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. 
, b_lk,
    -     * where b_ij is an integer between 0 and numBins - 1 for regressions 
and binary
    -     * classification and the categorical feature value in  multiclass 
classification.
    -     * Invalid sample is denoted by noting bin for feature 1 as -1.
    -     *
    -     * For unordered features, the "bin index" returned is actually the 
feature value (category).
    -     *
    -     * @return  Array of size 1 + numFeatures * numNodes, where
    -     *          arr(0) = label for labeledPoint, and
    -     *          arr(1 + numFeatures * nodeIndex + featureIndex) =
    -     *            bin index for this labeledPoint
    -     *            (or InvalidBinIndex if labeledPoint is not handled by 
this node)
    +     * Find the node (indexed from 0 at the start of this level) for the 
given example.
    --- End diff --
    
    May be the comment should reflect that the indexing from the start of the 
groupShift at a given level.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to