You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ji...@apache.org on 2018/07/25 21:21:41 UTC
madlib git commit: RF: Port DT fix for incorrect importance vector
length
Repository: madlib
Updated Branches:
refs/heads/master b96366d82 -> 2aac41897
RF: Port DT fix for incorrect importance vector length
JIRA: MADLIB-1254
Commit 0f7834e contained a fix in DT that ensured that impurity variable
importance was of the correct length for each group if a single group
eliminated a categorical variable.
This commit applies the same fix for random forest.
Closes #299
Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/2aac4189
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/2aac4189
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/2aac4189
Branch: refs/heads/master
Commit: 2aac418977128bd70bfc863fc02282a400427db5
Parents: b96366d
Author: Rahul Iyer <ri...@apache.org>
Authored: Wed Jul 25 10:51:47 2018 -0700
Committer: Jingyi Mei <jm...@pivotal.io>
Committed: Wed Jul 25 14:21:00 2018 -0700
----------------------------------------------------------------------
.../recursive_partitioning/random_forest.py_in | 75 +++++++++++---------
1 file changed, 40 insertions(+), 35 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/madlib/blob/2aac4189/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
index c06bed8..4d74872 100644
--- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
@@ -10,6 +10,8 @@
import plpy
from math import sqrt, ceil
+from internal.db_utils import quote_literal
+
from utilities.control import MinWarning
from utilities.control import OptimizerControl
from utilities.control import HashaggControl
@@ -504,7 +506,7 @@ def forest_train(
##################################################################
# training random forest
- tree_terminated = None
+ tree_terminated = dict()
for sample_id in range(1, num_trees + 1):
if 1 - sample_ratio < 1e-6:
random_sample_expr = "0.::double precision"
@@ -548,18 +550,22 @@ def forest_train(
""".format(**locals()))
if not grouping_cols:
- tree_state = _tree_train_using_bins(
+ tree = _tree_train_using_bins(
schema_madlib, bins, src_view, cat_features, con_features,
boolean_cats, num_bins, 'poisson_count', dep, min_split,
min_bucket, max_tree_depth, filter_null, dep_n_levels,
is_classification, split_criterion, True,
num_random_features, max_n_surr, null_proxy)
- tree_states = [dict(tree_state=tree_state['tree_state'],
- grp_key='')]
-
- tree_terminated = {'': tree_state['finished']}
-
+ tree['grp_key'] = ''
+ tree['cat_features'] = cat_features
+ tree['con_features'] = con_features
+ if importance:
+ tree.update(_compute_var_importance(
+ schema_madlib, tree,
+ len(cat_features), len(con_features)))
+ tree_states = [tree]
+ tree_terminated = {'': tree['finished']}
else:
tree_states = _tree_train_grps_using_bins(
schema_madlib, bins, src_view, cat_features, con_features,
@@ -576,21 +582,17 @@ def forest_train(
# If a tree for a group is terminated (not finished properly),
# then we do not need to compute other trees, and can just
# stop calculating that group further.
- if tree_terminated is None:
- tree_terminated = dict((item['grp_key'], item['finished'])
- for item in tree_states)
- else:
- for item in tree_states:
- if item['grp_key'] not in tree_terminated:
- tree_terminated[item['grp_key']] = item['finished']
- elif item['finished'] == 2:
- tree_terminated[item['grp_key']] = 2
- if importance:
for tree in tree_states:
- importance_vectors = _compute_var_importance(
- schema_madlib, tree,
- len(cat_features), len(con_features))
- tree.update(**importance_vectors)
+ grp_key = tree['grp_key']
+ tree['cat_features'] = bins['grp_to_cat_features'][grp_key]
+ tree['con_features'] = bins['con_features']
+ tree_terminated[grp_key] = tree['finished']
+ if importance:
+ importance_vectors = _compute_var_importance(
+ schema_madlib, tree,
+ len(tree['cat_features']),
+ len(tree['con_features']))
+ tree.update(**importance_vectors)
_insert_into_result_table(
schema_madlib, tree_states, output_table_name, impurity_imp_table,
@@ -1399,28 +1401,31 @@ def _insert_into_result_table(schema_madlib, tree_states, output_table_name,
""".format(**locals())
sql_plan = plpy.prepare(sql, ['{0}.bytea8[]'.format(schema_madlib)])
plpy.execute(sql_plan, [[tree_state['tree_state'] for tree_state in tree_states]])
+
if importance:
- importance_results = [tree_state['impurity_var_importance']
- for tree_state in tree_states]
- importance_results = py_list_to_sql_string(importance_results,
- 'DOUBLE PRECISION[]',
- True)
- importance_query = """({schema_madlib}.array_unnest_2d_to_1d(
- {importance_results})).unnest_result AS
- impurity_var_importance""".format(**locals())
- plpy.execute("""
+ grp_imp_values = []
+ for tree_state in tree_states:
+ importance_vector = py_list_to_sql_string(
+ tree_state['impurity_var_importance'],
+ 'DOUBLE PRECISION',
+ True)
+ grp_imp_values.append("({0}, {1})".
+ format(quote_literal(tree_state['grp_key']),
+ importance_vector))
+ sql = """
INSERT INTO {impurity_imp_table}
SELECT
{gid_str},
{sample_id} AS sample_id,
impurity_var_importance
FROM (
- SELECT
- {grp_key_sql}
- {importance_query}
- ) grp_key_to_tree
+ VALUES
+ {grp_imp_values_str}
+ ) grp_key_to_importance(grp_key, impurity_var_importance)
{grp_join_sql}
- """.format(**locals()))
+ """.format(grp_imp_values_str=', \n'.join(grp_imp_values),
+ **locals())
+ plpy.execute(sql)
# ------------------------------------------------------------------------------