Hi all,

Suppose I have this binary tree that I want to level-based traverse using
recursive algorithm:

  .
└── 1/
    ├── 2/
    │   ├── 3/
    │   │   ├── 4
    │   │   └── 9
    │   └── 30
    └── 71/
        ├── 72
        └── 99

I wrote this algorithm inspired by the level first traversal of a tree
algorithm which stops at a certain input depth:

def get_subtree_from_rt(subtree, root_start, max_depth):
    if max_depth == 0:
        return []
    nodes = [root_start]
    if root_start == -1:
        return []
    else:
        nodes.extend([subtree.children_left[root_start],
subtree.children_right[root_start]])
        print(nodes)
    nodes.extend(child for child in get_subtree_from_rt(subtree,
subtree.children_left[root_start], max_depth - 1) if
                     child not in list(filter(lambda a: a != -1, nodes)))

    nodes.extend(child for child in get_subtree_from_rt(subtree,
subtree.children_right[root_start], max_depth - 1) if
                     child not in list(filter(lambda a: a != -1, nodes)))
    return nodes

The algorithm does traverse the tree but in an unwanted order, namely the
returned result for the mentioned tree was:

[1, 2, 71, 3, 30, 4, 9]

While the right one should have been:

[1, 2, 71, 3, 30, 72, 99]

Indeed the root_start is not the same for both recursive calls, since the
first recursive call alters its value.


My question is how to obtain the mentioned results but avoid calling the
second recursive call on a different root_start value?

use: tree_stucture as input as subtree

  import pandas as pd
  import numpy as np
  from sklearn import *
  from sklearn.model_selection import train_test_split
  from sklearn.tree import DecisionTreeRegressor
  from sklearn import tree
  dataset = pd.read_csv("anydatasetPath")
  x = dataset.drop(dataset.columns[9],axis = 1)
  y = dataset.iloc[:,9]

  x_train, x_test,y_train,y_test = train_test_split(x,y,test_size=
0.2,random_state = 28)


  model = DecisionTreeRegressor(random_state=0)
  model.fit(x_train,y_train)
  y_pred = model.predict(x_test)

  tree_stucture = model.tree_

  print(get_subtree_from_rt(tree_stucture,1,3))



with many thanks
_______________________________________________
scikit-learn mailing list
scikit-learn@python.org
https://mail.python.org/mailman/listinfo/scikit-learn

Reply via email to