Here is a sample code on how to retrieve the nodes traversed by a given sample:
from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import load_iris iris = load_iris() X, y = iris.data, iris.target clf = DecisionTreeClassifier().fit(X, y) def path(tree, sample): nodes = [] features = [] node = 0 while tree.children_right[node] != -1: nodes.append(node) if sample[tree.feature[node]] <= tree.threshold[node]: node = tree.children_left[node] else: node = tree.children_right[node] return nodes path(clf.tree_, X[100]) # [0, 2, 12] Now to derive statistics like the number of samples reaching each node, you can iterate over your data X and increment counters, e.g., by doing counters[path(clf.tree_, X[i])] += 1, where counters is a numpy array of size tree_.node_count. Hope this helps, Gilles On 30 August 2015 at 22:37, Rex X <dnsr...@gmail.com> wrote: > Jacob, this modification seems not easy. After fetching the decision rules > leading to the node of interest, a following Pandas groupby script can > compute these numbers through. Thank you. :) > > > > On Sun, Aug 30, 2015 at 11:54 AM, Jacob Schreiber <jmschreibe...@gmail.com> > wrote: >> >> You would have to modify sklearn/tree/_tree.pyx. See the Tree class near >> the bottom, and its list of properties. An issue may be that you would have >> to extensively modify the code, as you would need to modify both splitter >> and criterion objects as well. If you are doing this for your own personal >> use, it may be easier to write a small script which successively applies the >> rules of the tree to your data to see how many points from each class are >> present. >> >> On Sun, Aug 30, 2015 at 10:50 AM, Rex X <dnsr...@gmail.com> wrote: >>> >>> Hi Jacob and Trevor, >>> >>> Which part of the source code we can modify to add a new attribute to >>> DecisionTreeClassifier.tree_, to count the number of samples of each class >>> within each node? >>> >>> Could you point me the right direction? >>> >>> Best, >>> Rex >>> >>> >>> >>> >>> On Sun, Aug 30, 2015 at 8:12 AM, Jacob Schreiber >>> <jmschreibe...@gmail.com> wrote: >>>> >>>> This value is computed while building the tree, but is not kept in the >>>> tree. >>>> >>>> On Sun, Aug 30, 2015 at 7:02 AM, Rex X <dnsr...@gmail.com> wrote: >>>>> >>>>> DecisionTreeClassifier.tree_.n_node_samples is the total number of >>>>> samples in all classes of one node, and DecisionTreeClassifier.tree_.value >>>>> is the computed weight for each class of one node. Only if the >>>>> sample_weight >>>>> and class_weight of this DecisionTreeClassifier is one, then this >>>>> attribute >>>>> equals the number of samples of each class of one node. >>>>> >>>>> But for the general case with a given sample_weight and class_weight, >>>>> is there any attribute telling us the number of samples of each class >>>>> within >>>>> one node? >>>>> >>>>> >>>>> import pandas as pd >>>>> from sklearn.datasets import load_iris >>>>> from sklearn import tree >>>>> import sklearn >>>>> >>>>> iris = sklearn.datasets.load_iris() >>>>> clf = tree.DecisionTreeClassifier(class_weight={0 : 0.30, 1: 0.3, >>>>> 2:0.4}, max_features="auto") >>>>> clf.fit(iris.data, iris.target) >>>>> >>>>> >>>>> # the total number of samples in all classes of each node >>>>> clf.tree_.n_node_samples >>>>> >>>>> # the computed weight for each class of each node >>>>> clf.tree_.value >>>>> >>>>> >>>>> >>>>> >>>>> ------------------------------------------------------------------------------ >>>>> >>>>> _______________________________________________ >>>>> Scikit-learn-general mailing list >>>>> Scikit-learn-general@lists.sourceforge.net >>>>> https://lists.sourceforge.net/lists/listinfo/scikit-learn-general >>>>> >>>> >>>> >>>> >>>> ------------------------------------------------------------------------------ >>>> >>>> _______________________________________________ >>>> Scikit-learn-general mailing list >>>> Scikit-learn-general@lists.sourceforge.net >>>> https://lists.sourceforge.net/lists/listinfo/scikit-learn-general >>>> >>> >>> >>> >>> ------------------------------------------------------------------------------ >>> >>> _______________________________________________ >>> Scikit-learn-general mailing list >>> Scikit-learn-general@lists.sourceforge.net >>> https://lists.sourceforge.net/lists/listinfo/scikit-learn-general >>> >> >> >> >> ------------------------------------------------------------------------------ >> >> _______________________________________________ >> Scikit-learn-general mailing list >> Scikit-learn-general@lists.sourceforge.net >> https://lists.sourceforge.net/lists/listinfo/scikit-learn-general >> > > > ------------------------------------------------------------------------------ > > _______________________________________________ > Scikit-learn-general mailing list > Scikit-learn-general@lists.sourceforge.net > https://lists.sourceforge.net/lists/listinfo/scikit-learn-general > ------------------------------------------------------------------------------ _______________________________________________ Scikit-learn-general mailing list Scikit-learn-general@lists.sourceforge.net https://lists.sourceforge.net/lists/listinfo/scikit-learn-general