DT: Update error message for invalid num_splits

Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/c4fd91e1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/c4fd91e1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/c4fd91e1

Branch: refs/heads/latest_release
Commit: c4fd91e16827a5f8be4051eb3ea0d311d3e957f2
Parents: a3d54be
Author: Rahul Iyer <ri...@apache.org>
Authored: Thu Apr 27 12:12:48 2017 -0700
Committer: Rahul Iyer <ri...@apache.org>
Committed: Thu Apr 27 12:12:48 2017 -0700

----------------------------------------------------------------------
 src/modules/recursive_partitioning/feature_encoding.cpp |  8 ++++++--
 .../recursive_partitioning/test/decision_tree.sql_in    | 12 +++++++++---
 2 files changed, 15 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/c4fd91e1/src/modules/recursive_partitioning/feature_encoding.cpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/feature_encoding.cpp 
b/src/modules/recursive_partitioning/feature_encoding.cpp
index 20856e2..3b0a452 100644
--- a/src/modules/recursive_partitioning/feature_encoding.cpp
+++ b/src/modules/recursive_partitioning/feature_encoding.cpp
@@ -39,7 +39,7 @@ dst_compute_con_splits_transition::run(AnyType &args){
     if (!state.empty() && state.num_rows >= state.buff_size) {
         return args[0];
     }
-    // NULL-handling is done in python to make sure consistency b/w
+    // NULLs are handled by caller to ensure consistency between
     // feature encoding and tree training
     MappedColumnVector con_features = args[1].getAs<MappedColumnVector>();
 
@@ -71,8 +71,12 @@ dst_compute_con_splits_final::run(AnyType &args){
 
     if (state.num_rows <= state.num_splits) {
         std::stringstream error_msg;
+        // In message below, add 1 to state.num_splits since the meaning of
+        // "splits" for the caller is the number of quantiles, where as
+        // "splits" in this function is the number of values dividing the data
+        // into quantiles.
         error_msg << "Decision tree error: Number of splits ("
-            << state.num_splits
+            << state.num_splits + 1
             << ") is larger than the number of records ("
             << state.num_rows << ")";
         throw std::runtime_error(error_msg.str());

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/c4fd91e1/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
----------------------------------------------------------------------
diff --git 
a/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in 
b/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
index 28a4647..dd861a0 100644
--- 
a/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
+++ 
b/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
@@ -287,7 +287,7 @@ SELECT tree_train('dt_golf'::text,         -- source table
                          'train_output'::text,    -- output model table
                          'id'::text,              -- id column
                          'temperature::double precision'::text,           -- 
response
-                         'humidity, windy'::text,   -- features
+                         '"OUTLOOK", humidity, windy'::text,   -- features
                          NULL::text,        -- exclude columns
                          'gini'::text,      -- split criterion
                          'class'::text,     -- grouping
@@ -301,13 +301,19 @@ SELECT tree_train('dt_golf'::text,         -- source table
 
 SELECT _print_decision_tree(tree) from train_output;
 SELECT tree_display('train_output', False);
-SELECT tree_predict('train_output', 'dt_golf', 'predict_output');
+
+CREATE TABLE dt_golf2 as
+SELECT * FROM dt_golf
+UNION
+SELECT 15 as id, 'humid' as "OUTLOOK", 71 as temperature, 80 as humidity,
+        true as windy, 'Don''t Play' as class;
+SELECT tree_predict('train_output', 'dt_golf2', 'predict_output');
 \x off
 SELECT *
 FROM
     predict_output
 JOIN
-    dt_golf
+    dt_golf2
 USING (id);
 \x on
 select * from train_output;

Reply via email to