Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/19433#discussion_r147317401
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala ---
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.tree.model.ImpurityStats
+
+/** Object exposing methods for local training of decision trees */
+private[ml] object LocalDecisionTree {
+
+ /**
+ * Fully splits the passed-in node on the provided local dataset,
returning
+ * an InternalNode/LeafNode corresponding to the root of the resulting
tree.
+ *
+ * @param node LearningNode to use as the root of the subtree fit on the
passed-in dataset
+ * @param metadata learning and dataset metadata for DecisionTree
+ * @param splits splits(i) = array of splits for feature i
+ */
+ private[ml] def fitNode(
+ input: Array[TreePoint],
+ instanceWeights: Array[Double],
+ node: LearningNode,
+ metadata: DecisionTreeMetadata,
+ splits: Array[Array[Split]]): Node = {
+
+ // The case with 1 node (depth = 0) is handled separately.
+ // This allows all iterations in the depth > 0 case to use the same
code.
+ // TODO: Check that learning works when maxDepth > 0 but learning
stops at 1 node (because of
+ // other parameters).
+ if (metadata.maxDepth == 0) {
+ return node.toNode
+ }
+
+ // Prepare column store.
+ // Note: rowToColumnStoreDense checks to make sure numRows <
Int.MaxValue.
+ val colStoreInit: Array[Array[Int]]
+ =
LocalDecisionTreeUtils.rowToColumnStoreDense(input.map(_.binnedFeatures))
+ val labels = input.map(_.label)
+
+ // Fit a regression model on the dataset, throwing an error if
metadata indicates that
+ // we should train a classifier.
+ // TODO: Add support for training classifiers
+ if (metadata.numClasses > 1 && metadata.numClasses <= 32) {
+ throw new UnsupportedOperationException("Local training of a
decision tree classifier is " +
+ "unsupported; currently, only regression is supported")
+ } else {
+ trainRegressor(node, colStoreInit, instanceWeights, labels,
metadata, splits)
+ }
+ }
+
+ /**
+ * Locally fits a decision tree regressor.
+ * TODO(smurching): Logic for fitting a classifier & regressor is the
same; only difference
+ * is impurity metric. Use the same logic for fitting a classifier.
+ *
+ * @param rootNode Node to use as root of the tree fit on the passed-in
dataset
+ * @param colStoreInit Array of columns of training data
+ * @param instanceWeights Array of weights for each training example
+ * @param metadata learning and dataset metadata for DecisionTree
+ * @param splits splits(i) = Array of possible splits for feature i
+ * @return LeafNode or InternalNode representation of rootNode
+ */
+ private[ml] def trainRegressor(
+ rootNode: LearningNode,
+ colStoreInit: Array[Array[Int]],
+ instanceWeights: Array[Double],
+ labels: Array[Double],
+ metadata: DecisionTreeMetadata,
+ splits: Array[Array[Split]]): Node = {
+
+ // Sort each column by decision tree node.
+ val colStore: Array[FeatureVector] = colStoreInit.zipWithIndex.map {
case (col, featureIndex) =>
+ val featureArity: Int =
metadata.featureArity.getOrElse(featureIndex, 0)
+ FeatureVector(featureIndex, featureArity, col)
+ }
+
+ val numRows = colStore.headOption match {
+ case None => 0
+ case Some(column) => column.values.length
+ }
+
+ // Create a new TrainingInfo describing the status of our
partially-trained subtree
+ // at each iteration of training
+ var trainingInfo: TrainingInfo = TrainingInfo(colStore,
+ nodeOffsets = Array[(Int, Int)]((0, numRows)),
currentLevelActiveNodes = Array(rootNode))
+
+ // Iteratively learn, one level of the tree at a time.
+ // Note: We do not use node IDs.
+ var currentLevel = 0
+ var doneLearning = false
+
+ while (currentLevel < metadata.maxDepth && !doneLearning) {
+ // Splits each active node if possible, returning an array of new
active nodes
+ val nextLevelNodes: Array[LearningNode] =
+ computeBestSplits(trainingInfo, instanceWeights, labels, metadata,
splits)
+ // Count number of non-leaf nodes in the next level
+ val estimatedRemainingActive = nextLevelNodes.count(!_.isLeaf)
+ // TODO: Check to make sure we split something, and stop otherwise.
+ doneLearning = currentLevel + 1 >= metadata.maxDepth ||
estimatedRemainingActive == 0
+ if (!doneLearning) {
+ // Obtain a new trainingInfo instance describing our current
training status
+ trainingInfo = trainingInfo.update(splits, nextLevelNodes)
+ }
+ currentLevel += 1
+ }
+
+ // Done with learning
+ rootNode.toNode
+ }
+
+ /**
+ * Iterate over feature values and labels for a specific (node,
feature), updating stats
+ * aggregator for the current node.
+ */
+ private[impl] def updateAggregator(
+ statsAggregator: DTStatsAggregator,
+ col: FeatureVector,
+ instanceWeights: Array[Double],
+ labels: Array[Double],
+ from: Int,
+ to: Int,
+ featureIndexIdx: Int,
+ featureSplits: Array[Split]): Unit = {
+ val metadata = statsAggregator.metadata
+ if (metadata.isUnordered(col.featureIndex)) {
+ from.until(to).foreach { idx =>
+ val rowIndex = col.indices(idx)
+ AggUpdateUtils.updateUnorderedFeature(statsAggregator,
col.values(idx), labels(rowIndex),
+ featureIndex = col.featureIndex, featureIndexIdx, featureSplits,
+ instanceWeight = instanceWeights(rowIndex))
+ }
+ } else {
+ from.until(to).foreach { idx =>
+ val rowIndex = col.indices(idx)
+ AggUpdateUtils.updateOrderedFeature(statsAggregator,
col.values(idx), labels(rowIndex),
+ featureIndex = col.featureIndex, featureIndexIdx,
+ instanceWeight = instanceWeights(rowIndex))
+ }
+ }
+ }
+
+ /**
+ * Find the best splits for all active nodes
+ *
+ * @param trainingInfo Contains node offset info for current set of
active nodes
+ * @return Array of new active nodes formed by splitting the current
set of active nodes.
+ */
+ private def computeBestSplits(
+ trainingInfo: TrainingInfo,
+ instanceWeights: Array[Double],
+ labels: Array[Double],
+ metadata: DecisionTreeMetadata,
+ splits: Array[Array[Split]]): Array[LearningNode] = {
+ // For each node, select the best split across all features
+ trainingInfo match {
+ case TrainingInfo(columns: Array[FeatureVector],
+ nodeOffsets: Array[(Int, Int)], activeNodes: Array[LearningNode]) =>
{
--- End diff --
`activeNodes` ==> `currentLevelNodes`
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]