Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/1975#discussion_r16327359
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
    @@ -905,144 +821,140 @@ object DecisionTree extends Serializable with 
Logging {
         }
     
         // Calculate bin aggregates.
    +    timer.start("aggregation")
         val binAggregates = {
    -      
binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
    +      input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, 
binCombOp)
         }
         timer.stop("aggregation")
         logDebug("binAggregates.length = " + binAggregates.length)
     
         /**
    -     * Calculates the information gain for all splits based upon 
left/right split aggregates.
    -     * @param leftNodeAgg left node aggregates
    -     * @param featureIndex feature index
    -     * @param splitIndex split index
    -     * @param rightNodeAgg right node aggregate
    +     * Calculate the information gain for a given (feature, split) based 
upon left/right aggregates.
    +     * @param leftNodeAgg left node aggregates for this (feature, split)
    +     * @param rightNodeAgg right node aggregate for this (feature, split)
          * @param topImpurity impurity of the parent node
          * @return information gain and statistics for all splits
          */
         def calculateGainForSplit(
    -        leftNodeAgg: Array[Array[Array[Double]]],
    -        featureIndex: Int,
    -        splitIndex: Int,
    -        rightNodeAgg: Array[Array[Array[Double]]],
    +        leftNodeAgg: Array[Double],
    +        rightNodeAgg: Array[Double],
             topImpurity: Double): InformationGainStats = {
    -      strategy.algo match {
    -        case Classification =>
    -          val leftCounts: Array[Double] = 
leftNodeAgg(featureIndex)(splitIndex)
    -          val rightCounts: Array[Double] = 
rightNodeAgg(featureIndex)(splitIndex)
    -          val leftTotalCount = leftCounts.sum
    -          val rightTotalCount = rightCounts.sum
    -
    -          val impurity = {
    -            if (level > 0) {
    -              topImpurity
    -            } else {
    -              // Calculate impurity for root node.
    -              val rootNodeCounts = new Array[Double](numClasses)
    -              var classIndex = 0
    -              while (classIndex < numClasses) {
    -                rootNodeCounts(classIndex) = leftCounts(classIndex) + 
rightCounts(classIndex)
    -                classIndex += 1
    -              }
    -              strategy.impurity.calculate(rootNodeCounts, leftTotalCount + 
rightTotalCount)
    -            }
    -          }
    +      if (metadata.isClassification) {
    +        val leftTotalCount = leftNodeAgg.sum
    +        val rightTotalCount = rightNodeAgg.sum
     
    -          val totalCount = leftTotalCount + rightTotalCount
    -          if (totalCount == 0) {
    -            // Return arbitrary prediction.
    -            return new InformationGainStats(0, topImpurity, topImpurity, 
topImpurity, 0)
    +        val impurity = {
    +          if (level > 0) {
    +            topImpurity
    +          } else {
    +            // Calculate impurity for root node.
    +            val rootNodeCounts = new Array[Double](numClasses)
    +            var classIndex = 0
    +            while (classIndex < numClasses) {
    +              rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + 
rightNodeAgg(classIndex)
    +              classIndex += 1
    +            }
    +            metadata.impurity.calculate(rootNodeCounts, leftTotalCount + 
rightTotalCount)
               }
    +        }
     
    -          // Sum of count for each label
    -          val leftRightCounts: Array[Double] =
    -            leftCounts.zip(rightCounts).map { case (leftCount, rightCount) 
=>
    -              leftCount + rightCount
    -            }
    +        val totalCount = leftTotalCount + rightTotalCount
    +        if (totalCount == 0) {
    +          // Return arbitrary prediction.
    +          return new InformationGainStats(0, topImpurity, topImpurity, 
topImpurity, 0)
    +        }
     
    -          def indexOfLargestArrayElement(array: Array[Double]): Int = {
    -            val result = array.foldLeft(-1, Double.MinValue, 0) {
    -              case ((maxIndex, maxValue, currentIndex), currentValue) =>
    -                if (currentValue > maxValue) {
    -                  (currentIndex, currentValue, currentIndex + 1)
    -                } else {
    -                  (maxIndex, maxValue, currentIndex + 1)
    -                }
    -            }
    -            if (result._1 < 0) {
    -              throw new RuntimeException("DecisionTree internal error:" +
    -                " calculateGainForSplit failed in 
indexOfLargestArrayElement")
    -            }
    -            result._1
    +        // Sum of count for each label
    +        val leftrightNodeAgg: Array[Double] =
    +          leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) 
=>
    +            leftCount + rightCount
               }
     
    -          val predict = indexOfLargestArrayElement(leftRightCounts)
    -          val prob = leftRightCounts(predict) / totalCount
    -
    -          val leftImpurity = if (leftTotalCount == 0) {
    -            topImpurity
    -          } else {
    -            strategy.impurity.calculate(leftCounts, leftTotalCount)
    +        def indexOfLargestArrayElement(array: Array[Double]): Int = {
    +          val result = array.foldLeft(-1, Double.MinValue, 0) {
    +            case ((maxIndex, maxValue, currentIndex), currentValue) =>
    +              if (currentValue > maxValue) {
    +                (currentIndex, currentValue, currentIndex + 1)
    +              } else {
    +                (maxIndex, maxValue, currentIndex + 1)
    +              }
               }
    -          val rightImpurity = if (rightTotalCount == 0) {
    -            topImpurity
    -          } else {
    -            strategy.impurity.calculate(rightCounts, rightTotalCount)
    +          if (result._1 < 0) {
    +            throw new RuntimeException("DecisionTree internal error:" +
    +              " calculateGainForSplit failed in 
indexOfLargestArrayElement")
               }
    +          result._1
    +        }
     
    -          val leftWeight = leftTotalCount / totalCount
    -          val rightWeight = rightTotalCount / totalCount
    +        val predict = indexOfLargestArrayElement(leftrightNodeAgg)
    +        val prob = leftrightNodeAgg(predict) / totalCount
     
    -          val gain = impurity - leftWeight * leftImpurity - rightWeight * 
rightImpurity
    +        val leftImpurity = if (leftTotalCount == 0) {
    +          topImpurity
    +        } else {
    +          metadata.impurity.calculate(leftNodeAgg, leftTotalCount)
    +        }
    +        val rightImpurity = if (rightTotalCount == 0) {
    +          topImpurity
    +        } else {
    +          metadata.impurity.calculate(rightNodeAgg, rightTotalCount)
    +        }
     
    -          new InformationGainStats(gain, impurity, leftImpurity, 
rightImpurity, predict, prob)
    +        val leftWeight = leftTotalCount / totalCount
    +        val rightWeight = rightTotalCount / totalCount
     
    -        case Regression =>
    -          val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
    -          val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
    -          val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2)
    +        val gain = impurity - leftWeight * leftImpurity - rightWeight * 
rightImpurity
     
    -          val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0)
    -          val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1)
    -          val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2)
    +        new InformationGainStats(gain, impurity, leftImpurity, 
rightImpurity, predict, prob)
     
    -          val impurity = {
    -            if (level > 0) {
    -              topImpurity
    -            } else {
    -              // Calculate impurity for root node.
    -              val count = leftCount + rightCount
    -              val sum = leftSum + rightSum
    -              val sumSquares = leftSumSquares + rightSumSquares
    -              strategy.impurity.calculate(count, sum, sumSquares)
    -            }
    -          }
    +      } else {
    +        // Regression
     
    -          if (leftCount == 0) {
    -            return new InformationGainStats(0, topImpurity, 
Double.MinValue, topImpurity,
    -              rightSum / rightCount)
    -          }
    -          if (rightCount == 0) {
    -            return new InformationGainStats(0, topImpurity ,topImpurity,
    -              Double.MinValue, leftSum / leftCount)
    +        val leftCount = leftNodeAgg(0)
    +        val leftSum = leftNodeAgg(1)
    +        val leftSumSquares = leftNodeAgg(2)
    +
    +        val rightCount = rightNodeAgg(0)
    +        val rightSum = rightNodeAgg(1)
    +        val rightSumSquares = rightNodeAgg(2)
    +
    +        val impurity = {
    +          if (level > 0) {
    +            topImpurity
    +          } else {
    +            // Calculate impurity for root node.
    +            val count = leftCount + rightCount
    +            val sum = leftSum + rightSum
    +            val sumSquares = leftSumSquares + rightSumSquares
    +            metadata.impurity.calculate(count, sum, sumSquares)
               }
    +        }
     
    -          val leftImpurity = strategy.impurity.calculate(leftCount, 
leftSum, leftSumSquares)
    -          val rightImpurity = strategy.impurity.calculate(rightCount, 
rightSum, rightSumSquares)
    +        if (leftCount == 0) {
    +          return new InformationGainStats(0, topImpurity, Double.MinValue, 
topImpurity,
    +            rightSum / rightCount)
    +        }
    +        if (rightCount == 0) {
    +          return new InformationGainStats(0, topImpurity, topImpurity,
    +            Double.MinValue, leftSum / leftCount)
    +        }
     
    -          val leftWeight = leftCount.toDouble / (leftCount + rightCount)
    -          val rightWeight = rightCount.toDouble / (leftCount + rightCount)
    +        val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, 
leftSumSquares)
    +        val rightImpurity = metadata.impurity.calculate(rightCount, 
rightSum, rightSumSquares)
     
    -          val gain = {
    -            if (level > 0) {
    -              impurity - leftWeight * leftImpurity - rightWeight * 
rightImpurity
    -            } else {
    -              impurity - leftWeight * leftImpurity - rightWeight * 
rightImpurity
    -            }
    +        val leftWeight = leftCount.toDouble / (leftCount + rightCount)
    +        val rightWeight = rightCount.toDouble / (leftCount + rightCount)
    +
    +        val gain = {
    +          if (level > 0) {
    +            impurity - leftWeight * leftImpurity - rightWeight * 
rightImpurity
    +          } else {
    +            impurity - leftWeight * leftImpurity - rightWeight * 
rightImpurity
    --- End diff --
    
    There is no difference between if and else.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to