Repository: incubator-madlib
Updated Branches:
  refs/heads/master 9e632f1aa -> d43fc29d0


DT: Include NULL rows in count for termination check

An optimization was added to check after training a node
if its children would split further by analyzing the tuples going to
those nodes. There were two issues with this optimization:

1. To decide if a node won't split, it was checking if a node is too
pure (i.e. all responses are similar) *and* the number of tuples in the
node was less than a threshold (min_split). The check, however, should
be *or* since either one of those conditions should stop a node from
splitting. This commit corrects the logic.

2. When the primary split feature for a node is computed, the statistics
of rows going to the true and false side don't include the rows that have
NULL value for this split feature. These "NULL" rows can only be
included in the statistics during the next pass when the child node is
trained. This commit ensures that in the presence of NULL rows, this
optimization is disabled so that we don't terminate prematurely
after comparing with a lower count.

Closes #142


Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/d43fc29d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/d43fc29d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/d43fc29d

Branch: refs/heads/master
Commit: d43fc29d08c1219e7d55b1f7e1bca6f5c56df18e
Parents: 9e632f1
Author: Rahul Iyer <ri...@apache.org>
Authored: Tue Jun 20 23:31:06 2017 -0700
Committer: Rahul Iyer <ri...@apache.org>
Committed: Tue Jun 27 10:36:27 2017 -0700

----------------------------------------------------------------------
 src/modules/recursive_partitioning/DT_impl.hpp  | 147 ++++++++++---------
 src/modules/recursive_partitioning/DT_proto.hpp |   8 +-
 .../recursive_partitioning/decision_tree.sql_in |  12 +-
 3 files changed, 85 insertions(+), 82 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/d43fc29d/src/modules/recursive_partitioning/DT_impl.hpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/DT_impl.hpp 
b/src/modules/recursive_partitioning/DT_impl.hpp
index 6d15db5..27bc647 100644
--- a/src/modules/recursive_partitioning/DT_impl.hpp
+++ b/src/modules/recursive_partitioning/DT_impl.hpp
@@ -446,21 +446,21 @@ DecisionTree<Container>::updatePrimarySplit(
     predictions.row(falseChild(node_index)) = false_stats;
 
     // true_stats and false_stats only include the tuples for which the primary
-    // split is NULL. The number of tuples in these stats need to be stored to
+    // split is not NULL. The number of tuples in these stats need to be 
stored to
     // compute a majority branch during surrogate training.
     uint64_t true_count = statCount(true_stats);
     uint64_t false_count = statCount(false_stats);
-    nonnull_split_count(node_index*2) = static_cast<double>(true_count);
-    nonnull_split_count(node_index*2 + 1) = static_cast<double>(false_count);
-
-    // current node's children won't split if,
-    // 1. children are pure (responses are too similar to split further)
-    // 2. children are too small to split further (count < min_split)
-    bool children_wont_split = (isChildPure(true_stats) &&
-                                isChildPure(false_stats) &&
-                                true_count < min_split &&
-                                false_count < min_split
-                                );
+    nonnull_split_count(trueChild(node_index)) = 
static_cast<double>(true_count);
+    nonnull_split_count(falseChild(node_index)) = 
static_cast<double>(false_count);
+
+    // current node's each child won't split if,
+    //      1. child is pure (responses are too similar to split further)
+    //      OR
+    //      2. child is too small to split (count < min_split)
+    bool children_wont_split = ((isChildPure(true_stats) ||
+                                    true_count < min_split) &&
+                                (isChildPure(false_stats) ||
+                                    false_count < min_split));
     return children_wont_split;
 }
 // -------------------------------------------------------------------------
@@ -486,10 +486,23 @@ DecisionTree<Container>::expand(const Accumulator &state,
             Index stats_i = static_cast<Index>(state.stats_lookup(i));
             assert(stats_i >= 0);
 
-            // 1. Set the prediction for current node from stats of all rows
-            predictions.row(current) = state.node_stats.row(stats_i);
+            // 1. Update predictions if necessary
+            if (statCount(predictions.row(current)) !=
+                    statCount(state.node_stats.row(stats_i))){
+                // Predictions for each node is set by its parent using stats
+                // recorded while training parent node. These stats do not
+                // include rows that had a NULL value for the primary split
+                // feature. The NULL count is included in 'node_stats' while
+                // training current node.
+                predictions.row(current) = state.node_stats.row(stats_i);
+
+                // Presence of NULL rows indicate that stats used for deciding
+                // 'children_wont_split' are inaccurate. Hence avoid using the
+                // flag to decide termination.
+                children_wont_split = false;
+            }
 
-            // 2. Compute the best feature to split current node by
+            // 2. Compute the best feature to split current node
 
             // if a leaf node exists, compute the gain in impurity for each 
split
             // pick split  with maximum gain and update node with split value
@@ -533,9 +546,13 @@ DecisionTree<Container>::expand(const Accumulator &state,
                 }
             }
 
-            // 3. Create and update child nodes if splitting current
+            // 3. Create and update children if splitting current
+            uint64_t true_count = statCount(max_stats.segment(0, sps));
+            uint64_t false_count = statCount(max_stats.segment(sps, sps));
+            uint64_t total_count = statCount(predictions.row(current));
             if (max_impurity_gain > 0 &&
-                    shouldSplit(max_stats, min_split, min_bucket, sps, 
max_depth)) {
+                    shouldSplit(total_count, true_count, false_count,
+                                min_split, min_bucket, sps, max_depth)) {
 
                 double max_threshold;
                 if (max_is_cat)
@@ -761,12 +778,22 @@ DecisionTree<Container>::expand_by_sampling(const 
Accumulator &state,
 
     for (Index i=0; i < state.n_leaf_nodes; i++) {
         Index current = n_non_leaf_nodes + i;
-        Index stats_i = static_cast<Index>(state.stats_lookup(i));
-        assert(stats_i >= 0);
-
         if (feature_indices(current) == IN_PROCESS_LEAF) {
-            // 1. Set the prediction for current node from stats of all rows
-            predictions.row(current) = state.node_stats.row(stats_i);
+            Index stats_i = static_cast<Index>(state.stats_lookup(i));
+            assert(stats_i >= 0);
+
+            if (statCount(predictions.row(current)) !=
+                    statCount(state.node_stats.row(stats_i))){
+                // Predictions for each node is set by its parent using stats
+                // recorded while training parent node. These stats do not 
include
+                // rows that had a NULL value for the primary split feature.
+                // The NULL count is included in the 'node_stats' while 
training
+                // current node. Further, presence of NULL rows indicate that
+                // stats used for deciding 'children_wont_split' are 
inaccurate.
+                // Hence avoid using the flag to decide termination.
+                predictions.row(current) = state.node_stats.row(stats_i);
+                children_wont_split = false;
+            }
 
             for (int j=0; j<total_cat_con_features; j++) {
                 cat_con_feature_indices[j] = j;
@@ -829,9 +856,14 @@ DecisionTree<Container>::expand_by_sampling(const 
Accumulator &state,
                 }
             }
 
-            // create and update child nodes if splitting current
+            // Create and update child nodes if splitting current
+            uint64_t true_count = statCount(max_stats.segment(0, sps));
+            uint64_t false_count = statCount(max_stats.segment(sps, sps));
+            uint64_t total_count = statCount(predictions.row(current));
+
             if (max_impurity_gain > 0 &&
-                    shouldSplit(max_stats, min_split, min_bucket, sps, 
max_depth)) {
+                    shouldSplit(total_count, true_count, false_count,
+                                min_split, min_bucket, sps, max_depth)) {
 
                 double max_threshold;
                 if (max_is_cat)
@@ -1024,19 +1056,19 @@ DecisionTree<Container>::isChildPure(const ColumnVector 
&stats) const{
 template <class Container>
 inline
 bool
-DecisionTree<Container>::shouldSplit(const ColumnVector &combined_stats,
-                                      const uint16_t &min_split,
-                                      const uint16_t &min_bucket,
-                                      const uint16_t &stats_per_split,
-                                      const uint16_t &max_depth) const {
-
-    // combined_stats is assumed to be of size = stats_per_split
-    // we always want at least 1 tuple going into a child node. Hence the
+DecisionTree<Container>::shouldSplit(const uint64_t &total_count,
+                                     const uint64_t &true_count,
+                                     const uint64_t &false_count,
+                                     const uint16_t &min_split,
+                                     const uint16_t &min_bucket,
+                                     const uint16_t &stats_per_split,
+                                     const uint16_t &max_depth) const {
+    // total_count != true_count + false_count if there are rows with NULL 
values
+
+    // Always want at least 1 tuple going into a child node. Hence the
     // minimum value for min_bucket is 1
     uint64_t thresh_min_bucket = (min_bucket == 0) ? 1u : min_bucket;
-    uint64_t true_count = statCount(combined_stats.segment(0, 
stats_per_split));
-    uint64_t false_count = statCount(combined_stats.segment(stats_per_split, 
stats_per_split));
-    return ((true_count + false_count) >= min_split &&
+    return (total_count >= min_split &&
             true_count >= thresh_min_bucket &&
             false_count >= thresh_min_bucket &&
             tree_depth <= max_depth + 1);
@@ -1045,28 +1077,6 @@ DecisionTree<Container>::shouldSplit(const ColumnVector 
&combined_stats,
 
 template <class Container>
 inline
-bool
-DecisionTree<Container>::shouldSplitWeights(const ColumnVector &combined_stats,
-                                      const uint16_t &min_split,
-                                      const uint16_t &min_bucket,
-                                      const uint16_t &stats_per_split) const {
-
-    // combined_stats is assumed to be of size = stats_per_split
-    // number of tuples landing on a node is equal to the sum of weights for
-    // that node. we therefore use statWeightedCount
-    // we always want at least 1 tuple going into a child node. Hence the
-    // minimum value for min_bucket is 1
-    uint64_t thresh_min_bucket = (min_bucket == 0) ? 1u : min_bucket;
-    double true_count = statWeightedCount(combined_stats.segment(0, 
stats_per_split));
-    double false_count = 
statWeightedCount(combined_stats.segment(stats_per_split, stats_per_split));
-    return ((true_count + false_count) >= min_split &&
-            true_count >= thresh_min_bucket &&
-            false_count >= thresh_min_bucket);
-}
-// ------------------------------------------------------------------------
-
-template <class Container>
-inline
 uint16_t
 DecisionTree<Container>::recomputeTreeDepth() const{
     if (feature_indices.size() <= 1 || tree_depth <= 1)
@@ -1131,15 +1141,12 @@ DecisionTree<Container>::displayLeafNode(
             // can be ignored
             const Index pred_size = predictions.row(id).size() - 1;
             for (Index i = 0; i < pred_size; i += NUM_PER_LINE){
-                uint16_t n_elem;
-                if (i + NUM_PER_LINE <= pred_size) {
+                if (i + NUM_PER_LINE >= pred_size) {
                     // not overflowing the vector
-                    n_elem = NUM_PER_LINE;
+                    display_str << predictions.row(id).segment(i, pred_size - 
i);
                 } else {
-                    // less than NUM_PER_LINE left, avoid reading past the end
-                    n_elem = static_cast<uint16_t>(pred_size - i);
+                    display_str << predictions.row(id).segment(i, 
NUM_PER_LINE) << "\n";
                 }
-                display_str << predictions.row(id).segment(i, n_elem) << "\n";
             }
             display_str << "]";
 
@@ -1174,7 +1181,7 @@ DecisionTree<Container>::displayInternalNode(
         label_str << escape_quotes(feature_name) << " <= " << 
feature_thresholds(id);
     } else {
         feature_name = get_text(cat_features_str, feature_indices(id));
-        label_str << escape_quotes(feature_name) << " <= ";
+        label_str << escape_quotes(feature_name) << " = ";
 
         // Text for all categoricals are stored in a flat array 
(cat_levels_text);
         // find the appropriate index for this node
@@ -1202,14 +1209,12 @@ DecisionTree<Container>::displayInternalNode(
             // be ignored
             const Index pred_size = predictions.row(id).size() - 1;
             for (Index i = 0; i < pred_size; i += NUM_PER_LINE){
-                uint16_t n_elem;
-                if (i + NUM_PER_LINE <= pred_size) {
+                if (i + NUM_PER_LINE > pred_size) {
                     // not overflowing the vector
-                    n_elem = NUM_PER_LINE;
+                    display_str << predictions.row(id).segment(i, pred_size - 
i);
                 } else {
-                    n_elem = static_cast<uint16_t>(pred_size - i);
+                    display_str << predictions.row(id).segment(i, 
NUM_PER_LINE) << "\n";
                 }
-                display_str << predictions.row(id).segment(i, n_elem) << "\n";
             }
             display_str << "]";
         }
@@ -1414,7 +1419,7 @@ DecisionTree<Container>::getCatLabels(Index cat_index,
                                       Index end_value,
                                       ArrayHandle<text*> &cat_levels_text,
                                       ArrayHandle<int> &cat_n_levels) {
-    Index MAX_LABELS = 5;
+    Index MAX_LABELS = 2;
     size_t to_skip = 0;
     for (Index i=0; i < cat_index; i++) {
         to_skip += cat_n_levels[i];
@@ -1431,7 +1436,7 @@ DecisionTree<Container>::getCatLabels(Index cat_index,
             break;
         }
     }
-    cat_levels << get_text(cat_levels_text, index) << "}";
+    cat_levels << get_text(cat_levels_text, to_skip + end_value) << "}";
     return cat_levels.str();
 }
 // -------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/d43fc29d/src/modules/recursive_partitioning/DT_proto.hpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/DT_proto.hpp 
b/src/modules/recursive_partitioning/DT_proto.hpp
index 272fec4..33446a1 100644
--- a/src/modules/recursive_partitioning/DT_proto.hpp
+++ b/src/modules/recursive_partitioning/DT_proto.hpp
@@ -96,16 +96,14 @@ public:
     double impurity(const ColumnVector & stats) const;
     double impurityGain(const ColumnVector &combined_stats,
                         const uint16_t &stats_per_split) const;
-    bool shouldSplit(const ColumnVector & stats,
+    bool shouldSplit(const uint64_t &total_count,
+                     const uint64_t &true_count,
+                     const uint64_t &false_count,
                      const uint16_t &min_split,
                      const uint16_t &min_bucket,
                      const uint16_t &stats_per_split,
                      const uint16_t &max_depth) const;
 
-    bool shouldSplitWeights(const ColumnVector & stats,
-                            const uint16_t &min_split,
-                            const uint16_t &min_bucket,
-                            const uint16_t &stats_per_split) const;
     template <class Accumulator>
     void pickSurrogates(const Accumulator &state,
                         const MappedMatrix &con_splits);

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/d43fc29d/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
----------------------------------------------------------------------
diff --git 
a/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in 
b/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
index ac08466..92a123d 100644
--- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
+++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
@@ -596,16 +596,16 @@ Result:
         For each leaf, the prediction is given after the '-->'
 &nbsp;-------------------------------------
 (0)[5 9]  "OUTLOOK" in {overcast}
-   (1)[0 4]  * --> "'Play'"
+   (1)[0 4]  * \-\-> "'Play'"
    (2)[5 5]  "Cont_features"[1] <= 75
       (5)[3 5]  "Cont_features"[1] <= 65
-         (11)[1 0]  * --> "'Don't Play'"
+         (11)[1 0]  * \-\-> "'Don't Play'"
          (12)[2 5]  "Cont_features"[1] <= 70
-            (25)[0 3]  * --> "'Play'"
+            (25)[0 3]  * \-\-> "'Play'"
             (26)[2 2]  "Cont_features"[1] <= 72
-               (53)[2 0]  * --> "'Don't Play'"
-               (54)[0 2]  * --> "'Play'"
-      (6)[2 0]  * --> "'Don't Play'"
+               (53)[2 0]  * \-\-> "'Don't Play'"
+               (54)[0 2]  * \-\-> "'Play'"
+      (6)[2 0]  * \-\-> "'Don't Play'"
 &nbsp;-------------------------------------
 </pre>
 Here are some more details on how to interpret the tree display above...

Reply via email to