Also, have a look at the documentation here https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L3205 to understand the structure of the tree_ object.
On 31 August 2015 at 08:55, Gilles Louppe <g.lou...@gmail.com> wrote: > Here is a sample code on how to retrieve the nodes traversed by a given > sample: > > from sklearn.tree import DecisionTreeClassifier > from sklearn.datasets import load_iris > > iris = load_iris() > X, y = iris.data, iris.target > > clf = DecisionTreeClassifier().fit(X, y) > > def path(tree, sample): > nodes = [] > features = [] > node = 0 > > while tree.children_right[node] != -1: > nodes.append(node) > > if sample[tree.feature[node]] <= tree.threshold[node]: > node = tree.children_left[node] > else: > node = tree.children_right[node] > > return nodes > > path(clf.tree_, X[100]) > > # [0, 2, 12] > > Now to derive statistics like the number of samples reaching each > node, you can iterate over your data X and increment counters, e.g., > by doing counters[path(clf.tree_, X[i])] += 1, where counters is a > numpy array of size tree_.node_count. > > Hope this helps, > Gilles > > On 30 August 2015 at 22:37, Rex X <dnsr...@gmail.com> wrote: >> Jacob, this modification seems not easy. After fetching the decision rules >> leading to the node of interest, a following Pandas groupby script can >> compute these numbers through. Thank you. :) >> >> >> >> On Sun, Aug 30, 2015 at 11:54 AM, Jacob Schreiber <jmschreibe...@gmail.com> >> wrote: >>> >>> You would have to modify sklearn/tree/_tree.pyx. See the Tree class near >>> the bottom, and its list of properties. An issue may be that you would have >>> to extensively modify the code, as you would need to modify both splitter >>> and criterion objects as well. If you are doing this for your own personal >>> use, it may be easier to write a small script which successively applies the >>> rules of the tree to your data to see how many points from each class are >>> present. >>> >>> On Sun, Aug 30, 2015 at 10:50 AM, Rex X <dnsr...@gmail.com> wrote: >>>> >>>> Hi Jacob and Trevor, >>>> >>>> Which part of the source code we can modify to add a new attribute to >>>> DecisionTreeClassifier.tree_, to count the number of samples of each class >>>> within each node? >>>> >>>> Could you point me the right direction? >>>> >>>> Best, >>>> Rex >>>> >>>> >>>> >>>> >>>> On Sun, Aug 30, 2015 at 8:12 AM, Jacob Schreiber >>>> <jmschreibe...@gmail.com> wrote: >>>>> >>>>> This value is computed while building the tree, but is not kept in the >>>>> tree. >>>>> >>>>> On Sun, Aug 30, 2015 at 7:02 AM, Rex X <dnsr...@gmail.com> wrote: >>>>>> >>>>>> DecisionTreeClassifier.tree_.n_node_samples is the total number of >>>>>> samples in all classes of one node, and >>>>>> DecisionTreeClassifier.tree_.value >>>>>> is the computed weight for each class of one node. Only if the >>>>>> sample_weight >>>>>> and class_weight of this DecisionTreeClassifier is one, then this >>>>>> attribute >>>>>> equals the number of samples of each class of one node. >>>>>> >>>>>> But for the general case with a given sample_weight and class_weight, >>>>>> is there any attribute telling us the number of samples of each class >>>>>> within >>>>>> one node? >>>>>> >>>>>> >>>>>> import pandas as pd >>>>>> from sklearn.datasets import load_iris >>>>>> from sklearn import tree >>>>>> import sklearn >>>>>> >>>>>> iris = sklearn.datasets.load_iris() >>>>>> clf = tree.DecisionTreeClassifier(class_weight={0 : 0.30, 1: 0.3, >>>>>> 2:0.4}, max_features="auto") >>>>>> clf.fit(iris.data, iris.target) >>>>>> >>>>>> >>>>>> # the total number of samples in all classes of each node >>>>>> clf.tree_.n_node_samples >>>>>> >>>>>> # the computed weight for each class of each node >>>>>> clf.tree_.value >>>>>> >>>>>> >>>>>> >>>>>> >>>>>> ------------------------------------------------------------------------------ >>>>>> >>>>>> _______________________________________________ >>>>>> Scikit-learn-general mailing list >>>>>> Scikit-learn-general@lists.sourceforge.net >>>>>> https://lists.sourceforge.net/lists/listinfo/scikit-learn-general >>>>>> >>>>> >>>>> >>>>> >>>>> ------------------------------------------------------------------------------ >>>>> >>>>> _______________________________________________ >>>>> Scikit-learn-general mailing list >>>>> Scikit-learn-general@lists.sourceforge.net >>>>> https://lists.sourceforge.net/lists/listinfo/scikit-learn-general >>>>> >>>> >>>> >>>> >>>> ------------------------------------------------------------------------------ >>>> >>>> _______________________________________________ >>>> Scikit-learn-general mailing list >>>> Scikit-learn-general@lists.sourceforge.net >>>> https://lists.sourceforge.net/lists/listinfo/scikit-learn-general >>>> >>> >>> >>> >>> ------------------------------------------------------------------------------ >>> >>> _______________________________________________ >>> Scikit-learn-general mailing list >>> Scikit-learn-general@lists.sourceforge.net >>> https://lists.sourceforge.net/lists/listinfo/scikit-learn-general >>> >> >> >> ------------------------------------------------------------------------------ >> >> _______________________________________________ >> Scikit-learn-general mailing list >> Scikit-learn-general@lists.sourceforge.net >> https://lists.sourceforge.net/lists/listinfo/scikit-learn-general >> ------------------------------------------------------------------------------ _______________________________________________ Scikit-learn-general mailing list Scikit-learn-general@lists.sourceforge.net https://lists.sourceforge.net/lists/listinfo/scikit-learn-general