Repository: incubator-madlib Updated Branches: refs/heads/master 01788982d -> 2f1c4b288
RF: Ensure n_random_features always > 0 Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/2f1c4b28 Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/2f1c4b28 Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/2f1c4b28 Branch: refs/heads/master Commit: 2f1c4b28847aa9d95edc0a54a25ad7651b2410ed Parents: 0178898 Author: Rahul Iyer <ri...@apache.org> Authored: Mon Jul 3 10:22:30 2017 -0700 Committer: Rahul Iyer <ri...@apache.org> Committed: Mon Jul 3 10:25:21 2017 -0700 ---------------------------------------------------------------------- .../modules/recursive_partitioning/random_forest.py_in | 11 ++++++----- .../recursive_partitioning/test/random_forest.sql_in | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/2f1c4b28/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in index 1b5ad88..05c029e 100644 --- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in +++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in @@ -8,7 +8,7 @@ """ import plpy -from math import sqrt +from math import sqrt, ceil from utilities.control import MinWarning from utilities.control import EnableOptimizer @@ -321,9 +321,10 @@ def forest_train( if num_random_features is None: n_all_features = len(features) - num_random_features = (sqrt(n_all_features) if is_classification - else n_all_features / 3) - _assert(num_random_features <= len(features), + num_random_features = int(sqrt(n_all_features) if is_classification + else ceil(float(n_all_features) / 3)) + + _assert(0 < num_random_features <= len(features), "Random forest error: Number of features to be selected " "is more than the actual number of features.") @@ -351,7 +352,7 @@ def forest_train( dep = ("(CASE " + "\n ". join(["WHEN ({dep_col})::text = $${c}$$ THEN {i}". - format(dep_col=dep_col_str, c=c, i=i) + format(dep_col=dep_col_str, c=c, i=i) for i, c in enumerate(dep_list)]) + "\nEND)") dep_n_levels = len(dep_list) http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/2f1c4b28/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in index 8aec1f0..37837b0 100644 --- a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in +++ b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in @@ -39,7 +39,7 @@ SELECT forest_train( NULL::TEXT, -- exclude columns NULL::TEXT, -- no grouping 5, -- num of trees - 1, -- num of random features + NULL, -- num of random features TRUE::BOOLEAN, -- importance 1::INTEGER, -- num_permutations 10::INTEGER, -- max depth @@ -65,7 +65,7 @@ SELECT forest_train( NULL::TEXT, -- exclude columns 'class', -- grouping 5, -- num of trees - 1, -- num of random features + NULL, -- num of random features TRUE::BOOLEAN, -- importance 20::INTEGER, -- num_permutations 10::INTEGER, -- max depth