Repository: incubator-madlib
Updated Branches:
  refs/heads/master 18b8486ca -> ec60b83d2


RF: Filter NULL dependent values in OOB

JIRA: MADLIB-1097

Added `filter_null` string obtained from decision_tree.py into the OOB
view to exclude rows that have NULL dependent values.


Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/9b45ecaa
Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/9b45ecaa
Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/9b45ecaa

Branch: refs/heads/master
Commit: 9b45ecaaadb9e0d4999dc49e72df8a97cb7692d2
Parents: 18b8486
Author: Rahul Iyer <ri...@apache.org>
Authored: Wed May 3 17:07:55 2017 -0700
Committer: Rahul Iyer <ri...@apache.org>
Committed: Wed May 10 15:56:57 2017 -0700

----------------------------------------------------------------------
 .../recursive_partitioning/random_forest.py_in  | 24 ++++++++++++--------
 .../test/random_forest.sql_in                   | 14 +++++++-----
 2 files changed, 23 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/9b45ecaa/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
----------------------------------------------------------------------
diff --git 
a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in 
b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
index 4b6f2d6..1b5ad88 100644
--- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
@@ -450,7 +450,8 @@ def forest_train(
                     bins['cat_origin']])
 
                 con_splits_table = unique_string()
-                _create_con_splits_table(schema_madlib, con_splits_table, 
grouping_cols, grp_key_to_grp_cols, bins)
+                _create_con_splits_table(schema_madlib, con_splits_table,
+                                         grouping_cols, grp_key_to_grp_cols, 
bins)
 
                 
##################################################################
                 # create views and tables for training (growing) of trees
@@ -600,7 +601,8 @@ def forest_train(
                         con_splits_table, oob_prediction_table, oob_view,
                         sample_id, id_col_name, cat_features, con_features,
                         boolean_cats, grouping_cols, grp_key_to_grp_cols, dep,
-                        num_permutations, is_classification, importance, 
num_bins)
+                        num_permutations, is_classification, importance,
+                        num_bins, filter_null)
 
                 
###################################################################
                 # evaluating and summerizing random forest
@@ -626,9 +628,9 @@ def forest_train(
                 # calculated, otherwise we use an empty table which will be 
used later
                 # for an outer join.
                 if importance:
-                    _calculate_variable_importance(schema_madlib,
-                            oob_prediction_table, is_classification,
-                            importance_table, len(cat_features), 
len(con_features))
+                    _calculate_variable_importance(
+                        schema_madlib, oob_prediction_table, is_classification,
+                        importance_table, len(cat_features), len(con_features))
 
                 _create_group_table(schema_madlib, output_table_name,
                                     oob_error_table, importance_table,
@@ -926,7 +928,7 @@ def _calculate_oob_prediction(
         schema_madlib, model_table, cat_features_info_table, con_splits_table,
         oob_prediction_table, oob_view, sample_id, id_col_name, cat_features,
         con_features, boolean_cats, grouping_cols, grp_key_to_grp_cols, dep,
-        num_permutations, is_classification, importance, num_bins):
+        num_permutations, is_classification, importance, num_bins, 
filter_null):
     """Calculate predication for out-of-bag sample"""
 
     cat_features_str, con_features_str = get_feature_str(
@@ -1045,6 +1047,7 @@ def _calculate_oob_prediction(
             LEFT OUTER JOIN -- empty if variable importance is disabled
                 {oob_var_dist_view}
             USING (gid)
+            WHERE {filter_null}
     """.format(**locals())
     plpy.notice("sql_oob_predict : " + str(sql_oob_predict))
     plpy.execute(sql_oob_predict)
@@ -1091,11 +1094,14 @@ def _create_con_splits_table(schema_madlib, 
con_splits_table, grouping_cols,
 # 
------------------------------------------------------------------------------
 
 
-def _calculate_variable_importance(schema_madlib, oob_prediction_table,
-        is_classification, importance_table, n_cat, n_con):
+def _calculate_variable_importance(
+        schema_madlib, oob_prediction_table, is_classification,
+        importance_table, n_cat, n_con):
     if not is_classification:
+        # squared error
         score_expression = "-((oob_prediction - dep)^2)".format(**locals())
     else:
+        # misclassification
         score_expression = """
                 CASE WHEN dep = oob_prediction::integer
                     THEN 1.
@@ -1200,7 +1206,7 @@ def _create_summary_table(**kwargs):
 
     kwargs['indep_type'] = ', '.join(kwargs['all_cols_types'][col]
                                      for col in kwargs['cat_features'] +
-                                        kwargs['con_features'])
+                                                kwargs['con_features'])
     kwargs['dep_type'] = _get_dep_type(kwargs['training_table_name'],
                                        kwargs['dependent_variable'])
     kwargs['cat_features_str'] = ','.join(kwargs['cat_features'])

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/9b45ecaa/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
----------------------------------------------------------------------
diff --git 
a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in 
b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
index 086e74b..f3ad93c 100644
--- 
a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
+++ 
b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
@@ -13,16 +13,18 @@ INSERT INTO dt_golf 
(id,"OUTLOOK",temperature,humidity,cont_features,windy,class
 (1, 'sunny', 85, 85,ARRAY[85, 85], false, 'Don''t Play'),
 (2, 'sunny', 80, 90,ARRAY[80, 90], true, 'Don''t Play'),
 (3, 'overcast', 83, 78,ARRAY[83, 78], false, 'Play'),
-(4, 'rain', 70, 96,ARRAY[70, 96], false, 'Play'),
+(4, 'rain', 70, NULL,ARRAY[70, 96], false, 'Play'),
 (5, 'rain', 68, 80,ARRAY[68, 80], false, 'Play'),
-(6, 'rain', 65, 70,ARRAY[65, 70], true, 'Don''t Play'),
-(7, 'overcast', 64, 65,ARRAY[64, 65], true, 'Play'),
+(6, 'rain', NULL, 70,ARRAY[65, 70], true, 'Don''t Play'),
+(7, 'overcast', 64, 65,ARRAY[64, 65],NULL, 'Play'),
 (8, 'sunny', 72, 95,ARRAY[72, 95], false, 'Don''t Play'),
 (9, 'sunny', 69, 70,ARRAY[69, 70], false, 'Play'),
 (10, 'rain', 75, 80,ARRAY[75, 80], false, 'Play'),
 (11, 'sunny', 75, 70,ARRAY[75, 70], true, 'Play'),
-(12, 'overcast', 72, 90,ARRAY[72, 90], true, 'Play'),
+(12, 'overcast', 72, 90,ARRAY[72, 90], NULL, 'Play'),
 (13, 'overcast', 81, 75,ARRAY[81, 75], false, 'Play'),
+(15, NULL, 81, 75,ARRAY[81, 75], false, 'Play'),
+(16, 'overcast', NULL, 75,ARRAY[81, 75], false, 'Play'),
 (14, 'rain', 71, 80,ARRAY[71, 80], true, 'Don''t Play');
 
 -------------------------------------------------------------------------
@@ -116,7 +118,7 @@ SELECT forest_train(
                   'dt_golf'::TEXT,         -- source table
                   'train_output'::TEXT,    -- output model table
                   'id'::TEXT,              -- id column
-                  'temperature::double precision'::TEXT,           -- response
+                  'temperature::double precision'::TEXT,   -- response
                   'class, temperature, windy'::TEXT,   -- features
                   NULL::TEXT,        -- exclude columns
                   NULL::TEXT,        -- no grouping
@@ -150,7 +152,7 @@ SELECT forest_train(
                   'temperature::double precision'::TEXT,           -- response
                   'humidity'::TEXT,   -- features
                   NULL::TEXT,        -- exclude columns
-                  'class,windy',          -- grouping
+                  'class',          -- grouping
                   5,                -- num of trees
                   1,                 -- num of random features
                   TRUE::BOOLEAN,     -- importance

Reply via email to