Github user manishamde commented on a diff in the pull request:
https://github.com/apache/spark/pull/2435#discussion_r17889907
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala ---
@@ -0,0 +1,430 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint,
DecisionTreeMetadata, TimeTracker}
+import org.apache.spark.mllib.tree.impurity.Impurities
+import org.apache.spark.mllib.tree.model._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+/**
+ * :: Experimental ::
+ * A class which implements a random forest learning algorithm for
classification and regression.
+ * It supports both continuous and categorical features.
+ *
+ * @param strategy The configuration parameters for the random forest
algorithm which specify
+ * the type of algorithm (classification, regression,
etc.), feature type
+ * (continuous, categorical), depth of the tree, quantile
calculation strategy,
+ * etc.
+ * @param numTrees If 1, then no bootstrapping is used. If > 1, then
bootstrapping is done.
+ * @param featureSubsetStrategy Number of features to consider for splits
at each node.
+ * Supported: "auto" (default), "all",
"sqrt", "log2", "onethird".
+ * If "auto" is set, this parameter is set
based on numTrees:
+ * if numTrees == 1, then
featureSubsetStrategy = "all";
+ * if numTrees > 1, then
featureSubsetStrategy = "sqrt".
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
+ */
+@Experimental
+private class RandomForest (
+ private val strategy: Strategy,
+ private val numTrees: Int,
+ featureSubsetStrategy: String,
+ private val seed: Int)
+ extends Serializable with Logging {
+
+ strategy.assertValid()
+ require(numTrees > 0, s"RandomForest requires numTrees > 0, but was
given numTrees = $numTrees.")
+
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
+ s"RandomForest given invalid featureSubsetStrategy:
$featureSubsetStrategy." +
+ s" Supported values:
${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
+
+ /**
+ * Method to train a decision tree model over an RDD
+ * @param input Training data: RDD of
[[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @return RandomForestModel that can be used for prediction
+ */
+ def train(input: RDD[LabeledPoint]): RandomForestModel = {
+
+ val timer = new TimeTracker()
+
+ timer.start("total")
+
+ timer.start("init")
+
+ val retaggedInput = input.retag(classOf[LabeledPoint])
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(retaggedInput, strategy,
numTrees, featureSubsetStrategy)
+ logDebug("algo = " + strategy.algo)
+ logDebug("numTrees = " + numTrees)
+ logDebug("seed = " + seed)
+ logDebug("maxBins = " + metadata.maxBins)
+ logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
+ logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
+
+ // Find the splits and the corresponding bins (interval between the
splits) using a sample
+ // of the input data.
+ timer.start("findSplitsBins")
+ val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput,
metadata)
+ timer.stop("findSplitsBins")
+ logDebug("numBins: feature: number of bins")
+ logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+ s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+ }.mkString("\n"))
+
+ // Bin feature values (TreePoint representation).
+ // Cache input RDD for speedup during multiple passes.
+ val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins,
metadata)
+ val baggedInput = if (numTrees > 1) {
+ BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed)
+ } else {
+ BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+ }.persist(StorageLevel.MEMORY_AND_DISK)
+
+ // depth of the decision tree
+ val maxDepth = strategy.maxDepth
+ require(maxDepth <= 30,
+ s"DecisionTree currently only supports maxDepth <= 30, but was given
maxDepth = $maxDepth.")
+
+ // Max memory usage for aggregates
+ val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
+ logDebug("max memory usage for aggregates = " + maxMemoryUsage + "
bytes.")
+ val maxMemoryPerNode = {
+ val featureSubset: Option[Array[Int]] = if
(metadata.subsamplingFeatures) {
+ // Find numFeaturesPerNode largest bins to get an upper bound on
memory usage.
+ Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
+ .take(metadata.numFeaturesPerNode).map(_._2))
+ } else {
+ None
+ }
+ RandomForest.numElementsForNode(metadata, featureSubset) * 8L
+ }
+ require(maxMemoryPerNode <= maxMemoryUsage,
+ s"RandomForest/DecisionTree given maxMemoryInMB =
${strategy.maxMemoryInMB}," +
+ " which is too small for the given features." +
+ s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
+ // TODO: Calculate memory usage more precisely.
+
+ timer.stop("init")
+
+ /*
+ * The main idea here is to perform group-wise training of the
decision tree nodes thus
+ * reducing the passes over the data from (# nodes) to (# nodes /
maxNumberOfNodesPerGroup).
+ * Each data sample is handled by a particular node (or it reaches a
leaf and is not used
+ * in lower levels).
+ */
+
+ // FIFO queue of nodes to train: (treeIndex, node)
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+
+ val rng = new scala.util.Random()
+ rng.setSeed(seed)
+
+ // Allocate and queue root nodes.
+ val topNodes: Array[Node] =
Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
+ Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex,
topNodes(treeIndex))))
+
+ while (nodeQueue.nonEmpty) {
+ // Collect some nodes to split, and choose features for each node
(if subsampling).
+ val (nodesForGroup: Map[Int, Array[Node]],
--- End diff --
Minor: One could remove the explicit types and add a comment to explain
what data each variable holds.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]