Github user smurching commented on a diff in the pull request:
https://github.com/apache/spark/pull/19433#discussion_r147307553
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala ---
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.tree.model.ImpurityStats
+
+/** Object exposing methods for local training of decision trees */
+private[ml] object LocalDecisionTree {
+
+ /**
+ * Fully splits the passed-in node on the provided local dataset,
returning
+ * an InternalNode/LeafNode corresponding to the root of the resulting
tree.
+ *
+ * @param node LearningNode to use as the root of the subtree fit on the
passed-in dataset
+ * @param metadata learning and dataset metadata for DecisionTree
+ * @param splits splits(i) = array of splits for feature i
+ */
+ private[ml] def fitNode(
+ input: Array[TreePoint],
+ instanceWeights: Array[Double],
+ node: LearningNode,
+ metadata: DecisionTreeMetadata,
+ splits: Array[Array[Split]]): Node = {
+
+ // The case with 1 node (depth = 0) is handled separately.
+ // This allows all iterations in the depth > 0 case to use the same
code.
+ // TODO: Check that learning works when maxDepth > 0 but learning
stops at 1 node (because of
+ // other parameters).
+ if (metadata.maxDepth == 0) {
+ return node.toNode
+ }
+
+ // Prepare column store.
+ // Note: rowToColumnStoreDense checks to make sure numRows <
Int.MaxValue.
+ val colStoreInit: Array[Array[Int]]
+ =
LocalDecisionTreeUtils.rowToColumnStoreDense(input.map(_.binnedFeatures))
+ val labels = input.map(_.label)
+
+ // Fit a regression model on the dataset, throwing an error if
metadata indicates that
+ // we should train a classifier.
+ // TODO: Add support for training classifiers
+ if (metadata.numClasses > 1 && metadata.numClasses <= 32) {
+ throw new UnsupportedOperationException("Local training of a
decision tree classifier is " +
+ "unsupported; currently, only regression is supported")
+ } else {
+ trainRegressor(node, colStoreInit, instanceWeights, labels,
metadata, splits)
+ }
+ }
+
+ /**
+ * Locally fits a decision tree regressor.
+ * TODO(smurching): Logic for fitting a classifier & regressor is the
same; only difference
+ * is impurity metric. Use the same logic for fitting a classifier.
+ *
+ * @param rootNode Node to use as root of the tree fit on the passed-in
dataset
+ * @param colStoreInit Array of columns of training data
+ * @param instanceWeights Array of weights for each training example
+ * @param metadata learning and dataset metadata for DecisionTree
+ * @param splits splits(i) = Array of possible splits for feature i
+ * @return LeafNode or InternalNode representation of rootNode
+ */
+ private[ml] def trainRegressor(
+ rootNode: LearningNode,
+ colStoreInit: Array[Array[Int]],
+ instanceWeights: Array[Double],
+ labels: Array[Double],
+ metadata: DecisionTreeMetadata,
+ splits: Array[Array[Split]]): Node = {
+
+ // Sort each column by decision tree node.
+ val colStore: Array[FeatureVector] = colStoreInit.zipWithIndex.map {
case (col, featureIndex) =>
+ val featureArity: Int =
metadata.featureArity.getOrElse(featureIndex, 0)
+ FeatureVector(featureIndex, featureArity, col)
+ }
+
+ val numRows = colStore.headOption match {
+ case None => 0
+ case Some(column) => column.values.length
+ }
+
+ // Create a new PartitionInfo describing the status of our
partially-trained subtree
+ // at each iteration of training
+ var trainingInfo: TrainingInfo = TrainingInfo(colStore,
instanceWeights,
+ nodeOffsets = Array[(Int, Int)]((0, numRows)), activeNodes =
Array(rootNode))
+
+ // Iteratively learn, one level of the tree at a time.
+ // Note: We do not use node IDs.
+ var currentLevel = 0
+ var doneLearning = false
+
+ while (currentLevel < metadata.maxDepth && !doneLearning) {
+ // Splits each active node if possible, returning an array of new
active nodes
+ val activeNodes: Array[LearningNode] =
+ computeBestSplits(trainingInfo, labels, metadata, splits)
+ // Filter active node periphery by impurity.
+ val estimatedRemainingActive = activeNodes.count(_.stats.impurity >
0.0)
--- End diff --
Oh true -- I'll doc for `currentLevelActiveNodes` to say:
```
* @param currentLevelActiveNodes Nodes which are active (could still be
split).
* Inactive nodes are known to be leaves in
the final tree.
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]