Hi all,
Suppose I have this binary tree that I want to level-based traverse using
recursive algorithm:
.
└── 1/
├── 2/
│ ├── 3/
│ │ ├── 4
│ │ └── 9
│ └── 30
└── 71/
├── 72
└── 99
I wrote this algorithm inspired by the level first traversal of a tree
algorithm which stops at a certain input depth:
def get_subtree_from_rt(subtree, root_start, max_depth):
if max_depth == 0:
return []
nodes = [root_start]
if root_start == -1:
return []
else:
nodes.extend([subtree.children_left[root_start],
subtree.children_right[root_start]])
print(nodes)
nodes.extend(child for child in get_subtree_from_rt(subtree,
subtree.children_left[root_start], max_depth - 1) if
child not in list(filter(lambda a: a != -1, nodes)))
nodes.extend(child for child in get_subtree_from_rt(subtree,
subtree.children_right[root_start], max_depth - 1) if
child not in list(filter(lambda a: a != -1, nodes)))
return nodes
The algorithm does traverse the tree but in an unwanted order, namely the
returned result for the mentioned tree was:
[1, 2, 71, 3, 30, 4, 9]
While the right one should have been:
[1, 2, 71, 3, 30, 72, 99]
Indeed the root_start is not the same for both recursive calls, since the
first recursive call alters its value.
My question is how to obtain the mentioned results but avoid calling the
second recursive call on a different root_start value?
use: tree_stucture as input as subtree
import pandas as pd
import numpy as np
from sklearn import *
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn import tree
dataset = pd.read_csv("anydatasetPath")
x = dataset.drop(dataset.columns[9],axis = 1)
y = dataset.iloc[:,9]
x_train, x_test,y_train,y_test = train_test_split(x,y,test_size=
0.2,random_state = 28)
model = DecisionTreeRegressor(random_state=0)
model.fit(x_train,y_train)
y_pred = model.predict(x_test)
tree_stucture = model.tree_
print(get_subtree_from_rt(tree_stucture,1,3))
with many thanks
_______________________________________________
scikit-learn mailing list
[email protected]
https://mail.python.org/mailman/listinfo/scikit-learn