Hi, Since the last version, scikit-learn provides an `apply` method for the classifier itself, hence preventing users from shooting themselves in the foot :)
So basically, you can replace clf.tree_.apply(X_train) with clf..apply(X_train) and it should work. Hope this helps, Gilles On 23 May 2015 at 17:26, Kittipat Kampa <kitti...@gmail.com> wrote: > Hi everyone, > > I'm using decision tree classifier from the scikit-learn package in python > 3.4, and I want to get the corresponding leaf node id for each of my input > data point. > > For example, my input data array (three records, 4d each) looks like this: > > array([[ 5.1, 3.5, 1.4, 0.2], > [ 4.9, 3. , 1.4, 0.2], > [ 4.7, 3.2, 1.3, 0.2]]) > > and let's suppose the corresponding leaf nodes for each record are 16, 5 and > 45 respectively. I want my output to be: > > leaf_node_id = array([16, 5, 45]) > > I have read through the scikit-learn mailing list and related questions on > StackOverflow but I still can't get it to work. Here is some hint I found on > the mailing list, but still does not work. > > http://sourceforge.net/p/scikit-learn/mailman/message/31728624/ > > Below is the code I'm using and it raises the error at the final step. At > the end of the day, I want to write a function GetLeafNodes(clf, X_input) > that returns an array of corresponding leaf nodes of the input data X_input > when clf is the decision tree classifier object. Any suggestion is very > appreciated. > > from sklearn.datasets import load_iris > from sklearn import tree > > # load data and divide it to train and validation > iris = load_iris() > > num_train = 100 > X_train = iris.data[:num_train,:] > X_valida = iris.data[num_train:,:] > > y_train = iris.target[:num_train] > y_valida = iris.target[num_train:] > > # fit the decision tree using the train data set > clf = tree.DecisionTreeClassifier() > clf = clf.fit(X_train, y_train) > > # Now I want to know the corresponding leaf node id for each of my training > data points > clf.tree_.apply(X_train) > > # This gives the error message below: > --------------------------------------------------------------------------- > ValueError Traceback (most recent call last) > <ipython-input-17-2ecc95213752> in <module>() > ----> 1 clf.tree_.apply(X_train) > > _tree.pyx in sklearn.tree._tree.Tree.apply (sklearn/tree/_tree.c:19595)() > > ValueError: Buffer dtype mismatch, expected 'DTYPE_t' but got 'double' > > > ------------------------------------------------------------------------------ > One dashboard for servers and applications across Physical-Virtual-Cloud > Widest out-of-the-box monitoring support with 50+ applications > Performance metrics, stats and reports that give you Actionable Insights > Deep dive visibility with transaction tracing using APM Insight. > http://ad.doubleclick.net/ddm/clk/290420510;117567292;y > _______________________________________________ > Scikit-learn-general mailing list > Scikit-learn-general@lists.sourceforge.net > https://lists.sourceforge.net/lists/listinfo/scikit-learn-general > ------------------------------------------------------------------------------ One dashboard for servers and applications across Physical-Virtual-Cloud Widest out-of-the-box monitoring support with 50+ applications Performance metrics, stats and reports that give you Actionable Insights Deep dive visibility with transaction tracing using APM Insight. http://ad.doubleclick.net/ddm/clk/290420510;117567292;y _______________________________________________ Scikit-learn-general mailing list Scikit-learn-general@lists.sourceforge.net https://lists.sourceforge.net/lists/listinfo/scikit-learn-general