Repository: incubator-madlib
Updated Branches:
  refs/heads/master 0cdd644a9 -> 20b115800


DT: Assign memory only for reachable nodes

JIRA: MADLIB-1057

TreeAccumulator assigns a matrix to track the statistics of rows
reaching the last layer of nodes. This matrix assumes a complete
tree and assigns memory for all nodes. As the tree gets deeper,
most of the nodes are unreachable, resulting in excessive wasted
memory. This commit reduces that waste by only assigning memory
for nodes that are reachable and accessing them through a lookup
table.

Closes #120


Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/20b11580
Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/20b11580
Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/20b11580

Branch: refs/heads/master
Commit: 20b115800e8e984553d3239c81c8ff62c64efaa3
Parents: 0cdd644
Author: Rahul Iyer <ri...@apache.org>
Authored: Tue Apr 25 15:00:40 2017 -0700
Committer: Rahul Iyer <ri...@apache.org>
Committed: Tue Apr 25 15:00:40 2017 -0700

----------------------------------------------------------------------
 src/modules/recursive_partitioning/DT_impl.hpp  | 125 +++++++++++--------
 src/modules/recursive_partitioning/DT_proto.hpp |  20 ++-
 .../recursive_partitioning/decision_tree.cpp    |  62 +++++++--
 3 files changed, 143 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/20b11580/src/modules/recursive_partitioning/DT_impl.hpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/DT_impl.hpp 
b/src/modules/recursive_partitioning/DT_impl.hpp
index 64d2b88..6d15db5 100644
--- a/src/modules/recursive_partitioning/DT_impl.hpp
+++ b/src/modules/recursive_partitioning/DT_impl.hpp
@@ -475,7 +475,7 @@ DecisionTree<Container>::expand(const Accumulator &state,
                                 const uint16_t &min_split,
                                 const uint16_t &min_bucket,
                                 const uint16_t &max_depth) {
-    uint16_t n_non_leaf_nodes = static_cast<uint16_t>(state.n_leaf_nodes - 1);
+    uint32_t n_non_leaf_nodes = static_cast<uint32_t>(state.n_leaf_nodes - 1);
     bool children_not_allocated = true;
     bool children_wont_split = true;
 
@@ -483,8 +483,11 @@ DecisionTree<Container>::expand(const Accumulator &state,
     for (Index i=0; i < state.n_leaf_nodes; i++) {
         Index current = n_non_leaf_nodes + i;
         if (feature_indices(current) == IN_PROCESS_LEAF) {
+            Index stats_i = static_cast<Index>(state.stats_lookup(i));
+            assert(stats_i >= 0);
+
             // 1. Set the prediction for current node from stats of all rows
-            predictions.row(current) = state.node_stats.row(i);
+            predictions.row(current) = state.node_stats.row(stats_i);
 
             // 2. Compute the best feature to split current node by
 
@@ -502,14 +505,14 @@ DecisionTree<Container>::expand(const Accumulator &state,
                     // each value of feature
                     Index fv_index = state.indexCatStats(f, v, true);
                     double gain = impurityGain(
-                        state.cat_stats.row(i).segment(fv_index, sps * 2), 
sps);
+                        state.cat_stats.row(stats_i).
+                            segment(fv_index, sps * 2), sps);
                     if (gain > max_impurity_gain){
                         max_impurity_gain = gain;
                         max_feat = f;
                         max_bin = v;
                         max_is_cat = true;
-                        max_stats = state.cat_stats.row(i).segment(fv_index,
-                                                                   sps * 2);
+                        max_stats = 
state.cat_stats.row(stats_i).segment(fv_index, sps * 2);
                     }
                 }
             }
@@ -519,14 +522,13 @@ DecisionTree<Container>::expand(const Accumulator &state,
                     // each bin of feature
                     Index fb_index = state.indexConStats(f, b, true);
                     double gain = impurityGain(
-                        state.con_stats.row(i).segment(fb_index, sps * 2), 
sps);
+                        state.con_stats.row(stats_i).segment(fb_index, sps * 
2), sps);
                     if (gain > max_impurity_gain){
                         max_impurity_gain = gain;
                         max_feat = f;
                         max_bin = b;
                         max_is_cat = false;
-                        max_stats = state.con_stats.row(i).segment(fb_index,
-                                                                   sps * 2);
+                        max_stats = 
state.con_stats.row(stats_i).segment(fb_index, sps * 2);
                     }
                 }
             }
@@ -548,7 +550,8 @@ DecisionTree<Container>::expand(const Accumulator &state,
                 }
                 children_wont_split &=
                     updatePrimarySplit(
-                        current, static_cast<int>(max_feat),
+                        current,
+                        static_cast<int>(max_feat),
                         max_threshold, max_is_cat,
                         min_split,
                         max_stats.segment(0, sps),   // true_stats
@@ -626,8 +629,8 @@ DecisionTree<Container>::pickSurrogates(
     Matrix cat_stats_counts(state.cat_stats * cat_agg_matrix);
     Matrix con_stats_counts(state.con_stats * con_agg_matrix);
 
-    // cat_stats_counts size = n_nodes x n_cats*2
-    // con_stats_counts size = n_nodes x n_cons*2
+    // cat_stats_counts size = n_reachable_leaf_nodes x n_cats*2
+    // con_stats_counts size = n_reachable_leaf_nodes x n_cons*2
     // *_stats_counts now contains the agreement count for each split where
     // each even col represents forward surrogate split count and
     // each odd col represents reverse surrogate split count.
@@ -635,12 +638,14 @@ DecisionTree<Container>::pickSurrogates(
     // Number of nodes in a last layer = 2^(tree_depth-1). (since depth starts 
from 1)
     // For n_surr_nodes, we need number of nodes in 2nd last layer,
     // so we use 2^(tree_depth-2)
-    uint16_t n_surr_nodes = static_cast<uint16_t>(pow(2, tree_depth - 2));
-    uint16_t n_ancestors = static_cast<uint16_t>(n_surr_nodes - 1);
+    uint32_t n_surr_nodes = static_cast<uint32_t>(pow(2, tree_depth - 2));
+    uint32_t n_ancestors = static_cast<uint32_t>(n_surr_nodes - 1);
 
     for (Index i=0; i < n_surr_nodes; i++){
         Index curr_node = n_ancestors + i;
         assert(curr_node >= 0 && curr_node < feature_indices.size());
+        Index stats_i = static_cast<Index>(state.stats_lookup(i));
+        assert(stats_i >= 0);
 
         if (feature_indices(curr_node) >= 0){
             // 1. Compute the max count and corresponding split threshold for
@@ -652,11 +657,11 @@ DecisionTree<Container>::pickSurrogates(
             for (Index each_cat=0; each_cat < n_cats; each_cat++){
                 Index n_levels = state.cat_levels_cumsum(each_cat) - 
prev_cum_levels;
                 Index max_label;
-                (cat_stats_counts.row(i).segment(
+                (cat_stats_counts.row(stats_i).segment(
                     prev_cum_levels * 2, n_levels * 2)).maxCoeff(&max_label);
                 cat_max_thres(each_cat) = static_cast<double>(max_label / 2);
                 cat_max_count(each_cat) =
-                        cat_stats_counts(i, prev_cum_levels*2 + max_label);
+                        cat_stats_counts(stats_i, prev_cum_levels*2 + 
max_label);
                 // every odd col is for reverse, hence i % 2 == 1 for reverse 
index i
                 cat_max_is_reverse(each_cat) = (max_label % 2 == 1) ? 1 : 0;
                 prev_cum_levels = state.cat_levels_cumsum(each_cat);
@@ -667,11 +672,11 @@ DecisionTree<Container>::pickSurrogates(
             IntegerVector con_max_is_reverse = IntegerVector::Zero(n_cons);
             for (Index each_con=0; each_con < n_cons; each_con++){
                 Index max_label;
-                (con_stats_counts.row(i).segment(
+                (con_stats_counts.row(stats_i).segment(
                         each_con*n_bins*2, n_bins*2)).maxCoeff(&max_label);
                 con_max_thres(each_con) = con_splits(each_con, max_label / 2);
                 con_max_count(each_con) =
-                        con_stats_counts(i, each_con*n_bins*2 + max_label);
+                        con_stats_counts(stats_i, each_con*n_bins*2 + 
max_label);
                 con_max_is_reverse(each_con) = (max_label % 2 == 1) ? 1 : 0;
             }
 
@@ -740,7 +745,7 @@ DecisionTree<Container>::expand_by_sampling(const 
Accumulator &state,
                                 const uint16_t &max_depth,
                                 const int &n_random_features) {
 
-    uint16_t n_non_leaf_nodes = static_cast<uint16_t>(state.n_leaf_nodes - 1);
+    uint32_t n_non_leaf_nodes = static_cast<uint32_t>(state.n_leaf_nodes - 1);
     bool children_not_allocated = true;
     bool children_wont_split = true;
 
@@ -756,9 +761,12 @@ DecisionTree<Container>::expand_by_sampling(const 
Accumulator &state,
 
     for (Index i=0; i < state.n_leaf_nodes; i++) {
         Index current = n_non_leaf_nodes + i;
+        Index stats_i = static_cast<Index>(state.stats_lookup(i));
+        assert(stats_i >= 0);
+
         if (feature_indices(current) == IN_PROCESS_LEAF) {
             // 1. Set the prediction for current node from stats of all rows
-            predictions.row(current) = state.node_stats.row(i);
+            predictions.row(current) = state.node_stats.row(stats_i);
 
             for (int j=0; j<total_cat_con_features; j++) {
                 cat_con_feature_indices[j] = j;
@@ -785,14 +793,16 @@ DecisionTree<Container>::expand_by_sampling(const 
Accumulator &state,
                         // each value of feature
                         Index fv_index = state.indexCatStats(f, v, true);
                         double gain = impurityGain(
-                            state.cat_stats.row(i).segment(fv_index, sps * 2), 
sps);
+                            state.cat_stats.row(stats_i).
+                                segment(fv_index, sps * 2),
+                            sps);
                         if (gain > max_impurity_gain){
                             max_impurity_gain = gain;
                             max_feat = f;
                             max_bin = v;
                             max_is_cat = true;
-                            max_stats = 
state.cat_stats.row(i).segment(fv_index,
-                                                                       sps * 
2);
+                            max_stats = state.cat_stats.row(stats_i).
+                                            segment(fv_index, sps * 2);
                         }
                     }
 
@@ -804,14 +814,16 @@ DecisionTree<Container>::expand_by_sampling(const 
Accumulator &state,
                         // each bin of feature
                         Index fb_index = state.indexConStats(f, b, true);
                         double gain = impurityGain(
-                            state.con_stats.row(i).segment(fb_index, sps * 2), 
sps);
+                            state.con_stats.row(stats_i).
+                                segment(fb_index, sps * 2),
+                            sps);
                         if (gain > max_impurity_gain){
                             max_impurity_gain = gain;
                             max_feat = f;
                             max_bin = b;
                             max_is_cat = false;
-                            max_stats = 
state.con_stats.row(i).segment(fb_index,
-                                                                       sps * 
2);
+                            max_stats = state.con_stats.row(stats_i).
+                                            segment(fb_index, sps * 2);
                         }
                     }
                 }
@@ -1061,7 +1073,7 @@ DecisionTree<Container>::recomputeTreeDepth() const{
         return tree_depth;
 
     for(uint16_t depth_counter = 2; depth_counter <= tree_depth; 
depth_counter++){
-        uint32_t n_leaf_nodes = static_cast<uint16_t>(pow(2, depth_counter - 
1));
+        uint32_t n_leaf_nodes = static_cast<uint32_t>(pow(2, depth_counter - 
1));
         uint32_t leaf_start_index = n_leaf_nodes - 1;
         bool all_non_existing = true;
         for (uint32_t leaf_index=0; leaf_index < n_leaf_nodes; leaf_index++){
@@ -1125,7 +1137,7 @@ DecisionTree<Container>::displayLeafNode(
                     n_elem = NUM_PER_LINE;
                 } else {
                     // less than NUM_PER_LINE left, avoid reading past the end
-                    n_elem = pred_size - i;
+                    n_elem = static_cast<uint16_t>(pred_size - i);
                 }
                 display_str << predictions.row(id).segment(i, n_elem) << "\n";
             }
@@ -1169,7 +1181,7 @@ DecisionTree<Container>::displayInternalNode(
         size_t to_skip = 0;
         for (Index i=0; i < feature_indices(id); i++)
             to_skip += cat_n_levels[i];
-        const size_t index = to_skip + feature_thresholds(id);
+        const size_t index = to_skip + 
static_cast<size_t>(feature_thresholds(id));
         label_str << get_text(cat_levels_text, index);
     }
 
@@ -1195,7 +1207,7 @@ DecisionTree<Container>::displayInternalNode(
                     // not overflowing the vector
                     n_elem = NUM_PER_LINE;
                 } else {
-                    n_elem = pred_size - i;
+                    n_elem = static_cast<uint16_t>(pred_size - i);
                 }
                 display_str << predictions.row(id).segment(i, n_elem) << "\n";
             }
@@ -1520,8 +1532,8 @@ TreeAccumulator<Container, DTree>::TreeAccumulator(
  * there is no guarantee yet that the element can indeed be accessed. It is
  * cruicial to first check this.
  *
- * Provided that this methods correctly lists all member variables, all other
- * methods can, however, rely on that fact that all variables are correctly
+ * Provided that this method correctly lists all member variables, all other
+ * methods can rely on that fact that all variables are correctly
  * initialized and accessible.
  */
 template <class Container, class DTree>
@@ -1536,6 +1548,7 @@ TreeAccumulator<Container, DTree>::bind(ByteStream_type& 
inStream) {
              >> n_con_features
              >> total_n_cat_levels
              >> n_leaf_nodes
+             >> n_reachable_leaf_nodes
              >> stats_per_split
              >> weights_as_rows ;
 
@@ -1543,7 +1556,8 @@ TreeAccumulator<Container, DTree>::bind(ByteStream_type& 
inStream) {
     uint16_t n_cat = 0;
     uint16_t n_con = 0;
     uint32_t tot_levels = 0;
-    uint16_t n_leafs = 0;
+    uint32_t n_leaves = 0;
+    uint32_t n_reachable_leaves = 0;
     uint16_t n_stats = 0;
 
     if (!n_rows.isNull()){
@@ -1551,15 +1565,17 @@ TreeAccumulator<Container, 
DTree>::bind(ByteStream_type& inStream) {
         n_cat = n_cat_features;
         n_con = n_con_features;
         tot_levels = total_n_cat_levels;
-        n_leafs = n_leaf_nodes;
+        n_leaves = n_leaf_nodes;
+        n_reachable_leaves = n_reachable_leaf_nodes;
         n_stats = stats_per_split;
     }
 
     inStream
         >> cat_levels_cumsum.rebind(n_cat)
-        >> cat_stats.rebind(n_leafs, tot_levels * n_stats * 2)
-        >> con_stats.rebind(n_leafs, n_con * n_bins_tmp * n_stats * 2)
-        >> node_stats.rebind(n_leafs, n_stats);
+        >> cat_stats.rebind(n_reachable_leaves, tot_levels * n_stats * 2)
+        >> con_stats.rebind(n_reachable_leaves, n_con * n_bins_tmp * n_stats * 
2)
+        >> node_stats.rebind(n_reachable_leaves, n_stats)
+        >> stats_lookup.rebind(n_leaves);
 }
 // -------------------------------------------------------------------------
 
@@ -1574,7 +1590,8 @@ void
 TreeAccumulator<Container, DTree>::rebind(
         uint16_t in_n_bins, uint16_t in_n_cat_feat,
         uint16_t in_n_con_feat, uint32_t in_n_total_levels,
-        uint16_t tree_depth, uint16_t in_n_stats, bool in_weights_as_rows) {
+        uint16_t tree_depth, uint16_t in_n_stats,
+        bool in_weights_as_rows, uint32_t n_reachable_leaves) {
 
     n_bins = in_n_bins;
     n_cat_features = in_n_cat_feat;
@@ -1582,9 +1599,13 @@ TreeAccumulator<Container, DTree>::rebind(
     total_n_cat_levels = in_n_total_levels;
     weights_as_rows = in_weights_as_rows;
     if (tree_depth > 0)
-        n_leaf_nodes = static_cast<uint16_t>(pow(2, tree_depth - 1));
+        n_leaf_nodes = static_cast<uint32_t>(pow(2, tree_depth - 1));
     else
         n_leaf_nodes = 1;
+    if (n_reachable_leaves >= n_leaf_nodes)
+        n_reachable_leaf_nodes = n_leaf_nodes;
+    else
+        n_reachable_leaf_nodes = n_reachable_leaves;
     stats_per_split = in_n_stats;
     this->resize();
 }
@@ -1618,7 +1639,7 @@ TreeAccumulator<Container, DTree>::operator<<(const 
tuple_type& inTuple) {
         } else if (n_con_features != 
static_cast<uint16_t>(con_features.size())) {
             warning("Inconsistent numbers of continuous independent 
variables.");
         } else{
-            uint16_t n_non_leaf_nodes = static_cast<uint16_t>(n_leaf_nodes - 
1);
+            uint32_t n_non_leaf_nodes = static_cast<uint32_t>(n_leaf_nodes - 
1);
             Index dt_search_index = dt.search(cat_features, con_features);
             if (dt.feature_indices(dt_search_index) != dt.FINISHED_LEAF &&
                    dt.feature_indices(dt_search_index) != 
dt.NODE_NON_EXISTING) {
@@ -1687,8 +1708,8 @@ TreeAccumulator<Container, DTree>::operator<<(const 
surr_tuple_type& inTuple) {
     } else{
         // the accumulator is setup to train for the 2nd last layer
         // hence the n_leaf_nodes is same as n_surr_nodes
-        uint16_t n_surr_nodes = n_leaf_nodes;
-        uint16_t n_non_surr_nodes = static_cast<uint16_t>(n_surr_nodes - 1);
+        uint32_t n_surr_nodes = n_leaf_nodes;
+        uint32_t n_non_surr_nodes = static_cast<uint32_t>(n_surr_nodes - 1);
 
         Index dt_parent_index = dt.parentIndex(dt.search(cat_features, 
con_features));
 
@@ -1710,8 +1731,7 @@ TreeAccumulator<Container, DTree>::operator<<(const 
surr_tuple_type& inTuple) {
             if (dt.feature_indices(dt_parent_index) >= 0){
                 Index row_index = dt_parent_index - n_non_surr_nodes;
 
-                assert(row_index >= 0 && row_index < cat_stats.rows() &&
-                       row_index < con_stats.rows());
+                assert(row_index >= 0 && row_index < stats_lookup.rows());
 
                 for (Index i=0; i < n_cat_features; ++i){
                     if (is_primary_cat && i == primary_index)
@@ -1800,7 +1820,8 @@ TreeAccumulator<Container, DTree>::updateNodeStats(bool 
is_regression,
         stats(static_cast<uint16_t>(response)) = weight;
         stats.tail(1)(0) = n_rows;
     }
-    node_stats.row(node_index) += stats;
+    assert(stats_lookup(node_index) >= 0);
+    node_stats.row(stats_lookup(node_index)) += stats;
 }
 // -------------------------------------------------------------------------
 
@@ -1826,11 +1847,12 @@ TreeAccumulator<Container, DTree>::updateStats(bool 
is_regression,
         stats(static_cast<uint16_t>(response)) = weight;
         stats.tail(1)(0) = n_rows;
     }
-
+    Index stats_i = stats_lookup(row_index);
+    assert(stats_i >= 0);
     if (is_cat) {
-        cat_stats.row(row_index).segment(stats_index, stats_per_split) += 
stats;
+        cat_stats.row(stats_i).segment(stats_index, stats_per_split) += stats;
     } else {
-        con_stats.row(row_index).segment(stats_index, stats_per_split) += 
stats;
+        con_stats.row(stats_i).segment(stats_index, stats_per_split) += stats;
     }
 }
 // -------------------------------------------------------------------------
@@ -1854,10 +1876,12 @@ TreeAccumulator<Container, DTree>::updateSurrStats(
     else
         stats << 0, dup_count;
 
+    Index stats_i = stats_lookup(row_index);
+    assert(stats_i >= 0);
     if (is_cat) {
-        cat_stats.row(row_index).segment(stats_index, stats_per_split) += 
stats;
+        cat_stats.row(stats_i).segment(stats_index, stats_per_split) += stats;
     } else {
-        con_stats.row(row_index).segment(stats_index, stats_per_split) += 
stats;
+        con_stats.row(stats_i).segment(stats_index, stats_per_split) += stats;
     }
 }
 // -------------------------------------------------------------------------
@@ -1881,7 +1905,8 @@ TreeAccumulator<Container, DTree>::indexCatStats(Index 
feature_index,
                                                  int   cat_value,
                                                  bool  is_split_true) const {
     // cat_stats is a matrix
-    //   size = (n_leaf_nodes) x (total_n_cat_levels * stats_per_split * 2)
+    //   size = (n_reachable_leaf_nodes) x
+    //                  (total_n_cat_levels * stats_per_split * 2)
     assert(feature_index < n_cat_features);
     unsigned int cat_cumsum_value = (feature_index == 0) ? 0 : 
cat_levels_cumsum(feature_index - 1);
     return computeSubIndex(static_cast<Index>(cat_cumsum_value),

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/20b11580/src/modules/recursive_partitioning/DT_proto.hpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/DT_proto.hpp 
b/src/modules/recursive_partitioning/DT_proto.hpp
index a2881a5..272fec4 100644
--- a/src/modules/recursive_partitioning/DT_proto.hpp
+++ b/src/modules/recursive_partitioning/DT_proto.hpp
@@ -245,7 +245,8 @@ public:
     void bind(ByteStream_type& inStream);
     void rebind(uint16_t n_bins, uint16_t n_cat_feat,
                 uint16_t n_con_feat, uint32_t n_total_levels,
-                uint16_t tree_depth, uint16_t n_stats, bool weights_as_rows);
+                uint16_t tree_depth, uint16_t n_stats, bool weights_as_rows,
+                uint32_t n_reachable_leaf_nodes);
 
     TreeAccumulator& operator<<(const tuple_type& inTuple);
     TreeAccumulator& operator<<(const surr_tuple_type& inTuple);
@@ -284,7 +285,13 @@ public:
     // sum of num of levels in each categorical variable
     uint32_type total_n_cat_levels;
     // n_leaf_nodes = 2^{dt.tree_depth-1} for dt.tree_depth > 0
-    uint16_type n_leaf_nodes;
+    uint32_type n_leaf_nodes;
+
+    // Not all "leaf" nodes at a tree level are reachable. A leaf becomes
+    // non-reachable when one of its ancestor is itself a leaf.
+    // For a full tree, n_leaf_nodes = n_reachable_leaf_nodes
+    uint32_type n_reachable_leaf_nodes;
+
     // For regression, stats_per_split = 4, i.e. (w, w*y, w*y^2, 1)
     // For classification, stats_per_split = (number of class labels + 1)
     // i.e. (w_1, w_2, ..., w_c, 1)
@@ -305,10 +312,11 @@ public:
     // con_stats and cat_stats are matrices that contain the statistics used
     // during training.
     // cat_stats is a matrix of size:
-    // (n_leaf_nodes) x (total_n_cat_levels * stats_per_split * 2)
+    // (n_reachable_leaf_nodes) x (total_n_cat_levels * stats_per_split * 2)
     Matrix_type cat_stats;
+
     // con_stats is a matrix:
-    // (n_leaf_nodes) x (n_con_features * n_bins * stats_per_split * 2)
+    // (n_reachable_leaf_nodes) x (n_con_features * n_bins * stats_per_split * 
2)
     Matrix_type con_stats;
 
     // node_stats is used to keep a statistic of all the rows that land on a
@@ -317,6 +325,10 @@ public:
     // cat_stats/con_stats. In the presence of NULL value, the stats could be
     // different.
     Matrix_type node_stats;
+
+    // Above stats matrices are used as pseudo-sparse matrices since not all
+    // leaf nodes are reachable (esp. as tree gets deeper).
+    IntegerVector_type stats_lookup;
 };
 // ------------------------------------------------------------------------
 

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/20b11580/src/modules/recursive_partitioning/decision_tree.cpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/decision_tree.cpp 
b/src/modules/recursive_partitioning/decision_tree.cpp
index b298df8..b85923a 100644
--- a/src/modules/recursive_partitioning/decision_tree.cpp
+++ b/src/modules/recursive_partitioning/decision_tree.cpp
@@ -154,6 +154,24 @@ compute_leaf_stats_transition::run(AnyType & args){
     }
 
     if (state.empty()){
+        // To initialize the accumulator, first find which of the leaf nodes
+        // in current tree are actually reachable.
+        // The lookup vector maps the leaf node index in a (fictional) complete
+        // tree to the index in the actual tree.
+        ColumnVector leaf_feature_indices =
+            dt.feature_indices.tail(dt.feature_indices.size()/2 + 
1).cast<double>();
+        ColumnVector leaf_node_lookup(leaf_feature_indices.size());
+        size_t n_leaves_not_finished = 0;
+        for (Index i=0; i < leaf_feature_indices.size(); i++){
+            if ((leaf_feature_indices(i) != dt.NODE_NON_EXISTING) &&
+                    (leaf_feature_indices(i) != dt.FINISHED_LEAF)){
+                leaf_node_lookup(i) = n_leaves_not_finished++;  // increment 
after assigning
+            }
+            else{
+                leaf_node_lookup(i) = -1;
+            }
+        }
+
         // For classification, we store for each split the number of weighted
         // tuples for each possible response value and the number of unweighted
         // tuples landing on that node.
@@ -167,22 +185,27 @@ compute_leaf_stats_transition::run(AnyType & args){
                      static_cast<uint32_t>(cat_levels.sum()),
                      static_cast<uint16_t>(dt.tree_depth),
                      stats_per_split,
-                     weights_as_rows
+                     weights_as_rows,
+                     static_cast<uint32_t>(n_leaves_not_finished)
                     );
+        for (Index i=0; i < state.stats_lookup.size(); i++)
+            state.stats_lookup(i) = leaf_node_lookup(i);
+
         // compute cumulative sum of the levels of the categorical variables
         int current_sum = 0;
         for (Index i=0; i < state.n_cat_features; ++i){
-            // We assume that the levels of each categorical variable are 
sorted
-            //  by the entropy for predicting the response. We then create 
splits
-            //  of the form 'A <= t', where A has N levels and t in [0, N-2].
+            // Assuming that the levels of each categorical variable are 
ordered,
+            //    create splits of the form 'A <= t', where A has N levels
+            //    and t in [0, N-2].
             // This split places all levels <= t on true node and
-            //  others on false node. We only check till N-2 since we want at
-            //  least 1 level falling to the false node.
-            // We keep a variable with just 1 level to ensure alignment,
-            //  even though that variable will not be used as a split feature.
+            //    others on false node. Checking till N-2 instead of N-1
+            //    since at least 1 level should go to false node.
+            // Variable with just 1 level is maintained to ensure alignment,
+            //    even though the variable will not be used as a split feature.
             current_sum += cat_levels(i);
             state.cat_levels_cumsum(i) = current_sum;
         }
+
     }
 
     state << MutableLevelState::tuple_type(dt, cat_features, con_features,
@@ -236,6 +259,7 @@ dt_apply::run(AnyType & args){
         return_code = TERMINATED;  // indicates termination due to error
     }
 
+
     AnyType output_tuple;
     output_tuple << dt.storage()
                  << return_code
@@ -292,6 +316,21 @@ compute_surr_stats_transition::run(AnyType & args){
     // the root be an internal node i.e. we need the tree_depth to be more 
than 1.
     if (dt.tree_depth > 1){
         if (state.empty()){
+             // To initialize the accumulator, first find which of the last
+             // level of internal nodes are actually reachable.
+            ColumnVector final_internal_feature_indices =
+                dt.feature_indices.segment(dt.feature_indices.size()/4,
+                                           dt.feature_indices.size()/4 + 
1).cast<double>();
+            ColumnVector index_lookup(final_internal_feature_indices.size());
+            Index n_internal_nodes_reachable = 0;
+            for (Index i=0; i < final_internal_feature_indices.size(); i++){
+                if (final_internal_feature_indices(i) >= 0){
+                    index_lookup(i) = n_internal_nodes_reachable++;  // 
increment after assigning
+                }
+                else{
+                    index_lookup(i) = -1;
+                }
+            }
             // 1. We need to compute stats for parent of each leaf.
             //      Hence the tree_depth is decremented by 1.
             // 2. We store 2 values for each surrogate split
@@ -303,11 +342,14 @@ compute_surr_stats_transition::run(AnyType & args){
                          static_cast<uint32_t>(cat_levels.sum()),
                          static_cast<uint16_t>(dt.tree_depth - 1),
                          2,
-                         false // dummy, only used in compute_leaf_stat
+                         false, // dummy, only used in compute_leaf_stat
+                         n_internal_nodes_reachable
                         );
+            for (Index i = 0; i < state.stats_lookup.size(); i++)
+                state.stats_lookup(i) = index_lookup(i);
             // compute cumulative sum of the levels of the categorical 
variables
             int current_sum = 0;
-            for (Index i=0; i < state.n_cat_features; ++i){
+            for (Index i=0; i < state.n_cat_features; i++){
                 current_sum += cat_levels(i);
                 state.cat_levels_cumsum(i) = current_sum;
             }

Reply via email to