Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7294#discussion_r34757708
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
    @@ -0,0 +1,1131 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import java.io.IOException
    +
    +import scala.collection.mutable
    +import scala.util.Random
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.ml.classification.DecisionTreeClassificationModel
    +import org.apache.spark.ml.regression.DecisionTreeRegressionModel
    +import org.apache.spark.ml.tree._
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, 
Strategy => OldStrategy}
    +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, 
DecisionTreeMetadata,
    +  TimeTracker}
    +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
    +import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict}
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.storage.StorageLevel
    +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
    +
    +
    +private[ml] object RandomForest extends Logging {
    +
    +  /**
    +   * Train a random forest.
    +   * @param input Training data: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]]
    +   * @return an unweighted set of trees
    +   */
    +  def run(
    +      input: RDD[LabeledPoint],
    +      strategy: OldStrategy,
    +      numTrees: Int,
    +      featureSubsetStrategy: String,
    +      seed: Long,
    +      parentUID: Option[String] = None): Array[DecisionTreeModel] = {
    +
    +    val timer = new TimeTracker()
    +
    +    timer.start("total")
    +
    +    timer.start("init")
    +
    +    val retaggedInput = input.retag(classOf[LabeledPoint])
    +    val metadata =
    +      DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, 
numTrees, featureSubsetStrategy)
    +    logDebug("algo = " + strategy.algo)
    +    logDebug("numTrees = " + numTrees)
    +    logDebug("seed = " + seed)
    +    logDebug("maxBins = " + metadata.maxBins)
    +    logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
    +    logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
    +    logDebug("subsamplingRate = " + strategy.subsamplingRate)
    +
    +    // Find the splits and the corresponding bins (interval between the 
splits) using a sample
    +    // of the input data.
    +    timer.start("findSplitsBins")
    +    val splits = findSplits(retaggedInput, metadata)
    +    timer.stop("findSplitsBins")
    +    logDebug("numBins: feature: number of bins")
    +    logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
    +      s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
    +    }.mkString("\n"))
    +
    +    // Bin feature values (TreePoint representation).
    +    // Cache input RDD for speedup during multiple passes.
    +    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, 
metadata)
    +
    +    val withReplacement = if (numTrees > 1) true else false
    +
    +    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 
strategy.subsamplingRate, numTrees,
    +      withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
    +
    +    // depth of the decision tree
    +    val maxDepth = strategy.maxDepth
    +    require(maxDepth <= 30,
    +      s"DecisionTree currently only supports maxDepth <= 30, but was given 
maxDepth = $maxDepth.")
    +
    +    // Max memory usage for aggregates
    +    // TODO: Calculate memory usage more precisely.
    +    val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
    +    logDebug("max memory usage for aggregates = " + maxMemoryUsage + " 
bytes.")
    +    val maxMemoryPerNode = {
    +      val featureSubset: Option[Array[Int]] = if 
(metadata.subsamplingFeatures) {
    +        // Find numFeaturesPerNode largest bins to get an upper bound on 
memory usage.
    +        Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
    +          .take(metadata.numFeaturesPerNode).map(_._2))
    +      } else {
    +        None
    +      }
    +      RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
    +    }
    +    require(maxMemoryPerNode <= maxMemoryUsage,
    +      s"RandomForest/DecisionTree given maxMemoryInMB = 
${strategy.maxMemoryInMB}," +
    +        " which is too small for the given features." +
    +        s"  Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
    +
    +    timer.stop("init")
    +
    +    /*
    +     * The main idea here is to perform group-wise training of the 
decision tree nodes thus
    +     * reducing the passes over the data from (# nodes) to (# nodes / 
maxNumberOfNodesPerGroup).
    +     * Each data sample is handled by a particular node (or it reaches a 
leaf and is not used
    +     * in lower levels).
    +     */
    +
    +    // Create an RDD of node Id cache.
    +    // At first, all the rows belong to the root nodes (node Id == 1).
    +    val nodeIdCache = if (strategy.useNodeIdCache) {
    +      Some(NodeIdCache.init(
    +        data = baggedInput,
    +        numTrees = numTrees,
    +        checkpointInterval = strategy.checkpointInterval,
    +        initVal = 1))
    +    } else {
    +      None
    +    }
    +
    +    // FIFO queue of nodes to train: (treeIndex, node)
    +    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
    +
    +    val rng = new Random()
    +    rng.setSeed(seed)
    +
    +    // Allocate and queue root nodes.
    +    val topNodes = 
Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
    +    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, 
topNodes(treeIndex))))
    +
    +    while (nodeQueue.nonEmpty) {
    +      // Collect some nodes to split, and choose features for each node 
(if subsampling).
    +      // Each group of nodes may come from one or multiple trees, and at 
multiple levels.
    +      val (nodesForGroup, treeToNodeToIndexInfo) =
    +        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, 
metadata, rng)
    +      // Sanity check (should never occur):
    +      assert(nodesForGroup.nonEmpty,
    +        s"RandomForest selected empty nodesForGroup.  Error for unknown 
reason.")
    +
    +      // Choose node splits, and enqueue new nodes as needed.
    +      timer.start("findBestSplits")
    +      RandomForest.findBestSplits(baggedInput, metadata, topNodes, 
nodesForGroup,
    +        treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
    +      timer.stop("findBestSplits")
    +    }
    +
    +    baggedInput.unpersist()
    +
    +    timer.stop("total")
    +
    +    logInfo("Internal timing for DecisionTree:")
    +    logInfo(s"$timer")
    +
    +    // Delete any remaining checkpoints used for node Id cache.
    +    if (nodeIdCache.nonEmpty) {
    +      try {
    +        nodeIdCache.get.deleteAllCheckpoints()
    +      } catch {
    +        case e: IOException =>
    +          logWarning(s"delete all checkpoints failed. Error reason: 
${e.getMessage}")
    +      }
    +    }
    +
    +    parentUID match {
    +      case Some(uid) =>
    +        if (strategy.algo == OldAlgo.Classification) {
    +          topNodes.map(rootNode => new 
DecisionTreeClassificationModel(uid, rootNode.toNode))
    +        } else {
    +          topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, 
rootNode.toNode))
    +        }
    +      case None =>
    +        if (strategy.algo == OldAlgo.Classification) {
    +          topNodes.map(rootNode => new 
DecisionTreeClassificationModel(rootNode.toNode))
    +        } else {
    +          topNodes.map(rootNode => new 
DecisionTreeRegressionModel(rootNode.toNode))
    +        }
    +    }
    +  }
    +
    +  /**
    +   * Get the node index corresponding to this data point.
    +   * This function mimics prediction, passing an example from the root 
node down to a leaf
    +   * or unsplit node; that node's index is returned.
    +   *
    +   * @param node  Node in tree from which to classify the given data point.
    +   * @param binnedFeatures  Binned feature vector for data point.
    +   * @param splits possible splits for all features, indexed 
(numFeatures)(numSplits)
    +   * @return  Leaf index if the data point reaches a leaf.
    +   *          Otherwise, last node reachable in tree matching this example.
    +   *          Note: This is the global node index, i.e., the index used in 
the tree.
    +   *                This index is different from the index used during 
training a particular
    +   *                group of nodes on one call to [[findBestSplits()]].
    +   */
    +  private def predictNodeIndex(
    +      node: LearningNode,
    +      binnedFeatures: Array[Int],
    +      splits: Array[Array[Split]]): Int = {
    +    if (node.isLeaf || node.split.isEmpty) {
    +      node.id
    +    } else {
    +      val split = node.split.get
    +      val featureIndex = split.featureIndex
    +      val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), 
splits(featureIndex))
    +      if (node.leftChild.isEmpty) {
    +        // Not yet split. Return index from next layer of nodes to train
    +        if (splitLeft) {
    +          LearningNode.leftChildIndex(node.id)
    +        } else {
    +          LearningNode.rightChildIndex(node.id)
    +        }
    +      } else {
    +        if (splitLeft) {
    +          predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
    +        } else {
    +          predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
    +        }
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Helper for binSeqOp, for data which can contain a mix of ordered and 
unordered features.
    +   *
    +   * For ordered features, a single bin is updated.
    +   * For unordered features, bins correspond to subsets of categories; 
either the left or right bin
    +   * for each subset is updated.
    +   *
    +   * @param agg  Array storing aggregate calculation, with a set of 
sufficient statistics for
    +   *             each (feature, bin).
    +   * @param treePoint  Data point being aggregated.
    +   * @param splits possible splits indexed (numFeatures)(numSplits)
    +   * @param unorderedFeatures  Set of indices of unordered features.
    +   * @param instanceWeight  Weight (importance) of instance in dataset.
    +   */
    +  private def mixedBinSeqOp(
    +      agg: DTStatsAggregator,
    +      treePoint: TreePoint,
    +      splits: Array[Array[Split]],
    +      unorderedFeatures: Set[Int],
    +      instanceWeight: Double,
    +      featuresForNode: Option[Array[Int]]): Unit = {
    +    val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
    +      // Use subsampled features
    +      featuresForNode.get.length
    +    } else {
    +      // Use all features
    +      agg.metadata.numFeatures
    +    }
    +    // Iterate over features.
    +    var featureIndexIdx = 0
    +    while (featureIndexIdx < numFeaturesPerNode) {
    +      val featureIndex = if (featuresForNode.nonEmpty) {
    +        featuresForNode.get.apply(featureIndexIdx)
    +      } else {
    +        featureIndexIdx
    +      }
    +      if (unorderedFeatures.contains(featureIndex)) {
    +        // Unordered feature
    +        val featureValue = treePoint.binnedFeatures(featureIndex)
    +        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
    +          agg.getLeftRightFeatureOffsets(featureIndexIdx)
    +        // Update the left or right bin for each split.
    +        val numSplits = agg.metadata.numSplits(featureIndex)
    +        val featureSplits = splits(featureIndex)
    +        var splitIndex = 0
    +        while (splitIndex < numSplits) {
    +          if (featureSplits(splitIndex).shouldGoLeft(featureValue, 
featureSplits)) {
    +            agg.featureUpdate(leftNodeFeatureOffset, splitIndex, 
treePoint.label, instanceWeight)
    +          } else {
    +            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, 
treePoint.label, instanceWeight)
    +          }
    +          splitIndex += 1
    +        }
    +      } else {
    +        // Ordered feature
    +        val binIndex = treePoint.binnedFeatures(featureIndex)
    +        agg.update(featureIndexIdx, binIndex, treePoint.label, 
instanceWeight)
    +      }
    +      featureIndexIdx += 1
    +    }
    +  }
    +
    +  /**
    +   * Helper for binSeqOp, for regression and for classification with only 
ordered features.
    +   *
    +   * For each feature, the sufficient statistics of one bin are updated.
    +   *
    +   * @param agg  Array storing aggregate calculation, with a set of 
sufficient statistics for
    +   *             each (feature, bin).
    +   * @param treePoint  Data point being aggregated.
    +   * @param instanceWeight  Weight (importance) of instance in dataset.
    +   */
    +  private def orderedBinSeqOp(
    +      agg: DTStatsAggregator,
    +      treePoint: TreePoint,
    +      instanceWeight: Double,
    +      featuresForNode: Option[Array[Int]]): Unit = {
    +    val label = treePoint.label
    +
    +    // Iterate over features.
    +    if (featuresForNode.nonEmpty) {
    +      // Use subsampled features
    +      var featureIndexIdx = 0
    +      while (featureIndexIdx < featuresForNode.get.length) {
    +        val binIndex = 
treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
    +        agg.update(featureIndexIdx, binIndex, label, instanceWeight)
    +        featureIndexIdx += 1
    +      }
    +    } else {
    +      // Use all features
    +      val numFeatures = agg.metadata.numFeatures
    +      var featureIndex = 0
    +      while (featureIndex < numFeatures) {
    +        val binIndex = treePoint.binnedFeatures(featureIndex)
    +        agg.update(featureIndex, binIndex, label, instanceWeight)
    +        featureIndex += 1
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Given a group of nodes, this finds the best split for each node.
    +   *
    +   * @param input Training data: RDD of 
[[org.apache.spark.mllib.tree.impl.TreePoint]]
    +   * @param metadata Learning and dataset metadata
    +   * @param topNodes Root node for each tree.  Used for matching instances 
with nodes.
    +   * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
    +   * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> 
nodeIndexInfo,
    +   *                              where nodeIndexInfo stores the index in 
the group and the
    +   *                              feature subsets (if using feature 
subsets).
    +   * @param splits possible splits for all features, indexed 
(numFeatures)(numSplits)
    +   * @param nodeQueue  Queue of nodes to split, with values (treeIndex, 
node).
    +   *                   Updated with new non-leaf nodes which are created.
    +   * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
    +   *                    each value in the array is the data point's node Id
    +   *                    for a corresponding tree. This is used to prevent 
the need
    +   *                    to pass the entire tree to the executors during
    +   *                    the node stat aggregation phase.
    +   */
    +  private[tree] def findBestSplits(
    +      input: RDD[BaggedPoint[TreePoint]],
    +      metadata: DecisionTreeMetadata,
    +      topNodes: Array[LearningNode],
    +      nodesForGroup: Map[Int, Array[LearningNode]],
    +      treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
    +      splits: Array[Array[Split]],
    +      nodeQueue: mutable.Queue[(Int, LearningNode)],
    +      timer: TimeTracker = new TimeTracker,
    +      nodeIdCache: Option[NodeIdCache] = None): Unit = {
    +
    +    /*
    +     * The high-level descriptions of the best split optimizations are 
noted here.
    +     *
    +     * *Group-wise training*
    +     * We perform bin calculations for groups of nodes to reduce the 
number of
    +     * passes over the data.  Each iteration requires more computation and 
storage,
    +     * but saves several iterations over the data.
    +     *
    +     * *Bin-wise computation*
    +     * We use a bin-wise best split computation strategy instead of a 
straightforward best split
    +     * computation strategy. Instead of analyzing each sample for 
contribution to the left/right
    +     * child node impurity of every split, we first categorize each 
feature of a sample into a
    +     * bin. We exploit this structure to calculate aggregates for bins and 
then use these aggregates
    +     * to calculate information gain for each split.
    +     *
    +     * *Aggregation over partitions*
    +     * Instead of performing a flatMap/reduceByKey operation, we exploit 
the fact that we know
    +     * the number of splits in advance. Thus, we store the aggregates (at 
the appropriate
    +     * indices) in a single array for all bins and rely upon the RDD 
aggregate method to
    +     * drastically reduce the communication overhead.
    +     */
    +
    +    // numNodes:  Number of nodes in this group
    +    val numNodes = nodesForGroup.values.map(_.length).sum
    +    logDebug("numNodes = " + numNodes)
    +    logDebug("numFeatures = " + metadata.numFeatures)
    +    logDebug("numClasses = " + metadata.numClasses)
    +    logDebug("isMulticlass = " + metadata.isMulticlass)
    +    logDebug("isMulticlassWithCategoricalFeatures = " +
    +      metadata.isMulticlassWithCategoricalFeatures)
    +    logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
    +
    +    /**
    +     * Performs a sequential aggregation over a partition for a particular 
tree and node.
    +     *
    +     * For each feature, the aggregate sufficient statistics are updated 
for the relevant
    +     * bins.
    +     *
    +     * @param treeIndex Index of the tree that we want to perform 
aggregation for.
    +     * @param nodeInfo The node info for the tree node.
    +     * @param agg Array storing aggregate calculation, with a set of 
sufficient statistics
    +     *            for each (node, feature, bin).
    +     * @param baggedPoint Data point being aggregated.
    +     */
    +    def nodeBinSeqOp(
    +        treeIndex: Int,
    +        nodeInfo: NodeIndexInfo,
    +        agg: Array[DTStatsAggregator],
    +        baggedPoint: BaggedPoint[TreePoint]): Unit = {
    +      if (nodeInfo != null) {
    +        val aggNodeIndex = nodeInfo.nodeIndexInGroup
    +        val featuresForNode = nodeInfo.featureSubset
    +        val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
    +        if (metadata.unorderedFeatures.isEmpty) {
    +          orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, 
instanceWeight, featuresForNode)
    +        } else {
    +          mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
    +            metadata.unorderedFeatures, instanceWeight, featuresForNode)
    +        }
    +      }
    +    }
    +
    +    /**
    +     * Performs a sequential aggregation over a partition.
    +     *
    +     * Each data point contributes to one node. For each feature,
    +     * the aggregate sufficient statistics are updated for the relevant 
bins.
    +     *
    +     * @param agg  Array storing aggregate calculation, with a set of 
sufficient statistics for
    +     *             each (node, feature, bin).
    +     * @param baggedPoint   Data point being aggregated.
    +     * @return  agg
    +     */
    +    def binSeqOp(
    +        agg: Array[DTStatsAggregator],
    +        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
    +      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
    +        val nodeIndex =
    +          predictNodeIndex(topNodes(treeIndex), 
baggedPoint.datum.binnedFeatures, splits)
    +        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, 
null), agg, baggedPoint)
    +      }
    +      agg
    +    }
    +
    +    /**
    +     * Do the same thing as binSeqOp, but with nodeIdCache.
    +     */
    +    def binSeqOpWithNodeIdCache(
    +        agg: Array[DTStatsAggregator],
    +        dataPoint: (BaggedPoint[TreePoint], Array[Int])): 
Array[DTStatsAggregator] = {
    +      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
    +        val baggedPoint = dataPoint._1
    +        val nodeIdCache = dataPoint._2
    +        val nodeIndex = nodeIdCache(treeIndex)
    +        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, 
null), agg, baggedPoint)
    +      }
    +
    +      agg
    +    }
    +
    +    /**
    +     * Get node index in group --> features indices map,
    +     * which is a short cut to find feature indices for a node given node 
index in group.
    +     */
    +    def getNodeToFeatures(
    +        treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): 
Option[Map[Int, Array[Int]]] = {
    +      if (!metadata.subsamplingFeatures) {
    +        None
    +      } else {
    +        val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
    +        treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
    +          nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
    +            assert(nodeIndexInfo.featureSubset.isDefined)
    +            mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = 
nodeIndexInfo.featureSubset.get
    +          }
    +        }
    +        Some(mutableNodeToFeatures.toMap)
    +      }
    +    }
    +
    +    // array of nodes to train indexed by node index in group
    +    val nodes = new Array[LearningNode](numNodes)
    +    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
    +      nodesForTree.foreach { node =>
    +        nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) 
= node
    +      }
    +    }
    +
    +    // Calculate best splits for all nodes in the group
    +    timer.start("chooseSplits")
    +
    +    // In each partition, iterate all instances and compute aggregate 
stats for each node,
    +    // yield an (nodeIndex, nodeAggregateStats) pair for each node.
    +    // After a `reduceByKey` operation,
    +    // stats of a node will be shuffled to a particular partition and be 
combined together,
    +    // then best splits for nodes are found there.
    +    // Finally, only best Splits for nodes are collected to driver to 
construct decision tree.
    +    val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
    +    val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
    +
    +    val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if 
(nodeIdCache.nonEmpty) {
    +      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { 
points =>
    +        // Construct a nodeStatsAggregators array to hold node aggregate 
stats,
    +        // each node will have a nodeStatsAggregator
    +        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
    +          val featuresForNode = nodeToFeaturesBc.value.flatMap { 
nodeToFeatures =>
    +            Some(nodeToFeatures(nodeIndex))
    +          }
    +          new DTStatsAggregator(metadata, featuresForNode)
    +        }
    +
    +        // iterator all instances in current partition and update 
aggregate stats
    +        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
    +
    +        // transform nodeStatsAggregators array to (nodeIndex, 
nodeAggregateStats) pairs,
    +        // which can be combined with other partition using `reduceByKey`
    +        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
    +      }
    +    } else {
    +      input.mapPartitions { points =>
    +        // Construct a nodeStatsAggregators array to hold node aggregate 
stats,
    +        // each node will have a nodeStatsAggregator
    +        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
    +          val featuresForNode = nodeToFeaturesBc.value.flatMap { 
nodeToFeatures =>
    +            Some(nodeToFeatures(nodeIndex))
    +          }
    +          new DTStatsAggregator(metadata, featuresForNode)
    +        }
    +
    +        // iterator all instances in current partition and update 
aggregate stats
    +        points.foreach(binSeqOp(nodeStatsAggregators, _))
    +
    +        // transform nodeStatsAggregators array to (nodeIndex, 
nodeAggregateStats) pairs,
    +        // which can be combined with other partition using `reduceByKey`
    +        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
    +      }
    +    }
    +
    +    val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => 
a.merge(b))
    +      .map { case (nodeIndex, aggStats) =>
    +      val featuresForNode = nodeToFeaturesBc.value.flatMap { 
nodeToFeatures =>
    --- End diff --
    
    indentation


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to