Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/19433#discussion_r147036693 --- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala --- @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.tree.model.ImpurityStats + +/** Object exposing methods for local training of decision trees */ +private[ml] object LocalDecisionTree { + + /** + * Fully splits the passed-in node on the provided local dataset, returning + * an InternalNode/LeafNode corresponding to the root of the resulting tree. + * + * @param node LearningNode to use as the root of the subtree fit on the passed-in dataset + * @param metadata learning and dataset metadata for DecisionTree + * @param splits splits(i) = array of splits for feature i + */ + private[ml] def fitNode( + input: Array[TreePoint], + instanceWeights: Array[Double], + node: LearningNode, + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]]): Node = { + + // The case with 1 node (depth = 0) is handled separately. + // This allows all iterations in the depth > 0 case to use the same code. + // TODO: Check that learning works when maxDepth > 0 but learning stops at 1 node (because of + // other parameters). + if (metadata.maxDepth == 0) { + return node.toNode + } + + // Prepare column store. + // Note: rowToColumnStoreDense checks to make sure numRows < Int.MaxValue. + val colStoreInit: Array[Array[Int]] + = LocalDecisionTreeUtils.rowToColumnStoreDense(input.map(_.binnedFeatures)) + val labels = input.map(_.label) + + // Fit a regression model on the dataset, throwing an error if metadata indicates that + // we should train a classifier. + // TODO: Add support for training classifiers + if (metadata.numClasses > 1 && metadata.numClasses <= 32) { + throw new UnsupportedOperationException("Local training of a decision tree classifier is " + + "unsupported; currently, only regression is supported") + } else { + trainRegressor(node, colStoreInit, instanceWeights, labels, metadata, splits) + } + } + + /** + * Locally fits a decision tree regressor. + * TODO(smurching): Logic for fitting a classifier & regressor is the same; only difference + * is impurity metric. Use the same logic for fitting a classifier. + * + * @param rootNode Node to use as root of the tree fit on the passed-in dataset + * @param colStoreInit Array of columns of training data + * @param instanceWeights Array of weights for each training example + * @param metadata learning and dataset metadata for DecisionTree + * @param splits splits(i) = Array of possible splits for feature i + * @return LeafNode or InternalNode representation of rootNode + */ + private[ml] def trainRegressor( + rootNode: LearningNode, + colStoreInit: Array[Array[Int]], + instanceWeights: Array[Double], + labels: Array[Double], + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]]): Node = { + + // Sort each column by decision tree node. + val colStore: Array[FeatureVector] = colStoreInit.zipWithIndex.map { case (col, featureIndex) => + val featureArity: Int = metadata.featureArity.getOrElse(featureIndex, 0) + FeatureVector(featureIndex, featureArity, col) + } + + val numRows = colStore.headOption match { + case None => 0 + case Some(column) => column.values.length + } + + // Create a new PartitionInfo describing the status of our partially-trained subtree + // at each iteration of training + var trainingInfo: TrainingInfo = TrainingInfo(colStore, instanceWeights, + nodeOffsets = Array[(Int, Int)]((0, numRows)), activeNodes = Array(rootNode)) + + // Iteratively learn, one level of the tree at a time. + // Note: We do not use node IDs. + var currentLevel = 0 + var doneLearning = false + + while (currentLevel < metadata.maxDepth && !doneLearning) { + // Splits each active node if possible, returning an array of new active nodes + val activeNodes: Array[LearningNode] = + computeBestSplits(trainingInfo, labels, metadata, splits) + // Filter active node periphery by impurity. + val estimatedRemainingActive = activeNodes.count(_.stats.impurity > 0.0) --- End diff -- Wait... I check the code here: `trainingInfo = trainingInfo.update(splits, activeNodes)` So it seems you do not filter out the leaf node from the "activeNodes"(which is actually the `nextLevelNode` I mentioned above). So I think `TrainingInfo.activeNodes` is still possible to contains leaf node.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org