You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ji...@apache.org on 2018/07/25 21:21:41 UTC

madlib git commit: RF: Port DT fix for incorrect importance vector length

Repository: madlib
Updated Branches:
  refs/heads/master b96366d82 -> 2aac41897


RF: Port DT fix for incorrect importance vector length

JIRA: MADLIB-1254

Commit 0f7834e contained a fix in DT that ensured that impurity variable
importance was of the correct length for each group if a single group
eliminated a categorical variable.

This commit applies the same fix for random forest.

Closes #299


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/2aac4189
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/2aac4189
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/2aac4189

Branch: refs/heads/master
Commit: 2aac418977128bd70bfc863fc02282a400427db5
Parents: b96366d
Author: Rahul Iyer <ri...@apache.org>
Authored: Wed Jul 25 10:51:47 2018 -0700
Committer: Jingyi Mei <jm...@pivotal.io>
Committed: Wed Jul 25 14:21:00 2018 -0700

----------------------------------------------------------------------
 .../recursive_partitioning/random_forest.py_in  | 75 +++++++++++---------
 1 file changed, 40 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/2aac4189/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
index c06bed8..4d74872 100644
--- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
@@ -10,6 +10,8 @@
 import plpy
 from math import sqrt, ceil
 
+from internal.db_utils import quote_literal
+
 from utilities.control import MinWarning
 from utilities.control import OptimizerControl
 from utilities.control import HashaggControl
@@ -504,7 +506,7 @@ def forest_train(
 
                 ##################################################################
                 # training random forest
-                tree_terminated = None
+                tree_terminated = dict()
                 for sample_id in range(1, num_trees + 1):
                     if 1 - sample_ratio < 1e-6:
                         random_sample_expr = "0.::double precision"
@@ -548,18 +550,22 @@ def forest_train(
                         """.format(**locals()))
 
                     if not grouping_cols:
-                        tree_state = _tree_train_using_bins(
+                        tree = _tree_train_using_bins(
                             schema_madlib, bins, src_view, cat_features, con_features,
                             boolean_cats, num_bins, 'poisson_count', dep, min_split,
                             min_bucket, max_tree_depth, filter_null, dep_n_levels,
                             is_classification, split_criterion, True,
                             num_random_features, max_n_surr, null_proxy)
 
-                        tree_states = [dict(tree_state=tree_state['tree_state'],
-                                            grp_key='')]
-
-                        tree_terminated = {'': tree_state['finished']}
-
+                        tree['grp_key'] = ''
+                        tree['cat_features'] = cat_features
+                        tree['con_features'] = con_features
+                        if importance:
+                            tree.update(_compute_var_importance(
+                                schema_madlib, tree,
+                                len(cat_features), len(con_features)))
+                        tree_states = [tree]
+                        tree_terminated = {'': tree['finished']}
                     else:
                         tree_states = _tree_train_grps_using_bins(
                             schema_madlib, bins, src_view, cat_features, con_features,
@@ -576,21 +582,17 @@ def forest_train(
                         # If a tree for a group is terminated (not finished properly),
                         # then we do not need to compute other trees, and can just
                         # stop calculating that group further.
-                        if tree_terminated is None:
-                            tree_terminated = dict((item['grp_key'], item['finished'])
-                                                   for item in tree_states)
-                        else:
-                            for item in tree_states:
-                                if item['grp_key'] not in tree_terminated:
-                                    tree_terminated[item['grp_key']] = item['finished']
-                                elif item['finished'] == 2:
-                                    tree_terminated[item['grp_key']] = 2
-                    if importance:
                         for tree in tree_states:
-                            importance_vectors = _compute_var_importance(
-                                schema_madlib, tree,
-                                len(cat_features), len(con_features))
-                            tree.update(**importance_vectors)
+                            grp_key = tree['grp_key']
+                            tree['cat_features'] = bins['grp_to_cat_features'][grp_key]
+                            tree['con_features'] = bins['con_features']
+                            tree_terminated[grp_key] = tree['finished']
+                            if importance:
+                                importance_vectors = _compute_var_importance(
+                                    schema_madlib, tree,
+                                    len(tree['cat_features']),
+                                    len(tree['con_features']))
+                                tree.update(**importance_vectors)
 
                     _insert_into_result_table(
                         schema_madlib, tree_states, output_table_name, impurity_imp_table,
@@ -1399,28 +1401,31 @@ def _insert_into_result_table(schema_madlib, tree_states, output_table_name,
         """.format(**locals())
     sql_plan = plpy.prepare(sql, ['{0}.bytea8[]'.format(schema_madlib)])
     plpy.execute(sql_plan, [[tree_state['tree_state'] for tree_state in tree_states]])
+
     if importance:
-        importance_results = [tree_state['impurity_var_importance']
-                              for tree_state in tree_states]
-        importance_results = py_list_to_sql_string(importance_results,
-                                                   'DOUBLE PRECISION[]',
-                                                   True)
-        importance_query = """({schema_madlib}.array_unnest_2d_to_1d(
-                                    {importance_results})).unnest_result AS
-                                 impurity_var_importance""".format(**locals())
-        plpy.execute("""
+        grp_imp_values = []
+        for tree_state in tree_states:
+            importance_vector = py_list_to_sql_string(
+                tree_state['impurity_var_importance'],
+                'DOUBLE PRECISION',
+                True)
+            grp_imp_values.append("({0}, {1})".
+                                  format(quote_literal(tree_state['grp_key']),
+                                         importance_vector))
+        sql = """
         INSERT INTO {impurity_imp_table}
         SELECT
             {gid_str},
             {sample_id} AS sample_id,
             impurity_var_importance
         FROM (
-            SELECT
-                {grp_key_sql}
-                {importance_query}
-        ) grp_key_to_tree
+            VALUES
+                {grp_imp_values_str}
+        ) grp_key_to_importance(grp_key, impurity_var_importance)
         {grp_join_sql}
-        """.format(**locals()))
+        """.format(grp_imp_values_str=', \n'.join(grp_imp_values),
+                   **locals())
+        plpy.execute(sql)
 
 # ------------------------------------------------------------------------------