Hi everyone,

I'm using decision tree classifier from the scikit-learn package in python
3.4, and I want to get the corresponding leaf node id for each of my input
data point.

For example, my input data array (three records, 4d each) looks like this:

array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2]])

and let's suppose the corresponding leaf nodes for each record are 16, 5
and 45 respectively. I want my output to be:

leaf_node_id = array([16, 5, 45])

I have read through the scikit-learn mailing list and related questions on
StackOverflow but I still can't get it to work. Here is some hint I found
on the mailing list, but still does not work.

http://sourceforge.net/p/scikit-learn/mailman/message/31728624/

Below is the code I'm using and it raises the error at the final step. At
the end of the day, I want to write a function GetLeafNodes(clf, X_input)
that returns an array of corresponding leaf nodes of the input data X_input
when clf is the decision tree classifier object. Any suggestion is very
appreciated.

from sklearn.datasets import load_irisfrom sklearn import tree
# load data and divide it to train and validation
iris = load_iris()

num_train = 100
X_train = iris.data[:num_train,:]
X_valida = iris.data[num_train:,:]

y_train = iris.target[:num_train]
y_valida = iris.target[num_train:]
# fit the decision tree using the train data set
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, y_train)
# Now I want to know the corresponding leaf node id for each of my
training data points
clf.tree_.apply(X_train)
# This gives the error message
below:---------------------------------------------------------------------------ValueError
                               Traceback (most recent call
last)<ipython-input-17-2ecc95213752> in <module>()----> 1
clf.tree_.apply(X_train)

_tree.pyx in sklearn.tree._tree.Tree.apply (sklearn/tree/_tree.c:19595)()
ValueError: Buffer dtype mismatch, expected 'DTYPE_t' but got 'double'
------------------------------------------------------------------------------
One dashboard for servers and applications across Physical-Virtual-Cloud 
Widest out-of-the-box monitoring support with 50+ applications
Performance metrics, stats and reports that give you Actionable Insights
Deep dive visibility with transaction tracing using APM Insight.
http://ad.doubleclick.net/ddm/clk/290420510;117567292;y
_______________________________________________
Scikit-learn-general mailing list
Scikit-learn-general@lists.sourceforge.net
https://lists.sourceforge.net/lists/listinfo/scikit-learn-general

Reply via email to