If you are able to traverse the tree, then you can extract the id of the leaf node for each feature vector. This is like a modified predict method where it returns the leaf node assigned to the data point instead of the prediction for that leaf node. The following example code should work:
import org.apache.spark.mllib.tree.model.Node import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vector // Load and parse the data file. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Split the data into training and test sets (30% held out for testing) val splits = data.randomSplit(Array(0.7, 0.3)) val (trainingData, testData) = (splits(0), splits(1)) // Train a DecisionTree model. // Empty categoricalFeaturesInfo indicates all features are continuous. val numClasses = 2 val categoricalFeaturesInfo = Map[Int, Int]() val impurity = "gini" val maxDepth = 5 val maxBins = 32 val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins) def predictImpl(node: Node, features: Vector): Node = { if (node.isLeaf) { node } else { if (node.split.get.featureType == Continuous) { if (features(node.split.get.feature) <= node.split.get.threshold) { predictImpl(node.leftNode.get, features) } else { predictImpl(node.rightNode.get, features) } } else { if (node.split.get.categories.contains(features(node.split.get.feature))) { predictImpl(node.leftNode.get, features) } else { predictImpl(node.rightNode.get, features) } } } } val nodeIDAndPredsAndLabels = data.map { lp => val node = predictImpl(model.topNode, lp.features) (node.id, (node.predict.predict, lp.label)) } >From here, you should be able to perform analysis of the accuracy of each leaf node. Note that in the new Spark ML library a predictNodeIndex is implemented (which is being converted to a predictImpl method) similar to the implementation above. Hopefully that code helps. -- View this message in context: http://apache-spark-user-list.1001560.n3.nabble.com/Spark-MLlib-Decision-Tree-Node-Accuracy-tp24561p24629.html Sent from the Apache Spark User List mailing list archive at Nabble.com. --------------------------------------------------------------------- To unsubscribe, e-mail: user-unsubscr...@spark.apache.org For additional commands, e-mail: user-h...@spark.apache.org