Repository: incubator-madlib
Updated Branches:
  refs/heads/master 01788982d -> 2f1c4b288


RF: Ensure n_random_features always > 0


Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/2f1c4b28
Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/2f1c4b28
Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/2f1c4b28

Branch: refs/heads/master
Commit: 2f1c4b28847aa9d95edc0a54a25ad7651b2410ed
Parents: 0178898
Author: Rahul Iyer <ri...@apache.org>
Authored: Mon Jul 3 10:22:30 2017 -0700
Committer: Rahul Iyer <ri...@apache.org>
Committed: Mon Jul 3 10:25:21 2017 -0700

----------------------------------------------------------------------
 .../modules/recursive_partitioning/random_forest.py_in   | 11 ++++++-----
 .../recursive_partitioning/test/random_forest.sql_in     |  4 ++--
 2 files changed, 8 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/2f1c4b28/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
----------------------------------------------------------------------
diff --git 
a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in 
b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
index 1b5ad88..05c029e 100644
--- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
@@ -8,7 +8,7 @@
 """
 
 import plpy
-from math import sqrt
+from math import sqrt, ceil
 
 from utilities.control import MinWarning
 from utilities.control import EnableOptimizer
@@ -321,9 +321,10 @@ def forest_train(
 
                 if num_random_features is None:
                     n_all_features = len(features)
-                    num_random_features = (sqrt(n_all_features) if 
is_classification
-                                           else n_all_features / 3)
-                _assert(num_random_features <= len(features),
+                    num_random_features = int(sqrt(n_all_features) if 
is_classification
+                                              else ceil(float(n_all_features) 
/ 3))
+
+                _assert(0 < num_random_features <= len(features),
                         "Random forest error: Number of features to be 
selected "
                         "is more than the actual number of features.")
 
@@ -351,7 +352,7 @@ def forest_train(
                     dep = ("(CASE " +
                            "\n               ".
                            join(["WHEN ({dep_col})::text = $${c}$$ THEN {i}".
-                                    format(dep_col=dep_col_str, c=c, i=i)
+                                format(dep_col=dep_col_str, c=c, i=i)
                                  for i, c in enumerate(dep_list)]) +
                            "\nEND)")
                     dep_n_levels = len(dep_list)

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/2f1c4b28/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
----------------------------------------------------------------------
diff --git 
a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in 
b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
index 8aec1f0..37837b0 100644
--- 
a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
+++ 
b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
@@ -39,7 +39,7 @@ SELECT forest_train(
                   NULL::TEXT,        -- exclude columns
                   NULL::TEXT,        -- no grouping
                   5,                -- num of trees
-                  1,                 -- num of random features
+                  NULL,                 -- num of random features
                   TRUE::BOOLEAN,    -- importance
                   1::INTEGER,       -- num_permutations
                   10::INTEGER,       -- max depth
@@ -65,7 +65,7 @@ SELECT forest_train(
                   NULL::TEXT,        -- exclude columns
                   'class',          -- grouping
                   5,                -- num of trees
-                  1,                 -- num of random features
+                  NULL,                 -- num of random features
                   TRUE::BOOLEAN,     -- importance
                   20::INTEGER,         -- num_permutations
                   10::INTEGER,       -- max depth

Reply via email to