You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ok...@apache.org on 2017/08/29 20:41:50 UTC
[13/50] [abbrv] incubator-madlib git commit: RF: Ensure
n_random_features always > 0
RF: Ensure n_random_features always > 0
Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/2f1c4b28
Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/2f1c4b28
Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/2f1c4b28
Branch: refs/heads/latest_release
Commit: 2f1c4b28847aa9d95edc0a54a25ad7651b2410ed
Parents: 0178898
Author: Rahul Iyer <ri...@apache.org>
Authored: Mon Jul 3 10:22:30 2017 -0700
Committer: Rahul Iyer <ri...@apache.org>
Committed: Mon Jul 3 10:25:21 2017 -0700
----------------------------------------------------------------------
.../modules/recursive_partitioning/random_forest.py_in | 11 ++++++-----
.../recursive_partitioning/test/random_forest.sql_in | 4 ++--
2 files changed, 8 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/2f1c4b28/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
index 1b5ad88..05c029e 100644
--- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
@@ -8,7 +8,7 @@
"""
import plpy
-from math import sqrt
+from math import sqrt, ceil
from utilities.control import MinWarning
from utilities.control import EnableOptimizer
@@ -321,9 +321,10 @@ def forest_train(
if num_random_features is None:
n_all_features = len(features)
- num_random_features = (sqrt(n_all_features) if is_classification
- else n_all_features / 3)
- _assert(num_random_features <= len(features),
+ num_random_features = int(sqrt(n_all_features) if is_classification
+ else ceil(float(n_all_features) / 3))
+
+ _assert(0 < num_random_features <= len(features),
"Random forest error: Number of features to be selected "
"is more than the actual number of features.")
@@ -351,7 +352,7 @@ def forest_train(
dep = ("(CASE " +
"\n ".
join(["WHEN ({dep_col})::text = $${c}$$ THEN {i}".
- format(dep_col=dep_col_str, c=c, i=i)
+ format(dep_col=dep_col_str, c=c, i=i)
for i, c in enumerate(dep_list)]) +
"\nEND)")
dep_n_levels = len(dep_list)
http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/2f1c4b28/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
index 8aec1f0..37837b0 100644
--- a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
+++ b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in
@@ -39,7 +39,7 @@ SELECT forest_train(
NULL::TEXT, -- exclude columns
NULL::TEXT, -- no grouping
5, -- num of trees
- 1, -- num of random features
+ NULL, -- num of random features
TRUE::BOOLEAN, -- importance
1::INTEGER, -- num_permutations
10::INTEGER, -- max depth
@@ -65,7 +65,7 @@ SELECT forest_train(
NULL::TEXT, -- exclude columns
'class', -- grouping
5, -- num of trees
- 1, -- num of random features
+ NULL, -- num of random features
TRUE::BOOLEAN, -- importance
20::INTEGER, -- num_permutations
10::INTEGER, -- max depth