Github user manishamde commented on a diff in the pull request:

    https://github.com/apache/spark/pull/2125#discussion_r16867878
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
    @@ -619,661 +662,258 @@ object DecisionTree extends Serializable with 
Logging {
           if (level == 0) {
             0
           } else {
    -        val globalNodeIndex = predictNodeIndex(nodes(0), 
treePoint.binnedFeatures)
    -        // Get index for this (level, group).
    -        globalNodeIndex - levelOffset - groupShift
    -      }
    -    }
    -
    -    /**
    -     * Increment aggregate in location for (node, feature, bin, label).
    -     *
    -     * @param treePoint  Data point being aggregated.
    -     * @param agg  Array storing aggregate calculation, of size:
    -     *             numClasses * numBins * numFeatures * numNodes.
    -     *             Indexed by (node, feature, bin, label) where label is 
the least significant bit.
    -     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 
at start of (level, group).
    -     */
    -    def updateBinForOrderedFeature(
    -        treePoint: TreePoint,
    -        agg: Array[Double],
    -        nodeIndex: Int,
    -        featureIndex: Int): Unit = {
    -      // Update the left or right count for one bin.
    -      val aggIndex =
    -        numClasses * numBins * numFeatures * nodeIndex +
    -        numClasses * numBins * featureIndex +
    -        numClasses * treePoint.binnedFeatures(featureIndex) +
    -        treePoint.label.toInt
    -      agg(aggIndex) += 1
    -    }
    -
    -    /**
    -     * Increment aggregate in location for (nodeIndex, featureIndex, 
[bins], label),
    -     * where [bins] ranges over all bins.
    -     * Updates left or right side of aggregate depending on split.
    -     *
    -     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 
at start of (level, group).
    -     * @param treePoint  Data point being aggregated.
    -     * @param agg  Indexed by (left/right, node, feature, bin, label)
    -     *             where label is the least significant bit.
    -     *             The left/right specifier is a 0/1 index indicating 
left/right child info.
    -     * @param rightChildShift Offset for right side of agg.
    -     */
    -    def updateBinForUnorderedFeature(
    -        nodeIndex: Int,
    -        featureIndex: Int,
    -        treePoint: TreePoint,
    -        agg: Array[Double],
    -        rightChildShift: Int): Unit = {
    -      val featureValue = treePoint.binnedFeatures(featureIndex)
    -      // Update the left or right count for one bin.
    -      val aggShift =
    -        numClasses * numBins * numFeatures * nodeIndex +
    -        numClasses * numBins * featureIndex +
    -        treePoint.label.toInt
    -      // Find all matching bins and increment their values
    -      val featureCategories = metadata.featureArity(featureIndex)
    -      val numCategoricalBins = (1 << featureCategories - 1) - 1
    -      var binIndex = 0
    -      while (binIndex < numCategoricalBins) {
    -        val aggIndex = aggShift + binIndex * numClasses
    -        if 
(bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) {
    -          agg(aggIndex) += 1
    -        } else {
    -          agg(rightChildShift + aggIndex) += 1
    -        }
    -        binIndex += 1
    -      }
    -    }
    -
    -    /**
    -     * Helper for binSeqOp.
    -     *
    -     * @param agg  Array storing aggregate calculation, of size:
    -     *             numClasses * numBins * numFeatures * numNodes.
    -     *             Indexed by (node, feature, bin, label) where label is 
the least significant bit.
    -     * @param treePoint  Data point being aggregated.
    -     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 
at start of (level, group).
    -     */
    -    def binaryOrNotCategoricalBinSeqOp(
    -        agg: Array[Double],
    -        treePoint: TreePoint,
    -        nodeIndex: Int): Unit = {
    -      // Iterate over all features.
    -      var featureIndex = 0
    -      while (featureIndex < numFeatures) {
    -        updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
    -        featureIndex += 1
    -      }
    -    }
    -
    -    val rightChildShift = numClasses * numBins * numFeatures * numNodes
    -
    -    /**
    -     * Helper for binSeqOp.
    -     *
    -     * @param agg  Array storing aggregate calculation.
    -     *             For ordered features, this is of size:
    -     *               numClasses * numBins * numFeatures * numNodes.
    -     *             For unordered features, this is of size:
    -     *               2 * numClasses * numBins * numFeatures * numNodes.
    -     * @param treePoint   Data point being aggregated.
    -     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 
at start of (level, group).
    -     */
    -    def multiclassWithCategoricalBinSeqOp(
    -        agg: Array[Double],
    -        treePoint: TreePoint,
    -        nodeIndex: Int): Unit = {
    -      val label = treePoint.label
    -      // Iterate over all features.
    -      var featureIndex = 0
    -      while (featureIndex < numFeatures) {
    -        if (metadata.isUnordered(featureIndex)) {
    -          updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, 
agg, rightChildShift)
    -        } else {
    -          updateBinForOrderedFeature(treePoint, agg, nodeIndex, 
featureIndex)
    -        }
    -        featureIndex += 1
    -      }
    -    }
    -
    -    /**
    -     * Performs a sequential aggregation over a partition for regression.
    -     * For l nodes, k features,
    -     * the count, sum, sum of squares of one of the p bins is incremented.
    -     *
    -     * @param agg Array storing aggregate calculation, updated by this 
function.
    -     *            Size: 3 * numBins * numFeatures * numNodes
    -     * @param treePoint   Data point being aggregated.
    -     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 
at start of (level, group).
    -     * @return agg
    -     */
    -    def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, 
nodeIndex: Int): Unit = {
    -      val label = treePoint.label
    -      // Iterate over all features.
    -      var featureIndex = 0
    -      while (featureIndex < numFeatures) {
    -        // Update count, sum, and sum^2 for one bin.
    -        val binIndex = treePoint.binnedFeatures(featureIndex)
    -        val aggIndex =
    -          3 * numBins * numFeatures * nodeIndex +
    -          3 * numBins * featureIndex +
    -          3 * binIndex
    -        agg(aggIndex) += 1
    -        agg(aggIndex + 1) += label
    -        agg(aggIndex + 2) += label * label
    -        featureIndex += 1
    +        val globalNodeIndex =
    +          predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, 
metadata.unorderedFeatures)
    +        globalNodeIndex - globalNodeIndexOffset
           }
         }
     
         /**
          * Performs a sequential aggregation over a partition.
    -     * For l nodes, k features,
    -     *   For classification:
    -     *     Either the left count or the right count of one of the bins is
    -     *     incremented based upon whether the feature is classified as 0 
or 1.
    -     *   For regression:
    -     *     The count, sum, sum of squares of one of the bins is 
incremented.
          *
    -     * @param agg Array storing aggregate calculation, updated by this 
function.
    -     *            Size for classification:
    -     *              numClasses * numBins * numFeatures * numNodes for 
ordered features, or
    -     *              2 * numClasses * numBins * numFeatures * numNodes for 
unordered features.
    -     *            Size for regression:
    -     *              3 * numBins * numFeatures * numNodes.
    +     * Each data point contributes to one node. For each feature,
    +     * the aggregate sufficient statistics are updated for the relevant 
bins.
    +     *
    +     * @param agg  Array storing aggregate calculation, with a set of 
sufficient statistics for
    +     *             each (node, feature, bin).
          * @param treePoint   Data point being aggregated.
          * @return  agg
          */
    -    def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] 
= {
    +    def binSeqOp(
    +        agg: DTStatsAggregator,
    +        treePoint: TreePoint): DTStatsAggregator = {
           val nodeIndex = treePointToNodeIndex(treePoint)
           // If the example does not reach this level, then nodeIndex < 0.
           // If the example reaches this level but is handled in a different 
group,
           //  then either nodeIndex < 0 (previous group) or nodeIndex >= 
numNodes (later group).
           if (nodeIndex >= 0 && nodeIndex < numNodes) {
    -        if (metadata.isClassification) {
    -          if (isMulticlassWithCategoricalFeatures) {
    -            multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex)
    -          } else {
    -            binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex)
    -          }
    +        if (metadata.unorderedFeatures.isEmpty) {
    +          orderedBinSeqOp(agg, treePoint, nodeIndex)
             } else {
    -          regressionBinSeqOp(agg, treePoint, nodeIndex)
    +          someUnorderedBinSeqOp(agg, treePoint, nodeIndex, bins, 
metadata.unorderedFeatures)
    --- End diff --
    
    ```mixed``` or ```mixedOrderedUnordered``` instead of ```someUnordered```


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to